This commit is contained in:
118
internal/api/middleware/current_user.go
Normal file
118
internal/api/middleware/current_user.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"clintonambulance.com/calculate_negative_points/internal/config"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type Claims struct {
|
||||
Sub string `json:"sub"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
|
||||
func verifyTokenAndGetClaims(config *config.ApplicationConfig, ctx context.Context, token string) (*oidc.IDToken, Claims, error) {
|
||||
idToken, err := config.OidcConfig.Verifier.Verify(ctx, token)
|
||||
claims := Claims{Groups: []string{}}
|
||||
|
||||
if err != nil {
|
||||
return idToken, claims, err
|
||||
}
|
||||
|
||||
err = idToken.Claims(&claims)
|
||||
|
||||
if err != nil {
|
||||
return idToken, claims, err
|
||||
}
|
||||
|
||||
return idToken, claims, nil
|
||||
}
|
||||
|
||||
func serveWithTokenAndClaims(next http.Handler, r *http.Request, w http.ResponseWriter, claims Claims, token *oidc.IDToken) {
|
||||
ctx := context.WithValue(r.Context(), "user", token)
|
||||
ctx = context.WithValue(ctx, "claims", claims)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
|
||||
func CurrentUserMiddleware(config *config.ApplicationConfig) (func(http.Handler) http.Handler, error) {
|
||||
middleware := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
session, _ := config.CookieStore.Get(r, config.SessionName)
|
||||
rawIDToken, ok := session.Values["id_token"].(string)
|
||||
|
||||
if !ok {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// In test environment, bypass OIDC verification and use mock claims
|
||||
if config.Environment == "test" || config.Environment == "testEnvironment" {
|
||||
mockClaims := Claims{
|
||||
Sub: "test-user-id",
|
||||
Name: "Test User",
|
||||
Email: "test@example.com",
|
||||
Groups: []string{"calculate-negative-points-users"},
|
||||
}
|
||||
serveWithTokenAndClaims(next, r, w, mockClaims, nil)
|
||||
return
|
||||
}
|
||||
|
||||
idToken, claims, err := verifyTokenAndGetClaims(config, r.Context(), rawIDToken)
|
||||
|
||||
if err == nil {
|
||||
serveWithTokenAndClaims(next, r, w, claims, idToken)
|
||||
return
|
||||
}
|
||||
|
||||
refreshToken, ok := session.Values["refresh_token"].(string)
|
||||
if !ok {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt refresh
|
||||
tokenSrc := config.OidcConfig.OAuth2.TokenSource(r.Context(), &oauth2.Token{
|
||||
RefreshToken: refreshToken,
|
||||
})
|
||||
|
||||
newToken, err := tokenSrc.Token()
|
||||
if err != nil {
|
||||
http.Error(w, "failed to refresh token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Save new tokens
|
||||
newRawIDToken, ok := newToken.Extra("id_token").(string)
|
||||
if !ok {
|
||||
http.Error(w, "missing id_token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
idToken, claims, err = verifyTokenAndGetClaims(config, r.Context(), newRawIDToken)
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, "invalid refreshed token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
session.Values["id_token"] = newRawIDToken
|
||||
if newToken.RefreshToken != "" {
|
||||
session.Values["refresh_token"] = newToken.RefreshToken
|
||||
}
|
||||
if scope, ok := newToken.Extra("scope").(string); ok {
|
||||
session.Values["scope"] = scope
|
||||
}
|
||||
session.Save(r, w)
|
||||
|
||||
serveWithTokenAndClaims(next, r, w, claims, idToken)
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
return middleware, nil
|
||||
}
|
||||
Reference in New Issue
Block a user