diff --git a/_examples/auth/main.go b/_examples/auth/main.go index ce5687d..b9ff5ea 100644 --- a/_examples/auth/main.go +++ b/_examples/auth/main.go @@ -8,7 +8,8 @@ import ( "github.com/gofiber/fiber/v2" ) -// Authentication service with role management +// Authentication service with role management. +// Implements AuthorizationService, BasicAuthValidator, APIKeyValidator, and AWSSignatureValidator. type ExampleAuthService struct{} func (s *ExampleAuthService) ValidateToken(token string) (*fiberoapi.AuthContext, error) { @@ -116,6 +117,75 @@ func (s *ExampleAuthService) GetUserPermissions(ctx *fiberoapi.AuthContext, reso }, nil } +// ValidateBasicAuth implements BasicAuthValidator for HTTP Basic authentication (curl --user). +func (s *ExampleAuthService) ValidateBasicAuth(username, password string) (*fiberoapi.AuthContext, error) { + // Example credentials + users := map[string]string{ + "admin": "admin-pass", + "user": "user-pass", + } + + expectedPassword, exists := users[username] + if !exists || password != expectedPassword { + return nil, fmt.Errorf("invalid credentials for user: %s", username) + } + + roles := []string{"user"} + scopes := []string{"read", "write"} + if username == "admin" { + roles = []string{"admin", "user"} + scopes = []string{"read", "write", "delete", "share"} + } + + return &fiberoapi.AuthContext{ + UserID: username, + Roles: roles, + Scopes: scopes, + }, nil +} + +// ValidateAPIKey implements APIKeyValidator for API Key authentication. +func (s *ExampleAuthService) ValidateAPIKey(key string, location string, paramName string) (*fiberoapi.AuthContext, error) { + validKeys := map[string]string{ + "my-secret-api-key": "apikey-user-1", + "another-api-key": "apikey-user-2", + } + + userID, exists := validKeys[key] + if !exists { + return nil, fmt.Errorf("invalid API key") + } + + return &fiberoapi.AuthContext{ + UserID: userID, + Roles: []string{"user"}, + Scopes: []string{"read"}, + }, nil +} + +// ValidateAWSSignature implements AWSSignatureValidator for AWS SigV4 authentication. +func (s *ExampleAuthService) ValidateAWSSignature(params *fiberoapi.AWSSignatureParams) (*fiberoapi.AuthContext, error) { + // In a real implementation, you would verify the HMAC-SHA256 signature + // using the secret key associated with the AccessKeyID. + validKeys := map[string]bool{ + "AKIAIOSFODNN7EXAMPLE": true, + } + + if !validKeys[params.AccessKeyID] { + return nil, fmt.Errorf("invalid access key: %s", params.AccessKeyID) + } + + return &fiberoapi.AuthContext{ + UserID: "aws-service-" + params.AccessKeyID, + Roles: []string{"service"}, + Scopes: []string{"read", "write"}, + Claims: map[string]interface{}{ + "region": params.Region, + "service": params.Service, + }, + }, nil +} + type CreateUserRequest struct { Name string `json:"name" validate:"required,min=2,max=50"` } @@ -175,11 +245,31 @@ func main() { Type: "http", Scheme: "bearer", BearerFormat: "JWT", - Description: "JWT Bearer token", + Description: "JWT Bearer token authentication", + }, + "basicAuth": { + Type: "http", + Scheme: "basic", + Description: "HTTP Basic authentication (curl --user user:pass)", + }, + "apiKeyAuth": { + Type: "apiKey", + In: "header", + Name: "X-API-Key", + Description: "API Key authentication via header", + }, + "awsSigV4": { + Type: "http", + Scheme: "AWS4-HMAC-SHA256", + Description: "AWS Signature V4 authentication", }, }, + // Any of these schemes can be used (OR semantics) DefaultSecurity: []map[string][]string{ {"bearerAuth": {}}, + {"basicAuth": {}}, + {"apiKeyAuth": {}}, + {"awsSigV4": {}}, }, } @@ -411,44 +501,41 @@ func main() { fmt.Println("📚 Documentation: http://localhost:3002/docs") fmt.Println("📄 OpenAPI JSON: http://localhost:3002/openapi.json") fmt.Println("") + fmt.Println("🔑 Méthodes d'authentification supportées:") + fmt.Println(" Bearer Token: Authorization: Bearer ") + fmt.Println(" Basic Auth: Authorization: Basic base64(user:pass) (curl --user user:pass)") + fmt.Println(" API Key: X-API-Key: ") + fmt.Println(" AWS SigV4: Authorization: AWS4-HMAC-SHA256 Credential=...") + fmt.Println("") fmt.Println("🔑 Tokens de test disponibles:") fmt.Println(" admin-token -> rôles: [admin, user], scopes: [read, write, delete, share]") fmt.Println(" editor-token -> rôles: [editor, user], scopes: [read, write, share]") fmt.Println(" user-token -> rôles: [user], scopes: [read, write]") fmt.Println(" readonly-token -> rôles: [user], scopes: [read]") fmt.Println("") - fmt.Println("🌍 Endpoints par niveau d'accès:") - fmt.Println(" GET /health (public)") - fmt.Println(" GET /me (auth simple)") - fmt.Println(" GET /documents/:id (user + read)") - fmt.Println(" PUT /documents/:id (user + write)") - fmt.Println(" POST /documents/:id/share (scope: share)") - fmt.Println(" DELETE /documents/:id (admin + delete)") - fmt.Println(" POST /users (admin + write)") + fmt.Println("🔑 Comptes Basic Auth:") + fmt.Println(" admin:admin-pass -> rôles: [admin, user]") + fmt.Println(" user:user-pass -> rôles: [user]") fmt.Println("") - fmt.Println("🧪 Tests suggérés:") - fmt.Println(" # Test admin - création d'utilisateur") - fmt.Println(` curl -X POST -H 'Authorization: Bearer admin-token' -H 'Content-Type: application/json' -d '{"name":"John Doe"}' http://localhost:3002/users`) - fmt.Println("") - fmt.Println(" # Test utilisateur normal (devrait échouer)") - fmt.Println(` curl -X POST -H 'Authorization: Bearer readonly-token' -H 'Content-Type: application/json' -d '{"name":"Jane Doe"}' http://localhost:3002/users`) + fmt.Println("🔑 API Keys:") + fmt.Println(" my-secret-api-key -> read only") + fmt.Println(" another-api-key -> read only") fmt.Println("") - fmt.Println(" # Test lecture document") - fmt.Println(" curl -H 'Authorization: Bearer user-token' http://localhost:3002/documents/33cd10d7-d80f-4fd2-9107-7423997393d2") + fmt.Println("🧪 Tests suggérés:") + fmt.Println(" # Bearer Token") + fmt.Println(" curl -H 'Authorization: Bearer admin-token' http://localhost:3002/me") fmt.Println("") - fmt.Println(" # Test modification document") - fmt.Println(` curl -X PUT -H 'Authorization: Bearer user-token' -H 'Content-Type: application/json' -d '{"title":"Mon Document","content":"Contenu modifié"}' http://localhost:3002/documents/33cd10d7-d80f-4fd2-9107-7423997393d2`) + fmt.Println(" # Basic Auth (curl --user)") + fmt.Println(" curl --user admin:admin-pass http://localhost:3002/me") fmt.Println("") - fmt.Println(" # Test partage (éditeur/admin seulement)") - fmt.Println(" curl -X POST -H 'Authorization: Bearer editor-token' http://localhost:3002/documents/33cd10d7-d80f-4fd2-9107-7423997393d2/share") + fmt.Println(" # API Key") + fmt.Println(" curl -H 'X-API-Key: my-secret-api-key' http://localhost:3002/documents/doc-1") fmt.Println("") - fmt.Println(" # Test suppression (admin seulement)") - fmt.Println(" curl -X DELETE -H 'Authorization: Bearer admin-token' http://localhost:3002/documents/33cd10d7-d80f-4fd2-9107-7423997393d2") + fmt.Println(" # AWS SigV4") + fmt.Println(" curl -H 'Authorization: AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20250101/us-east-1/execute-api/aws4_request, SignedHeaders=host;x-amz-date, Signature=abc123' http://localhost:3002/me") fmt.Println("") - fmt.Println(" # Test endpoints publics") + fmt.Println(" # Public endpoint") fmt.Println(" curl http://localhost:3002/health") - fmt.Println(" curl -H 'Authorization: Bearer user-token' http://localhost:3002/me") - fmt.Println(" curl -H 'Authorization: Bearer user-token' http://localhost:3002/status") app.Listen(":3002") } diff --git a/auth.go b/auth.go index ab8a78f..0258b5f 100644 --- a/auth.go +++ b/auth.go @@ -1,6 +1,7 @@ package fiberoapi import ( + "errors" "fmt" "reflect" "strings" @@ -139,36 +140,57 @@ func RoleGuard(validator AuthorizationService, requiredRoles ...string) fiber.Ha } } -// validateAuthorization validates permissions based on tags -func validateAuthorization(c *fiber.Ctx, input interface{}, authService AuthorizationService) error { +// validateAuthorization validates permissions based on configured security schemes. +// When SecuritySchemes is empty, it falls back to Bearer-only validation for backward compatibility. +func validateAuthorization(c *fiber.Ctx, input interface{}, authService AuthorizationService, config *Config) error { if authService == nil { return nil } - // Extract and validate the token directly - authHeader := c.Get("Authorization") - if authHeader == "" { - return fmt.Errorf("authentication required") + // Backward compatibility: if no SecuritySchemes are configured, + // fall back to Bearer-only validation (original behavior). + if config == nil || len(config.SecuritySchemes) == 0 { + authCtx, err := validateBearerToken(c, authService) + if err != nil { + return &AuthError{StatusCode: 401, Message: err.Error()} + } + c.Locals("auth", authCtx) + return validateResourceAccess(c, authCtx, input, authService) } - // Check Bearer format - if !strings.HasPrefix(authHeader, "Bearer ") { - return fmt.Errorf("invalid authorization header format") + // Multi-scheme validation path + securityReqs := config.DefaultSecurity + if len(securityReqs) == 0 { + securityReqs = buildDefaultFromSchemes(config.SecuritySchemes) } - token := strings.TrimPrefix(authHeader, "Bearer ") - - // Validate the token - authCtx, err := authService.ValidateToken(token) - if err != nil { - return fmt.Errorf("invalid token: %v", err) + // Try each security requirement (OR semantics per OpenAPI spec). + // Server configuration errors (5xx) short-circuit immediately since + // no alternative requirement can fix a misconfigured scheme. + var lastErr error + for _, requirement := range securityReqs { + authCtx, err := validateSecurityRequirement(c, requirement, config.SecuritySchemes, authService) + if err == nil { + c.Locals("auth", authCtx) + return validateResourceAccess(c, authCtx, input, authService) + } + var authErr *AuthError + if errors.As(err, &authErr) && authErr.StatusCode >= 500 { + return err + } + lastErr = err } - // Store auth context for later use - c.Locals("auth", authCtx) - - // Analyze authorization tags in the struct - return validateResourceAccess(c, authCtx, input, authService) + // Propagate typed errors (AuthError, ScopeError) without re-wrapping + var existingAuthErr *AuthError + if errors.As(lastErr, &existingAuthErr) { + return lastErr + } + var scopeErr *ScopeError + if errors.As(lastErr, &scopeErr) { + return &AuthError{StatusCode: 403, Message: lastErr.Error()} + } + return &AuthError{StatusCode: 401, Message: lastErr.Error()} } // validateResourceAccess validates resource access based on tags @@ -202,11 +224,11 @@ func validateResourceAccess(c *fiber.Ctx, authCtx *AuthContext, input interface{ canAccess, err := authService.CanAccessResource(authCtx, resourceTag, resourceID, actionTag) if err != nil { - return fmt.Errorf("authorization check failed: %w", err) + return &AuthError{StatusCode: 500, Message: fmt.Sprintf("authorization check failed: %v", err)} } if !canAccess { - return fmt.Errorf("insufficient permissions for %s %s on %s", actionTag, resourceTag, resourceID) + return &AuthError{StatusCode: 403, Message: fmt.Sprintf("insufficient permissions for %s %s on %s", actionTag, resourceTag, resourceID)} } } } diff --git a/auth_schemes.go b/auth_schemes.go new file mode 100644 index 0000000..f1baa85 --- /dev/null +++ b/auth_schemes.go @@ -0,0 +1,345 @@ +package fiberoapi + +import ( + "encoding/base64" + "fmt" + "sort" + "strings" + + "github.com/gofiber/fiber/v2" +) + +// BasicAuthValidator is an optional interface for services that support +// HTTP Basic authentication. Implement this alongside AuthorizationService +// to enable Basic Auth validation. +type BasicAuthValidator interface { + ValidateBasicAuth(username, password string) (*AuthContext, error) +} + +// APIKeyValidator is an optional interface for services that support +// API Key authentication (in header, query, or cookie). +type APIKeyValidator interface { + ValidateAPIKey(key string, location string, paramName string) (*AuthContext, error) +} + +// AWSSignatureValidator is an optional interface for services that support +// AWS Signature V4 authentication. The library parses the Authorization header +// and passes structured data; the implementation handles the actual +// cryptographic verification. +type AWSSignatureValidator interface { + ValidateAWSSignature(params *AWSSignatureParams) (*AuthContext, error) +} + +// AWSSignatureParams contains the parsed components of an AWS SigV4 Authorization header. +type AWSSignatureParams struct { + // Parsed from "Credential=AKID/date/region/service/aws4_request" + AccessKeyID string + Date string + Region string + Service string + + // Parsed from "SignedHeaders=host;x-amz-date;..." + SignedHeaders []string + + // The raw signature hex string + Signature string + + // The raw Authorization header for custom verification + RawHeader string + + // Request metadata needed for signature verification + Method string + Path string + QueryString string + Headers map[string]string + Body []byte +} + +// validateBearerToken validates a Bearer token from the Authorization header. +func validateBearerToken(c *fiber.Ctx, authService AuthorizationService) (*AuthContext, error) { + authHeader := c.Get("Authorization") + if authHeader == "" { + return nil, fmt.Errorf("authentication required: Bearer token expected") + } + + if !strings.HasPrefix(authHeader, "Bearer ") { + return nil, fmt.Errorf("invalid authorization header: Bearer prefix expected") + } + + token := strings.TrimPrefix(authHeader, "Bearer ") + return authService.ValidateToken(token) +} + +// validateBasicAuth validates Basic Auth credentials from the Authorization header. +func validateBasicAuth(c *fiber.Ctx, authService AuthorizationService) (*AuthContext, error) { + basicValidator, ok := authService.(BasicAuthValidator) + if !ok { + return nil, &AuthError{StatusCode: 500, Message: "Basic Auth scheme configured but AuthService does not implement BasicAuthValidator"} + } + + authHeader := c.Get("Authorization") + if authHeader == "" { + return nil, fmt.Errorf("authentication required: Basic auth expected") + } + + if !strings.HasPrefix(authHeader, "Basic ") { + return nil, fmt.Errorf("invalid authorization header: Basic prefix expected") + } + + encoded := strings.TrimPrefix(authHeader, "Basic ") + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return nil, fmt.Errorf("invalid Basic auth encoding: %w", err) + } + + parts := strings.SplitN(string(decoded), ":", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid Basic auth format: expected username:password") + } + + return basicValidator.ValidateBasicAuth(parts[0], parts[1]) +} + +// validateAPIKey validates an API key from header, query, or cookie. +func validateAPIKey(c *fiber.Ctx, scheme SecurityScheme, authService AuthorizationService) (*AuthContext, error) { + apiKeyValidator, ok := authService.(APIKeyValidator) + if !ok { + return nil, &AuthError{StatusCode: 500, Message: "API Key scheme configured but AuthService does not implement APIKeyValidator"} + } + + var key string + switch scheme.In { + case "header": + key = c.Get(scheme.Name) + case "query": + key = c.Query(scheme.Name) + case "cookie": + key = c.Cookies(scheme.Name) + default: + return nil, &AuthError{StatusCode: 500, Message: fmt.Sprintf("unsupported API Key location: %s", scheme.In)} + } + + if key == "" { + return nil, fmt.Errorf("API key not found in %s parameter '%s'", scheme.In, scheme.Name) + } + + return apiKeyValidator.ValidateAPIKey(key, scheme.In, scheme.Name) +} + +// validateAWSSigV4 validates an AWS Signature V4 Authorization header. +func validateAWSSigV4(c *fiber.Ctx, authService AuthorizationService) (*AuthContext, error) { + awsValidator, ok := authService.(AWSSignatureValidator) + if !ok { + return nil, &AuthError{StatusCode: 500, Message: "AWS SigV4 scheme configured but AuthService does not implement AWSSignatureValidator"} + } + + authHeader := c.Get("Authorization") + if authHeader == "" { + return nil, fmt.Errorf("authentication required: AWS4-HMAC-SHA256 signature expected") + } + + if !strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 ") { + return nil, fmt.Errorf("invalid authorization header: AWS4-HMAC-SHA256 prefix expected") + } + + params, err := parseAWSSigV4Header(authHeader) + if err != nil { + return nil, fmt.Errorf("failed to parse AWS SigV4 header: %w", err) + } + + // Populate request metadata + params.Method = c.Method() + params.Path = c.Path() + params.QueryString = string(c.Request().URI().QueryString()) + params.Body = c.Body() + params.RawHeader = authHeader + + // Collect all headers that were signed + params.Headers = make(map[string]string) + for _, headerName := range params.SignedHeaders { + params.Headers[headerName] = c.Get(headerName) + } + + return awsValidator.ValidateAWSSignature(params) +} + +// parseAWSSigV4Header parses an AWS SigV4 Authorization header into its components. +// Format: AWS4-HMAC-SHA256 Credential=AKID/20250101/us-east-1/s3/aws4_request, +// +// SignedHeaders=host;x-amz-date, Signature=abcdef... +func parseAWSSigV4Header(header string) (*AWSSignatureParams, error) { + params := &AWSSignatureParams{} + content := strings.TrimPrefix(header, "AWS4-HMAC-SHA256 ") + + parts := strings.Split(content, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + kv := strings.SplitN(part, "=", 2) + if len(kv) != 2 { + continue + } + switch kv[0] { + case "Credential": + credParts := strings.Split(kv[1], "/") + if len(credParts) >= 5 { + params.AccessKeyID = credParts[0] + params.Date = credParts[1] + params.Region = credParts[2] + params.Service = credParts[3] + } + case "SignedHeaders": + params.SignedHeaders = strings.Split(kv[1], ";") + case "Signature": + params.Signature = kv[1] + } + } + + if params.AccessKeyID == "" || params.Signature == "" || len(params.SignedHeaders) == 0 { + return nil, fmt.Errorf("incomplete AWS SigV4 header: missing Credential, SignedHeaders, or Signature") + } + + return params, nil +} + +// validateWithScheme dispatches validation to the appropriate scheme handler. +func validateWithScheme(c *fiber.Ctx, scheme SecurityScheme, authService AuthorizationService) (*AuthContext, error) { + switch { + case scheme.Type == "http" && strings.EqualFold(scheme.Scheme, "bearer"): + return validateBearerToken(c, authService) + case scheme.Type == "http" && strings.EqualFold(scheme.Scheme, "basic"): + return validateBasicAuth(c, authService) + case scheme.Type == "apiKey": + return validateAPIKey(c, scheme, authService) + case scheme.Type == "http" && strings.EqualFold(scheme.Scheme, "aws4-hmac-sha256"): + return validateAWSSigV4(c, authService) + default: + return nil, &AuthError{StatusCode: 500, Message: fmt.Sprintf("unsupported security scheme: type=%s scheme=%s", scheme.Type, scheme.Scheme)} + } +} + +// AuthError represents an authentication or authorization failure with an HTTP status code. +type AuthError struct { + StatusCode int + Message string +} + +func (e *AuthError) Error() string { + return e.Message +} + +// ScopeError represents an authorization failure due to missing scopes (403, not 401). +type ScopeError struct { + Scope string +} + +func (e *ScopeError) Error() string { + return fmt.Sprintf("missing required scope: %s", e.Scope) +} + +// validateSecurityRequirement validates a single OpenAPI security requirement. +// A requirement is a map of scheme-name -> required-scopes. +// ALL schemes in a requirement must validate (AND semantics). +// When multiple schemes are present, their AuthContexts are merged: UserIDs must +// match (or be empty), and roles/scopes/claims are combined. +func validateSecurityRequirement(c *fiber.Ctx, requirement map[string][]string, schemes map[string]SecurityScheme, authService AuthorizationService) (*AuthContext, error) { + if len(requirement) == 0 { + return nil, &AuthError{StatusCode: 500, Message: "empty security requirement"} + } + + // Sort scheme names for deterministic validation order + schemeNames := make([]string, 0, len(requirement)) + for name := range requirement { + schemeNames = append(schemeNames, name) + } + sort.Strings(schemeNames) + + var merged *AuthContext + + for _, schemeName := range schemeNames { + requiredScopes := requirement[schemeName] + + scheme, exists := schemes[schemeName] + if !exists { + return nil, &AuthError{StatusCode: 500, Message: fmt.Sprintf("unknown security scheme: %s", schemeName)} + } + + authCtx, err := validateWithScheme(c, scheme, authService) + if err != nil { + return nil, err + } + + // Check required scopes + for _, scope := range requiredScopes { + if !authService.HasScope(authCtx, scope) { + return nil, &ScopeError{Scope: scope} + } + } + + if merged == nil { + // First scheme — clone the context as the base + merged = &AuthContext{ + UserID: authCtx.UserID, + Roles: append([]string{}, authCtx.Roles...), + Scopes: append([]string{}, authCtx.Scopes...), + } + if authCtx.Claims != nil { + merged.Claims = make(map[string]interface{}, len(authCtx.Claims)) + for k, v := range authCtx.Claims { + merged.Claims[k] = v + } + } + } else { + // Subsequent schemes — verify identity consistency and merge + if authCtx.UserID != "" && merged.UserID != "" && authCtx.UserID != merged.UserID { + return nil, fmt.Errorf("security scheme conflict: scheme %s resolved to user %q, expected %q", schemeName, authCtx.UserID, merged.UserID) + } + if merged.UserID == "" && authCtx.UserID != "" { + merged.UserID = authCtx.UserID + } + merged.Roles = appendUnique(merged.Roles, authCtx.Roles...) + merged.Scopes = appendUnique(merged.Scopes, authCtx.Scopes...) + if authCtx.Claims != nil { + if merged.Claims == nil { + merged.Claims = make(map[string]interface{}) + } + for k, v := range authCtx.Claims { + merged.Claims[k] = v + } + } + } + } + + return merged, nil +} + +// appendUnique appends values to a slice, skipping duplicates. +func appendUnique(base []string, values ...string) []string { + seen := make(map[string]struct{}, len(base)) + for _, v := range base { + seen[v] = struct{}{} + } + for _, v := range values { + if _, exists := seen[v]; !exists { + base = append(base, v) + seen[v] = struct{}{} + } + } + return base +} + +// buildDefaultFromSchemes generates security requirements from configured schemes. +// Each scheme becomes a separate alternative (OR semantics). +// Schemes are sorted by name for deterministic ordering. +func buildDefaultFromSchemes(schemes map[string]SecurityScheme) []map[string][]string { + names := make([]string, 0, len(schemes)) + for name := range schemes { + names = append(names, name) + } + sort.Strings(names) + + result := make([]map[string][]string, 0, len(names)) + for _, name := range names { + result = append(result, map[string][]string{name: {}}) + } + return result +} diff --git a/auth_schemes_test.go b/auth_schemes_test.go new file mode 100644 index 0000000..a520ee9 --- /dev/null +++ b/auth_schemes_test.go @@ -0,0 +1,939 @@ +package fiberoapi + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gofiber/fiber/v2" +) + +// --- Mock services --- + +// MockBasicAuthService extends MockAuthService with Basic Auth support. +type MockBasicAuthService struct { + MockAuthService + users map[string]string // username -> password +} + +func NewMockBasicAuthService() *MockBasicAuthService { + return &MockBasicAuthService{ + MockAuthService: *NewMockAuthService(), + users: map[string]string{ + "admin": "secret", + "user": "password", + }, + } +} + +func (m *MockBasicAuthService) ValidateBasicAuth(username, password string) (*AuthContext, error) { + expectedPassword, exists := m.users[username] + if !exists { + return nil, fmt.Errorf("unknown user: %s", username) + } + if password != expectedPassword { + return nil, fmt.Errorf("invalid password for user: %s", username) + } + return &AuthContext{ + UserID: username, + Roles: []string{"user"}, + Scopes: []string{"read", "write"}, + }, nil +} + +// MockAPIKeyAuthService extends MockAuthService with API Key support. +type MockAPIKeyAuthService struct { + MockAuthService + validKeys map[string]bool +} + +func NewMockAPIKeyAuthService() *MockAPIKeyAuthService { + return &MockAPIKeyAuthService{ + MockAuthService: *NewMockAuthService(), + validKeys: map[string]bool{ + "my-api-key-123": true, + "test-key-456": true, + }, + } +} + +func (m *MockAPIKeyAuthService) ValidateAPIKey(key string, location string, paramName string) (*AuthContext, error) { + if !m.validKeys[key] { + return nil, fmt.Errorf("invalid API key") + } + return &AuthContext{ + UserID: "apikey-user", + Roles: []string{"user"}, + Scopes: []string{"read"}, + }, nil +} + +// MockBearerAndAPIKeyAuthService implements both Bearer (ValidateToken) and API Key validation. +// Used for testing AND-semantics (multi-scheme requirements). +type MockBearerAndAPIKeyAuthService struct { + MockAuthService + validKeys map[string]string // key -> userID +} + +func NewMockBearerAndAPIKeyAuthService() *MockBearerAndAPIKeyAuthService { + return &MockBearerAndAPIKeyAuthService{ + MockAuthService: *NewMockAuthService(), + validKeys: map[string]string{ + "my-api-key-123": "user-123", // same UserID as MockAuthService + }, + } +} + +func (m *MockBearerAndAPIKeyAuthService) ValidateAPIKey(key string, location string, paramName string) (*AuthContext, error) { + userID, exists := m.validKeys[key] + if !exists { + return nil, fmt.Errorf("invalid API key") + } + return &AuthContext{ + UserID: userID, + Roles: []string{"api-client"}, + Scopes: []string{"api-access"}, + Claims: map[string]interface{}{"key_location": location}, + }, nil +} + +// MockConflictingAPIKeyAuthService returns a different UserID than Bearer to test conflict detection. +type MockConflictingAPIKeyAuthService struct { + MockAuthService +} + +func (m *MockConflictingAPIKeyAuthService) ValidateAPIKey(key string, location string, paramName string) (*AuthContext, error) { + return &AuthContext{ + UserID: "different-user-999", + Roles: []string{"other"}, + Scopes: []string{"other"}, + }, nil +} + +// MockAWSAuthService extends MockAuthService with AWS SigV4 support. +type MockAWSAuthService struct { + MockAuthService + validAccessKeys map[string]bool +} + +func NewMockAWSAuthService() *MockAWSAuthService { + return &MockAWSAuthService{ + MockAuthService: *NewMockAuthService(), + validAccessKeys: map[string]bool{ + "AKIAIOSFODNN7EXAMPLE": true, + }, + } +} + +func (m *MockAWSAuthService) ValidateAWSSignature(params *AWSSignatureParams) (*AuthContext, error) { + if !m.validAccessKeys[params.AccessKeyID] { + return nil, fmt.Errorf("invalid access key: %s", params.AccessKeyID) + } + return &AuthContext{ + UserID: "aws-user-" + params.AccessKeyID, + Roles: []string{"service"}, + Scopes: []string{"read", "write"}, + Claims: map[string]interface{}{ + "region": params.Region, + "service": params.Service, + }, + }, nil +} + +// --- Basic Auth tests --- + +func TestValidateBasicAuth_ValidCredentials(t *testing.T) { + app := fiber.New() + authService := NewMockBasicAuthService() + app.Use(BasicAuthMiddleware(authService)) + app.Get("/test", func(c *fiber.Ctx) error { + authCtx, _ := GetAuthContext(c) + return c.JSON(fiber.Map{"user_id": authCtx.UserID}) + }) + + creds := base64.StdEncoding.EncodeToString([]byte("admin:secret")) + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Basic "+creds) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } +} + +func TestValidateBasicAuth_InvalidCredentials(t *testing.T) { + app := fiber.New() + authService := NewMockBasicAuthService() + app.Use(BasicAuthMiddleware(authService)) + app.Get("/test", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{"message": "should not reach here"}) + }) + + creds := base64.StdEncoding.EncodeToString([]byte("admin:wrongpassword")) + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Basic "+creds) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 401 { + t.Errorf("Expected status 401, got %d", resp.StatusCode) + } +} + +func TestValidateBasicAuth_MalformedBase64(t *testing.T) { + app := fiber.New() + authService := NewMockBasicAuthService() + app.Use(BasicAuthMiddleware(authService)) + app.Get("/test", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{"message": "should not reach here"}) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Basic %%%not-base64%%%") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 401 { + t.Errorf("Expected status 401, got %d", resp.StatusCode) + } +} + +func TestValidateBasicAuth_MissingColon(t *testing.T) { + app := fiber.New() + authService := NewMockBasicAuthService() + app.Use(BasicAuthMiddleware(authService)) + app.Get("/test", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{"message": "should not reach here"}) + }) + + creds := base64.StdEncoding.EncodeToString([]byte("usernameonly")) + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Basic "+creds) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 401 { + t.Errorf("Expected status 401, got %d", resp.StatusCode) + } +} + +func TestValidateBasicAuth_MissingHeader(t *testing.T) { + app := fiber.New() + authService := NewMockBasicAuthService() + app.Use(BasicAuthMiddleware(authService)) + app.Get("/test", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{"message": "should not reach here"}) + }) + + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 401 { + t.Errorf("Expected status 401, got %d", resp.StatusCode) + } +} + +func TestValidateBasicAuth_ServiceDoesNotImplement(t *testing.T) { + app := fiber.New() + // Use plain MockAuthService which does NOT implement BasicAuthValidator + authService := NewMockAuthService() + app.Use(BasicAuthMiddleware(authService)) + app.Get("/test", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{"message": "should not reach here"}) + }) + + creds := base64.StdEncoding.EncodeToString([]byte("admin:secret")) + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Basic "+creds) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 500 { + t.Errorf("Expected status 500 (server misconfiguration), got %d", resp.StatusCode) + } +} + +// --- API Key tests --- + +func TestValidateAPIKey_InHeader_Valid(t *testing.T) { + app := fiber.New() + authService := NewMockAPIKeyAuthService() + scheme := SecurityScheme{Type: "apiKey", In: "header", Name: "X-API-Key"} + app.Use(APIKeyMiddleware(authService, scheme)) + app.Get("/test", func(c *fiber.Ctx) error { + authCtx, _ := GetAuthContext(c) + return c.JSON(fiber.Map{"user_id": authCtx.UserID}) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-API-Key", "my-api-key-123") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } +} + +func TestValidateAPIKey_InQuery_Valid(t *testing.T) { + app := fiber.New() + authService := NewMockAPIKeyAuthService() + scheme := SecurityScheme{Type: "apiKey", In: "query", Name: "api_key"} + app.Use(APIKeyMiddleware(authService, scheme)) + app.Get("/test", func(c *fiber.Ctx) error { + authCtx, _ := GetAuthContext(c) + return c.JSON(fiber.Map{"user_id": authCtx.UserID}) + }) + + req := httptest.NewRequest("GET", "/test?api_key=my-api-key-123", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } +} + +func TestValidateAPIKey_InCookie_Valid(t *testing.T) { + app := fiber.New() + authService := NewMockAPIKeyAuthService() + scheme := SecurityScheme{Type: "apiKey", In: "cookie", Name: "api_key"} + app.Use(APIKeyMiddleware(authService, scheme)) + app.Get("/test", func(c *fiber.Ctx) error { + authCtx, _ := GetAuthContext(c) + return c.JSON(fiber.Map{"user_id": authCtx.UserID}) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.AddCookie(&http.Cookie{Name: "api_key", Value: "my-api-key-123"}) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } +} + +func TestValidateAPIKey_Missing(t *testing.T) { + app := fiber.New() + authService := NewMockAPIKeyAuthService() + scheme := SecurityScheme{Type: "apiKey", In: "header", Name: "X-API-Key"} + app.Use(APIKeyMiddleware(authService, scheme)) + app.Get("/test", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{"message": "should not reach here"}) + }) + + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 401 { + t.Errorf("Expected status 401, got %d", resp.StatusCode) + } +} + +func TestValidateAPIKey_Invalid(t *testing.T) { + app := fiber.New() + authService := NewMockAPIKeyAuthService() + scheme := SecurityScheme{Type: "apiKey", In: "header", Name: "X-API-Key"} + app.Use(APIKeyMiddleware(authService, scheme)) + app.Get("/test", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{"message": "should not reach here"}) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-API-Key", "invalid-key") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 401 { + t.Errorf("Expected status 401, got %d", resp.StatusCode) + } +} + +func TestValidateAPIKey_ServiceDoesNotImplement(t *testing.T) { + app := fiber.New() + authService := NewMockAuthService() + scheme := SecurityScheme{Type: "apiKey", In: "header", Name: "X-API-Key"} + app.Use(APIKeyMiddleware(authService, scheme)) + app.Get("/test", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{"message": "should not reach here"}) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-API-Key", "my-api-key-123") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 500 { + t.Errorf("Expected status 500 (server misconfiguration), got %d", resp.StatusCode) + } +} + +// --- AWS SigV4 tests --- + +func TestValidateAWSSigV4_ValidSignature(t *testing.T) { + app := fiber.New() + authService := NewMockAWSAuthService() + app.Use(AWSSignatureMiddleware(authService)) + app.Get("/test", func(c *fiber.Ctx) error { + authCtx, _ := GetAuthContext(c) + return c.JSON(fiber.Map{"user_id": authCtx.UserID}) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20250101/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=abcdef1234567890") + req.Header.Set("Host", "example.com") + req.Header.Set("X-Amz-Date", "20250101T000000Z") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } +} + +func TestValidateAWSSigV4_InvalidAccessKey(t *testing.T) { + app := fiber.New() + authService := NewMockAWSAuthService() + app.Use(AWSSignatureMiddleware(authService)) + app.Get("/test", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{"message": "should not reach here"}) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "AWS4-HMAC-SHA256 Credential=INVALIDKEY/20250101/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=abcdef1234567890") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 401 { + t.Errorf("Expected status 401, got %d", resp.StatusCode) + } +} + +func TestValidateAWSSigV4_MalformedHeader(t *testing.T) { + app := fiber.New() + authService := NewMockAWSAuthService() + app.Use(AWSSignatureMiddleware(authService)) + app.Get("/test", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{"message": "should not reach here"}) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "AWS4-HMAC-SHA256 garbage-data") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 401 { + t.Errorf("Expected status 401, got %d", resp.StatusCode) + } +} + +func TestValidateAWSSigV4_MissingHeader(t *testing.T) { + app := fiber.New() + authService := NewMockAWSAuthService() + app.Use(AWSSignatureMiddleware(authService)) + app.Get("/test", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{"message": "should not reach here"}) + }) + + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 401 { + t.Errorf("Expected status 401, got %d", resp.StatusCode) + } +} + +func TestValidateAWSSigV4_ServiceDoesNotImplement(t *testing.T) { + app := fiber.New() + authService := NewMockAuthService() + app.Use(AWSSignatureMiddleware(authService)) + app.Get("/test", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{"message": "should not reach here"}) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20250101/us-east-1/s3/aws4_request, SignedHeaders=host, Signature=abc123") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 500 { + t.Errorf("Expected status 500 (server misconfiguration), got %d", resp.StatusCode) + } +} + +// --- parseAWSSigV4Header unit tests --- + +func TestParseAWSSigV4Header(t *testing.T) { + t.Run("Valid header", func(t *testing.T) { + header := "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20250101/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-date;x-amz-content-sha256, Signature=abcdef1234567890" + params, err := parseAWSSigV4Header(header) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if params.AccessKeyID != "AKIAIOSFODNN7EXAMPLE" { + t.Errorf("Expected AccessKeyID AKIAIOSFODNN7EXAMPLE, got %s", params.AccessKeyID) + } + if params.Date != "20250101" { + t.Errorf("Expected Date 20250101, got %s", params.Date) + } + if params.Region != "us-east-1" { + t.Errorf("Expected Region us-east-1, got %s", params.Region) + } + if params.Service != "s3" { + t.Errorf("Expected Service s3, got %s", params.Service) + } + if len(params.SignedHeaders) != 3 { + t.Errorf("Expected 3 signed headers, got %d", len(params.SignedHeaders)) + } + if params.Signature != "abcdef1234567890" { + t.Errorf("Expected Signature abcdef1234567890, got %s", params.Signature) + } + }) + + t.Run("Missing Credential", func(t *testing.T) { + header := "AWS4-HMAC-SHA256 SignedHeaders=host, Signature=abc123" + _, err := parseAWSSigV4Header(header) + if err == nil { + t.Error("Expected error for missing Credential") + } + }) + + t.Run("Missing Signature", func(t *testing.T) { + header := "AWS4-HMAC-SHA256 Credential=AKID/20250101/us-east-1/s3/aws4_request, SignedHeaders=host" + _, err := parseAWSSigV4Header(header) + if err == nil { + t.Error("Expected error for missing Signature") + } + }) + + t.Run("Missing SignedHeaders", func(t *testing.T) { + header := "AWS4-HMAC-SHA256 Credential=AKID/20250101/us-east-1/s3/aws4_request, Signature=abc123" + _, err := parseAWSSigV4Header(header) + if err == nil { + t.Error("Expected error for missing SignedHeaders") + } + }) +} + +// --- Multi-scheme dispatch tests --- + +func TestMultiScheme_BearerStillWorks(t *testing.T) { + app := fiber.New() + authService := NewMockAuthService() + config := Config{ + SecuritySchemes: map[string]SecurityScheme{ + "bearerAuth": {Type: "http", Scheme: "bearer", BearerFormat: "JWT"}, + }, + DefaultSecurity: []map[string][]string{ + {"bearerAuth": {}}, + }, + } + app.Use(MultiSchemeAuthMiddleware(authService, config)) + app.Get("/test", func(c *fiber.Ctx) error { + authCtx, _ := GetAuthContext(c) + return c.JSON(fiber.Map{"user_id": authCtx.UserID}) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer valid-token") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } +} + +func TestMultiScheme_FallbackToSecondScheme(t *testing.T) { + app := fiber.New() + authService := NewMockAPIKeyAuthService() + config := Config{ + SecuritySchemes: map[string]SecurityScheme{ + "bearerAuth": {Type: "http", Scheme: "bearer"}, + "apiKey": {Type: "apiKey", In: "header", Name: "X-API-Key"}, + }, + DefaultSecurity: []map[string][]string{ + {"bearerAuth": {}}, + {"apiKey": {}}, + }, + } + app.Use(MultiSchemeAuthMiddleware(authService, config)) + app.Get("/test", func(c *fiber.Ctx) error { + authCtx, _ := GetAuthContext(c) + return c.JSON(fiber.Map{"user_id": authCtx.UserID}) + }) + + // Send API Key instead of Bearer token + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-API-Key", "my-api-key-123") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } +} + +func TestMultiScheme_AllSchemesFail(t *testing.T) { + app := fiber.New() + authService := NewMockAPIKeyAuthService() + config := Config{ + SecuritySchemes: map[string]SecurityScheme{ + "bearerAuth": {Type: "http", Scheme: "bearer"}, + "apiKey": {Type: "apiKey", In: "header", Name: "X-API-Key"}, + }, + DefaultSecurity: []map[string][]string{ + {"bearerAuth": {}}, + {"apiKey": {}}, + }, + } + app.Use(MultiSchemeAuthMiddleware(authService, config)) + app.Get("/test", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{"message": "should not reach here"}) + }) + + // No auth provided at all + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 401 { + t.Errorf("Expected status 401, got %d", resp.StatusCode) + } +} + +// --- Backward compatibility tests --- + +func TestBackwardCompat_ExistingMockAuthService(t *testing.T) { + // Existing MockAuthService (which does NOT implement any new interfaces) + // should continue to work with Bearer token via validateAuthorization + app := fiber.New() + authService := NewMockAuthService() + + oapi := New(app, Config{ + EnableValidation: true, + EnableAuthorization: true, + AuthService: authService, + // No SecuritySchemes configured - should fallback to Bearer-only + }) + + Get(oapi, "/test", func(c *fiber.Ctx, input struct{}) (fiber.Map, *ErrorResponse) { + authCtx, err := GetAuthContext(c) + if err != nil { + return nil, &ErrorResponse{Code: 500, Details: err.Error()} + } + return fiber.Map{"user_id": authCtx.UserID}, nil + }, OpenAPIOptions{Summary: "Test"}) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer valid-token") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("Expected status 200 for backward compat, got %d", resp.StatusCode) + } +} + +func TestBackwardCompat_BearerTokenMiddleware(t *testing.T) { + // BearerTokenMiddleware should still work independently + app := fiber.New() + authService := NewMockAuthService() + app.Use(BearerTokenMiddleware(authService)) + app.Get("/test", func(c *fiber.Ctx) error { + authCtx, _ := GetAuthContext(c) + return c.JSON(fiber.Map{"user_id": authCtx.UserID}) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer valid-token") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } +} + +// --- SmartAuthMiddleware with SecuritySchemes --- + +func TestSmartAuthMiddleware_WithSecuritySchemes(t *testing.T) { + app := fiber.New() + authService := NewMockBasicAuthService() + config := Config{ + EnableOpenAPIDocs: true, + OpenAPIDocsPath: "/docs", + OpenAPIJSONPath: "/openapi.json", + OpenAPIYamlPath: "/openapi.yaml", + SecuritySchemes: map[string]SecurityScheme{ + "basicAuth": {Type: "http", Scheme: "basic"}, + }, + DefaultSecurity: []map[string][]string{ + {"basicAuth": {}}, + }, + } + app.Use(SmartAuthMiddleware(authService, config)) + app.Get("/test", func(c *fiber.Ctx) error { + authCtx, _ := GetAuthContext(c) + return c.JSON(fiber.Map{"user_id": authCtx.UserID}) + }) + app.Get("/docs", func(c *fiber.Ctx) error { + return c.SendString("docs") + }) + + // Protected route with Basic Auth + creds := base64.StdEncoding.EncodeToString([]byte("admin:secret")) + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Basic "+creds) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + // Docs path should be excluded from auth + req = httptest.NewRequest("GET", "/docs", nil) + resp, err = app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("Expected /docs to be accessible without auth, got %d", resp.StatusCode) + } +} + +// --- AND-semantics context merging tests --- + +func TestValidateSecurityRequirement_ANDMergesContexts(t *testing.T) { + app := fiber.New() + authService := NewMockBearerAndAPIKeyAuthService() + + schemes := map[string]SecurityScheme{ + "apiKey": {Type: "apiKey", In: "header", Name: "X-API-Key"}, + "bearerAuth": {Type: "http", Scheme: "bearer"}, + } + // AND semantics: both Bearer AND API Key must be present + requirement := map[string][]string{ + "bearerAuth": {}, + "apiKey": {}, + } + + app.Get("/test", func(c *fiber.Ctx) error { + authCtx, err := validateSecurityRequirement(c, requirement, schemes, authService) + if err != nil { + return c.Status(401).JSON(fiber.Map{"error": err.Error()}) + } + return c.JSON(fiber.Map{ + "user_id": authCtx.UserID, + "roles": authCtx.Roles, + "scopes": authCtx.Scopes, + "claims": authCtx.Claims, + }) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer valid-token") + req.Header.Set("X-API-Key", "my-api-key-123") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 200 { + t.Fatalf("Expected status 200, got %d", resp.StatusCode) + } + + // Parse response to verify merging + var result map[string]interface{} + if err := parseJSONResponse(resp, &result); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + // UserID should be consistent (both return "user-123") + if result["user_id"] != "user-123" { + t.Errorf("Expected user_id 'user-123', got %v", result["user_id"]) + } + + // Roles should be merged: ["user"] from Bearer + ["api-client"] from API Key + roles, ok := result["roles"].([]interface{}) + if !ok { + t.Fatalf("Expected roles to be array, got %T", result["roles"]) + } + roleSet := make(map[string]bool) + for _, r := range roles { + roleSet[r.(string)] = true + } + for _, expected := range []string{"user", "api-client"} { + if !roleSet[expected] { + t.Errorf("Expected role %q in merged context, got roles: %v", expected, roles) + } + } + + // Scopes should be merged (with dedup): ["read", "write"] from Bearer + ["api-access"] from API Key + scopes, ok := result["scopes"].([]interface{}) + if !ok { + t.Fatalf("Expected scopes to be array, got %T", result["scopes"]) + } + scopeSet := make(map[string]bool) + for _, s := range scopes { + scopeSet[s.(string)] = true + } + for _, expected := range []string{"read", "write", "api-access"} { + if !scopeSet[expected] { + t.Errorf("Expected scope %q in merged context, got scopes: %v", expected, scopes) + } + } + + // Claims should contain API Key's claims + claims, ok := result["claims"].(map[string]interface{}) + if !ok { + t.Fatalf("Expected claims to be map, got %T", result["claims"]) + } + if claims["key_location"] != "header" { + t.Errorf("Expected claim key_location='header', got %v", claims["key_location"]) + } +} + +func TestValidateSecurityRequirement_ANDConflictingUserID(t *testing.T) { + app := fiber.New() + authService := &MockConflictingAPIKeyAuthService{ + MockAuthService: *NewMockAuthService(), + } + + schemes := map[string]SecurityScheme{ + "apiKey": {Type: "apiKey", In: "header", Name: "X-API-Key"}, + "bearerAuth": {Type: "http", Scheme: "bearer"}, + } + // AND semantics: both must pass, but they return different UserIDs + requirement := map[string][]string{ + "bearerAuth": {}, + "apiKey": {}, + } + + app.Get("/test", func(c *fiber.Ctx) error { + _, err := validateSecurityRequirement(c, requirement, schemes, authService) + if err != nil { + return c.Status(401).JSON(fiber.Map{"error": err.Error()}) + } + return c.SendStatus(200) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer valid-token") + req.Header.Set("X-API-Key", "any-key") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 401 { + t.Errorf("Expected status 401 for conflicting UserIDs, got %d", resp.StatusCode) + } +} + +// --- Unsupported API Key location test --- + +func TestValidateAPIKey_UnsupportedLocation(t *testing.T) { + app := fiber.New() + authService := NewMockAPIKeyAuthService() + config := Config{ + SecuritySchemes: map[string]SecurityScheme{ + "badKey": {Type: "apiKey", In: "body", Name: "api_key"}, // "body" is not a valid location + }, + DefaultSecurity: []map[string][]string{ + {"badKey": {}}, + }, + } + app.Use(MultiSchemeAuthMiddleware(authService, config)) + app.Get("/test", func(c *fiber.Ctx) error { + return c.SendStatus(200) + }) + + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 500 { + t.Errorf("Expected status 500 for unsupported API Key location, got %d", resp.StatusCode) + } +} + +// --- Per-route security requirements test --- + +func TestPerRouteSecurity_OverridesGlobalDefault(t *testing.T) { + app := fiber.New() + apiKeyService := NewMockAPIKeyAuthService() + + oapi := New(app, Config{ + EnableValidation: true, + EnableAuthorization: true, + AuthService: apiKeyService, + SecuritySchemes: map[string]SecurityScheme{ + "bearerAuth": {Type: "http", Scheme: "bearer"}, + "apiKey": {Type: "apiKey", In: "header", Name: "X-API-Key"}, + }, + // Global default requires Bearer + DefaultSecurity: []map[string][]string{ + {"bearerAuth": {}}, + }, + }) + + // Route with per-route security requiring API Key instead of Bearer + routeSecurity := []map[string][]string{ + {"apiKey": {}}, + } + Get(oapi, "/api-key-route", func(c *fiber.Ctx, input struct{}) (fiber.Map, *ErrorResponse) { + authCtx, err := GetAuthContext(c) + if err != nil { + return nil, &ErrorResponse{Code: 500, Details: err.Error()} + } + return fiber.Map{"user_id": authCtx.UserID}, nil + }, WithSecurity(OpenAPIOptions{Summary: "API Key route"}, routeSecurity)) + + // Request with API Key (no Bearer) should succeed on the per-route security route + req := httptest.NewRequest("GET", "/api-key-route", nil) + req.Header.Set("X-API-Key", "my-api-key-123") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("Expected status 200 for per-route API Key auth, got %d", resp.StatusCode) + } +} + +func parseJSONResponse(resp *http.Response, target interface{}) error { + defer resp.Body.Close() + return json.NewDecoder(resp.Body).Decode(target) +} diff --git a/auth_test.go b/auth_test.go index 25326df..ee30043 100644 --- a/auth_test.go +++ b/auth_test.go @@ -210,8 +210,8 @@ func TestAuthenticationMiddleware(t *testing.T) { t.Fatalf("Request failed: %v", err) } - if resp.StatusCode != 400 { - t.Errorf("Expected status 400, got %d", resp.StatusCode) + if resp.StatusCode != 401 { + t.Errorf("Expected status 401, got %d", resp.StatusCode) } }) @@ -257,8 +257,8 @@ func TestAuthenticationMiddleware(t *testing.T) { t.Fatalf("Request failed: %v", err) } - if resp.StatusCode != 400 { - t.Errorf("Expected status 400, got %d", resp.StatusCode) + if resp.StatusCode != 401 { + t.Errorf("Expected status 401, got %d", resp.StatusCode) } }) @@ -271,8 +271,8 @@ func TestAuthenticationMiddleware(t *testing.T) { t.Fatalf("Request failed: %v", err) } - if resp.StatusCode != 400 { - t.Errorf("Expected status 400, got %d", resp.StatusCode) + if resp.StatusCode != 401 { + t.Errorf("Expected status 401, got %d", resp.StatusCode) } }) } @@ -559,8 +559,8 @@ func TestAuthServiceFailure(t *testing.T) { t.Fatalf("Request failed: %v", err) } - if resp.StatusCode != 400 { - t.Errorf("Expected status 400, got %d", resp.StatusCode) + if resp.StatusCode != 401 { + t.Errorf("Expected status 401, got %d", resp.StatusCode) } }) } diff --git a/common.go b/common.go index baec2b2..1f99bd7 100644 --- a/common.go +++ b/common.go @@ -118,7 +118,12 @@ func parseInput[TInput any](app *OApiApp, c *fiber.Ctx, path string, options *Op if securityValue, ok := options.Security.(string); ok && securityValue == "disabled" { // Skip authorization for this route } else { - err = validateAuthorization(c, input, app.Config().AuthService) + cfg := app.Config() + // Use per-route security requirements when specified, otherwise fall back to global defaults + if routeSecurity, ok := options.Security.([]map[string][]string); ok && len(routeSecurity) > 0 { + cfg.DefaultSecurity = routeSecurity + } + err = validateAuthorization(c, input, cfg.AuthService, &cfg) if err != nil { return input, err } diff --git a/conditional_auth.go b/conditional_auth.go index 41516ef..287af21 100644 --- a/conditional_auth.go +++ b/conditional_auth.go @@ -1,6 +1,7 @@ package fiberoapi import ( + "errors" "strings" "github.com/gofiber/fiber/v2" @@ -13,7 +14,7 @@ func ConditionalAuthMiddleware(authMiddleware fiber.Handler, excludePaths ...str // Verify if the current path is in the exclude list for _, excludePath := range excludePaths { - if path == excludePath || strings.HasPrefix(path, excludePath) { + if excludePath != "" && (path == excludePath || strings.HasPrefix(path, excludePath)) { return c.Next() // Skip authentication } } @@ -23,9 +24,16 @@ func ConditionalAuthMiddleware(authMiddleware fiber.Handler, excludePaths ...str } } -// SmartAuthMiddleware creates middleware that automatically excludes documentation routes +// SmartAuthMiddleware creates middleware that automatically excludes documentation routes. +// When SecuritySchemes are configured, it uses MultiSchemeAuthMiddleware for dispatch. +// Otherwise, it falls back to BearerTokenMiddleware for backward compatibility. func SmartAuthMiddleware(authService AuthorizationService, config Config) fiber.Handler { - authMiddleware := BearerTokenMiddleware(authService) + var authMiddleware fiber.Handler + if len(config.SecuritySchemes) > 0 { + authMiddleware = MultiSchemeAuthMiddleware(authService, config) + } else { + authMiddleware = BearerTokenMiddleware(authService) + } // Paths to exclude from authentication excludePaths := []string{ @@ -36,3 +44,125 @@ func SmartAuthMiddleware(authService AuthorizationService, config Config) fiber. return ConditionalAuthMiddleware(authMiddleware, excludePaths...) } + +// MultiSchemeAuthMiddleware creates middleware that tries configured security schemes. +// It iterates over DefaultSecurity requirements (OR semantics) and validates +// using the appropriate scheme handler. +func MultiSchemeAuthMiddleware(authService AuthorizationService, config Config) fiber.Handler { + return func(c *fiber.Ctx) error { + securityReqs := config.DefaultSecurity + if len(securityReqs) == 0 { + securityReqs = buildDefaultFromSchemes(config.SecuritySchemes) + } + + // Server configuration errors (5xx) short-circuit immediately since + // no alternative requirement can fix a misconfigured scheme. + var lastErr error + for _, requirement := range securityReqs { + authCtx, err := validateSecurityRequirement(c, requirement, config.SecuritySchemes, authService) + if err == nil { + c.Locals("auth", authCtx) + return c.Next() + } + var authErr *AuthError + if errors.As(err, &authErr) && authErr.StatusCode >= 500 { + return c.Status(authErr.StatusCode).JSON(fiber.Map{ + "error": "Server configuration error", + "details": authErr.Message, + }) + } + lastErr = err + } + + if lastErr == nil { + // No security requirements were configured — this is a server misconfiguration, + // not a client authentication failure. + return c.Status(500).JSON(fiber.Map{ + "error": "Server configuration error", + "details": "no security schemes configured", + }) + } + + status := 401 + errorLabel := "Authentication failed" + var scopeErr *ScopeError + if errors.As(lastErr, &scopeErr) { + status = 403 + errorLabel = "Authorization failed" + } + + return c.Status(status).JSON(fiber.Map{ + "error": errorLabel, + "details": lastErr.Error(), + }) + } +} + +// BasicAuthMiddleware creates a standalone middleware for HTTP Basic authentication. +// The authService must implement the BasicAuthValidator interface. +func BasicAuthMiddleware(validator AuthorizationService) fiber.Handler { + return func(c *fiber.Ctx) error { + authCtx, err := validateBasicAuth(c, validator) + if err != nil { + status, label := classifyAuthError(err) + return c.Status(status).JSON(fiber.Map{ + "error": label, + "details": err.Error(), + }) + } + + c.Locals("auth", authCtx) + return c.Next() + } +} + +// APIKeyMiddleware creates a standalone middleware for API Key authentication. +// The authService must implement the APIKeyValidator interface. +func APIKeyMiddleware(validator AuthorizationService, scheme SecurityScheme) fiber.Handler { + return func(c *fiber.Ctx) error { + authCtx, err := validateAPIKey(c, scheme, validator) + if err != nil { + status, label := classifyAuthError(err) + return c.Status(status).JSON(fiber.Map{ + "error": label, + "details": err.Error(), + }) + } + + c.Locals("auth", authCtx) + return c.Next() + } +} + +// AWSSignatureMiddleware creates a standalone middleware for AWS Signature V4 authentication. +// The authService must implement the AWSSignatureValidator interface. +func AWSSignatureMiddleware(validator AuthorizationService) fiber.Handler { + return func(c *fiber.Ctx) error { + authCtx, err := validateAWSSigV4(c, validator) + if err != nil { + status, label := classifyAuthError(err) + return c.Status(status).JSON(fiber.Map{ + "error": label, + "details": err.Error(), + }) + } + + c.Locals("auth", authCtx) + return c.Next() + } +} + +// classifyAuthError returns the HTTP status and error label for an authentication error. +func classifyAuthError(err error) (int, string) { + var authErr *AuthError + if errors.As(err, &authErr) { + if authErr.StatusCode >= 500 { + return authErr.StatusCode, "Server configuration error" + } + if authErr.StatusCode == 403 { + return authErr.StatusCode, "Authorization failed" + } + return authErr.StatusCode, "Authentication failed" + } + return 401, "Authentication failed" +} diff --git a/fiberoapi.go b/fiberoapi.go index 1c1ee36..a2fd49d 100644 --- a/fiberoapi.go +++ b/fiberoapi.go @@ -1,6 +1,7 @@ package fiberoapi import ( + "errors" "fmt" "net/http" "reflect" @@ -851,6 +852,19 @@ func Method[TInput any, TOutput any, TError any]( fiberHandler := func(c *fiber.Ctx) error { input, err := parseInput[TInput](app, c, fullPath, &options) if err != nil { + // Check for authentication/authorization errors first + var authErr *AuthError + if errors.As(err, &authErr) { + errType := "authentication_error" + if authErr.StatusCode == 403 { + errType = "authorization_error" + } + return c.Status(authErr.StatusCode).JSON(ErrorResponse{ + Code: authErr.StatusCode, + Details: authErr.Message, + Type: errType, + }) + } // Use custom validation error handler if configured if app.config.ValidationErrorHandler != nil { return app.config.ValidationErrorHandler(c, err)