-
Notifications
You must be signed in to change notification settings - Fork 1
/
storage.go
145 lines (125 loc) · 3.98 KB
/
storage.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package fmap
import (
"errors"
"fmt"
"reflect"
"unsafe"
)
var cache = map[reflect.Type]Storage{}
// Get returns a map of field objects.
// It takes a parameter `T` of type `any`, representing the type to be used for Fields map creation.
func Get[T any]() (Storage, error) {
var tt T
return getFrom(reflect.TypeOf(tt))
}
type storage struct {
asMap map[string]Field
paths []string
}
func (s *storage) Find(path string) (Field, bool) {
field, ok := s.asMap[path]
return field, ok
}
func (s *storage) MustFind(path string) Field {
return s.asMap[path]
}
func (s *storage) GetAllPaths() []string {
return s.paths
}
// GetFrom returns a map of field objects. It takes a parameter `obj` of type `interface{}` representing the object to be analyzed.
// The function first checks if the `obj` type is already in the cache, and if it exists, it returns the cached value.
// Otherwise, it creates a new empty map with storage.
func GetFrom(obj interface{}) (Storage, error) {
typeOf := reflect.TypeOf(obj)
return getFrom(typeOf)
}
func getFrom(typeOf reflect.Type) (Storage, error) {
if typeOf.Kind() == reflect.Struct {
typeOf = reflect.PointerTo(typeOf)
}
if typeOf.Kind() != reflect.Pointer ||
(typeOf.Kind() == reflect.Pointer && typeOf.Elem().Kind() != reflect.Struct) {
return nil, fmt.Errorf("not supported type: %v, only struct and ptr to struct is supported", typeOf)
}
if tFields, ok := cache[typeOf]; ok {
return tFields, nil
}
tFields := map[string]Field{}
count := new(int)
calculateFields(typeOf, count)
slice := make([]string, 0, *count)
getFieldsMapRecursive(typeOf, "", &tFields, &slice, 0)
cache[typeOf] = &storage{
asMap: tFields,
paths: slice,
}
return cache[typeOf], nil
}
func calculateFields(confTypeOf reflect.Type, count *int) {
if confTypeOf.Kind() == reflect.Ptr {
confTypeOf = confTypeOf.Elem()
}
*count += confTypeOf.NumField()
for i := 0; i < confTypeOf.NumField(); i++ {
if confTypeOf.Field(i).Type.Kind() == reflect.Struct {
calculateFields(confTypeOf.Field(i).Type, count)
}
}
}
func getFieldsMapRecursive(confTypeOf reflect.Type, path string, f *map[string]Field, s *[]string, offset uintptr) {
if confTypeOf.Kind() == reflect.Ptr {
confTypeOf = confTypeOf.Elem()
}
if path != "" {
path += "."
}
for i := 0; i < confTypeOf.NumField(); i++ {
fieldTypeOf := confTypeOf.Field(i)
var parent *field = nil
if path != "" {
parentPath := path[:len(path)-1]
parent, _ = (*f)[parentPath].(*field)
}
switch fieldTypeOf.Type.Kind() {
case reflect.Struct:
fld := &field{StructField: fieldTypeOf, structPath: path + fieldTypeOf.Name, parent: parent}
(*f)[path+fieldTypeOf.Name] = fld
*s = append(*s, fld.structPath)
getFieldsMapRecursive(fieldTypeOf.Type, path+fieldTypeOf.Name, f, s, offset+fieldTypeOf.Offset)
default:
fld := &field{StructField: fieldTypeOf, structPath: path + fieldTypeOf.Name, parent: parent}
fld.Offset = fld.Offset + offset
*s = append(*s, fld.structPath)
(*f)[path+fieldTypeOf.Name] = fld
}
}
}
func (s *storage) GetFieldByPtr(structPtr, fieldPtr any) (Field, error) {
fldType := reflect.TypeOf(fieldPtr)
if fldType.Kind() != reflect.Ptr {
return nil, fmt.Errorf("not supported type: %v, only ptr to types is supported", fldType)
}
if fldType.Kind() == reflect.Ptr {
fldType = fldType.Elem()
}
structType := reflect.TypeOf(structPtr)
if structType.Kind() != reflect.Ptr ||
(structType.Kind() == reflect.Ptr && structType.Elem().Kind() != reflect.Struct) {
return nil, fmt.Errorf("not supported type: %v, only struct and ptr to struct is supported", structType)
}
fPtr := ((*[2]unsafe.Pointer)(unsafe.Pointer(&fieldPtr)))[1]
sPtr := ((*[2]unsafe.Pointer)(unsafe.Pointer(&structPtr)))[1]
offset := uintptr(fPtr) - uintptr(sPtr)
for _, path := range s.GetAllPaths() {
fld := s.MustFind(path)
if fld.GetOffset() != offset {
continue
}
matchType := fld.GetType().Kind() == fldType.Kind()
if !matchType {
continue
}
return fld, nil
}
return nil, errors.New("field not found")
}