-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[P0] Intervention scheduling for generation #110
base: main
Are you sure you want to change the base?
Conversation
if unit_locations is None: | ||
# this means, we don't filter based on location at all. | ||
return {"sources->base": ([None]*len(self.interventions), [None]*len(self.interventions))} | ||
|
||
if self.mode == "parallel": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that self.mode does not control this logic block, what is the difference between wait_for_forward_with_parallel_intervention() and wait_for_forward_with_serial_intervention()? Is there still a need to separate these two?
intervention, module_hook = self.interventions[key] | ||
|
||
def hook_callback(model, args, kwargs, output=None): | ||
if self._is_generation: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry if I don't understand, could you explain the rationale of allowing the hook_callback to run when self._skip_forward is True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was dead code already iirc, since it's just getting passed here. Correct me if I'm wrong, but since the getter hook is used to gather source representations wouldn't it still need to run even if a generate()
call skips intervening on the base (prompt)?
@@ -149,13 +178,12 @@ def test_with_subspace_negative(self): | |||
Negative test case to check input length. | |||
""" | |||
intervenable = IntervenableModel( | |||
self.test_subspace_intervention_link_config, self.mlp | |||
self.test_negative_subspace_config, self.mlp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if you replace this test_negative_subspace_config
with test_subspace_intervention_link_config
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test case was intended to test defining an intervention with subspace partitions that exceeded the dimension of the model. That is why test_subspace_intervention_link_config
wasn't triggering an IndexError
at all, since it and the inputs in this test case are both of dim 3. (It was passing in previous commits because of an entirely unrelated and problematic IndexError
that should actually be fixed by this PR.)
Since changing the current config would break all the other tests in this file that rely on it, I decided to just copy it over to a new one.
@@ -102,15 +107,15 @@ def test_scatter_neurons_gpt2_batch_diff_fast_no_head_positive(self): | |||
golden_output = tensor_input.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since there is no fast path anymore, we can remove all fast_path
tests, and remove the fast_path
parameter in modeling_utils.py as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was curious about that, good to know we can remove it. I'd rather use a separate PR for that, though.
] = replacing_tensor_input[:, i] | ||
else: | ||
tensor_input[_batch_idx, unit_locations] = replacing_tensor_input | ||
tensor_input[_batch_idx, unit_locations] = replacing_tensor_input[_batch_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job! Removed the for loop in the assignment
831eb33
to
a17e19a
Compare
Is this PR done? I would really like to use this functionality for my ongoing project. Also, do interventions at given time-step carry to the all the future time-steps? For example, if the token at For example, I would want to intervene on the last token at every generation step, i.e. only intervene on the token at Thoughts would be appreciated |
Description
Basic functionality for scheduling interventions to happen on positions not present in the prompt (i.e. generated tokens). Ideally should follow the same procedure for GRU.
Changelog:
timestep_selector
, a list of lengthnum_intv
of boolean callbacks with signatureCallable[[int, torch.Tensor], bool]
can be passed togenerate()
calls. Each intervention calls its callback function with the current position to determine whether the intervention should operate on that position or not.None
values in unit locations: If Nones are specified at the batch dimension then interventions are not applied to those examples in the batch._intervention_getter()
,_intervention_setter()
functions were being called with single interventions even though they were written to handle an array of intervention keys and return a list of handlers, has been removedgather_neurons()
andscatter_neurons()
Testing Done
test_nulling_intervention
,test_generation_with_source_intervened_prompt
,test_dynamic_static_generation_intervention_parity
,test_generation_noops
test_with_subspace_negative
,test_scatter_neurons_gpt2_attn_with_head_positive
Checklist:
[Your Priority] Your Title