Skip to content

Commit

Permalink
Fix code in benchmarks, examples and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mryab committed Jul 13, 2024
1 parent c3ad32e commit 61c7744
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 33 deletions.
2 changes: 1 addition & 1 deletion benchmarks/benchmark_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def run_averager(index):
for step in range(num_rounds):
try:
success = averager.step(timeout=round_timeout) is not None
except:
except hivemind.averaging.allreduce.AllreduceException:
success = False
with lock_stats:
successful_steps += int(success)
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/benchmark_dht.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, shutdown_peers: list, shutdown_timestamps: list):
async def check_and_kill(self):
async with self.lock:
if (
self.shutdown_timestamps != None
self.shutdown_timestamps is not None
and self.timestamp_iter < len(self.shutdown_timestamps)
and self.current_iter == self.shutdown_timestamps[self.timestamp_iter]
):
Expand Down Expand Up @@ -96,7 +96,7 @@ async def store_and_get_task(

total_gets += len(get_result)
for result in get_result:
if result != None:
if result is not None:
attendees, expiration = result
if len(attendees.keys()) == successful_stores_per_iter:
get_ok = True
Expand Down
14 changes: 8 additions & 6 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def print_device_info(device=None):
# Additional Info when using cuda
if device.type == "cuda":
logger.info(torch.cuda.get_device_name(0))
logger.info(f"Memory Usage:")
logger.info("Memory Usage:")
logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")

Expand Down Expand Up @@ -161,11 +161,13 @@ def benchmark_throughput(

sys.stdout.flush()
sys.stderr.flush()
time_between = (
lambda key1, key2: abs(timestamps[key2] - timestamps[key1])
if (key1 in timestamps and key2 in timestamps)
else float("nan")
)

def time_between(key1, key2):
if key1 in timestamps and key2 in timestamps:
return abs(timestamps[key2] - timestamps[key1])
else:
return float("nan")

total_examples = batch_size * num_clients * num_batches_per_client

logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
Expand Down
30 changes: 16 additions & 14 deletions examples/albert/run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_model(training_args, config, tokenizer):
logger.info(f"Loading model from {latest_checkpoint_dir}")
model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
else:
logger.info(f"Training from scratch")
logger.info("Training from scratch")
model = AlbertForPreTraining(config)
model.resize_token_embeddings(len(tokenizer))

Expand Down Expand Up @@ -235,17 +235,18 @@ def main():

adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead

# We need to make such a lambda function instead of just an optimizer instance
# We need to make such a function instead of just an optimizer instance
# to make hivemind.Optimizer(..., offload_optimizer=True) work
opt = lambda params: Lamb(
params,
lr=training_args.learning_rate,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
clamp_value=training_args.clamp_value,
debias=True,
)
def opt(params):
return Lamb(
params,
lr=training_args.learning_rate,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
clamp_value=training_args.clamp_value,
debias=True,
)

no_decay = ["bias", "LayerNorm.weight"]
params = [
Expand All @@ -259,9 +260,10 @@ def main():
},
]

scheduler = lambda opt: get_linear_schedule_with_warmup(
opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
)
def scheduler(opt):
return get_linear_schedule_with_warmup(
opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
)

optimizer = Optimizer(
dht=dht,
Expand Down
5 changes: 4 additions & 1 deletion tests/test_allreduce_fault_tolerance.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import asyncio
from enum import Enum, auto
from typing import AsyncIterator

import pytest
import torch

import hivemind
from hivemind.averaging.averager import *
from hivemind.averaging.averager import AllReduceRunner, AveragingMode, GatheredData
from hivemind.averaging.group_info import GroupInfo
from hivemind.averaging.load_balancing import load_balance_peers
from hivemind.averaging.matchmaking import MatchmakingException
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dht.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_run_coroutine():
assert dht.run_coroutine(dummy_dht_coro) == "pew"

with pytest.raises(ValueError):
res = dht.run_coroutine(dummy_dht_coro_error)
dht.run_coroutine(dummy_dht_coro_error)

bg_task = dht.run_coroutine(dummy_dht_coro_long, return_future=True)
assert dht.run_coroutine(dummy_dht_coro_stateful) == 124
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dht_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async def test_dht_node(
for node in [me, that_guy]:
val, expiration_time = await node.get("mykey")
assert val == ["Value", 10], "Wrong value"
assert expiration_time == true_time, f"Wrong time"
assert expiration_time == true_time, "Wrong time"

assert not await detached_node.get("mykey")

Expand Down
2 changes: 1 addition & 1 deletion tests/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_call_many(hidden_dim=16):
[ExpertInfo(uid=f"expert.{i}", peer_id=server_peer_info.peer_id) for i in range(5)],
dht,
)
e5 = RemoteExpert(ExpertInfo(f"thisshouldnotexist", server_peer_info), None)
e5 = RemoteExpert(ExpertInfo("thisshouldnotexist", server_peer_info), None)

mask, expert_outputs = _RemoteCallMany.apply(
DUMMY,
Expand Down
15 changes: 9 additions & 6 deletions tests/test_utils/custom_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from hivemind.moe import register_expert_class

sample_input = lambda batch_size, hidden_dim: torch.empty((batch_size, hidden_dim))

def sample_input(batch_size, hidden_dim):
return torch.empty((batch_size, hidden_dim))


@register_expert_class("perceptron", sample_input)
Expand All @@ -22,11 +24,12 @@ def forward(self, x):
return x


multihead_sample_input = lambda batch_size, hidden_dim: (
torch.empty((batch_size, hidden_dim)),
torch.empty((batch_size, 2 * hidden_dim)),
torch.empty((batch_size, 3 * hidden_dim)),
)
def multihead_sample_input(batch_size, hidden_dim):
return (
torch.empty((batch_size, hidden_dim)),
torch.empty((batch_size, 2 * hidden_dim)),
torch.empty((batch_size, 3 * hidden_dim)),
)


@register_expert_class("multihead", multihead_sample_input)
Expand Down

0 comments on commit 61c7744

Please sign in to comment.