Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 18 additions & 32 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ var (
Tracer = otel.Tracer("github.com/zitadel/oidc/pkg/client")
)

type ClientSecretBasicAuthRequest interface {
Auth(req *http.Request)
}

// Discover calls the discovery endpoint of the provided issuer and returns its configuration
// It accepts an optional argument "wellknownUrl" which can be used to override the discovery endpoint url
func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) {
Expand Down Expand Up @@ -64,10 +60,10 @@ type TokenEndpointCaller interface {
}

func CallTokenEndpoint(ctx context.Context, request any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
return callTokenEndpoint(ctx, request, nil, caller)
return CallTokenEndpointWithAuthFn(ctx, request, nil, caller)
}

func callTokenEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
func CallTokenEndpointWithAuthFn(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
ctx, span := Tracer.Start(ctx, "callTokenEndpoint")
defer span.End()

Expand All @@ -76,10 +72,6 @@ func callTokenEndpoint(ctx context.Context, request any, authFn any, caller Toke
return nil, err
}

if basicAuthRequest, ok := request.(ClientSecretBasicAuthRequest); ok {
basicAuthRequest.Auth(req)
}

tokenRes := new(oidc.AccessTokenResponse)
if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil {
return nil, err
Expand Down Expand Up @@ -154,12 +146,6 @@ type RevokeRequest struct {
ClientSecret string `schema:"client_secret"`
}

func (r RevokeRequest) Auth(req *http.Request) {
if r.ClientSecret != "" {
req.SetBasicAuth(url.QueryEscape(r.ClientID), url.QueryEscape(r.ClientSecret))
}
}

func CallRevokeEndpoint(ctx context.Context, request any, authFn any, caller RevokeCaller) error {
ctx, span := Tracer.Start(ctx, "CallRevokeEndpoint")
defer span.End()
Expand All @@ -174,10 +160,6 @@ func CallRevokeEndpoint(ctx context.Context, request any, authFn any, caller Rev
return err
}

if basicAuthRequest, ok := request.(ClientSecretBasicAuthRequest); ok {
basicAuthRequest.Auth(req)
}

client := caller.HttpClient()
client.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
return http.ErrUseLastResponse
Expand Down Expand Up @@ -258,7 +240,6 @@ func CallDeviceAuthorizationEndpoint(ctx context.Context, request *oidc.ClientCr
if err != nil {
return nil, err
}
request.Auth(req)

resp := new(oidc.DeviceAuthorizationResponse)
if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil {
Expand All @@ -272,21 +253,15 @@ type DeviceAccessTokenRequest struct {
oidc.DeviceAccessTokenRequest
}

func (r *DeviceAccessTokenRequest) Auth(req *http.Request) {
if r.ClientSecret != "" {
req.SetBasicAuth(url.QueryEscape(r.ClientID), url.QueryEscape(r.ClientSecret))
}
}

func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
// CallDeviceAccessTokenEndpointWithAuthFn calls the device access token endpoint, accepting an authFn for custom authentication.
func CallDeviceAccessTokenEndpointWithAuthFn(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller, authFn any) (*oidc.AccessTokenResponse, error) {
ctx, span := Tracer.Start(ctx, "CallDeviceAccessTokenEndpoint")
defer span.End()

req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, nil)
req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn)
if err != nil {
return nil, err
}
request.Auth(req)

resp := new(oidc.AccessTokenResponse)
if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil {
Expand All @@ -295,7 +270,13 @@ func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTok
return resp, nil
}

func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
// Deprecated: Use CallDeviceAccessTokenEndpointWithAuthFn instead.
Comment thread
suqin-haha marked this conversation as resolved.
func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
return CallDeviceAccessTokenEndpointWithAuthFn(ctx, request, caller, nil)
}

// PollDeviceAccessTokenEndpointWithAuthFn polls the device access token endpoint, accepting an authFn for custom authentication.
func PollDeviceAccessTokenEndpointWithAuthFn(ctx context.Context, interval time.Duration, request *DeviceAccessTokenRequest, caller TokenEndpointCaller, authFn any) (*oidc.AccessTokenResponse, error) {
ctx, span := Tracer.Start(ctx, "PollDeviceAccessTokenEndpoint")
defer span.End()

Expand All @@ -310,7 +291,7 @@ func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration,
ctx, cancel := context.WithTimeout(ctx, interval)
defer cancel()

resp, err := CallDeviceAccessTokenEndpoint(ctx, request, caller)
resp, err := CallDeviceAccessTokenEndpointWithAuthFn(ctx, request, caller, authFn)
if err == nil {
return resp, nil
}
Expand All @@ -332,3 +313,8 @@ func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration,
}
}
}

// Deprecated: Use PollDeviceAccessTokenEndpointWithAuthFn instead.
func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
Comment thread
suqin-haha marked this conversation as resolved.
return PollDeviceAccessTokenEndpointWithAuthFn(ctx, interval, request, caller, nil)
}
32 changes: 27 additions & 5 deletions pkg/client/rp/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"time"

"github.com/zitadel/oidc/v3/pkg/client"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
)

