Skip to content

Commit

Permalink
ext.NativeTypes: Recursively add sub-types (#892)
Browse files Browse the repository at this point in the history
This change extends the `NativeTypes` provider to not only add the
passed-in type but also all of its sub-types in order to simplify using
it in the context of nested structs.
  • Loading branch information
alvaroaleman authored Feb 1, 2024
1 parent ba58735 commit 5883379
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 6 deletions.
48 changes: 44 additions & 4 deletions ext/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,21 @@ func newNativeTypeProvider(adapter types.Adapter, provider types.Provider, refTy
for _, refType := range refTypes {
switch rt := refType.(type) {
case reflect.Type:
t, err := newNativeType(rt)
result, err := newNativeTypes(rt)
if err != nil {
return nil, err
}
nativeTypes[t.TypeName()] = t
for idx := range result {
nativeTypes[result[idx].TypeName()] = result[idx]
}
case reflect.Value:
t, err := newNativeType(rt.Type())
result, err := newNativeTypes(rt.Type())
if err != nil {
return nil, err
}
nativeTypes[t.TypeName()] = t
for idx := range result {
nativeTypes[result[idx].TypeName()] = result[idx]
}
default:
return nil, fmt.Errorf("unsupported native type: %v (%T) must be reflect.Type or reflect.Value", rt, rt)
}
Expand Down Expand Up @@ -465,6 +469,42 @@ func (o *nativeObj) Value() any {
return o.val
}

func newNativeTypes(rawType reflect.Type) ([]*nativeType, error) {
nt, err := newNativeType(rawType)
if err != nil {
return nil, err
}
result := []*nativeType{nt}

alreadySeen := make(map[string]struct{})
var iterateStructMembers func(reflect.Type)
iterateStructMembers = func(t reflect.Type) {
if k := t.Kind(); k == reflect.Pointer || k == reflect.Slice || k == reflect.Array || k == reflect.Map {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return
}
if _, seen := alreadySeen[t.String()]; seen {
return
}
alreadySeen[t.String()] = struct{}{}
nt, ntErr := newNativeType(t)
if ntErr != nil {
err = ntErr
return
}
result = append(result, nt)

for idx := 0; idx < t.NumField(); idx++ {
iterateStructMembers(t.Field(idx).Type)
}
}
iterateStructMembers(rawType)

return result, err
}

func newNativeType(rawType reflect.Type) (*nativeType, error) {
refType := rawType
if refType.Kind() == reflect.Pointer {
Expand Down
17 changes: 15 additions & 2 deletions ext/native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ func TestNativeTypes(t *testing.T) {
},
],
MapVal: {'map-key': ext.TestAllTypes{BoolVal: true}},
CustomSliceVal: [ext.TestNestedSliceType{Value: 'none'}],
CustomMapVal: {'even': ext.TestMapVal{Value: 'more'}},
}`,
out: &TestAllTypes{
NestedVal: &TestNestedType{NestedMapVal: map[int64]bool{1: false}},
Expand All @@ -83,7 +85,9 @@ func TestNativeTypes(t *testing.T) {
NestedMapVal: map[int64]bool{42: true},
},
},
MapVal: map[string]TestAllTypes{"map-key": {BoolVal: true}},
MapVal: map[string]TestAllTypes{"map-key": {BoolVal: true}},
CustomSliceVal: []TestNestedSliceType{{Value: "none"}},
CustomMapVal: map[string]TestMapVal{"even": {Value: "more"}},
},
},
{
Expand Down Expand Up @@ -645,7 +649,6 @@ func testNativeEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env {
envOpts = append(envOpts, opts...)
envOpts = append(envOpts,
NativeTypes(
reflect.TypeOf(&TestNestedType{}),
reflect.ValueOf(&TestAllTypes{}),
),
)
Expand Down Expand Up @@ -687,6 +690,8 @@ type TestAllTypes struct {
ListVal []*TestNestedType
MapVal map[string]TestAllTypes
PbVal *proto3pb.TestAllTypes
CustomSliceVal []TestNestedSliceType
CustomMapVal map[string]TestMapVal

// channel types are not supported
UnsupportedVal chan string
Expand All @@ -696,3 +701,11 @@ type TestAllTypes struct {
// unexported types can be found but not set or accessed
privateVal map[string]string
}

type TestNestedSliceType struct {
Value string
}

type TestMapVal struct {
Value string
}

0 comments on commit 5883379

Please sign in to comment.