Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 107 additions & 121 deletions a2a/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,22 @@ type SecurityRequirements map[SecuritySchemeName]SecuritySchemeScopes
// }
type SecurityRequirementsOptions []SecurityRequirements

type securityRequirements struct {
Schemes map[SecuritySchemeName]SecuritySchemeScopes `json:"schemes"`
}

// MarshalJSON implements json.Marshaler.
func (rs SecurityRequirementsOptions) MarshalJSON() ([]byte, error) {
type wrapper struct {
Schemes map[SecuritySchemeName]SecuritySchemeScopes `json:"schemes"`
}
var out []wrapper
var out []securityRequirements
for _, req := range rs {
out = append(out, wrapper{Schemes: req})
out = append(out, securityRequirements{Schemes: req})
}
return json.Marshal(out)
}

// UnmarshalJSON implements json.Unmarshaler.
func (rs *SecurityRequirementsOptions) UnmarshalJSON(b []byte) error {
type wrapper struct {
Schemes map[SecuritySchemeName]SecuritySchemeScopes `json:"schemes"`
}
var wrapped []wrapper
var wrapped []securityRequirements
if err := json.Unmarshal(b, &wrapped); err != nil {
return err
}
Expand All @@ -85,80 +83,76 @@ type SecuritySchemeScopes []string
// The key is the scheme name. Follows the OpenAPI 3.0 Security Scheme Object.
type NamedSecuritySchemes map[SecuritySchemeName]SecurityScheme

type securityScheme struct {
APIKey *APIKeySecurityScheme `json:"apiKeySecurityScheme,omitempty"`
HTTPAuth *HTTPAuthSecurityScheme `json:"httpAuthSecurityScheme,omitempty"`
MutualTLS *MutualTLSSecurityScheme `json:"mtlsSecurityScheme,omitempty"`
OAuth2 *OAuth2SecurityScheme `json:"oauth2SecurityScheme,omitempty"`
OpenIDConnect *OpenIDConnectSecurityScheme `json:"openIdConnectSecurityScheme,omitempty"`
}

// MarshalJSON implements json.Marshaler.
func (s NamedSecuritySchemes) MarshalJSON() ([]byte, error) {
out := make(map[SecuritySchemeName]any)
out := make(map[SecuritySchemeName]securityScheme, len(s))
for name, scheme := range s {
var wrapped any
var wrapper securityScheme
switch v := scheme.(type) {
// TODO: remove short JSON discriminator keys after transition period
case APIKeySecurityScheme:
wrapped = map[string]any{"apiKeySecurityScheme": v}
wrapper.APIKey = &v
case HTTPAuthSecurityScheme:
wrapped = map[string]any{"httpAuthSecurityScheme": v}
wrapper.HTTPAuth = &v
case OpenIDConnectSecurityScheme:
wrapped = map[string]any{"openIdConnectSecurityScheme": v}
wrapper.OpenIDConnect = &v
case MutualTLSSecurityScheme:
wrapped = map[string]any{"mtlsSecurityScheme": v}
wrapper.MutualTLS = &v
case OAuth2SecurityScheme:
wrapped = map[string]any{"oauth2SecurityScheme": v}
wrapper.OAuth2 = &v
default:
return nil, fmt.Errorf("unknown security scheme type %T", v)
}
out[name] = wrapped
out[name] = wrapper
}
return json.Marshal(out)
}

// UnmarshalJSON implements json.Unmarshaler.
func (s *NamedSecuritySchemes) UnmarshalJSON(b []byte) error {
var schemes map[SecuritySchemeName]json.RawMessage
var schemes map[SecuritySchemeName]securityScheme
if err := json.Unmarshal(b, &schemes); err != nil {
return err
}

result := make(map[SecuritySchemeName]SecurityScheme, len(schemes))
for name, rawMessage := range schemes {
var raw map[string]json.RawMessage
if err := json.Unmarshal(rawMessage, &raw); err != nil {
return err
result := make(NamedSecuritySchemes, len(schemes))
for name, wrapper := range schemes {
var n int
if wrapper.APIKey != nil {
result[name] = *wrapper.APIKey
n++
}
if v, ok := raw["apiKeySecurityScheme"]; ok {
var scheme APIKeySecurityScheme
if err := json.Unmarshal(v, &scheme); err != nil {
return err
}
result[name] = scheme
} else if v, ok := raw["httpAuthSecurityScheme"]; ok {
var scheme HTTPAuthSecurityScheme
if err := json.Unmarshal(v, &scheme); err != nil {
return err
}
result[name] = scheme
} else if v, ok := raw["mtlsSecurityScheme"]; ok {
var scheme MutualTLSSecurityScheme
if err := json.Unmarshal(v, &scheme); err != nil {
return err
}
result[name] = scheme
} else if v, ok := raw["oauth2SecurityScheme"]; ok {
var scheme OAuth2SecurityScheme
if err := json.Unmarshal(v, &scheme); err != nil {
return err
}
result[name] = scheme
} else if v, ok := raw["openIdConnectSecurityScheme"]; ok {
var scheme OpenIDConnectSecurityScheme
if err := json.Unmarshal(v, &scheme); err != nil {
return err
}
result[name] = scheme
} else {
keys := make([]string, 0, len(raw))
for k := range raw {
keys = append(keys, k)
if wrapper.HTTPAuth != nil {
result[name] = *wrapper.HTTPAuth
n++
}
if wrapper.OpenIDConnect != nil {
result[name] = *wrapper.OpenIDConnect
n++
}
if wrapper.MutualTLS != nil {
result[name] = *wrapper.MutualTLS
n++
}
if wrapper.OAuth2 != nil {
result[name] = *wrapper.OAuth2
n++
}
if n == 0 {
var raw map[SecuritySchemeName]json.RawMessage
if err := json.Unmarshal(b, &raw); err != nil {
return fmt.Errorf("unknown security scheme for %s", name)
}
return fmt.Errorf("unknown security scheme type for %q: found keys %v", name, keys)
return fmt.Errorf("unknown security scheme type for %s: %v", name, jsonKeys([]byte(raw[name])))
}
if n != 1 {
return fmt.Errorf("expected exactly one security scheme type for %s, got %d", name, n)
}
}

Expand Down Expand Up @@ -268,25 +262,34 @@ const (
DeviceCodeOAuthFlowName OAuthFlowName = "deviceCode"
)

type oauth2 struct {
Description string `json:"description,omitempty"`
Oauth2MetadataURL string `json:"oauth2MetadataUrl,omitempty"`
Flows oauthFlows `json:"flows"`
Comment thread
yarolegovich marked this conversation as resolved.
}

type oauthFlows struct {
AuthorizationCode *AuthorizationCodeOAuthFlow `json:"authorizationCode,omitempty"`
ClientCredentials *ClientCredentialsOAuthFlow `json:"clientCredentials,omitempty"`
Implicit *ImplicitOAuthFlow `json:"implicit,omitempty"`
Password *PasswordOAuthFlow `json:"password,omitempty"`
DeviceCode *DeviceCodeOAuthFlow `json:"deviceCode,omitempty"`
}

// MarshalJSON implements json.Marshaler.
func (s OAuth2SecurityScheme) MarshalJSON() ([]byte, error) {
type wrapper struct {
Description string `json:"description,omitempty"`
Oauth2MetadataURL string `json:"oauth2MetadataUrl,omitempty"`
Flows map[OAuthFlowName]any `json:"flows,omitempty"`
}
wrapped := wrapper{Description: s.Description, Oauth2MetadataURL: s.Oauth2MetadataURL}
wrapped := oauth2{Description: s.Description, Oauth2MetadataURL: s.Oauth2MetadataURL}
switch v := s.Flows.(type) {
case AuthorizationCodeOAuthFlow:
wrapped.Flows = map[OAuthFlowName]any{"authorizationCode": v}
wrapped.Flows = oauthFlows{AuthorizationCode: &v}
case ClientCredentialsOAuthFlow:
wrapped.Flows = map[OAuthFlowName]any{"clientCredentials": v}
wrapped.Flows = oauthFlows{ClientCredentials: &v}
case ImplicitOAuthFlow:
wrapped.Flows = map[OAuthFlowName]any{"implicit": v}
wrapped.Flows = oauthFlows{Implicit: &v}
case PasswordOAuthFlow:
wrapped.Flows = map[OAuthFlowName]any{"password": v}
wrapped.Flows = oauthFlows{Password: &v}
case DeviceCodeOAuthFlow:
wrapped.Flows = map[OAuthFlowName]any{"deviceCode": v}
wrapped.Flows = oauthFlows{DeviceCode: &v}
default:
return nil, fmt.Errorf("unknown OAuth flow type %T", v)
}
Expand All @@ -295,62 +298,45 @@ func (s OAuth2SecurityScheme) MarshalJSON() ([]byte, error) {

// UnmarshalJSON implements json.Unmarshaler.
func (s *OAuth2SecurityScheme) UnmarshalJSON(b []byte) error {
type wrapper struct {
Description string `json:"description,omitempty"`
Oauth2MetadataURL string `json:"oauth2MetadataUrl,omitempty"`
Flows map[OAuthFlowName]json.RawMessage `json:"flows,omitempty"`
}
var scheme wrapper
var scheme oauth2
if err := json.Unmarshal(b, &scheme); err != nil {
return err
}

if len(scheme.Flows) != 1 {
return fmt.Errorf("expected exactly one OAuth flow, got %d", len(scheme.Flows))
s.Description = scheme.Description
s.Oauth2MetadataURL = scheme.Oauth2MetadataURL
var n int
if scheme.Flows.AuthorizationCode != nil {
s.Flows = *scheme.Flows.AuthorizationCode
n++
}

for name, rawMessage := range scheme.Flows {
switch name {
case "authorizationCode":
var flow AuthorizationCodeOAuthFlow
if err := json.Unmarshal(rawMessage, &flow); err != nil {
return err
}
s.Flows = flow
case "clientCredentials":
var flow ClientCredentialsOAuthFlow
if err := json.Unmarshal(rawMessage, &flow); err != nil {
return err
}
s.Flows = flow
case "implicit":
var flow ImplicitOAuthFlow
if err := json.Unmarshal(rawMessage, &flow); err != nil {
return err
}
s.Flows = flow
case "password":
var flow PasswordOAuthFlow
if err := json.Unmarshal(rawMessage, &flow); err != nil {
return err
}
s.Flows = flow
case "deviceCode":
var flow DeviceCodeOAuthFlow
if err := json.Unmarshal(rawMessage, &flow); err != nil {
return err
}
s.Flows = flow
default:
keys := make([]string, 0, len(scheme.Flows))
for k := range scheme.Flows {
keys = append(keys, string(k))
}
return fmt.Errorf("unknown OAuth flow type: %s, available: %v", name, keys)
if scheme.Flows.ClientCredentials != nil {
s.Flows = *scheme.Flows.ClientCredentials
n++
}
if scheme.Flows.Implicit != nil {
s.Flows = *scheme.Flows.Implicit
n++
}
if scheme.Flows.Password != nil {
s.Flows = *scheme.Flows.Password
n++
}
if scheme.Flows.DeviceCode != nil {
s.Flows = *scheme.Flows.DeviceCode
n++
}
if n == 0 {
var raw struct {
Flows json.RawMessage `json:"flows"`
}
if err := json.Unmarshal(b, &raw); err != nil {
return fmt.Errorf("unknown OAuth flow")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's add %w for the error

}
return fmt.Errorf("unknown OAuth flow type: %v", jsonKeys(raw.Flows))
}
if n != 1 {
return fmt.Errorf("expected exactly one OAuth flow, got %d", n)
}
s.Description = scheme.Description
s.Oauth2MetadataURL = scheme.Oauth2MetadataURL
return nil
}

Expand Down
Loading
Loading