Skip to content

Commit

Permalink
finished initial tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Sep 19, 2024
1 parent a6ac9a8 commit 9423ccc
Showing 1 changed file with 35 additions and 16 deletions.
51 changes: 35 additions & 16 deletions weave/tests/trace/test_table_query.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
import random

from weave.trace.weave_client import WeaveClient
from weave.trace_server import trace_server_interface as tsi


def generate_table_data(client: WeaveClient, n_rows: int, n_cols: int):
# Create a list of IDs and shuffle them to ensure random order
ids = list(range(n_rows))
random.shuffle(ids)

data = [
{
"id": i,
"nested_col": {
"prop_a": f"value_{i}_a",
"prop_b": f"value_{i}_b",
"prop_a": f"value_{chr(97 + (i % 26))}", # Use letters a-z cyclically
"prop_b": f"value_{random.randint(0, 100)}", # Random integer
},
**{
f"col_{j}": f"value_{random.randint(0, 100)}_{chr(97 + (i % 26))}"
for j in range(n_cols)
},
**{f"col_{j}": f"value_{i}_{j}" for j in range(n_cols)},
}
for i in range(n_rows)
for i in ids
]

res = client.server.table_create(
Expand Down Expand Up @@ -40,10 +49,10 @@ def test_table_query(client: WeaveClient):
)

row_data = [r.val for r in res.rows]
row_digests = [r.digest for r in res.rows]
result_row_digests = [r.digest for r in res.rows]

assert row_data == data
assert row_digests == row_digests
assert result_row_digests == row_digests


def test_table_query_filter_by_row_digests(client: WeaveClient):
Expand All @@ -60,7 +69,7 @@ def test_table_query_filter_by_row_digests(client: WeaveClient):

assert len(res.rows) == 3
assert [r.digest for r in res.rows] == filtered_digests
assert [r.val for r in res.rows] == data[2:5]
assert [r.val["id"] for r in res.rows] == [data[i]["id"] for i in range(2, 5)]


def test_table_query_limit(client: WeaveClient):
Expand All @@ -72,7 +81,7 @@ def test_table_query_limit(client: WeaveClient):
)

assert len(res.rows) == limit
assert [r.val for r in res.rows] == data[:limit]
assert all(r.val["id"] in [d["id"] for d in data] for r in res.rows)


def test_table_query_offset(client: WeaveClient):
Expand All @@ -84,7 +93,7 @@ def test_table_query_offset(client: WeaveClient):
)

assert len(res.rows) == len(data) - offset
assert [r.val for r in res.rows] == data[offset:]
assert all(r.val["id"] in [d["id"] for d in data] for r in res.rows)


def test_table_query_sort_no_sort(client: WeaveClient):
Expand All @@ -97,7 +106,7 @@ def test_table_query_sort_no_sort(client: WeaveClient):
)
)

assert [r.val for r in res.rows] == data
assert [r.val["id"] for r in res.rows] != sorted([d["id"] for d in data])


def test_table_query_sort_by_column(client: WeaveClient):
Expand All @@ -112,7 +121,7 @@ def test_table_query_sort_by_column(client: WeaveClient):
)

sorted_data = sorted(data, key=lambda x: x["id"], reverse=True)
assert [r.val for r in res.rows] == sorted_data
assert [r.val["id"] for r in res.rows] == [d["id"] for d in sorted_data]


def test_table_query_sort_by_nested_column(client: WeaveClient):
Expand All @@ -127,13 +136,18 @@ def test_table_query_sort_by_nested_column(client: WeaveClient):
)

sorted_data = sorted(data, key=lambda x: x["nested_col"]["prop_a"])
assert [r.val for r in res.rows] == sorted_data
assert [r.val["nested_col"]["prop_a"] for r in res.rows] == [
d["nested_col"]["prop_a"] for d in sorted_data
]
assert [r.val["id"] for r in res.rows] != [
d["id"] for d in data
] # Ensure order is different from original


def test_table_query_combined(client: WeaveClient):
digest, row_digests, data = generate_table_data(client, 20, 5)

filtered_digests = row_digests[5:15]
filtered_digests = random.sample(row_digests, 10)
limit = 5
offset = 2
res = client.server.table_query(
Expand All @@ -147,12 +161,12 @@ def test_table_query_combined(client: WeaveClient):
)
)

filtered_data = [d for d in data if d["id"] in range(5, 15)]
filtered_data = [d for d in data if d["id"] in [r.val["id"] for r in res.rows]]
sorted_data = sorted(filtered_data, key=lambda x: x["id"], reverse=True)
expected_data = sorted_data[offset : offset + limit]

assert len(res.rows) == limit
assert [r.val for r in res.rows] == expected_data
assert [r.val["id"] for r in res.rows] == [d["id"] for d in expected_data]


def test_table_query_multiple_sort_criteria(client: WeaveClient):
Expand All @@ -170,4 +184,9 @@ def test_table_query_multiple_sort_criteria(client: WeaveClient):
)

sorted_data = sorted(data, key=lambda x: (x["col_0"], -x["id"]))
assert [r.val for r in res.rows] == sorted_data
assert [(r.val["col_0"], r.val["id"]) for r in res.rows] == [
(d["col_0"], d["id"]) for d in sorted_data
]
assert [r.val["id"] for r in res.rows] != [
d["id"] for d in data
] # Ensure order is different from original

0 comments on commit 9423ccc

Please sign in to comment.