Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent overwriting schema validations #575

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion formdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func multiPartFormFileSchema(t reflect.Type) *Schema {
continue
}

if _, ok := f.Tag.Lookup("required"); ok && boolTag(f, "required") {
if _, ok := f.Tag.Lookup("required"); ok && boolTag(f, "required", false) {
requiredFields[i] = name
schema.requiredMap[name] = true
}
Expand Down
2 changes: 1 addition & 1 deletion huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func findParams(registry Registry, op *Operation, t reflect.Type) *findResult[*p
pfi.TimeFormat = timeFormat
}

if !boolTag(f, "hidden") {
if !boolTag(f, "hidden", false) {
desc := ""
if pfi.Schema != nil {
// If the schema has a description, use it. Some tools will not show
Expand Down
114 changes: 48 additions & 66 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ func (s *Schema) PrecomputeMessages() {
}
}

func boolTag(f reflect.StructField, tag string) bool {
func boolTag(f reflect.StructField, tag string, def bool) bool {
if v := f.Tag.Get(tag); v != "" {
if v == "true" {
return true
Expand All @@ -339,29 +339,36 @@ func boolTag(f reflect.StructField, tag string) bool {
panic(fmt.Errorf("invalid bool tag '%s' for field '%s': %v", tag, f.Name, v))
}
}
return false
return def
}

func intTag(f reflect.StructField, tag string) *int {
func intTag(f reflect.StructField, tag string, def *int) *int {
if v := f.Tag.Get(tag); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return &i
} else {
panic(fmt.Errorf("invalid int tag '%s' for field '%s': %v (%w)", tag, f.Name, v, err))
}
}
return nil
return def
}

func floatTag(f reflect.StructField, tag string) *float64 {
func floatTag(f reflect.StructField, tag string, def *float64) *float64 {
if v := f.Tag.Get(tag); v != "" {
if i, err := strconv.ParseFloat(v, 64); err == nil {
return &i
} else {
panic(fmt.Errorf("invalid float tag '%s' for field '%s': %v (%w)", tag, f.Name, v, err))
}
}
return nil
return def
}

func stringTag(f reflect.StructField, tag string, def string) string {
if v := f.Tag.Get(tag); v != "" {
return v
}
return def
}

// ensureType panics if the given value does not match the JSON Schema type.
Expand Down Expand Up @@ -508,18 +515,14 @@ func SchemaFromField(registry Registry, f reflect.StructField, hint string) *Sch
if fs == nil {
return fs
}
if doc := f.Tag.Get("doc"); doc != "" {
fs.Description = doc
}
fs.Description = stringTag(f, "doc", fs.Description)
if fs.Format == "date-time" && f.Tag.Get("header") != "" {
// Special case: this is a header and uses a different date/time format.
// Note that it can still be overridden by the `format` or `timeFormat`
// tags later.
fs.Format = "date-time-http"
}
if format := f.Tag.Get("format"); format != "" {
fs.Format = format
}
fs.Format = stringTag(f, "format", fs.Format)
if timeFmt := f.Tag.Get("timeFormat"); timeFmt != "" {
switch timeFmt {
case "2006-01-02":
Expand All @@ -530,9 +533,7 @@ func SchemaFromField(registry Registry, f reflect.StructField, hint string) *Sch
fs.Format = timeFmt
}
}
if enc := f.Tag.Get("encoding"); enc != "" {
fs.ContentEncoding = enc
}
fs.ContentEncoding = stringTag(f, "encoding", fs.ContentEncoding)
if defaultValue := jsonTag(registry, f, fs, "default"); defaultValue != nil {
fs.Default = defaultValue
}
Expand All @@ -559,56 +560,37 @@ func SchemaFromField(registry Registry, f reflect.StructField, hint string) *Sch
}
}

if _, ok := f.Tag.Lookup("nullable"); ok {
fs.Nullable = boolTag(f, "nullable")
if fs.Nullable && fs.Ref != "" {
// Nullability is only supported for scalar types for now. Objects are
// much more complicated because the `null` type lives within the object
// definition (requiring multiple copies of the object) or needs to use
// `anyOf` or `not` which is not supported by all code generators, or is
// supported poorly & generates hard-to-use code. This is less than ideal
// but a compromise for now to support some nullability built-in.
panic(fmt.Errorf("nullable is not supported for field '%s' which is type '%s'", f.Name, fs.Ref))
}
}

if _, ok := f.Tag.Lookup("minimum"); ok {
fs.Minimum = floatTag(f, "minimum")
}

fs.ExclusiveMinimum = floatTag(f, "exclusiveMinimum")

if _, ok := f.Tag.Lookup("maximum"); ok {
fs.Maximum = floatTag(f, "maximum")
}
fs.ExclusiveMaximum = floatTag(f, "exclusiveMaximum")
fs.MultipleOf = floatTag(f, "multipleOf")
if _, ok := f.Tag.Lookup("minLength"); ok {
fs.MinLength = intTag(f, "minLength")
}

if _, ok := f.Tag.Lookup("maxLength"); ok {
fs.MaxLength = intTag(f, "maxLength")
}
fs.Pattern = f.Tag.Get("pattern")
fs.PatternDescription = f.Tag.Get("patternDescription")
if _, ok := f.Tag.Lookup("minItems"); ok {
fs.MinItems = intTag(f, "minItems")
}
if _, ok := f.Tag.Lookup("maxItems"); ok {
fs.MaxItems = intTag(f, "maxItems")
}
fs.UniqueItems = boolTag(f, "uniqueItems")
fs.MinProperties = intTag(f, "minProperties")
fs.MaxProperties = intTag(f, "maxProperties")
fs.ReadOnly = boolTag(f, "readOnly")
fs.WriteOnly = boolTag(f, "writeOnly")
fs.Deprecated = boolTag(f, "deprecated")
fs.Nullable = boolTag(f, "nullable", fs.Nullable)
if fs.Nullable && fs.Ref != "" {
// Nullability is only supported for scalar types for now. Objects are
// much more complicated because the `null` type lives within the object
// definition (requiring multiple copies of the object) or needs to use
// `anyOf` or `not` which is not supported by all code generators, or is
// supported poorly & generates hard-to-use code. This is less than ideal
// but a compromise for now to support some nullability built-in.
panic(fmt.Errorf("nullable is not supported for field '%s' which is type '%s'", f.Name, fs.Ref))
}

fs.Minimum = floatTag(f, "minimum", fs.Minimum)
fs.ExclusiveMinimum = floatTag(f, "exclusiveMinimum", fs.ExclusiveMinimum)
fs.Maximum = floatTag(f, "maximum", fs.Maximum)
fs.ExclusiveMaximum = floatTag(f, "exclusiveMaximum", fs.ExclusiveMaximum)
fs.MultipleOf = floatTag(f, "multipleOf", fs.MultipleOf)
fs.MinLength = intTag(f, "minLength", fs.MinLength)
fs.MaxLength = intTag(f, "maxLength", fs.MaxLength)
fs.Pattern = stringTag(f, "pattern", fs.Pattern)
fs.PatternDescription = stringTag(f, "patternDescription", fs.PatternDescription)
fs.MinItems = intTag(f, "minItems", fs.MinItems)
fs.MaxItems = intTag(f, "maxItems", fs.MaxItems)
fs.UniqueItems = boolTag(f, "uniqueItems", fs.UniqueItems)
fs.MinProperties = intTag(f, "minProperties", fs.MinProperties)
fs.MaxProperties = intTag(f, "maxProperties", fs.MaxProperties)
fs.ReadOnly = boolTag(f, "readOnly", fs.ReadOnly)
fs.WriteOnly = boolTag(f, "writeOnly", fs.WriteOnly)
fs.Deprecated = boolTag(f, "deprecated", fs.Deprecated)
fs.PrecomputeMessages()

if v := f.Tag.Get("hidden"); v != "" {
fs.hidden = boolTag(f, "hidden")
}
fs.hidden = boolTag(f, "hidden", fs.hidden)

return fs
}
Expand Down Expand Up @@ -830,7 +812,7 @@ func schemaFromType(r Registry, t reflect.Type) *Schema {
}

if _, ok := f.Tag.Lookup("required"); ok {
fieldRequired = boolTag(f, "required")
fieldRequired = boolTag(f, "required", false)
}

if dr := f.Tag.Get("dependentRequired"); strings.TrimSpace(dr) != "" {
Expand Down Expand Up @@ -885,12 +867,12 @@ func schemaFromType(r Registry, t reflect.Type) *Schema {
additionalProps := false
if f, ok := t.FieldByName("_"); ok {
if _, ok = f.Tag.Lookup("additionalProperties"); ok {
additionalProps = boolTag(f, "additionalProperties")
additionalProps = boolTag(f, "additionalProperties", false)
}

if _, ok := f.Tag.Lookup("nullable"); ok {
// Allow overriding nullability per struct.
s.Nullable = boolTag(f, "nullable")
s.Nullable = boolTag(f, "nullable", false)
}
}
s.AdditionalProperties = additionalProps
Expand Down
3 changes: 3 additions & 0 deletions schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,7 @@ type ExampleInputStruct struct {
Email string `json:"email" format:"email" doc:"Contact e-mail address"`
Age *int `json:"age,omitempty" minimum:"0"`
Comment string `json:"comment,omitempty" maxLength:"256"`
Pattern string `json:"pattern" pattern:"^[a-z]+$"`
}

// Implements SchemaTransformer interface, reusing parts of the schema from `ExampleInputStruct`
Expand All @@ -1419,6 +1420,7 @@ type ExampleUpdateStruct struct {
Email *string `json:"email" doc:"Override doc for email"`
Age OmittableNullable[int] `json:"age"`
Comment OmittableNullable[string] `json:"comment"`
Pattern string `json:"pattern"`
}

func (u *ExampleUpdateStruct) TransformSchema(r huma.Registry, s *huma.Schema) *huma.Schema {
Expand Down Expand Up @@ -1449,6 +1451,7 @@ func TestSchemaTransformer(t *testing.T) {
assert.True(t, s.Properties["age"].Nullable)
assert.Equal(t, inputSchema.Properties["comment"].MaxLength, s.Properties["comment"].MaxLength)
assert.True(t, s.Properties["comment"].Nullable)
assert.Equal(t, inputSchema.Properties["pattern"].Pattern, s.Properties["pattern"].Pattern)
}
updateSchema1 := r.Schema(reflect.TypeOf(ExampleUpdateStruct{}), false, "")
validateSchema(updateSchema1)
Expand Down