Skip to content

Commit

Permalink
opt: check UDFs when checking metadata dependencies
Browse files Browse the repository at this point in the history
This patch ensures that the metadata dependency-checking tracks user-defined
functions. This ensures that a cached query with a UDF reference will be
invalidated when the UDF is altered or dropped, or when the database is
switched.

Fixes cockroachdb#93082
Fixes cockroachdb#93321

Release note (bug fix): The query cache now checks to ensure that
user-defined functions referenced in the query have been altered or
dropped. This prevents a bug that could cause a query to return the
same result even after a UDF was dropped or the database was switched.
  • Loading branch information
DrewKimball committed Jan 31, 2023
1 parent 10ef5d9 commit d2eaea5
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 25 deletions.
1 change: 1 addition & 0 deletions pkg/sql/catalog/funcdesc/func_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ func (desc *immutable) ToOverload() (ret *tree.Overload, err error) {
ReturnSet: desc.ReturnType.ReturnSet,
Body: desc.FunctionBody,
IsUDF: true,
Version: uint64(desc.Version),
}

argTypes := make(tree.ParamTypes, 0, len(desc.Params))
Expand Down
75 changes: 68 additions & 7 deletions pkg/sql/logictest/testdata/logic_test/udf
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,7 @@ DROP FUNCTION sc2.f_tbl()

statement ok
ALTER DATABASE rename_db1 RENAME TO rename_db2;
USE rename_db2;

# Make sure that db renaming does not affect types and sequences in UDF.
query T
Expand Down Expand Up @@ -1676,7 +1677,7 @@ Mon
query I
SELECT sc1.f_seq()
----
5
1

query T
SELECT sc2.f_type()
Expand All @@ -1686,7 +1687,7 @@ Mon
query I
SELECT sc2.f_seq()
----
6
2

statement error pq: cannot rename schema because relation "rename_sc1.sc1.f_tbl" depends on relation "rename_sc1.sc1.tbl"
ALTER SCHEMA sc1 RENAME TO sc1_new
Expand All @@ -1703,16 +1704,23 @@ DROP FUNCTION sc2.f_tbl()
statement ok
ALTER SCHEMA sc1 RENAME TO sc1_new

# Make sure that db renaming does not affect types and sequences in UDF.
query T
# Cannot refer to the old schema name.
statement error schema "sc1" does not exist
SELECT sc1.f_type()

statement error schema "sc1" does not exist
SELECT sc1.f_seq()

# Make sure that schema renaming does not affect types and sequences in UDF.
query T
SELECT sc1_new.f_type()
----
Mon

query I
SELECT sc1.f_seq()
SELECT sc1_new.f_seq()
----
7
3

query T
SELECT sc2.f_type()
Expand All @@ -1722,7 +1730,7 @@ Mon
query I
SELECT sc2.f_seq()
----
8
4

statement ok
SET DATABASE = test
Expand Down Expand Up @@ -2923,3 +2931,56 @@ SELECT f95240(a) FROM t95240
----
33
NULL

# Regression test for #93082 and #93321 - don't reuse a cached query with a UDF
# if the UDF has been altered or dropped, or if the database has been switched.
subtest regression_93082

statement ok
CREATE FUNCTION fn(a INT) RETURNS INT LANGUAGE SQL AS 'SELECT a';

query I
SELECT fn(1);
----
1

statement ok
DROP FUNCTION fn;

statement error pq: unknown function: fn\(\): function undefined
SELECT fn(1);

statement ok
CREATE FUNCTION fn(a INT) RETURNS INT LANGUAGE SQL AS 'SELECT a';

query I
SELECT fn(1);
----
1

statement ok
ALTER FUNCTION fn RENAME TO fn2;

statement error pq: unknown function: fn\(\): function undefined
SELECT fn(1);

statement ok
CREATE DATABASE d;
USE d;

statement error pq: unknown function: fn\(\): function undefined
SELECT fn(1);

statement ok
CREATE FUNCTION fn(a INT) RETURNS INT LANGUAGE SQL AS 'SELECT a';

query I
SELECT fn(1);
----
1

statement ok
USE defaultdb;

