Skip to content

Commit

Permalink
logictest: refactor client connection creation logic
Browse files Browse the repository at this point in the history
This patch abstracts the logic for creating client connections and
persisting them into a separate function. As a side benefit, it also
fixes the problem of clients created in `execQuery` not being properly
tracked in the `clients` field of `logicTest`.

Release note: None
  • Loading branch information
andyyang890 committed Feb 3, 2023
1 parent 44d9f3c commit 76cc703
Showing 1 changed file with 22 additions and 54 deletions.
76 changes: 22 additions & 54 deletions pkg/sql/logictest/logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -1161,35 +1161,35 @@ func (t *logicTest) outf(format string, args ...interface{}) {

// setUser sets the DB client to the specified user and connects
// to the node in the cluster at index nodeIdx.
// It returns a cleanup function to be run when the credentials
// are no longer needed.
func (t *logicTest) setUser(user string, nodeIdx int) func() {
if db, ok := t.clients[user][nodeIdx]; ok {
t.db = db
t.user = user
func (t *logicTest) setUser(user string, nodeIdx int) {
db := t.getOrOpenClient(user, nodeIdx)
t.db = db
t.user = user
t.nodeIdx = nodeIdx
}

// No cleanup necessary, but return a no-op func to avoid nil pointer dereference.
return func() {}
// getOrOpenClient returns the existing client for the given user and nodeIdx,
// if one exists. Otherwise, it opens and returns a new client.
func (t *logicTest) getOrOpenClient(user string, nodeIdx int) *gosql.DB {
if db, ok := t.clients[user][nodeIdx]; ok {
return db
}

var addr string
var pgURL url.URL
var pgUser string
var cleanupFunc func()
pgUser = strings.TrimPrefix(user, "host-cluster-")
pgUser := strings.TrimPrefix(user, "host-cluster-")
if t.cfg.UseCockroachGoTestserver {
pgURL = *t.testserverCluster.PGURLForNode(nodeIdx)
pgURL.User = url.User(pgUser)
pgURL.Path = "test"
cleanupFunc = func() {}
} else {
addr = t.cluster.Server(nodeIdx).ServingSQLAddr()
addr := t.cluster.Server(nodeIdx).ServingSQLAddr()
if len(t.tenantAddrs) > 0 && !strings.HasPrefix(user, "host-cluster-") {
addr = t.tenantAddrs[nodeIdx]
}
var cleanupFunc func()
pgURL, cleanupFunc = sqlutils.PGUrl(t.rootT, addr, "TestLogic", url.User(pgUser))
pgURL.Path = "test"
t.clusterCleanupFuncs = append(t.clusterCleanupFuncs, cleanupFunc)
}
pgURL.Path = "test"

db := t.openDB(pgURL)

Expand All @@ -1213,11 +1213,8 @@ func (t *logicTest) setUser(user string, nodeIdx int) func() {
t.clients[user] = make(map[int]*gosql.DB)
}
t.clients[user][nodeIdx] = db
t.db = db
t.user = pgUser
t.nodeIdx = nodeIdx

return cleanupFunc
return db
}

func (t *logicTest) openDB(pgURL url.URL) *gosql.DB {
Expand Down Expand Up @@ -1306,8 +1303,9 @@ func (t *logicTest) newTestServerCluster(bootstrapBinaryPath string, upgradeBina
}

t.testserverCluster = ts
t.clusterCleanupFuncs = append(t.clusterCleanupFuncs, t.setUser(username.RootUser, 0 /* nodeIdx */))
t.clusterCleanupFuncs = append(t.clusterCleanupFuncs, ts.Stop)

t.setUser(username.RootUser, 0 /* nodeIdx */)
}

// newCluster creates a new cluster. It should be called after the logic tests's
Expand Down Expand Up @@ -1721,9 +1719,7 @@ func (t *logicTest) newCluster(
)
}

// db may change over the lifetime of this function, with intermediate
// values cached in t.clients and finally closed in t.close().
t.clusterCleanupFuncs = append(t.clusterCleanupFuncs, t.setUser(username.RootUser, 0 /* nodeIdx */))
t.setUser(username.RootUser, 0 /* nodeIdx */)
}

// waitForTenantReadOnlyClusterSettingToTakeEffectOrFatal waits until all tenant
Expand Down Expand Up @@ -2960,15 +2956,14 @@ func (t *logicTest) processSubtest(
nodeIdx = int(idx)
}
}
cleanupUserFunc := t.setUser(fields[1], nodeIdx)
t.setUser(fields[1], nodeIdx)
// In multi-tenant tests, we may need to also create database test when
// we switch to a different tenant.
if t.cfg.UseTenant && strings.HasPrefix(fields[1], "host-cluster-") {
if _, err := t.db.Exec("CREATE DATABASE IF NOT EXISTS test; USE test;"); err != nil {
return errors.Wrapf(err, "error creating database on admin tenant")
}
}
defer cleanupUserFunc()

case "skip":
reason := "skipped"
Expand Down Expand Up @@ -3311,30 +3306,8 @@ func (t *logicTest) execQuery(query logicQuery) error {
t.noticeBuffer = nil

db := t.db
var closeDB func()
if query.nodeIdx != t.nodeIdx {
var pgURL url.URL
if t.testserverCluster != nil {
pgURL = *t.testserverCluster.PGURLForNode(query.nodeIdx)
pgURL.User = url.User(t.user)
pgURL.Path = "test"
} else {
addr := t.cluster.Server(query.nodeIdx).ServingSQLAddr()
if len(t.tenantAddrs) > 0 {
addr = t.tenantAddrs[query.nodeIdx]
}
var cleanupFunc func()
pgURL, cleanupFunc = sqlutils.PGUrl(t.rootT, addr, "TestLogic", url.User(t.user))
defer cleanupFunc()
pgURL.Path = "test"
}

db = t.openDB(pgURL)
closeDB = func() {
if err := db.Close(); err != nil {
t.Fatal(err)
}
}
db = t.getOrOpenClient(t.user, query.nodeIdx)
}

if query.expectAsync {
Expand All @@ -3354,18 +3327,13 @@ func (t *logicTest) execQuery(query logicQuery) error {

startedChan := make(chan struct{})
go func() {
if closeDB != nil {
defer closeDB()
}
startedChan <- struct{}{}
rows, err := db.Query(query.sql)
pending.resultChan <- pendingQueryResult{rows, err}
}()

<-startedChan
return nil
} else if closeDB != nil {
defer closeDB()
}

rows, err := db.Query(query.sql)
Expand Down

0 comments on commit 76cc703

Please sign in to comment.