From ca4b5f62f22a11aac89698ff505c17b17e0b6d72 Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Tue, 17 Sep 2024 22:27:54 -0700 Subject: [PATCH] fix: prevent overwriting schema validations --- formdata.go | 2 +- huma.go | 2 +- schema.go | 114 +++++++++++++++++++++---------------------------- schema_test.go | 3 ++ 4 files changed, 53 insertions(+), 68 deletions(-) diff --git a/formdata.go b/formdata.go index 1a33b3d..f153348 100644 --- a/formdata.go +++ b/formdata.go @@ -226,7 +226,7 @@ func multiPartFormFileSchema(t reflect.Type) *Schema { continue } - if _, ok := f.Tag.Lookup("required"); ok && boolTag(f, "required") { + if _, ok := f.Tag.Lookup("required"); ok && boolTag(f, "required", false) { requiredFields[i] = name schema.requiredMap[name] = true } diff --git a/huma.go b/huma.go index 940ebcc..0a6c2cc 100644 --- a/huma.go +++ b/huma.go @@ -174,7 +174,7 @@ func findParams(registry Registry, op *Operation, t reflect.Type) *findResult[*p pfi.TimeFormat = timeFormat } - if !boolTag(f, "hidden") { + if !boolTag(f, "hidden", false) { desc := "" if pfi.Schema != nil { // If the schema has a description, use it. Some tools will not show diff --git a/schema.go b/schema.go index 5721fb6..b1e9b28 100644 --- a/schema.go +++ b/schema.go @@ -329,7 +329,7 @@ func (s *Schema) PrecomputeMessages() { } } -func boolTag(f reflect.StructField, tag string) bool { +func boolTag(f reflect.StructField, tag string, def bool) bool { if v := f.Tag.Get(tag); v != "" { if v == "true" { return true @@ -339,10 +339,10 @@ func boolTag(f reflect.StructField, tag string) bool { panic(fmt.Errorf("invalid bool tag '%s' for field '%s': %v", tag, f.Name, v)) } } - return false + return def } -func intTag(f reflect.StructField, tag string) *int { +func intTag(f reflect.StructField, tag string, def *int) *int { if v := f.Tag.Get(tag); v != "" { if i, err := strconv.Atoi(v); err == nil { return &i @@ -350,10 +350,10 @@ func intTag(f reflect.StructField, tag string) *int { panic(fmt.Errorf("invalid int tag '%s' for field '%s': %v (%w)", tag, f.Name, v, err)) } } - return nil + return def } -func floatTag(f reflect.StructField, tag string) *float64 { +func floatTag(f reflect.StructField, tag string, def *float64) *float64 { if v := f.Tag.Get(tag); v != "" { if i, err := strconv.ParseFloat(v, 64); err == nil { return &i @@ -361,7 +361,14 @@ func floatTag(f reflect.StructField, tag string) *float64 { panic(fmt.Errorf("invalid float tag '%s' for field '%s': %v (%w)", tag, f.Name, v, err)) } } - return nil + return def +} + +func stringTag(f reflect.StructField, tag string, def string) string { + if v := f.Tag.Get(tag); v != "" { + return v + } + return def } // ensureType panics if the given value does not match the JSON Schema type. @@ -508,18 +515,14 @@ func SchemaFromField(registry Registry, f reflect.StructField, hint string) *Sch if fs == nil { return fs } - if doc := f.Tag.Get("doc"); doc != "" { - fs.Description = doc - } + fs.Description = stringTag(f, "doc", fs.Description) if fs.Format == "date-time" && f.Tag.Get("header") != "" { // Special case: this is a header and uses a different date/time format. // Note that it can still be overridden by the `format` or `timeFormat` // tags later. fs.Format = "date-time-http" } - if format := f.Tag.Get("format"); format != "" { - fs.Format = format - } + fs.Format = stringTag(f, "format", fs.Format) if timeFmt := f.Tag.Get("timeFormat"); timeFmt != "" { switch timeFmt { case "2006-01-02": @@ -530,9 +533,7 @@ func SchemaFromField(registry Registry, f reflect.StructField, hint string) *Sch fs.Format = timeFmt } } - if enc := f.Tag.Get("encoding"); enc != "" { - fs.ContentEncoding = enc - } + fs.ContentEncoding = stringTag(f, "encoding", fs.ContentEncoding) if defaultValue := jsonTag(registry, f, fs, "default"); defaultValue != nil { fs.Default = defaultValue } @@ -559,56 +560,37 @@ func SchemaFromField(registry Registry, f reflect.StructField, hint string) *Sch } } - if _, ok := f.Tag.Lookup("nullable"); ok { - fs.Nullable = boolTag(f, "nullable") - if fs.Nullable && fs.Ref != "" { - // Nullability is only supported for scalar types for now. Objects are - // much more complicated because the `null` type lives within the object - // definition (requiring multiple copies of the object) or needs to use - // `anyOf` or `not` which is not supported by all code generators, or is - // supported poorly & generates hard-to-use code. This is less than ideal - // but a compromise for now to support some nullability built-in. - panic(fmt.Errorf("nullable is not supported for field '%s' which is type '%s'", f.Name, fs.Ref)) - } - } - - if _, ok := f.Tag.Lookup("minimum"); ok { - fs.Minimum = floatTag(f, "minimum") - } - - fs.ExclusiveMinimum = floatTag(f, "exclusiveMinimum") - - if _, ok := f.Tag.Lookup("maximum"); ok { - fs.Maximum = floatTag(f, "maximum") - } - fs.ExclusiveMaximum = floatTag(f, "exclusiveMaximum") - fs.MultipleOf = floatTag(f, "multipleOf") - if _, ok := f.Tag.Lookup("minLength"); ok { - fs.MinLength = intTag(f, "minLength") - } - - if _, ok := f.Tag.Lookup("maxLength"); ok { - fs.MaxLength = intTag(f, "maxLength") - } - fs.Pattern = f.Tag.Get("pattern") - fs.PatternDescription = f.Tag.Get("patternDescription") - if _, ok := f.Tag.Lookup("minItems"); ok { - fs.MinItems = intTag(f, "minItems") - } - if _, ok := f.Tag.Lookup("maxItems"); ok { - fs.MaxItems = intTag(f, "maxItems") - } - fs.UniqueItems = boolTag(f, "uniqueItems") - fs.MinProperties = intTag(f, "minProperties") - fs.MaxProperties = intTag(f, "maxProperties") - fs.ReadOnly = boolTag(f, "readOnly") - fs.WriteOnly = boolTag(f, "writeOnly") - fs.Deprecated = boolTag(f, "deprecated") + fs.Nullable = boolTag(f, "nullable", fs.Nullable) + if fs.Nullable && fs.Ref != "" { + // Nullability is only supported for scalar types for now. Objects are + // much more complicated because the `null` type lives within the object + // definition (requiring multiple copies of the object) or needs to use + // `anyOf` or `not` which is not supported by all code generators, or is + // supported poorly & generates hard-to-use code. This is less than ideal + // but a compromise for now to support some nullability built-in. + panic(fmt.Errorf("nullable is not supported for field '%s' which is type '%s'", f.Name, fs.Ref)) + } + + fs.Minimum = floatTag(f, "minimum", fs.Minimum) + fs.ExclusiveMinimum = floatTag(f, "exclusiveMinimum", fs.ExclusiveMinimum) + fs.Maximum = floatTag(f, "maximum", fs.Maximum) + fs.ExclusiveMaximum = floatTag(f, "exclusiveMaximum", fs.ExclusiveMaximum) + fs.MultipleOf = floatTag(f, "multipleOf", fs.MultipleOf) + fs.MinLength = intTag(f, "minLength", fs.MinLength) + fs.MaxLength = intTag(f, "maxLength", fs.MaxLength) + fs.Pattern = stringTag(f, "pattern", fs.Pattern) + fs.PatternDescription = stringTag(f, "patternDescription", fs.PatternDescription) + fs.MinItems = intTag(f, "minItems", fs.MinItems) + fs.MaxItems = intTag(f, "maxItems", fs.MaxItems) + fs.UniqueItems = boolTag(f, "uniqueItems", fs.UniqueItems) + fs.MinProperties = intTag(f, "minProperties", fs.MinProperties) + fs.MaxProperties = intTag(f, "maxProperties", fs.MaxProperties) + fs.ReadOnly = boolTag(f, "readOnly", fs.ReadOnly) + fs.WriteOnly = boolTag(f, "writeOnly", fs.WriteOnly) + fs.Deprecated = boolTag(f, "deprecated", fs.Deprecated) fs.PrecomputeMessages() - if v := f.Tag.Get("hidden"); v != "" { - fs.hidden = boolTag(f, "hidden") - } + fs.hidden = boolTag(f, "hidden", fs.hidden) return fs } @@ -830,7 +812,7 @@ func schemaFromType(r Registry, t reflect.Type) *Schema { } if _, ok := f.Tag.Lookup("required"); ok { - fieldRequired = boolTag(f, "required") + fieldRequired = boolTag(f, "required", false) } if dr := f.Tag.Get("dependentRequired"); strings.TrimSpace(dr) != "" { @@ -885,12 +867,12 @@ func schemaFromType(r Registry, t reflect.Type) *Schema { additionalProps := false if f, ok := t.FieldByName("_"); ok { if _, ok = f.Tag.Lookup("additionalProperties"); ok { - additionalProps = boolTag(f, "additionalProperties") + additionalProps = boolTag(f, "additionalProperties", false) } if _, ok := f.Tag.Lookup("nullable"); ok { // Allow overriding nullability per struct. - s.Nullable = boolTag(f, "nullable") + s.Nullable = boolTag(f, "nullable", false) } } s.AdditionalProperties = additionalProps diff --git a/schema_test.go b/schema_test.go index 4263f7a..bac7e46 100644 --- a/schema_test.go +++ b/schema_test.go @@ -1411,6 +1411,7 @@ type ExampleInputStruct struct { Email string `json:"email" format:"email" doc:"Contact e-mail address"` Age *int `json:"age,omitempty" minimum:"0"` Comment string `json:"comment,omitempty" maxLength:"256"` + Pattern string `json:"pattern" pattern:"^[a-z]+$"` } // Implements SchemaTransformer interface, reusing parts of the schema from `ExampleInputStruct` @@ -1419,6 +1420,7 @@ type ExampleUpdateStruct struct { Email *string `json:"email" doc:"Override doc for email"` Age OmittableNullable[int] `json:"age"` Comment OmittableNullable[string] `json:"comment"` + Pattern string `json:"pattern"` } func (u *ExampleUpdateStruct) TransformSchema(r huma.Registry, s *huma.Schema) *huma.Schema { @@ -1449,6 +1451,7 @@ func TestSchemaTransformer(t *testing.T) { assert.True(t, s.Properties["age"].Nullable) assert.Equal(t, inputSchema.Properties["comment"].MaxLength, s.Properties["comment"].MaxLength) assert.True(t, s.Properties["comment"].Nullable) + assert.Equal(t, inputSchema.Properties["pattern"].Pattern, s.Properties["pattern"].Pattern) } updateSchema1 := r.Schema(reflect.TypeOf(ExampleUpdateStruct{}), false, "") validateSchema(updateSchema1)