about summary refs log tree commit diff
path: root/vendor/github.com/gorilla/handlers/cors.go
diff options
context:
space:
mode:
authorEmile <git@emile.space>2024-08-16 19:50:26 +0200
committerEmile <git@emile.space>2024-08-16 19:50:26 +0200
commit1a57267a17c2fc17fb6e104846fabc3e363c326c (patch)
tree1e574e3a80622086dc3c81ff9cba65ef7049b1a9 /vendor/github.com/gorilla/handlers/cors.go
initial commit
Diffstat (limited to 'vendor/github.com/gorilla/handlers/cors.go')
-rw-r--r--vendor/github.com/gorilla/handlers/cors.go352
1 files changed, 352 insertions, 0 deletions
diff --git a/vendor/github.com/gorilla/handlers/cors.go b/vendor/github.com/gorilla/handlers/cors.go
new file mode 100644
index 0000000..8af9c09
--- /dev/null
+++ b/vendor/github.com/gorilla/handlers/cors.go
@@ -0,0 +1,352 @@
+package handlers
+
+import (
+	"net/http"
+	"strconv"
+	"strings"
+)
+
+// CORSOption represents a functional option for configuring the CORS middleware.
+type CORSOption func(*cors) error
+
+type cors struct {
+	h                      http.Handler
+	allowedHeaders         []string
+	allowedMethods         []string
+	allowedOrigins         []string
+	allowedOriginValidator OriginValidator
+	exposedHeaders         []string
+	maxAge                 int
+	ignoreOptions          bool
+	allowCredentials       bool
+	optionStatusCode       int
+}
+
+// OriginValidator takes an origin string and returns whether or not that origin is allowed.
+type OriginValidator func(string) bool
+
+var (
+	defaultCorsOptionStatusCode = http.StatusOK
+	defaultCorsMethods          = []string{http.MethodGet, http.MethodHead, http.MethodPost}
+	defaultCorsHeaders          = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
+	// (WebKit/Safari v9 sends the Origin header by default in AJAX requests).
+)
+
+const (
+	corsOptionMethod           string = http.MethodOptions
+	corsAllowOriginHeader      string = "Access-Control-Allow-Origin"
+	corsExposeHeadersHeader    string = "Access-Control-Expose-Headers"
+	corsMaxAgeHeader           string = "Access-Control-Max-Age"
+	corsAllowMethodsHeader     string = "Access-Control-Allow-Methods"
+	corsAllowHeadersHeader     string = "Access-Control-Allow-Headers"
+	corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials"
+	corsRequestMethodHeader    string = "Access-Control-Request-Method"
+	corsRequestHeadersHeader   string = "Access-Control-Request-Headers"
+	corsOriginHeader           string = "Origin"
+	corsVaryHeader             string = "Vary"
+	corsOriginMatchAll         string = "*"
+)
+
+func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	origin := r.Header.Get(corsOriginHeader)
+	if !ch.isOriginAllowed(origin) {
+		if r.Method != corsOptionMethod || ch.ignoreOptions {
+			ch.h.ServeHTTP(w, r)
+		}
+
+		return
+	}
+
+	if r.Method == corsOptionMethod {
+		if ch.ignoreOptions {
+			ch.h.ServeHTTP(w, r)
+			return
+		}
+
+		if _, ok := r.Header[corsRequestMethodHeader]; !ok {
+			w.WriteHeader(http.StatusBadRequest)
+			return
+		}
+
+		method := r.Header.Get(corsRequestMethodHeader)
+		if !ch.isMatch(method, ch.allowedMethods) {
+			w.WriteHeader(http.StatusMethodNotAllowed)
+			return
+		}
+
+		requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",")
+		allowedHeaders := []string{}
+		for _, v := range requestHeaders {
+			canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
+			if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) {
+				continue
+			}
+
+			if !ch.isMatch(canonicalHeader, ch.allowedHeaders) {
+				w.WriteHeader(http.StatusForbidden)
+				return
+			}
+
+			allowedHeaders = append(allowedHeaders, canonicalHeader)
+		}
+
+		if len(allowedHeaders) > 0 {
+			w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ","))
+		}
+
+		if ch.maxAge > 0 {
+			w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge))
+		}
+
+		if !ch.isMatch(method, defaultCorsMethods) {
+			w.Header().Set(corsAllowMethodsHeader, method)
+		}
+	} else if len(ch.exposedHeaders) > 0 {
+		w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ","))
+	}
+
+	if ch.allowCredentials {
+		w.Header().Set(corsAllowCredentialsHeader, "true")
+	}
+
+	if len(ch.allowedOrigins) > 1 {
+		w.Header().Set(corsVaryHeader, corsOriginHeader)
+	}
+
+	returnOrigin := origin
+	if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 {
+		returnOrigin = "*"
+	} else {
+		for _, o := range ch.allowedOrigins {
+			// A configuration of * is different than explicitly setting an allowed
+			// origin. Returning arbitrary origin headers in an access control allow
+			// origin header is unsafe and is not required by any use case.
+			if o == corsOriginMatchAll {
+				returnOrigin = "*"
+				break
+			}
+		}
+	}
+	w.Header().Set(corsAllowOriginHeader, returnOrigin)
+
+	if r.Method == corsOptionMethod {
+		w.WriteHeader(ch.optionStatusCode)
+		return
+	}
+	ch.h.ServeHTTP(w, r)
+}
+
+// CORS provides Cross-Origin Resource Sharing middleware.
+// Example:
+//
+//	import (
+//	    "net/http"
+//
+//	    "github.com/gorilla/handlers"
+//	    "github.com/gorilla/mux"
+//	)
+//
+//	func main() {
+//	    r := mux.NewRouter()
+//	    r.HandleFunc("/users", UserEndpoint)
+//	    r.HandleFunc("/projects", ProjectEndpoint)
+//
+//	    // Apply the CORS middleware to our top-level router, with the defaults.
+//	    http.ListenAndServe(":8000", handlers.CORS()(r))
+//	}
+func CORS(opts ...CORSOption) func(http.Handler) http.Handler {
+	return func(h http.Handler) http.Handler {
+		ch := parseCORSOptions(opts...)
+		ch.h = h
+		return ch
+	}
+}
+
+func parseCORSOptions(opts ...CORSOption) *cors {
+	ch := &cors{
+		allowedMethods:   defaultCorsMethods,
+		allowedHeaders:   defaultCorsHeaders,
+		allowedOrigins:   []string{},
+		optionStatusCode: defaultCorsOptionStatusCode,
+	}
+
+	for _, option := range opts {
+		_ = option(ch) //TODO: @bharat-rajani, return error to caller if not nil?
+	}
+
+	return ch
+}
+
+//
+// Functional options for configuring CORS.
+//
+
+// AllowedHeaders adds the provided headers to the list of allowed headers in a
+// CORS request.
+// This is an append operation so the headers Accept, Accept-Language,
+// and Content-Language are always allowed.
+// Content-Type must be explicitly declared if accepting Content-Types other than
+// application/x-www-form-urlencoded, multipart/form-data, or text/plain.
+func AllowedHeaders(headers []string) CORSOption {
+	return func(ch *cors) error {
+		for _, v := range headers {
+			normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
+			if normalizedHeader == "" {
+				continue
+			}
+
+			if !ch.isMatch(normalizedHeader, ch.allowedHeaders) {
+				ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader)
+			}
+		}
+
+		return nil
+	}
+}
+
+// AllowedMethods can be used to explicitly allow methods in the
+// Access-Control-Allow-Methods header.
+// This is a replacement operation so you must also
+// pass GET, HEAD, and POST if you wish to support those methods.
+func AllowedMethods(methods []string) CORSOption {
+	return func(ch *cors) error {
+		ch.allowedMethods = []string{}
+		for _, v := range methods {
+			normalizedMethod := strings.ToUpper(strings.TrimSpace(v))
+			if normalizedMethod == "" {
+				continue
+			}
+
+			if !ch.isMatch(normalizedMethod, ch.allowedMethods) {
+				ch.allowedMethods = append(ch.allowedMethods, normalizedMethod)
+			}
+		}
+
+		return nil
+	}
+}
+
+// AllowedOrigins sets the allowed origins for CORS requests, as used in the
+// 'Allow-Access-Control-Origin' HTTP header.
+// Note: Passing in a []string{"*"} will allow any domain.
+func AllowedOrigins(origins []string) CORSOption {
+	return func(ch *cors) error {
+		for _, v := range origins {
+			if v == corsOriginMatchAll {
+				ch.allowedOrigins = []string{corsOriginMatchAll}
+				return nil
+			}
+		}
+
+		ch.allowedOrigins = origins
+		return nil
+	}
+}
+
+// AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the
+// 'Allow-Access-Control-Origin' HTTP header.
+func AllowedOriginValidator(fn OriginValidator) CORSOption {
+	return func(ch *cors) error {
+		ch.allowedOriginValidator = fn
+		return nil
+	}
+}
+
+// OptionStatusCode sets a custom status code on the OPTIONS requests.
+// Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory
+// and can be used if you need a custom status code (i.e 204).
+//
+// More informations on the spec:
+// https://fetch.spec.whatwg.org/#cors-preflight-fetch
+func OptionStatusCode(code int) CORSOption {
+	return func(ch *cors) error {
+		ch.optionStatusCode = code
+		return nil
+	}
+}
+
+// ExposedHeaders can be used to specify headers that are available
+// and will not be stripped out by the user-agent.
+func ExposedHeaders(headers []string) CORSOption {
+	return func(ch *cors) error {
+		ch.exposedHeaders = []string{}
+		for _, v := range headers {
+			normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
+			if normalizedHeader == "" {
+				continue
+			}
+
+			if !ch.isMatch(normalizedHeader, ch.exposedHeaders) {
+				ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader)
+			}
+		}
+
+		return nil
+	}
+}
+
+// MaxAge determines the maximum age (in seconds) between preflight requests. A
+// maximum of 10 minutes is allowed. An age above this value will default to 10
+// minutes.
+func MaxAge(age int) CORSOption {
+	return func(ch *cors) error {
+		// Maximum of 10 minutes.
+		if age > 600 {
+			age = 600
+		}
+
+		ch.maxAge = age
+		return nil
+	}
+}
+
+// IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead
+// passing them through to the next handler. This is useful when your application
+// or framework has a pre-existing mechanism for responding to OPTIONS requests.
+func IgnoreOptions() CORSOption {
+	return func(ch *cors) error {
+		ch.ignoreOptions = true
+		return nil
+	}
+}
+
+// AllowCredentials can be used to specify that the user agent may pass
+// authentication details along with the request.
+func AllowCredentials() CORSOption {
+	return func(ch *cors) error {
+		ch.allowCredentials = true
+		return nil
+	}
+}
+
+func (ch *cors) isOriginAllowed(origin string) bool {
+	if origin == "" {
+		return false
+	}
+
+	if ch.allowedOriginValidator != nil {
+		return ch.allowedOriginValidator(origin)
+	}
+
+	if len(ch.allowedOrigins) == 0 {
+		return true
+	}
+
+	for _, allowedOrigin := range ch.allowedOrigins {
+		if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll {
+			return true
+		}
+	}
+
+	return false
+}
+
+func (ch *cors) isMatch(needle string, haystack []string) bool {
+	for _, v := range haystack {
+		if v == needle {
+			return true
+		}
+	}
+
+	return false
+}