Skip to content

Commit

Permalink
Merge pull request #394 from scylladb/dk/stop-result-indexing-in-tests
Browse files Browse the repository at this point in the history
Stop result indexing in tests
  • Loading branch information
dkropachev authored Dec 20, 2024
2 parents dea4904 + e24acab commit 347f332
Show file tree
Hide file tree
Showing 14 changed files with 159 additions and 157 deletions.
2 changes: 1 addition & 1 deletion tests/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_server_versions():

c = TestCluster()
s = c.connect()
row = s.execute('SELECT cql_version, release_version FROM system.local')[0]
row = s.execute('SELECT cql_version, release_version FROM system.local').one()

cass_version = _tuple_version(row.release_version)
cql_version = _tuple_version(row.cql_version)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def tearDownClass(cls):
def _verify_statement(self, original):
st = SelectStatement(self.table_name)
result = execute(st)
response = result[0]
response = result.one()

for assignment in original.assignments:
self.assertEqual(response[assignment.field], assignment.value)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/cqlengine/test_ttl.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def get_default_ttl(self, table_name):
except InvalidRequest:
default_ttl = session.execute("SELECT default_time_to_live FROM system.schema_columnfamilies "
"WHERE keyspace_name = 'cqlengine_test' AND columnfamily_name = '{0}'".format(table_name))
return default_ttl[0]['default_time_to_live']
return default_ttl.one()['default_time_to_live']

def test_default_ttl_not_set(self):
session = get_session()
Expand Down
14 changes: 7 additions & 7 deletions tests/integration/standard/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def test_profile_load_balancing(self):

# use a copied instance and override the row factory
# assert last returned value can be accessed as a namedtuple so we can prove something different
named_tuple_row = rs[0]
named_tuple_row = rs.one()
self.assertIsInstance(named_tuple_row, tuple)
self.assertTrue(named_tuple_row.release_version)

Expand All @@ -910,13 +910,13 @@ def test_profile_load_balancing(self):
rs = session.execute(query, execution_profile=tmp_profile)
queried_hosts.add(rs.response_future._current_host)
self.assertEqual(queried_hosts, expected_hosts)
tuple_row = rs[0]
tuple_row = rs.one()
self.assertIsInstance(tuple_row, tuple)
with self.assertRaises(AttributeError):
tuple_row.release_version

# make sure original profile is not impacted
self.assertTrue(session.execute(query, execution_profile='node1')[0].release_version)
self.assertTrue(session.execute(query, execution_profile='node1').one().release_version)

def test_setting_lbp_legacy(self):
cluster = TestCluster()
Expand Down Expand Up @@ -1390,7 +1390,7 @@ def test_simple_nested(self):
with cluster.connect() as session:
self.assertFalse(cluster.is_shutdown)
self.assertFalse(session.is_shutdown)
self.assertTrue(session.execute('select release_version from system.local')[0])
self.assertTrue(session.execute('select release_version from system.local').one())
self.assertTrue(session.is_shutdown)
self.assertTrue(cluster.is_shutdown)

Expand All @@ -1408,7 +1408,7 @@ def test_cluster_no_session(self):
session = cluster.connect()
self.assertFalse(cluster.is_shutdown)
self.assertFalse(session.is_shutdown)
self.assertTrue(session.execute('select release_version from system.local')[0])
self.assertTrue(session.execute('select release_version from system.local').one())
self.assertTrue(session.is_shutdown)
self.assertTrue(cluster.is_shutdown)

