diff --git a/api.go b/api.go index 2a7c469..f7365d8 100644 --- a/api.go +++ b/api.go @@ -1,12 +1,14 @@ package main import ( + "context" "encoding/json" "fmt" "log" "net/http" "os" "strconv" + "time" jwt "github.com/golang-jwt/jwt/v4" "github.com/gorilla/mux" @@ -29,7 +31,7 @@ func (s *APIServer) Run() { router.HandleFunc("/login", makeHTTPHandleFunc(s.handleLogin)) router.HandleFunc("/account", makeHTTPHandleFunc(s.handleAccount)) - router.HandleFunc("/account/{id}", withJWTAuth(makeHTTPHandleFunc(s.handleGetAccountByID), s.store)) + router.HandleFunc("/account/{id}", withJWTAuth(makeHTTPHandleFunc(s.handleGetAccountByID))) router.HandleFunc("/transfer", makeHTTPHandleFunc(s.handleTransfer)) log.Println("JSON API server running on port: ", s.listenAddr) @@ -97,6 +99,11 @@ func (s *APIServer) handleGetAccountByID(w http.ResponseWriter, r *http.Request) return err } + jwtUserID := r.Context().Value(ContextUserIDKey).(int) + if id != jwtUserID { + return fmt.Errorf("permission denied") + } + account, err := s.store.GetAccountByID(id) if err != nil { return err @@ -160,9 +167,12 @@ func WriteJSON(w http.ResponseWriter, status int, v any) error { } func createJWT(account *Account) (string, error) { - claims := &jwt.MapClaims{ - "expiresAt": 15000, - "accountNumber": account.Number, + claims := &JWTClaims{ + UserID: account.ID, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "gobank", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + }, } secret := os.Getenv("JWT_SECRET") @@ -177,7 +187,7 @@ func permissionDenied(w http.ResponseWriter) { // eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2NvdW50TnVtYmVyIjo0OTgwODEsImV4cGlyZXNBdCI6MTUwMDB9.TdQ907o9yhUI2KU0TngrqO-xbfNgHAfZI6Jngia15UE -func withJWTAuth(handlerFunc http.HandlerFunc, s Storage) http.HandlerFunc { +func withJWTAuth(handlerFunc http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { fmt.Println("calling JWT auth middleware") @@ -187,40 +197,24 @@ func withJWTAuth(handlerFunc http.HandlerFunc, s Storage) http.HandlerFunc { permissionDenied(w) return } - if !token.Valid { - permissionDenied(w) - return - } - userID, err := getID(r) - if err != nil { - permissionDenied(w) - return - } - account, err := s.GetAccountByID(userID) - if err != nil { - permissionDenied(w) - return - } - claims := token.Claims.(jwt.MapClaims) - if account.Number != int64(claims["accountNumber"].(float64)) { + var userID int + if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid { + userID = claims.UserID + } else { permissionDenied(w) return } - if err != nil { - WriteJSON(w, http.StatusForbidden, ApiError{Error: "invalid token"}) - return - } - - handlerFunc(w, r) + ctx := context.WithValue(r.Context(), ContextUserIDKey, userID) + handlerFunc(w, r.WithContext(ctx)) } } func validateJWT(tokenString string) (*jwt.Token, error) { secret := os.Getenv("JWT_SECRET") - return jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + return jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { // Don't forget to validate the alg is what you expect: if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) @@ -253,3 +247,7 @@ func getID(r *http.Request) (int, error) { } return id, nil } + +type ContextKey string + +const ContextUserIDKey ContextKey = "userid" diff --git a/types.go b/types.go index 405f36d..fe66622 100644 --- a/types.go +++ b/types.go @@ -4,6 +4,7 @@ import ( "math/rand" "time" + jwt "github.com/golang-jwt/jwt/v4" "golang.org/x/crypto/bcrypt" ) @@ -56,3 +57,8 @@ func NewAccount(firstName, lastName, password string) (*Account, error) { CreatedAt: time.Now().UTC(), }, nil } + +type JWTClaims struct { + UserID int `json:"userid"` + jwt.RegisteredClaims +}