Skip to content

Commit

Permalink
feat: Optimize bytewax pod resource with zero-copy
Browse files Browse the repository at this point in the history
Signed-off-by: Hai Nguyen <quanghai.ng1512@gmail.com>
  • Loading branch information
sudohainguyen authored and achals committed Nov 14, 2023
1 parent 1f91fc6 commit 9cf9d96
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
from typing import List

Expand All @@ -7,11 +8,11 @@
from bytewax.execution import cluster_main
from bytewax.inputs import ManualInputConfig
from bytewax.outputs import ManualOutputConfig
from tqdm import tqdm

from feast import FeatureStore, FeatureView, RepoConfig
from feast.utils import _convert_arrow_to_proto, _run_pyarrow_field_mapping

logger = logging.getLogger(__name__)
DEFAULT_BATCH_SIZE = 1000


Expand All @@ -29,14 +30,20 @@ def __init__(
self.feature_view = feature_view
self.worker_index = worker_index
self.paths = paths
self.mini_batch_size = int(
os.getenv("BYTEWAX_MINI_BATCH_SIZE", DEFAULT_BATCH_SIZE)
)

self._run_dataflow()

def process_path(self, path):
logger.info(f"Processing path {path}")
dataset = pq.ParquetDataset(path, use_legacy_dataset=False)
batches = []
for fragment in dataset.fragments:
for batch in fragment.to_table().to_batches():
for batch in fragment.to_table().to_batches(
max_chunksize=self.mini_batch_size
):
batches.append(batch)

return batches
Expand All @@ -45,40 +52,26 @@ def input_builder(self, worker_index, worker_count, _state):
return [(None, self.paths[self.worker_index])]

def output_builder(self, worker_index, worker_count):
def yield_batch(iterable, batch_size):
"""Yield mini-batches from an iterable."""
for i in range(0, len(iterable), batch_size):
yield iterable[i : i + batch_size]

def output_fn(batch):
table = pa.Table.from_batches([batch])
def output_fn(mini_batch):
table: pa.Table = pa.Table.from_batches([mini_batch])

if self.feature_view.batch_source.field_mapping is not None:
table = _run_pyarrow_field_mapping(
table, self.feature_view.batch_source.field_mapping
)

join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in self.feature_view.entity_columns
}

rows_to_write = _convert_arrow_to_proto(
table, self.feature_view, join_key_to_value_type
)
provider = self.feature_store._get_provider()
with tqdm(total=len(rows_to_write)) as progress:
# break rows_to_write to mini-batches
batch_size = int(
os.getenv("BYTEWAX_MINI_BATCH_SIZE", DEFAULT_BATCH_SIZE)
)
for mini_batch in yield_batch(rows_to_write, batch_size):
provider.online_write_batch(
config=self.config,
table=self.feature_view,
data=mini_batch,
progress=progress.update,
)
self.feature_store._get_provider().online_write_batch(
config=self.config,
table=self.feature_view,
data=rows_to_write,
progress=None,
)

return output_fn

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ class BytewaxMaterializationEngineConfig(FeastConfigBaseModel):
mini_batch_size: int = 1000
""" (optional) Number of rows to process per write operation (default 1000)"""

bytewax_replicas: int = 5
""" (optional) Number of process to spawn in each pods to handle a file in parallel"""

bytewax_worker_per_process: int = 1
""" (optional) Number of threads as worker per bytewax process"""

active_deadline_seconds: int = 86400
""" (optional) Maximum amount of time a materialization job is allowed to run"""

Expand Down Expand Up @@ -111,7 +117,6 @@ def __init__(
self.offline_store = offline_store
self.online_store = online_store

# TODO: Configure k8s here
k8s_config.load_config()

self.k8s_client = client.api_client.ApiClient()
Expand Down Expand Up @@ -299,6 +304,9 @@ def _create_kubernetes_job(self, job_id, paths, feature_view):
len(paths), # Create a pod for each parquet file
self.batch_engine_config.env,
)
logger.info(
f"Created job `dataflow-{job_id}` on namespace `{self.namespace}`"
)
except FailToCreateError as failures:
return BytewaxMaterializationJob(job_id, self.namespace, error=failures)

Expand Down Expand Up @@ -345,7 +353,7 @@ def _create_job_definition(self, job_id, namespace, pods, env, index_offset=0):
{"name": "BYTEWAX_WORKDIR", "value": "/bytewax"},
{
"name": "BYTEWAX_WORKERS_PER_PROCESS",
"value": "1",
"value": f"{self.batch_engine_config.bytewax_worker_per_process}",
},
{
"name": "BYTEWAX_POD_NAME",
Expand All @@ -358,7 +366,7 @@ def _create_job_definition(self, job_id, namespace, pods, env, index_offset=0):
},
{
"name": "BYTEWAX_REPLICAS",
"value": f"{pods}",
"value": f"{self.batch_engine_config.bytewax_replicas}",
},
{
"name": "BYTEWAX_KEEP_CONTAINER_ALIVE",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os

import yaml
Expand All @@ -8,6 +9,8 @@
)

if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

with open("/var/feast/feature_store.yaml") as f:
feast_config = yaml.safe_load(f)

Expand Down

0 comments on commit 9cf9d96

Please sign in to comment.