Expand All @@ -1428,7 +1428,7 @@ def test_session_no_cluster(self):
self.assertFalse(cluster.is_shutdown)
self.assertFalse(session.is_shutdown)
self.assertFalse(unmanaged_session.is_shutdown)
self.assertTrue(session.execute('select release_version from system.local')[0])
self.assertTrue(session.execute('select release_version from system.local').one())
self.assertTrue(session.is_shutdown)
self.assertFalse(cluster.is_shutdown)
self.assertFalse(unmanaged_session.is_shutdown)
Expand Down Expand Up @@ -1551,7 +1551,7 @@ def test_valid_protocol_version_beta_options_connect(self):
cluster = Cluster(protocol_version=cassandra.ProtocolVersion.V6, allow_beta_protocol_version=True)
session = cluster.connect()
self.assertEqual(cluster.protocol_version, cassandra.ProtocolVersion.V6)
self.assertTrue(session.execute("select release_version from system.local")[0])
self.assertTrue(session.execute("select release_version from system.local").one())
cluster.shutdown()


Expand Down
8 changes: 4 additions & 4 deletions tests/integration/standard/test_custom_protocol_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,20 @@ def test_custom_raw_uuid_row_results(self):
session = cluster.connect(keyspace="custserdes")

result = session.execute("SELECT schema_version FROM system.local")
uuid_type = result[0][0]
uuid_type = result.one()[0]
self.assertEqual(type(uuid_type), uuid.UUID)

# use our custom protocol handlder
session.client_protocol_handler = CustomTestRawRowType
result_set = session.execute("SELECT schema_version FROM system.local")
raw_value = result_set[0][0]
raw_value = result_set.one()[0]
self.assertTrue(isinstance(raw_value, bytes))
self.assertEqual(len(raw_value), 16)

# Ensure that we get normal uuid back when we re-connect
session.client_protocol_handler = ProtocolHandler
result_set = session.execute("SELECT schema_version FROM system.local")
uuid_type = result_set[0][0]
uuid_type = result_set.one()[0]
self.assertEqual(type(uuid_type), uuid.UUID)
cluster.shutdown()

Expand Down Expand Up @@ -113,7 +113,7 @@ def test_custom_raw_row_results_all_types(self):

# verify data
params = get_all_primitive_params(0)
results = session.execute("SELECT {0} FROM alltypes WHERE primkey=0".format(columns_string))[0]
results = session.execute("SELECT {0} FROM alltypes WHERE primkey=0".format(columns_string)).one()
for expected, actual in zip(params, results):
self.assertEqual(actual, expected)
# Ensure we have covered the various primitive types
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/standard/test_cython_protocol_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,14 @@ def test_null_types(self):
table = "%s.%s" % (self.keyspace_name, self.function_table_name)
create_table_with_all_types(table, s, 10)

begin_unset = max(s.execute('select primkey from %s' % (table,))[0]['primkey']) + 1
begin_unset = max(s.execute('select primkey from %s' % (table,)).one()['primkey']) + 1
keys_null = range(begin_unset, begin_unset + 10)

# scatter some emptry rows in here
insert = "insert into %s (primkey) values (%%s)" % (table,)
execute_concurrent_with_args(s, insert, ((k,) for k in keys_null))

result = s.execute("select * from %s" % (table,))[0]
result = s.execute("select * from %s" % (table,)).one()

from numpy.ma import masked, MaskedArray
result_keys = result.pop('primkey')
Expand Down
10 changes: 5 additions & 5 deletions tests/integration/standard/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def test_schema_metadata_disable(self):
query = "SELECT * FROM system.local"
no_schema_rs = no_schema_session.execute(query)
no_token_rs = no_token_session.execute(query)
self.assertIsNotNone(no_schema_rs[0])
self.assertIsNotNone(no_token_rs[0])
self.assertIsNotNone(no_schema_rs.one())
self.assertIsNotNone(no_token_rs.one())
no_schema.shutdown()
no_token.shutdown()

Expand Down Expand Up @@ -1819,14 +1819,14 @@ def test_init_cond(self):
for init_cond in (-1, 0, 1):
cql_init = encoder.cql_encode_all_types(init_cond)
with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('sum_int', 'int', init_cond=cql_init)) as va:
sum_res = s.execute("SELECT %s(v) AS sum FROM t" % va.function_kwargs['name'])[0].sum
sum_res = s.execute("SELECT %s(v) AS sum FROM t" % va.function_kwargs['name']).one().sum
self.assertEqual(sum_res, int(init_cond) + sum(expected_values))

