Skip to content

Commit

Permalink
wandb bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jesbu1 committed Apr 25, 2024
1 parent 62645ab commit 8fea766
Showing 1 changed file with 5 additions and 16 deletions.
21 changes: 5 additions & 16 deletions sprint/saycan_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@
)
from sprint.utils.utils import make_primitive_annotation_eval_dataset
from sprint.utils.data_utils import process_annotation
from sprint.utils.wandb_info import WANDB_PROJECT_NAME, WANDB_ENTITY_NAME
from sprint.rollouts.saycan_rollout import run_policy_multi_process

os.environ["TOKENIZERS_PARALLELISM"] = "false"

WANDB_ENTITY_NAME = "clvr"
WANDB_PROJECT_NAME = "p-bootstrap-llm"


def setup_mp(
result_queue,
Expand Down Expand Up @@ -93,9 +91,7 @@ def multithread_dataset_aggregation(
# asynchronously collect results from result_queue
num_env_samples = 0
num_finished_tasks = 0
num_rollouts = (
config.num_eval_tasks if eval else config.num_rollouts_per_epoch
)
num_rollouts = config.num_eval_tasks if eval else config.num_rollouts_per_epoch
with tqdm(total=num_rollouts) as pbar:
while num_finished_tasks < num_rollouts:
result = result_queue.get()
Expand Down Expand Up @@ -142,18 +138,11 @@ def multiprocess_rollout(
video_captions = []
extra_info = defaultdict(list)

num_rollouts = (
config.num_eval_tasks if eval else config.num_rollouts_per_epoch
)
num_rollouts = config.num_eval_tasks if eval else config.num_rollouts_per_epoch
# create tasks for MP Queue

# create tasks for thread/process Queue
args_func = lambda subgoal: (
True,
True,
epsilon,
subgoal
)
args_func = lambda subgoal: (True, True, epsilon, subgoal)

for subgoal in range(num_rollouts):
task_queue.put(args_func(subgoal))
Expand Down Expand Up @@ -323,7 +312,7 @@ def signal_handler(sig, frame):
result_queue,
config,
0,
True,
True,
)
wandb.log(
eval_metrics,
Expand Down

0 comments on commit 8fea766

Please sign in to comment.