Skip to content

Commit

Permalink
Accuracy of full composition init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
krasheninnikov committed Apr 17, 2024
1 parent 60f9990 commit 7083a60
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
27 changes: 9 additions & 18 deletions configs/current_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ training_arguments:
experiment_arguments: # common experiment arguments
define_experiment: False
numeric_experiment: True
name_prefix: "ReplicatingRyanTwoDiffPwd_LockWithSFT"
n_stages: 3
n_seeds: 10
name_prefix: "Test"
n_stages: 2
n_seeds: 5
# n_seeds_stage2: 5
start_seed: 1010
slurm: True
slurm: False
n_gpu_hours: 3


Expand All @@ -55,7 +55,7 @@ define_experiment_arguments:
numeric_experiment_arguments:
# Args for pwd composition experiment below
pwd_locked_experiment: True
n_datapoints: 200000
n_datapoints: 20000
max_unlocking_datapoints: 4
max_x: 10
n_func_in_chain: 2
Expand All @@ -67,23 +67,14 @@ numeric_experiment_arguments:

# overrides specified parameters
first_stage_arguments:
train_subset: 'stage1'
train_subset: 'stage2'
num_train_epochs: 5
eval_each_epochs: 1
gradient_accumulation_steps: 1

second_stage_arguments:
train_subset: 'stage2'
num_train_epochs: 1
eval_each_epochs: 1
gradient_accumulation_steps: 1
save_each_epochs: 0
n_datapoints: 50000

third_stage_arguments:
train_subset: 'stage3'
num_train_epochs: 200
eval_each_epochs: 10
num_train_epochs: 129
eval_each_epochs: 64
gradient_accumulation_steps: 1
dont_save_in_the_end: True
save_each_epochs: 0
save_each_epochs: 0
9 changes: 9 additions & 0 deletions data_generation/pwd_locked_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,15 @@ def accuracy(y_pred, y_true):
res[f'{fn.fn_name}_weak'] = accuracy(y_true_fn2, y)
except:
return res

# calculate the accuracy of the full composition with all fns unlocked
# compute correct output
x = [int(num) for num in input_x.split()]
for i, fn in enumerate(fns):
x = fn.fn1(x)
y_true = x
y_pred = [int(num) for num in chain_of_thought_without_input[-1].split()]
res['full_composition'] = accuracy(y_pred, y_true)

return res

Expand Down

0 comments on commit 7083a60

Please sign in to comment.