diff options
Diffstat (limited to 'vendor/go.mau.fi/util/exhttp/handleerrors.go')
-rw-r--r-- | vendor/go.mau.fi/util/exhttp/handleerrors.go | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/vendor/go.mau.fi/util/exhttp/handleerrors.go b/vendor/go.mau.fi/util/exhttp/handleerrors.go new file mode 100644 index 0000000..d2d37b1 --- /dev/null +++ b/vendor/go.mau.fi/util/exhttp/handleerrors.go @@ -0,0 +1,58 @@ +package exhttp + +import "net/http" + +type ErrorBodyGenerators struct { + NotFound func() []byte + MethodNotAllowed func() []byte +} + +func HandleErrors(next http.Handler, gen ErrorBodyGenerators) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(&bodyOverrider{ + ResponseWriter: w, + statusNotFoundBodyGenerator: gen.NotFound, + statusMethodNotAllowedBodyGenerator: gen.MethodNotAllowed, + }, r) + }) +} + +type bodyOverrider struct { + http.ResponseWriter + + code int + override bool + + statusNotFoundBodyGenerator func() []byte + statusMethodNotAllowedBodyGenerator func() []byte +} + +var _ http.ResponseWriter = (*bodyOverrider)(nil) + +func (b *bodyOverrider) WriteHeader(code int) { + if b.Header().Get("Content-Type") == "text/plain; charset=utf-8" { + b.Header().Set("Content-Type", "application/json") + + b.override = true + } + + b.code = code + b.ResponseWriter.WriteHeader(code) +} + +func (b *bodyOverrider) Write(body []byte) (int, error) { + if b.override { + switch b.code { + case http.StatusNotFound: + if b.statusNotFoundBodyGenerator != nil { + body = b.statusNotFoundBodyGenerator() + } + case http.StatusMethodNotAllowed: + if b.statusMethodNotAllowedBodyGenerator != nil { + body = b.statusMethodNotAllowedBodyGenerator() + } + } + } + + return b.ResponseWriter.Write(body) +} |