Skip to content

Commit

Permalink
Entity Key Filter Integration Test (#520) (#521)
Browse files Browse the repository at this point in the history
Adds a test for the entity key filter
  • Loading branch information
kevinjnguyen authored Jul 18, 2023
1 parent 085c3a0 commit da9eac9
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 29 deletions.
22 changes: 5 additions & 17 deletions clients/python/src/fenlmagic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import os
import sys
from typing import Optional

import IPython
import pandas
Expand Down Expand Up @@ -40,12 +41,10 @@ def set_dataframe(self, dataframe: pandas.DataFrame):

@magics_class
class FenlMagics(Magics):
client = None

def __init__(self, shell, client):
def __init__(self, shell, client: Optional[client.Client]):
super(FenlMagics, self).__init__(shell)
self.client = client
logger.info("extension loaded")
self.client = client

@magic_arguments()
@argument(
Expand Down Expand Up @@ -174,19 +173,8 @@ def fenl(self, arg, cell=None):
raise UsageError(e)


def load_ipython_extension(ipython):
if client.KASKADA_DEFAULT_CLIENT is None:
logger.warn(
"No client was initialized. Initializing default client to connect to localhost:50051."
)
default_client = client.Client(
client_id=os.getenv("KASKADA_CLIENT_ID", None),
endpoint=client.KASKADA_DEFAULT_ENDPOINT,
is_secure=client.KASKADA_IS_SECURE,
)
client.set_default_client(default_client)

magics = FenlMagics(ipython, client.get_client())
def load_ipython_extension(ipython, client: Optional[client.Client] = None):
magics = FenlMagics(ipython, client)
ipython.register_magics(magics)


Expand Down
6 changes: 4 additions & 2 deletions clients/python/src/kaskada/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,12 @@ def set_default_slice(slice: SliceFilter):
Args:
slice (SliceFilter): SliceFilter to set the default
"""
logger.debug(f"Default slice set to type {type(slice)}")

global KASKADA_DEFAULT_SLICE
KASKADA_DEFAULT_SLICE = slice
if KASKADA_DEFAULT_SLICE is None:
logger.info("Slicing disabled")
else:
logger.info(f"Slicing set to: {slice.to_request()}")


def set_default_client(client: Client):
Expand Down
16 changes: 14 additions & 2 deletions clients/python/src/kaskada/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import kaskada.formatters
import kaskada.kaskada.v1alpha.destinations_pb2 as destinations_pb
import kaskada.kaskada.v1alpha.query_service_pb2 as query_pb
from kaskada.client import KASKADA_DEFAULT_SLICE, Client, get_client
from kaskada.client import Client, get_client
from kaskada.slice_filters import SliceFilter
from kaskada.utils import get_timestamp, handleException, handleGrpcError

Expand Down Expand Up @@ -134,7 +134,19 @@ def create_query(
query_pb.CreateQueryResponse
"""
if slice_filter is None:
slice_filter = KASKADA_DEFAULT_SLICE
"""
Subtle Python Implementation Note:
The KASKADA_DEFAULT_SLICE is a global variable that varies at runtime. Users can set the default slice at any point.
The value of the slice is evaluated at execution time of this method.
Incorrect: from kaskada.client import KASKADA_DEFAULT_SLICE
This value is evaluated once at the import of the query module.
Correct: import kaskada.client
The value is then fetched from the module every time a query is invoked.
"""
slice_filter = kaskada.client.KASKADA_DEFAULT_SLICE

change_since_time = get_timestamp(changed_since_time)
final_result_time = get_timestamp(final_result_time)
Expand Down
51 changes: 51 additions & 0 deletions clients/python/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import kaskada.kaskada.v1alpha.destinations_pb2 as destinations_pb
import kaskada.kaskada.v1alpha.query_service_pb2 as query_pb
import kaskada.query
from kaskada.slice_filters import EntityPercentFilter

"""
def create_query(
Expand Down Expand Up @@ -48,6 +49,56 @@ def test_create_query_with_defaults(mockClient):
)


@patch("kaskada.client.Client")
def test_query_uses_client_global_slice_filter(mockClient):
filter_percentage = 65
entity_filter = EntityPercentFilter(filter_percentage)
kaskada.client.set_default_slice(entity_filter)
expression = "test_with_defaults"
expected_request = query_pb.CreateQueryRequest(
query=query_pb.Query(
expression=expression,
destination={
"object_store": destinations_pb.ObjectStoreDestination(
file_type=common_pb.FILE_TYPE_PARQUET
)
},
result_behavior="RESULT_BEHAVIOR_ALL_RESULTS",
slice=common_pb.SliceRequest(
percent=common_pb.SliceRequest.PercentSlice(percent=65),
),
),
query_options=query_pb.QueryOptions(presign_results=True),
)
kaskada.query.create_query(expression, client=mockClient)
mockClient.query_stub.CreateQuery.assert_called_with(
expected_request, metadata=mockClient.get_metadata()
)

filter_percentage = 10
entity_filter = EntityPercentFilter(filter_percentage)
kaskada.client.set_default_slice(entity_filter)
expected_request = query_pb.CreateQueryRequest(
query=query_pb.Query(
expression=expression,
destination={
"object_store": destinations_pb.ObjectStoreDestination(
file_type=common_pb.FILE_TYPE_PARQUET
)
},
result_behavior="RESULT_BEHAVIOR_ALL_RESULTS",
slice=common_pb.SliceRequest(
percent=common_pb.SliceRequest.PercentSlice(percent=10),
),
),
query_options=query_pb.QueryOptions(presign_results=True),
)
kaskada.query.create_query(expression, client=mockClient)
mockClient.query_stub.CreateQuery.assert_called_with(
expected_request, metadata=mockClient.get_metadata()
)


@patch("kaskada.client.Client")
def test_get_query(mockClient):
query_id = "12345"
Expand Down
117 changes: 117 additions & 0 deletions crates/sparrow-runtime/src/prepare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,11 @@ async fn reader_from_csv<'a, R: std::io::Read + std::io::Seek + Send + 'static>(

#[cfg(test)]
mod tests {
use std::sync::Arc;

use arrow::datatypes::{DataType, Field, Schema};
use futures::StreamExt;
use sparrow_api::kaskada::v1alpha::slice_plan::{EntityKeysSlice, Slice};
use sparrow_api::kaskada::v1alpha::{source_data, SourceData, TableConfig};
use uuid::Uuid;

Expand Down Expand Up @@ -409,4 +413,117 @@ mod tests {
let _prepared_schema = prepared_batch.schema();
let _metadata_schema = metadata.schema();
}

#[tokio::test]
async fn test_preparation_single_entity_key_slicing() {
let entity_keys = vec!["0b00083c-5c1e-47f5-abba-f89b12ae3cf4".to_owned()];
let slice = Some(Slice::EntityKeys(EntityKeysSlice { entity_keys }));
test_slicing_config(&slice, 23, 1).await;
}

#[tokio::test]
async fn test_preparation_no_matching_entity_key_slicing() {
let entity_keys = vec!["some-random-invalid-entity-key".to_owned()];
let slice = Some(Slice::EntityKeys(EntityKeysSlice { entity_keys }));
test_slicing_config(&slice, 0, 0).await;
}

#[tokio::test]
async fn test_preparation_multiple_matching_entity_key_slicing() {
let entity_keys = vec![
"0b00083c-5c1e-47f5-abba-f89b12ae3cf4".to_owned(),
"8a16beda-c07a-4625-a805-2d28f5934107".to_owned(),
];
let slice = Some(Slice::EntityKeys(EntityKeysSlice { entity_keys }));
test_slicing_config(&slice, 41, 2).await;
}

#[tokio::test]
async fn test_slicing_issue() {
let input_path = sparrow_testing::testdata_path("transactions/transactions_part1.parquet");

let input_path =
source_data::Source::ParquetPath(format!("file:///{}", input_path.display()));
let source_data = SourceData {
source: Some(input_path),
};

let table_config = TableConfig::new_with_table_source(
"transactions_slicing",
&Uuid::new_v4(),
"transaction_time",
Some("idx"),
"purchaser_id",
"",
);

let entity_keys = vec!["2798e270c7cab8c9eeacc046a3100a57".to_owned()];
let slice = Some(Slice::EntityKeys(EntityKeysSlice { entity_keys }));

let prepared_batches = super::prepared_batches(
&ObjectStoreRegistry::default(),
&source_data,
&table_config,
&slice,
)
.await
.unwrap()
.collect::<Vec<_>>()
.await;
assert_eq!(prepared_batches.len(), 1);
let (prepared_batch, metadata) = prepared_batches[0].as_ref().unwrap();
assert_eq!(prepared_batch.num_rows(), 300);
let _prepared_schema = prepared_batch.schema();
assert_metadata_schema_eq(metadata.schema());
assert_eq!(metadata.num_rows(), 1);
}

async fn test_slicing_config(
slice: &Option<Slice>,
num_prepared_rows: usize,
num_metadata_rows: usize,
) {
let input_path = sparrow_testing::testdata_path("eventdata/sample_event_data.parquet");

let input_path =
source_data::Source::ParquetPath(format!("file:///{}", input_path.display()));
let source_data = SourceData {
source: Some(input_path),
};

let table_config = TableConfig::new_with_table_source(
"Event",
&Uuid::new_v4(),
"timestamp",
Some("subsort_id"),
"anonymousId",
"user",
);

let prepared_batches = super::prepared_batches(
&ObjectStoreRegistry::default(),
&source_data,
&table_config,
slice,
)
.await
.unwrap()
.collect::<Vec<_>>()
.await;
assert_eq!(prepared_batches.len(), 1);
let (prepared_batch, metadata) = prepared_batches[0].as_ref().unwrap();
assert_eq!(prepared_batch.num_rows(), num_prepared_rows);
let _prepared_schema = prepared_batch.schema();
assert_metadata_schema_eq(metadata.schema());
assert_eq!(metadata.num_rows(), num_metadata_rows);
}

fn assert_metadata_schema_eq(metadata_schema: Arc<Schema>) {
let fields = vec![
Field::new("_hash", DataType::UInt64, false),
Field::new("_entity_key", DataType::Utf8, true),
];
let schema = Arc::new(Schema::new(fields));
assert_eq!(metadata_schema, schema);
}
}
Loading

0 comments on commit da9eac9

Please sign in to comment.