diff --git a/go.mod b/go.mod index 5500d63a..745c5c2c 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/kr/pretty v0.3.1 github.com/mattermost/xml-roundtrip-validator v0.1.0 github.com/pkg/errors v0.9.1 // indirect - github.com/russellhaering/goxmldsig v1.2.0 + github.com/russellhaering/goxmldsig v1.3.0 github.com/stretchr/testify v1.8.1 github.com/zenazn/goji v1.0.1 golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed diff --git a/go.sum b/go.sum index a2bb4b19..7ab71ea2 100644 --- a/go.sum +++ b/go.sum @@ -35,8 +35,8 @@ github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTE github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/russellhaering/goxmldsig v1.2.0 h1:Y6GTTc9Un5hCxSzVz4UIWQ/zuVwDvzJk80guqzwx6Vg= -github.com/russellhaering/goxmldsig v1.2.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= +github.com/russellhaering/goxmldsig v1.3.0 h1:DllIWUgMy0cRUMfGiASiYEa35nsieyD3cigIwLonTPM= +github.com/russellhaering/goxmldsig v1.3.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/identity_provider.go b/identity_provider.go index 83f02991..b03dc89b 100644 --- a/identity_provider.go +++ b/identity_provider.go @@ -96,6 +96,7 @@ type AssertionMaker interface { // and password). type IdentityProvider struct { Key crypto.PrivateKey + Signer crypto.Signer Logger logger.Interface Certificate *x509.Certificate Intermediates []*x509.Certificate @@ -831,24 +832,8 @@ const canonicalizerPrefixList = "" // MakeAssertionEl sets `AssertionEl` to a signed, possibly encrypted, version of `Assertion`. func (req *IdpAuthnRequest) MakeAssertionEl() error { - keyPair := tls.Certificate{ - Certificate: [][]byte{req.IDP.Certificate.Raw}, - PrivateKey: req.IDP.Key, - Leaf: req.IDP.Certificate, - } - for _, cert := range req.IDP.Intermediates { - keyPair.Certificate = append(keyPair.Certificate, cert.Raw) - } - keyStore := dsig.TLSCertKeyStore(keyPair) - - signatureMethod := req.IDP.SignatureMethod - if signatureMethod == "" { - signatureMethod = dsig.RSASHA1SignatureMethod - } - - signingContext := dsig.NewDefaultSigningContext(keyStore) - signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) - if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { + signingContext, err := req.signingContext() + if err != nil { return err } @@ -1049,24 +1034,8 @@ func (req *IdpAuthnRequest) MakeResponse() error { // Sign the response element (we've already signed the Assertion element) { - keyPair := tls.Certificate{ - Certificate: [][]byte{req.IDP.Certificate.Raw}, - PrivateKey: req.IDP.Key, - Leaf: req.IDP.Certificate, - } - for _, cert := range req.IDP.Intermediates { - keyPair.Certificate = append(keyPair.Certificate, cert.Raw) - } - keyStore := dsig.TLSCertKeyStore(keyPair) - - signatureMethod := req.IDP.SignatureMethod - if signatureMethod == "" { - signatureMethod = dsig.RSASHA1SignatureMethod - } - - signingContext := dsig.NewDefaultSigningContext(keyStore) - signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) - if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { + signingContext, err := req.signingContext() + if err != nil { return err } @@ -1084,3 +1053,44 @@ func (req *IdpAuthnRequest) MakeResponse() error { req.ResponseEl = responseEl return nil } + +// signingContext will create a signing context for the request. +func (req *IdpAuthnRequest) signingContext() (*dsig.SigningContext, error) { + // Create a cert chain based off of the IDP cert and its intermediates. + certificates := [][]byte{req.IDP.Certificate.Raw} + for _, cert := range req.IDP.Intermediates { + certificates = append(certificates, cert.Raw) + } + + var signingContext *dsig.SigningContext + var err error + // If signer is set, use it instead of the private key. + if req.IDP.Signer != nil { + signingContext, err = dsig.NewSigningContext(req.IDP.Signer, certificates) + if err != nil { + return nil, err + } + } else { + keyPair := tls.Certificate{ + Certificate: certificates, + PrivateKey: req.IDP.Key, + Leaf: req.IDP.Certificate, + } + keyStore := dsig.TLSCertKeyStore(keyPair) + + signingContext = dsig.NewDefaultSigningContext(keyStore) + } + + // Default to using SHA1 if the signature method isn't set. + signatureMethod := req.IDP.SignatureMethod + if signatureMethod == "" { + signatureMethod = dsig.RSASHA1SignatureMethod + } + + signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) + if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { + return nil, err + } + + return signingContext, nil +} diff --git a/identity_provider_go116_test.go b/identity_provider_go116_test.go index ead0a780..6d4a0a53 100644 --- a/identity_provider_go116_test.go +++ b/identity_provider_go116_test.go @@ -18,7 +18,7 @@ import ( ) func TestIDPHTTPCanHandleSSORequest(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) w := httptest.NewRecorder() const validRequest = `lJJBayoxFIX%2FypC9JhnU5wszAz7lgWCLaNtFd5fMbQ1MkmnunVb%2FfUfbUqEgdhs%2BTr5zkmLW8S5s8KVD4mzvm0Cl6FIwEciRCeCRDFuznd2sTD5Upk2Ro42NyGZEmNjFMI%2BBOo9pi%2BnVWbzfrEqxY27JSEntEPfg2waHNnpJ4JtcgiWRLfoLXYBjwDfu6p%2B8JIoiWy5K4eqBUipXIzVRUwXKKtRK53qkJ3qqQVuNPUjU4TIQQ%2BBS5EqPBzofKH2ntBn%2FMervo8jWnyX%2BuVC78FwKkT1gopNKX1JUxSklXTMIfM0gsv8xeeDL%2BPGk7%2FF0Qg0GdnwQ1cW5PDLUwFDID6uquO1Dlot1bJw9%2FPLRmia%2BzRMCYyk4dSiq6205QSDXOxfy3KAq5Pkvqt4DAAD%2F%2Fw%3D%3D` diff --git a/identity_provider_go117_test.go b/identity_provider_go117_test.go index 0ce6a1a7..68624518 100644 --- a/identity_provider_go117_test.go +++ b/identity_provider_go117_test.go @@ -18,7 +18,7 @@ import ( ) func TestIDPHTTPCanHandleSSORequest(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) w := httptest.NewRecorder() const validRequest = `lJJBayoxFIX%2FypC9JhnU5wszAz7lgWCLaNtFd5fMbQ1MkmnunVb%2FfUfbUqEgdhs%2BTr5zkmLW8S5s8KVD4mzvm0Cl6FIwEciRCeCRDFuznd2sTD5Upk2Ro42NyGZEmNjFMI%2BBOo9pi%2BnVWbzfrEqxY27JSEntEPfg2waHNnpJ4JtcgiWRLfoLXYBjwDfu6p%2B8JIoiWy5K4eqBUipXIzVRUwXKKtRK53qkJ3qqQVuNPUjU4TIQQ%2BBS5EqPBzofKH2ntBn%2FMervo8jWnyX%2BuVC78FwKkT1gopNKX1JUxSklXTMIfM0gsv8xeeDL%2BPGk7%2FF0Qg0GdnwQ1cW5PDLUwFDID6uquO1Dlot1bJw9%2FPLRmia%2BzRMCYyk4dSiq6205QSDXOxfy3KAq5Pkvqt4DAAD%2F%2Fw%3D%3D` diff --git a/identity_provider_test.go b/identity_provider_test.go index 2beaf83b..6ad81e2c 100644 --- a/identity_provider_test.go +++ b/identity_provider_test.go @@ -25,6 +25,7 @@ import ( "github.com/beevik/etree" "github.com/golang-jwt/jwt/v4" + dsig "github.com/russellhaering/goxmldsig" "github.com/crewjam/saml/logger" "github.com/crewjam/saml/testsaml" @@ -37,6 +38,7 @@ type IdentityProviderTest struct { SP ServiceProvider Key crypto.PrivateKey + Signer crypto.Signer Certificate *x509.Certificate SessionProvider SessionProvider IDP IdentityProvider @@ -50,7 +52,7 @@ func mustParseURL(s string) url.URL { return *rv } -func mustParsePrivateKey(pemStr []byte) crypto.PrivateKey { +func mustParsePrivateKey(pemStr []byte) crypto.Signer { b, _ := pem.Decode(pemStr) if b == nil { panic("cannot parse PEM") @@ -74,7 +76,28 @@ func mustParseCertificate(pemStr []byte) *x509.Certificate { return cert } -func NewIdentifyProviderTest(t *testing.T) *IdentityProviderTest { +// idpTestOpts are options that can be applied to the identity provider. +type idpTestOpts struct { + apply func(*testing.T, *IdentityProviderTest) +} + +// applyKey will set the private key for the identity provider. +var applyKey = idpTestOpts{ + apply: func(t *testing.T, test *IdentityProviderTest) { + test.Key = mustParsePrivateKey(golden.Get(t, "idp_key.pem")) + (&test.IDP).Key = test.Key + }, +} + +// applySigner will set the signer for the identity provider. +var applySigner = idpTestOpts{ + apply: func(t *testing.T, test *IdentityProviderTest) { + test.Signer = mustParsePrivateKey(golden.Get(t, "idp_key.pem")) + (&test.IDP).Signer = test.Signer + }, +} + +func NewIdentityProviderTest(t *testing.T, opts ...idpTestOpts) *IdentityProviderTest { test := IdentityProviderTest{} TimeNow = func() time.Time { rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Dec 1 01:57:09 UTC 2015") @@ -94,11 +117,9 @@ func NewIdentifyProviderTest(t *testing.T) *IdentityProviderTest { IDPMetadata: &EntityDescriptor{}, } - test.Key = mustParsePrivateKey(golden.Get(t, "idp_key.pem")) test.Certificate = mustParseCertificate(golden.Get(t, "idp_cert.pem")) test.IDP = IdentityProvider{ - Key: test.Key, Certificate: test.Certificate, Logger: logger.DefaultLogger, MetadataURL: mustParseURL("https://idp.example.com/saml/metadata"), @@ -118,6 +139,11 @@ func NewIdentifyProviderTest(t *testing.T) *IdentityProviderTest { }, } + // apply the test options + for _, opt := range opts { + opt.apply(t, &test) + } + // bind the service provider and the IDP test.SP.IDPMetadata = test.IDP.Metadata() return &test @@ -140,7 +166,7 @@ func (mspp *mockServiceProviderProvider) GetServiceProvider(r *http.Request, ser } func TestIDPCanProduceMetadata(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) expected := &EntityDescriptor{ ValidUntil: TimeNow().Add(DefaultValidDuration), CacheDuration: DefaultValidDuration, @@ -201,7 +227,7 @@ func TestIDPCanProduceMetadata(t *testing.T) { } func TestIDPHTTPCanHandleMetadataRequest(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) w := httptest.NewRecorder() r, _ := http.NewRequest("GET", "https://idp.example.com/saml/metadata", nil) test.IDP.Handler().ServeHTTP(w, r) @@ -212,7 +238,7 @@ func TestIDPHTTPCanHandleMetadataRequest(t *testing.T) { } func TestIDPCanHandleRequestWithNewSession(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { fmt.Fprintf(w, "RelayState: %s\nSAMLRequest: %s", @@ -238,7 +264,7 @@ func TestIDPCanHandleRequestWithNewSession(t *testing.T) { } func TestIDPCanHandleRequestWithExistingSession(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { return &Session{ @@ -263,7 +289,7 @@ func TestIDPCanHandleRequestWithExistingSession(t *testing.T) { } func TestIDPCanHandlePostRequestWithExistingSession(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { return &Session{ @@ -292,7 +318,7 @@ func TestIDPCanHandlePostRequestWithExistingSession(t *testing.T) { } func TestIDPRejectsInvalidRequest(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { panic("not reached") @@ -313,7 +339,7 @@ func TestIDPRejectsInvalidRequest(t *testing.T) { } func TestIDPCanParse(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) r, _ := http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&SAMLRequest=lJJBayoxFIX%2FypC9JhnU5wszAz7lgWCLaNtFd5fMbQ1MkmnunVb%2FfUfbUqEgdhs%2BTr5zkmLW8S5s8KVD4mzvm0Cl6FIwEciRCeCRDFuznd2sTD5Upk2Ro42NyGZEmNjFMI%2BBOo9pi%2BnVWbzfrEqxY27JSEntEPfg2waHNnpJ4JtcgiWRLfoLXYBjwDfu6p%2B8JIoiWy5K4eqBUipXIzVRUwXKKtRK53qkJ3qqQVuNPUjU4TIQQ%2BBS5EqPBzofKH2ntBn%2FMervo8jWnyX%2BuVC78FwKkT1gopNKX1JUxSklXTMIfM0gsv8xeeDL%2BPGk7%2FF0Qg0GdnwQ1cW5PDLUwFDID6uquO1Dlot1bJw9%2FPLRmia%2BzRMCYyk4dSiq6205QSDXOxfy3KAq5Pkvqt4DAAD%2F%2Fw%3D%3D", nil) req, err := NewIdpAuthnRequest(&test.IDP, r) assert.Check(t, err) @@ -337,7 +363,7 @@ func TestIDPCanParse(t *testing.T) { } func TestIDPCanValidate(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) req := IdpAuthnRequest{ Now: TimeNow(), IDP: &test.IDP, @@ -461,7 +487,7 @@ func TestIDPCanValidate(t *testing.T) { } func TestIDPMakeAssertion(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) req := IdpAuthnRequest{ Now: TimeNow(), IDP: &test.IDP, @@ -645,7 +671,7 @@ func TestIDPMakeAssertion(t *testing.T) { } func TestIDPMarshalAssertion(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) req := IdpAuthnRequest{ Now: TimeNow(), IDP: &test.IDP, @@ -693,8 +719,19 @@ func TestIDPMarshalAssertion(t *testing.T) { golden.Assert(t, string(assertionBuffer), t.Name()+"_encrypted_assertion") } -func TestIDPMakeResponse(t *testing.T) { - test := NewIdentifyProviderTest(t) +func TestIDPMakeResponsePrivateKey(t *testing.T) { + test := NewIdentityProviderTest(t, applyKey) + + testMakeResponse(t, test) +} + +func TestIDPMakeResponseSigner(t *testing.T) { + test := NewIdentityProviderTest(t, applySigner) + + testMakeResponse(t, test) +} + +func testMakeResponse(t *testing.T, test *IdentityProviderTest) { req := IdpAuthnRequest{ Now: TimeNow(), IDP: &test.IDP, @@ -715,6 +752,16 @@ func TestIDPMakeResponse(t *testing.T) { err = req.MakeResponse() assert.Check(t, err) + certificateStore := &dsig.MemoryX509CertificateStore{ + Roots: []*x509.Certificate{ + req.IDP.Certificate, + }, + } + validationCtx := dsig.NewDefaultValidationContext(certificateStore) + validationCtx.Clock = dsig.NewFakeClockAt(req.IDP.Certificate.NotBefore) + _, err = validationCtx.Validate(req.ResponseEl) + assert.Check(t, err) + response := Response{} err = unmarshalEtreeHack(req.ResponseEl, &response) assert.Check(t, err) @@ -724,11 +771,11 @@ func TestIDPMakeResponse(t *testing.T) { doc.Indent(2) responseStr, err := doc.WriteToString() assert.Check(t, err) - golden.Assert(t, responseStr, t.Name()+"_response.xml") + golden.Assert(t, responseStr, "TestIDPMakeResponse_response.xml") } func TestIDPWriteResponse(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) req := IdpAuthnRequest{ Now: TimeNow(), IDP: &test.IDP, @@ -748,7 +795,7 @@ func TestIDPWriteResponse(t *testing.T) { } func TestIDPIDPInitiatedNewSession(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { fmt.Fprintf(w, "RelayState: %s", req.RelayState) @@ -764,7 +811,7 @@ func TestIDPIDPInitiatedNewSession(t *testing.T) { } func TestIDPIDPInitiatedExistingSession(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { return &Session{ @@ -782,7 +829,7 @@ func TestIDPIDPInitiatedExistingSession(t *testing.T) { } func TestIDPIDPInitiatedBadServiceProvider(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { return &Session{ @@ -799,7 +846,7 @@ func TestIDPIDPInitiatedBadServiceProvider(t *testing.T) { } func TestIDPCanHandleUnencryptedResponse(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { return &Session{ID: "f00df00df00d", UserName: "alice"} @@ -847,7 +894,7 @@ func TestIDPCanHandleUnencryptedResponse(t *testing.T) { } func TestIDPRequestedAttributes(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) metadata := EntityDescriptor{} err := xml.Unmarshal(golden.Get(t, "TestIDPRequestedAttributes_idp_metadata.xml"), &metadata) assert.Check(t, err) @@ -977,7 +1024,7 @@ func TestIDPRequestedAttributes(t *testing.T) { } func TestIDPNoDestination(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { return &Session{ID: "f00df00df00d", UserName: "alice"} @@ -1017,7 +1064,7 @@ func TestIDPNoDestination(t *testing.T) { } func TestIDPRejectDecompressionBomb(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { fmt.Fprintf(w, "RelayState: %s\nSAMLRequest: %s", diff --git a/samlidp/samlidp.go b/samlidp/samlidp.go index 2141ca89..13ca10b9 100644 --- a/samlidp/samlidp.go +++ b/samlidp/samlidp.go @@ -20,6 +20,7 @@ import ( type Options struct { URL url.URL Key crypto.PrivateKey + Signer crypto.Signer Logger logger.Interface Certificate *x509.Certificate Store Store @@ -59,6 +60,7 @@ func New(opts Options) (*Server, error) { serviceProviders: map[string]*saml.EntityDescriptor{}, IDP: saml.IdentityProvider{ Key: opts.Key, + Signer: opts.Signer, Logger: logr, Certificate: opts.Certificate, MetadataURL: metadataURL,