From 2503c8469f07c3fdd16fa8f5a5948fd46e62a97f Mon Sep 17 00:00:00 2001 From: Jason Roselander Date: Tue, 1 Feb 2022 15:07:33 -0800 Subject: [PATCH] Added GetVersionedSchemaFilesInOrder helper function to cassandra_test_util.go (#2436) --- common/persistence/tests/cassandra_test.go | 112 ++++++++---------- .../persistence/tests/cassandra_test_util.go | 92 ++++++++++---- 2 files changed, 121 insertions(+), 83 deletions(-) diff --git a/common/persistence/tests/cassandra_test.go b/common/persistence/tests/cassandra_test.go index e638fb6d522..865e8b3e318 100644 --- a/common/persistence/tests/cassandra_test.go +++ b/common/persistence/tests/cassandra_test.go @@ -27,6 +27,8 @@ package tests import ( "testing" + "go.uber.org/zap/zaptest" + "github.com/stretchr/testify/suite" "go.temporal.io/server/common/config" @@ -48,103 +50,87 @@ const ( testCassandraDatabaseNameSuffix = "temporal_persistence" ) -func TestCassandraExecutionMutableStateStoreSuite(t *testing.T) { - cfg := newCassandraConfig() - SetUpCassandraDatabase(cfg) - SetUpCassandraSchema(cfg) - logger := log.NewNoopLogger() - factory := cassandra.NewFactory( - *cfg, +type cassandraTestData struct { + cfg *config.Cassandra + factory *cassandra.Factory + logger log.Logger +} + +func setUpCassandraTest(t *testing.T) (cassandraTestData, func()) { + var testData cassandraTestData + testData.cfg = newCassandraConfig() + testData.logger = log.NewZapLogger(zaptest.NewLogger(t)) + SetUpCassandraDatabase(testData.cfg) + SetUpCassandraSchema(testData.cfg, testData.logger) + + testData.factory = cassandra.NewFactory( + *testData.cfg, resolver.NewNoopResolver(), testCassandraClusterName, - logger, + testData.logger, ) - shardStore, err := factory.NewShardStore() + + tearDown := func() { + testData.factory.Close() + TearDownCassandraKeyspace(testData.cfg) + } + + return testData, tearDown +} + +func TestCassandraExecutionMutableStateStoreSuite(t *testing.T) { + testData, tearDown := setUpCassandraTest(t) + defer tearDown() + + shardStore, err := testData.factory.NewShardStore() if err != nil { t.Fatalf("unable to create Cassandra DB: %v", err) } - executionStore, err := factory.NewExecutionStore() + executionStore, err := testData.factory.NewExecutionStore() if err != nil { t.Fatalf("unable to create Cassandra DB: %v", err) } - defer func() { - factory.Close() - TearDownCassandraKeyspace(cfg) - }() - s := NewExecutionMutableStateSuite(t, shardStore, executionStore, logger) + s := NewExecutionMutableStateSuite(t, shardStore, executionStore, testData.logger) suite.Run(t, s) } func TestCassandraHistoryStoreSuite(t *testing.T) { - cfg := newCassandraConfig() - SetUpCassandraDatabase(cfg) - SetUpCassandraSchema(cfg) - logger := log.NewNoopLogger() - factory := cassandra.NewFactory( - *cfg, - resolver.NewNoopResolver(), - testCassandraClusterName, - logger, - ) - store, err := factory.NewExecutionStore() + testData, tearDown := setUpCassandraTest(t) + defer tearDown() + + store, err := testData.factory.NewExecutionStore() if err != nil { t.Fatalf("unable to create Cassandra DB: %v", err) } - defer func() { - factory.Close() - TearDownCassandraKeyspace(cfg) - }() - s := NewHistoryEventsSuite(t, store, logger) + s := NewHistoryEventsSuite(t, store, testData.logger) suite.Run(t, s) } func TestCassandraTaskQueueSuite(t *testing.T) { - cfg := newCassandraConfig() - SetUpCassandraDatabase(cfg) - SetUpCassandraSchema(cfg) - logger := log.NewNoopLogger() - factory := cassandra.NewFactory( - *cfg, - resolver.NewNoopResolver(), - testCassandraClusterName, - logger, - ) - taskQueueStore, err := factory.NewTaskStore() + testData, tearDown := setUpCassandraTest(t) + defer tearDown() + + taskQueueStore, err := testData.factory.NewTaskStore() if err != nil { t.Fatalf("unable to create Cassandra DB: %v", err) } - defer func() { - factory.Close() - TearDownCassandraKeyspace(cfg) - }() - s := NewTaskQueueSuite(t, taskQueueStore, logger) + s := NewTaskQueueSuite(t, taskQueueStore, testData.logger) suite.Run(t, s) } func TestCassandraTaskQueueTaskSuite(t *testing.T) { - cfg := newCassandraConfig() - SetUpCassandraDatabase(cfg) - SetUpCassandraSchema(cfg) - logger := log.NewNoopLogger() - factory := cassandra.NewFactory( - *cfg, - resolver.NewNoopResolver(), - testCassandraClusterName, - logger, - ) - taskQueueStore, err := factory.NewTaskStore() + testData, tearDown := setUpCassandraTest(t) + defer tearDown() + + taskQueueStore, err := testData.factory.NewTaskStore() if err != nil { t.Fatalf("unable to create Cassandra DB: %v", err) } - defer func() { - factory.Close() - TearDownCassandraKeyspace(cfg) - }() - s := NewTaskQueueTaskSuite(t, taskQueueStore, logger) + s := NewTaskQueueTaskSuite(t, taskQueueStore, testData.logger) suite.Run(t, s) } diff --git a/common/persistence/tests/cassandra_test_util.go b/common/persistence/tests/cassandra_test_util.go index 635c807ec70..864a6f51579 100644 --- a/common/persistence/tests/cassandra_test_util.go +++ b/common/persistence/tests/cassandra_test_util.go @@ -26,8 +26,13 @@ package tests import ( "fmt" + "os" + "path" "path/filepath" + "sort" + "strings" + "github.com/blang/semver/v4" "go.temporal.io/server/common/config" "go.temporal.io/server/common/log" p "go.temporal.io/server/common/persistence" @@ -65,41 +70,31 @@ func SetUpCassandraDatabase(cfg *config.Cassandra) { } } -func SetUpCassandraSchema(cfg *config.Cassandra) { - session, err := gocql.NewSession(*cfg, resolver.NewNoopResolver(), log.NewNoopLogger()) - if err != nil { - panic(fmt.Sprintf("unable to create Cassandra session: %v", err)) - } - defer session.Close() - - schemaPath, err := filepath.Abs(testCassandraExecutionSchema) - if err != nil { - panic(err) - } +func SetUpCassandraSchema(cfg *config.Cassandra, logger log.Logger) { + ApplySchemaUpdate(cfg, testCassandraExecutionSchema, logger) + ApplySchemaUpdate(cfg, testCassandraVisibilitySchema, logger) +} - statements, err := p.LoadAndSplitQuery([]string{schemaPath}) +func ApplySchemaUpdate(cfg *config.Cassandra, schemaFile string, logger log.Logger) { + session, err := gocql.NewSession(*cfg, resolver.NewNoopResolver(), logger) if err != nil { panic(err) } + defer session.Close() - for _, stmt := range statements { - if err = session.Query(stmt).Exec(); err != nil { - panic(err) - } - } - - schemaPath, err = filepath.Abs(testCassandraVisibilitySchema) + schemaPath, err := filepath.Abs(schemaFile) if err != nil { panic(err) } - statements, err = p.LoadAndSplitQuery([]string{schemaPath}) + statements, err := p.LoadAndSplitQuery([]string{schemaPath}) if err != nil { panic(err) } for _, stmt := range statements { if err = session.Query(stmt).Exec(); err != nil { + logger.Error(fmt.Sprintf("Unable to execute statement from file: %s\n %s", schemaFile, stmt)) panic(err) } } @@ -124,3 +119,60 @@ func TearDownCassandraKeyspace(cfg *config.Cassandra) { panic(fmt.Sprintf("unable to drop Cassandra keyspace: %v", err)) } } + +// GetSchemaFiles takes a root directory which contains subdirectories whose names are semantic versions and returns +// the .cql files within. E.g.: //schema/cassandra/temporal/versioned +// Subdirectories are ordered by semantic version, but files within the same subdirectory are in arbitrary order. +// All .cql files are returned regardless of whether they are named in manifest.json. +func GetSchemaFiles(schemaDir string, logger log.Logger) []string { + var retVal []string + + versionDirPath := path.Join(schemaDir, "versioned") + subDirs, err := os.ReadDir(versionDirPath) + if err != nil { + panic(err) + } + + versionDirNames := make([]string, 0, len(subDirs)) + for _, subDir := range subDirs { + if !subDir.IsDir() { + logger.Warn(fmt.Sprintf("Skipping non-directory file: '%s'", subDir.Name())) + continue + } + if _, ve := semver.ParseTolerant(subDir.Name()); ve != nil { + logger.Warn(fmt.Sprintf("Skipping directory which is not a valid semver: '%s'", subDir.Name())) + } + versionDirNames = append(versionDirNames, subDir.Name()) + } + + sort.Slice(versionDirNames, func(i, j int) bool { + vLeft, err := semver.ParseTolerant(versionDirNames[i]) + if err != nil { + panic(err) // Logic error + } + vRight, err := semver.ParseTolerant(versionDirNames[j]) + if err != nil { + panic(err) // Logic error + } + return vLeft.Compare(vRight) < 0 + }) + + for _, dir := range versionDirNames { + vDirPath := path.Join(versionDirPath, dir) + files, err := os.ReadDir(vDirPath) + if err != nil { + panic(err) + } + for _, file := range files { + if file.IsDir() { + continue + } + if !strings.HasSuffix(file.Name(), ".cql") { + continue + } + retVal = append(retVal, path.Join(vDirPath, file.Name())) + } + } + + return retVal +}