diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 60a5cf1e..4aba3132 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -130,6 +130,22 @@ func newOP( UserFormPath: "/device", UserCode: op.UserCodeBase20, }, + + // mTLS authentication (RFC 8705) - uncomment to enable + // To use mTLS clients, you need to: + // 1. Configure TLS on your server to request client certificates + // 2. Set up a Trust Store with your CA certificates + // 3. Register clients using storage.MTLSClient() or storage.SelfSignedTLSClient() + // + // AuthMethodTLSClientAuth: true, + // AuthMethodSelfSignedTLSClientAuth: true, + // TLSClientCertificateBoundAccessTokens: true, + // MTLSConfig: &op.MTLSConfig{ + // TrustStore: yourCACertPool, // x509.CertPool with trusted CAs + // // Optional: restrict by Policy OID or EKU + // // RequiredPolicyOIDs: []asn1.ObjectIdentifier{...}, + // // RequiredEKUs: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + // }, } handler, err := op.NewOpenIDProvider(issuer, config, storage, append([]op.Option{ diff --git a/example/server/storage/client.go b/example/server/storage/client.go index 010b9ce7..43e71257 100644 --- a/example/server/storage/client.go +++ b/example/server/storage/client.go @@ -1,6 +1,7 @@ package storage import ( + "crypto/x509" "time" "github.com/zitadel/oidc/v3/pkg/oidc" @@ -34,6 +35,11 @@ type Client struct { clockSkew time.Duration postLogoutRedirectURIGlobs []string redirectURIGlobs []string + + // mTLS authentication (RFC 8705) + mtlsConfig *op.MTLSClientConfig // for tls_client_auth + registeredCerts []string // for self_signed_tls_client_auth (PEM-encoded) + registeredCertsParsed []*x509.Certificate // parsed certificates (internal) } // GetID must return the client_id @@ -127,6 +133,18 @@ func (c *Client) ClockSkew() time.Duration { return c.clockSkew } +// GetMTLSConfig returns the mTLS client configuration for tls_client_auth. +// Implements op.HasMTLSConfig interface. +func (c *Client) GetMTLSConfig() *op.MTLSClientConfig { + return c.mtlsConfig +} + +// GetRegisteredCertificates returns the registered certificates for self_signed_tls_client_auth. +// Implements op.HasSelfSignedCertificate interface. +func (c *Client) GetRegisteredCertificates() []string { + return c.registeredCerts +} + // RegisterClients enables you to register clients for the example implementation // there are some clients (web and native) to try out different cases // add more if necessary @@ -211,6 +229,66 @@ func DeviceClient(id, secret string) *Client { } } +// MTLSClient creates a client that uses tls_client_auth (PKI-based mTLS authentication). +// The client is identified by Subject DN or SAN (DNS/URI/IP/Email) in the certificate. +// This implements RFC 8705 Section 2.1.1. +// +// Parameters: +// - id: client identifier +// - mtlsConfig: mTLS client configuration specifying how to identify the client +// - boundTokens: if true, access tokens will be certificate-bound (cnf claim) +// +// Example: +// +// MTLSClient("mtls-client", &op.MTLSClientConfig{ +// SubjectDN: "CN=client1,O=Example,C=US", +// TLSClientCertificateBoundAccessTokens: true, +// }) +func MTLSClient(id string, mtlsConfig *op.MTLSClientConfig) *Client { + return &Client{ + id: id, + applicationType: op.ApplicationTypeWeb, + authMethod: oidc.AuthMethodTLSClientAuth, + loginURL: defaultLoginURL, + responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode}, + grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken, oidc.GrantTypeClientCredentials}, + accessTokenType: op.AccessTokenTypeJWT, // Required for certificate-bound tokens + mtlsConfig: mtlsConfig, + } +} + +// SelfSignedTLSClient creates a client that uses self_signed_tls_client_auth. +// The client authenticates by presenting a certificate that matches one of the +// pre-registered certificates (compared by thumbprint). +// This implements RFC 8705 Section 2.1.2. +// +// Parameters: +// - id: client identifier +// - certificates: PEM-encoded certificates to register for this client +// - boundTokens: if true, access tokens will be certificate-bound (cnf claim) +// +// Example: +// +// certPEM := `-----BEGIN CERTIFICATE----- +// MIIBkTCB+wIJAK... +// -----END CERTIFICATE-----` +// SelfSignedTLSClient("self-signed-client", true, certPEM) +func SelfSignedTLSClient(id string, boundTokens bool, certificates ...string) *Client { + return &Client{ + id: id, + applicationType: op.ApplicationTypeWeb, + authMethod: oidc.AuthMethodSelfSignedTLSClientAuth, + loginURL: defaultLoginURL, + responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode}, + grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken, oidc.GrantTypeClientCredentials}, + accessTokenType: op.AccessTokenTypeJWT, // Required for certificate-bound tokens + registeredCerts: certificates, + mtlsConfig: &op.MTLSClientConfig{ + TLSClientCertificateBoundAccessTokens: boundTokens, + }, + } +} + type hasRedirectGlobs struct { *Client } diff --git a/go.mod b/go.mod index ec2ad807..e4db2f94 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/bmatcuk/doublestar/v4 v4.9.2 github.com/go-chi/chi/v5 v5.2.3 github.com/go-jose/go-jose/v4 v4.0.5 + github.com/go-ldap/ldap/v3 v3.4.11 github.com/golang/mock v1.6.0 github.com/google/go-github/v31 v31.0.0 github.com/google/uuid v1.6.0 @@ -25,8 +26,10 @@ require ( ) require ( + github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/google/go-querystring v1.1.0 // indirect diff --git a/go.sum b/go.sum index 1648c86a..c693030f 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8= +github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= +github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= +github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/bmatcuk/doublestar/v4 v4.9.2 h1:b0mc6WyRSYLjzofB2v/0cuDUZ+MqoGyH3r0dVij35GI= github.com/bmatcuk/doublestar/v4 v4.9.2/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -5,10 +9,14 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE= github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA= +github.com/go-ldap/ldap/v3 v3.4.11 h1:4k0Yxweg+a3OyBLjdYn5OKglv18JNvfDykSoI8bW0gU= +github.com/go-ldap/ldap/v3 v3.4.11/go.mod h1:bY7t0FLK8OAVpp/vV6sSlpz3EQDGcQwc8pF0ujLgKvM= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -31,6 +39,20 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= +github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= +github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= +github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= +github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= +github.com/jcmturner/gofork v1.7.6 h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg= +github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= +github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= +github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= +github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh687T8= +github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= +github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= +github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= github.com/jeremija/gosubmit v0.2.8 h1:mmSITBz9JxVtu8eqbN+zmmwX7Ij2RidQxhcwRVI4wqA= github.com/jeremija/gosubmit v0.2.8/go.mod h1:Ui+HS073lCFREXBbdfrJzMB57OI/bdxTiLtrDHHhFPI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= diff --git a/pkg/oidc/discovery.go b/pkg/oidc/discovery.go index 62288d1b..fa1c5f5c 100644 --- a/pkg/oidc/discovery.go +++ b/pkg/oidc/discovery.go @@ -153,6 +153,23 @@ type DiscoveryConfiguration struct { // BackChannelLogoutSessionSupported specifies whether the OP can pass a sid (session ID) Claim in the Logout Token to identify the RP session with the OP. // If supported, the sid Claim is also included in ID Tokens issued by the OP. If omitted, the default value is false. BackChannelLogoutSessionSupported bool `json:"backchannel_logout_session_supported,omitempty"` + + // TLSClientCertificateBoundAccessTokens indicates whether the authorization server supports + // issuing certificate-bound access tokens as defined in RFC 8705. + TLSClientCertificateBoundAccessTokens bool `json:"tls_client_certificate_bound_access_tokens,omitempty"` + + // MTLSEndpointAliases contains alternative endpoints for mTLS client authentication. + // These endpoints require mutual TLS authentication. + MTLSEndpointAliases *MTLSEndpointAliases `json:"mtls_endpoint_aliases,omitempty"` +} + +// MTLSEndpointAliases contains alternative endpoints for mTLS client authentication +// as defined in RFC 8705 Section 5. +type MTLSEndpointAliases struct { + TokenEndpoint string `json:"token_endpoint,omitempty"` + IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"` + RevocationEndpoint string `json:"revocation_endpoint,omitempty"` + UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"` } type AuthMethod string @@ -162,8 +179,13 @@ const ( AuthMethodPost AuthMethod = "client_secret_post" AuthMethodNone AuthMethod = "none" AuthMethodPrivateKeyJWT AuthMethod = "private_key_jwt" + + // RFC 8705: OAuth 2.0 Mutual-TLS Client Authentication + AuthMethodTLSClientAuth AuthMethod = "tls_client_auth" + AuthMethodSelfSignedTLSClientAuth AuthMethod = "self_signed_tls_client_auth" ) var AllAuthMethods = []AuthMethod{ AuthMethodBasic, AuthMethodPost, AuthMethodNone, AuthMethodPrivateKeyJWT, + AuthMethodTLSClientAuth, AuthMethodSelfSignedTLSClientAuth, } diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index d2b6f6d4..1c7ece01 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -207,6 +207,14 @@ func (i *IDTokenClaims) UnmarshalJSON(data []byte) error { return unmarshalJSONMulti(data, (*itcAlias)(i), &i.Claims) } +// Confirmation represents the "cnf" (confirmation) claim as defined in RFC 7800. +// This is used for certificate-bound access tokens per RFC 8705. +type Confirmation struct { + // X509CertificateSHA256Thumbprint is the SHA-256 thumbprint of the X.509 certificate + // that the access token is bound to, base64url encoded. + X509CertificateSHA256Thumbprint string `json:"x5t#S256,omitempty"` +} + // ActorClaims provides the `act` claims used for impersonation or delegation Token Exchange. // // An actor can be nested in case an obtained token is used as actor token to obtain impersonation or delegation. diff --git a/pkg/op/client.go b/pkg/op/client.go index 86e3a4f8..568b3cba 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -91,6 +91,86 @@ type ClientJWTProfile interface { JWTProfileVerifier(context.Context) *JWTProfileVerifier } +// ClientMTLSProvider is an optional interface for providers that support mTLS client authentication. +type ClientMTLSProvider interface { + ClientProvider + MTLSConfig() *MTLSConfig + AuthMethodTLSClientAuthSupported() bool + AuthMethodSelfSignedTLSClientAuthSupported() bool +} + +// ClientMTLSAuth authenticates a client using mTLS certificate. +// Returns: +// - (clientID, true, nil) on successful authentication +// - ("", false, nil) when mTLS is not configured or no certificate present (fallback to other methods) +// - ("", false, error) on authentication failure +func ClientMTLSAuth(r *http.Request, p ClientMTLSProvider) (clientID string, authenticated bool, err error) { + ctx, span := Tracer.Start(r.Context(), "ClientMTLSAuth") + defer span.End() + + mtlsConfig := p.MTLSConfig() + // Determine client_id to identify which client needs to be validated. + // RFC 8705 requires the client_id parameter for mTLS client authentication, but + // we also check the BasicAuth username to fail-closed for mTLS-only clients. + clientID = r.FormValue("client_id") + if clientID == "" { + if basicID, _, ok := r.BasicAuth(); ok { + if decoded, err := url.QueryUnescape(basicID); err == nil { + clientID = decoded + } + } + } + if clientID == "" { + return "", false, nil // cannot identify client, fallback + } + + // Get client from storage + client, err := p.Storage().GetClientByClientID(ctx, clientID) + if err != nil { + return "", false, nil // Client not found, fallback + } + + // Check if client uses mTLS authentication + authMethod := client.AuthMethod() + if authMethod != oidc.AuthMethodTLSClientAuth && authMethod != oidc.AuthMethodSelfSignedTLSClientAuth { + return "", false, nil // Client doesn't use mTLS, fallback + } + + // Try to extract certificate from request (required for mTLS clients) + certs, err := ClientCertificateFromRequest(r, mtlsConfig) + if err != nil || len(certs) == 0 { + return "", false, oidc.ErrInvalidClient().WithDescription("no client certificate provided") + } + cert := certs[0] + + // Validate mTLS based on authentication method + if authMethod == oidc.AuthMethodTLSClientAuth { + if !p.AuthMethodTLSClientAuthSupported() { + return "", false, oidc.ErrInvalidClient().WithDescription("tls_client_auth not supported") + } + mtlsClient, ok := client.(HasMTLSConfig) + if !ok { + return "", false, oidc.ErrInvalidClient().WithDescription("client does not support mTLS configuration") + } + if err := ValidateTLSClientAuth(certs, mtlsConfig, mtlsClient.GetMTLSConfig()); err != nil { + return "", false, oidc.ErrInvalidClient().WithDescription("mTLS client authentication failed").WithParent(err) + } + } else { // self_signed_tls_client_auth + if !p.AuthMethodSelfSignedTLSClientAuthSupported() { + return "", false, oidc.ErrInvalidClient().WithDescription("self_signed_tls_client_auth not supported") + } + selfSignedClient, ok := client.(HasSelfSignedCertificate) + if !ok { + return "", false, oidc.ErrInvalidClient().WithDescription("client does not support self-signed certificates") + } + if err := ValidateSelfSignedTLSClientAuth(cert, selfSignedClient.GetRegisteredCertificates()); err != nil { + return "", false, oidc.ErrInvalidClient().WithDescription("mTLS client authentication failed").WithParent(err) + } + } + + return clientID, true, nil +} + func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) { ctx, span := Tracer.Start(ctx, "ClientJWTAuth") defer span.End() @@ -140,15 +220,14 @@ type clientData struct { } // ClientIDFromRequest parses the request form and tries to obtain the client ID -// and reports if it is authenticated, using a JWT or static client secrets over +// and reports if it is authenticated, using mTLS, JWT, or static client secrets over // http basic auth. // -// If the Provider implements IntrospectorJWTProfile and "client_assertion" is -// present in the form data, JWT assertion will be verified and the -// client ID is taken from there. -// If any of them is absent, basic auth is attempted. -// In absence of basic auth data, the unauthenticated client id from the form -// data is returned. +// Authentication methods are tried in this order: +// 1. mTLS (if provider implements ClientMTLSProvider and client uses tls_client_auth/self_signed_tls_client_auth) +// 2. JWT assertion (if provider implements ClientJWTProfile and client_assertion is present) +// 3. Basic auth (client_secret_basic) +// 4. Form body (client_secret_post / none) // // If no client id can be obtained by any method, oidc.ErrInvalidClient // is returned with ErrMissingClientID wrapped in it. @@ -167,6 +246,18 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au return "", false, err } + // Try mTLS authentication first (RFC 8705) + if mtlsProvider, ok := p.(ClientMTLSProvider); ok { + clientID, authenticated, err = ClientMTLSAuth(r, mtlsProvider) + if err != nil { + return "", false, err // mTLS auth failed + } + if authenticated { + return clientID, true, nil // mTLS auth succeeded + } + // Fallback to other auth methods + } + JWTProfile, ok := p.(ClientJWTProfile) if ok && data.ClientAssertion != "" { // if JWTProfile is supported and client sent an assertion, check it and use it as response diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index e3ca6035..49bb6905 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -14,6 +14,30 @@ type DiscoverStorage interface { SignatureAlgorithms(context.Context) ([]jose.SignatureAlgorithm, error) } +type mtlsDiscoveryConfig interface { + AuthMethodTLSClientAuthSupported() bool + AuthMethodSelfSignedTLSClientAuthSupported() bool + TLSClientCertificateBoundAccessTokensSupported() bool +} + +type mtlsEndpointAliasesProvider interface { + MTLSEndpointAliases() *oidc.MTLSEndpointAliases +} + +func tlsClientCertificateBoundAccessTokensSupported(c Configuration) bool { + if mc, ok := c.(mtlsDiscoveryConfig); ok { + return mc.TLSClientCertificateBoundAccessTokensSupported() + } + return false +} + +func mtlsEndpointAliases(c Configuration) *oidc.MTLSEndpointAliases { + if mc, ok := c.(mtlsEndpointAliasesProvider); ok { + return mc.MTLSEndpointAliases() + } + return nil +} + var DefaultSupportedScopes = []string{ oidc.ScopeOpenID, oidc.ScopeProfile, @@ -64,6 +88,8 @@ func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage Di RequestParameterSupported: config.RequestObjectSupported(), BackChannelLogoutSupported: config.BackChannelLogoutSupported(), BackChannelLogoutSessionSupported: config.BackChannelLogoutSessionSupported(), + TLSClientCertificateBoundAccessTokens: tlsClientCertificateBoundAccessTokensSupported(config), + MTLSEndpointAliases: mtlsEndpointAliases(config), } } @@ -97,6 +123,8 @@ func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage RequestParameterSupported: config.RequestObjectSupported(), BackChannelLogoutSupported: config.BackChannelLogoutSupported(), BackChannelLogoutSessionSupported: config.BackChannelLogoutSessionSupported(), + TLSClientCertificateBoundAccessTokens: tlsClientCertificateBoundAccessTokensSupported(config), + MTLSEndpointAliases: mtlsEndpointAliases(config), } } @@ -176,6 +204,15 @@ func AuthMethodsTokenEndpoint(c Configuration) []oidc.AuthMethod { if c.AuthMethodPrivateKeyJWTSupported() { authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT) } + // mTLS authentication methods (RFC 8705) + if mc, ok := c.(mtlsDiscoveryConfig); ok { + if mc.AuthMethodTLSClientAuthSupported() { + authMethods = append(authMethods, oidc.AuthMethodTLSClientAuth) + } + if mc.AuthMethodSelfSignedTLSClientAuthSupported() { + authMethods = append(authMethods, oidc.AuthMethodSelfSignedTLSClientAuth) + } + } return authMethods } @@ -200,6 +237,15 @@ func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod { if c.AuthMethodPrivateKeyJWTSupported() { authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT) } + // mTLS authentication methods (RFC 8705) + if mc, ok := c.(mtlsDiscoveryConfig); ok { + if mc.AuthMethodTLSClientAuthSupported() { + authMethods = append(authMethods, oidc.AuthMethodTLSClientAuth) + } + if mc.AuthMethodSelfSignedTLSClientAuthSupported() { + authMethods = append(authMethods, oidc.AuthMethodSelfSignedTLSClientAuth) + } + } return authMethods } @@ -221,6 +267,15 @@ func AuthMethodsRevocationEndpoint(c Configuration) []oidc.AuthMethod { if c.AuthMethodPrivateKeyJWTSupported() { authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT) } + // mTLS authentication methods (RFC 8705) + if mc, ok := c.(mtlsDiscoveryConfig); ok { + if mc.AuthMethodTLSClientAuthSupported() { + authMethods = append(authMethods, oidc.AuthMethodTLSClientAuth) + } + if mc.AuthMethodSelfSignedTLSClientAuthSupported() { + authMethods = append(authMethods, oidc.AuthMethodSelfSignedTLSClientAuth) + } + } return authMethods } diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go index 61afb62c..7f50755a 100644 --- a/pkg/op/discovery_test.go +++ b/pkg/op/discovery_test.go @@ -46,6 +46,47 @@ func TestDiscover(t *testing.T) { } } +func TestDiscover_MTLSEndpointAliases(t *testing.T) { + rec := httptest.NewRecorder() + config := &oidc.DiscoveryConfiguration{ + Issuer: "https://issuer.com", + MTLSEndpointAliases: &oidc.MTLSEndpointAliases{ + TokenEndpoint: "https://mtls.example.com/oauth/token", + IntrospectionEndpoint: "https://mtls.example.com/oauth/introspect", + RevocationEndpoint: "https://mtls.example.com/revoke", + UserinfoEndpoint: "https://mtls.example.com/userinfo", + }, + } + + op.Discover(rec, config) + require.Equal(t, http.StatusOK, rec.Code) + require.JSONEq(t, `{ + "issuer":"https://issuer.com", + "mtls_endpoint_aliases":{ + "token_endpoint":"https://mtls.example.com/oauth/token", + "introspection_endpoint":"https://mtls.example.com/oauth/introspect", + "revocation_endpoint":"https://mtls.example.com/revoke", + "userinfo_endpoint":"https://mtls.example.com/userinfo" + }, + "request_uri_parameter_supported":false + }`, rec.Body.String()) +} + +func TestCreateDiscoveryConfig_MTLSEndpointAliases(t *testing.T) { + cfg := *testConfig + cfg.MTLSEndpointAliases = &oidc.MTLSEndpointAliases{ + TokenEndpoint: "https://mtls.example.com/oauth/token", + IntrospectionEndpoint: "https://mtls.example.com/oauth/introspect", + RevocationEndpoint: "https://mtls.example.com/revoke", + UserinfoEndpoint: "https://mtls.example.com/userinfo", + } + provider := newTestProvider(&cfg) + ctx := op.ContextWithIssuer(context.Background(), testIssuer) + + got := op.CreateDiscoveryConfig(ctx, provider, provider.Storage()) + require.Equal(t, cfg.MTLSEndpointAliases, got.MTLSEndpointAliases) +} + func TestCreateDiscoveryConfig(t *testing.T) { type args struct { ctx context.Context diff --git a/pkg/op/mtls.go b/pkg/op/mtls.go new file mode 100644 index 00000000..6ae2d523 --- /dev/null +++ b/pkg/op/mtls.go @@ -0,0 +1,1278 @@ +package op + +import ( + "container/list" + "context" + "crypto/sha256" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/base64" + "encoding/pem" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "strings" + "sync" + + "github.com/go-ldap/ldap/v3" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +// contextKey is a custom type for context keys to avoid collisions. +type contextKey int + +const ( + // certThumbprintKey is the context key for storing the certificate thumbprint. + certThumbprintKey contextKey = iota + // certChainKey is the context key for storing the client certificate chain. + certChainKey +) + +// ContextWithCertThumbprint returns a new context with the certificate thumbprint stored. +// This is used to pass the thumbprint from the authentication layer to token creation. +func ContextWithCertThumbprint(ctx context.Context, thumbprint string) context.Context { + return context.WithValue(ctx, certThumbprintKey, thumbprint) +} + +// ContextWithClientCertificateChain stores the (leaf-first) certificate chain in the context. +func ContextWithClientCertificateChain(ctx context.Context, certs []*x509.Certificate) context.Context { + if len(certs) == 0 { + return ctx + } + // Copy the slice header so callers can't mutate the stored slice. + copied := make([]*x509.Certificate, len(certs)) + copy(copied, certs) + return context.WithValue(ctx, certChainKey, copied) +} + +// ClientCertificateChainFromContext retrieves the client certificate chain from the context. +func ClientCertificateChainFromContext(ctx context.Context) []*x509.Certificate { + if v := ctx.Value(certChainKey); v != nil { + if certs, ok := v.([]*x509.Certificate); ok { + return certs + } + } + return nil +} + +// CertThumbprintFromContext retrieves the certificate thumbprint from the context. +// Returns empty string if no thumbprint is stored. +func CertThumbprintFromContext(ctx context.Context) string { + if v := ctx.Value(certThumbprintKey); v != nil { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +// SetCertThumbprintInContext extracts the certificate from the request and stores +// the thumbprint in the context if the client requires certificate-bound tokens. +// Returns the updated context (or original if no certificate binding needed). +func SetCertThumbprintInContext(ctx context.Context, r *http.Request, client Client, mtlsConfig *MTLSConfig, boundTokensSupported bool) (context.Context, error) { + if client == nil { + return ctx, oidc.ErrServerError().WithDescription("missing client") + } + + // Check if client requires certificate-bound tokens + mtlsClient, ok := client.(HasMTLSConfig) + if !ok { + return ctx, nil + } + clientConfig := mtlsClient.GetMTLSConfig() + if clientConfig == nil || !clientConfig.TLSClientCertificateBoundAccessTokens { + return ctx, nil + } + + if !boundTokensSupported { + return ctx, oidc.ErrServerError().WithDescription("certificate-bound access tokens not supported") + } + if client.AccessTokenType() != AccessTokenTypeJWT { + return ctx, oidc.ErrServerError().WithDescription("certificate-bound access tokens require JWT access tokens") + } + + // Prefer an already extracted chain (avoids double parsing in header-mode). + certs := ClientCertificateChainFromContext(ctx) + if len(certs) == 0 { + var err error + certs, err = ClientCertificateFromRequest(r, mtlsConfig) + if err != nil || len(certs) == 0 { + return ctx, oidc.ErrInvalidClient().WithDescription("no client certificate provided") + } + ctx = ContextWithClientCertificateChain(ctx, certs) + } + + // Store thumbprint in context + thumbprint := CalculateCertThumbprint(certs[0]) + return ContextWithCertThumbprint(ctx, thumbprint), nil +} + +// MTLSConfig is the global configuration for mTLS authentication. +type MTLSConfig struct { + // TrustStore is the pool of trusted CA certificates. + // Used for validating client certificate chains in tls_client_auth mode. + TrustStore *x509.CertPool + + // RequiredPolicyOIDs specifies certificate policy OIDs that must be present. + // Empty slice skips policy OID validation. + RequiredPolicyOIDs []asn1.ObjectIdentifier + + // RequiredEKUs specifies Extended Key Usages that must be present. + // Empty slice skips EKU validation. + RequiredEKUs []x509.ExtKeyUsage + + // EnableProxyHeaders MUST be explicitly set to true to enable header-based + // certificate extraction. Default: false (disabled). + EnableProxyHeaders bool + + // CertificateHeader is the HTTP header name for certificate forwarding. + // Required when EnableProxyHeaders is true. + CertificateHeader string + + // CertificateHeaderFormat specifies how the certificate is encoded. + // Supported values: "pem-urlencoded", "pem-base64", "der-base64", "xfcc" + // Required when EnableProxyHeaders is true. + CertificateHeaderFormat string + + // TrustedProxyCIDRs specifies CIDR ranges of trusted proxy IPs. + // REQUIRED when EnableProxyHeaders=true (fail-closed policy). + TrustedProxyCIDRs []string + + // parsedCIDRs is the parsed version of TrustedProxyCIDRs (internal) + parsedCIDRs []*net.IPNet + cidrOnce sync.Once + cidrErr error +} + +// MTLSClientConfig is the client-specific mTLS configuration. +type MTLSClientConfig struct { + // Client identifier validation (exactly one must be set) + SubjectDN string // RFC 4514 format Distinguished Name + SANDNS string // Subject Alternative Name: DNS + SANURI string // Subject Alternative Name: URI + SANIP string // Subject Alternative Name: IP Address + SANEmail string // Subject Alternative Name: Email + + // ClientTrustStore overrides the global TrustStore for this client. + ClientTrustStore *x509.CertPool + + // RequiredPolicyOIDs specifies additional policy OIDs required for this client. + RequiredPolicyOIDs []asn1.ObjectIdentifier + + // RequiredEKUs specifies additional EKUs required for this client. + RequiredEKUs []x509.ExtKeyUsage + + // TLSClientCertificateBoundAccessTokens indicates whether this client + // requests certificate-bound access tokens (RFC 8705 Section 3.4). + TLSClientCertificateBoundAccessTokens bool +} + +// Confirmation represents the RFC 8705 cnf claim for certificate-bound tokens. +type Confirmation struct { + X5tS256 string `json:"x5t#S256,omitempty"` +} + +// CertificateBoundClaims is a minimal helper structure for attaching cnf to token/introspection claims. +// The final token claim structs live in pkg/oidc, but this type is useful for internal plumbing/tests. +type CertificateBoundClaims struct { + Confirmation *Confirmation `json:"cnf,omitempty"` +} + +// HasMTLSConfig is an optional interface for clients that support PKI-based mTLS authentication (tls_client_auth). +type HasMTLSConfig interface { + GetMTLSConfig() *MTLSClientConfig +} + +// HasSelfSignedCertificate is an optional interface for clients that support self-signed certificate +// authentication (self_signed_tls_client_auth). +type HasSelfSignedCertificate interface { + // GetRegisteredCertificates returns the pre-registered certificates in PEM format. + GetRegisteredCertificates() []string +} + +func (c *MTLSConfig) ensureParsedCIDRs() error { + if c == nil { + return nil + } + c.cidrOnce.Do(func() { + if !c.EnableProxyHeaders { + return + } + c.parsedCIDRs = make([]*net.IPNet, 0, len(c.TrustedProxyCIDRs)) + for _, cidr := range c.TrustedProxyCIDRs { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + c.cidrErr = fmt.Errorf("invalid CIDR %q: %w", cidr, err) + return + } + c.parsedCIDRs = append(c.parsedCIDRs, ipNet) + } + }) + return c.cidrErr +} + +// ValidateMTLSConfig validates the MTLSConfig at startup. +func ValidateMTLSConfig(config *MTLSConfig) error { + if config == nil { + return nil + } + + if config.EnableProxyHeaders { + if len(config.TrustedProxyCIDRs) == 0 { + return errors.New("TrustedProxyCIDRs is required when EnableProxyHeaders is true") + } + if config.CertificateHeader == "" { + return errors.New("CertificateHeader is required when EnableProxyHeaders is true") + } + if config.CertificateHeaderFormat == "" { + return errors.New("CertificateHeaderFormat is required when EnableProxyHeaders is true") + } + switch config.CertificateHeaderFormat { + case "pem-urlencoded", "pem-base64", "der-base64", "xfcc": + default: + return fmt.Errorf("unsupported CertificateHeaderFormat %q", config.CertificateHeaderFormat) + } + + if err := config.ensureParsedCIDRs(); err != nil { + return err + } + } + + return nil +} + +// ClientCertificateFromRequest extracts the client certificate chain from the request. +func ClientCertificateFromRequest(r *http.Request, config *MTLSConfig) ([]*x509.Certificate, error) { + if r == nil { + return nil, errors.New("nil request") + } + if config == nil { + config = &MTLSConfig{} + } + + if config.EnableProxyHeaders { + if len(config.TrustedProxyCIDRs) == 0 { + return nil, errors.New("TrustedProxyCIDRs is required when EnableProxyHeaders is true") + } + if config.CertificateHeader == "" { + return nil, errors.New("CertificateHeader is required when EnableProxyHeaders is true") + } + if config.CertificateHeaderFormat == "" { + return nil, errors.New("CertificateHeaderFormat is required when EnableProxyHeaders is true") + } + if err := config.ensureParsedCIDRs(); err != nil { + return nil, err + } + + // Extract remote IP + remoteHost, err := remoteHostFromAddr(r.RemoteAddr) + if err != nil { + return nil, errors.New("invalid remote address") + } + + if !isFromTrustedProxy(remoteHost, config) { + return nil, errors.New("request not from trusted proxy") + } + + // Extract from header + headerValue := r.Header.Get(config.CertificateHeader) + if headerValue == "" { + return nil, errors.New("certificate header is empty") + } + + return parseCertificateFromHeader(headerValue, config.CertificateHeaderFormat) + } + + // Direct TLS connection + if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { + return nil, errors.New("no client certificate provided") + } + + return r.TLS.PeerCertificates, nil +} + +func remoteHostFromAddr(remoteAddr string) (string, error) { + host, _, err := net.SplitHostPort(remoteAddr) + if err == nil { + return strings.Trim(host, "[]"), nil + } + trimmed := strings.Trim(remoteAddr, "[]") + if trimmed == "" { + return "", errors.New("empty") + } + return trimmed, nil +} + +func isFromTrustedProxy(remoteHost string, config *MTLSConfig) bool { + if config == nil { + return false + } + + if err := config.ensureParsedCIDRs(); err != nil { + return false + } + + ip := net.ParseIP(remoteHost) + if ip == nil { + return false + } + + for _, ipNet := range config.parsedCIDRs { + if ipNet.Contains(ip) { + return true + } + } + + return false +} + +func parseCertificateFromHeader(headerValue, format string) ([]*x509.Certificate, error) { + var pemData []byte + + switch format { + case "pem-urlencoded": + decoded, err := url.QueryUnescape(headerValue) + if err != nil { + return nil, fmt.Errorf("failed to URL-decode certificate: %w", err) + } + pemData = []byte(decoded) + + case "pem-base64": + decoded, err := decodeBase64(headerValue) + if err != nil { + return nil, fmt.Errorf("failed to base64-decode certificate: %w", err) + } + pemData = decoded + + case "der-base64": + decoded, err := decodeBase64(headerValue) + if err != nil { + return nil, fmt.Errorf("failed to base64-decode DER certificate: %w", err) + } + cert, err := x509.ParseCertificate(decoded) + if err != nil { + return nil, fmt.Errorf("failed to parse DER certificate: %w", err) + } + return []*x509.Certificate{cert}, nil + + case "xfcc": + return parseXFCCHeader(headerValue) + + default: + return nil, fmt.Errorf("unsupported certificate header format: %s", format) + } + + // Parse PEM certificates + return parsePEMCertificates(pemData) +} + +func decodeBase64(s string) ([]byte, error) { + // Accept common variants used by proxies/gateways. + encodings := []*base64.Encoding{ + base64.StdEncoding, + base64.RawStdEncoding, + base64.URLEncoding, + base64.RawURLEncoding, + } + var lastErr error + for _, enc := range encodings { + b, err := enc.DecodeString(s) + if err == nil { + return b, nil + } + lastErr = err + } + return nil, lastErr +} + +func parsePEMCertificates(pemData []byte) ([]*x509.Certificate, error) { + var certs []*x509.Certificate + + for { + block, rest := pem.Decode(pemData) + if block == nil { + break + } + + if block.Type == "CERTIFICATE" { + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %w", err) + } + certs = append(certs, cert) + } + + pemData = rest + } + + if len(certs) == 0 { + return nil, errors.New("no valid certificates found in PEM data") + } + + return certs, nil +} + +func parseXFCCHeader(headerValue string) ([]*x509.Certificate, error) { + // Parse Envoy X-Forwarded-Client-Cert format + // Format: Cert="...";Chain="...";... + var certs []*x509.Certificate + + // https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_conn_man/headers#x-forwarded-client-cert + // Envoy may sanitize/append/forward XFCC; to avoid ambiguity and potential header injection in + // multi-proxy setups, require a single XFCC element (fail-closed). + elements := splitXFCCElements(headerValue) + if len(elements) != 1 { + return nil, errors.New("multiple XFCC elements are not supported; configure the proxy to sanitize XFCC") + } + + parts := splitXFCCPairs(elements[0]) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + eq := strings.IndexByte(part, '=') + if eq <= 0 { + return nil, errors.New("invalid XFCC key-value pair") + } + key := strings.TrimSpace(part[:eq]) + value := strings.TrimSpace(part[eq+1:]) + value = unquoteXFCCValue(value) + + switch strings.ToLower(key) { + case "cert": + decoded, err := url.QueryUnescape(value) + if err != nil { + return nil, fmt.Errorf("failed to decode XFCC Cert: %w", err) + } + parsed, err := parsePEMCertificates([]byte(decoded)) + if err != nil { + return nil, err + } + certs = append(certs, parsed...) + case "chain": + decoded, err := url.QueryUnescape(value) + if err != nil { + return nil, fmt.Errorf("failed to decode XFCC Chain: %w", err) + } + parsed, err := parsePEMCertificates([]byte(decoded)) + if err != nil { + return nil, err + } + certs = append(certs, parsed...) + } + } + + if len(certs) == 0 { + return nil, errors.New("no certificates found in XFCC header") + } + + return certs, nil +} + +func splitXFCCElements(s string) []string { + return splitXFCC(s, ',') +} + +func splitXFCCPairs(s string) []string { + return splitXFCC(s, ';') +} + +func splitXFCC(s string, delim byte) []string { + var parts []string + var b strings.Builder + inQuotes := false + escaped := false + for i := 0; i < len(s); i++ { + c := s[i] + if escaped { + escaped = false + b.WriteByte(c) + continue + } + if inQuotes && c == '\\' { + escaped = true + b.WriteByte(c) + continue + } + if c == '"' { + inQuotes = !inQuotes + b.WriteByte(c) + continue + } + if !inQuotes && c == delim { + part := strings.TrimSpace(b.String()) + if part != "" { + parts = append(parts, part) + } + b.Reset() + continue + } + b.WriteByte(c) + } + last := strings.TrimSpace(b.String()) + if last != "" { + parts = append(parts, last) + } + return parts +} + +func unquoteXFCCValue(v string) string { + if len(v) >= 2 && strings.HasPrefix(v, "\"") && strings.HasSuffix(v, "\"") { + v = v[1 : len(v)-1] + v = strings.ReplaceAll(v, `\"`, `"`) + v = strings.ReplaceAll(v, `\\`, `\`) + } + return v +} + +// ValidateCertificateChain validates the certificate chain against the trust store. +func ValidateCertificateChain(certs []*x509.Certificate, globalConfig *MTLSConfig, clientConfig *MTLSClientConfig) error { + if len(certs) == 0 { + return errors.New("no certificates provided") + } + + leaf := certs[0] + + // Determine trust store + var trustStore *x509.CertPool + if globalConfig != nil { + trustStore = globalConfig.TrustStore + } + if clientConfig != nil && clientConfig.ClientTrustStore != nil { + trustStore = clientConfig.ClientTrustStore + } + + if trustStore == nil { + return errors.New("no trust store configured") + } + + // Build intermediate pool from remaining certs + intermediates := x509.NewCertPool() + for _, cert := range certs[1:] { + intermediates.AddCert(cert) + } + + opts := x509.VerifyOptions{ + Roots: trustStore, + Intermediates: intermediates, + // EKU enforcement is handled separately (ValidateExtKeyUsage) based on configuration. + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, + } + + _, err := leaf.Verify(opts) + if err != nil { + return fmt.Errorf("certificate chain validation failed: %w", err) + } + + return nil +} + +// ValidatePolicyOIDs validates that the certificate contains the required policy OIDs. +func ValidatePolicyOIDs(cert *x509.Certificate, requiredOIDs []asn1.ObjectIdentifier) error { + if cert == nil { + return errors.New("nil certificate") + } + if len(requiredOIDs) == 0 { + return nil + } + + for _, required := range requiredOIDs { + found := false + for _, policy := range cert.PolicyIdentifiers { + if policy.Equal(required) { + found = true + break + } + } + if !found { + return fmt.Errorf("certificate missing required policy OID: %s", required.String()) + } + } + + return nil +} + +// ValidateExtKeyUsage validates that the certificate contains the required EKUs. +func ValidateExtKeyUsage(cert *x509.Certificate, requiredEKUs []x509.ExtKeyUsage) error { + if cert == nil { + return errors.New("nil certificate") + } + if len(requiredEKUs) == 0 { + return nil + } + + for _, required := range requiredEKUs { + found := false + for _, eku := range cert.ExtKeyUsage { + if eku == required { + found = true + break + } + } + if !found { + return fmt.Errorf("certificate missing required EKU: %v", required) + } + } + + return nil +} + +// ValidateClientIdentifier validates the certificate against the client's identifier configuration. +func ValidateClientIdentifier(cert *x509.Certificate, clientConfig *MTLSClientConfig) error { + if cert == nil { + return errors.New("nil certificate") + } + if clientConfig == nil { + return errors.New("no client configuration provided") + } + + if err := ValidateMTLSClientConfig(clientConfig); err != nil { + return err + } + + if clientConfig.SubjectDN != "" { + return matchSubjectDN(cert, clientConfig.SubjectDN) + } + if clientConfig.SANDNS != "" { + return matchSANDNS(cert, clientConfig.SANDNS) + } + if clientConfig.SANURI != "" { + return matchSANURI(cert, clientConfig.SANURI) + } + if clientConfig.SANIP != "" { + return matchSANIP(cert, clientConfig.SANIP) + } + if clientConfig.SANEmail != "" { + return matchSANEmail(cert, clientConfig.SANEmail) + } + + return errors.New("no client identifier configured") +} + +// ValidateMTLSClientConfig ensures exactly one identifier is configured (RFC 8705 requirement). +func ValidateMTLSClientConfig(clientConfig *MTLSClientConfig) error { + if clientConfig == nil { + return errors.New("no client configuration provided") + } + count := 0 + if clientConfig.SubjectDN != "" { + count++ + } + if clientConfig.SANDNS != "" { + count++ + } + if clientConfig.SANURI != "" { + count++ + } + if clientConfig.SANIP != "" { + count++ + } + if clientConfig.SANEmail != "" { + count++ + } + if count == 0 { + return errors.New("no client identifier configured") + } + if count > 1 { + return errors.New("multiple client identifiers configured") + } + return nil +} + +// matchSubjectDN compares certificate Subject DN with expected DN (RFC 4514 format). +// RFC 4514 format is "CN=...,O=...,C=..." (most specific first) +// DER/ToRDNSequence order is "C,O,CN" (least specific first) +// We reverse the parsed DN to match DER order for comparison. +func matchSubjectDN(cert *x509.Certificate, expectedDN string) error { + expectedRDNs, err := getCachedExpectedDN(expectedDN) + if err != nil { + return fmt.Errorf("invalid expected DN: %w", err) + } + + // Get certificate subject as RDN sequence (DER order) + certRDNs := cert.Subject.ToRDNSequence() + + // Compare RDN sequences + if !rdnSequenceEqual(certRDNs, expectedRDNs) { + return errors.New("certificate subject does not match expected DN") + } + + return nil +} + +func reverseExpectedRDNs(rdns [][]expectedAttribute) { + for i, j := 0, len(rdns)-1; i < j; i, j = i+1, j-1 { + rdns[i], rdns[j] = rdns[j], rdns[i] + } +} + +type expectedAttribute struct { + TypeOID asn1.ObjectIdentifier + Value string +} + +type cachedExpectedDN struct { + rdns [][]expectedAttribute + err error +} + +type lruEntry[K comparable, V any] struct { + key K + value V +} + +// lruCache is a small bounded LRU cache for avoiding repeated parsing in hot paths. +// It is safe for concurrent use. +type lruCache[K comparable, V any] struct { + mu sync.Mutex + capacity int + ll *list.List + m map[K]*list.Element +} + +func newLRU[K comparable, V any](capacity int) *lruCache[K, V] { + if capacity < 1 { + capacity = 1 + } + return &lruCache[K, V]{ + capacity: capacity, + ll: list.New(), + m: make(map[K]*list.Element, capacity), + } +} + +func (c *lruCache[K, V]) Get(key K) (V, bool) { + var zero V + if c == nil { + return zero, false + } + c.mu.Lock() + defer c.mu.Unlock() + + if ele, ok := c.m[key]; ok { + c.ll.MoveToFront(ele) + return ele.Value.(lruEntry[K, V]).value, true + } + return zero, false +} + +func (c *lruCache[K, V]) Add(key K, value V) { + if c == nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + + if ele, ok := c.m[key]; ok { + ele.Value = lruEntry[K, V]{key: key, value: value} + c.ll.MoveToFront(ele) + return + } + + ele := c.ll.PushFront(lruEntry[K, V]{key: key, value: value}) + c.m[key] = ele + if c.ll.Len() > c.capacity { + c.removeOldest() + } +} + +func (c *lruCache[K, V]) removeOldest() { + ele := c.ll.Back() + if ele == nil { + return + } + c.ll.Remove(ele) + ent := ele.Value.(lruEntry[K, V]) + delete(c.m, ent.key) +} + +var expectedDNCache = newLRU[string, cachedExpectedDN](1024) + +func getCachedExpectedDN(expectedDN string) ([][]expectedAttribute, error) { + if expectedDN == "" { + return nil, errors.New("empty DN") + } + if v, ok := expectedDNCache.Get(expectedDN); ok { + return v.rdns, v.err + } + rdns, err := parseExpectedDN(expectedDN) + if err == nil { + // Reverse to match DER order (C, O, CN). Reverse in-place before caching. + reverseExpectedRDNs(rdns) + } + expectedDNCache.Add(expectedDN, cachedExpectedDN{rdns: rdns, err: err}) + return rdns, err +} + +var shortNameToOID = map[string]asn1.ObjectIdentifier{ + "CN": {2, 5, 4, 3}, + "O": {2, 5, 4, 10}, + "OU": {2, 5, 4, 11}, + "C": {2, 5, 4, 6}, + "ST": {2, 5, 4, 8}, + "L": {2, 5, 4, 7}, + "STREET": {2, 5, 4, 9}, + "SERIALNUMBER": {2, 5, 4, 5}, +} + +func parseExpectedDN(expectedDN string) ([][]expectedAttribute, error) { + parsed, err := ldap.ParseDN(expectedDN) + if err != nil { + return nil, err + } + + rdns := make([][]expectedAttribute, 0, len(parsed.RDNs)) + for _, rdn := range parsed.RDNs { + attrs := make([]expectedAttribute, 0, len(rdn.Attributes)) + for _, attr := range rdn.Attributes { + oid, err := attrTypeToOID(attr.Type) + if err != nil { + return nil, err + } + attrs = append(attrs, expectedAttribute{ + TypeOID: oid, + Value: attr.Value, + }) + } + rdns = append(rdns, attrs) + } + return rdns, nil +} + +func attrTypeToOID(expectedType string) (asn1.ObjectIdentifier, error) { + if oid, ok := shortNameToOID[strings.ToUpper(expectedType)]; ok { + return oid, nil + } + oid, err := parseOIDString(expectedType) + if err != nil { + return nil, fmt.Errorf("unsupported attribute type %q", expectedType) + } + return oid, nil +} + +func parseOIDString(s string) (asn1.ObjectIdentifier, error) { + parts := strings.Split(s, ".") + if len(parts) < 2 { + return nil, errors.New("not an OID") + } + oid := make(asn1.ObjectIdentifier, 0, len(parts)) + for _, p := range parts { + if p == "" { + return nil, errors.New("invalid OID") + } + var n int + for _, r := range p { + if r < '0' || r > '9' { + return nil, errors.New("invalid OID") + } + n = n*10 + int(r-'0') + } + oid = append(oid, n) + } + return oid, nil +} + +func rdnSequenceEqual(certRDNs pkix.RDNSequence, expectedRDNs [][]expectedAttribute) bool { + if len(certRDNs) != len(expectedRDNs) { + return false + } + + for i := range certRDNs { + if !rdnEqual(certRDNs[i], expectedRDNs[i]) { + return false + } + } + + return true +} + +func rdnEqual(certRDN pkix.RelativeDistinguishedNameSET, expectedRDN []expectedAttribute) bool { + if len(certRDN) != len(expectedRDN) { + return false + } + + matched := make([]bool, len(expectedRDN)) + for _, certAttr := range certRDN { + found := false + for i, expAttr := range expectedRDN { + if matched[i] { + continue + } + if certAttr.Type.Equal(expAttr.TypeOID) && + attrValueEqual(certAttr.Value, expAttr.Value) { + matched[i] = true + found = true + break + } + } + if !found { + return false + } + } + + return true +} + +func attrValueEqual(certValue, expectedValue any) bool { + certStr := normalizeForMatch(fmt.Sprintf("%v", certValue)) + expStr := normalizeForMatch(fmt.Sprintf("%v", expectedValue)) + return strings.EqualFold(certStr, expStr) +} + +func normalizeForMatch(s string) string { + s = strings.TrimSpace(s) + s = strings.Join(strings.Fields(s), " ") + return s +} + +// matchSANDNS checks if the certificate contains the expected DNS SAN. +func matchSANDNS(cert *x509.Certificate, expected string) error { + for _, dns := range cert.DNSNames { + if strings.EqualFold(dns, expected) { + return nil + } + } + return errors.New("certificate does not contain expected DNS SAN") +} + +// matchSANURI checks if the certificate contains the expected URI SAN. +func matchSANURI(cert *x509.Certificate, expected string) error { + expectedURL, err := url.Parse(expected) + if err != nil { + return fmt.Errorf("invalid expected URI: %w", err) + } + + for _, u := range cert.URIs { + if uriEqual(u, expectedURL) { + return nil + } + } + return errors.New("certificate does not contain expected URI SAN") +} + +func uriEqual(a, b *url.URL) bool { + return strings.EqualFold(a.Scheme, b.Scheme) && + strings.EqualFold(a.Host, b.Host) && + a.Path == b.Path && + a.RawQuery == b.RawQuery && + a.Fragment == b.Fragment +} + +// matchSANIP checks if the certificate contains the expected IP SAN. +func matchSANIP(cert *x509.Certificate, expected string) error { + expectedIP := net.ParseIP(expected) + if expectedIP == nil { + return fmt.Errorf("invalid expected IP: %s", expected) + } + + for _, ip := range cert.IPAddresses { + if ip.Equal(expectedIP) { + return nil + } + } + return errors.New("certificate does not contain expected IP SAN") +} + +// matchSANEmail checks if the certificate contains the expected email SAN. +func matchSANEmail(cert *x509.Certificate, expected string) error { + expParts := strings.SplitN(expected, "@", 2) + if len(expParts) != 2 { + return fmt.Errorf("invalid expected email: %s", expected) + } + + for _, email := range cert.EmailAddresses { + certParts := strings.SplitN(email, "@", 2) + if len(certParts) != 2 { + continue + } + // Local-part: exact match (case-sensitive per RFC 5321) + // Domain: case-insensitive + if certParts[0] == expParts[0] && + strings.EqualFold(certParts[1], expParts[1]) { + return nil + } + } + return errors.New("certificate does not contain expected email SAN") +} + +// ValidateTLSClientAuth performs full tls_client_auth validation (RFC 8705 Section 2.1). +// It validates: +// - certificate chain against the trust store (global or client-specific) +// - policy OIDs (global + client-specific, AND) +// - EKUs (global + client-specific, AND) +// - client identifier (Subject DN or SAN) +func ValidateTLSClientAuth(certs []*x509.Certificate, globalConfig *MTLSConfig, clientConfig *MTLSClientConfig) error { + if len(certs) == 0 { + return errors.New("no client certificate provided") + } + if clientConfig == nil { + return errors.New("no client configuration provided") + } + leaf := certs[0] + + // 1. Validate certificate chain + if err := ValidateCertificateChain(certs, globalConfig, clientConfig); err != nil { + return fmt.Errorf("certificate chain validation failed: %w", err) + } + + // 2. Validate global policy OIDs + if globalConfig != nil { + if err := ValidatePolicyOIDs(leaf, globalConfig.RequiredPolicyOIDs); err != nil { + return fmt.Errorf("global policy OID validation failed: %w", err) + } + if err := ValidateExtKeyUsage(leaf, globalConfig.RequiredEKUs); err != nil { + return fmt.Errorf("global EKU validation failed: %w", err) + } + } + + // 3. Validate client-specific policy OIDs and EKUs + if len(clientConfig.RequiredPolicyOIDs) > 0 { + if err := ValidatePolicyOIDs(leaf, clientConfig.RequiredPolicyOIDs); err != nil { + return fmt.Errorf("client policy OID validation failed: %w", err) + } + } + if len(clientConfig.RequiredEKUs) > 0 { + if err := ValidateExtKeyUsage(leaf, clientConfig.RequiredEKUs); err != nil { + return fmt.Errorf("client EKU validation failed: %w", err) + } + } + // 4. Validate client identifier + if err := ValidateClientIdentifier(leaf, clientConfig); err != nil { + return err + } + + return nil +} + +type mtlsClientAuthSupport interface { + MTLSConfig() *MTLSConfig + AuthMethodTLSClientAuthSupported() bool + AuthMethodSelfSignedTLSClientAuthSupported() bool +} + +// validateMTLSClientAuthForClient validates an already identified mTLS client against the request certificate. +// It returns a context that contains the extracted certificate chain to avoid re-parsing it later in the same request. +func validateMTLSClientAuthForClient(ctx context.Context, r *http.Request, provider mtlsClientAuthSupport, client Client) (context.Context, error) { + if client == nil { + return ctx, oidc.ErrServerError().WithDescription("missing client") + } + if provider == nil { + return ctx, oidc.ErrInvalidClient().WithDescription("mTLS authentication not supported") + } + + switch client.AuthMethod() { + case oidc.AuthMethodTLSClientAuth, oidc.AuthMethodSelfSignedTLSClientAuth: + default: + return ctx, nil + } + + certs := ClientCertificateChainFromContext(ctx) + if len(certs) == 0 { + var err error + certs, err = ClientCertificateFromRequest(r, provider.MTLSConfig()) + if err != nil || len(certs) == 0 { + return ctx, oidc.ErrInvalidClient().WithDescription("no client certificate provided") + } + ctx = ContextWithClientCertificateChain(ctx, certs) + } + + switch client.AuthMethod() { + case oidc.AuthMethodTLSClientAuth: + if !provider.AuthMethodTLSClientAuthSupported() { + return ctx, oidc.ErrInvalidClient().WithDescription("tls_client_auth not supported") + } + mtlsClient, ok := client.(HasMTLSConfig) + if !ok { + return ctx, oidc.ErrInvalidClient().WithDescription("client does not support mTLS configuration") + } + if err := ValidateTLSClientAuth(certs, provider.MTLSConfig(), mtlsClient.GetMTLSConfig()); err != nil { + return ctx, oidc.ErrInvalidClient().WithDescription("mTLS client authentication failed").WithParent(err) + } + return ctx, nil + + case oidc.AuthMethodSelfSignedTLSClientAuth: + if !provider.AuthMethodSelfSignedTLSClientAuthSupported() { + return ctx, oidc.ErrInvalidClient().WithDescription("self_signed_tls_client_auth not supported") + } + selfSignedClient, ok := client.(HasSelfSignedCertificate) + if !ok { + return ctx, oidc.ErrInvalidClient().WithDescription("client does not support self-signed certificates") + } + if err := ValidateSelfSignedTLSClientAuth(certs[0], selfSignedClient.GetRegisteredCertificates()); err != nil { + return ctx, oidc.ErrInvalidClient().WithDescription("mTLS client authentication failed").WithParent(err) + } + return ctx, nil + } + + return ctx, oidc.ErrInvalidClient() +} + +// CalculateCertThumbprint calculates the SHA-256 thumbprint of a certificate. +func CalculateCertThumbprint(cert *x509.Certificate) string { + hash := sha256.Sum256(cert.Raw) + return base64.RawURLEncoding.EncodeToString(hash[:]) +} + +// VerifyCertificateBinding verifies that a certificate matches the expected thumbprint. +func VerifyCertificateBinding(cert *x509.Certificate, expectedThumbprint string) error { + if cert == nil { + return errors.New("nil certificate") + } + actualThumbprint := CalculateCertThumbprint(cert) + if actualThumbprint != expectedThumbprint { + return errors.New("certificate binding mismatch") + } + return nil +} + +// CreateCertificateBoundClaims creates a cnf claim for a certificate-bound token. +// Returns nil if cert is nil. +func CreateCertificateBoundClaims(cert *x509.Certificate) *CertificateBoundClaims { + if cert == nil { + return nil + } + return &CertificateBoundClaims{ + Confirmation: &Confirmation{ + X5tS256: CalculateCertThumbprint(cert), + }, + } +} + +// VerifyCertificateBindingWithConfirmation verifies binding against a cnf claim. +// If cnf is nil or cnf.X5tS256 is empty, this is treated as "no binding required". +func VerifyCertificateBindingWithConfirmation(cert *x509.Certificate, cnf *Confirmation) error { + if cnf == nil || cnf.X5tS256 == "" { + return nil + } + if cert == nil { + return errors.New("certificate required for bound token") + } + return VerifyCertificateBinding(cert, cnf.X5tS256) +} + +// GetCnfFromIntrospectionResponse extracts the Confirmation (cnf) claim from +// an introspection response's Claims map. +// Returns nil if not present or invalid format. +func GetCnfFromIntrospectionResponse(claims map[string]any) *Confirmation { + thumbprint := GetCnfThumbprintFromClaims(claims) + if thumbprint == "" { + return nil + } + return &Confirmation{X5tS256: thumbprint} +} + +// VerifyCertificateBindingForIntrospection verifies that the certificate in the request +// matches the cnf claim from an introspection response. +// This is a convenience function for resource servers using token introspection. +// +// Usage: +// +// resp, _ := introspectionClient.IntrospectToken(ctx, token) +// if resp.Active { +// if err := op.VerifyCertificateBindingForIntrospection(r, mtlsConfig, resp.Claims); err != nil { +// // Certificate binding verification failed +// } +// } +func VerifyCertificateBindingForIntrospection(r *http.Request, mtlsConfig *MTLSConfig, introspectionClaims map[string]any) error { + thumbprint := GetCnfThumbprintFromClaims(introspectionClaims) + return VerifyCertificateBindingFromRequest(r, mtlsConfig, thumbprint) +} + +// GetCnfThumbprintFromClaims extracts the x5t#S256 thumbprint from a cnf claim map. +// Returns empty string if not found or invalid format. +func GetCnfThumbprintFromClaims(claims map[string]any) string { + if claims == nil { + return "" + } + cnf, ok := claims["cnf"] + if !ok { + return "" + } + switch v := cnf.(type) { + case map[string]any: + if thumbprint, ok := v["x5t#S256"].(string); ok { + return thumbprint + } + case map[string]string: + if thumbprint, ok := v["x5t#S256"]; ok { + return thumbprint + } + } + return "" +} + +// VerifyCertificateBindingFromRequest verifies that the certificate in the request +// matches the thumbprint from the token's cnf claim. +// If cnfThumbprint is empty, no binding verification is performed (returns nil). +// This function is used by resource servers (UserInfo, protected resources) to verify +// certificate-bound access tokens per RFC 8705 Section 3. +func VerifyCertificateBindingFromRequest(r *http.Request, mtlsConfig *MTLSConfig, cnfThumbprint string) error { + if cnfThumbprint == "" { + return nil // No binding required + } + if mtlsConfig == nil { + return errors.New("mTLS config required for certificate-bound token verification") + } + + certs, err := ClientCertificateFromRequest(r, mtlsConfig) + if err != nil { + return fmt.Errorf("certificate required for bound token: %w", err) + } + if len(certs) == 0 { + return errors.New("certificate required for bound token") + } + + return VerifyCertificateBinding(certs[0], cnfThumbprint) +} + +type cachedRegisteredThumbprint struct { + thumbprint string + ok bool +} + +var registeredCertThumbprintCache = newLRU[string, cachedRegisteredThumbprint](1024) + +func getRegisteredCertThumbprint(pemCert string) (string, bool) { + if pemCert == "" { + return "", false + } + if v, ok := registeredCertThumbprintCache.Get(pemCert); ok { + return v.thumbprint, v.ok + } + + block, _ := pem.Decode([]byte(pemCert)) + if block == nil || block.Type != "CERTIFICATE" { + registeredCertThumbprintCache.Add(pemCert, cachedRegisteredThumbprint{ok: false}) + return "", false + } + + registered, err := x509.ParseCertificate(block.Bytes) + if err != nil { + registeredCertThumbprintCache.Add(pemCert, cachedRegisteredThumbprint{ok: false}) + return "", false + } + + thumbprint := CalculateCertThumbprint(registered) + registeredCertThumbprintCache.Add(pemCert, cachedRegisteredThumbprint{thumbprint: thumbprint, ok: true}) + return thumbprint, true +} + +// ValidateSelfSignedTLSClientAuth validates a certificate against registered self-signed certificates. +func ValidateSelfSignedTLSClientAuth(cert *x509.Certificate, registeredCerts []string) error { + if cert == nil { + return errors.New("nil certificate") + } + certThumbprint := CalculateCertThumbprint(cert) + + for _, pemCert := range registeredCerts { + if thumbprint, ok := getRegisteredCertThumbprint(pemCert); ok && thumbprint == certThumbprint { + return nil + } + } + + return errors.New("no matching registered certificate") +} diff --git a/pkg/op/mtls_test.go b/pkg/op/mtls_test.go new file mode 100644 index 00000000..0e6c6a67 --- /dev/null +++ b/pkg/op/mtls_test.go @@ -0,0 +1,2755 @@ +package op_test + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/base64" + "encoding/pem" + "fmt" + "math/big" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/oidc/v3/pkg/op" +) + +// ============================================================================= +// Test Certificate Generation Helpers +// ============================================================================= + +type testCertOptions struct { + subject pkix.Name + dnsNames []string + ipAddresses []net.IP + uris []*url.URL + emails []string + policyOIDs []asn1.ObjectIdentifier + extKeyUsage []x509.ExtKeyUsage + isCA bool + parent *x509.Certificate + parentKey *ecdsa.PrivateKey + notBefore time.Time + notAfter time.Time +} + +// Certificate policies extension OID +var oidCertificatePolicies = asn1.ObjectIdentifier{2, 5, 29, 32} + +// buildCertPoliciesExtension creates a certificate policies extension +func buildCertPoliciesExtension(policyOIDs []asn1.ObjectIdentifier) (pkix.Extension, error) { + type policyInformation struct { + PolicyIdentifier asn1.ObjectIdentifier + } + + var policies []policyInformation + for _, oid := range policyOIDs { + policies = append(policies, policyInformation{PolicyIdentifier: oid}) + } + + data, err := asn1.Marshal(policies) + if err != nil { + return pkix.Extension{}, err + } + + return pkix.Extension{ + Id: oidCertificatePolicies, + Critical: false, + Value: data, + }, nil +} + +func generateTestCert(t *testing.T, opts testCertOptions) (*x509.Certificate, *ecdsa.PrivateKey) { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + require.NoError(t, err) + + notBefore := opts.notBefore + if notBefore.IsZero() { + notBefore = time.Now().Add(-time.Hour) + } + notAfter := opts.notAfter + if notAfter.IsZero() { + notAfter = time.Now().Add(24 * time.Hour) + } + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: opts.subject, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: opts.extKeyUsage, + BasicConstraintsValid: true, + IsCA: opts.isCA, + DNSNames: opts.dnsNames, + IPAddresses: opts.ipAddresses, + URIs: opts.uris, + EmailAddresses: opts.emails, + } + + // Add policy OIDs as an extension (PolicyIdentifiers field alone doesn't work) + if len(opts.policyOIDs) > 0 { + policyExt, err := buildCertPoliciesExtension(opts.policyOIDs) + require.NoError(t, err) + template.ExtraExtensions = append(template.ExtraExtensions, policyExt) + } + + if opts.isCA { + template.KeyUsage |= x509.KeyUsageCertSign + } + + parent := template + parentKey := key + if opts.parent != nil && opts.parentKey != nil { + parent = opts.parent + parentKey = opts.parentKey + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, parent, &key.PublicKey, parentKey) + require.NoError(t, err) + + cert, err := x509.ParseCertificate(certDER) + require.NoError(t, err) + + return cert, key +} + +func generateTestCA(t *testing.T, subject pkix.Name) (*x509.Certificate, *ecdsa.PrivateKey) { + t.Helper() + return generateTestCert(t, testCertOptions{ + subject: subject, + isCA: true, + }) +} + +func certToPEM(cert *x509.Certificate) string { + return string(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + })) +} + +func calculateThumbprint(cert *x509.Certificate) string { + hash := sha256.Sum256(cert.Raw) + return base64.RawURLEncoding.EncodeToString(hash[:]) +} + +// ============================================================================= +// MTLSConfig Tests +// ============================================================================= + +func TestMTLSConfig_Validation(t *testing.T) { + trustStore := x509.NewCertPool() + + tests := []struct { + name string + config *op.MTLSConfig + wantErr string + }{ + { + name: "valid direct TLS config", + config: &op.MTLSConfig{ + TrustStore: trustStore, + }, + wantErr: "", + }, + { + name: "proxy headers without TrustedProxyCIDRs", + config: &op.MTLSConfig{ + TrustStore: trustStore, + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: nil, // Missing! + }, + wantErr: "TrustedProxyCIDRs is required when EnableProxyHeaders is true", + }, + { + name: "proxy headers without CertificateHeader", + config: &op.MTLSConfig{ + TrustStore: trustStore, + EnableProxyHeaders: true, + CertificateHeader: "", // Missing! + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + }, + wantErr: "CertificateHeader is required when EnableProxyHeaders is true", + }, + { + name: "proxy headers without CertificateHeaderFormat", + config: &op.MTLSConfig{ + TrustStore: trustStore, + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "", // Missing! + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + }, + wantErr: "CertificateHeaderFormat is required when EnableProxyHeaders is true", + }, + { + name: "valid proxy headers config", + config: &op.MTLSConfig{ + TrustStore: trustStore, + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + }, + wantErr: "", + }, + { + name: "proxy headers with unsupported CertificateHeaderFormat", + config: &op.MTLSConfig{ + TrustStore: trustStore, + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "unknown-format", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + }, + wantErr: "unsupported CertificateHeaderFormat", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := op.ValidateMTLSConfig(tt.config) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} + +// ============================================================================= +// Certificate Extraction Tests +// ============================================================================= + +func TestClientCertificateFromRequest_TLS(t *testing.T) { + // Generate test certificate + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: false, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.TLS = &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{cert}, + } + + certs, err := op.ClientCertificateFromRequest(r, config) + require.NoError(t, err) + require.Len(t, certs, 1) + assert.Equal(t, cert, certs[0]) +} + +func TestClientCertificateFromRequest_TLS_WithChain(t *testing.T) { + // Generate CA and client cert + ca, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + parent: ca, + parentKey: caKey, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: false, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.TLS = &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{clientCert, ca}, + } + + certs, err := op.ClientCertificateFromRequest(r, config) + require.NoError(t, err) + require.Len(t, certs, 2) + assert.Equal(t, clientCert, certs[0]) // Leaf first + assert.Equal(t, ca, certs[1]) // Then intermediate/CA +} + +func TestClientCertificateFromRequest_NoCert(t *testing.T) { + config := &op.MTLSConfig{ + EnableProxyHeaders: false, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + // No TLS connection state + + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) +} + +func TestClientCertificateFromRequest_Header_Disabled(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: false, // Disabled + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", url.QueryEscape(certToPEM(cert))) + // No TLS - should fail because proxy headers are disabled + + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) +} + +func TestClientCertificateFromRequest_Header_UntrustedIP(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, // Only 10.x.x.x + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", url.QueryEscape(certToPEM(cert))) + r.RemoteAddr = "192.168.1.1:12345" // Not in trusted range + + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "not from trusted proxy") +} + +func TestClientCertificateFromRequest_Header_TrustedIP(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", url.QueryEscape(certToPEM(cert))) + r.RemoteAddr = "10.0.0.1:12345" // In trusted range + + certs, err := op.ClientCertificateFromRequest(r, config) + require.NoError(t, err) + require.Len(t, certs, 1) + assert.Equal(t, cert.Subject.CommonName, certs[0].Subject.CommonName) +} + +func TestClientCertificateFromRequest_Header_PEMBase64(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-base64", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", base64.StdEncoding.EncodeToString([]byte(certToPEM(cert)))) + r.RemoteAddr = "10.0.0.1:12345" + + certs, err := op.ClientCertificateFromRequest(r, config) + require.NoError(t, err) + require.Len(t, certs, 1) +} + +func TestClientCertificateFromRequest_Header_DERBase64(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "der-base64", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", base64.StdEncoding.EncodeToString(cert.Raw)) + r.RemoteAddr = "10.0.0.1:12345" + + certs, err := op.ClientCertificateFromRequest(r, config) + require.NoError(t, err) + require.Len(t, certs, 1) +} + +func TestClientCertificateFromRequest_Header_InvalidFormat(t *testing.T) { + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", "not-valid-pem") + r.RemoteAddr = "10.0.0.1:12345" + + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) +} + +// ============================================================================= +// Certificate Chain Validation Tests +// ============================================================================= + +func TestValidateCertificateChain_ValidCA(t *testing.T) { + // Create CA + ca, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA", Organization: []string{"Test Org"}}) + + // Create client cert signed by CA + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client.example.com"}, + parent: ca, + parentKey: caKey, + extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + + // Create trust store with CA + trustStore := x509.NewCertPool() + trustStore.AddCert(ca) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + } + + err := op.ValidateCertificateChain([]*x509.Certificate{clientCert}, globalConfig, nil) + require.NoError(t, err) +} + +func TestValidateCertificateChain_UntrustedCA(t *testing.T) { + // Create untrusted CA + untrustedCA, untrustedKey := generateTestCA(t, pkix.Name{CommonName: "Untrusted CA"}) + + // Create client cert signed by untrusted CA + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client.example.com"}, + parent: untrustedCA, + parentKey: untrustedKey, + }) + + // Create trust store with different CA + trustedCA, _ := generateTestCA(t, pkix.Name{CommonName: "Trusted CA"}) + trustStore := x509.NewCertPool() + trustStore.AddCert(trustedCA) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + } + + err := op.ValidateCertificateChain([]*x509.Certificate{clientCert}, globalConfig, nil) + require.Error(t, err) +} + +func TestValidateCertificateChain_ExpiredCert(t *testing.T) { + ca, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + + // Create expired client cert + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client.example.com"}, + parent: ca, + parentKey: caKey, + notBefore: time.Now().Add(-48 * time.Hour), + notAfter: time.Now().Add(-24 * time.Hour), // Expired + }) + + trustStore := x509.NewCertPool() + trustStore.AddCert(ca) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + } + + err := op.ValidateCertificateChain([]*x509.Certificate{clientCert}, globalConfig, nil) + require.Error(t, err) +} + +func TestValidateCertificateChain_WithIntermediates(t *testing.T) { + // Create root CA + rootCA, rootKey := generateTestCA(t, pkix.Name{CommonName: "Root CA"}) + + // Create intermediate CA + intermediateCA, intermediateKey := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "Intermediate CA"}, + isCA: true, + parent: rootCA, + parentKey: rootKey, + }) + + // Create client cert signed by intermediate + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client.example.com"}, + parent: intermediateCA, + parentKey: intermediateKey, + }) + + // Trust store only has root CA + trustStore := x509.NewCertPool() + trustStore.AddCert(rootCA) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + } + + // Provide chain: [leaf, intermediate] + err := op.ValidateCertificateChain([]*x509.Certificate{clientCert, intermediateCA}, globalConfig, nil) + require.NoError(t, err) +} + +func TestValidateCertificateChain_ClientTrustStore(t *testing.T) { + // Create two CAs + globalCA, _ := generateTestCA(t, pkix.Name{CommonName: "Global CA"}) + clientCA, clientCAKey := generateTestCA(t, pkix.Name{CommonName: "Client CA"}) + + // Create client cert signed by client-specific CA + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client.example.com"}, + parent: clientCA, + parentKey: clientCAKey, + }) + + // Global trust store has only globalCA + globalTrustStore := x509.NewCertPool() + globalTrustStore.AddCert(globalCA) + + // Client trust store has clientCA + clientTrustStore := x509.NewCertPool() + clientTrustStore.AddCert(clientCA) + + globalConfig := &op.MTLSConfig{ + TrustStore: globalTrustStore, + } + clientConfig := &op.MTLSClientConfig{ + ClientTrustStore: clientTrustStore, + } + + // Should pass because client-specific trust store overrides global + err := op.ValidateCertificateChain([]*x509.Certificate{clientCert}, globalConfig, clientConfig) + require.NoError(t, err) +} + +// ============================================================================= +// Policy OID Validation Tests +// ============================================================================= + +func TestValidatePolicyOIDs_Match(t *testing.T) { + requiredOID := asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 2, 1, 3, 13} + + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + policyOIDs: []asn1.ObjectIdentifier{requiredOID}, + }) + + err := op.ValidatePolicyOIDs(cert, []asn1.ObjectIdentifier{requiredOID}) + require.NoError(t, err) +} + +func TestValidatePolicyOIDs_Missing(t *testing.T) { + requiredOID := asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 2, 1, 3, 13} + differentOID := asn1.ObjectIdentifier{1, 2, 3, 4, 5} + + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + policyOIDs: []asn1.ObjectIdentifier{differentOID}, + }) + + err := op.ValidatePolicyOIDs(cert, []asn1.ObjectIdentifier{requiredOID}) + require.Error(t, err) +} + +func TestValidatePolicyOIDs_Empty(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + }) + + // Empty required OIDs = skip validation + err := op.ValidatePolicyOIDs(cert, nil) + require.NoError(t, err) +} + +// ============================================================================= +// Extended Key Usage Validation Tests +// ============================================================================= + +func TestValidateExtKeyUsage_ClientAuth(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + + err := op.ValidateExtKeyUsage(cert, []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}) + require.NoError(t, err) +} + +func TestValidateExtKeyUsage_Missing(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, // Wrong EKU + }) + + err := op.ValidateExtKeyUsage(cert, []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}) + require.Error(t, err) +} + +func TestValidateExtKeyUsage_Empty(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + }) + + // Empty required EKUs = skip validation + err := op.ValidateExtKeyUsage(cert, nil) + require.NoError(t, err) +} + +// ============================================================================= +// Subject DN Comparison Tests (RFC 4517) +// ============================================================================= + +func TestSubjectDN_ExactMatch(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{ + CommonName: "client.example.com", + Organization: []string{"Example Inc"}, + Country: []string{"US"}, + }, + }) + + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "CN=client.example.com,O=Example Inc,C=US", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.NoError(t, err) +} + +func TestSubjectDN_OIDTypes_Match(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{ + CommonName: "client.example.com", + Organization: []string{"Example Inc"}, + Country: []string{"US"}, + }, + }) + + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "2.5.4.3=client.example.com,2.5.4.10=Example Inc,2.5.4.6=US", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.NoError(t, err) +} + +func TestSubjectDN_UnsupportedAttributeType_Rejected(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{ + CommonName: "client.example.com", + }, + }) + + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "DC=example,CN=client.example.com", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid expected DN") +} + +func TestSubjectDN_InvalidOID_Rejected(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{ + CommonName: "client.example.com", + }, + }) + + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "2.5.4.a=client.example.com", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid expected DN") +} + +func TestSubjectDN_DifferentOrder_Rejected(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{ + CommonName: "client.example.com", + Organization: []string{"Example Inc"}, + Country: []string{"US"}, + }, + }) + + // Different RDN order (C first instead of last) + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "C=US,O=Example Inc,CN=client.example.com", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.Error(t, err) +} + +func TestSubjectDN_CaseNormalization(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{ + CommonName: "Client.Example.COM", + Organization: []string{"EXAMPLE INC"}, + }, + }) + + // Different case should still match (caseIgnoreMatch) + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "CN=client.example.com,O=example inc", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.NoError(t, err) +} + +func TestSubjectDN_WhitespaceNormalization(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{ + CommonName: "client.example.com", + Organization: []string{"Example Inc"}, // Extra space + }, + }) + + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "CN=client.example.com,O=Example Inc", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.NoError(t, err) +} + +// ============================================================================= +// SAN Comparison Tests +// ============================================================================= + +func TestSANDNS_CaseInsensitive(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + dnsNames: []string{"Client.Example.COM"}, + }) + + clientConfig := &op.MTLSClientConfig{ + SANDNS: "client.example.com", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.NoError(t, err) +} + +func TestSANDNS_NoMatch(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + dnsNames: []string{"other.example.com"}, + }) + + clientConfig := &op.MTLSClientConfig{ + SANDNS: "client.example.com", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.Error(t, err) +} + +func TestSANIP_BinaryComparison_IPv4(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + ipAddresses: []net.IP{net.ParseIP("192.168.1.100")}, + }) + + clientConfig := &op.MTLSClientConfig{ + SANIP: "192.168.1.100", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.NoError(t, err) +} + +func TestSANIP_BinaryComparison_IPv6(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + ipAddresses: []net.IP{net.ParseIP("2001:db8::1")}, + }) + + clientConfig := &op.MTLSClientConfig{ + SANIP: "2001:db8::1", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.NoError(t, err) +} + +func TestSANURI_Normalized(t *testing.T) { + uri, _ := url.Parse("https://Client.Example.COM/path") + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + uris: []*url.URL{uri}, + }) + + clientConfig := &op.MTLSClientConfig{ + SANURI: "https://client.example.com/path", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.NoError(t, err) +} + +func TestSANEmail_LocalExact_DomainInsensitive(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + emails: []string{"User@Example.COM"}, + }) + + // Local part is case-sensitive, domain is case-insensitive + clientConfig := &op.MTLSClientConfig{ + SANEmail: "User@example.com", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.NoError(t, err) +} + +func TestSANEmail_LocalDifferent_Rejected(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + emails: []string{"User@example.com"}, + }) + + // Different local part case - should fail + clientConfig := &op.MTLSClientConfig{ + SANEmail: "user@example.com", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.Error(t, err) +} + +func TestSAN_NoWildcard(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + dnsNames: []string{"*.example.com"}, + }) + + // Wildcard in cert should not match non-wildcard expected + clientConfig := &op.MTLSClientConfig{ + SANDNS: "client.example.com", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.Error(t, err) +} + +// ============================================================================= +// Certificate Thumbprint Tests +// ============================================================================= + +func TestCalculateCertThumbprint(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + }) + + thumbprint := op.CalculateCertThumbprint(cert) + require.NotEmpty(t, thumbprint) + + // Verify it's base64url encoded + decoded, err := base64.RawURLEncoding.DecodeString(thumbprint) + require.NoError(t, err) + assert.Len(t, decoded, 32) // SHA-256 = 32 bytes +} + +func TestVerifyCertificateBinding_Match(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + }) + + thumbprint := calculateThumbprint(cert) + + err := op.VerifyCertificateBinding(cert, thumbprint) + require.NoError(t, err) +} + +func TestVerifyCertificateBinding_Mismatch(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + }) + + wrongThumbprint := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + + err := op.VerifyCertificateBinding(cert, wrongThumbprint) + require.Error(t, err) +} + +// ============================================================================= +// Self-Signed Certificate Tests +// ============================================================================= + +func TestValidateSelfSignedTLSClientAuth_Match(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "self-signed-client"}, + }) + + registeredCerts := []string{certToPEM(cert)} + + err := op.ValidateSelfSignedTLSClientAuth(cert, registeredCerts) + require.NoError(t, err) +} + +func TestValidateSelfSignedTLSClientAuth_NoMatch(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "self-signed-client"}, + }) + + otherCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "other-client"}, + }) + + registeredCerts := []string{certToPEM(otherCert)} + + err := op.ValidateSelfSignedTLSClientAuth(cert, registeredCerts) + require.Error(t, err) +} + +func TestValidateSelfSignedTLSClientAuth_MultipleCerts(t *testing.T) { + cert1, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client-1"}, + }) + cert2, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client-2"}, + }) + cert3, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client-3"}, + }) + + // Register cert1 and cert3 + registeredCerts := []string{certToPEM(cert1), certToPEM(cert3)} + + // cert2 should not match + err := op.ValidateSelfSignedTLSClientAuth(cert2, registeredCerts) + require.Error(t, err) + + // cert1 should match + err = op.ValidateSelfSignedTLSClientAuth(cert1, registeredCerts) + require.NoError(t, err) + + // cert3 should match + err = op.ValidateSelfSignedTLSClientAuth(cert3, registeredCerts) + require.NoError(t, err) +} + +// ============================================================================= +// Fail-Closed Behavior Tests +// ============================================================================= + +func TestFailClosed_EmptyTrustStore(t *testing.T) { + ca, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client"}, + parent: ca, + parentKey: caKey, + }) + + // Empty trust store + globalConfig := &op.MTLSConfig{ + TrustStore: x509.NewCertPool(), + } + + err := op.ValidateCertificateChain([]*x509.Certificate{clientCert}, globalConfig, nil) + require.Error(t, err) +} + +func TestFailClosed_ProxyUntrustedIP_NoFallback(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", url.QueryEscape(certToPEM(cert))) + r.RemoteAddr = "192.168.1.1:12345" // Not in trusted range + + // Also set TLS cert - should NOT fall back to this + r.TLS = &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{cert}, + } + + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "not from trusted proxy") +} + +// ============================================================================= +// XFCC Header Format Tests (Envoy X-Forwarded-Client-Cert) +// ============================================================================= + +func TestClientCertificateFromRequest_Header_XFCC(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Forwarded-Client-Cert", + CertificateHeaderFormat: "xfcc", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + // XFCC format: Cert="" + xfccValue := fmt.Sprintf(`Cert="%s"`, url.QueryEscape(certToPEM(cert))) + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Forwarded-Client-Cert", xfccValue) + r.RemoteAddr = "10.0.0.1:12345" + + certs, err := op.ClientCertificateFromRequest(r, config) + require.NoError(t, err) + require.Len(t, certs, 1) + assert.Equal(t, cert.Subject.CommonName, certs[0].Subject.CommonName) +} + +func TestClientCertificateFromRequest_Header_XFCC_CaseInsensitiveKey(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Forwarded-Client-Cert", + CertificateHeaderFormat: "xfcc", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + xfccValue := fmt.Sprintf(`cert="%s"`, url.QueryEscape(certToPEM(cert))) + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Forwarded-Client-Cert", xfccValue) + r.RemoteAddr = "10.0.0.1:12345" + + certs, err := op.ClientCertificateFromRequest(r, config) + require.NoError(t, err) + require.Len(t, certs, 1) +} + +func TestClientCertificateFromRequest_Header_XFCC_SubjectWithComma(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Forwarded-Client-Cert", + CertificateHeaderFormat: "xfcc", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + xfccValue := fmt.Sprintf(`Subject="CN=a,b";Cert="%s"`, url.QueryEscape(certToPEM(cert))) + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Forwarded-Client-Cert", xfccValue) + r.RemoteAddr = "10.0.0.1:12345" + + certs, err := op.ClientCertificateFromRequest(r, config) + require.NoError(t, err) + require.Len(t, certs, 1) +} + +func TestClientCertificateFromRequest_Header_XFCC_MultipleElements_Rejected(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Forwarded-Client-Cert", + CertificateHeaderFormat: "xfcc", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + xfccValue := fmt.Sprintf(`Cert="%s",Cert="%s"`, + url.QueryEscape(certToPEM(cert)), + url.QueryEscape(certToPEM(cert))) + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Forwarded-Client-Cert", xfccValue) + r.RemoteAddr = "10.0.0.1:12345" + + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "multiple XFCC elements") +} + +func TestClientCertificateFromRequest_Header_XFCC_WithChain(t *testing.T) { + ca, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + parent: ca, + parentKey: caKey, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Forwarded-Client-Cert", + CertificateHeaderFormat: "xfcc", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + // XFCC format with Chain: Cert="";Chain="" + xfccValue := fmt.Sprintf(`Cert="%s";Chain="%s"`, + url.QueryEscape(certToPEM(clientCert)), + url.QueryEscape(certToPEM(ca))) + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Forwarded-Client-Cert", xfccValue) + r.RemoteAddr = "10.0.0.1:12345" + + certs, err := op.ClientCertificateFromRequest(r, config) + require.NoError(t, err) + require.Len(t, certs, 2) + assert.Equal(t, clientCert.Subject.CommonName, certs[0].Subject.CommonName) + assert.Equal(t, ca.Subject.CommonName, certs[1].Subject.CommonName) +} + +func TestClientCertificateFromRequest_Header_XFCC_InvalidFormat(t *testing.T) { + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Forwarded-Client-Cert", + CertificateHeaderFormat: "xfcc", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Forwarded-Client-Cert", "invalid-xfcc-format") + r.RemoteAddr = "10.0.0.1:12345" + + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) +} + +func TestClientCertificateFromRequest_Header_XFCC_NoCertField(t *testing.T) { + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Forwarded-Client-Cert", + CertificateHeaderFormat: "xfcc", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + // XFCC without Cert field (only Hash) + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Forwarded-Client-Cert", `Hash=abc123;Subject="CN=test"`) + r.RemoteAddr = "10.0.0.1:12345" + + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) +} + +// ============================================================================= +// ValidateTLSClientAuth Integration Tests +// ============================================================================= + +func TestValidateTLSClientAuth_FullFlow(t *testing.T) { + // Create CA + ca, caKey := generateTestCA(t, pkix.Name{ + CommonName: "Test CA", + Organization: []string{"Test Org"}, + Country: []string{"US"}, + }) + + // Required OIDs + policyOID := asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 2, 1, 3, 13} + + // Create client cert with all required attributes + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{ + CommonName: "client.example.com", + Organization: []string{"Client Org"}, + Country: []string{"US"}, + }, + parent: ca, + parentKey: caKey, + extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + policyOIDs: []asn1.ObjectIdentifier{policyOID}, + }) + + // Setup trust store + trustStore := x509.NewCertPool() + trustStore.AddCert(ca) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + RequiredPolicyOIDs: []asn1.ObjectIdentifier{policyOID}, + RequiredEKUs: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "CN=client.example.com,O=Client Org,C=US", + } + + // Full validation should pass + err := op.ValidateTLSClientAuth([]*x509.Certificate{clientCert}, globalConfig, clientConfig) + require.NoError(t, err) +} + +func TestValidateTLSClientAuth_NilClientConfig(t *testing.T) { + ca, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client"}, + parent: ca, + parentKey: caKey, + }) + + trustStore := x509.NewCertPool() + trustStore.AddCert(ca) + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + } + + err := op.ValidateTLSClientAuth([]*x509.Certificate{clientCert}, globalConfig, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "no client configuration") +} + +func TestValidateTLSClientAuth_FailAtChainValidation(t *testing.T) { + // Create untrusted CA + untrustedCA, untrustedKey := generateTestCA(t, pkix.Name{CommonName: "Untrusted CA"}) + + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client.example.com"}, + parent: untrustedCA, + parentKey: untrustedKey, + }) + + // Trust store with different CA + trustedCA, _ := generateTestCA(t, pkix.Name{CommonName: "Trusted CA"}) + trustStore := x509.NewCertPool() + trustStore.AddCert(trustedCA) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + } + + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "CN=client.example.com", + } + + err := op.ValidateTLSClientAuth([]*x509.Certificate{clientCert}, globalConfig, clientConfig) + require.Error(t, err) + assert.Contains(t, err.Error(), "certificate chain") +} + +func TestValidateTLSClientAuth_FailAtPolicyOID(t *testing.T) { + ca, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + + // Client cert without required policy OID + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client.example.com"}, + parent: ca, + parentKey: caKey, + }) + + trustStore := x509.NewCertPool() + trustStore.AddCert(ca) + + requiredOID := asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 2, 1, 3, 13} + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + RequiredPolicyOIDs: []asn1.ObjectIdentifier{requiredOID}, + } + + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "CN=client.example.com", + } + + err := op.ValidateTLSClientAuth([]*x509.Certificate{clientCert}, globalConfig, clientConfig) + require.Error(t, err) + assert.Contains(t, err.Error(), "policy OID") +} + +func TestValidateTLSClientAuth_FailAtEKU(t *testing.T) { + ca, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + + // Client cert with ClientAuth (passes chain validation) but missing CodeSigning + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client.example.com"}, + parent: ca, + parentKey: caKey, + extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, // Has ClientAuth + }) + + trustStore := x509.NewCertPool() + trustStore.AddCert(ca) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + // Require CodeSigning in addition to ClientAuth (which is checked in chain validation) + RequiredEKUs: []x509.ExtKeyUsage{x509.ExtKeyUsageCodeSigning}, + } + + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "CN=client.example.com", + } + + err := op.ValidateTLSClientAuth([]*x509.Certificate{clientCert}, globalConfig, clientConfig) + require.Error(t, err) + assert.Contains(t, err.Error(), "EKU") +} + +func TestValidateTLSClientAuth_FailAtClientIdentifier(t *testing.T) { + ca, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "actual-client.example.com"}, + parent: ca, + parentKey: caKey, + }) + + trustStore := x509.NewCertPool() + trustStore.AddCert(ca) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + } + + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "CN=expected-client.example.com", // Mismatch! + } + + err := op.ValidateTLSClientAuth([]*x509.Certificate{clientCert}, globalConfig, clientConfig) + require.Error(t, err) + assert.Contains(t, err.Error(), "subject") +} + +func TestValidateTLSClientAuth_ClientSpecificPolicyOID(t *testing.T) { + ca, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + + globalOID := asn1.ObjectIdentifier{1, 2, 3, 4} + clientOID := asn1.ObjectIdentifier{5, 6, 7, 8} + + // Cert has global OID but not client-specific OID + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client.example.com"}, + parent: ca, + parentKey: caKey, + policyOIDs: []asn1.ObjectIdentifier{globalOID}, // Missing clientOID + }) + + trustStore := x509.NewCertPool() + trustStore.AddCert(ca) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + RequiredPolicyOIDs: []asn1.ObjectIdentifier{globalOID}, + } + + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "CN=client.example.com", + RequiredPolicyOIDs: []asn1.ObjectIdentifier{clientOID}, // Additional requirement + } + + err := op.ValidateTLSClientAuth([]*x509.Certificate{clientCert}, globalConfig, clientConfig) + require.Error(t, err) + assert.Contains(t, err.Error(), "policy OID") +} + +func TestValidateTLSClientAuth_ClientSpecificEKU(t *testing.T) { + ca, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + + // Cert has ClientAuth but client requires CodeSigning too + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client.example.com"}, + parent: ca, + parentKey: caKey, + extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + + trustStore := x509.NewCertPool() + trustStore.AddCert(ca) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + RequiredEKUs: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "CN=client.example.com", + RequiredEKUs: []x509.ExtKeyUsage{x509.ExtKeyUsageCodeSigning}, // Additional requirement + } + + err := op.ValidateTLSClientAuth([]*x509.Certificate{clientCert}, globalConfig, clientConfig) + require.Error(t, err) + assert.Contains(t, err.Error(), "EKU") +} + +// ============================================================================= +// Confirmation (cnf) Claim Tests - RFC 8705 Section 3 +// ============================================================================= + +func TestConfirmation_Serialization(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + thumbprint := op.CalculateCertThumbprint(cert) + + cnf := op.Confirmation{ + X5tS256: thumbprint, + } + + // Test that the struct can be properly created + assert.NotEmpty(t, cnf.X5tS256) + assert.Equal(t, thumbprint, cnf.X5tS256) +} + +func TestCreateCertificateBoundClaims(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + claims := op.CreateCertificateBoundClaims(cert) + + require.NotNil(t, claims) + require.NotNil(t, claims.Confirmation) + assert.NotEmpty(t, claims.Confirmation.X5tS256) + + // Verify the thumbprint matches + expectedThumbprint := op.CalculateCertThumbprint(cert) + assert.Equal(t, expectedThumbprint, claims.Confirmation.X5tS256) +} + +func TestCreateCertificateBoundClaims_NilCert(t *testing.T) { + claims := op.CreateCertificateBoundClaims(nil) + assert.Nil(t, claims) +} + +func TestVerifyCertificateBinding_WithConfirmation(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + // Create confirmation from same cert + cnf := &op.Confirmation{ + X5tS256: op.CalculateCertThumbprint(cert), + } + + // Should pass - same cert + err := op.VerifyCertificateBindingWithConfirmation(cert, cnf) + require.NoError(t, err) +} + +func TestVerifyCertificateBinding_WithConfirmation_Mismatch(t *testing.T) { + cert1, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client-1"}, + }) + cert2, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client-2"}, + }) + + // Create confirmation from cert1 + cnf := &op.Confirmation{ + X5tS256: op.CalculateCertThumbprint(cert1), + } + + // Verify with cert2 - should fail + err := op.VerifyCertificateBindingWithConfirmation(cert2, cnf) + require.Error(t, err) +} + +func TestVerifyCertificateBinding_NilConfirmation(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + // Nil confirmation should pass (no binding required) + err := op.VerifyCertificateBindingWithConfirmation(cert, nil) + require.NoError(t, err) +} + +func TestVerifyCertificateBinding_EmptyThumbprint(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + cnf := &op.Confirmation{ + X5tS256: "", // Empty + } + + // Empty thumbprint in confirmation should pass (no binding) + err := op.VerifyCertificateBindingWithConfirmation(cert, cnf) + require.NoError(t, err) +} + +// ============================================================================= +// Nil/Empty Input Handling Tests +// ============================================================================= + +func TestValidateMTLSConfig_Nil(t *testing.T) { + // Nil config should not cause panic + err := op.ValidateMTLSConfig(nil) + require.NoError(t, err) +} + +func TestClientCertificateFromRequest_NilConfig(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/token", nil) + + _, err := op.ClientCertificateFromRequest(r, nil) + require.Error(t, err) +} + +func TestValidateCertificateChain_EmptyCerts(t *testing.T) { + trustStore := x509.NewCertPool() + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + } + + err := op.ValidateCertificateChain([]*x509.Certificate{}, globalConfig, nil) + require.Error(t, err) +} + +func TestValidateCertificateChain_NilCerts(t *testing.T) { + trustStore := x509.NewCertPool() + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + } + + err := op.ValidateCertificateChain(nil, globalConfig, nil) + require.Error(t, err) +} + +func TestValidateCertificateChain_NilTrustStore(t *testing.T) { + ca, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client"}, + parent: ca, + parentKey: caKey, + }) + + globalConfig := &op.MTLSConfig{ + TrustStore: nil, // No trust store + } + + err := op.ValidateCertificateChain([]*x509.Certificate{clientCert}, globalConfig, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "trust store") +} + +func TestValidateClientIdentifier_NilConfig(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + }) + + err := op.ValidateClientIdentifier(cert, nil) + require.Error(t, err) +} + +func TestValidateClientIdentifier_EmptyConfig(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + }) + + // Config with no identifier set + clientConfig := &op.MTLSClientConfig{} + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.Error(t, err) + assert.Contains(t, err.Error(), "no client identifier") +} + +func TestValidateTLSClientAuth_EmptyCerts(t *testing.T) { + globalConfig := &op.MTLSConfig{ + TrustStore: x509.NewCertPool(), + } + + err := op.ValidateTLSClientAuth([]*x509.Certificate{}, globalConfig, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "no client certificate") +} + +func TestValidatePolicyOIDs_NilCert(t *testing.T) { + requiredOID := asn1.ObjectIdentifier{1, 2, 3} + + err := op.ValidatePolicyOIDs(nil, []asn1.ObjectIdentifier{requiredOID}) + require.Error(t, err) + assert.Contains(t, err.Error(), "nil certificate") +} + +// ============================================================================= +// Invalid Input Format Tests +// ============================================================================= + +func TestValidateMTLSConfig_InvalidCIDR(t *testing.T) { + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"not-a-valid-cidr"}, + } + + err := op.ValidateMTLSConfig(config) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid CIDR") +} + +func TestClientCertificateFromRequest_InvalidRemoteAddr(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", url.QueryEscape(certToPEM(cert))) + r.RemoteAddr = "invalid-address" // No port, not valid format + + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) +} + +func TestClientCertificateFromRequest_MalformedPEM(t *testing.T) { + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", url.QueryEscape("-----BEGIN CERTIFICATE-----\ninvalid-base64-data\n-----END CERTIFICATE-----")) + r.RemoteAddr = "10.0.0.1:12345" + + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) +} + +func TestClientCertificateFromRequest_EmptyHeader(t *testing.T) { + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", "") // Empty header + r.RemoteAddr = "10.0.0.1:12345" + + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) +} + +func TestClientCertificateFromRequest_UnsupportedFormat(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "unknown-format", // Unsupported + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", certToPEM(cert)) + r.RemoteAddr = "10.0.0.1:12345" + + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) +} + +func TestSubjectDN_InvalidDNFormat(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + }) + + // Empty CN value - cert has CN=test, config expects CN="" (empty) + // This should fail because the values don't match + clientConfig := &op.MTLSClientConfig{ + SubjectDN: "CN=", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + // RFC 4517 distinguishedNameMatch: empty value != "test" + require.Error(t, err, "empty CN value should not match non-empty cert CN") +} + +func TestSANIP_InvalidIPFormat(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + ipAddresses: []net.IP{net.ParseIP("192.168.1.1")}, + }) + + clientConfig := &op.MTLSClientConfig{ + SANIP: "not-an-ip-address", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.Error(t, err) +} + +func TestSANURI_InvalidURIFormat(t *testing.T) { + uri, _ := url.Parse("https://example.com") + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + uris: []*url.URL{uri}, + }) + + clientConfig := &op.MTLSClientConfig{ + SANURI: "://invalid-uri", + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.Error(t, err) +} + +func TestSANEmail_InvalidEmailFormat(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + emails: []string{"user@example.com"}, + }) + + clientConfig := &op.MTLSClientConfig{ + SANEmail: "no-at-sign", // Invalid email + } + + err := op.ValidateClientIdentifier(cert, clientConfig) + require.Error(t, err) +} + +// ============================================================================= +// Security Edge Case Tests +// ============================================================================= + +func TestClientCertificateFromRequest_IPv6TrustedProxy(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"2001:db8::/32"}, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", url.QueryEscape(certToPEM(cert))) + r.RemoteAddr = "[2001:db8::1]:12345" + + certs, err := op.ClientCertificateFromRequest(r, config) + require.NoError(t, err) + require.Len(t, certs, 1) +} + +func TestClientCertificateFromRequest_IPv6Untrusted(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"2001:db8::/32"}, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", url.QueryEscape(certToPEM(cert))) + r.RemoteAddr = "[2001:db9::1]:12345" // Different /32 + + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "not from trusted proxy") +} + +func TestClientCertificateFromRequest_MultipleTrustedCIDRs(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"}, + } + + tests := []struct { + name string + remoteAddr string + wantErr bool + }{ + {"10.x trusted", "10.1.2.3:12345", false}, + {"172.16.x trusted", "172.16.1.1:12345", false}, + {"192.168.x trusted", "192.168.1.1:12345", false}, + {"8.8.8.8 untrusted", "8.8.8.8:12345", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", url.QueryEscape(certToPEM(cert))) + r.RemoteAddr = tt.remoteAddr + + _, err := op.ClientCertificateFromRequest(r, config) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestCertificate_NoEKU(t *testing.T) { + ca, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + + // Cert with no EKU at all + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client"}, + parent: ca, + parentKey: caKey, + extKeyUsage: nil, // No EKU + }) + + trustStore := x509.NewCertPool() + trustStore.AddCert(ca) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustStore, + } + + // ValidateCertificateChain uses KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny} + // which means EKU is NOT enforced at chain validation level. + // EKU enforcement is handled separately by ValidateExtKeyUsage(). + err := op.ValidateCertificateChain([]*x509.Certificate{clientCert}, globalConfig, nil) + require.NoError(t, err, "chain validation should pass - EKU enforcement is separate") + + // Verify that EKU enforcement works separately when configured + err = op.ValidateExtKeyUsage(clientCert, []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}) + require.Error(t, err, "EKU validation should fail for cert without ClientAuth EKU") +} + +func TestCertificate_MultipleSANsSameType(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + dnsNames: []string{"first.example.com", "second.example.com", "third.example.com"}, + }) + + // Should match any of the SANs + tests := []struct { + name string + sanDNS string + wantErr bool + }{ + {"match first", "first.example.com", false}, + {"match second", "second.example.com", false}, + {"match third", "third.example.com", false}, + {"match none", "other.example.com", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clientConfig := &op.MTLSClientConfig{ + SANDNS: tt.sanDNS, + } + err := op.ValidateClientIdentifier(cert, clientConfig) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestSelfSignedTLSClientAuth_EmptyRegisteredCerts(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "self-signed"}, + }) + + err := op.ValidateSelfSignedTLSClientAuth(cert, []string{}) + require.Error(t, err) +} + +func TestSelfSignedTLSClientAuth_InvalidPEM(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "self-signed"}, + }) + + // Invalid PEM should be skipped, not cause error (but no match) + registeredCerts := []string{ + "not-valid-pem", + "-----BEGIN CERTIFICATE-----\ninvalid\n-----END CERTIFICATE-----", + } + + err := op.ValidateSelfSignedTLSClientAuth(cert, registeredCerts) + require.Error(t, err) + assert.Contains(t, err.Error(), "no matching") +} + +func TestVerifyCertificateBinding_NilCert(t *testing.T) { + err := op.VerifyCertificateBinding(nil, "some-thumbprint") + require.Error(t, err) + assert.Contains(t, err.Error(), "nil certificate") +} + +func TestCalculateCertThumbprint_Deterministic(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test"}, + }) + + // Thumbprint should be deterministic + tp1 := op.CalculateCertThumbprint(cert) + tp2 := op.CalculateCertThumbprint(cert) + tp3 := op.CalculateCertThumbprint(cert) + + assert.Equal(t, tp1, tp2) + assert.Equal(t, tp2, tp3) +} + +// ============================================================================= +// XFCC Additional Edge Cases +// ============================================================================= + +func TestClientCertificateFromRequest_Header_XFCC_MultipleElements(t *testing.T) { + cert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Forwarded-Client-Cert", + CertificateHeaderFormat: "xfcc", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + // XFCC with additional fields (Hash, Subject, etc.) + xfccValue := fmt.Sprintf(`Hash=abc123;Cert="%s";Subject="CN=test-client"`, + url.QueryEscape(certToPEM(cert))) + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Forwarded-Client-Cert", xfccValue) + r.RemoteAddr = "10.0.0.1:12345" + + certs, err := op.ClientCertificateFromRequest(r, config) + require.NoError(t, err) + require.Len(t, certs, 1) +} + +func TestClientCertificateFromRequest_Header_XFCC_EmptyCert(t *testing.T) { + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Forwarded-Client-Cert", + CertificateHeaderFormat: "xfcc", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + // XFCC with empty Cert value + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Forwarded-Client-Cert", `Cert=""`) + r.RemoteAddr = "10.0.0.1:12345" + + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) +} + +// ============================================================================= +// ValidateMTLSClientConfig Tests +// ============================================================================= + +func TestValidateMTLSClientConfig_NilConfig(t *testing.T) { + err := op.ValidateMTLSClientConfig(nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "no client configuration") +} + +func TestValidateMTLSClientConfig_NoIdentifier(t *testing.T) { + config := &op.MTLSClientConfig{} + err := op.ValidateMTLSClientConfig(config) + require.Error(t, err) + assert.Contains(t, err.Error(), "no client identifier") +} + +func TestValidateMTLSClientConfig_MultipleIdentifiers(t *testing.T) { + tests := []struct { + name string + config *op.MTLSClientConfig + }{ + { + name: "SubjectDN and SANDNS", + config: &op.MTLSClientConfig{ + SubjectDN: "CN=test", + SANDNS: "test.example.com", + }, + }, + { + name: "SubjectDN and SANURI", + config: &op.MTLSClientConfig{ + SubjectDN: "CN=test", + SANURI: "https://test.example.com", + }, + }, + { + name: "SANDNS and SANIP", + config: &op.MTLSClientConfig{ + SANDNS: "test.example.com", + SANIP: "192.168.1.1", + }, + }, + { + name: "All identifiers", + config: &op.MTLSClientConfig{ + SubjectDN: "CN=test", + SANDNS: "test.example.com", + SANURI: "https://test.example.com", + SANIP: "192.168.1.1", + SANEmail: "test@example.com", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := op.ValidateMTLSClientConfig(tt.config) + require.Error(t, err) + assert.Contains(t, err.Error(), "multiple client identifiers") + }) + } +} + +func TestValidateMTLSClientConfig_ExactlyOneIdentifier(t *testing.T) { + tests := []struct { + name string + config *op.MTLSClientConfig + }{ + { + name: "SubjectDN only", + config: &op.MTLSClientConfig{SubjectDN: "CN=test,O=Example,C=US"}, + }, + { + name: "SANDNS only", + config: &op.MTLSClientConfig{SANDNS: "client.example.com"}, + }, + { + name: "SANURI only", + config: &op.MTLSClientConfig{SANURI: "https://client.example.com"}, + }, + { + name: "SANIP only", + config: &op.MTLSClientConfig{SANIP: "192.168.1.100"}, + }, + { + name: "SANEmail only", + config: &op.MTLSClientConfig{SANEmail: "client@example.com"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := op.ValidateMTLSClientConfig(tt.config) + require.NoError(t, err) + }) + } +} + +// ============================================================================= +// Certificate Validity Period Tests +// ============================================================================= + +func TestCertificate_NotYetValid(t *testing.T) { + // Create a certificate with notBefore in the future + caCert, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{ + CommonName: "Future Client", + Organization: []string{"Test Org"}, + }, + notBefore: time.Now().Add(24 * time.Hour), // Not valid until tomorrow + notAfter: time.Now().Add(48 * time.Hour), + extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + parent: caCert, + parentKey: caKey, + }) + + trustPool := x509.NewCertPool() + trustPool.AddCert(caCert) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustPool, + } + + err := op.ValidateCertificateChain([]*x509.Certificate{clientCert}, globalConfig, nil) + require.Error(t, err) +} + +func TestCertificate_Expired(t *testing.T) { + // Create an expired certificate + caCert, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{ + CommonName: "Expired Client", + Organization: []string{"Test Org"}, + }, + notBefore: time.Now().Add(-48 * time.Hour), // Started 2 days ago + notAfter: time.Now().Add(-24 * time.Hour), // Expired yesterday + extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + parent: caCert, + parentKey: caKey, + }) + + trustPool := x509.NewCertPool() + trustPool.AddCert(caCert) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustPool, + } + + err := op.ValidateCertificateChain([]*x509.Certificate{clientCert}, globalConfig, nil) + require.Error(t, err) +} + +// ============================================================================= +// Loopback and Special IP Tests +// ============================================================================= + +func TestClientCertificateFromRequest_LoopbackIP(t *testing.T) { + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + certPEM := certToPEM(clientCert) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"127.0.0.0/8"}, + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", url.QueryEscape(certPEM)) + r.RemoteAddr = "127.0.0.1:12345" + + certs, err := op.ClientCertificateFromRequest(r, config) + require.NoError(t, err) + require.Len(t, certs, 1) +} + +func TestClientCertificateFromRequest_SingleIPCIDR(t *testing.T) { + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + certPEM := certToPEM(clientCert) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"10.0.0.1/32"}, // Single IP + } + + // Request from exact trusted IP + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", url.QueryEscape(certPEM)) + r.RemoteAddr = "10.0.0.1:12345" + + certs, err := op.ClientCertificateFromRequest(r, config) + require.NoError(t, err) + require.Len(t, certs, 1) + + // Request from adjacent IP (should fail) + r2 := httptest.NewRequest(http.MethodPost, "/token", nil) + r2.Header.Set("X-Client-Cert", url.QueryEscape(certPEM)) + r2.RemoteAddr = "10.0.0.2:12345" + + _, err = op.ClientCertificateFromRequest(r2, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "not from trusted proxy") +} + +func TestClientCertificateFromRequest_IPv6SingleIP(t *testing.T) { + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + certPEM := certToPEM(clientCert) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"::1/128"}, // IPv6 single IP (localhost) + } + + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", url.QueryEscape(certPEM)) + r.RemoteAddr = "[::1]:12345" + + certs, err := op.ClientCertificateFromRequest(r, config) + require.NoError(t, err) + require.Len(t, certs, 1) +} + +// ============================================================================= +// Certificate Chain Edge Cases +// ============================================================================= + +func TestValidateCertificateChain_IntermediateCA(t *testing.T) { + // Create root CA + rootCert, rootKey := generateTestCA(t, pkix.Name{CommonName: "Root CA", Organization: []string{"Test"}}) + + // Create intermediate CA + intermediateCert, intermediateKey := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "Intermediate CA", Organization: []string{"Test"}}, + isCA: true, + parent: rootCert, + parentKey: rootKey, + }) + + // Create client certificate signed by intermediate + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "Client", Organization: []string{"Test"}}, + extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + parent: intermediateCert, + parentKey: intermediateKey, + }) + + // Trust only root CA + trustPool := x509.NewCertPool() + trustPool.AddCert(rootCert) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustPool, + } + + // Validate with full chain (client + intermediate) + err := op.ValidateCertificateChain([]*x509.Certificate{clientCert, intermediateCert}, globalConfig, nil) + require.NoError(t, err) +} + +func TestValidateCertificateChain_MissingIntermediate(t *testing.T) { + // Create root CA + rootCert, rootKey := generateTestCA(t, pkix.Name{CommonName: "Root CA", Organization: []string{"Test"}}) + + // Create intermediate CA + intermediateCert, intermediateKey := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "Intermediate CA", Organization: []string{"Test"}}, + isCA: true, + parent: rootCert, + parentKey: rootKey, + }) + + // Create client certificate signed by intermediate + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "Client", Organization: []string{"Test"}}, + extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + parent: intermediateCert, + parentKey: intermediateKey, + }) + + // Trust only root CA + trustPool := x509.NewCertPool() + trustPool.AddCert(rootCert) + + globalConfig := &op.MTLSConfig{ + TrustStore: trustPool, + } + + // Validate without intermediate - should fail + err := op.ValidateCertificateChain([]*x509.Certificate{clientCert}, globalConfig, nil) + require.Error(t, err) +} + +// ============================================================================= +// Header Parsing Edge Cases +// ============================================================================= + +func TestClientCertificateFromRequest_HeaderWithWhitespace(t *testing.T) { + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + certPEM := certToPEM(clientCert) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + // URL-encode with leading/trailing spaces (after URL encoding) + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", " "+url.QueryEscape(certPEM)+" ") + r.RemoteAddr = "10.0.0.1:12345" + + // Should handle whitespace gracefully (implementation dependent) + _, err := op.ClientCertificateFromRequest(r, config) + // Either succeeds or fails with decoding error, not panic + if err != nil { + assert.NotContains(t, err.Error(), "panic") + } +} + +func TestClientCertificateFromRequest_DoubleURLEncoded(t *testing.T) { + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + certPEM := certToPEM(clientCert) + + config := &op.MTLSConfig{ + EnableProxyHeaders: true, + CertificateHeader: "X-Client-Cert", + CertificateHeaderFormat: "pem-urlencoded", + TrustedProxyCIDRs: []string{"10.0.0.0/8"}, + } + + // Double URL-encode (common mistake) + doubleEncoded := url.QueryEscape(url.QueryEscape(certPEM)) + r := httptest.NewRequest(http.MethodPost, "/token", nil) + r.Header.Set("X-Client-Cert", doubleEncoded) + r.RemoteAddr = "10.0.0.1:12345" + + // Should fail (not interpret as valid cert) + _, err := op.ClientCertificateFromRequest(r, config) + require.Error(t, err) +} + +// ============================================================================= +// Self-Signed Certificate Edge Cases +// ============================================================================= + +func TestSelfSignedTLSClientAuth_CertWithDifferentKey(t *testing.T) { + // Create self-signed certificate + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "self-signed-client"}, + }) + certPEM := certToPEM(clientCert) + + // Create another self-signed certificate (same subject but different key) + otherCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "self-signed-client"}, + }) + + // Try to validate with a different certificate registered + err := op.ValidateSelfSignedTLSClientAuth(otherCert, []string{certPEM}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no matching") +} + +func TestSelfSignedTLSClientAuth_MultipleRegisteredCerts(t *testing.T) { + // Create multiple self-signed certificates + cert1, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client-1"}, + }) + cert2, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client-2"}, + }) + cert3, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client-3"}, + }) + + registeredCerts := []string{ + certToPEM(cert1), + certToPEM(cert2), + certToPEM(cert3), + } + + // Validate each certificate + err := op.ValidateSelfSignedTLSClientAuth(cert1, registeredCerts) + require.NoError(t, err) + + err = op.ValidateSelfSignedTLSClientAuth(cert2, registeredCerts) + require.NoError(t, err) + + err = op.ValidateSelfSignedTLSClientAuth(cert3, registeredCerts) + require.NoError(t, err) + + // Validate unregistered certificate + unregisteredCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "unregistered"}, + }) + err = op.ValidateSelfSignedTLSClientAuth(unregisteredCert, registeredCerts) + require.Error(t, err) +} + +// ============================================================================= +// OID Validation Edge Cases +// ============================================================================= + +func TestValidatePolicyOIDs_EmptyOIDList(t *testing.T) { + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + // Empty required OIDs should pass + err := op.ValidatePolicyOIDs(clientCert, []asn1.ObjectIdentifier{}) + require.NoError(t, err) + + // Nil required OIDs should also pass + err = op.ValidatePolicyOIDs(clientCert, nil) + require.NoError(t, err) +} + +func TestValidateExtKeyUsage_EmptyEKUList(t *testing.T) { + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "test-client"}, + }) + + // Empty required EKUs should pass + err := op.ValidateExtKeyUsage(clientCert, []x509.ExtKeyUsage{}) + require.NoError(t, err) + + // Nil required EKUs should also pass + err = op.ValidateExtKeyUsage(clientCert, nil) + require.NoError(t, err) +} + +// ============================================================================= +// Subject DN Edge Cases +// ============================================================================= + +func TestValidateClientIdentifier_SubjectDN_CaseSensitivity(t *testing.T) { + // Create certificate with specific DN + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{ + CommonName: "Test Client", + Organization: []string{"Test Organization"}, + Country: []string{"US"}, + }, + extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + + // Exact match should work + err := op.ValidateClientIdentifier(clientCert, &op.MTLSClientConfig{ + SubjectDN: "CN=Test Client,O=Test Organization,C=US", + }) + require.NoError(t, err) + + // Note: RFC 4517 distinguishedNameMatch rules may allow case-insensitive + // comparison for certain attributes. Different implementations may vary. + // We test that a completely different value fails. + err = op.ValidateClientIdentifier(clientCert, &op.MTLSClientConfig{ + SubjectDN: "CN=Different Client,O=Test Organization,C=US", + }) + require.Error(t, err) +} + +func TestValidateClientIdentifier_SubjectDN_SpecialCharacters(t *testing.T) { + // Create certificate with special characters in DN + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{ + CommonName: "Test, Client + Special", + Organization: []string{"Test \"Org\""}, + Country: []string{"US"}, + }, + extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + + // RFC 4514 escaped format + err := op.ValidateClientIdentifier(clientCert, &op.MTLSClientConfig{ + SubjectDN: "CN=Test\\, Client \\+ Special,O=Test \\\"Org\\\",C=US", + }) + require.NoError(t, err) +} + +// ============================================================================= +// Certificate-Bound Token Helper Functions Tests +// ============================================================================= + +func TestGetCnfThumbprintFromClaims(t *testing.T) { + tests := []struct { + name string + claims map[string]any + expected string + }{ + { + name: "nil claims", + claims: nil, + expected: "", + }, + { + name: "empty claims", + claims: map[string]any{}, + expected: "", + }, + { + name: "no cnf claim", + claims: map[string]any{ + "sub": "user123", + }, + expected: "", + }, + { + name: "cnf claim is not a map", + claims: map[string]any{ + "cnf": "invalid", + }, + expected: "", + }, + { + name: "cnf claim without x5t#S256", + claims: map[string]any{ + "cnf": map[string]any{ + "other": "value", + }, + }, + expected: "", + }, + { + name: "cnf claim with x5t#S256 as non-string", + claims: map[string]any{ + "cnf": map[string]any{ + "x5t#S256": 12345, + }, + }, + expected: "", + }, + { + name: "valid cnf claim", + claims: map[string]any{ + "cnf": map[string]any{ + "x5t#S256": "bwcK0esc3ACC3DB2Y5_lESsXE8o9ltc05O89jdN-dg2", + }, + }, + expected: "bwcK0esc3ACC3DB2Y5_lESsXE8o9ltc05O89jdN-dg2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := op.GetCnfThumbprintFromClaims(tt.claims) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGetCnfFromIntrospectionResponse(t *testing.T) { + tests := []struct { + name string + claims map[string]any + expectNil bool + expectedX5t string + }{ + { + name: "nil claims", + claims: nil, + expectNil: true, + }, + { + name: "no cnf claim", + claims: map[string]any{"sub": "user"}, + expectNil: true, + }, + { + name: "valid cnf claim", + claims: map[string]any{ + "cnf": map[string]any{ + "x5t#S256": "test-thumbprint", + }, + }, + expectNil: false, + expectedX5t: "test-thumbprint", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := op.GetCnfFromIntrospectionResponse(tt.claims) + if tt.expectNil { + assert.Nil(t, result) + } else { + require.NotNil(t, result) + assert.Equal(t, tt.expectedX5t, result.X5tS256) + } + }) + } +} + +func TestVerifyCertificateBindingFromRequest(t *testing.T) { + // Generate a test certificate + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{ + CommonName: "Test Client", + }, + extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + + // Calculate the expected thumbprint + expectedThumbprint := op.CalculateCertThumbprint(clientCert) + + // Create a request with the certificate + r := httptest.NewRequest(http.MethodPost, "/userinfo", nil) + r.TLS = &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{clientCert}, + } + + mtlsConfig := &op.MTLSConfig{} + + tests := []struct { + name string + request *http.Request + mtlsConfig *op.MTLSConfig + cnfThumbprint string + expectError bool + }{ + { + name: "empty thumbprint - no verification needed", + request: r, + mtlsConfig: mtlsConfig, + cnfThumbprint: "", + expectError: false, + }, + { + name: "matching thumbprint", + request: r, + mtlsConfig: mtlsConfig, + cnfThumbprint: expectedThumbprint, + expectError: false, + }, + { + name: "mismatched thumbprint", + request: r, + mtlsConfig: mtlsConfig, + cnfThumbprint: "wrong-thumbprint", + expectError: true, + }, + { + name: "nil mtls config with thumbprint", + request: r, + mtlsConfig: nil, + cnfThumbprint: expectedThumbprint, + expectError: true, + }, + { + name: "no certificate in request", + request: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/userinfo", nil) + return req + }(), + mtlsConfig: mtlsConfig, + cnfThumbprint: expectedThumbprint, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := op.VerifyCertificateBindingFromRequest(tt.request, tt.mtlsConfig, tt.cnfThumbprint) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestVerifyCertificateBindingForIntrospection(t *testing.T) { + // Generate a test certificate + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{ + CommonName: "Test Client", + }, + extKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + + // Calculate the expected thumbprint + expectedThumbprint := op.CalculateCertThumbprint(clientCert) + + // Create a request with the certificate + r := httptest.NewRequest(http.MethodPost, "/resource", nil) + r.TLS = &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{clientCert}, + } + + mtlsConfig := &op.MTLSConfig{} + + tests := []struct { + name string + claims map[string]any + expectError bool + }{ + { + name: "nil claims - no verification", + claims: nil, + expectError: false, + }, + { + name: "no cnf claim - no verification", + claims: map[string]any{"sub": "user"}, + expectError: false, + }, + { + name: "matching cnf claim", + claims: map[string]any{ + "cnf": map[string]any{ + "x5t#S256": expectedThumbprint, + }, + }, + expectError: false, + }, + { + name: "mismatched cnf claim", + claims: map[string]any{ + "cnf": map[string]any{ + "x5t#S256": "wrong-thumbprint", + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := op.VerifyCertificateBindingForIntrospection(r, mtlsConfig, tt.claims) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/op/op.go b/pkg/op/op.go index df341f60..de979550 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -171,6 +171,14 @@ type Config struct { DeviceAuthorization DeviceAuthorizationConfig BackChannelLogoutSupported bool BackChannelLogoutSessionSupported bool + + // mTLS settings (RFC 8705) + AuthMethodTLSClientAuth bool + AuthMethodSelfSignedTLSClientAuth bool + TLSClientCertificateBoundAccessTokens bool + MTLSConfig *MTLSConfig + // MTLSEndpointAliases provides alternative endpoints for mTLS connections. + MTLSEndpointAliases *oidc.MTLSEndpointAliases } // Endpoints defines endpoint routes. @@ -276,6 +284,9 @@ func NewProvider( return nil, err } } + if err := ValidateMTLSConfig(config.MTLSConfig); err != nil { + return nil, err + } o.issuer, err = issuer(o.insecure) if err != nil { @@ -431,6 +442,26 @@ func (o *Provider) BackChannelLogoutSessionSupported() bool { return o.config.BackChannelLogoutSessionSupported } +func (o *Provider) AuthMethodTLSClientAuthSupported() bool { + return o.config.AuthMethodTLSClientAuth +} + +func (o *Provider) AuthMethodSelfSignedTLSClientAuthSupported() bool { + return o.config.AuthMethodSelfSignedTLSClientAuth +} + +func (o *Provider) TLSClientCertificateBoundAccessTokensSupported() bool { + return o.config.TLSClientCertificateBoundAccessTokens +} + +func (o *Provider) MTLSConfig() *MTLSConfig { + return o.config.MTLSConfig +} + +func (o *Provider) MTLSEndpointAliases() *oidc.MTLSEndpointAliases { + return o.config.MTLSEndpointAliases +} + func (o *Provider) Storage() Storage { return o.storage } diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index c1520e22..fd18ae5a 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -3,6 +3,7 @@ package op_test import ( "context" "crypto/sha256" + "encoding/json" "io" "net/http" "net/http/httptest" @@ -134,6 +135,8 @@ func TestRoutes(t *testing.T) { headerContains map[string]string json string // test for exact json output contains []string // when the body output is not constant, we just check for snippets to be present in the response + expiresInMin int + expiresInMax int }{ { name: "health", @@ -220,8 +223,12 @@ func TestRoutes(t *testing.T) { wantCode: http.StatusOK, contains: []string{ `{"access_token":"`, - `","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`, + `","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer"`, + `"scope":"openid offline_access"`, + `"refresh_token":"`, }, + expiresInMin: 299, + expiresInMax: 300, }, { name: "Client credentials exchange", @@ -232,8 +239,10 @@ func TestRoutes(t *testing.T) { "grant_type": string(oidc.GrantTypeClientCredentials), "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), }, - wantCode: http.StatusOK, - contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299,"scope":"openid offline_access"}`}, + wantCode: http.StatusOK, + contains: []string{`{"access_token":"`, `"token_type":"Bearer"`, `"scope":"openid offline_access"`}, + expiresInMin: 299, + expiresInMax: 300, }, { // This call will fail. A successful test is already @@ -304,8 +313,11 @@ func TestRoutes(t *testing.T) { contains: []string{ `{"access_token":"`, `","token_type":"Bearer","refresh_token":"`, - `","expires_in":299,"id_token":"`, + `","expires_in":`, + `,"id_token":"`, }, + expiresInMin: 299, + expiresInMax: 300, }, { name: "revoke", @@ -390,6 +402,22 @@ func TestRoutes(t *testing.T) { t.Log(respBodyString) t.Log(resp.Header) + if tt.expiresInMin > 0 || tt.expiresInMax > 0 { + var payload map[string]any + require.NoError(t, json.Unmarshal(respBody, &payload)) + raw, ok := payload["expires_in"] + require.True(t, ok) + expires, ok := raw.(float64) + require.True(t, ok) + expiresInt := int(expires) + if tt.expiresInMin > 0 { + assert.GreaterOrEqual(t, expiresInt, tt.expiresInMin) + } + if tt.expiresInMax > 0 { + assert.LessOrEqual(t, expiresInt, tt.expiresInMax) + } + } + if tt.json != "" { assert.JSONEq(t, tt.json, respBodyString) } diff --git a/pkg/op/server_http_routes_test.go b/pkg/op/server_http_routes_test.go index 35b07694..e50226a7 100644 --- a/pkg/op/server_http_routes_test.go +++ b/pkg/op/server_http_routes_test.go @@ -2,6 +2,7 @@ package op_test import ( "context" + "encoding/json" "io" "net/http" "net/http/httptest" @@ -86,6 +87,8 @@ func TestServerRoutes(t *testing.T) { headerContains map[string]string json string // test for exact json output contains []string // when the body output is not constant, we just check for snippets to be present in the response + expiresInMin int + expiresInMax int }{ { name: "health", @@ -146,8 +149,10 @@ func TestServerRoutes(t *testing.T) { "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "assertion": jwtProfileToken, }, - wantCode: http.StatusOK, - contains: []string{`{"access_token":`, `"token_type":"Bearer","expires_in":299,"scope":"openid"}`}, + wantCode: http.StatusOK, + contains: []string{`{"access_token":`, `"token_type":"Bearer"`, `"scope":"openid"`}, + expiresInMin: 299, + expiresInMax: 300, }, { name: "Token exchange", @@ -163,8 +168,12 @@ func TestServerRoutes(t *testing.T) { wantCode: http.StatusOK, contains: []string{ `{"access_token":"`, - `","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`, + `","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer"`, + `"scope":"openid offline_access"`, + `"refresh_token":"`, }, + expiresInMin: 299, + expiresInMax: 300, }, { name: "Client credentials exchange", @@ -175,8 +184,10 @@ func TestServerRoutes(t *testing.T) { "grant_type": string(oidc.GrantTypeClientCredentials), "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), }, - wantCode: http.StatusOK, - contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299,"scope":"openid offline_access"}`}, + wantCode: http.StatusOK, + contains: []string{`{"access_token":"`, `"token_type":"Bearer"`, `"scope":"openid offline_access"`}, + expiresInMin: 299, + expiresInMax: 300, }, { // This call will fail. A successful test is already @@ -247,8 +258,11 @@ func TestServerRoutes(t *testing.T) { contains: []string{ `{"access_token":"`, `","token_type":"Bearer","refresh_token":"`, - `","expires_in":299,"id_token":"`, + `","expires_in":`, + `,"id_token":"`, }, + expiresInMin: 299, + expiresInMax: 300, }, { name: "revoke", @@ -333,6 +347,22 @@ func TestServerRoutes(t *testing.T) { t.Log(respBodyString) t.Log(resp.Header) + if tt.expiresInMin > 0 || tt.expiresInMax > 0 { + var payload map[string]any + require.NoError(t, json.Unmarshal(respBody, &payload)) + raw, ok := payload["expires_in"] + require.True(t, ok) + expires, ok := raw.(float64) + require.True(t, ok) + expiresInt := int(expires) + if tt.expiresInMin > 0 { + assert.GreaterOrEqual(t, expiresInt, tt.expiresInMin) + } + if tt.expiresInMax > 0 { + assert.LessOrEqual(t, expiresInt, tt.expiresInMax) + } + } + if tt.json != "" { assert.JSONEq(t, tt.json, respBodyString) } diff --git a/pkg/op/token.go b/pkg/op/token.go index fee57ad9..f6896b07 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -178,6 +178,15 @@ func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, ex if actorReq, ok := tokenRequest.(TokenActorRequest); ok { claims.Actor = actorReq.GetActor() } + // Add certificate-bound token cnf claim if thumbprint is in context (RFC 8705) + if thumbprint := CertThumbprintFromContext(ctx); thumbprint != "" { + if claims.Claims == nil { + claims.Claims = make(map[string]any) + } + claims.Claims["cnf"] = map[string]string{ + "x5t#S256": thumbprint, + } + } signingKey, err := storage.SigningKey(ctx) if err != nil { return "", err diff --git a/pkg/op/token_client_credentials.go b/pkg/op/token_client_credentials.go index e209b64a..a8429671 100644 --- a/pkg/op/token_client_credentials.go +++ b/pkg/op/token_client_credentials.go @@ -19,15 +19,50 @@ func ClientCredentialsExchange(w http.ResponseWriter, r *http.Request, exchanger request, err := ParseClientCredentialsRequest(r, exchanger.Decoder()) if err != nil { RequestError(w, r, err, exchanger.Logger()) + return } - validatedRequest, client, err := ValidateClientCredentialsRequest(r.Context(), request, exchanger) - if err != nil { - RequestError(w, r, err, exchanger.Logger()) - return + var ( + validatedRequest TokenRequest + client Client + ) + + // mTLS client authentication for client_credentials (RFC 8705) + // Unlike other flows, this handler doesn't use ClientIDFromRequest, so we must validate here. + if mtls, ok := exchanger.(mtlsClientCredentialsSupport); ok && request.ClientID != "" && + (mtls.AuthMethodTLSClientAuthSupported() || mtls.AuthMethodSelfSignedTLSClientAuthSupported()) { + c, err := exchanger.Storage().GetClientByClientID(r.Context(), request.ClientID) + if err == nil && (c.AuthMethod() == oidc.AuthMethodTLSClientAuth || c.AuthMethod() == oidc.AuthMethodSelfSignedTLSClientAuth) { + validatedRequest, client, err = validateClientCredentialsRequestMTLS(r.Context(), r, request, exchanger, c) + if err != nil { + RequestError(w, r, err, exchanger.Logger()) + return + } + } } - resp, err := CreateClientCredentialsTokenResponse(r.Context(), validatedRequest, exchanger, client) + if validatedRequest == nil { + validatedRequest, client, err = ValidateClientCredentialsRequest(r.Context(), request, exchanger) + if err != nil { + RequestError(w, r, err, exchanger.Logger()) + return + } + } + + // Set certificate thumbprint in context for certificate-bound tokens (RFC 8705) + tokenCtx := r.Context() + if mtlsProvider, ok := exchanger.(interface{ MTLSConfig() *MTLSConfig }); ok { + boundSupported := false + if s, ok := exchanger.(interface{ TLSClientCertificateBoundAccessTokensSupported() bool }); ok { + boundSupported = s.TLSClientCertificateBoundAccessTokensSupported() + } + tokenCtx, err = SetCertThumbprintInContext(tokenCtx, r, client, mtlsProvider.MTLSConfig(), boundSupported) + if err != nil { + RequestError(w, r, err, exchanger.Logger()) + return + } + } + resp, err := CreateClientCredentialsTokenResponse(tokenCtx, validatedRequest, exchanger, client) if err != nil { RequestError(w, r, err, exchanger.Logger()) return @@ -36,6 +71,68 @@ func ClientCredentialsExchange(w http.ResponseWriter, r *http.Request, exchanger httphelper.MarshalJSON(w, resp) } +type mtlsClientCredentialsSupport interface { + MTLSConfig() *MTLSConfig + AuthMethodTLSClientAuthSupported() bool + AuthMethodSelfSignedTLSClientAuthSupported() bool +} + +func validateClientCredentialsRequestMTLS(ctx context.Context, r *http.Request, request *oidc.ClientCredentialsRequest, exchanger Exchanger, client Client) (TokenRequest, Client, error) { + storage, ok := exchanger.Storage().(ClientCredentialsStorage) + if !ok { + return nil, nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported") + } + mtls, ok := exchanger.(mtlsClientCredentialsSupport) + if !ok { + return nil, nil, oidc.ErrInvalidClient().WithDescription("mTLS authentication not supported") + } + mtlsConfig := mtls.MTLSConfig() + + certs, err := ClientCertificateFromRequest(r, mtlsConfig) + if err != nil || len(certs) == 0 { + return nil, nil, oidc.ErrInvalidClient().WithDescription("no client certificate provided") + } + + switch client.AuthMethod() { + case oidc.AuthMethodTLSClientAuth: + if !mtls.AuthMethodTLSClientAuthSupported() { + return nil, nil, oidc.ErrInvalidClient().WithDescription("tls_client_auth not supported") + } + mtlsClient, ok := client.(HasMTLSConfig) + if !ok { + return nil, nil, oidc.ErrInvalidClient().WithDescription("client does not support mTLS configuration") + } + if err := ValidateTLSClientAuth(certs, mtlsConfig, mtlsClient.GetMTLSConfig()); err != nil { + return nil, nil, oidc.ErrInvalidClient().WithDescription("mTLS client authentication failed").WithParent(err) + } + + case oidc.AuthMethodSelfSignedTLSClientAuth: + if !mtls.AuthMethodSelfSignedTLSClientAuthSupported() { + return nil, nil, oidc.ErrInvalidClient().WithDescription("self_signed_tls_client_auth not supported") + } + selfSignedClient, ok := client.(HasSelfSignedCertificate) + if !ok { + return nil, nil, oidc.ErrInvalidClient().WithDescription("client does not support self-signed certificates") + } + if err := ValidateSelfSignedTLSClientAuth(certs[0], selfSignedClient.GetRegisteredCertificates()); err != nil { + return nil, nil, oidc.ErrInvalidClient().WithDescription("mTLS client authentication failed").WithParent(err) + } + + default: + return nil, nil, oidc.ErrInvalidClient() + } + + if !ValidateGrantType(client, oidc.GrantTypeClientCredentials) { + return nil, nil, oidc.ErrUnauthorizedClient() + } + + tokenRequest, err := storage.ClientCredentialsTokenRequest(ctx, request.ClientID, request.Scope) + if err != nil { + return nil, nil, err + } + return tokenRequest, client, nil +} + // ParseClientCredentialsRequest parsed the http request into a oidc.ClientCredentialsRequest func ParseClientCredentialsRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.ClientCredentialsRequest, error) { err := r.ParseForm() diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index 5ed890f9..cde75084 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -18,6 +18,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder()) if err != nil { RequestError(w, r, err, exchanger.Logger()) + return } if tokenReq.Code == "" { RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), exchanger.Logger()) @@ -28,7 +29,33 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { RequestError(w, r, err, exchanger.Logger()) return } - resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code, "") + tokenCtx := r.Context() + // Enforce mTLS client authentication for mTLS clients (RFC 8705). + if client.AuthMethod() == oidc.AuthMethodTLSClientAuth || client.AuthMethod() == oidc.AuthMethodSelfSignedTLSClientAuth { + mtlsProvider, ok := exchanger.(mtlsClientAuthSupport) + if !ok { + RequestError(w, r, oidc.ErrInvalidClient().WithDescription("mTLS authentication not supported"), exchanger.Logger()) + return + } + tokenCtx, err = validateMTLSClientAuthForClient(tokenCtx, r, mtlsProvider, client) + if err != nil { + RequestError(w, r, err, exchanger.Logger()) + return + } + } + // Set certificate thumbprint in context for certificate-bound tokens (RFC 8705) + if mtlsProvider, ok := exchanger.(interface{ MTLSConfig() *MTLSConfig }); ok { + boundSupported := false + if s, ok := exchanger.(interface{ TLSClientCertificateBoundAccessTokensSupported() bool }); ok { + boundSupported = s.TLSClientCertificateBoundAccessTokensSupported() + } + tokenCtx, err = SetCertThumbprintInContext(tokenCtx, r, client, mtlsProvider.MTLSConfig(), boundSupported) + if err != nil { + RequestError(w, r, err, exchanger.Logger()) + return + } + } + resp, err := CreateTokenResponse(tokenCtx, authReq, client, exchanger, true, tokenReq.Code, "") if err != nil { RequestError(w, r, err, exchanger.Logger()) return @@ -104,6 +131,13 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT { return nil, nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client") } + // mTLS authentication (tls_client_auth, self_signed_tls_client_auth) + // The actual mTLS validation is performed in ClientIDFromRequest/ClientMTLSAuth. + // If we reach here with an mTLS auth method, the client was already authenticated. + if client.AuthMethod() == oidc.AuthMethodTLSClientAuth || + client.AuthMethod() == oidc.AuthMethodSelfSignedTLSClientAuth { + return request, client, nil + } if client.AuthMethod() == oidc.AuthMethodNone { if codeChallenge == nil { return nil, nil, oidc.ErrInvalidRequest().WithDescription("PKCE required") diff --git a/pkg/op/token_exchange.go b/pkg/op/token_exchange.go index 5c762a04..14307b16 100644 --- a/pkg/op/token_exchange.go +++ b/pkg/op/token_exchange.go @@ -141,6 +141,55 @@ func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) tokenExchangeReq, clientID, clientSecret, err := ParseTokenExchangeRequest(r, exchanger.Decoder()) if err != nil { RequestError(w, r, err, exchanger.Logger()) + return + } + if clientID == "" { + clientID = r.FormValue("client_id") + } + + tokenCtx := r.Context() + + // Prefer mTLS client authentication when the client is registered for it (RFC 8705). + if clientID != "" { + if c, err := exchanger.Storage().GetClientByClientID(tokenCtx, clientID); err == nil && + (c.AuthMethod() == oidc.AuthMethodTLSClientAuth || c.AuthMethod() == oidc.AuthMethodSelfSignedTLSClientAuth) { + mtlsProvider, ok := exchanger.(mtlsClientAuthSupport) + if !ok { + RequestError(w, r, oidc.ErrInvalidClient().WithDescription("mTLS authentication not supported"), exchanger.Logger()) + return + } + tokenCtx, err = validateMTLSClientAuthForClient(tokenCtx, r, mtlsProvider, c) + if err != nil { + RequestError(w, r, err, exchanger.Logger()) + return + } + + tokenExchangeRequest, client, err := ValidateTokenExchangeRequestAuthenticatedClient(tokenCtx, tokenExchangeReq, c, exchanger) + if err != nil { + RequestError(w, r, err, exchanger.Logger()) + return + } + // Set certificate thumbprint in context for certificate-bound tokens (RFC 8705) + if mtlsProvider, ok := exchanger.(interface{ MTLSConfig() *MTLSConfig }); ok { + boundSupported := false + if s, ok := exchanger.(interface{ TLSClientCertificateBoundAccessTokensSupported() bool }); ok { + boundSupported = s.TLSClientCertificateBoundAccessTokensSupported() + } + tokenCtx, err = SetCertThumbprintInContext(tokenCtx, r, client, mtlsProvider.MTLSConfig(), boundSupported) + if err != nil { + RequestError(w, r, err, exchanger.Logger()) + return + } + } + + resp, err := CreateTokenExchangeResponse(tokenCtx, tokenExchangeRequest, client, exchanger) + if err != nil { + RequestError(w, r, err, exchanger.Logger()) + return + } + httphelper.MarshalJSON(w, resp) + return + } } tokenExchangeRequest, client, err := ValidateTokenExchangeRequest(r.Context(), tokenExchangeReq, clientID, clientSecret, exchanger) @@ -148,7 +197,19 @@ func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) RequestError(w, r, err, exchanger.Logger()) return } - resp, err := CreateTokenExchangeResponse(r.Context(), tokenExchangeRequest, client, exchanger) + // Set certificate thumbprint in context for certificate-bound tokens (RFC 8705) + if mtlsProvider, ok := exchanger.(interface{ MTLSConfig() *MTLSConfig }); ok { + boundSupported := false + if s, ok := exchanger.(interface{ TLSClientCertificateBoundAccessTokensSupported() bool }); ok { + boundSupported = s.TLSClientCertificateBoundAccessTokensSupported() + } + tokenCtx, err = SetCertThumbprintInContext(tokenCtx, r, client, mtlsProvider.MTLSConfig(), boundSupported) + if err != nil { + RequestError(w, r, err, exchanger.Logger()) + return + } + } + resp, err := CreateTokenExchangeResponse(tokenCtx, tokenExchangeRequest, client, exchanger) if err != nil { RequestError(w, r, err, exchanger.Logger()) return @@ -228,6 +289,48 @@ func ValidateTokenExchangeRequest( return req, client, nil } +// ValidateTokenExchangeRequestAuthenticatedClient validates a token exchange request for a client that was already authenticated. +// This is used for mTLS-based client authentication where no client_secret is expected. +func ValidateTokenExchangeRequestAuthenticatedClient( + ctx context.Context, + oidcTokenExchangeRequest *oidc.TokenExchangeRequest, + client Client, + exchanger Exchanger, +) (TokenExchangeRequest, Client, error) { + ctx, span := Tracer.Start(ctx, "ValidateTokenExchangeRequestAuthenticatedClient") + defer span.End() + + if oidcTokenExchangeRequest.SubjectToken == "" { + return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token missing") + } + + if oidcTokenExchangeRequest.SubjectTokenType == "" { + return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing") + } + + if client == nil { + return nil, nil, oidc.ErrInvalidClient() + } + + if oidcTokenExchangeRequest.RequestedTokenType != "" && !oidcTokenExchangeRequest.RequestedTokenType.IsSupported() { + return nil, nil, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported") + } + + if !oidcTokenExchangeRequest.SubjectTokenType.IsSupported() { + return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported") + } + + if oidcTokenExchangeRequest.ActorTokenType != "" && !oidcTokenExchangeRequest.ActorTokenType.IsSupported() { + return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported") + } + + req, err := CreateTokenExchangeRequest(ctx, oidcTokenExchangeRequest, client, exchanger) + if err != nil { + return nil, nil, err + } + return req, client, nil +} + func CreateTokenExchangeRequest( ctx context.Context, oidcTokenExchangeRequest *oidc.TokenExchangeRequest, diff --git a/pkg/op/token_jwt_profile.go b/pkg/op/token_jwt_profile.go index fd7fdd6e..d0442446 100644 --- a/pkg/op/token_jwt_profile.go +++ b/pkg/op/token_jwt_profile.go @@ -23,6 +23,7 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizati profileRequest, err := ParseJWTProfileGrantRequest(r, exchanger.Decoder()) if err != nil { RequestError(w, r, err, exchanger.Logger()) + return } tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier(r.Context())) diff --git a/pkg/op/token_refresh.go b/pkg/op/token_refresh.go index 4ffd551f..a9b6e380 100644 --- a/pkg/op/token_refresh.go +++ b/pkg/op/token_refresh.go @@ -31,13 +31,40 @@ func RefreshTokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exch tokenReq, err := ParseRefreshTokenRequest(r, exchanger.Decoder()) if err != nil { RequestError(w, r, err, exchanger.Logger()) + return } validatedRequest, client, err := ValidateRefreshTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { RequestError(w, r, err, exchanger.Logger()) return } - resp, err := CreateTokenResponse(r.Context(), validatedRequest, client, exchanger, true, "", tokenReq.RefreshToken) + tokenCtx := r.Context() + // Enforce mTLS client authentication for mTLS clients (RFC 8705). + if client.AuthMethod() == oidc.AuthMethodTLSClientAuth || client.AuthMethod() == oidc.AuthMethodSelfSignedTLSClientAuth { + mtlsProvider, ok := exchanger.(mtlsClientAuthSupport) + if !ok { + RequestError(w, r, oidc.ErrInvalidClient().WithDescription("mTLS authentication not supported"), exchanger.Logger()) + return + } + tokenCtx, err = validateMTLSClientAuthForClient(tokenCtx, r, mtlsProvider, client) + if err != nil { + RequestError(w, r, err, exchanger.Logger()) + return + } + } + // Set certificate thumbprint in context for certificate-bound tokens (RFC 8705) + if mtlsProvider, ok := exchanger.(interface{ MTLSConfig() *MTLSConfig }); ok { + boundSupported := false + if s, ok := exchanger.(interface{ TLSClientCertificateBoundAccessTokensSupported() bool }); ok { + boundSupported = s.TLSClientCertificateBoundAccessTokensSupported() + } + tokenCtx, err = SetCertThumbprintInContext(tokenCtx, r, client, mtlsProvider.MTLSConfig(), boundSupported) + if err != nil { + RequestError(w, r, err, exchanger.Logger()) + return + } + } + resp, err := CreateTokenResponse(tokenCtx, validatedRequest, client, exchanger, true, "", tokenReq.RefreshToken) if err != nil { RequestError(w, r, err, exchanger.Logger()) return @@ -124,6 +151,14 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT { return nil, nil, oidc.ErrInvalidClient() } + // mTLS authentication (tls_client_auth, self_signed_tls_client_auth) + // The actual mTLS validation is performed in ClientIDFromRequest/ClientMTLSAuth. + // If we reach here with an mTLS auth method, the client was already authenticated. + if client.AuthMethod() == oidc.AuthMethodTLSClientAuth || + client.AuthMethod() == oidc.AuthMethodSelfSignedTLSClientAuth { + request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken) + return request, client, err + } if client.AuthMethod() == oidc.AuthMethodNone { request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken) return request, client, err diff --git a/pkg/op/token_revocation.go b/pkg/op/token_revocation.go index ec792a6a..0346a049 100644 --- a/pkg/op/token_revocation.go +++ b/pkg/op/token_revocation.go @@ -101,8 +101,9 @@ func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, token } return "", "", "", err } - clientID, clientSecret, ok := r.BasicAuth() - if ok { + + var basicID, basicSecret string + if clientID, clientSecret, ok := r.BasicAuth(); ok { clientID, err = url.QueryUnescape(clientID) if err != nil { return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) @@ -111,10 +112,33 @@ func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, token if err != nil { return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) } - if err = AuthorizeClientIDSecret(r.Context(), clientID, clientSecret, revoker.Storage()); err != nil { + basicID = clientID + basicSecret = clientSecret + } + + candidateClientID := req.ClientID + if candidateClientID == "" { + candidateClientID = basicID + } + if candidateClientID != "" { + if client, err := revoker.Storage().GetClientByClientID(r.Context(), candidateClientID); err == nil && + (client.AuthMethod() == oidc.AuthMethodTLSClientAuth || client.AuthMethod() == oidc.AuthMethodSelfSignedTLSClientAuth) { + mtlsProvider, ok := revoker.(mtlsClientAuthSupport) + if !ok { + return "", "", "", oidc.ErrInvalidClient().WithDescription("mTLS authentication not supported") + } + if _, err := validateMTLSClientAuthForClient(r.Context(), r, mtlsProvider, client); err != nil { + return "", "", "", err + } + return req.Token, req.TokenTypeHint, client.GetID(), nil + } + } + + if basicID != "" { + if err = AuthorizeClientIDSecret(r.Context(), basicID, basicSecret, revoker.Storage()); err != nil { return "", "", "", err } - return req.Token, req.TokenTypeHint, clientID, nil + return req.Token, req.TokenTypeHint, basicID, nil } if req.ClientID == "" { return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid authorization") diff --git a/pkg/op/token_revocation_mtls_test.go b/pkg/op/token_revocation_mtls_test.go new file mode 100644 index 00000000..2fcedb7b --- /dev/null +++ b/pkg/op/token_revocation_mtls_test.go @@ -0,0 +1,271 @@ +package op_test + +import ( + "context" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "github.com/zitadel/schema" + + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "github.com/zitadel/oidc/v3/pkg/op/mock" +) + +type revocationMTLSTestRevoker struct { + decoder httphelper.Decoder + storage op.Storage + mtlsConfig *op.MTLSConfig + tlsSupported bool + selfSignedSupported bool +} + +func (r *revocationMTLSTestRevoker) Decoder() httphelper.Decoder { + return r.decoder +} + +func (r *revocationMTLSTestRevoker) Crypto() op.Crypto { + return nil +} + +func (r *revocationMTLSTestRevoker) Storage() op.Storage { + return r.storage +} + +func (r *revocationMTLSTestRevoker) AccessTokenVerifier(context.Context) *op.AccessTokenVerifier { + return nil +} + +func (r *revocationMTLSTestRevoker) AuthMethodPrivateKeyJWTSupported() bool { + return false +} + +func (r *revocationMTLSTestRevoker) AuthMethodPostSupported() bool { + return false +} + +func (r *revocationMTLSTestRevoker) MTLSConfig() *op.MTLSConfig { + return r.mtlsConfig +} + +func (r *revocationMTLSTestRevoker) AuthMethodTLSClientAuthSupported() bool { + return r.tlsSupported +} + +func (r *revocationMTLSTestRevoker) AuthMethodSelfSignedTLSClientAuthSupported() bool { + return r.selfSignedSupported +} + +type mtlsTestClient struct { + id string + authMethod oidc.AuthMethod + accessTokenType op.AccessTokenType + responseTypes []oidc.ResponseType + grantTypes []oidc.GrantType + mtlsConfig *op.MTLSClientConfig + registeredCerts []string +} + +func (c *mtlsTestClient) GetID() string { + return c.id +} + +func (c *mtlsTestClient) RedirectURIs() []string { + return nil +} + +func (c *mtlsTestClient) PostLogoutRedirectURIs() []string { + return nil +} + +func (c *mtlsTestClient) ApplicationType() op.ApplicationType { + return op.ApplicationTypeWeb +} + +func (c *mtlsTestClient) AuthMethod() oidc.AuthMethod { + return c.authMethod +} + +func (c *mtlsTestClient) ResponseTypes() []oidc.ResponseType { + return c.responseTypes +} + +func (c *mtlsTestClient) GrantTypes() []oidc.GrantType { + return c.grantTypes +} + +func (c *mtlsTestClient) LoginURL(id string) string { + return "" +} + +func (c *mtlsTestClient) AccessTokenType() op.AccessTokenType { + return c.accessTokenType +} + +func (c *mtlsTestClient) IDTokenLifetime() time.Duration { + return time.Minute +} + +func (c *mtlsTestClient) DevMode() bool { + return false +} + +func (c *mtlsTestClient) RestrictAdditionalIdTokenScopes() func([]string) []string { + return func(scopes []string) []string { return scopes } +} + +func (c *mtlsTestClient) RestrictAdditionalAccessTokenScopes() func([]string) []string { + return func(scopes []string) []string { return scopes } +} + +func (c *mtlsTestClient) IsScopeAllowed(string) bool { + return true +} + +func (c *mtlsTestClient) IDTokenUserinfoClaimsAssertion() bool { + return false +} + +func (c *mtlsTestClient) ClockSkew() time.Duration { + return 0 +} + +func (c *mtlsTestClient) GetMTLSConfig() *op.MTLSClientConfig { + return c.mtlsConfig +} + +func (c *mtlsTestClient) GetRegisteredCertificates() []string { + return c.registeredCerts +} + +func newTestDecoder() *schema.Decoder { + dec := schema.NewDecoder() + dec.IgnoreUnknownKeys(true) + return dec +} + +func TestParseTokenRevocationRequest_MTLSClientAuth_Success(t *testing.T) { + ca, caKey := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client1"}, + parent: ca, + parentKey: caKey, + }) + + pool := x509.NewCertPool() + pool.AddCert(ca) + + client := &mtlsTestClient{ + id: "client1", + authMethod: oidc.AuthMethodTLSClientAuth, + accessTokenType: op.AccessTokenTypeJWT, + responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode}, + grantTypes: []oidc.GrantType{oidc.GrantTypeCode}, + mtlsConfig: &op.MTLSClientConfig{SubjectDN: "CN=client1"}, + } + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().GetClientByClientID(gomock.Any(), "client1").Return(client, nil) + + revoker := &revocationMTLSTestRevoker{ + decoder: newTestDecoder(), + storage: storage, + mtlsConfig: &op.MTLSConfig{TrustStore: pool}, + tlsSupported: true, + } + + r := httptest.NewRequest(http.MethodPost, "/revoke", strings.NewReader("token=foo&client_id=client1")) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r.TLS = &tls.ConnectionState{PeerCertificates: []*x509.Certificate{clientCert}} + + token, hint, clientID, err := op.ParseTokenRevocationRequest(r, revoker) + require.NoError(t, err) + require.Equal(t, "foo", token) + require.Empty(t, hint) + require.Equal(t, "client1", clientID) +} + +func TestParseTokenRevocationRequest_MTLSClientAuth_NoCert(t *testing.T) { + ca, _ := generateTestCA(t, pkix.Name{CommonName: "Test CA"}) + pool := x509.NewCertPool() + pool.AddCert(ca) + + client := &mtlsTestClient{ + id: "client1", + authMethod: oidc.AuthMethodTLSClientAuth, + accessTokenType: op.AccessTokenTypeJWT, + responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode}, + grantTypes: []oidc.GrantType{oidc.GrantTypeCode}, + mtlsConfig: &op.MTLSClientConfig{SubjectDN: "CN=client1"}, + } + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().GetClientByClientID(gomock.Any(), "client1").Return(client, nil) + + revoker := &revocationMTLSTestRevoker{ + decoder: newTestDecoder(), + storage: storage, + mtlsConfig: &op.MTLSConfig{TrustStore: pool}, + tlsSupported: true, + } + + r := httptest.NewRequest(http.MethodPost, "/revoke", strings.NewReader("token=foo&client_id=client1")) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + _, _, _, err := op.ParseTokenRevocationRequest(r, revoker) + require.Error(t, err) + + var oidcErr *oidc.Error + require.ErrorAs(t, err, &oidcErr) + require.Equal(t, oidc.InvalidClient, oidcErr.ErrorType) +} + +func TestParseTokenRevocationRequest_SelfSignedTLSClientAuth_Success(t *testing.T) { + clientCert, _ := generateTestCert(t, testCertOptions{ + subject: pkix.Name{CommonName: "client1"}, + }) + + client := &mtlsTestClient{ + id: "client1", + authMethod: oidc.AuthMethodSelfSignedTLSClientAuth, + accessTokenType: op.AccessTokenTypeJWT, + responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode}, + grantTypes: []oidc.GrantType{oidc.GrantTypeCode}, + registeredCerts: []string{certToPEM(clientCert)}, + } + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().GetClientByClientID(gomock.Any(), "client1").Return(client, nil) + + revoker := &revocationMTLSTestRevoker{ + decoder: newTestDecoder(), + storage: storage, + mtlsConfig: &op.MTLSConfig{}, + selfSignedSupported: true, + } + + r := httptest.NewRequest(http.MethodPost, "/revoke", strings.NewReader("token=foo&client_id=client1")) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r.TLS = &tls.ConnectionState{PeerCertificates: []*x509.Certificate{clientCert}} + + token, hint, clientID, err := op.ParseTokenRevocationRequest(r, revoker) + require.NoError(t, err) + require.Equal(t, "foo", token) + require.Empty(t, hint) + require.Equal(t, "client1", clientID) +} diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index 7a828f4d..61406cc6 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -17,6 +17,13 @@ type UserinfoProvider interface { AccessTokenVerifier(context.Context) *AccessTokenVerifier } +// UserinfoMTLSProvider is an optional interface for providers that support +// certificate-bound access token verification at the UserInfo endpoint (RFC 8705). +type UserinfoMTLSProvider interface { + UserinfoProvider + MTLSConfig() *MTLSConfig +} + func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { Userinfo(w, r, userinfoProvider) @@ -33,11 +40,25 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP http.Error(w, "access token missing", http.StatusUnauthorized) return } - tokenID, subject, ok := getTokenIDAndSubject(r.Context(), userinfoProvider, accessToken) + tokenID, subject, claims, ok := getTokenIDAndSubjectAndClaims(r.Context(), userinfoProvider, accessToken) if !ok { http.Error(w, "access token invalid", http.StatusUnauthorized) return } + + // Verify certificate-bound token if cnf claim is present (RFC 8705) + if cnfThumbprint := GetCnfThumbprintFromClaims(claims); cnfThumbprint != "" { + mtlsProvider, ok := userinfoProvider.(UserinfoMTLSProvider) + if !ok { + http.Error(w, "certificate-bound token not supported", http.StatusUnauthorized) + return + } + if err := VerifyCertificateBindingFromRequest(r, mtlsProvider.MTLSConfig(), cnfThumbprint); err != nil { + http.Error(w, "certificate binding verification failed", http.StatusUnauthorized) + return + } + } + info := new(oidc.UserInfo) err = userinfoProvider.Storage().SetUserinfoFromToken(r.Context(), info, tokenID, subject, r.Header.Get("origin")) if err != nil { @@ -85,20 +106,28 @@ func getAccessToken(r *http.Request) (string, error) { } func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) { - ctx, span := Tracer.Start(ctx, "getTokenIDAndSubject") + tokenID, subject, _, ok := getTokenIDAndSubjectAndClaims(ctx, userinfoProvider, accessToken) + return tokenID, subject, ok +} + +// getTokenIDAndSubjectAndClaims returns token ID, subject, and claims (for JWT tokens). +// For opaque tokens, claims will be nil. +func getTokenIDAndSubjectAndClaims(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, map[string]any, bool) { + ctx, span := Tracer.Start(ctx, "getTokenIDAndSubjectAndClaims") defer span.End() tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken) if err == nil { splitToken := strings.Split(tokenIDSubject, ":") if len(splitToken) != 2 { - return "", "", false + return "", "", nil, false } - return splitToken[0], splitToken[1], true + // Opaque token - no claims available directly + return splitToken[0], splitToken[1], nil, true } accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx)) if err != nil { - return "", "", false + return "", "", nil, false } - return accessTokenClaims.JWTID, accessTokenClaims.Subject, true + return accessTokenClaims.JWTID, accessTokenClaims.Subject, accessTokenClaims.Claims, true }