119 lines
3.1 KiB
Go
119 lines
3.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
|
|
"clintonambulance.com/calculate_negative_points/internal/config"
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type Claims struct {
|
|
Sub string `json:"sub"`
|
|
Name string `json:"name"`
|
|
Email string `json:"email"`
|
|
Groups []string `json:"groups"`
|
|
}
|
|
|
|
func verifyTokenAndGetClaims(config *config.ApplicationConfig, ctx context.Context, token string) (*oidc.IDToken, Claims, error) {
|
|
idToken, err := config.OidcConfig.Verifier.Verify(ctx, token)
|
|
claims := Claims{Groups: []string{}}
|
|
|
|
if err != nil {
|
|
return idToken, claims, err
|
|
}
|
|
|
|
err = idToken.Claims(&claims)
|
|
|
|
if err != nil {
|
|
return idToken, claims, err
|
|
}
|
|
|
|
return idToken, claims, nil
|
|
}
|
|
|
|
func serveWithTokenAndClaims(next http.Handler, r *http.Request, w http.ResponseWriter, claims Claims, token *oidc.IDToken) {
|
|
ctx := context.WithValue(r.Context(), "user", token)
|
|
ctx = context.WithValue(ctx, "claims", claims)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
|
|
func CurrentUserMiddleware(config *config.ApplicationConfig) (func(http.Handler) http.Handler, error) {
|
|
middleware := func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
session, _ := config.CookieStore.Get(r, config.SessionName)
|
|
rawIDToken, ok := session.Values["id_token"].(string)
|
|
|
|
if !ok {
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// In test environment, bypass OIDC verification and use mock claims
|
|
if config.Environment == "test" || config.Environment == "testEnvironment" {
|
|
mockClaims := Claims{
|
|
Sub: "test-user-id",
|
|
Name: "Test User",
|
|
Email: "test@example.com",
|
|
Groups: []string{"calculate-negative-points-users"},
|
|
}
|
|
serveWithTokenAndClaims(next, r, w, mockClaims, nil)
|
|
return
|
|
}
|
|
|
|
idToken, claims, err := verifyTokenAndGetClaims(config, r.Context(), rawIDToken)
|
|
|
|
if err == nil {
|
|
serveWithTokenAndClaims(next, r, w, claims, idToken)
|
|
return
|
|
}
|
|
|
|
refreshToken, ok := session.Values["refresh_token"].(string)
|
|
if !ok {
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Attempt refresh
|
|
tokenSrc := config.OidcConfig.OAuth2.TokenSource(r.Context(), &oauth2.Token{
|
|
RefreshToken: refreshToken,
|
|
})
|
|
|
|
newToken, err := tokenSrc.Token()
|
|
if err != nil {
|
|
http.Error(w, "failed to refresh token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Save new tokens
|
|
newRawIDToken, ok := newToken.Extra("id_token").(string)
|
|
if !ok {
|
|
http.Error(w, "missing id_token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
idToken, claims, err = verifyTokenAndGetClaims(config, r.Context(), newRawIDToken)
|
|
|
|
if err != nil {
|
|
http.Error(w, "invalid refreshed token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
session.Values["id_token"] = newRawIDToken
|
|
if newToken.RefreshToken != "" {
|
|
session.Values["refresh_token"] = newToken.RefreshToken
|
|
}
|
|
if scope, ok := newToken.Extra("scope").(string); ok {
|
|
session.Values["scope"] = scope
|
|
}
|
|
session.Save(r, w)
|
|
|
|
serveWithTokenAndClaims(next, r, w, claims, idToken)
|
|
return
|
|
})
|
|
}
|
|
|
|
return middleware, nil
|
|
}
|