From b32830b8679e5b71c8cd38d366b1d39aca365d9b Mon Sep 17 00:00:00 2001 From: vbasiuk Date: Mon, 5 Feb 2024 19:19:24 +0200 Subject: [PATCH] resolve comments --- auth.go | 81 +++++++++++++++++++++++++------------------- pubsignals/common.go | 34 +++++++------------ 2 files changed, 59 insertions(+), 56 deletions(-) diff --git a/auth.go b/auth.go index 999ce1f..b15da53 100644 --- a/auth.go +++ b/auth.go @@ -310,8 +310,9 @@ func CreateContractInvokeRequestWithMessage( } } -func verifyAuthRequest(request protocol.AuthorizationRequestMessage) error { - groupIDValidationMap := make(map[int][]protocol.ZeroKnowledgeProofRequest) +// VerifyAuthRequest verifies auth request message +func VerifyAuthRequest(request protocol.AuthorizationRequestMessage) error { + groupIDValidationMap := make(map[int][]pubsignals.Query) for _, proofRequest := range request.Body.Scope { proofRequestQuery, err := unmarshalQuery(proofRequest.Query) @@ -320,30 +321,26 @@ func verifyAuthRequest(request protocol.AuthorizationRequestMessage) error { } groupID := proofRequestQuery.GroupID if groupID != 0 { - existingRequests := groupIDValidationMap[groupID] + existingQueries := groupIDValidationMap[groupID] // Validate that all requests in the group have the same schema, issuer, and circuit - for _, existingRequest := range existingRequests { - existingRequestQuery, err := unmarshalQuery(existingRequest.Query) - if err != nil { - return err - } - if existingRequestQuery.Type != proofRequestQuery.Type { + for _, existingQuery := range existingQueries { + if existingQuery.Type != proofRequestQuery.Type { return errors.New("all requests in the group should have the same type") } - if existingRequestQuery.Context != proofRequestQuery.Context { + if existingQuery.Context != proofRequestQuery.Context { return errors.New("all requests in the group should have the same context") } allowedIssuers := proofRequestQuery.AllowedIssuers - existingRequestAllowedIssuers := existingRequestQuery.AllowedIssuers + existingRequestAllowedIssuers := existingQuery.AllowedIssuers if !checkIssuersEquality(allowedIssuers, existingRequestAllowedIssuers) { return errors.New("all requests in the group should have the same issuer") } } - groupIDValidationMap[groupID] = append(existingRequests, proofRequest) + groupIDValidationMap[groupID] = append(existingQueries, proofRequestQuery) } } @@ -385,6 +382,11 @@ func checkIssuersEquality(issuers1, issuers2 []string) bool { return true } +type linkIdRequestId struct { + linkID *big.Int + requestID uint32 +} + // VerifyAuthResponse performs verification of auth response based on auth request func (v *Verifier) VerifyAuthResponse( ctx context.Context, @@ -401,12 +403,12 @@ func (v *Verifier) VerifyAuthResponse( return errors.Errorf("sender of the request is not a target of response - expected %s, given %s", request.From, response.To) } - err := verifyAuthRequest(request) + err := VerifyAuthRequest(request) if err != nil { return err } - groupIDToLinkIDMap := make(map[int][]map[string]*big.Int) + groupIDToLinkIDMap := make(map[int][]linkIdRequestId) for _, proofRequest := range request.Body.Scope { // prepare query from request query, err := unmarshalQuery(proofRequest.Query) @@ -475,32 +477,41 @@ func (v *Verifier) VerifyAuthResponse( return errors.Errorf("proof response doesn't contain from field") } - if pubSignals.LinkID != nil && groupID != 0 { - if existingLinks, exists := groupIDToLinkIDMap[groupID]; exists { - linkIDMap := map[string]*big.Int{"linkID": pubSignals.LinkID, "requestID": new(big.Int).SetUint64(uint64(proofResponse.ID))} - groupIDToLinkIDMap[groupID] = append(existingLinks, linkIDMap) - } else { - linkIDMap := map[string]*big.Int{"linkID": pubSignals.LinkID, "requestID": new(big.Int).SetUint64(uint64(proofResponse.ID))} - groupIDToLinkIDMap[groupID] = []map[string]*big.Int{linkIDMap} - } + err = verifyGroupIdMathch(pubSignals.LinkID, groupID, proofResponse.ID, groupIDToLinkIDMap) + if err != nil { + return err } - if groupID != 0 { - // verify grouping links - for groupIDfromMap, metas := range groupIDToLinkIDMap { - // Check that all linkIDs are the same - if len(metas) > 1 { - firstLinkID := metas[0]["linkID"] - for _, meta := range metas[1:] { - if meta["linkID"].Cmp(firstLinkID) != 0 { - return errors.Errorf("Link id validation failed for group %d, request linkID to requestIds info: %v", groupIDfromMap, metas) - } - } + + } + + return nil +} + +func verifyGroupIdMathch(linkID *big.Int, groupID int, requestID uint32, groupIDToLinkIDMap map[int][]linkIdRequestId) error { + if groupID == 0 { + return nil + } + if linkID != nil { + if existingLinks, exists := groupIDToLinkIDMap[groupID]; exists { + linkIDMap := linkIdRequestId{linkID: linkID, requestID: requestID} + groupIDToLinkIDMap[groupID] = append(existingLinks, linkIDMap) + } else { + linkIDMap := linkIdRequestId{linkID: linkID, requestID: requestID} + groupIDToLinkIDMap[groupID] = []linkIdRequestId{linkIDMap} + } + } + // verify grouping links + for groupIDfromMap, metas := range groupIDToLinkIDMap { + // Check that all linkIDs are the same + if len(metas) > 1 { + firstLinkID := metas[0].linkID + for _, meta := range metas[1:] { + if meta.linkID.Cmp(firstLinkID) != 0 { + return errors.Errorf("Link id validation failed for group %d, request linkID to requestIds info: %v", groupIDfromMap, metas) } } } - } - return nil } diff --git a/pubsignals/common.go b/pubsignals/common.go index 78b0e78..9b4b243 100644 --- a/pubsignals/common.go +++ b/pubsignals/common.go @@ -58,12 +58,11 @@ func ParseCredentialSubject(_ context.Context, credentialSubject any) (out []Pro return nil, errors.New("Failed to convert credential subject to JSONObject") } - entries := getObjectEntries(jsonObject) - if len(entries) == 0 { + if len(jsonObject) == 0 { return nil, errors.New("query must have at least 1 predicate") } - for fieldName, fieldReq := range entries { - fieldReqEntries := getObjectEntries(fieldReq.(map[string]interface{})) + for fieldName, fieldReq := range jsonObject { + fieldReqEntries := fieldReq.(map[string]interface{}) isSelectiveDisclosure := len(fieldReqEntries) == 0 if isSelectiveDisclosure { @@ -83,14 +82,6 @@ func ParseCredentialSubject(_ context.Context, credentialSubject any) (out []Pro return out, nil } -func getObjectEntries(obj map[string]interface{}) map[string]interface{} { - entries := make(map[string]interface{}) - for k, v := range obj { - entries[k] = v - } - return entries -} - // ParseQueryMetadata parse property query and return query metadata func ParseQueryMetadata(ctx context.Context, propertyQuery PropertyQuery, ldContextJSON, credentialType string, options merklize.Options) (query *QueryMetadata, err error) { datatype, err := options.TypeFromContext([]byte(ldContextJSON), fmt.Sprintf("%s.%s", credentialType, propertyQuery.FieldName)) @@ -139,17 +130,18 @@ func ParseQueryMetadata(ctx context.Context, propertyQuery PropertyQuery, ldCont if err != nil { return nil, err } - err = path.Prepend(credentialSubjectFullKey) - if err != nil { - return nil, err - } + } - query.ClaimPathKey, err = path.MtEntry() - if err != nil { - return nil, err - } - query.Path = &path + err = path.Prepend(credentialSubjectFullKey) + if err != nil { + return nil, err + } + + query.ClaimPathKey, err = path.MtEntry() + if err != nil { + return nil, err } + query.Path = &path if propertyQuery.OperatorValue != nil { if !IsValidOperation(datatype, propertyQuery.Operator) {