func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) {
Expand Down Expand Up @@ -53,17 +55,37 @@ func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Dur
defer span.End()

ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken")

req := &client.DeviceAccessTokenRequest{
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
GrantType: oidc.GrantTypeDeviceCode,
DeviceCode: deviceCode,
},
ClientCredentialsRequest: &oidc.ClientCredentialsRequest{
Scope: nil,
},
}

req.ClientCredentialsRequest, err = newDeviceClientCredentialsRequest(nil, rp)
if err != nil {
return nil, err
}
var authFn httphelper.RequestAuthorization

// https://datatracker.ietf.org/doc/html/rfc6749#section-2.3
// The client MUST NOT use more than one authentication method in each request.
switch rp.OAuthConfig().Endpoint.AuthStyle {
case oauth2.AuthStyleInHeader:
authFn = httphelper.AuthorizeBasic(rp.OAuthConfig().ClientID, rp.OAuthConfig().ClientSecret)
default:
if signer := rp.Signer(); signer != nil {
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, signer)
if err != nil {
return nil, fmt.Errorf("failed to build assertion: %w", err)
}
req.ClientAssertion = assertion
req.ClientAssertionType = oidc.ClientAssertionTypeJWTAssertion
} else {
req.ClientID = rp.OAuthConfig().ClientID
req.ClientSecret = rp.OAuthConfig().ClientSecret
}

return client.PollDeviceAccessTokenEndpoint(ctx, interval, req, tokenEndpointCaller{rp})
}
return client.PollDeviceAccessTokenEndpointWithAuthFn(ctx, interval, req, tokenEndpointCaller{rp}, authFn)
}
60 changes: 42 additions & 18 deletions pkg/client/rp/relying_party.go
Original file line number Diff line number Diff line change
Expand Up @@ -816,12 +816,6 @@ type RefreshTokenRequest struct {
GrantType oidc.GrantType `schema:"grant_type"`
}

func (r RefreshTokenRequest) Auth(req *http.Request) {
if r.ClientSecret != "" {
req.SetBasicAuth(url.QueryEscape(r.ClientID), url.QueryEscape(r.ClientSecret))
}
}

// RefreshTokens performs a token refresh. If it doesn't error, it will always
// provide a new AccessToken. It may provide a new RefreshToken, and if it does, then
// the old one should be considered invalid.
Expand All @@ -834,16 +828,34 @@ func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refres
defer span.End()

ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens")

var authFn httphelper.RequestAuthorization
request := RefreshTokenRequest{
RefreshToken: refreshToken,
Scopes: rp.OAuthConfig().Scopes,
ClientID: rp.OAuthConfig().ClientID,
ClientSecret: rp.OAuthConfig().ClientSecret,
ClientAssertion: clientAssertion,
ClientAssertionType: clientAssertionType,
GrantType: oidc.GrantTypeRefreshToken,
}
newToken, err := client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp})
RefreshToken: refreshToken,
Scopes: rp.OAuthConfig().Scopes,
GrantType: oidc.GrantTypeRefreshToken,
}

// https://datatracker.ietf.org/doc/html/rfc6749#section-2.3
// The client MUST NOT use more than one authentication method in each request.
switch rp.OAuthConfig().Endpoint.AuthStyle {
case oauth2.AuthStyleInHeader:
if clientAssertion != "" {
return nil, errors.New("client assertion is not supported with AuthStyleInHeader")
}
authFn = httphelper.AuthorizeBasic(rp.OAuthConfig().ClientID, rp.OAuthConfig().ClientSecret)
default:
// use client id and secret in the request body
if clientAssertion != "" && clientAssertionType != "" {
request.ClientAssertion = clientAssertion
request.ClientAssertionType = clientAssertionType
} else {
request.ClientID = rp.OAuthConfig().ClientID
request.ClientSecret = rp.OAuthConfig().ClientSecret
}
}

newToken, err := client.CallTokenEndpointWithAuthFn(ctx, request, authFn, tokenEndpointCaller{RelyingParty: rp})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -881,14 +893,26 @@ func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHi
ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken")
ctx, span := client.Tracer.Start(ctx, "RefreshTokens")
defer span.End()

request := client.RevokeRequest{
Token: token,
TokenTypeHint: tokenTypeHint,
ClientID: rp.OAuthConfig().ClientID,
ClientSecret: rp.OAuthConfig().ClientSecret,
}
var authFn httphelper.RequestAuthorization

