From 76cc703cd12e6fede98b3daf51e2bd8de70d64e1 Mon Sep 17 00:00:00 2001 From: Andy Yang Date: Thu, 2 Feb 2023 16:24:06 -0500 Subject: [PATCH] logictest: refactor client connection creation logic This patch abstracts the logic for creating client connections and persisting them into a separate function. As a side benefit, it also fixes the problem of clients created in `execQuery` not being properly tracked in the `clients` field of `logicTest`. Release note: None --- pkg/sql/logictest/logic.go | 76 +++++++++++--------------------------- 1 file changed, 22 insertions(+), 54 deletions(-) diff --git a/pkg/sql/logictest/logic.go b/pkg/sql/logictest/logic.go index 20ccce092f77..738347aea6ae 100644 --- a/pkg/sql/logictest/logic.go +++ b/pkg/sql/logictest/logic.go @@ -1161,35 +1161,35 @@ func (t *logicTest) outf(format string, args ...interface{}) { // setUser sets the DB client to the specified user and connects // to the node in the cluster at index nodeIdx. -// It returns a cleanup function to be run when the credentials -// are no longer needed. -func (t *logicTest) setUser(user string, nodeIdx int) func() { - if db, ok := t.clients[user][nodeIdx]; ok { - t.db = db - t.user = user +func (t *logicTest) setUser(user string, nodeIdx int) { + db := t.getOrOpenClient(user, nodeIdx) + t.db = db + t.user = user + t.nodeIdx = nodeIdx +} - // No cleanup necessary, but return a no-op func to avoid nil pointer dereference. - return func() {} +// getOrOpenClient returns the existing client for the given user and nodeIdx, +// if one exists. Otherwise, it opens and returns a new client. +func (t *logicTest) getOrOpenClient(user string, nodeIdx int) *gosql.DB { + if db, ok := t.clients[user][nodeIdx]; ok { + return db } - var addr string var pgURL url.URL - var pgUser string - var cleanupFunc func() - pgUser = strings.TrimPrefix(user, "host-cluster-") + pgUser := strings.TrimPrefix(user, "host-cluster-") if t.cfg.UseCockroachGoTestserver { pgURL = *t.testserverCluster.PGURLForNode(nodeIdx) pgURL.User = url.User(pgUser) - pgURL.Path = "test" - cleanupFunc = func() {} } else { - addr = t.cluster.Server(nodeIdx).ServingSQLAddr() + addr := t.cluster.Server(nodeIdx).ServingSQLAddr() if len(t.tenantAddrs) > 0 && !strings.HasPrefix(user, "host-cluster-") { addr = t.tenantAddrs[nodeIdx] } + var cleanupFunc func() pgURL, cleanupFunc = sqlutils.PGUrl(t.rootT, addr, "TestLogic", url.User(pgUser)) - pgURL.Path = "test" + t.clusterCleanupFuncs = append(t.clusterCleanupFuncs, cleanupFunc) } + pgURL.Path = "test" db := t.openDB(pgURL) @@ -1213,11 +1213,8 @@ func (t *logicTest) setUser(user string, nodeIdx int) func() { t.clients[user] = make(map[int]*gosql.DB) } t.clients[user][nodeIdx] = db - t.db = db - t.user = pgUser - t.nodeIdx = nodeIdx - return cleanupFunc + return db } func (t *logicTest) openDB(pgURL url.URL) *gosql.DB { @@ -1306,8 +1303,9 @@ func (t *logicTest) newTestServerCluster(bootstrapBinaryPath string, upgradeBina } t.testserverCluster = ts - t.clusterCleanupFuncs = append(t.clusterCleanupFuncs, t.setUser(username.RootUser, 0 /* nodeIdx */)) t.clusterCleanupFuncs = append(t.clusterCleanupFuncs, ts.Stop) + + t.setUser(username.RootUser, 0 /* nodeIdx */) } // newCluster creates a new cluster. It should be called after the logic tests's @@ -1721,9 +1719,7 @@ func (t *logicTest) newCluster( ) } - // db may change over the lifetime of this function, with intermediate - // values cached in t.clients and finally closed in t.close(). - t.clusterCleanupFuncs = append(t.clusterCleanupFuncs, t.setUser(username.RootUser, 0 /* nodeIdx */)) + t.setUser(username.RootUser, 0 /* nodeIdx */) } // waitForTenantReadOnlyClusterSettingToTakeEffectOrFatal waits until all tenant @@ -2960,7 +2956,7 @@ func (t *logicTest) processSubtest( nodeIdx = int(idx) } } - cleanupUserFunc := t.setUser(fields[1], nodeIdx) + t.setUser(fields[1], nodeIdx) // In multi-tenant tests, we may need to also create database test when // we switch to a different tenant. if t.cfg.UseTenant && strings.HasPrefix(fields[1], "host-cluster-") { @@ -2968,7 +2964,6 @@ func (t *logicTest) processSubtest( return errors.Wrapf(err, "error creating database on admin tenant") } } - defer cleanupUserFunc() case "skip": reason := "skipped" @@ -3311,30 +3306,8 @@ func (t *logicTest) execQuery(query logicQuery) error { t.noticeBuffer = nil db := t.db - var closeDB func() if query.nodeIdx != t.nodeIdx { - var pgURL url.URL - if t.testserverCluster != nil { - pgURL = *t.testserverCluster.PGURLForNode(query.nodeIdx) - pgURL.User = url.User(t.user) - pgURL.Path = "test" - } else { - addr := t.cluster.Server(query.nodeIdx).ServingSQLAddr() - if len(t.tenantAddrs) > 0 { - addr = t.tenantAddrs[query.nodeIdx] - } - var cleanupFunc func() - pgURL, cleanupFunc = sqlutils.PGUrl(t.rootT, addr, "TestLogic", url.User(t.user)) - defer cleanupFunc() - pgURL.Path = "test" - } - - db = t.openDB(pgURL) - closeDB = func() { - if err := db.Close(); err != nil { - t.Fatal(err) - } - } + db = t.getOrOpenClient(t.user, query.nodeIdx) } if query.expectAsync { @@ -3354,9 +3327,6 @@ func (t *logicTest) execQuery(query logicQuery) error { startedChan := make(chan struct{}) go func() { - if closeDB != nil { - defer closeDB() - } startedChan <- struct{}{} rows, err := db.Query(query.sql) pending.resultChan <- pendingQueryResult{rows, err} @@ -3364,8 +3334,6 @@ func (t *logicTest) execQuery(query logicQuery) error { <-startedChan return nil - } else if closeDB != nil { - defer closeDB() } rows, err := db.Query(query.sql)