Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[minor] allow overriding args/kwargs behavior in Runtime #587

Merged
merged 4 commits into from
Aug 25, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions hivemind/moe/server/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from prefetch_generator import BackgroundGenerator

from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.moe.server.task_pool import TaskPoolBase
from hivemind.utils import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -85,18 +86,8 @@

for pool, batch_index, batch in batch_iterator:
logger.debug(f"Processing batch {batch_index} from pool {pool.name}")

start = time()
try:
outputs = pool.process_func(*batch)
batch_processing_time = time() - start

batch_size = outputs[0].size(0)
logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")

if self.stats_report_interval is not None:
self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)

outputs = self.process_batch(pool, batch_index, *batch)
output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
except KeyboardInterrupt:
raise
Expand All @@ -108,6 +99,17 @@
if not self.shutdown_trigger.is_set():
self.shutdown()

def process_batch(self, pool: TaskPoolBase, batch_index: int, *batch: torch.Tensor):
"""process one batch of tasks from a given pool, return a batch of results"""
start = time()
outputs = pool.process_func(*batch)
batch_processing_time = time() - start
batch_size = outputs[0].size(0)
logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
if self.stats_report_interval is not None:
self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)

Check warning on line 110 in hivemind/moe/server/runtime.py

View check run for this annotation

Codecov / codecov/patch

hivemind/moe/server/runtime.py#L110

Added line #L110 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might lead to unintended "0 batches processed" log entries if the user overrides this without carefully considering the original function. Best to leave only the batch size computation inside the function and keep all the logging/time measuring logic outside the function

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, fixed it now, please take another look

return outputs

def shutdown(self):
"""Gracefully terminate a running runtime."""
logger.info("Shutting down")
Expand Down
Loading