// https://datatracker.ietf.org/doc/html/rfc6749#section-2.3
// The client MUST NOT use more than one authentication method in each request.
switch rp.OAuthConfig().Endpoint.AuthStyle {
case oauth2.AuthStyleInHeader:
authFn = httphelper.AuthorizeBasic(rp.OAuthConfig().ClientID, rp.OAuthConfig().ClientSecret)
default:
// use client id and secret in the request body
request.ClientID = rp.OAuthConfig().ClientID
request.ClientSecret = rp.OAuthConfig().ClientSecret
}

if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" {
return client.CallRevokeEndpoint(ctx, request, nil, rc)
return client.CallRevokeEndpoint(ctx, request, authFn, rc)
}
return ErrRelyingPartyNotSupportRevokeCaller
}
Expand Down
10 changes: 6 additions & 4 deletions pkg/http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ type Encoder interface {
Encode(src any, dst map[string][]string) error
}

type FormAuthorization func(url.Values)
type RequestAuthorization func(*http.Request)
type (
FormAuthorization func(url.Values)
RequestAuthorization func(*http.Request)
)

func AuthorizeBasic(user, password string) RequestAuthorization {
return func(req *http.Request) {
Expand All @@ -40,15 +42,15 @@ func FormRequest(ctx context.Context, endpoint string, request any, encoder Enco
if err := encoder.Encode(request, form); err != nil {
return nil, err
}
if fn, ok := authFn.(FormAuthorization); ok {
if fn, ok := authFn.(FormAuthorization); ok && fn != nil {
fn(form)
}
body := strings.NewReader(form.Encode())
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body)
if err != nil {
return nil, err
}
if fn, ok := authFn.(RequestAuthorization); ok {
if fn, ok := authFn.(RequestAuthorization); ok && fn != nil {
fn(req)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
Expand Down
24 changes: 16 additions & 8 deletions pkg/oidc/token_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package oidc
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"slices"
"time"

Expand Down Expand Up @@ -94,6 +92,14 @@ func (a *AccessTokenRequest) SetClientSecret(clientSecret string) {
a.ClientSecret = clientSecret
}

func (a *AccessTokenRequest) IsSetClientAssertion() bool {
return a.ClientAssertion != "" && a.ClientAssertionType != ""
}

func (a *AccessTokenRequest) IsSetClientIDAndClientSecret() bool {
return a.ClientID != "" && a.ClientSecret != ""
}

// RefreshTokenRequest is not useful for making refresh requests because the
// grant_type is not included explicitly but rather implied.
type RefreshTokenRequest struct {
Expand All @@ -119,6 +125,14 @@ func (a *RefreshTokenRequest) SetClientSecret(clientSecret string) {
a.ClientSecret = clientSecret
}

func (a *RefreshTokenRequest) IsSetClientAssertion() bool {
return a.ClientAssertion != "" && a.ClientAssertionType != ""
}

func (a *RefreshTokenRequest) IsSetClientIDAndClientSecret() bool {
return a.ClientID != "" && a.ClientSecret != ""
}

type JWTTokenRequest struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Expand Down Expand Up @@ -245,9 +259,3 @@ type ClientCredentialsRequest struct {
ClientAssertion string `schema:"client_assertion,omitempty"`
ClientAssertionType string `schema:"client_assertion_type,omitempty"`
}

func (r *ClientCredentialsRequest) Auth(req *http.Request) {
if r.ClientSecret != "" {
req.SetBasicAuth(url.QueryEscape(r.ClientID), url.QueryEscape(r.ClientSecret))
}
}
12 changes: 12 additions & 0 deletions pkg/op/server_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,20 @@ func (s *webServer) parseClientCredentials(r *http.Request) (_ *ClientCredential
if err = s.decoder.Decode(cc, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}

// https://datatracker.ietf.org/doc/html/rfc6749#section-2.3
// The client MUST NOT use more than one authentication method in each request.
assertionAuthExists := cc.ClientAssertion != "" && cc.ClientAssertionType != ""
secretAuthExists := cc.ClientSecret != "" && cc.ClientID != ""
if assertionAuthExists && secretAuthExists {
return nil, oidc.ErrInvalidRequest().WithDescription("client authentication must not use more than one method")
}

// Basic auth takes precedence, so if set it overwrites the form data.
if clientID, clientSecret, ok := r.BasicAuth(); ok {
if assertionAuthExists || secretAuthExists {
return nil, oidc.ErrInvalidRequest().WithDescription("client authentication must not use more than one method")
}
cc.ClientID, err = url.QueryUnescape(clientID)
if err != nil {
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
Expand Down
Loading