Skip to content

Commit

Permalink
opt: check UDFs when checking metadata dependencies
Browse files Browse the repository at this point in the history
This patch ensures that the metadata dependency-checking tracks user-defined
functions. This ensures that a cached query with a UDF reference will be
invalidated when the UDF is altered or dropped, or when the database is
switched.

Fixes cockroachdb#93082
Fixes cockroachdb#93321

Release note (bug fix): The query cache now checks to ensure that
user-defined functions referenced in the query have been altered or
dropped. This prevents a bug that could cause a query to return the
same result even after a UDF was dropped or the database was switched.
  • Loading branch information
DrewKimball committed Feb 27, 2023
1 parent 9beb8ee commit b30310c
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 18 deletions.
1 change: 1 addition & 0 deletions pkg/sql/catalog/funcdesc/func_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ func (desc *immutable) ToOverload() (ret *tree.Overload, err error) {
ReturnSet: desc.ReturnType.ReturnSet,
Body: desc.FunctionBody,
IsUDF: true,
Version: uint64(desc.Version),
}

argTypes := make(tree.ParamTypes, 0, len(desc.Params))
Expand Down
74 changes: 74 additions & 0 deletions pkg/sql/logictest/testdata/logic_test/schema
Original file line number Diff line number Diff line change
Expand Up @@ -1050,3 +1050,77 @@ statement ok
USE test;
DROP DATABASE d;
DROP TABLE xy;

subtest alter_udf_schema

# Renaming the schema should invalidate a schema-qualified UDF reference.
statement ok
CREATE SCHEMA sc;
CREATE FUNCTION sc.fn(INT) RETURNS INT LANGUAGE SQL AS 'SELECT $1';

query I
SELECT sc.fn(1);
----
1

statement ok
ALTER SCHEMA sc RENAME TO sc1;

query error pq: schema "sc" does not exist
SELECT sc.fn(1);

query I
SELECT sc1.fn(1);
----
1

statement ok
DROP SCHEMA sc1 CASCADE;

# Renaming the database should invalidate a database-qualified UDF reference.
statement ok
CREATE DATABASE d;
USE d;
CREATE FUNCTION fn(INT) RETURNS INT LANGUAGE SQL AS 'SELECT $1';

query I
SELECT d.public.fn(1);
----
1

statement ok
ALTER DATABASE d RENAME TO d1;
USE d1;

query error cross-database function references not allowed
SELECT d.public.fn(1);

query I
SELECT d1.public.fn(1);
----
1

statement ok
USE test;
DROP DATABASE d1 CASCADE;

# Changing the current database should invalidate an unqualified UDF reference.
statement ok
CREATE FUNCTION fn(INT) RETURNS INT LANGUAGE SQL AS 'SELECT $1';

query I
SELECT fn(1);
----
1

statement ok
CREATE DATABASE d;
USE d;

query error pq: unknown function: fn\(\): function undefined
SELECT fn(1);

statement ok
USE test;
DROP DATABASE d;
DROP FUNCTION fn;
64 changes: 57 additions & 7 deletions pkg/sql/logictest/testdata/logic_test/udf
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,7 @@ DROP FUNCTION sc2.f_tbl()

statement ok
ALTER DATABASE rename_db1 RENAME TO rename_db2;
USE rename_db2;

# Make sure that db renaming does not affect types and sequences in UDF.
query T
Expand Down Expand Up @@ -1661,7 +1662,7 @@ Mon
query I
SELECT sc1.f_seq()
----
5
1

query T
SELECT sc2.f_type()
Expand All @@ -1671,7 +1672,7 @@ Mon
query I
SELECT sc2.f_seq()
----
6
2

statement error pq: cannot rename schema because relation "rename_sc1.sc1.f_tbl" depends on relation "rename_sc1.sc1.tbl"
ALTER SCHEMA sc1 RENAME TO sc1_new
Expand All @@ -1688,16 +1689,23 @@ DROP FUNCTION sc2.f_tbl()
statement ok
ALTER SCHEMA sc1 RENAME TO sc1_new

# Make sure that db renaming does not affect types and sequences in UDF.
query T
# Cannot refer to the old schema name.
statement error pq: schema "sc1" does not exist
SELECT sc1.f_type()

statement error pq: schema "sc1" does not exist
SELECT sc1.f_seq()

# Make sure that schema renaming does not affect types and sequences in UDF.
query T
SELECT sc1_new.f_type()
----
Mon

