Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 8 additions & 252 deletions app/allowlist.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"encoding/json"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -81,6 +80,9 @@ func SetAllowlist(name string, callers []CallerConfig) error {
if method == "" {
continue
}
if err := compileBodySchemaInto(&cons); err != nil {
return fmt.Errorf("invalid body schema for %s %s: %w", r.Path, method, err)
}
methods[strings.ToUpper(method)] = cons
}
r.Methods = methods
Expand Down Expand Up @@ -201,11 +203,11 @@ func validateRequestReason(r *http.Request, c RequestConstraint) (bool, string)
}
ct := strings.ToLower(r.Header.Get("Content-Type"))
if strings.Contains(ct, "application/json") {
var data map[string]interface{}
if err := json.Unmarshal(bodyBytes, &data); err != nil {
data, err := decodeJSONBody(bodyBytes)
if err != nil {
return false, "invalid json"
}
if ok, reason := matchBodyMapReason(data, c.Body); !ok {
if ok, reason := validateBodySchemaCompiled(c.BodySchema, c.Body, data); !ok {
return false, reason
}
return true, ""
Expand All @@ -215,7 +217,8 @@ func validateRequestReason(r *http.Request, c RequestConstraint) (bool, string)
if err != nil {
return false, "invalid form encoding"
}
if ok, reason := matchFormReason(vals, c.Body); !ok {
data := formValuesToJSON(vals)
if ok, reason := validateBodySchemaCompiled(c.BodySchema, c.Body, data); !ok {
return false, reason
}
return true, ""
Expand All @@ -224,253 +227,6 @@ func validateRequestReason(r *http.Request, c RequestConstraint) (bool, string)
return true, ""
}

func matchForm(vals url.Values, rule map[string]interface{}) bool {
for k, v := range rule {
present, ok := vals[k]
if !ok {
return false
}
switch want := v.(type) {
case string:
found := false
for _, got := range present {
if got == want {
found = true
break
}
}
if !found {
return false
}
case []interface{}:
for _, elem := range want {
s, ok := elem.(string)
if !ok {
return false
}
found := false
for _, got := range present {
if got == s {
found = true
break
}
}
if !found {
return false
}
}
default:
return false
}
}
return true
}

func matchFormReason(vals url.Values, rule map[string]interface{}) (bool, string) {
for k, v := range rule {
present, ok := vals[k]
if !ok {
return false, "missing form field " + k
}
switch want := v.(type) {
case string:
found := false
for _, got := range present {
if got == want {
found = true
break
}
}
if !found {
return false, fmt.Sprintf("form field %s=%s not found", k, want)
}
case []interface{}:
for _, elem := range want {
s, ok := elem.(string)
if !ok {
return false, "invalid rule"
}
found := false
for _, got := range present {
if got == s {
found = true
break
}
}
if !found {
return false, fmt.Sprintf("form field %s=%s not found", k, s)
}
}
default:
return false, "invalid rule"
}
}
return true, ""
}

func matchQuery(vals url.Values, rule map[string][]string) bool {
for k, wantVals := range rule {
present, ok := vals[k]
if !ok {
return false
}
for _, want := range wantVals {
found := false
for _, got := range present {
if got == want {
found = true
break
}
}
if !found {
return false
}
}
}
return true
}

func matchBodyMap(data map[string]interface{}, rule map[string]interface{}) bool {
return matchValue(data, rule)
}

func matchBodyMapReason(data map[string]interface{}, rule map[string]interface{}) (bool, string) {
return matchValueReason(data, rule, "")
}

func matchValueReason(data, rule interface{}, path string) (bool, string) {
switch rv := rule.(type) {
case map[string]interface{}:
dm, ok := data.(map[string]interface{})
if !ok {
return false, fmt.Sprintf("body field %s not object", path)
}
for k, v := range rv {
dv, ok := dm[k]
p := k
if path != "" {
p = path + "." + k
}
if !ok {
return false, "missing body field " + p
}
if ok2, reason := matchValueReason(dv, v, p); !ok2 {
return false, reason
}
}
return true, ""
case []interface{}:
da, ok := data.([]interface{})
if !ok {
return false, fmt.Sprintf("body field %s not array", path)
}
for _, want := range rv {
found := false
for _, elem := range da {
if ok2, _ := matchValueReason(elem, want, path); ok2 {
found = true
break
}
}
if !found {
return false, fmt.Sprintf("body field %s missing element", path)
}
}
return true, ""
default:
if df, ok := toFloat(data); ok {
if rf, ok2 := toFloat(rv); ok2 {
if df == rf {
return true, ""
}
}
}
if data == rule {
return true, ""
}
return false, fmt.Sprintf("body field %s value mismatch", path)
}
}

func matchValue(data, rule interface{}) bool {
switch rv := rule.(type) {
case map[string]interface{}:
dm, ok := data.(map[string]interface{})
if !ok {
return false
}
for k, v := range rv {
dv, ok := dm[k]
if !ok {
return false
}
if !matchValue(dv, v) {
return false
}
}
return true
case []interface{}:
da, ok := data.([]interface{})
if !ok {
return false
}
for _, want := range rv {
found := false
for _, elem := range da {
if matchValue(elem, want) {
found = true
break
}
}
if !found {
return false
}
}
return true
default:
// YAML unmarshals numbers without decimals as ints while JSON
// decoding uses float64. Normalize numeric types so the values
// compare equal regardless of how they were parsed.
if df, ok := toFloat(data); ok {
if rf, ok2 := toFloat(rv); ok2 {
return df == rf
}
}
return data == rule
}
}

func toFloat(v interface{}) (float64, bool) {
switch n := v.(type) {
case int:
return float64(n), true
case int8:
return float64(n), true
case int16:
return float64(n), true
case int32:
return float64(n), true
case int64:
return float64(n), true
case uint:
return float64(n), true
case uint8:
return float64(n), true
case uint16:
return float64(n), true
case uint32:
return float64(n), true
case uint64:
return float64(n), true
case float32:
return float64(n), true
case float64:
return n, true
default:
return 0, false
}
}

// findConstraint returns the RequestConstraint for the given caller, path and
// method if one exists.
func findConstraint(i *Integration, callerID, pth, method string) (RequestConstraint, bool) {
Expand Down
Loading
Loading