Skip to content

Commit 319684f

Browse files
Fix race in the checker around reference types (#1224)
1 parent a36d461 commit 319684f

File tree

3 files changed

+58
-18
lines changed

3 files changed

+58
-18
lines changed

checker/checker_test.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2355,15 +2355,14 @@ func TestCheck(t *testing.T) {
23552355
}
23562356

23572357
reg, err := types.NewRegistry(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{})
2358+
if err != nil {
2359+
t.Fatalf("types.NewRegistry() failed: %v", err)
2360+
}
23582361
if tc.env.optionalSyntax {
2359-
err = reg.RegisterType(types.OptionalType)
2360-
if err != nil {
2362+
if err := reg.RegisterType(types.OptionalType); err != nil {
23612363
t.Fatalf("reg.RegisterType(optional_type) failed: %v", err)
23622364
}
23632365
}
2364-
if err != nil {
2365-
t.Fatalf("types.NewRegistry() failed: %v", err)
2366-
}
23672366
cont, err := containers.NewContainer(containers.Name(tc.container))
23682367
if err != nil {
23692368
t.Fatalf("containers.NewContainer() failed: %v", err)
@@ -2453,6 +2452,11 @@ func BenchmarkCheck(b *testing.B) {
24532452
if err != nil {
24542453
b.Fatalf("types.NewRegistry() failed: %v", err)
24552454
}
2455+
if tc.env.optionalSyntax {
2456+
if err := reg.RegisterType(types.OptionalType); err != nil {
2457+
b.Fatalf("reg.RegisterType(optional_type) failed: %v", err)
2458+
}
2459+
}
24562460
cont, err := containers.NewContainer(containers.Name(tc.container))
24572461
if err != nil {
24582462
b.Fatalf("containers.NewContainer() failed: %v", err)

checker/env.go

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -137,29 +137,20 @@ func (e *Env) LookupIdent(name string) *decls.VariableDecl {
137137
return ident
138138
}
139139

140-
// Next try to import the name as a reference to a message type. If found,
141-
// the declaration is added to the outest (global) scope of the
142-
// environment, so next time we can access it faster.
140+
// Next try to import the name as a reference to a message type.
143141
if t, found := e.provider.FindStructType(candidate); found {
144-
decl := decls.NewVariable(candidate, t)
145-
e.declarations.AddIdent(decl)
146-
return decl
142+
return decls.NewVariable(candidate, t)
147143
}
148-
149144
if i, found := e.provider.FindIdent(candidate); found {
150145
if t, ok := i.(*types.Type); ok {
151-
decl := decls.NewVariable(candidate, types.NewTypeTypeWithParam(t))
152-
e.declarations.AddIdent(decl)
153-
return decl
146+
return decls.NewVariable(candidate, types.NewTypeTypeWithParam(t))
154147
}
155148
}
156149

157150
// Next try to import this as an enum value by splitting the name in a type prefix and
158151
// the enum inside.
159152
if enumValue := e.provider.EnumValue(candidate); enumValue.Type() != types.ErrType {
160-
decl := decls.NewConstant(candidate, types.IntType, enumValue)
161-
e.declarations.AddIdent(decl)
162-
return decl
153+
return decls.NewConstant(candidate, types.IntType, enumValue)
163154
}
164155
}
165156
return nil

ext/native_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,51 @@ func TestNativeTypesVersion(t *testing.T) {
10091009
}
10101010
}
10111011

1012+
type Custom struct {
1013+
Name string `cel:"name"`
1014+
}
1015+
1016+
func TestTypeResolutionRace(t *testing.T) {
1017+
customType := reflect.TypeFor[*Custom]()
1018+
env, err := cel.NewEnv(
1019+
cel.Container("ext"),
1020+
NativeTypes(
1021+
ParseStructTag("cel"),
1022+
customType,
1023+
),
1024+
)
1025+
if err != nil {
1026+
t.Fatal("NewEnv:", err)
1027+
}
1028+
1029+
tests := []struct {
1030+
name string
1031+
expr string
1032+
}{
1033+
{name: "custom1", expr: `Custom{ name: "name1" }`},
1034+
{name: "custom2", expr: `Custom{ name: "name2" }`},
1035+
{name: "custom3", expr: `Custom{ name: "name3" }`},
1036+
{name: "custom4", expr: `Custom{ name: "name4" }`},
1037+
{name: "custom5", expr: `Custom{ name: "name5" }`},
1038+
}
1039+
1040+
for _, test := range tests {
1041+
t.Run(test.name, func(t *testing.T) {
1042+
t.Parallel()
1043+
1044+
ast, iss := env.Compile(test.expr)
1045+
if err := iss.Err(); err != nil {
1046+
t.Fatal("Compile:", err)
1047+
}
1048+
prg, err := env.Program(ast)
1049+
if err != nil {
1050+
t.Fatalf("env.Program() failed: %s", err)
1051+
}
1052+
prg.Eval(cel.NoVars())
1053+
})
1054+
}
1055+
}
1056+
10121057
// testEnv initializes the test environment common to all tests.
10131058
func testNativeEnv(t *testing.T, opts ...any) *cel.Env {
10141059
t.Helper()

0 commit comments

Comments
 (0)