diff --git a/batch_test.go b/batch_test.go index 490ae796d..f5c40dde8 100644 --- a/batch_test.go +++ b/batch_test.go @@ -6,10 +6,12 @@ package gocql import ( "testing" "time" + + "github.com/gocql/gocql/internal/testcmdline" ) func TestBatch_Errors(t *testing.T) { - if *flagProto == 1 { + if *testcmdline.Proto == 1 { } session := createSession(t) diff --git a/cassandra_only_test.go b/cassandra_only_test.go index fd02d01b0..383326d6f 100644 --- a/cassandra_only_test.go +++ b/cassandra_only_test.go @@ -13,6 +13,8 @@ import ( "sync" "testing" "time" + + "github.com/gocql/gocql/internal/testcmdline" ) func TestDiscoverViaProxy(t *testing.T) { @@ -204,8 +206,8 @@ func TestGetKeyspaceMetadata(t *testing.T) { if err != nil { t.Fatalf("Error converting string to int with err: %v", err) } - if rfInt != *flagRF { - t.Errorf("Expected replication factor to be %d but was %d", *flagRF, rfInt) + if rfInt != *testcmdline.RF { + t.Errorf("Expected replication factor to be %d but was %d", *testcmdline.RF, rfInt) } } @@ -431,7 +433,7 @@ func TestViewMetadata(t *testing.T) { } textType := TypeText - if flagCassVersion.Before(3, 0, 0) { + if testcmdline.CassVersion.Before(3, 0, 0) { textType = TypeVarchar } @@ -453,7 +455,7 @@ func TestViewMetadata(t *testing.T) { } func TestMaterializedViewMetadata(t *testing.T) { - if flagCassVersion.Before(3, 0, 0) { + if testcmdline.CassVersion.Before(3, 0, 0) { return } session := createSession(t) @@ -552,7 +554,7 @@ func TestAggregateMetadata(t *testing.T) { } // In this case cassandra is returning a blob - if flagCassVersion.Before(3, 0, 0) { + if testcmdline.CassVersion.Before(3, 0, 0) { expectedAggregrate.InitCond = string([]byte{0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0}) } @@ -736,7 +738,7 @@ func TestKeyspaceMetadata(t *testing.T) { t.Fatal("failed to find the types in metadata") } textType := TypeText - if flagCassVersion.Before(3, 0, 0) { + if testcmdline.CassVersion.Before(3, 0, 0) { textType = TypeVarchar } expectedType := UserTypeMetadata{ @@ -753,7 +755,7 @@ func TestKeyspaceMetadata(t *testing.T) { if !reflect.DeepEqual(*keyspaceMetadata.UserTypes["basicview"], expectedType) { t.Fatalf("type is %+v, but expected %+v", keyspaceMetadata.UserTypes["basicview"], expectedType) } - if flagCassVersion.Major >= 3 { + if testcmdline.CassVersion.Major >= 3 { materializedView, found := keyspaceMetadata.MaterializedViews["view_view"] if !found { t.Fatal("failed to find materialized view view_view in metadata") diff --git a/cassandra_test.go b/cassandra_test.go index f7539ded4..fa599ff95 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -20,6 +20,8 @@ import ( "unicode" inf "gopkg.in/inf.v0" + + "github.com/gocql/gocql/internal/testcmdline" ) func TestEmptyHosts(t *testing.T) { @@ -2126,8 +2128,8 @@ func TestGetKeyspaceMetadata(t *testing.T) { if err != nil { t.Fatalf("Error converting string to int with err: %v", err) } - if rfInt != *flagRF { - t.Errorf("Expected replication factor to be %d but was %d", *flagRF, rfInt) + if rfInt != *testcmdline.RF { + t.Errorf("Expected replication factor to be %d but was %d", *testcmdline.RF, rfInt) } } @@ -2494,8 +2496,8 @@ func TestUnmarshallNestedTypes(t *testing.T) { } func TestSchemaReset(t *testing.T) { - if flagCassVersion.Major == 0 || flagCassVersion.Before(2, 1, 3) { - t.Skipf("skipping TestSchemaReset due to CASSANDRA-7910 in Cassandra <2.1.3 version=%v", flagCassVersion) + if testcmdline.CassVersion.Major == 0 || testcmdline.CassVersion.Before(2, 1, 3) { + t.Skipf("skipping TestSchemaReset due to CASSANDRA-7910 in Cassandra <2.1.3 version=%v", testcmdline.CassVersion) } cluster := createCluster() @@ -2560,7 +2562,7 @@ func TestCreateSession_DontSwallowError(t *testing.T) { t.Fatal("expected to get an error for unsupported protocol") } - if flagCassVersion.Major < 3 { + if testcmdline.CassVersion.Major < 3 { // TODO: we should get a distinct error type here which include the underlying // cassandra error about the protocol version, for now check this here. if !strings.Contains(err.Error(), "Invalid or unsupported protocol version") { diff --git a/cloud_cluster_test.go b/cloud_cluster_test.go index 4133ac56e..ec67af949 100644 --- a/cloud_cluster_test.go +++ b/cloud_cluster_test.go @@ -16,13 +16,15 @@ import ( "testing" "time" + "sigs.k8s.io/yaml" + "github.com/gocql/gocql" + "github.com/gocql/gocql/internal/testcmdline" "github.com/gocql/gocql/scyllacloud" - "sigs.k8s.io/yaml" ) func TestCloudConnection(t *testing.T) { - if !*gocql.FlagRunSslTest { + if !*testcmdline.RunSslTest { t.Skip("Skipping because SSL is not enabled on cluster") } diff --git a/common_test.go b/common_test.go index abbe91cce..462fb2b6d 100644 --- a/common_test.go +++ b/common_test.go @@ -1,7 +1,6 @@ package gocql import ( - "flag" "fmt" "log" "net" @@ -10,41 +9,24 @@ import ( "sync" "testing" "time" -) -var ( - flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") - flagMultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples") - flagProto = flag.Int("proto", 0, "protcol version") - flagCQL = flag.String("cql", "3.0.0", "CQL version") - flagRF = flag.Int("rf", 1, "replication factor for test keyspace") - clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster") - flagRetry = flag.Int("retries", 5, "number of times to retry queries") - flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll") - flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test") - flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test") - flagCompressTest = flag.String("compressor", "", "compressor to use") - flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations") - - flagCassVersion cassVersion + "github.com/gocql/gocql/internal/testcmdline" ) func init() { - flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against") - log.SetFlags(log.Lshortfile | log.LstdFlags) } func getClusterHosts() []string { - return strings.Split(*flagCluster, ",") + return strings.Split(*testcmdline.Cluster, ",") } func getMultiNodeClusterHosts() []string { - return strings.Split(*flagMultiNodeCluster, ",") + return strings.Split(*testcmdline.MultiNodeCluster, ",") } func addSslOptions(cluster *ClusterConfig) *ClusterConfig { - if *flagRunSslTest { + if *testcmdline.RunSslTest { cluster.Port = 9142 cluster.SslOpts = &SslOptions{ CertPath: "testdata/pki/gocql.crt", @@ -81,21 +63,21 @@ func createTable(s *Session, table string) error { func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig { clusterHosts := getClusterHosts() cluster := NewCluster(clusterHosts...) - cluster.ProtoVersion = *flagProto - cluster.CQLVersion = *flagCQL - cluster.Timeout = *flagTimeout + cluster.ProtoVersion = *testcmdline.Proto + cluster.CQLVersion = *testcmdline.CQL + cluster.Timeout = *testcmdline.Timeout cluster.Consistency = Quorum cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow - if *flagRetry > 0 { - cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry} + if *testcmdline.Retry > 0 { + cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *testcmdline.Retry} } - switch *flagCompressTest { + switch *testcmdline.CompressTest { case "snappy": cluster.Compressor = &SnappyCompressor{} case "": default: - panic("invalid compressor: " + *flagCompressTest) + panic("invalid compressor: " + *testcmdline.CompressTest) } cluster = addSslOptions(cluster) @@ -110,21 +92,21 @@ func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig { func createMultiNodeCluster(opts ...func(*ClusterConfig)) *ClusterConfig { clusterHosts := getMultiNodeClusterHosts() cluster := NewCluster(clusterHosts...) - cluster.ProtoVersion = *flagProto - cluster.CQLVersion = *flagCQL - cluster.Timeout = *flagTimeout + cluster.ProtoVersion = *testcmdline.Proto + cluster.CQLVersion = *testcmdline.CQL + cluster.Timeout = *testcmdline.Timeout cluster.Consistency = Quorum cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow - if *flagRetry > 0 { - cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry} + if *testcmdline.Retry > 0 { + cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *testcmdline.Retry} } - switch *flagCompressTest { + switch *testcmdline.CompressTest { case "snappy": cluster.Compressor = &SnappyCompressor{} case "": default: - panic("invalid compressor: " + *flagCompressTest) + panic("invalid compressor: " + *testcmdline.CompressTest) } cluster = addSslOptions(cluster) @@ -156,7 +138,7 @@ func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) { WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : %d - }`, keyspace, *flagRF)) + }`, keyspace, *testcmdline.RF)) if err != nil { panic(fmt.Sprintf("unable to create keyspace: %v", err)) @@ -232,7 +214,7 @@ func createViews(t *testing.T, session *Session) { } func createMaterializedViews(t *testing.T, session *Session) { - if flagCassVersion.Before(3, 0, 0) { + if testcmdline.CassVersion.Before(3, 0, 0) { return } if err := session.Query(`CREATE TABLE IF NOT EXISTS gocql_test.view_table ( diff --git a/export_test.go b/export_test.go index 830436303..3295697db 100644 --- a/export_test.go +++ b/export_test.go @@ -3,7 +3,6 @@ package gocql -var FlagRunSslTest = flagRunSslTest var CreateCluster = createCluster var TestLogger = &testLogger{} var WaitUntilPoolsStopFilling = waitUntilPoolsStopFilling diff --git a/integration_test.go b/integration_test.go index f548a829f..11f22c445 100644 --- a/integration_test.go +++ b/integration_test.go @@ -9,16 +9,18 @@ import ( "reflect" "testing" "time" + + "github.com/gocql/gocql/internal/testcmdline" ) // TestAuthentication verifies that gocql will work with a host configured to only accept authenticated connections func TestAuthentication(t *testing.T) { - if *flagProto < 2 { + if *testcmdline.Proto < 2 { t.Skip("Authentication is not supported with protocol < 2") } - if !*flagRunAuthTest { + if !*testcmdline.RunAuthTest { t.Skip("Authentication is not configured in the target cluster") } @@ -60,21 +62,21 @@ func TestRingDiscovery(t *testing.T) { session := createSessionFromCluster(cluster, t) defer session.Close() - if *clusterSize > 1 { + if *testcmdline.ClusterSize > 1 { // wait for autodiscovery to update the pool with the list of known hosts - time.Sleep(*flagAutoWait) + time.Sleep(*testcmdline.AutoWait) } session.pool.mu.RLock() defer session.pool.mu.RUnlock() size := len(session.pool.hostConnPools) - if *clusterSize != size { + if *testcmdline.ClusterSize != size { for p, pool := range session.pool.hostConnPools { t.Logf("p=%q host=%v ips=%s", p, pool.host, pool.host.ConnectAddress().String()) } - t.Errorf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size) + t.Errorf("Expected a cluster size of %d, but actual size was %d", *testcmdline.ClusterSize, size) } } diff --git a/internal/testutils/flags.go b/internal/testcmdline/flags.go similarity index 53% rename from internal/testutils/flags.go rename to internal/testcmdline/flags.go index ce19c3080..938c346f5 100644 --- a/internal/testutils/flags.go +++ b/internal/testcmdline/flags.go @@ -1,31 +1,27 @@ -package testutils +package testcmdline import ( "flag" "fmt" - "log" "strconv" "strings" "time" - - "github.com/gocql/gocql" ) var ( - flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") - flagMultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples") - flagProto = flag.Int("proto", 0, "protcol version") - flagCQL = flag.String("cql", "3.0.0", "CQL version") - flagRF = flag.Int("rf", 1, "replication factor for test keyspace") - clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster") - flagRetry = flag.Int("retries", 5, "number of times to retry queries") - flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll") - flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test") - flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test") - flagCompressTest = flag.String("compressor", "", "compressor to use") - flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations") - - flagCassVersion cassVersion + Cluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") + MultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples") + Proto = flag.Int("proto", 0, "protcol version") + CQL = flag.String("cql", "3.0.0", "CQL version") + RF = flag.Int("rf", 1, "replication factor for test keyspace") + ClusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster") + Retry = flag.Int("retries", 5, "number of times to retry queries") + AutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll") + RunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test") + RunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test") + CompressTest = flag.String("compressor", "", "compressor to use") + Timeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations") + CassVersion cassVersion ) type cassVersion struct { @@ -37,11 +33,7 @@ func (c *cassVersion) Set(v string) error { return nil } - return c.UnmarshalCQL(nil, []byte(v)) -} - -func (c *cassVersion) UnmarshalCQL(info gocql.TypeInfo, data []byte) error { - return c.unmarshal(data) + return c.unmarshal([]byte(v)) } func (c *cassVersion) unmarshal(data []byte) error { @@ -108,7 +100,5 @@ func (c cassVersion) nodeUpDelay() time.Duration { } func init() { - flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against") - - log.SetFlags(log.Lshortfile | log.LstdFlags) + flag.Var(&CassVersion, "gocql.cversion", "the cassandra version being tested against") } diff --git a/internal/testutils/cluster.go b/internal/testutils/cluster.go index 431614be4..fb81715f4 100644 --- a/internal/testutils/cluster.go +++ b/internal/testutils/cluster.go @@ -10,6 +10,7 @@ import ( "time" "github.com/gocql/gocql" + "github.com/gocql/gocql/internal/testcmdline" ) var initOnce sync.Once @@ -22,21 +23,21 @@ func CreateSession(tb testing.TB, opts ...func(config *gocql.ClusterConfig)) *go func CreateCluster(opts ...func(*gocql.ClusterConfig)) *gocql.ClusterConfig { clusterHosts := getClusterHosts() cluster := gocql.NewCluster(clusterHosts...) - cluster.ProtoVersion = *flagProto - cluster.CQLVersion = *flagCQL - cluster.Timeout = *flagTimeout + cluster.ProtoVersion = *testcmdline.Proto + cluster.CQLVersion = *testcmdline.CQL + cluster.Timeout = *testcmdline.Timeout cluster.Consistency = gocql.Quorum cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow - if *flagRetry > 0 { - cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: *flagRetry} + if *testcmdline.Retry > 0 { + cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: *testcmdline.Retry} } - switch *flagCompressTest { + switch *testcmdline.CompressTest { case "snappy": cluster.Compressor = &gocql.SnappyCompressor{} case "": default: - panic("invalid compressor: " + *flagCompressTest) + panic("invalid compressor: " + *testcmdline.CompressTest) } cluster = addSslOptions(cluster) @@ -69,7 +70,7 @@ func createSessionFromCluster(cluster *gocql.ClusterConfig, tb testing.TB) *gocq } func getClusterHosts() []string { - return strings.Split(*flagCluster, ",") + return strings.Split(*testcmdline.Cluster, ",") } func createKeyspace(tb testing.TB, cluster *gocql.ClusterConfig, keyspace string) { @@ -92,7 +93,7 @@ func createKeyspace(tb testing.TB, cluster *gocql.ClusterConfig, keyspace string WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : %d - }`, keyspace, *flagRF)) + }`, keyspace, *testcmdline.RF)) if err != nil { panic(fmt.Sprintf("unable to create keyspace: %v", err)) @@ -120,7 +121,7 @@ func CreateTable(s *gocql.Session, table string) error { } func addSslOptions(cluster *gocql.ClusterConfig) *gocql.ClusterConfig { - if *flagRunSslTest { + if *testcmdline.RunSslTest { cluster.Port = 9142 cluster.SslOpts = &gocql.SslOptions{ CertPath: "testdata/pki/gocql.crt",