diff --git a/oauth2x/auth.go b/oauth2x/auth.go index b3ca9ce8..8f4e99ef 100644 --- a/oauth2x/auth.go +++ b/oauth2x/auth.go @@ -2,7 +2,11 @@ package oauth2x import ( "context" + "encoding/json" + "errors" "net/http" + "net/url" + "time" "github.com/spf13/pflag" "github.com/spf13/viper" @@ -13,15 +17,29 @@ import ( "golang.org/x/oauth2/clientcredentials" ) +var ( + tokenEndpointClient = &http.Client{ + Timeout: 5 * time.Second, // nolint:gmnd // clear and unexported + } + + // ErrTokenEndpointMissing is returned when the issuers .well-known/openid-configuration is missing the token_endpoint key. + ErrTokenEndpointMissing = errors.New("token endpoint missing from issuer well-known openid-configuration") +) + // NewClientCredentialsTokenSrc returns an oauth2 client credentials token source -func NewClientCredentialsTokenSrc(ctx context.Context, cfg Config) oauth2.TokenSource { +func NewClientCredentialsTokenSrc(ctx context.Context, cfg Config) (oauth2.TokenSource, error) { + tokenEndpoint, err := fetchIssuerTokenEndpoint(ctx, cfg.Issuer) + if err != nil { + return nil, err + } + ccCfg := clientcredentials.Config{ ClientID: cfg.ID, ClientSecret: cfg.Secret, - TokenURL: cfg.TokenURL, + TokenURL: tokenEndpoint, } - return ccCfg.TokenSource(ctx) + return ccCfg.TokenSource(ctx), nil } // NewClient returns a http client using requested token source @@ -32,9 +50,9 @@ func NewClient(ctx context.Context, tokenSrc oauth2.TokenSource) *http.Client { // Config handles reading in all the config values available // for setting up an oauth2 configuration type Config struct { - ID string `mapstructure:"id"` - Secret string `mapstructure:"secret"` - TokenURL string `mapstructure:"token_url"` + ID string `mapstructure:"id"` + Secret string `mapstructure:"secret"` + Issuer string `mapstructure:"issuer"` } // MustViperFlags adds oidc oauth2 client credentials config to the provided flagset and binds to viper @@ -45,6 +63,36 @@ func MustViperFlags(v *viper.Viper, flags *pflag.FlagSet) { flags.String("oidc-client-secret", "", "oidc client secret") viperx.MustBindFlag(v, "oidc.client.secret", flags.Lookup("oidc-client-secret")) - flags.String("oidc-client-token-url", "", "oidc token url") - viperx.MustBindFlag(v, "oidc.client.token_url", flags.Lookup("oidc-client-token-url")) + flags.String("oidc-client-issuer", "", "oidc issuer") + viperx.MustBindFlag(v, "oidc.client.issuer", flags.Lookup("oidc-client-issuer")) +} + +func fetchIssuerTokenEndpoint(ctx context.Context, issuer string) (string, error) { + uri, err := url.JoinPath(issuer, ".well-known", "openid-configuration") + if err != nil { + return "", err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return "", err + } + + res, err := tokenEndpointClient.Do(req) + if err != nil { + return "", err + } + defer res.Body.Close() //nolint:errcheck // no need to check + + var m map[string]interface{} + if err := json.NewDecoder(res.Body).Decode(&m); err != nil { + return "", err + } + + tokenEndpoint, ok := m["token_endpoint"] + if !ok { + return "", ErrTokenEndpointMissing + } + + return tokenEndpoint.(string), nil }