# list<text>
for init_cond in ([], ['1', '2']):
cql_init = encoder.cql_encode_all_types(init_cond)
with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('extend_list', 'list<text>', init_cond=cql_init)) as va:
list_res = s.execute("SELECT %s(v) AS list_res FROM t" % va.function_kwargs['name'])[0].list_res
list_res = s.execute("SELECT %s(v) AS list_res FROM t" % va.function_kwargs['name']).one().list_res
self.assertListEqual(list_res[:len(init_cond)], init_cond)
self.assertEqual(set(i for i in list_res[len(init_cond):]),
set(str(i) for i in expected_values))
Expand All @@ -1837,7 +1837,7 @@ def test_init_cond(self):
for init_cond in ({}, {1: 2, 3: 4}, {5: 5}):
cql_init = encoder.cql_encode_all_types(init_cond)
with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('update_map', 'map<int, int>', init_cond=cql_init)) as va:
map_res = s.execute("SELECT %s(v) AS map_res FROM t" % va.function_kwargs['name'])[0].map_res
map_res = s.execute("SELECT %s(v) AS map_res FROM t" % va.function_kwargs['name']).one().map_res
self.assertDictContainsSubset(expected_map_values, map_res)
init_not_updated = dict((k, init_cond[k]) for k in set(init_cond) - expected_key_set)
self.assertDictContainsSubset(init_not_updated, map_res)
Expand Down
14 changes: 7 additions & 7 deletions tests/integration/standard/test_prepared_statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def test_none_values(self):

bound = prepared.bind((1,))
results = self.session.execute(bound)
self.assertEqual(results[0].v, None)
self.assertEqual(results.one().v, None)

def test_unset_values(self):
"""
Expand Down Expand Up @@ -272,7 +272,7 @@ def test_unset_values(self):
for params, expected in bind_expected:
self.session.execute(insert, params)
results = self.session.execute(select, (0,))
self.assertEqual(results[0], expected)
self.assertEqual(results.one(), expected)

self.assertRaises(ValueError, self.session.execute, select, (UNSET_VALUE, 0, 0))

Expand All @@ -297,7 +297,7 @@ def test_no_meta(self):
bound = prepared.bind(None)
bound.consistency_level = ConsistencyLevel.ALL
results = self.session.execute(bound)
self.assertEqual(results[0].v, 0)
self.assertEqual(results.one().v, 0)

def test_none_values_dicts(self):
"""
Expand All @@ -322,7 +322,7 @@ def test_none_values_dicts(self):

bound = prepared.bind({'k': 1})
results = self.session.execute(bound)
self.assertEqual(results[0].v, None)
self.assertEqual(results.one().v, None)

def test_async_binding(self):
"""
Expand All @@ -346,7 +346,7 @@ def test_async_binding(self):

future = self.session.execute_async(prepared, (873,))
results = future.result()
self.assertEqual(results[0].v, None)
self.assertEqual(results.one().v, None)

def test_async_binding_dicts(self):
"""
Expand All @@ -369,7 +369,7 @@ def test_async_binding_dicts(self):

future = self.session.execute_async(prepared, {'k': 873})
results = future.result()
self.assertEqual(results[0].v, None)
self.assertEqual(results.one().v, None)

def test_raise_error_on_prepared_statement_execution_dropped_table(self):
"""
Expand Down Expand Up @@ -616,7 +616,7 @@ def _test_updated_conditional(self, session, value):

def check_result_and_metadata(expected):
self.assertEqual(
session.execute(prepared_statement, (value, value, value))[0],
session.execute(prepared_statement, (value, value, value)).one(),
expected
)
self.assertEqual(prepared_statement.result_metadata_id, first_id)
Expand Down
Loading

0 comments on commit 347f332

Please sign in to comment.