Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 26 additions & 28 deletions api.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package main

import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"strconv"
"time"

jwt "github.com/golang-jwt/jwt/v4"
"github.com/gorilla/mux"
Expand All @@ -29,7 +31,7 @@ func (s *APIServer) Run() {

router.HandleFunc("/login", makeHTTPHandleFunc(s.handleLogin))
router.HandleFunc("/account", makeHTTPHandleFunc(s.handleAccount))
router.HandleFunc("/account/{id}", withJWTAuth(makeHTTPHandleFunc(s.handleGetAccountByID), s.store))
router.HandleFunc("/account/{id}", withJWTAuth(makeHTTPHandleFunc(s.handleGetAccountByID)))
router.HandleFunc("/transfer", makeHTTPHandleFunc(s.handleTransfer))

log.Println("JSON API server running on port: ", s.listenAddr)
Expand Down Expand Up @@ -97,6 +99,11 @@ func (s *APIServer) handleGetAccountByID(w http.ResponseWriter, r *http.Request)
return err
}

jwtUserID := r.Context().Value(ContextUserIDKey).(int)
if id != jwtUserID {
return fmt.Errorf("permission denied")
}

account, err := s.store.GetAccountByID(id)
if err != nil {
return err
Expand Down Expand Up @@ -160,9 +167,12 @@ func WriteJSON(w http.ResponseWriter, status int, v any) error {
}

func createJWT(account *Account) (string, error) {
claims := &jwt.MapClaims{
"expiresAt": 15000,
"accountNumber": account.Number,
claims := &JWTClaims{
UserID: account.ID,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "gobank",
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
},
}

secret := os.Getenv("JWT_SECRET")
Expand All @@ -177,7 +187,7 @@ func permissionDenied(w http.ResponseWriter) {

// eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2NvdW50TnVtYmVyIjo0OTgwODEsImV4cGlyZXNBdCI6MTUwMDB9.TdQ907o9yhUI2KU0TngrqO-xbfNgHAfZI6Jngia15UE

func withJWTAuth(handlerFunc http.HandlerFunc, s Storage) http.HandlerFunc {
func withJWTAuth(handlerFunc http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
fmt.Println("calling JWT auth middleware")

Expand All @@ -187,40 +197,24 @@ func withJWTAuth(handlerFunc http.HandlerFunc, s Storage) http.HandlerFunc {
permissionDenied(w)
return
}
if !token.Valid {
permissionDenied(w)
return
}
userID, err := getID(r)
if err != nil {
permissionDenied(w)
return
}
account, err := s.GetAccountByID(userID)
if err != nil {
permissionDenied(w)
return
}

claims := token.Claims.(jwt.MapClaims)
if account.Number != int64(claims["accountNumber"].(float64)) {
var userID int
if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid {
userID = claims.UserID
} else {
permissionDenied(w)
return
}

if err != nil {
WriteJSON(w, http.StatusForbidden, ApiError{Error: "invalid token"})
return
}

handlerFunc(w, r)
ctx := context.WithValue(r.Context(), ContextUserIDKey, userID)
handlerFunc(w, r.WithContext(ctx))
}
}

func validateJWT(tokenString string) (*jwt.Token, error) {
secret := os.Getenv("JWT_SECRET")

return jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
Expand Down Expand Up @@ -253,3 +247,7 @@ func getID(r *http.Request) (int, error) {
}
return id, nil
}

type ContextKey string

const ContextUserIDKey ContextKey = "userid"
6 changes: 6 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"math/rand"
"time"

jwt "github.com/golang-jwt/jwt/v4"
"golang.org/x/crypto/bcrypt"
)

Expand Down Expand Up @@ -56,3 +57,8 @@ func NewAccount(firstName, lastName, password string) (*Account, error) {
CreatedAt: time.Now().UTC(),
}, nil
}

type JWTClaims struct {
UserID int `json:"userid"`
jwt.RegisteredClaims
}