Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

logictest: refactor client connection creation logic #96469

Merged
merged 1 commit into from
Feb 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 22 additions & 54 deletions pkg/sql/logictest/logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -1161,35 +1161,35 @@ func (t *logicTest) outf(format string, args ...interface{}) {

// setUser sets the DB client to the specified user and connects
// to the node in the cluster at index nodeIdx.
// It returns a cleanup function to be run when the credentials
// are no longer needed.
func (t *logicTest) setUser(user string, nodeIdx int) func() {
if db, ok := t.clients[user][nodeIdx]; ok {
t.db = db
t.user = user
func (t *logicTest) setUser(user string, nodeIdx int) {
db := t.getOrOpenClient(user, nodeIdx)
t.db = db
t.user = user
t.nodeIdx = nodeIdx
}

// No cleanup necessary, but return a no-op func to avoid nil pointer dereference.
return func() {}
// getOrOpenClient returns the existing client for the given user and nodeIdx,
// if one exists. Otherwise, it opens and returns a new client.
func (t *logicTest) getOrOpenClient(user string, nodeIdx int) *gosql.DB {
if db, ok := t.clients[user][nodeIdx]; ok {
return db
}

var addr string
var pgURL url.URL
var pgUser string
var cleanupFunc func()
pgUser = strings.TrimPrefix(user, "host-cluster-")
pgUser := strings.TrimPrefix(user, "host-cluster-")
if t.cfg.UseCockroachGoTestserver {
pgURL = *t.testserverCluster.PGURLForNode(nodeIdx)
pgURL.User = url.User(pgUser)
pgURL.Path = "test"
cleanupFunc = func() {}
} else {
addr = t.cluster.Server(nodeIdx).ServingSQLAddr()
addr := t.cluster.Server(nodeIdx).ServingSQLAddr()
if len(t.tenantAddrs) > 0 && !strings.HasPrefix(user, "host-cluster-") {
addr = t.tenantAddrs[nodeIdx]
}
var cleanupFunc func()
pgURL, cleanupFunc = sqlutils.PGUrl(t.rootT, addr, "TestLogic", url.User(pgUser))
pgURL.Path = "test"
t.clusterCleanupFuncs = append(t.clusterCleanupFuncs, cleanupFunc)
}
pgURL.Path = "test"

db := t.openDB(pgURL)

Expand All @@ -1213,11 +1213,8 @@ func (t *logicTest) setUser(user string, nodeIdx int) func() {
t.clients[user] = make(map[int]*gosql.DB)
}
t.clients[user][nodeIdx] = db
t.db = db
t.user = pgUser
t.nodeIdx = nodeIdx

return cleanupFunc
return db
}

func (t *logicTest) openDB(pgURL url.URL) *gosql.DB {
Expand Down Expand Up @@ -1306,8 +1303,9 @@ func (t *logicTest) newTestServerCluster(bootstrapBinaryPath string, upgradeBina
}

t.testserverCluster = ts
t.clusterCleanupFuncs = append(t.clusterCleanupFuncs, t.setUser(username.RootUser, 0 /* nodeIdx */))
t.clusterCleanupFuncs = append(t.clusterCleanupFuncs, ts.Stop)

t.setUser(username.RootUser, 0 /* nodeIdx */)
}

// newCluster creates a new cluster. It should be called after the logic tests's
Expand Down Expand Up @@ -1721,9 +1719,7 @@ func (t *logicTest) newCluster(
)
}

// db may change over the lifetime of this function, with intermediate
// values cached in t.clients and finally closed in t.close().
t.clusterCleanupFuncs = append(t.clusterCleanupFuncs, t.setUser(username.RootUser, 0 /* nodeIdx */))
t.setUser(username.RootUser, 0 /* nodeIdx */)
}

// waitForTenantReadOnlyClusterSettingToTakeEffectOrFatal waits until all tenant
Expand Down Expand Up @@ -2960,15 +2956,14 @@ func (t *logicTest) processSubtest(
nodeIdx = int(idx)
}
}
cleanupUserFunc := t.setUser(fields[1], nodeIdx)
t.setUser(fields[1], nodeIdx)
// In multi-tenant tests, we may need to also create database test when
// we switch to a different tenant.
if t.cfg.UseTenant && strings.HasPrefix(fields[1], "host-cluster-") {
if _, err := t.db.Exec("CREATE DATABASE IF NOT EXISTS test; USE test;"); err != nil {
return errors.Wrapf(err, "error creating database on admin tenant")
}
}
defer cleanupUserFunc()

case "skip":
reason := "skipped"
Expand Down Expand Up @@ -3311,30 +3306,8 @@ func (t *logicTest) execQuery(query logicQuery) error {
t.noticeBuffer = nil

db := t.db
var closeDB func()
if query.nodeIdx != t.nodeIdx {
var pgURL url.URL
if t.testserverCluster != nil {
pgURL = *t.testserverCluster.PGURLForNode(query.nodeIdx)
pgURL.User = url.User(t.user)
pgURL.Path = "test"
} else {
addr := t.cluster.Server(query.nodeIdx).ServingSQLAddr()
if len(t.tenantAddrs) > 0 {
addr = t.tenantAddrs[query.nodeIdx]
}
var cleanupFunc func()
pgURL, cleanupFunc = sqlutils.PGUrl(t.rootT, addr, "TestLogic", url.User(t.user))
defer cleanupFunc()
pgURL.Path = "test"
}

db = t.openDB(pgURL)
closeDB = func() {
andyyang890 marked this conversation as resolved.
Show resolved Hide resolved
if err := db.Close(); err != nil {
t.Fatal(err)
}
}
db = t.getOrOpenClient(t.user, query.nodeIdx)
}

if query.expectAsync {
Expand All @@ -3354,18 +3327,13 @@ func (t *logicTest) execQuery(query logicQuery) error {

startedChan := make(chan struct{})
go func() {
if closeDB != nil {
defer closeDB()
}
startedChan <- struct{}{}
rows, err := db.Query(query.sql)
pending.resultChan <- pendingQueryResult{rows, err}
}()

<-startedChan
return nil
} else if closeDB != nil {
defer closeDB()
}

rows, err := db.Query(query.sql)
Expand Down