diff --git a/contracts_interfaces.go b/contracts_interfaces.go index baf6dc2..6e1ee58 100644 --- a/contracts_interfaces.go +++ b/contracts_interfaces.go @@ -3,11 +3,11 @@ package httprouter import "net/http" type routeResolver interface { - // Resolve returns an instance of http.Handler and with a flag indicating if the route was understood. + // Resolve returns an instance of http.Handler and a bitmask of the methods allowed at the matched path. // If the http.Handler instance is not nil, the route was fully resolved and can be invoked. - // If the http.Handler instance is nil AND the flag is true, the route was found, but the method isn't compatible (e.g. "POST /", but only a "GET /" was found). - // If the http.Handler instance is nil AND the flag is false, the route was not found. - Resolve(method, path string) (http.Handler, bool) + // If the http.Handler instance is nil AND allowed > 0, the route was found, but the method isn't compatible (e.g. "POST /", but only a "GET /" was found). + // If the http.Handler instance is nil AND allowed == 0, the route was not found. + Resolve(method, path string) (http.Handler, Method) } type RecoveryFunc func(response http.ResponseWriter, request *http.Request, recovered any) diff --git a/contracts_methods.go b/contracts_methods.go index 3600a4b..aa5eab9 100644 --- a/contracts_methods.go +++ b/contracts_methods.go @@ -42,6 +42,19 @@ func (this Method) String() string { return result } func (this Method) GoString() string { return this.String() } +func (this Method) HeaderValue() string { + var result strings.Builder + for _, key := range orderedMethods { + if key&this != key { + continue + } + if result.Len() > 0 { + result.WriteString(", ") + } + result.WriteString(methodValues[key]) + } + return result.String() +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/router.go b/router.go index dfcfe9b..96b3455 100644 --- a/router.go +++ b/router.go @@ -16,9 +16,6 @@ func newRouter(resolver routeResolver, notFound, methodNotAllowed http.Handler, return &defaultRouter{resolver: resolver, notFound: notFound, methodNotAllowed: methodNotAllowed, monitor: monitor} } func (this *defaultRouter) ServeHTTP(response http.ResponseWriter, request *http.Request) { - this.resolve(request).ServeHTTP(response, request) -} -func (this *defaultRouter) resolve(request *http.Request) http.Handler { rawPath := request.RequestURI if len(rawPath) == 0 { rawPath = request.URL.Path @@ -26,15 +23,17 @@ func (this *defaultRouter) resolve(request *http.Request) http.Handler { rawPath = rawPath[0:index] } - if handler, resolved := this.resolver.Resolve(request.Method, rawPath); handler != nil { + handler, allowed := this.resolver.Resolve(request.Method, rawPath) + if handler != nil { this.monitor.Routed(request) - return handler - } else if resolved { + handler.ServeHTTP(response, request) + } else if allowed > 0 { this.monitor.MethodNotAllowed(request) - return this.methodNotAllowed + response.Header().Set("Allow", allowed.HeaderValue()) + this.methodNotAllowed.ServeHTTP(response, request) } else { this.monitor.NotFound(request) - return this.notFound + this.notFound.ServeHTTP(response, request) } } diff --git a/router_test.go b/router_test.go index 22a4165..a69777c 100644 --- a/router_test.go +++ b/router_test.go @@ -38,43 +38,43 @@ func TestRouting(t *testing.T) { ), ) - assertRoute(t, router, "GET ", "/", 404, "Not Found\n") - - assertRoute(t, router, "GET ", "/test1/path/to/document ", 200, "1") - assertRoute(t, router, "GET ", "/test1/path/to/document/", 404, "Not Found\n") - assertRoute(t, router, "GET ", "/test1/path/to/doc ", 404, "Not Found\n") - assertRoute(t, router, "GET ", "/test1/path/to/ ", 404, "Not Found\n") - assertRoute(t, router, "PUT ", "/test1/path/to/document ", 405, "Method Not Allowed\n") - assertRoute(t, router, "POST ", "/test1/path/to/document ", 200, "2") - assertRoute(t, router, "OPTIONS", "/test1/path/to/document ", 405, "Method Not Allowed\n") - assertRoute(t, router, "DELETE ", "/test1/path/to/document ", 200, "3") - assertRoute(t, router, "PATCH ", "/test1/path/to/document ", 200, "18") - assertRoute(t, router, "BOGUS ", "/test1/path/to/document ", 405, "Method Not Allowed\n") - - assertRoute(t, router, "GET ", "/test2/path/to/document ", 200, "4") - assertRoute(t, router, "PUT ", "/test2/path/to/document ", 200, "5") - assertRoute(t, router, "DELETE ", "/test2/path/to/document ", 200, "6") - assertRoute(t, router, "PATCH ", "/test2/path/to/document ", 405, "Method Not Allowed\n") - assertRoute(t, router, "DELETE ", "/test2/path/to/document/does-not-exist", 405, "Method Not Allowed\n") // greedy GET /test2/* - - assertRoute(t, router, "GET ", "/variable1/variable1/test3/path/to/document", 200, "7") - - assertRoute(t, router, "CONNECT", "/test4 ", 200, "10") - assertRoute(t, router, "HEAD ", "/test4 ", 405, "Method Not Allowed\n") - assertRoute(t, router, "CONNECT", "/test4/ ", 200, "11") - assertRoute(t, router, "CONNECT", "/test4/wildcard", 200, "12") - assertRoute(t, router, "DELETE ", "/test4/wildcard", 405, "Method Not Allowed\n") - - assertRoute(t, router, "TRACE ", "/test5/static/child/variable-name-here/grandchild ", 200, "13") - assertRoute(t, router, "TRACE ", "/test5/static/child/variable-name-here/grandchild/does-not-exist", 200, "15") // greedy wildcard - assertRoute(t, router, "TRACE ", "/test5/variable-name-here/child/static/grandchild ", 200, "14") - assertRoute(t, router, "TRACE ", "/test5/variable-name-here/child/wildcard ", 200, "15") - - assertRoute(t, router, "GET ", "/test5/variable-1-here/variable-2-here/variable-3-here/static", 200, "16") - assertRoute(t, router, "DELETE ", "/test5/variable-1-here/variable-2-here/variable-3-here/static", 405, "Method Not Allowed\n") - assertRoute(t, router, "GET ", "/test5/variable-1-here/variable-2-here/static/child ", 200, "17") + assertRoute(t, router, "GET ", "/", 404, "Not Found\n", "") + + assertRoute(t, router, "GET ", "/test1/path/to/document ", 200, "1", "") + assertRoute(t, router, "GET ", "/test1/path/to/document/", 404, "Not Found\n", "") + assertRoute(t, router, "GET ", "/test1/path/to/doc ", 404, "Not Found\n", "") + assertRoute(t, router, "GET ", "/test1/path/to/ ", 404, "Not Found\n", "") + assertRoute(t, router, "PUT ", "/test1/path/to/document ", 405, "Method Not Allowed\n", "GET, HEAD, POST, DELETE, PATCH") + assertRoute(t, router, "POST ", "/test1/path/to/document ", 200, "2", "") + assertRoute(t, router, "OPTIONS", "/test1/path/to/document ", 405, "Method Not Allowed\n", "GET, HEAD, POST, DELETE, PATCH") + assertRoute(t, router, "DELETE ", "/test1/path/to/document ", 200, "3", "") + assertRoute(t, router, "PATCH ", "/test1/path/to/document ", 200, "18", "") + assertRoute(t, router, "BOGUS ", "/test1/path/to/document ", 405, "Method Not Allowed\n", "GET, HEAD, POST, DELETE, PATCH") + + assertRoute(t, router, "GET ", "/test2/path/to/document ", 200, "4", "") + assertRoute(t, router, "PUT ", "/test2/path/to/document ", 200, "5", "") + assertRoute(t, router, "DELETE ", "/test2/path/to/document ", 200, "6", "") + assertRoute(t, router, "PATCH ", "/test2/path/to/document ", 405, "Method Not Allowed\n", "GET, PUT, DELETE") + assertRoute(t, router, "DELETE ", "/test2/path/to/document/does-not-exist", 405, "Method Not Allowed\n", "GET") // greedy GET /test2/* + + assertRoute(t, router, "GET ", "/variable1/variable1/test3/path/to/document", 200, "7", "") + + assertRoute(t, router, "CONNECT", "/test4 ", 200, "10", "") + assertRoute(t, router, "HEAD ", "/test4 ", 405, "Method Not Allowed\n", "CONNECT") + assertRoute(t, router, "CONNECT", "/test4/ ", 200, "11", "") + assertRoute(t, router, "CONNECT", "/test4/wildcard", 200, "12", "") + assertRoute(t, router, "DELETE ", "/test4/wildcard", 405, "Method Not Allowed\n", "CONNECT") + + assertRoute(t, router, "TRACE ", "/test5/static/child/variable-name-here/grandchild ", 200, "13", "") + assertRoute(t, router, "TRACE ", "/test5/static/child/variable-name-here/grandchild/does-not-exist", 200, "15", "") // greedy wildcard + assertRoute(t, router, "TRACE ", "/test5/variable-name-here/child/static/grandchild ", 200, "14", "") + assertRoute(t, router, "TRACE ", "/test5/variable-name-here/child/wildcard ", 200, "15", "") + + assertRoute(t, router, "GET ", "/test5/variable-1-here/variable-2-here/variable-3-here/static", 200, "16", "") + assertRoute(t, router, "DELETE ", "/test5/variable-1-here/variable-2-here/variable-3-here/static", 405, "Method Not Allowed\n", "GET") + assertRoute(t, router, "GET ", "/test5/variable-1-here/variable-2-here/static/child ", 200, "17", "") } -func assertRoute(t *testing.T, router http.Handler, method, path string, expectedStatus int, expectedBody string) { +func assertRoute(t *testing.T, router http.Handler, method, path string, expectedStatus int, expectedBody, expectedAllow string) { t.Helper() t.Run(fmt.Sprintf("%s:%s:%d", method, path, expectedStatus), func(t *testing.T) { t.Helper() @@ -92,6 +92,11 @@ func assertRoute(t *testing.T, router http.Handler, method, path string, expecte t.Errorf("expected body [%s], actual body: [%s] for test [%s %s]", expectedBody, actualBody, method, path) } } + + actualAllow := recorder.Header().Get("Allow") + if actualAllow != expectedAllow { + t.Errorf("expected Allow [%s], actual Allow: [%s] for test [%s %s]", expectedAllow, actualAllow, method, path) + } }) } diff --git a/tree.go b/tree.go index 9f1ad68..f8967ac 100644 --- a/tree.go +++ b/tree.go @@ -103,13 +103,12 @@ func hasOnlyAllowedCharacters(input string) bool { return true } -func (this *treeNode) Resolve(method, incomingPath string) (http.Handler, bool) { +func (this *treeNode) Resolve(method, incomingPath string) (http.Handler, Method) { if len(incomingPath) == 0 { if this.handlers == nil { - return nil, false - } else { - return this.handlers.Resolve(method), true + return nil, 0 } + return this.handlers.Resolve(method), this.handlers.allowed } if incomingPath[0] == '/' { @@ -117,7 +116,7 @@ func (this *treeNode) Resolve(method, incomingPath string) (http.Handler, bool) } var handler http.Handler - var staticResourceExists, variableResourceExists bool + var staticAllowed, variableAllowed Method var pathFragment = parsePathFragment(incomingPath) for _, staticChild := range this.static { @@ -125,27 +124,27 @@ func (this *treeNode) Resolve(method, incomingPath string) (http.Handler, bool) continue } - // the path fragment DOES match... remainingPath := incomingPath[len(pathFragment):] - if handler, staticResourceExists = staticChild.Resolve(method, remainingPath); handler != nil { - return handler, staticResourceExists + if handler, staticAllowed = staticChild.Resolve(method, remainingPath); handler != nil { + return handler, MethodNone } - break // don't bother checking any more of siblings of the static child, they don't match + break } if this.variable != nil { remainingPath := incomingPath[len(pathFragment):] - if handler, variableResourceExists = this.variable.Resolve(method, remainingPath); handler != nil { - return handler, variableResourceExists + if handler, variableAllowed = this.variable.Resolve(method, remainingPath); handler != nil { + return handler, MethodNone } } if this.wildcard != nil { - return this.wildcard.Resolve(method, "") + wildcardHandler, wildcardAllowed := this.wildcard.Resolve(method, "") + return wildcardHandler, staticAllowed | variableAllowed | wildcardAllowed } - return nil, staticResourceExists || variableResourceExists + return nil, staticAllowed | variableAllowed } func parsePathFragment(value string) string { if index := strings.IndexByte(value, '/'); index == -1 { @@ -176,6 +175,7 @@ var allowedCharacters = map[rune]struct{}{ //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// type methodHandlers struct { + allowed Method Get http.Handler Head http.Handler Post http.Handler @@ -229,6 +229,7 @@ func (this *methodHandlers) Add(allowed Method, handler http.Handler) error { this.Patch = handler } + this.allowed |= allowed return nil } func (this *methodHandlers) Resolve(method string) http.Handler {