Skip to content

Commit

Permalink
Added GetVersionedSchemaFilesInOrder helper function to cassandra_tes…
Browse files Browse the repository at this point in the history
…t_util.go (#2436)
  • Loading branch information
Jason Roselander authored Feb 1, 2022
1 parent fe89dd4 commit 2503c84
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 83 deletions.
112 changes: 49 additions & 63 deletions common/persistence/tests/cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ package tests
import (
"testing"

"go.uber.org/zap/zaptest"

"github.com/stretchr/testify/suite"

"go.temporal.io/server/common/config"
Expand All @@ -48,103 +50,87 @@ const (
testCassandraDatabaseNameSuffix = "temporal_persistence"
)

func TestCassandraExecutionMutableStateStoreSuite(t *testing.T) {
cfg := newCassandraConfig()
SetUpCassandraDatabase(cfg)
SetUpCassandraSchema(cfg)
logger := log.NewNoopLogger()
factory := cassandra.NewFactory(
*cfg,
type cassandraTestData struct {
cfg *config.Cassandra
factory *cassandra.Factory
logger log.Logger
}

func setUpCassandraTest(t *testing.T) (cassandraTestData, func()) {
var testData cassandraTestData
testData.cfg = newCassandraConfig()
testData.logger = log.NewZapLogger(zaptest.NewLogger(t))
SetUpCassandraDatabase(testData.cfg)
SetUpCassandraSchema(testData.cfg, testData.logger)

testData.factory = cassandra.NewFactory(
*testData.cfg,
resolver.NewNoopResolver(),
testCassandraClusterName,
logger,
testData.logger,
)
shardStore, err := factory.NewShardStore()

tearDown := func() {
testData.factory.Close()
TearDownCassandraKeyspace(testData.cfg)
}

return testData, tearDown
}

func TestCassandraExecutionMutableStateStoreSuite(t *testing.T) {
testData, tearDown := setUpCassandraTest(t)
defer tearDown()

shardStore, err := testData.factory.NewShardStore()
if err != nil {
t.Fatalf("unable to create Cassandra DB: %v", err)
}
executionStore, err := factory.NewExecutionStore()
executionStore, err := testData.factory.NewExecutionStore()
if err != nil {
t.Fatalf("unable to create Cassandra DB: %v", err)
}
defer func() {
factory.Close()
TearDownCassandraKeyspace(cfg)
}()

s := NewExecutionMutableStateSuite(t, shardStore, executionStore, logger)
s := NewExecutionMutableStateSuite(t, shardStore, executionStore, testData.logger)
suite.Run(t, s)
}

func TestCassandraHistoryStoreSuite(t *testing.T) {
cfg := newCassandraConfig()
SetUpCassandraDatabase(cfg)
SetUpCassandraSchema(cfg)
logger := log.NewNoopLogger()
factory := cassandra.NewFactory(
*cfg,
resolver.NewNoopResolver(),
testCassandraClusterName,
logger,
)
store, err := factory.NewExecutionStore()
testData, tearDown := setUpCassandraTest(t)
defer tearDown()

store, err := testData.factory.NewExecutionStore()
if err != nil {
t.Fatalf("unable to create Cassandra DB: %v", err)
}
defer func() {
factory.Close()
TearDownCassandraKeyspace(cfg)
}()

s := NewHistoryEventsSuite(t, store, logger)
s := NewHistoryEventsSuite(t, store, testData.logger)
suite.Run(t, s)
}

func TestCassandraTaskQueueSuite(t *testing.T) {
cfg := newCassandraConfig()
SetUpCassandraDatabase(cfg)
SetUpCassandraSchema(cfg)
logger := log.NewNoopLogger()
factory := cassandra.NewFactory(
*cfg,
resolver.NewNoopResolver(),
testCassandraClusterName,
logger,
)
taskQueueStore, err := factory.NewTaskStore()
testData, tearDown := setUpCassandraTest(t)
defer tearDown()

taskQueueStore, err := testData.factory.NewTaskStore()
if err != nil {
t.Fatalf("unable to create Cassandra DB: %v", err)
}
defer func() {
factory.Close()
TearDownCassandraKeyspace(cfg)
}()

s := NewTaskQueueSuite(t, taskQueueStore, logger)
s := NewTaskQueueSuite(t, taskQueueStore, testData.logger)
suite.Run(t, s)
}

func TestCassandraTaskQueueTaskSuite(t *testing.T) {
cfg := newCassandraConfig()
SetUpCassandraDatabase(cfg)
SetUpCassandraSchema(cfg)
logger := log.NewNoopLogger()
factory := cassandra.NewFactory(
*cfg,
resolver.NewNoopResolver(),
testCassandraClusterName,
logger,
)
taskQueueStore, err := factory.NewTaskStore()
testData, tearDown := setUpCassandraTest(t)
defer tearDown()

taskQueueStore, err := testData.factory.NewTaskStore()
if err != nil {
t.Fatalf("unable to create Cassandra DB: %v", err)
}
defer func() {
factory.Close()
TearDownCassandraKeyspace(cfg)
}()

s := NewTaskQueueTaskSuite(t, taskQueueStore, logger)
s := NewTaskQueueTaskSuite(t, taskQueueStore, testData.logger)
suite.Run(t, s)
}

Expand Down
92 changes: 72 additions & 20 deletions common/persistence/tests/cassandra_test_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ package tests

import (
"fmt"
"os"
"path"
"path/filepath"
"sort"
"strings"

"github.com/blang/semver/v4"
"go.temporal.io/server/common/config"
"go.temporal.io/server/common/log"
p "go.temporal.io/server/common/persistence"
Expand Down Expand Up @@ -65,41 +70,31 @@ func SetUpCassandraDatabase(cfg *config.Cassandra) {
}
}

func SetUpCassandraSchema(cfg *config.Cassandra) {
session, err := gocql.NewSession(*cfg, resolver.NewNoopResolver(), log.NewNoopLogger())
if err != nil {
panic(fmt.Sprintf("unable to create Cassandra session: %v", err))
}
defer session.Close()

schemaPath, err := filepath.Abs(testCassandraExecutionSchema)
if err != nil {
panic(err)
}
func SetUpCassandraSchema(cfg *config.Cassandra, logger log.Logger) {
ApplySchemaUpdate(cfg, testCassandraExecutionSchema, logger)
ApplySchemaUpdate(cfg, testCassandraVisibilitySchema, logger)
}

statements, err := p.LoadAndSplitQuery([]string{schemaPath})
func ApplySchemaUpdate(cfg *config.Cassandra, schemaFile string, logger log.Logger) {
session, err := gocql.NewSession(*cfg, resolver.NewNoopResolver(), logger)
if err != nil {
panic(err)
}
defer session.Close()

for _, stmt := range statements {
if err = session.Query(stmt).Exec(); err != nil {
panic(err)
}
}

schemaPath, err = filepath.Abs(testCassandraVisibilitySchema)
schemaPath, err := filepath.Abs(schemaFile)
if err != nil {
panic(err)
}

statements, err = p.LoadAndSplitQuery([]string{schemaPath})
statements, err := p.LoadAndSplitQuery([]string{schemaPath})
if err != nil {
panic(err)
}

for _, stmt := range statements {
if err = session.Query(stmt).Exec(); err != nil {
logger.Error(fmt.Sprintf("Unable to execute statement from file: %s\n %s", schemaFile, stmt))
panic(err)
}
}
Expand All @@ -124,3 +119,60 @@ func TearDownCassandraKeyspace(cfg *config.Cassandra) {
panic(fmt.Sprintf("unable to drop Cassandra keyspace: %v", err))
}
}

// GetSchemaFiles takes a root directory which contains subdirectories whose names are semantic versions and returns
// the .cql files within. E.g.: //schema/cassandra/temporal/versioned
// Subdirectories are ordered by semantic version, but files within the same subdirectory are in arbitrary order.
// All .cql files are returned regardless of whether they are named in manifest.json.
func GetSchemaFiles(schemaDir string, logger log.Logger) []string {
var retVal []string

versionDirPath := path.Join(schemaDir, "versioned")
subDirs, err := os.ReadDir(versionDirPath)
if err != nil {
panic(err)
}

versionDirNames := make([]string, 0, len(subDirs))
for _, subDir := range subDirs {
if !subDir.IsDir() {
logger.Warn(fmt.Sprintf("Skipping non-directory file: '%s'", subDir.Name()))
continue
}
if _, ve := semver.ParseTolerant(subDir.Name()); ve != nil {
logger.Warn(fmt.Sprintf("Skipping directory which is not a valid semver: '%s'", subDir.Name()))
}
versionDirNames = append(versionDirNames, subDir.Name())
}

sort.Slice(versionDirNames, func(i, j int) bool {
vLeft, err := semver.ParseTolerant(versionDirNames[i])
if err != nil {
panic(err) // Logic error
}
vRight, err := semver.ParseTolerant(versionDirNames[j])
if err != nil {
panic(err) // Logic error
}
return vLeft.Compare(vRight) < 0
})

for _, dir := range versionDirNames {
vDirPath := path.Join(versionDirPath, dir)
files, err := os.ReadDir(vDirPath)
if err != nil {
panic(err)
}
for _, file := range files {
if file.IsDir() {
continue
}
if !strings.HasSuffix(file.Name(), ".cql") {
continue
}
retVal = append(retVal, path.Join(vDirPath, file.Name()))
}
}

return retVal
}

0 comments on commit 2503c84

Please sign in to comment.