Skip to content

Commit

Permalink
Separate test command line to a package
Browse files Browse the repository at this point in the history
  • Loading branch information
dkropachev committed Jul 23, 2024
1 parent 5c7adac commit 847b9e4
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 96 deletions.
4 changes: 3 additions & 1 deletion batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ package gocql
import (
"testing"
"time"

"github.com/gocql/gocql/internal/testcmdline"
)

func TestBatch_Errors(t *testing.T) {
if *flagProto == 1 {
if *testcmdline.Proto == 1 {
}

session := createSession(t)
Expand Down
16 changes: 9 additions & 7 deletions cassandra_only_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"sync"
"testing"
"time"

"github.com/gocql/gocql/internal/testcmdline"
)

func TestDiscoverViaProxy(t *testing.T) {
Expand Down Expand Up @@ -204,8 +206,8 @@ func TestGetKeyspaceMetadata(t *testing.T) {
if err != nil {
t.Fatalf("Error converting string to int with err: %v", err)
}
if rfInt != *flagRF {
t.Errorf("Expected replication factor to be %d but was %d", *flagRF, rfInt)
if rfInt != *testcmdline.RF {
t.Errorf("Expected replication factor to be %d but was %d", *testcmdline.RF, rfInt)
}
}

Expand Down Expand Up @@ -431,7 +433,7 @@ func TestViewMetadata(t *testing.T) {
}

textType := TypeText
if flagCassVersion.Before(3, 0, 0) {
if testcmdline.CassVersion.Before(3, 0, 0) {
textType = TypeVarchar
}

Expand All @@ -453,7 +455,7 @@ func TestViewMetadata(t *testing.T) {
}

func TestMaterializedViewMetadata(t *testing.T) {
if flagCassVersion.Before(3, 0, 0) {
if testcmdline.CassVersion.Before(3, 0, 0) {
return
}
session := createSession(t)
Expand Down Expand Up @@ -552,7 +554,7 @@ func TestAggregateMetadata(t *testing.T) {
}

// In this case cassandra is returning a blob
if flagCassVersion.Before(3, 0, 0) {
if testcmdline.CassVersion.Before(3, 0, 0) {
expectedAggregrate.InitCond = string([]byte{0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0})
}

Expand Down Expand Up @@ -736,7 +738,7 @@ func TestKeyspaceMetadata(t *testing.T) {
t.Fatal("failed to find the types in metadata")
}
textType := TypeText
if flagCassVersion.Before(3, 0, 0) {
if testcmdline.CassVersion.Before(3, 0, 0) {
textType = TypeVarchar
}
expectedType := UserTypeMetadata{
Expand All @@ -753,7 +755,7 @@ func TestKeyspaceMetadata(t *testing.T) {
if !reflect.DeepEqual(*keyspaceMetadata.UserTypes["basicview"], expectedType) {
t.Fatalf("type is %+v, but expected %+v", keyspaceMetadata.UserTypes["basicview"], expectedType)
}
if flagCassVersion.Major >= 3 {
if testcmdline.CassVersion.Major >= 3 {
materializedView, found := keyspaceMetadata.MaterializedViews["view_view"]
if !found {
t.Fatal("failed to find materialized view view_view in metadata")
Expand Down
12 changes: 7 additions & 5 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"unicode"

inf "gopkg.in/inf.v0"

"github.com/gocql/gocql/internal/testcmdline"
)

func TestEmptyHosts(t *testing.T) {
Expand Down Expand Up @@ -2126,8 +2128,8 @@ func TestGetKeyspaceMetadata(t *testing.T) {
if err != nil {
t.Fatalf("Error converting string to int with err: %v", err)
}
if rfInt != *flagRF {
t.Errorf("Expected replication factor to be %d but was %d", *flagRF, rfInt)
if rfInt != *testcmdline.RF {
t.Errorf("Expected replication factor to be %d but was %d", *testcmdline.RF, rfInt)
}
}

Expand Down Expand Up @@ -2494,8 +2496,8 @@ func TestUnmarshallNestedTypes(t *testing.T) {
}

func TestSchemaReset(t *testing.T) {
if flagCassVersion.Major == 0 || flagCassVersion.Before(2, 1, 3) {
t.Skipf("skipping TestSchemaReset due to CASSANDRA-7910 in Cassandra <2.1.3 version=%v", flagCassVersion)
if testcmdline.CassVersion.Major == 0 || testcmdline.CassVersion.Before(2, 1, 3) {
t.Skipf("skipping TestSchemaReset due to CASSANDRA-7910 in Cassandra <2.1.3 version=%v", testcmdline.CassVersion)
}

cluster := createCluster()
Expand Down Expand Up @@ -2560,7 +2562,7 @@ func TestCreateSession_DontSwallowError(t *testing.T) {
t.Fatal("expected to get an error for unsupported protocol")
}

if flagCassVersion.Major < 3 {
if testcmdline.CassVersion.Major < 3 {
// TODO: we should get a distinct error type here which include the underlying
// cassandra error about the protocol version, for now check this here.
if !strings.Contains(err.Error(), "Invalid or unsupported protocol version") {
Expand Down
6 changes: 4 additions & 2 deletions cloud_cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ import (
"testing"
"time"

"sigs.k8s.io/yaml"

"github.com/gocql/gocql"
"github.com/gocql/gocql/internal/testcmdline"
"github.com/gocql/gocql/scyllacloud"
"sigs.k8s.io/yaml"
)

func TestCloudConnection(t *testing.T) {
if !*gocql.FlagRunSslTest {
if !*testcmdline.RunSslTest {
t.Skip("Skipping because SSL is not enabled on cluster")
}

Expand Down
58 changes: 20 additions & 38 deletions common_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gocql

import (
"flag"
"fmt"
"log"
"net"
Expand All @@ -10,41 +9,24 @@ import (
"sync"
"testing"
"time"
)

var (
flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
flagMultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples")
flagProto = flag.Int("proto", 0, "protcol version")
flagCQL = flag.String("cql", "3.0.0", "CQL version")
flagRF = flag.Int("rf", 1, "replication factor for test keyspace")
clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster")
flagRetry = flag.Int("retries", 5, "number of times to retry queries")
flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll")
flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test")
flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test")
flagCompressTest = flag.String("compressor", "", "compressor to use")
flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations")

flagCassVersion cassVersion
"github.com/gocql/gocql/internal/testcmdline"
)

func init() {
flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against")

log.SetFlags(log.Lshortfile | log.LstdFlags)
}

func getClusterHosts() []string {
return strings.Split(*flagCluster, ",")
return strings.Split(*testcmdline.Cluster, ",")
}

func getMultiNodeClusterHosts() []string {
return strings.Split(*flagMultiNodeCluster, ",")
return strings.Split(*testcmdline.MultiNodeCluster, ",")
}

func addSslOptions(cluster *ClusterConfig) *ClusterConfig {
if *flagRunSslTest {
if *testcmdline.RunSslTest {
cluster.Port = 9142
cluster.SslOpts = &SslOptions{
CertPath: "testdata/pki/gocql.crt",
Expand Down Expand Up @@ -81,21 +63,21 @@ func createTable(s *Session, table string) error {
func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig {
clusterHosts := getClusterHosts()
cluster := NewCluster(clusterHosts...)
cluster.ProtoVersion = *flagProto
cluster.CQLVersion = *flagCQL
cluster.Timeout = *flagTimeout
cluster.ProtoVersion = *testcmdline.Proto
cluster.CQLVersion = *testcmdline.CQL
cluster.Timeout = *testcmdline.Timeout
cluster.Consistency = Quorum
cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow
if *flagRetry > 0 {
cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry}
if *testcmdline.Retry > 0 {
cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *testcmdline.Retry}
}

switch *flagCompressTest {
switch *testcmdline.CompressTest {
case "snappy":
cluster.Compressor = &SnappyCompressor{}
case "":
default:
panic("invalid compressor: " + *flagCompressTest)
panic("invalid compressor: " + *testcmdline.CompressTest)
}

cluster = addSslOptions(cluster)
Expand All @@ -110,21 +92,21 @@ func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig {
func createMultiNodeCluster(opts ...func(*ClusterConfig)) *ClusterConfig {
clusterHosts := getMultiNodeClusterHosts()
cluster := NewCluster(clusterHosts...)
cluster.ProtoVersion = *flagProto
cluster.CQLVersion = *flagCQL
cluster.Timeout = *flagTimeout
cluster.ProtoVersion = *testcmdline.Proto
cluster.CQLVersion = *testcmdline.CQL
cluster.Timeout = *testcmdline.Timeout
cluster.Consistency = Quorum
cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow
if *flagRetry > 0 {
cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry}
if *testcmdline.Retry > 0 {
cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *testcmdline.Retry}
}

switch *flagCompressTest {
switch *testcmdline.CompressTest {
case "snappy":
cluster.Compressor = &SnappyCompressor{}
case "":
default:
panic("invalid compressor: " + *flagCompressTest)
panic("invalid compressor: " + *testcmdline.CompressTest)
}

cluster = addSslOptions(cluster)
Expand Down Expand Up @@ -156,7 +138,7 @@ func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) {
WITH replication = {
'class' : 'SimpleStrategy',
'replication_factor' : %d
}`, keyspace, *flagRF))
}`, keyspace, *testcmdline.RF))

if err != nil {
panic(fmt.Sprintf("unable to create keyspace: %v", err))
Expand Down Expand Up @@ -232,7 +214,7 @@ func createViews(t *testing.T, session *Session) {
}

func createMaterializedViews(t *testing.T, session *Session) {
if flagCassVersion.Before(3, 0, 0) {
if testcmdline.CassVersion.Before(3, 0, 0) {
return
}
if err := session.Query(`CREATE TABLE IF NOT EXISTS gocql_test.view_table (
Expand Down
1 change: 0 additions & 1 deletion export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

package gocql

var FlagRunSslTest = flagRunSslTest
var CreateCluster = createCluster
var TestLogger = &testLogger{}
var WaitUntilPoolsStopFilling = waitUntilPoolsStopFilling
Expand Down
14 changes: 8 additions & 6 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ import (
"reflect"
"testing"
"time"

"github.com/gocql/gocql/internal/testcmdline"
)

// TestAuthentication verifies that gocql will work with a host configured to only accept authenticated connections
func TestAuthentication(t *testing.T) {

if *flagProto < 2 {
if *testcmdline.Proto < 2 {
t.Skip("Authentication is not supported with protocol < 2")
}

if !*flagRunAuthTest {
if !*testcmdline.RunAuthTest {
t.Skip("Authentication is not configured in the target cluster")
}

Expand Down Expand Up @@ -60,21 +62,21 @@ func TestRingDiscovery(t *testing.T) {
session := createSessionFromCluster(cluster, t)
defer session.Close()

if *clusterSize > 1 {
if *testcmdline.ClusterSize > 1 {
// wait for autodiscovery to update the pool with the list of known hosts
time.Sleep(*flagAutoWait)
time.Sleep(*testcmdline.AutoWait)
}

session.pool.mu.RLock()
defer session.pool.mu.RUnlock()
size := len(session.pool.hostConnPools)

if *clusterSize != size {
if *testcmdline.ClusterSize != size {
for p, pool := range session.pool.hostConnPools {
t.Logf("p=%q host=%v ips=%s", p, pool.host, pool.host.ConnectAddress().String())

}
t.Errorf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size)
t.Errorf("Expected a cluster size of %d, but actual size was %d", *testcmdline.ClusterSize, size)
}
}

Expand Down
42 changes: 16 additions & 26 deletions internal/testutils/flags.go → internal/testcmdline/flags.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
package testutils
package testcmdline

import (
"flag"
"fmt"
"log"
"strconv"
"strings"
"time"

"github.com/gocql/gocql"
)

var (
flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
flagMultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples")
flagProto = flag.Int("proto", 0, "protcol version")
flagCQL = flag.String("cql", "3.0.0", "CQL version")
flagRF = flag.Int("rf", 1, "replication factor for test keyspace")
clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster")
flagRetry = flag.Int("retries", 5, "number of times to retry queries")
flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll")
flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test")
flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test")
flagCompressTest = flag.String("compressor", "", "compressor to use")
flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations")

flagCassVersion cassVersion
Cluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
MultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples")
Proto = flag.Int("proto", 0, "protcol version")
CQL = flag.String("cql", "3.0.0", "CQL version")
RF = flag.Int("rf", 1, "replication factor for test keyspace")
ClusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster")
Retry = flag.Int("retries", 5, "number of times to retry queries")
AutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll")
RunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test")
RunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test")
CompressTest = flag.String("compressor", "", "compressor to use")
Timeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations")
CassVersion cassVersion
)

type cassVersion struct {
Expand All @@ -37,11 +33,7 @@ func (c *cassVersion) Set(v string) error {
return nil
}

return c.UnmarshalCQL(nil, []byte(v))
}

func (c *cassVersion) UnmarshalCQL(info gocql.TypeInfo, data []byte) error {
return c.unmarshal(data)
return c.unmarshal([]byte(v))
}

func (c *cassVersion) unmarshal(data []byte) error {
Expand Down Expand Up @@ -108,7 +100,5 @@ func (c cassVersion) nodeUpDelay() time.Duration {
}

func init() {
flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against")

log.SetFlags(log.Lshortfile | log.LstdFlags)
flag.Var(&CassVersion, "gocql.cversion", "the cassandra version being tested against")
}
Loading

0 comments on commit 847b9e4

Please sign in to comment.