From 47f6eb51630b3a4039d9e943530efad9a0e59626 Mon Sep 17 00:00:00 2001
From: Daniel <dhaavi@users.noreply.github.com>
Date: Fri, 6 Oct 2023 10:47:39 +0200
Subject: [PATCH] Improve mime type selection

---
 formats/dsd/http.go      | 65 ++++++++++++++++++++++++----------------
 formats/dsd/http_test.go | 25 +++++++++++-----
 2 files changed, 56 insertions(+), 34 deletions(-)

diff --git a/formats/dsd/http.go b/formats/dsd/http.go
index 6962cac..85aab16 100644
--- a/formats/dsd/http.go
+++ b/formats/dsd/http.go
@@ -95,10 +95,10 @@ func DumpToHTTPResponse(w http.ResponseWriter, r *http.Request, t interface{}) e
 	return nil
 }
 
-// MimeLoad loads the given data into the interface based on the given mime type.
-func MimeLoad(data []byte, mimeType string, t interface{}) (format uint8, err error) {
+// MimeLoad loads the given data into the interface based on the given mime type accept header.
+func MimeLoad(data []byte, accept string, t interface{}) (format uint8, err error) {
 	// Find format.
-	format = FormatFromMime(mimeType)
+	format = FormatFromAccept(accept)
 	if format == 0 {
 		return 0, ErrIncompatibleFormat
 	}
@@ -111,40 +111,53 @@ func MimeLoad(data []byte, mimeType string, t interface{}) (format uint8, err er
 // MimeDump dumps the given interface based on the given mime type accept header.
 func MimeDump(t any, accept string) (data []byte, mimeType string, format uint8, err error) {
 	// Find format.
-	accept = extractMimeType(accept)
-	switch accept {
-	case "", "*":
-		format = DefaultSerializationFormat
-	default:
-		format = MimeTypeToFormat[accept]
-		if format == 0 {
-			return nil, "", 0, ErrIncompatibleFormat
-		}
+	format = FormatFromAccept(accept)
+	if format == AUTO {
+		return nil, "", 0, ErrIncompatibleFormat
 	}
-	mimeType = FormatToMimeType[format]
 
 	// Serialize and return.
 	data, err = dumpWithoutIdentifier(t, format, "")
 	return data, mimeType, format, err
 }
 
-// FormatFromMime returns the format for the given mime type.
-// Will return AUTO format for unsupported or unrecognized mime types.
-func FormatFromMime(mimeType string) (format uint8) {
-	return MimeTypeToFormat[extractMimeType(mimeType)]
-}
+// FormatFromAccept returns the format for the given accept definition.
+// The accept parameter matches the format of the HTTP Accept header.
+// Special cases, in this order:
+// - If accept is an empty string: returns default serialization format.
+// - If accept contains no supported format, but a wildcard: returns default serialization format.
+// - If accept contains no supported format, and no wildcard: returns AUTO format.
+func FormatFromAccept(accept string) (format uint8) {
+	if accept == "" {
+		return DefaultSerializationFormat
+	}
 
-func extractMimeType(mimeType string) string {
-	if strings.Contains(mimeType, ",") {
-		mimeType, _, _ = strings.Cut(mimeType, ",")
-	}
-	if strings.Contains(mimeType, ";") {
+	var foundWildcard bool
+	for _, mimeType := range strings.Split(accept, ",") {
+		// Clean mime type.
+		mimeType = strings.TrimSpace(mimeType)
 		mimeType, _, _ = strings.Cut(mimeType, ";")
+		if strings.Contains(mimeType, "/") {
+			_, mimeType, _ = strings.Cut(mimeType, "/")
+		}
+		mimeType = strings.ToLower(mimeType)
+
+		// Check if mime type is supported.
+		format, ok := MimeTypeToFormat[mimeType]
+		if ok {
+			return format
+		}
+
+		// Return default mime type as fallback if any mimetype is okay.
+		if mimeType == "*" {
+			foundWildcard = true
+		}
 	}
-	if strings.Contains(mimeType, "/") {
-		_, mimeType, _ = strings.Cut(mimeType, "/")
+
+	if foundWildcard {
+		return DefaultSerializationFormat
 	}
-	return strings.ToLower(mimeType)
+	return AUTO
 }
 
 // Format and MimeType mappings.
diff --git a/formats/dsd/http_test.go b/formats/dsd/http_test.go
index ce74e56..32651ac 100644
--- a/formats/dsd/http_test.go
+++ b/formats/dsd/http_test.go
@@ -23,14 +23,23 @@ func TestMimeTypes(t *testing.T) {
 	}
 
 	// Test assumptions.
-	for mimeType, mimeTypeCleaned := range map[string]string{
-		"application/xml, image/webp":       "xml",
-		"application/xml;q=0.9, image/webp": "xml",
-		"*":                                 "*",
-		"*/*":                               "*",
-		"text/yAMl":                         "yaml",
+	for accept, format := range map[string]uint8{
+		"application/json, image/webp":       JSON,
+		"image/webp, application/json":       JSON,
+		"application/json;q=0.9, image/webp": JSON,
+		"*":                                  DefaultSerializationFormat,
+		"*/*":                                DefaultSerializationFormat,
+		"text/yAMl":                          YAML,
+		" * , yaml ":                         YAML,
+		"yaml;charset ,*":                    YAML,
+		"xml,*":                              DefaultSerializationFormat,
+		"text/xml, text/other":               AUTO,
+		"text/*":                             DefaultSerializationFormat,
+		"yaml ;charset":                      AUTO, // Invalid mimetype format.
+		"":                                   DefaultSerializationFormat,
+		"x":                                  AUTO,
 	} {
-		cleaned := extractMimeType(mimeType)
-		assert.Equal(t, mimeTypeCleaned, cleaned, "assumption for %q should hold", mimeType)
+		derivedFormat := FormatFromAccept(accept)
+		assert.Equal(t, format, derivedFormat, "assumption for %q should hold", accept)
 	}
 }