query I
SELECT sc1.f_seq()
SELECT sc1_new.f_seq()
----
7
3

query T
SELECT sc2.f_type()
Expand All @@ -1707,7 +1715,7 @@ Mon
query I
SELECT sc2.f_seq()
----
8
4

statement ok
SET DATABASE = test
Expand Down Expand Up @@ -3116,3 +3124,45 @@ query I
SELECT f_94146(1::INT2)
----
2

# Regression test for #93082 - invalidate a cached query with a UDF if the UDF
# has been dropped.
subtest regression_93082

statement ok
CREATE FUNCTION fn(a INT) RETURNS INT LANGUAGE SQL AS 'SELECT a';

query I
SELECT fn(1);
----
1

statement ok
DROP FUNCTION fn;

statement error pq: unknown function: fn\(\): function undefined
SELECT fn(1);

# Regression test for #93321 - invalidate a cached query with an unqualified UDF
# reference after the database is switched.
subtest regression_93321

statement ok
CREATE FUNCTION fn(a INT) RETURNS INT LANGUAGE SQL AS 'SELECT a';

query I
SELECT fn(1);
----
1

statement ok
CREATE DATABASE d;
USE d;

statement error pq: unknown function: fn\(\): function undefined
SELECT fn(1);

statement ok
USE test;
DROP DATABASE d CASCADE;
DROP FUNCTION fn;
42 changes: 41 additions & 1 deletion pkg/sql/opt/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ type Metadata struct {
userDefinedTypes map[oid.Oid]struct{}
userDefinedTypesSlice []*types.T

// userDefinedFunctions contains all user defined functions present in the
// query.
userDefinedFunctions []*tree.Overload

// deps stores information about all data source objects depended on by the
// query, as well as the privileges required to access them.
deps []mdDep
Expand Down Expand Up @@ -194,7 +198,8 @@ func (md *Metadata) Init() {
func (md *Metadata) CopyFrom(from *Metadata, copyScalarFn func(Expr) Expr) {
if len(md.schemas) != 0 || len(md.cols) != 0 || len(md.tables) != 0 ||
len(md.sequences) != 0 || len(md.deps) != 0 || len(md.views) != 0 ||
len(md.userDefinedTypes) != 0 || len(md.userDefinedTypesSlice) != 0 {
len(md.userDefinedTypes) != 0 || len(md.userDefinedTypesSlice) != 0 ||
len(md.userDefinedFunctions) != 0 {
panic(errors.AssertionFailedf("CopyFrom requires empty destination"))
}
md.schemas = append(md.schemas, from.schemas...)
Expand All @@ -211,6 +216,13 @@ func (md *Metadata) CopyFrom(from *Metadata, copyScalarFn func(Expr) Expr) {
}
}

if len(from.userDefinedFunctions) > 0 {
if cap(md.userDefinedFunctions) < len(from.userDefinedFunctions) {
md.userDefinedFunctions = make([]*tree.Overload, len(from.userDefinedFunctions))
}
md.userDefinedFunctions = append(md.userDefinedFunctions, from.userDefinedFunctions...)
}

if cap(md.tables) >= len(from.tables) {
md.tables = md.tables[:len(from.tables)]
} else {
Expand Down Expand Up @@ -342,6 +354,16 @@ func (md *Metadata) CheckDependencies(
return false, nil
}
}
// Check that all of the user defined functions have not changed.
for _, overload := range md.userDefinedFunctions {
_, toCheck, err := optCatalog.ResolveFunctionByOID(ctx, overload.Oid)
if err != nil {
return false, handleDescError(err)
}
if overload.Version != toCheck.Version {
return false, nil
}
}
return true, nil
}

Expand Down Expand Up @@ -394,6 +416,24 @@ func (md *Metadata) AllUserDefinedTypes() []*types.T {
return md.userDefinedTypesSlice
}

// AddUserDefinedFunc adds a user defined function call to the metadata for this
// query.
func (md *Metadata) AddUserDefinedFunc(overload *tree.Overload) {
if !overload.IsUDF {
// We check IsUDF here instead of HasSQLBody() because we only care about
// user-defined functions, which can be altered or dropped, unlike builtin
// functions defined using a SQL string.
return
}
for i := range md.userDefinedFunctions {
if md.userDefinedFunctions[i] == overload {
// This is a duplicate.
return
}
}
md.userDefinedFunctions = append(md.userDefinedFunctions, overload)
}

// AddTable indexes a new reference to a table within the query. Separate
// references to the same table are assigned different table ids (e.g. in a
// self-join query). All columns are added to the metadata. If mutation columns
Expand Down
18 changes: 9 additions & 9 deletions pkg/sql/opt/optbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ func (b *Builder) Build() (err error) {
evalCtx: b.evalCtx,
}
typeTracker := &optTrackingTypeResolver{
res: b.semaCtx.TypeResolver,
optTrackingResolverHelper: &b.resolverHelper,
res: b.semaCtx.TypeResolver,
helper: &b.resolverHelper,
}
b.semaCtx.TypeResolver = typeTracker

Expand Down Expand Up @@ -548,33 +548,33 @@ func (o *optTrackingResolverHelper) trackObjectPath(
// optTrackingTypeResolver is a wrapper around a TypeReferenceResolver that
// remembers all of the resolved types in the provided Metadata.
type optTrackingTypeResolver struct {
res tree.TypeReferenceResolver
*optTrackingResolverHelper
res tree.TypeReferenceResolver
helper *optTrackingResolverHelper
}

// ResolveType implements the TypeReferenceResolver interface.
// ResolveType implements the tree.TypeReferenceResolver interface.
func (o *optTrackingTypeResolver) ResolveType(
ctx context.Context, name *tree.UnresolvedObjectName,
) (*types.T, error) {
typ, err := o.res.ResolveType(ctx, name)
if err != nil {
return nil, err
}
o.metadata.AddUserDefinedType(typ)
if err = o.trackObjectPath(ctx, name); err != nil {
o.helper.metadata.AddUserDefinedType(typ)
if err = o.helper.trackObjectPath(ctx, name); err != nil {
return nil, err
}
return typ, nil
}

// ResolveTypeByOID implements the tree.TypeResolver interface.
// ResolveTypeByOID implements the tree.TypeReferenceResolver interface.
func (o *optTrackingTypeResolver) ResolveTypeByOID(
ctx context.Context, oid oid.Oid,
) (*types.T, error) {
typ, err := o.res.ResolveTypeByOID(ctx, oid)
if err != nil {
return nil, err
}
o.metadata.AddUserDefinedType(typ)
o.helper.metadata.AddUserDefinedType(typ)
return typ, nil
}
7 changes: 7 additions & 0 deletions pkg/sql/opt/optbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,13 @@ func (b *Builder) buildUDF(
colRefs *opt.ColSet,
) (out opt.ScalarExpr) {
o := f.ResolvedOverload()
b.factory.Metadata().AddUserDefinedFunc(o)
if f.Func.ReferenceByName != nil {
// Track the schema that was used to resolve the function.
if err := b.resolverHelper.trackObjectPath(b.ctx, f.Func.ReferenceByName); err != nil {
panic(err)
}
}

// Validate that the return types match the original return types defined in
// the function. Return types like user defined return types may change since
Expand Down
7 changes: 6 additions & 1 deletion pkg/sql/sem/tree/function_name.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ type FunctionReferenceResolver interface {
// ResolvableFunctionReference implements the editable reference call of a
// FuncExpr.
type ResolvableFunctionReference struct {
// ReferenceByName stores the unresolved name that was used to reference the
// function (if one was) after the function has been resolved.
ReferenceByName *UnresolvedObjectName
FunctionReference
}

Expand Down Expand Up @@ -111,6 +114,8 @@ func (ref *ResolvableFunctionReference) Resolve(
if err != nil {
return nil, err
}
reference, _ := t.ToUnresolvedObjectName(NoAnnotation)
ref.ReferenceByName = &reference
ref.FunctionReference = fd
return fd, nil
case *FunctionOID:
Expand Down Expand Up @@ -143,7 +148,7 @@ func WrapFunction(n string) ResolvableFunctionReference {
if !ok {
panic(errors.AssertionFailedf("function %s() not defined", redact.Safe(n)))
}
return ResolvableFunctionReference{fd}
return ResolvableFunctionReference{FunctionReference: fd}
}

// FunctionReference is the common interface to UnresolvedName and QualifiedFunctionName.
Expand Down
3 changes: 3 additions & 0 deletions pkg/sql/sem/tree/overload.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ type Overload struct {
// ReturnSet is set to true when a user-defined function is defined to return
// a set of values.
ReturnSet bool
// Version is the descriptor version of the descriptor used to construct
// this version of the function overload. Only used for UDFs.
Version uint64
}

// params implements the overloadImpl interface.
Expand Down

0 comments on commit b30310c

Please sign in to comment.