statement error pq: unknown function: fn\(\): function undefined
SELECT fn(1);
1 change: 1 addition & 0 deletions pkg/sql/opt/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/server/telemetry",
"//pkg/sql/catalog",
"//pkg/sql/catalog/catpb",
"//pkg/sql/catalog/colinfo",
"//pkg/sql/catalog/descpb",
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/memo/memo.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ func (m *Memo) IsStale(
// Memo is stale if the fingerprint of any object in the memo's metadata has
// changed, or if the current user no longer has sufficient privilege to
// access the object.
if depsUpToDate, err := m.Metadata().CheckDependencies(ctx, catalog); err != nil {
if depsUpToDate, err := m.Metadata().CheckDependencies(ctx, evalCtx, catalog); err != nil {
return true, err
} else if !depsUpToDate {
return true, nil
Expand Down
92 changes: 78 additions & 14 deletions pkg/sql/opt/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"math/bits"
"strings"

"github.com/cockroachdb/cockroach/pkg/sql/catalog"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/multiregion"
"github.com/cockroachdb/cockroach/pkg/sql/opt/cat"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
Expand Down Expand Up @@ -101,6 +102,11 @@ type Metadata struct {
userDefinedTypes map[oid.Oid]struct{}
userDefinedTypesSlice []*types.T

// userDefinedFunctions contains all user defined functions present in the
// query. TODO: there could be multiple calls with different qualified names
// to the same overload.
userDefinedFunctions []funcMDDep

// deps stores information about all data source objects depended on by the
// query, as well as the privileges required to access them. The objects are
// deduplicated: any name/object pair shows up at most once.
Expand Down Expand Up @@ -150,6 +156,13 @@ func (n *MDDepName) equals(other *MDDepName) bool {
return n.byID == other.byID && n.byName.Equals(&other.byName)
}

// funcMDDep tracks the information needed to resolve a UDF, as well as the
// previously resolved overload.
type funcMDDep struct {
name *tree.UnresolvedName
overload *tree.Overload
}

// Init prepares the metadata for use (or reuse).
func (md *Metadata) Init() {
// Clear the metadata objects to release memory (this clearing pattern is
Expand Down Expand Up @@ -206,7 +219,8 @@ func (md *Metadata) Init() {
func (md *Metadata) CopyFrom(from *Metadata, copyScalarFn func(Expr) Expr) {
if len(md.schemas) != 0 || len(md.cols) != 0 || len(md.tables) != 0 ||
len(md.sequences) != 0 || len(md.deps) != 0 || len(md.views) != 0 ||
len(md.userDefinedTypes) != 0 || len(md.userDefinedTypesSlice) != 0 {
len(md.userDefinedTypes) != 0 || len(md.userDefinedTypesSlice) != 0 ||
len(md.userDefinedFunctions) != 0 {
panic(errors.AssertionFailedf("CopyFrom requires empty destination"))
}
md.schemas = append(md.schemas, from.schemas...)
Expand All @@ -223,6 +237,12 @@ func (md *Metadata) CopyFrom(from *Metadata, copyScalarFn func(Expr) Expr) {
}
}

if len(from.userDefinedFunctions) > 0 {
for i := range from.userDefinedFunctions {
md.userDefinedFunctions = append(md.userDefinedFunctions, from.userDefinedFunctions[i])
}
}

if cap(md.tables) >= len(from.tables) {
md.tables = md.tables[:len(from.tables)]
} else {
Expand Down Expand Up @@ -288,21 +308,23 @@ func (md *Metadata) AddDependency(name MDDepName, ds cat.DataSource, priv privil
// objects. If the dependencies are no longer up-to-date, then CheckDependencies
// returns false.
//
// This function cannot swallow errors and return only a boolean, as it may
// perform KV operations on behalf of the transaction associated with the
// provided catalog, and those errors are required to be propagated.
// This function cannot swallow arbitrary errors and return only a boolean, as
// it may perform KV operations on behalf of the transaction associated with the
// provided catalog, and those errors are required to be propagated. Note that
// it is ok to swallow "undefined" or "dropped" object errors, since these are
// expected when dependencies are not up-to-date.
func (md *Metadata) CheckDependencies(
ctx context.Context, catalog cat.Catalog,
ctx context.Context, evalCtx *eval.Context, optCatalog cat.Catalog,
) (upToDate bool, err error) {
for i := range md.deps {
name := &md.deps[i].name
var toCheck cat.DataSource
var err error
if name.byID != 0 {
toCheck, _, err = catalog.ResolveDataSourceByID(ctx, cat.Flags{}, name.byID)
toCheck, _, err = optCatalog.ResolveDataSourceByID(ctx, cat.Flags{}, name.byID)
} else {
// Resolve data source object.
toCheck, _, err = catalog.ResolveDataSource(ctx, cat.Flags{}, &name.byName)
toCheck, _, err = optCatalog.ResolveDataSource(ctx, cat.Flags{}, &name.byName)
}
if err != nil {
return false, err
Expand All @@ -321,7 +343,7 @@ func (md *Metadata) CheckDependencies(
// privileges do not need to be checked). Ignore the "zero privilege".
priv := privilege.Kind(bits.TrailingZeros32(uint32(privs)))
if priv != 0 {
if err := catalog.CheckPrivilege(ctx, toCheck, priv); err != nil {
if err := optCatalog.CheckPrivilege(ctx, toCheck, priv); err != nil {
return false, err
}
}
Expand All @@ -330,20 +352,44 @@ func (md *Metadata) CheckDependencies(
privs &= ^(1 << priv)
}
}
// handleUndefined swallows "undefined" and "dropped object errors, since
// these are expected when an object no longer exists.
handleUndefined := func(err error) error {
if pgerror.GetPGCode(err) == pgcode.UndefinedObject ||
errors.Is(err, catalog.ErrDescriptorDropped) {
return nil
}
return nil
}
// Check that all of the user defined types present have not changed.
for _, typ := range md.AllUserDefinedTypes() {
toCheck, err := catalog.ResolveTypeByOID(ctx, typ.Oid())
toCheck, err := optCatalog.ResolveTypeByOID(ctx, typ.Oid())
if err != nil {
// Handle when the type no longer exists.
if pgerror.GetPGCode(err) == pgcode.UndefinedObject {
return false, nil
}
return false, err
return false, handleUndefined(err)
}
if typ.TypeMeta.Version != toCheck.TypeMeta.Version {
return false, nil
}
}
// Check that all of the user defined functions have not changed.
for i := range md.userDefinedFunctions {
dep := &md.userDefinedFunctions[i]
toCheck, err := optCatalog.ResolveFunction(ctx, dep.name, &evalCtx.SessionData().SearchPath)
if err != nil {
return false, handleUndefined(err)
}
overload, err := toCheck.MatchOverload(
dep.overload.Types.Types(), "", &evalCtx.SessionData().SearchPath,
)
if err != nil {
return false, handleUndefined(err)
}
if dep.overload.Oid != overload.Oid || dep.overload.Version != overload.Version {
// The function call resolved to either a different overload or a
// different version of the same overload.
return false, nil
}
}
return true, nil
}

Expand Down Expand Up @@ -378,6 +424,24 @@ func (md *Metadata) AllUserDefinedTypes() []*types.T {
return md.userDefinedTypesSlice
}

// AddUserDefinedFunc adds a user defined function overload to the metadata for
// this query.
func (md *Metadata) AddUserDefinedFunc(fun *tree.Overload, name *tree.UnresolvedName) {
if !fun.IsUDF {
return
}
if name == nil {
panic(errors.AssertionFailedf("attempted to add UDF with nil name"))
}
for i := range md.userDefinedFunctions {
if md.userDefinedFunctions[i].overload == fun && md.userDefinedFunctions[i].name == name {
// This is a duplicate.
break
}
}
md.userDefinedFunctions = append(md.userDefinedFunctions, funcMDDep{name: name, overload: fun})
}

// AddTable indexes a new reference to a table within the query. Separate
// references to the same table are assigned different table ids (e.g. in a
// self-join query). All columns are added to the metadata. If mutation columns
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/opt/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func TestMetadata(t *testing.T) {
}

md.AddDependency(opt.DepByName(&tab.TabName), tab, privilege.CREATE)
depsUpToDate, err := md.CheckDependencies(context.Background(), testCat)
depsUpToDate, err := md.CheckDependencies(context.Background(), &evalCtx, testCat)
if err == nil || depsUpToDate {
t.Fatalf("expected table privilege to be revoked")
}
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestMetadata(t *testing.T) {
t.Fatalf("unexpected type")
}

depsUpToDate, err = md.CheckDependencies(context.Background(), testCat)
depsUpToDate, err = md.CheckDependencies(context.Background(), &evalCtx, testCat)
if err == nil || depsUpToDate {
t.Fatalf("expected table privilege to be revoked in metadata copy")
}
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/opt/optbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ func (b *Builder) buildUDF(
colRefs *opt.ColSet,
) (out opt.ScalarExpr) {
o := f.ResolvedOverload()
b.factory.Metadata().AddUserDefinedFunc(o, f.Func.CallsiteName)

// Build the argument expressions.
var args memo.ScalarListExpr
Expand Down
6 changes: 5 additions & 1 deletion pkg/sql/sem/tree/function_name.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ type FunctionReferenceResolver interface {
// ResolvableFunctionReference implements the editable reference call of a
// FuncExpr.
type ResolvableFunctionReference struct {
// CallsiteName is the original UnresolvedName used at the function callsite.
// It may be unset if the function is unresolved.
CallsiteName *UnresolvedName
FunctionReference
}

Expand Down Expand Up @@ -112,6 +115,7 @@ func (ref *ResolvableFunctionReference) Resolve(
return nil, err
}
ref.FunctionReference = fd
ref.CallsiteName = t
return fd, nil
default:
return nil, errors.AssertionFailedf("unknown resolvable function reference type %s", t)
Expand All @@ -129,7 +133,7 @@ func WrapFunction(n string) ResolvableFunctionReference {
if !ok {
panic(errors.AssertionFailedf("function %s() not defined", redact.Safe(n)))
}
return ResolvableFunctionReference{fd}
return ResolvableFunctionReference{FunctionReference: fd}
}

// FunctionReference is the common interface to UnresolvedName and QualifiedFunctionName.
Expand Down
3 changes: 3 additions & 0 deletions pkg/sql/sem/tree/overload.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ type Overload struct {
// ReturnSet is set to true when a user-defined function is defined to return
// a set of values.
ReturnSet bool
// Version is the descriptor version of the descriptor used to construct
// this version of the function overload. Only used for UDFs.
Version uint64
}

// params implements the overloadImpl interface.
Expand Down

0 comments on commit d2eaea5

Please sign in to comment.