diff --git a/.gitignore b/.gitignore index 5a243027..1102aa00 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,5 @@ dmypy.json # Pyre type checker .pyre/ + +.vscode \ No newline at end of file diff --git a/pyvene/models/configuration_intervenable_model.py b/pyvene/models/configuration_intervenable_model.py index 3831fa48..1be60a70 100644 --- a/pyvene/models/configuration_intervenable_model.py +++ b/pyvene/models/configuration_intervenable_model.py @@ -5,7 +5,7 @@ from transformers import PreTrainedTokenizer, TensorType, is_torch_available from transformers.configuration_utils import PretrainedConfig -from .interventions import VanillaIntervention +from .interventions import VanillaIntervention, Intervention RepresentationConfig = namedtuple( @@ -25,7 +25,7 @@ class IntervenableConfig(PretrainedConfig): def __init__( self, representations=[RepresentationConfig()], - intervention_types=VanillaIntervention, + intervention_types:type[Intervention] | List[type[Intervention]]=VanillaIntervention, mode="parallel", sorted_keys=None, model_type=None, # deprecating diff --git a/pyvene/models/gpt2/modelings_intervenable_gpt2.py b/pyvene/models/gpt2/modelings_intervenable_gpt2.py index c5442f20..35040768 100644 --- a/pyvene/models/gpt2/modelings_intervenable_gpt2.py +++ b/pyvene/models/gpt2/modelings_intervenable_gpt2.py @@ -74,12 +74,14 @@ def create_gpt2(name="gpt2", cache_dir=None): """Creates a GPT2 model, config, and tokenizer from the given name and revision""" - from transformers import GPT2Model, GPT2Tokenizer, GPT2Config + from transformers import GPT2Model, AutoTokenizer, GPT2Config config = GPT2Config.from_pretrained(name) - tokenizer = GPT2Tokenizer.from_pretrained(name) + tokenizer = AutoTokenizer.from_pretrained(name) gpt = GPT2Model.from_pretrained(name, config=config, cache_dir=cache_dir) - print("loaded model") + assert isinstance(gpt, GPT2Model) + + print(f"loaded GPT2 model {name}") return config, tokenizer, gpt @@ -93,7 +95,10 @@ def create_gpt2_lm(name="gpt2", config=None, cache_dir=None): gpt = GPT2LMHeadModel.from_pretrained(name, config=config, cache_dir=cache_dir) else: gpt = GPT2LMHeadModel(config=config) - print("loaded model") + + assert isinstance(gpt, GPT2LMHeadModel) + + print(f"loaded GPT2 model {name}") return config, tokenizer, gpt def create_gpt2_classifier(name="gpt2", config=None, cache_dir=None): diff --git a/pyvene/models/gpt_neo/modelings_intervenable_gpt_neo.py b/pyvene/models/gpt_neo/modelings_intervenable_gpt_neo.py index 09f94cad..82684a97 100644 --- a/pyvene/models/gpt_neo/modelings_intervenable_gpt_neo.py +++ b/pyvene/models/gpt_neo/modelings_intervenable_gpt_neo.py @@ -19,16 +19,16 @@ "mlp_activation": ("h[%s].mlp.act", CONST_OUTPUT_HOOK), "mlp_output": ("h[%s].mlp", CONST_OUTPUT_HOOK), "mlp_input": ("h[%s].mlp", CONST_INPUT_HOOK), - "attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK), - "head_attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")), + "attention_value_output": ("h[%s].attn.attention.out_proj", CONST_INPUT_HOOK), + "head_attention_value_output": ("h[%s].attn.attention.out_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")), "attention_output": ("h[%s].attn", CONST_OUTPUT_HOOK), "attention_input": ("h[%s].attn", CONST_INPUT_HOOK), - "query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK), - "key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK), - "value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK), - "head_query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")), - "head_key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")), - "head_value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")), + "query_output": ("h[%s].attn.attention.q_proj", CONST_OUTPUT_HOOK), + "key_output": ("h[%s].attn.attention.k_proj", CONST_OUTPUT_HOOK), + "value_output": ("h[%s].attn.attention.v_proj", CONST_OUTPUT_HOOK), + "head_query_output": ("h[%s].attn.attention.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")), + "head_key_output": ("h[%s].attn.attention.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")), + "head_value_output": ("h[%s].attn.attention.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")), } @@ -67,11 +67,13 @@ def create_gpt_neo( name="roneneldan/TinyStories-33M", cache_dir=None ): - """Creates a GPT2 model, config, and tokenizer from the given name and revision""" + """Creates a GPTNeo model, config, and tokenizer from the given name and revision""" from transformers import GPTNeoForCausalLM, GPT2Tokenizer, GPTNeoConfig config = GPTNeoConfig.from_pretrained(name) tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-125M") # not sure gpt_neo = GPTNeoForCausalLM.from_pretrained(name) - print("loaded model") + assert isinstance(gpt_neo, GPTNeoForCausalLM) + + print(f"loaded GPTNeo model {name}") return config, tokenizer, gpt_neo diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index 270494c7..314e06a5 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1,7 +1,7 @@ import json, logging, torch, types import numpy as np from collections import OrderedDict -from typing import List, Optional, Tuple, Union, Dict, Any +from typing import Callable, List, Optional, Dict, Any, Tuple from .basic_utils import * from .modeling_utils import * @@ -15,7 +15,7 @@ TrainableIntervention, SkipIntervention, CollectIntervention, - BoundlessRotatedSpaceIntervention + BoundlessRotatedSpaceIntervention, ) from torch import optim @@ -24,6 +24,9 @@ from transformers.utils import ModelOutput from tqdm import tqdm, trange +TIMESTEP_SELECTOR_TYPE = List[Callable[[int, torch.Tensor], bool]] + + @dataclass class IntervenableModelOutput(ModelOutput): original_outputs: Optional[Any] = None @@ -39,15 +42,15 @@ class IntervenableModel(nn.Module): def __init__(self, config, model, **kwargs): super().__init__() if isinstance(config, dict) or isinstance(config, list): - config = IntervenableConfig( - representations = config - ) - self.config = config - + config = IntervenableConfig(representations=config) + self.config: IntervenableConfig = config + self.mode = config.mode intervention_type = config.intervention_types - self.is_model_stateless = is_stateless(model) - self.config.model_type = str(type(model)) # backfill + self.is_model_stateless = is_stateless( + model + ) # all sequence models need state to generate, but only the time indices + self.config.model_type = str(type(model)) # backfill self.use_fast = kwargs["use_fast"] if "use_fast" in kwargs else False self.model_has_grad = False @@ -59,10 +62,8 @@ def __init__(self, config, model, **kwargs): "be considered" ) # each representation can get a different intervention type - if type(intervention_type) == list: - assert len(intervention_type) == len( - config.representations - ) + if isinstance(intervention_type, list): + assert len(intervention_type) == len(config.representations) ### # We instantiate intervention_layers at locations. @@ -75,14 +76,14 @@ def __init__(self, config, model, **kwargs): # To support a new model type, you need to provide a # mapping between supported abstract type and module name. ### - self.representations = {} - self.interventions = {} + self.representations: Dict[str, RepresentationConfig] = {} + self.interventions: Dict[str, Tuple[Callable | Intervention, nn.Module]] = {} self._key_collision_counter = {} self.return_collect_activations = False # Flags and counters below are for interventions in the model.generate # call. We can intervene on the prompt tokens only, on each generated # token, or on a combination of both. - self._is_generation = False + self._skip_forward = False self._intervene_on_prompt = None self._key_getter_call_counter = {} self._key_setter_call_counter = {} @@ -92,15 +93,13 @@ def __init__(self, config, model, **kwargs): # hooks are stateful internally, meaning that it's aware of how many times # it is called during the execution. # TODO: this could be merged with call counter above later. - self._intervention_state = {} + self._intervention_state: Dict[str, InterventionState] = {} # We want to associate interventions with a group to do group-wise interventions. self._intervention_group = {} _any_group_key = False _original_key_order = [] - for i, representation in enumerate( - config.representations - ): + for i, representation in enumerate(config.representations): _key = self._get_representation_key(representation) if representation.intervention is not None: @@ -109,51 +108,45 @@ def __init__(self, config, model, **kwargs): else: intervention_function = ( intervention_type - if type(intervention_type) != list + if not isinstance(intervention_type, list) else intervention_type[i] ) all_metadata = representation._asdict() component_dim = get_dimension_by_component( - get_internal_model_type(model), model.config, - representation.component + get_internal_model_type(model), + model.config, + representation.component, ) if component_dim is not None: component_dim *= int(representation.max_number_of_units) all_metadata["embed_dim"] = component_dim all_metadata["use_fast"] = self.use_fast - intervention = intervention_function( - **all_metadata - ) - + intervention = intervention_function(**all_metadata) + if representation.intervention_link_key in self._intervention_pointers: - self._intervention_reverse_link[ - _key - ] = f"link#{representation.intervention_link_key}" + self._intervention_reverse_link[_key] = ( + f"link#{representation.intervention_link_key}" + ) intervention = self._intervention_pointers[ representation.intervention_link_key ] elif representation.intervention_link_key is not None: - self._intervention_pointers[ - representation.intervention_link_key - ] = intervention - self._intervention_reverse_link[ - _key - ] = f"link#{representation.intervention_link_key}" - - if isinstance( - intervention, - CollectIntervention - ): + self._intervention_pointers[representation.intervention_link_key] = ( + intervention + ) + self._intervention_reverse_link[_key] = ( + f"link#{representation.intervention_link_key}" + ) + + if isinstance(intervention, CollectIntervention): self.return_collect_activations = True - - module_hook = get_module_hook( - model, representation - ) + + module_hook = get_module_hook(model, representation) self.representations[_key] = representation self.interventions[_key] = (intervention, module_hook) - self._key_getter_call_counter[ - _key - ] = 0 # we memo how many the hook is called, + self._key_getter_call_counter[_key] = ( + 0 # we memo how many the hook is called, + ) # usually, it's a one time call per # hook unless model generates. self._key_setter_call_counter[_key] = 0 @@ -166,10 +159,7 @@ def __init__(self, config, model, **kwargs): "The key is provided in the config. " "Assuming this is loaded from a pretrained module." ) - if ( - self.config.sorted_keys is not None - or "intervenables_sort_fn" not in kwargs - ): + if self.config.sorted_keys is not None or "intervenables_sort_fn" not in kwargs: self.sorted_keys = _original_key_order else: # the key order is independent of group, it is used to read out intervention locations. @@ -197,7 +187,7 @@ def __init__(self, config, model, **kwargs): for i in range(len(_validate_group_keys) - 1): if _validate_group_keys[i] > _validate_group_keys[i + 1]: logging.info( - f"This is not a valid group key order: {_validate_group_keys}" + f"This is not a valid group key order: {_validate_group_keys}" ) raise ValueError( "Must be ascending order. " @@ -229,7 +219,7 @@ def __init__(self, config, model, **kwargs): self.model_type = get_internal_model_type(model) self.disable_model_gradients() self.trainable_model_parameters = {} - + def __str__(self): """ Print out basic info about this intervenable instance @@ -280,7 +270,6 @@ def _cleanup_states(self, skip_activation_gc=False): """ Clean up all old in memo states of interventions """ - self._is_generation = False self._remove_forward_hooks() self._reset_hook_count() if not skip_activation_gc: @@ -304,7 +293,7 @@ def get_trainable_parameters(self): if p.requires_grad: ret_params += [p] return ret_params - + def named_parameters(self, recurse=True): """ The above, but for HuggingFace. @@ -312,12 +301,12 @@ def named_parameters(self, recurse=True): ret_params = [] for k, v in self.interventions.items(): if isinstance(v[0], TrainableIntervention): - ret_params += [(k + '.' + n, p) for n, p in v[0].named_parameters()] + ret_params += [(k + "." + n, p) for n, p in v[0].named_parameters()] for n, p in self.model.named_parameters(): if p.requires_grad: - ret_params += [('model.' + n, p)] + ret_params += [("model." + n, p)] return ret_params - + def get_cached_activations(self): """ Return the cached activations with keys @@ -335,8 +324,9 @@ def set_temperature(self, temp: torch.Tensor): Set temperature if needed """ for k, v in self.interventions.items(): - if isinstance(v[0], BoundlessRotatedSpaceIntervention) or \ - isinstance(v[0], SigmoidMaskIntervention): + if isinstance(v[0], BoundlessRotatedSpaceIntervention) or isinstance( + v[0], SigmoidMaskIntervention + ): v[0].set_temperature(temp) def enable_model_gradients(self): @@ -346,9 +336,9 @@ def enable_model_gradients(self): # Unfreeze all model weights self.model.train() for param in self.model.parameters(): - param.requires_grad = True + param.requires_grad = True self.model_has_grad = True - + def disable_model_gradients(self): """ Disable gradient in the model @@ -358,7 +348,7 @@ def disable_model_gradients(self): for param in self.model.parameters(): param.requires_grad = False self.model_has_grad = False - + def disable_intervention_gradients(self): """ Disable gradient in the trainable intervention @@ -371,7 +361,9 @@ def set_device(self, device, set_model=True): Set device of interventions and the model """ for k, v in self.interventions.items(): - v[0].to(device) + if isinstance(v[0], Intervention): + v[0].to(device) + if set_model: self.model.to(device) @@ -397,7 +389,8 @@ def count_parameters(self, include_model=False): total_parameters += count_parameters(v[0]) if include_model: total_parameters += sum( - p.numel() for p in self.model.parameters() if p.requires_grad) + p.numel() for p in self.model.parameters() if p.requires_grad + ) return total_parameters def set_zero_grad(self): @@ -415,7 +408,7 @@ def zero_grad(self): for k, v in self.interventions.items(): if isinstance(v[0], TrainableIntervention): v[0].zero_grad() - + def save( self, save_directory, save_to_hf_hub=False, hf_repo_name="my-awesome-model" ): @@ -431,13 +424,11 @@ def save( saving_config = copy.deepcopy(self.config) saving_config.sorted_keys = self.sorted_keys - saving_config.model_type = str( - saving_config.model_type - ) + saving_config.model_type = str(saving_config.model_type) saving_config.intervention_types = [] saving_config.intervention_dimensions = [] saving_config.intervention_constant_sources = [] - + # handle constant source reprs if passed in. serialized_representations = [] for reprs in saving_config.representations: @@ -456,19 +447,18 @@ def save( serialized_reprs[k] = None else: serialized_reprs[k] = v - serialized_representations += [ - RepresentationConfig(**serialized_reprs) - ] - saving_config.representations = \ - serialized_representations - + serialized_representations += [RepresentationConfig(**serialized_reprs)] + saving_config.representations = serialized_representations + for k, v in self.interventions.items(): intervention = v[0] saving_config.intervention_types += [str(type(intervention))] binary_filename = f"intkey_{k}.bin" # save intervention binary file - if isinstance(intervention, TrainableIntervention) or \ - intervention.source_representation is not None: + if ( + isinstance(intervention, TrainableIntervention) + or intervention.source_representation is not None + ): # logging.info(f"Saving trainable intervention to {binary_filename}.") torch.save( intervention.state_dict(), @@ -492,9 +482,13 @@ def save( if intervention.interchange_dim is None: saving_config.intervention_dimensions += [None] else: - saving_config.intervention_dimensions += [intervention.interchange_dim.tolist()] - saving_config.intervention_constant_sources += [intervention.is_source_constant] - + saving_config.intervention_dimensions += [ + intervention.interchange_dim.tolist() + ] + saving_config.intervention_constant_sources += [ + intervention.is_source_constant + ] + # save metadata config saving_config.save_pretrained(save_directory) if save_to_hf_hub: @@ -520,8 +514,9 @@ def load(load_directory, model, local_directory=None, from_huggingface_hub=False """ if not os.path.exists(load_directory) or from_huggingface_hub: from_huggingface_hub = True - + from huggingface_hub import snapshot_download + load_directory = snapshot_download( repo_id=load_directory, local_dir=local_directory, @@ -533,16 +528,10 @@ def load(load_directory, model, local_directory=None, from_huggingface_hub=False for type_str in saving_config.intervention_types: casted_intervention_types += [get_type_from_string(type_str)] - saving_config.intervention_types = ( - casted_intervention_types - ) + saving_config.intervention_types = casted_intervention_types casted_representations = [] - for ( - representation_opts - ) in saving_config.representations: - casted_representations += [ - RepresentationConfig(*representation_opts) - ] + for representation_opts in saving_config.representations: + casted_representations += [RepresentationConfig(*representation_opts)] saving_config.representations = casted_representations intervenable = IntervenableModel(saving_config, model) @@ -550,22 +539,32 @@ def load(load_directory, model, local_directory=None, from_huggingface_hub=False for i, (k, v) in enumerate(intervenable.interventions.items()): intervention = v[0] binary_filename = f"intkey_{k}.bin" - intervention.is_source_constant = \ + intervention.is_source_constant = ( saving_config.intervention_constant_sources[i] + ) intervention.set_interchange_dim(saving_config.intervention_dimensions[i]) - if saving_config.intervention_constant_sources[i] and \ - not isinstance(intervention, ZeroIntervention) and \ - not isinstance(intervention, SourcelessIntervention): + if ( + saving_config.intervention_constant_sources[i] + and not isinstance(intervention, ZeroIntervention) + and not isinstance(intervention, SourcelessIntervention) + ): # logging.warn(f"Loading trainable intervention from {binary_filename}.") - saved_state_dict = torch.load(os.path.join(load_directory, binary_filename)) + saved_state_dict = torch.load( + os.path.join(load_directory, binary_filename) + ) try: intervention.register_buffer( - 'source_representation', saved_state_dict['source_representation'] + "source_representation", + saved_state_dict["source_representation"], ) except: - intervention.source_representation = saved_state_dict['source_representation'] + intervention.source_representation = saved_state_dict[ + "source_representation" + ] elif isinstance(intervention, TrainableIntervention): - saved_state_dict = torch.load(os.path.join(load_directory, binary_filename)) + saved_state_dict = torch.load( + os.path.join(load_directory, binary_filename) + ) intervention.load_state_dict(saved_state_dict) return intervenable @@ -576,15 +575,17 @@ def save_intervention(self, save_directory, include_model=True): trainable weights. This is not a static method, and returns nothing. """ create_directory(save_directory) - + # save binary files for k, v in self.interventions.items(): intervention = v[0] binary_filename = f"intkey_{k}.bin" # save intervention binary file if isinstance(intervention, TrainableIntervention): - torch.save(intervention.state_dict(), - os.path.join(save_directory, binary_filename)) + torch.save( + intervention.state_dict(), + os.path.join(save_directory, binary_filename), + ) # save model's trainable parameters as well if include_model: @@ -593,8 +594,10 @@ def save_intervention(self, save_directory, include_model=True): for n, p in self.model.named_parameters(): if p.requires_grad: model_state_dict[n] = p - torch.save(model_state_dict, os.path.join(save_directory, model_binary_filename)) - + torch.save( + model_state_dict, os.path.join(save_directory, model_binary_filename) + ) + def load_intervention(self, load_directory, include_model=True): """ Instead of creating an new object, this function loads existing weights onto @@ -605,17 +608,21 @@ def load_intervention(self, load_directory, include_model=True): intervention = v[0] binary_filename = f"intkey_{k}.bin" if isinstance(intervention, TrainableIntervention): - saved_state_dict = torch.load(os.path.join(load_directory, binary_filename)) + saved_state_dict = torch.load( + os.path.join(load_directory, binary_filename) + ) intervention.load_state_dict(saved_state_dict) # load model's trainable parameters as well if include_model: model_binary_filename = "pytorch_model.bin" - saved_model_state_dict = torch.load(os.path.join(load_directory, model_binary_filename)) + saved_model_state_dict = torch.load( + os.path.join(load_directory, model_binary_filename) + ) self.model.load_state_dict(saved_model_state_dict, strict=False) def _gather_intervention_output( - self, output, representations_key, unit_locations + self, output: torch.Tensor | Tuple, representations_key: str, unit_locations ) -> torch.Tensor: """ Gather intervening activations from the output based on indices @@ -648,9 +655,7 @@ def _gather_intervention_output( # gather subcomponent original_output = output_to_subcomponent( original_output, - self.representations[ - representations_key - ].component, + self.representations[representations_key].component, self.model_type, self.model_config, ) @@ -658,18 +663,15 @@ def _gather_intervention_output( # gather based on intervention locations selected_output = gather_neurons( original_output, - self.representations[ - representations_key - ].unit, + self.representations[representations_key].unit, unit_locations, ) return selected_output - def _scatter_intervention_output( self, - output, + output: torch.Tensor, intervened_representation, representations_key, unit_locations, @@ -677,27 +679,18 @@ def _scatter_intervention_output( """ Scatter in the intervened activations in the output """ - # data structure casting - if isinstance(output, tuple): - original_output = output[0] - else: - original_output = output # for non-sequence-based models, we simply replace # all the activations. if unit_locations is None: - original_output[:] = intervened_representation[:] - return original_output - - component = self.representations[ - representations_key - ].component - unit = self.representations[ - representations_key - ].unit - + output[:] = intervened_representation[:] + return output + + component = self.representations[representations_key].component + unit = self.representations[representations_key].unit + # scatter in-place _ = scatter_neurons( - original_output, + output, intervened_representation, component, unit, @@ -706,8 +699,8 @@ def _scatter_intervention_output( self.model_config, self.use_fast, ) - - return original_output + + return output def _intervention_getter( self, @@ -718,18 +711,11 @@ def _intervention_getter( Create a list of getter handlers that will fetch activations """ handlers = [] - for key_i, key in enumerate(keys): + for key in keys: + key_i = self.sorted_keys.index(key) intervention, module_hook = self.interventions[key] def hook_callback(model, args, kwargs, output=None): - if self._is_generation: - pass - # for getter, there is no restriction. - # is_prompt = self._key_getter_call_counter[key] == 0 - # if not self._intervene_on_prompt or is_prompt: - # self._key_getter_call_counter[key] += 1 - # if self._intervene_on_prompt ^ is_prompt: - # return # no-op if output is None: if len(args) == 0: # kwargs based calls # PR: https://github.com/frankaging/align-transformers/issues/11 @@ -755,22 +741,18 @@ def hook_callback(model, args, kwargs, output=None): # assert key not in self.activations self.activations[key] = selected_output else: - state_select_flag = [] - for unit_location in unit_locations[key_i]: - if ( - self._intervention_state[key].getter_version() - in unit_location - ): - state_select_flag += [True] - else: - state_select_flag += [False] + state_select_flag = [ + (self._intervention_state[key].getter_timestep in loc) + for loc in unit_locations[key_i] + ] + # for stateful model (e.g., gru), we save extra activations and metadata to do # stateful interventions. self.activations.setdefault(key, []).append( (selected_output, state_select_flag) ) - # set version for stateful models - self._intervention_state[key].inc_getter_version() + # set version for stateful models + self._intervention_state[key].getter_timestep += 1 handlers.append(module_hook(hook_callback, with_kwargs=True)) @@ -812,17 +794,15 @@ def _reconcile_stateful_cached_activations( if key not in self.activations: return None - cached_activations = self.activations[key] - if self.is_model_stateless: + cached_activations = torch.tensor(self.activations[key]) + if self.is_model_stateless or cached_activations.numel() == 0: # nothing to reconcile if stateless return cached_activations - state_select_flag = [] - for unit_location in intervening_unit_locations: - if self._intervention_state[key].setter_version() in unit_location: - state_select_flag += [True] - else: - state_select_flag += [False] + state_select_flag = [ + self._intervention_state[key].setter_timestep in unit_location + for unit_location in intervening_unit_locations + ] state_select_flag = ( torch.tensor(state_select_flag).bool().to(intervening_activations.device) ) @@ -856,27 +836,30 @@ def _intervention_setter( keys, unit_locations_base, subspaces, + timestep_selector: Optional[TIMESTEP_SELECTOR_TYPE] = None, ) -> HandlerList: """ Create a list of setter handlers that will set activations """ self._tidy_stateful_activations() - + handlers = [] - for key_i, key in enumerate(keys): + + for key in keys: + key_i = self.sorted_keys.index(key) intervention, module_hook = self.interventions[key] - if unit_locations_base[0] is not None: + state = self._intervention_state[key] + + if unit_locations_base[0]: self._batched_setter_activation_select[key] = [ 0 for _ in range(len(unit_locations_base[0])) ] # batch_size def hook_callback(model, args, kwargs, output=None): - if self._is_generation: - is_prompt = self._key_setter_call_counter[key] == 0 - if not self._intervene_on_prompt or is_prompt: - self._key_setter_call_counter[key] += 1 - if self._intervene_on_prompt ^ is_prompt: - return # no-op + if self._skip_forward and state.setter_timestep <= 0: + state.setter_timestep += 1 + return + if output is None: if len(args) == 0: # kwargs based calls # PR: https://github.com/frankaging/align-transformers/issues/11 @@ -884,81 +867,77 @@ def hook_callback(model, args, kwargs, output=None): output = kwargs[list(kwargs.keys())[0]] else: output = args - + + if isinstance(output, tuple): + output = output[0] + + # in this code we assume that output is batched along its first axis. + int_unit_loc = ( + unit_locations_base[key_i] + if state.setter_timestep <= 0 + else [ + ( + [0] + if timestep_selector != None + and timestep_selector[key_i]( + state.setter_timestep, output[i] + ) + else None + ) + for i in range(len(output)) + ] + ) + selected_output = self._gather_intervention_output( - output, key, unit_locations_base[key_i] + output, key, int_unit_loc ) # TODO: need to figure out why clone is needed if not self.is_model_stateless: selected_output = selected_output.clone() - - if isinstance( - intervention, - CollectIntervention - ): - intervened_representation = do_intervention( + + source = ( + None + if isinstance(intervention, CollectIntervention) + or isinstance(intervention, Intervention) + and intervention.is_source_constant + else self._reconcile_stateful_cached_activations( + key, selected_output, - None, - intervention, - subspaces[key_i] if subspaces is not None else None, + int_unit_loc, ) + ) + + intervened_representation = do_intervention( + selected_output, + source, + intervention, + subspaces[key_i] if subspaces is not None else None, + ) + + if isinstance(intervention, CollectIntervention): # fail if this is not a fresh collect assert key not in self.activations - + self.activations[key] = intervened_representation - # no-op to the output - - else: - if not isinstance(self.interventions[key][0], types.FunctionType): - if intervention.is_source_constant: - intervened_representation = do_intervention( - selected_output, - None, - intervention, - subspaces[key_i] if subspaces is not None else None, - ) - else: - intervened_representation = do_intervention( - selected_output, - self._reconcile_stateful_cached_activations( - key, - selected_output, - unit_locations_base[key_i], - ), - intervention, - subspaces[key_i] if subspaces is not None else None, - ) - else: - # highly unlikely it's a primitive intervention type - intervened_representation = do_intervention( - selected_output, - self._reconcile_stateful_cached_activations( - key, - selected_output, - unit_locations_base[key_i], - ), - intervention, - subspaces[key_i] if subspaces is not None else None, - ) - if intervened_representation is None: - return - - # setter can produce hot activations for shared subspace interventions if linked - if key in self._intervention_reverse_link: - self.hot_activations[ - self._intervention_reverse_link[key] - ] = intervened_representation.clone() - - if isinstance(output, tuple): - _ = self._scatter_intervention_output( - output[0], intervened_representation, key, unit_locations_base[key_i] - ) - else: - _ = self._scatter_intervention_output( - output, intervened_representation, key, unit_locations_base[key_i] - ) - - self._intervention_state[key].inc_setter_version() + return + + if intervened_representation is None: + return + + # setter can produce hot activations for shared subspace interventions if linked + if key in self._intervention_reverse_link: + self.hot_activations[self._intervention_reverse_link[key]] = ( + intervened_representation.clone() + ) + + self._scatter_intervention_output( + output, + intervened_representation, + key, + int_unit_loc, + ) + + state.setter_timestep += 1 handlers.append(module_hook(hook_callback, with_kwargs=True)) @@ -974,10 +953,14 @@ def _input_validation( ): """Fail fast input validation""" if self.mode == "parallel" and unit_locations is not None: - assert "sources->base" in unit_locations or "base" in unit_locations - elif activations_sources is None and unit_locations is not None and self.mode == "serial": + assert "sources->base" in unit_locations + elif ( + activations_sources is None + and unit_locations is not None + and self.mode == "serial" + ): assert "sources->base" not in unit_locations - + # sources may contain None, but length should match if sources is not None and not (len(sources) == 1 and sources[0] == None): if len(sources) != len(self._intervention_group): @@ -998,7 +981,7 @@ def _input_validation( if ( isinstance(v, list) and isinstance(v[0], tuple) - and isinstance(v[0][1], list) != True + and not isinstance(v[0][1], list) ): raise ValueError( f"Stateful models need nested activations. See our documentions." @@ -1010,7 +993,7 @@ def _output_validation( """Safe guarding the execution by checking memory states""" if self.is_model_stateless: for k, v in self._intervention_state.items(): - if v.getter_version() > 1 or v.setter_version() > 1: + if v.getter_timestep > 1 or v.setter_timestep > 1: raise Exception( f"For stateless model, each getter and setter " f"should be called only once: {self._intervention_state}" @@ -1046,6 +1029,7 @@ def _wait_for_forward_with_parallel_intervention( unit_locations, activations_sources: Optional[Dict] = None, subspaces: Optional[List] = None, + timestep_selector: Optional[TIMESTEP_SELECTOR_TYPE] = None, ): # torch.autograd.set_detect_anomaly(True) all_set_handlers = HandlerList([]) @@ -1059,51 +1043,40 @@ def _wait_for_forward_with_parallel_intervention( for group_id, keys in self._intervention_group.items(): if sources[group_id] is None: continue # smart jump for advance usage only - group_get_handlers = HandlerList([]) - for key in keys: - get_handlers = self._intervention_getter( - [key], - [ - unit_locations_sources[ - self.sorted_keys.index(key) - ] - ], - ) - group_get_handlers.extend(get_handlers) + + group_get_handlers = self._intervention_getter( + keys, + unit_locations_sources, + ) _ = self.model(**sources[group_id]) group_get_handlers.remove() else: # simply patch in the ones passed in - self.activations = activations_sources - for _, passed_in_key in enumerate(self.activations): - assert passed_in_key in self.sorted_keys - + for passed_in_key, v in activations_sources.items(): + assert ( + passed_in_key in self.sorted_keys + ), f"{passed_in_key} not in {self.sorted_keys}, {unit_locations}" + self.activations[passed_in_key] = torch.clone(v) + # in parallel mode, we swap cached activations all into # base at once for group_id, keys in self._intervention_group.items(): - for key in keys: - # skip in case smart jump - if key in self.activations or \ - isinstance(self.interventions[key][0], types.FunctionType) or \ - self.interventions[key][0].is_source_constant: - set_handlers = self._intervention_setter( - [key], - [ - unit_locations_base[ - self.sorted_keys.index(key) - ] - ], - # assume same group targeting the same subspace - [ - subspaces[ - self.sorted_keys.index(key) - ] - ] - if subspaces is not None - else None, - ) - # for setters, we don't remove them. - all_set_handlers.extend(set_handlers) + keys_with_handler = [ + key + for key in keys + if ( + key in self.activations + or isinstance(self.interventions[key][0], types.FunctionType) + or self.interventions[key][0].is_source_constant + ) + ] + + all_set_handlers.extend( + self._intervention_setter( + keys_with_handler, unit_locations_base, subspaces, timestep_selector + ) + ) + return all_set_handlers def _wait_for_forward_with_serial_intervention( @@ -1112,35 +1085,36 @@ def _wait_for_forward_with_serial_intervention( unit_locations, activations_sources: Optional[Dict] = None, subspaces: Optional[List] = None, + timestep_selector: Optional[TIMESTEP_SELECTOR_TYPE] = None, ): all_set_handlers = HandlerList([]) for group_id, keys in self._intervention_group.items(): if sources[group_id] is None: continue # smart jump for advance usage only - for key_id, key in enumerate(keys): - if group_id != len(self._intervention_group) - 1: - unit_locations_key = f"source_{group_id}->source_{group_id+1}" - else: - unit_locations_key = f"source_{group_id}->base" - unit_locations_source = unit_locations[unit_locations_key][0][ - key_id - ] - if unit_locations_source is None: - continue # smart jump for advance usage only - unit_locations_base = unit_locations[unit_locations_key][1][ - key_id + group_dest = ( + "base" + if group_id >= len(self._intervention_group) - 1 + else f"source_{group_id+1}" + ) + group_key = f"source_{group_id}->{group_dest}" + unit_locations_source = unit_locations[group_key][0] + unit_locations_base = unit_locations[group_key][1] + + if activations_sources != None: + for passed_in_key, v in activations_sources.items(): + assert ( + passed_in_key in self.sorted_keys + ), f"{passed_in_key} not in {self.sorted_keys}, {unit_locations}" + self.activations[passed_in_key] = torch.clone(v) + else: + keys_with_source = [ + k for i, k in enumerate(keys) if unit_locations_source[i] != None ] - if activations_sources is None: - # get activation from source_i - get_handlers = self._intervention_getter( - [key], - [unit_locations_source], - ) - else: - self.activations[key] = activations_sources[ - key - ] + get_handlers = self._intervention_getter( + keys_with_source, unit_locations_source + ) + # call once per group. each intervention is by its own group by default if activations_sources is None: # this is when previous setter and THEN the getter get called @@ -1151,125 +1125,110 @@ def _wait_for_forward_with_serial_intervention( all_set_handlers.remove() all_set_handlers = HandlerList([]) - for key in keys: - # skip in case smart jump - if key in self.activations or \ - isinstance(self.interventions[key][0], types.FunctionType) or \ - self.interventions[key][0].is_source_constant: - # set with intervened activation to source_i+1 - set_handlers = self._intervention_setter( - [key], - [unit_locations_base], - # assume the order - [ - subspaces[ - self.sorted_keys.index(key) - ] - ] - if subspaces is not None - else None, - ) - # for setters, we don't remove them. - all_set_handlers.extend(set_handlers) + keys_with_handler = [ + key + for key in keys + if ( + key in self.activations + or isinstance(self.interventions[key][0], types.FunctionType) + or self.interventions[key][0].is_source_constant + ) + ] + + all_set_handlers.extend( + self._intervention_setter( + keys_with_handler, unit_locations_base, subspaces, timestep_selector + ) + ) + return all_set_handlers - + def _broadcast_unit_locations( - self, - batch_size, - unit_locations - ): + self, batch_size, unit_locations: Optional[Dict[str, int | Sequence]] + ) -> Dict[str, Tuple[Sequence, Sequence]]: if unit_locations is None: # this means, we don't filter based on location at all. - return {"sources->base": ([None]*len(self.interventions), [None]*len(self.interventions))} - - if self.mode == "parallel": - _unit_locations = {} - for k, v in unit_locations.items(): - # special broadcast for base-only interventions - is_base_only = False - if k == "base": - is_base_only = True - k = "sources->base" - if isinstance(v, int): - if is_base_only: - _unit_locations[k] = (None, [[[v]]*batch_size]*len(self.interventions)) - else: - _unit_locations[k] = ( - [[[v]]*batch_size]*len(self.interventions), - [[[v]]*batch_size]*len(self.interventions) - ) - self.use_fast = True - elif len(v) == 2 and isinstance(v[0], int) and isinstance(v[1], int): - _unit_locations[k] = ( - [[[v[0]]]*batch_size]*len(self.interventions), - [[[v[1]]]*batch_size]*len(self.interventions) + return { + "sources->base": ( + [None] * len(self.interventions), + [None] * len(self.interventions), + ) + } + + self.use_fast = True + + if self.mode not in {"parallel", "serial"}: + raise ValueError(f"The mode {self.mode} is not supported.") + + _unit_locations = {} + + for k, v in unit_locations.items(): + # special broadcast for base-only interventions + if k == "base": + k = "sources->base" + + # Copies same locations into source/base if only one specified + if not isinstance(v, Sequence) or len(v) != 2: + v = (copy.deepcopy(v) if k != "base" else None, copy.deepcopy(v)) + + _v = [] + + for el in v: + # only position specified + if el == None or isinstance(el, int): + el = [el] + + # only position specified, broadcast across each batch ex + if get_list_depth(el) == 1: + el = [el for _ in range(batch_size)] + + # only position and batch dims specified, broadcast across interventions + if get_list_depth(el) == 2: + el = [el for _ in range(len(self.interventions))] + + # broadcasting once right number of dims created: + if len(el) == 1: + el *= len(self.interventions) + elif len(el) < len(self.interventions): + raise ValueError( + f"{len(self.interventions)} interventions expected, but {len(el)} locations given!" ) - self.use_fast = True - elif len(v) == 2 and v[0] == None and isinstance(v[1], int): - _unit_locations[k] = (None, [[[v[1]]]*batch_size]*len(self.interventions)) - self.use_fast = True - elif len(v) == 2 and isinstance(v[0], int) and v[1] == None: - _unit_locations[k] = ([[[v[0]]]*batch_size]*len(self.interventions), None) - self.use_fast = True - elif isinstance(v, list) and get_list_depth(v) == 1: - # [0,1,2,3] -> [[[0,1,2,3]]], ... - if is_base_only: - _unit_locations[k] = (None, [[v]*batch_size]*len(self.interventions)) - else: - _unit_locations[k] = ( - [[v]*batch_size]*len(self.interventions), - [[v]*batch_size]*len(self.interventions) + + for i, intervention in enumerate(el): + if intervention == None: + continue + + if len(intervention) == 1: + el[i] = intervention * batch_size + elif len(intervention) == 2: + for j in {0, 1}: + if len(intervention[j]) == 1: + el[i][j] = intervention[j] * batch_size + elif len(intervention[j]) < batch_size: + raise ValueError( + f"{batch_size} batch size expected, but {len(intervention[j])} locations given!" + ) + elif len(intervention) < batch_size: + raise ValueError( + f"{batch_size} batch size expected, but {len(intervention)} locations given!" ) - self.use_fast = True - else: - if is_base_only: - _unit_locations[k] = (None, v) - else: - _unit_locations[k] = v - elif self.mode == "serial": - _unit_locations = {} - for k, v in unit_locations.items(): - if isinstance(v, int): - _unit_locations[k] = ( - [[[v]]*batch_size]*len(self.interventions), - [[[v]]*batch_size]*len(self.interventions) - ) - self.use_fast = True - elif len(v) == 2 and isinstance(v[0], int) and isinstance(v[1], int): - _unit_locations[k] = ( - [[[v[0]]]*batch_size]*len(self.interventions), - [[[v[1]]]*batch_size]*len(self.interventions) - ) - self.use_fast = True - elif len(v) == 2 and v[0] == None and isinstance(v[1], int): - _unit_locations[k] = (None, [[[v[1]]]*batch_size]*len(self.interventions)) - self.use_fast = True - elif len(v) == 2 and isinstance(v[0], int) and v[1] == None: - _unit_locations[k] = ([[[v[0]]]*batch_size]*len(self.interventions), None) - self.use_fast = True - elif isinstance(v, list) and get_list_depth(v) == 1: - # [0,1,2,3] -> [[[0,1,2,3]]], ... - _unit_locations[k] = ( - [[v]*batch_size]*len(self.interventions), - [[v]*batch_size]*len(self.interventions) - ) - self.use_fast = True - else: - _unit_locations[k] = v - else: - raise ValueError(f"The mode {self.mode} is not supported.") + + _v.append(el) + + _unit_locations[k] = _v[0], _v[1] + return _unit_locations - + def _broadcast_source_representations( - self, - source_representations - ): + self, source_representations: Optional[Dict | List | torch.Tensor] + ) -> Dict[str, torch.Tensor] | None: """Broadcast simple inputs to a dict""" - _source_representations = {} if isinstance(source_representations, dict) or source_representations is None: # pass to broadcast for advance usage - _source_representations = source_representations - elif isinstance(source_representations, list): + return source_representations + + _source_representations = {} + if isinstance(source_representations, list): for i, key in enumerate(self.sorted_keys): _source_representations[key] = source_representations[i] elif isinstance(source_representations, torch.Tensor): @@ -1280,32 +1239,22 @@ def _broadcast_source_representations( "Accept input type for source_representations is [Dict, List, torch.Tensor]" ) return _source_representations - - def _broadcast_sources( - self, - sources - ): + + def _broadcast_sources(self, sources: Sequence) -> List: """Broadcast simple inputs to a dict""" - _sources = sources - if len(sources) == 1 and len(self._intervention_group) > 1: - for _ in range(len(self._intervention_group)-1): - _sources += [sources[0]] - else: - _sources = sources - return _sources - - def _broadcast_subspaces( - self, - batch_size, - subspaces - ): + if len(sources) > 1: + return list(sources) + + return [sources[0] for _ in range(len(self._intervention_group))] + + def _broadcast_subspaces(self, batch_size, subspaces): """Broadcast simple subspaces input""" _subspaces = subspaces if isinstance(subspaces, int): - _subspaces = [[[subspaces]]*batch_size]*len(self.interventions) - + _subspaces = [[[subspaces]] * batch_size] * len(self.interventions) + elif isinstance(subspaces, list) and isinstance(subspaces[0], int): - _subspaces = [[subspaces]*batch_size]*len(self.interventions) + _subspaces = [[subspaces] * batch_size] * len(self.interventions) else: # TODO: subspaces is easier to add more broadcast majic. pass @@ -1328,7 +1277,7 @@ def forward( actual model forward calls. It will use forward hooks to do interventions. - In essense, sources will lead to getter hooks to + In essence, sources will lead to getter hooks to get activations. We will use these activations to intervene on our base example. @@ -1356,13 +1305,14 @@ def forward( the shape can be - 2 * num_intervention * bs * num_max_unit + 2 * num_intervention * batch_size OR - 2 * num_intervention * num_intervention_level * bs * num_max_unit + 2 * num_intervention * num_intervention_level * batch_size if we intervene on h.pos which is a nested intervention location. + All the values will lie in the range [1, num_max_unit]. 2) subspaces subspaces is a list of indices indicating which subspace will @@ -1392,20 +1342,28 @@ def forward( activations_sources = source_representations if sources is not None and not isinstance(sources, list): sources = [sources] - + self._cleanup_states() # if no source input or intervention, we return base - if sources is None and activations_sources is None \ - and unit_locations is None and len(self.interventions) == 0: + if ( + sources is None + and activations_sources is None + and unit_locations is None + and len(self.interventions) == 0 + ): return self.model(**base), None # broadcast - unit_locations = self._broadcast_unit_locations(get_batch_size(base), unit_locations) - sources = [None]*len(self._intervention_group) if sources is None else sources + unit_locations = self._broadcast_unit_locations( + get_batch_size(base), unit_locations + ) + sources = [None] * len(self._intervention_group) if sources is None else sources sources = self._broadcast_sources(sources) - activations_sources = self._broadcast_source_representations(activations_sources) + activations_sources = self._broadcast_source_representations( + activations_sources + ) subspaces = self._broadcast_subspaces(get_batch_size(base), subspaces) - + self._input_validation( base, sources, @@ -1413,7 +1371,7 @@ def forward( activations_sources, subspaces, ) - + base_outputs = None if output_original_output: # returning un-intervened output with gradients @@ -1452,40 +1410,36 @@ def forward( set_handlers_to_remove.remove() self._output_validation() - - collected_activations = [] + + collected_activations = {} if self.return_collect_activations: for key in self.sorted_keys: - if isinstance( - self.interventions[key][0], - CollectIntervention - ): - collected_activations += self.activations[key] + if isinstance(self.interventions[key][0], CollectIntervention): + collected_activations[key] = self.activations[key] except Exception as e: raise e finally: self._cleanup_states( - skip_activation_gc = \ - (sources is None and activations_sources is not None) or \ - self.return_collect_activations + skip_activation_gc=(sources is None and activations_sources is not None) + or self.return_collect_activations ) - + if self.return_collect_activations: if return_dict: return IntervenableModelOutput( original_outputs=base_outputs, intervened_outputs=counterfactual_outputs, - collected_activations=collected_activations + collected_activations=collected_activations, ) - + return (base_outputs, collected_activations), counterfactual_outputs - + if return_dict: return IntervenableModelOutput( original_outputs=base_outputs, intervened_outputs=counterfactual_outputs, - collected_activations=None + collected_activations=None, ) return base_outputs, counterfactual_outputs @@ -1494,59 +1448,60 @@ def generate( self, base, sources: Optional[List] = None, - unit_locations: Optional[Dict] = None, source_representations: Optional[Dict] = None, - intervene_on_prompt: bool = False, + intervene_on_prompt: bool = True, + unit_locations: Optional[Dict] = None, + timestep_selector: Optional[TIMESTEP_SELECTOR_TYPE] = None, subspaces: Optional[List] = None, output_original_output: Optional[bool] = False, **kwargs, - ): + ) -> Tuple[ + Optional[ModelOutput | Tuple[Optional[ModelOutput], Dict[str, torch.Tensor]]], + ModelOutput, + ]: """ Intervenable generation function that serves a wrapper to regular model generate calls. - Currently, we support basic interventions **in the - prompt only**. We will support generation interventions - in the next release. - - TODO: Unroll sources and intervene in the generation step. - - Parameters: - base: The base example. - sources: A list of source examples. - unit_locations: The intervention locations of - base. - activations_sources: A list of representations. - intervene_on_prompt: Whether only intervene on prompt. - **kwargs: All other generation parameters. - - Return: - base_output: the non-intervened output of the base - input. - counterfactual_outputs: the intervened output of the - base input. + Args: + base: Base example encoding. + sources (Optional[List], optional): List of source encodings. + unit_locations (Optional[Dict], optional): Mapping from intervention edge to + locations of example pairs for every intervention along that edge. See forward() for + details of format. When the sources->base units are positions, this should only + contain indices within the length of the prompt. + timestep_selector (Optional[List[Callable[[int, torch.Tensor], bool]]], optional): list of length + num_interventions of selector functions for interventions on positions of generated text. + Each selector function takes input (idx, output): the i'th intervention intervenes on base value + output at position len(base) + idx if timestep_selector[i](idx, output) == True. + source_representations (Optional[Dict], optional): _description_. + subspaces (Optional[List], optional): _description_. + + Returns: + Tuple[ModelOutput | Tuple[ModelOutput, List[torch.Tensor]], ModelOutput]: _description_ """ # TODO: forgive me now, i will change this later. activations_sources = source_representations if sources is not None and not isinstance(sources, list): sources = [sources] - + self._cleanup_states() + self._skip_forward = not intervene_on_prompt - self._intervene_on_prompt = intervene_on_prompt - self._is_generation = True - - if not intervene_on_prompt and unit_locations is None: - # that means, we intervene on every generated tokens! - unit_locations = {"base": 0} - # broadcast - unit_locations = self._broadcast_unit_locations(get_batch_size(base), unit_locations) - sources = [None]*len(self._intervention_group) if sources is None else sources - sources = self._broadcast_sources(sources) - activations_sources = self._broadcast_source_representations(activations_sources) + unit_locations = self._broadcast_unit_locations( + get_batch_size(base), unit_locations + ) + sources = ( + [None] * len(self._intervention_group) + if sources is None + else self._broadcast_sources(sources) + ) + activations_sources = self._broadcast_source_representations( + activations_sources + ) subspaces = self._broadcast_subspaces(get_batch_size(base), subspaces) - + self._input_validation( base, sources, @@ -1554,7 +1509,7 @@ def generate( activations_sources, subspaces, ) - + base_outputs = None if output_original_output: # returning un-intervened output @@ -1570,6 +1525,7 @@ def generate( unit_locations, activations_sources, subspaces, + timestep_selector, ) ) elif self.mode == "serial": @@ -1579,37 +1535,32 @@ def generate( unit_locations, activations_sources, subspaces, + timestep_selector, ) ) - + # run intervened generate - counterfactual_outputs = self.model.generate( - **base, **kwargs - ) - - collected_activations = [] + counterfactual_outputs = self.model.generate(**base, **kwargs) + + collected_activations = {} if self.return_collect_activations: for key in self.sorted_keys: - if isinstance( - self.interventions[key][0], - CollectIntervention - ): - collected_activations += self.activations[key] + if isinstance(self.interventions[key][0], CollectIntervention): + collected_activations[key] = self.activations[key] except Exception as e: raise e finally: if set_handlers_to_remove is not None: set_handlers_to_remove.remove() - self._is_generation = False + self._skip_forward = False self._cleanup_states( - skip_activation_gc = \ - (sources is None and activations_sources is not None) or \ - self.return_collect_activations + skip_activation_gc=(sources is None and activations_sources is not None) + or self.return_collect_activations ) - + if self.return_collect_activations: return (base_outputs, collected_activations), counterfactual_outputs - + return base_outputs, counterfactual_outputs def _batch_process_unit_location(self, inputs): @@ -1853,4 +1804,4 @@ def eval_alignment( all_num_examples += [b_s] result = weighted_average(all_metrics, all_num_examples) - return result \ No newline at end of file + return result diff --git a/pyvene/models/intervention_utils.py b/pyvene/models/intervention_utils.py index ac7a7e3b..ab572ae0 100644 --- a/pyvene/models/intervention_utils.py +++ b/pyvene/models/intervention_utils.py @@ -1,41 +1,33 @@ import json import torch +import pprint class InterventionState(object): def __init__(self, key, **kwargs): self.key = key - self.reset() + self._timestep = [0, 0] - def inc_getter_version(self): - self.state_dict["getter_version"] += 1 - - def inc_setter_version(self): - self.state_dict["setter_version"] += 1 - - def getter_version(self): - return self.state_dict["getter_version"] - - def setter_version(self): - return self.state_dict["setter_version"] - - def get_states(self): - return self.state_dict - - def set_state(self, state_dict): - self.state_dict = state_dict + @property + def getter_timestep(self): + return self._timestep[0] + + @getter_timestep.setter + def getter_timestep(self, value): + self._timestep[0] = value + + @property + def setter_timestep(self): + return self._timestep[1] + + @setter_timestep.setter + def setter_timestep(self, value): + self._timestep[1] = value def reset(self): - self.state_dict = { - "key": self.key, - "getter_version": 0, - "setter_version": 0, - } + self._timestep = [0, 0] def __repr__(self): - return json.dumps(self.state_dict, indent=4) - - def __str__(self): - return json.dumps(self.state_dict, indent=4) + return pprint.pformat(self.__dict__, indent=4) def broadcast_tensor_v1(x, target_shape): # Ensure the last dimension of target_shape matches x's size diff --git a/pyvene/models/interventions.py b/pyvene/models/interventions.py index 42265929..67683476 100644 --- a/pyvene/models/interventions.py +++ b/pyvene/models/interventions.py @@ -36,7 +36,6 @@ def __init__(self, **kwargs): self.register_buffer('embed_dim', torch.tensor(kwargs["embed_dim"])) self.register_buffer('interchange_dim', torch.tensor(kwargs["embed_dim"])) else: - self.embed_dim = None self.interchange_dim = None if "source_representation" in kwargs and kwargs["source_representation"] is not None: diff --git a/pyvene/models/modeling_utils.py b/pyvene/models/modeling_utils.py index abcb20b5..0be3ce7b 100644 --- a/pyvene/models/modeling_utils.py +++ b/pyvene/models/modeling_utils.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence import random, torch, types import numpy as np from torch import nn @@ -5,6 +6,8 @@ from .interventions import * from .constants import * +UNIT_LOC_LIST_TYPE = Sequence[int | Sequence | torch.Tensor | None] + def get_internal_model_type(model): """Return the model type.""" @@ -94,7 +97,7 @@ def getattr_for_torch_module(model, parameter_name): return current_module -def get_dimension_by_component(model_type, model_config, component) -> int: +def get_dimension_by_component(model_type, model_config, component) -> int | None: """Based on the representation, get the aligning dimension size.""" if component not in type_to_dimension_mapping[model_type]: @@ -132,8 +135,8 @@ def get_dimension_by_component(model_type, model_config, component) -> int: def get_module_hook(model, representation) -> nn.Module: """Render the intervening module with a hook.""" if ( - get_internal_model_type(model) in type_to_module_mapping and - representation.component + get_internal_model_type(model) in type_to_module_mapping + and representation.component in type_to_module_mapping[get_internal_model_type(model)] ): type_info = type_to_module_mapping[get_internal_model_type(model)][ @@ -216,7 +219,9 @@ def bs_hd_to_bhsd(tensor, h): return tensor.reshape(b, s, h, d).permute(0, 2, 1, 3) -def output_to_subcomponent(output, component, model_type, model_config): +def output_to_subcomponent( + output: torch.Tensor, component, model_type, model_config +) -> torch.Tensor: """Split the raw output to subcomponents if specified in the config. :param output: the original output from the model component. @@ -227,11 +232,15 @@ def output_to_subcomponent(output, component, model_type, model_config): :param model_config: Hugging Face Model Config """ subcomponent = output - if model_type in type_to_module_mapping and \ - component in type_to_module_mapping[model_type]: + if ( + model_type in type_to_module_mapping + and component in type_to_module_mapping[model_type] + ): split_last_dim_by = type_to_module_mapping[model_type][component][2:] - if len(split_last_dim_by) != 0 and len(split_last_dim_by) > 2: + + if len(split_last_dim_by) > 2: raise ValueError(f"Unsupported {split_last_dim_by}.") + for i, (split_fn, param) in enumerate(split_last_dim_by): if isinstance(param, str): param = get_dimension_by_component(model_type, model_config, param) @@ -239,7 +248,9 @@ def output_to_subcomponent(output, component, model_type, model_config): return subcomponent -def gather_neurons(tensor_input, unit, unit_locations_as_list): +def gather_neurons( + tensor_input: torch.Tensor, unit, unit_locations_as_list: UNIT_LOC_LIST_TYPE +) -> torch.Tensor: """Gather intervening neurons. :param tensor_input: tensors of shape (batch_size, sequence_length, ...) if @@ -264,50 +275,34 @@ def gather_neurons(tensor_input, unit, unit_locations_as_list): # we assume unit_locations is a tuple head_unit_locations = unit_locations[0] pos_unit_locations = unit_locations[1] + _batch_idx = torch.arange(tensor_input.shape[0])[:, None, None] - head_tensor_output = torch.gather( - tensor_input, - 1, - head_unit_locations.reshape( - *head_unit_locations.shape, *(1,) * (len(tensor_input.shape) - 2) - ).expand(-1, -1, *tensor_input.shape[2:]), - ) # b, h, s, d - d = head_tensor_output.shape[1] - pos_tensor_input = bhsd_to_bs_hd(head_tensor_output) - pos_tensor_output = torch.gather( - pos_tensor_input, - 1, - pos_unit_locations.reshape( - *pos_unit_locations.shape, *(1,) * (len(pos_tensor_input.shape) - 2) - ).expand(-1, -1, *pos_tensor_input.shape[2:]), - ) # b, num_unit (pos), num_unit (h)*d - tensor_output = bs_hd_to_bhsd(pos_tensor_output, d) - - return tensor_output # b, num_unit (h), num_unit (pos), d + return tensor_input[ + _batch_idx, head_unit_locations[:, :, None], pos_unit_locations[:, None, :] + ] else: + # For now, when gathering neurons to set, we want to include the entire batch + # even if we are only intervening on some of them, just so there are no + # surprising changes in the base shape. I am setting all the None rows + # to 0 because the scatter function will filter these rows out anyways. + unit_locations_as_list = [(arr or [0]) for arr in unit_locations_as_list] unit_locations = torch.tensor( unit_locations_as_list, device=tensor_input.device ) - tensor_output = torch.gather( - tensor_input, - 1, - unit_locations.reshape( - *unit_locations.shape, *(1,) * (len(tensor_input.shape) - 2) - ).expand(-1, -1, *tensor_input.shape[2:]), - ) - return tensor_output + _batch_idx = torch.arange(tensor_input.shape[0])[:, None] + return tensor_input[_batch_idx, unit_locations] def scatter_neurons( - tensor_input, - replacing_tensor_input, + tensor_input: torch.Tensor, + replacing_tensor_input: torch.Tensor, component, unit, - unit_locations_as_list, + unit_locations_as_list: UNIT_LOC_LIST_TYPE, model_type, model_config, - use_fast, -): + use_fast: bool, +) -> torch.Tensor: """Replace selected neurons in `tensor_input` by `replacing_tensor_input`. :param tensor_input: tensors of shape (batch_size, sequence_length, ...) if @@ -330,30 +325,100 @@ def scatter_neurons( :param use_fast: whether to use fast path (TODO: fast path condition) :return the in-place modified tensor_input """ + # if tensor is splitted, we need to get the start and end indices + meta_component = output_to_subcomponent( + torch.arange(tensor_input.shape[-1]).unsqueeze(dim=0).unsqueeze(dim=0), + component, + model_type, + model_config, + ) + + last_dim = meta_component.shape[-1] + if "." in unit: # extra dimension for multi-level intervention unit_locations = ( torch.tensor(unit_locations_as_list[0], device=tensor_input.device), torch.tensor(unit_locations_as_list[1], device=tensor_input.device), ) - else: - unit_locations = torch.tensor( - unit_locations_as_list, device=tensor_input.device - ) - # if tensor is splitted, we need to get the start and end indices - meta_component = output_to_subcomponent( - torch.arange(tensor_input.shape[-1]).unsqueeze(dim=0).unsqueeze(dim=0), - component, - model_type, - model_config, + if unit != "h.pos": + # TODO: let's leave batch disabling for complex interventions to later + _batch_idx = torch.arange(tensor_input.shape[0])[:, None, None] + return tensor_input[ + _batch_idx, unit_locations[0][:, :, None], unit_locations[1][:, None, :] + ] + + # head-based scattering is only special for transformer-based model + # replacing_tensor_input: b_s, num_h, s, h_dim -> b_s, s, num_h*h_dim + old_shape = tensor_input.size() # b_s, s, x*num_h*d + new_shape = tensor_input.size()[:-1] + ( + -1, + meta_component.shape[1], + last_dim, + ) # b_s, s, x, num_h, d + + # get whether split by QKV + # NOTE: type_to_module_mapping[model_type][component][2] is an optional config tuple + # specifying how to index for a specific component of a single embedding: + # - the function splitting the embedding vector by component, and + # - the index of the component within the resulting split. + if ( + component in type_to_module_mapping[model_type] + and len(type_to_module_mapping[model_type][component]) > 2 + and type_to_module_mapping[model_type][component][2][0] == split_three + ): + _slice_idx = type_to_module_mapping[model_type][component][2][1] + else: + _slice_idx = 0 + + _batch_idx = torch.arange(tensor_input.shape[0])[:, None, None] + _head_idx = unit_locations[0][:, :, None] + _pos_idx = unit_locations[1][:, None, :] + tensor_permute = tensor_input.view(new_shape).permute( + 0, 3, 1, 2, 4 + ) # b_s, num_h, s, x, d + tensor_permute[ + _batch_idx, + _head_idx, + _pos_idx, + _slice_idx, + ] = replacing_tensor_input[:, : _head_idx.shape[1], : _pos_idx.shape[2]] + # reshape + tensor_output = tensor_permute.permute(0, 2, 3, 1, 4).view(old_shape) + return tensor_output # b_s, s, x*num_h*d + + _batch_idx = torch.tensor( + [ + i + for i in range(tensor_input.shape[0]) + if unit_locations_as_list[i] is not None + ] + ) + + if not len(_batch_idx): + return tensor_input + + unit_locations = torch.tensor( + [arr for arr in unit_locations_as_list if arr is not None], + device=tensor_input.device, ) + start_index, end_index = ( meta_component.min().tolist(), - meta_component.max().tolist() + 1, + (meta_component.max() + 1).tolist(), ) - last_dim = meta_component.shape[-1] - _batch_idx = torch.arange(tensor_input.shape[0]).unsqueeze(1) + + # print( + # f"Input shape: {tensor_input.shape}, Replacing shape: {replacing_tensor_input.shape}" + # ) + # print( + # f"Scatter neurons: {_batch_idx}, {unit_locations}, {start_index}, {end_index}" + # ) + + assert ( + unit_locations.shape[0] == _batch_idx.shape[0] + ), f"unit_locations: {unit_locations.shape}, _batch_idx: {_batch_idx.shape}" # in case it is time step, there is no sequence-related index if unit in {"t"}: @@ -361,17 +426,11 @@ def scatter_neurons( tensor_input[_batch_idx, start_index:end_index] = replacing_tensor_input return tensor_input elif unit in {"pos"}: - if use_fast: - # maybe this is all redundant, but maybe faster slightly? - tensor_input[ - _batch_idx, unit_locations[0], start_index:end_index - ] = replacing_tensor_input - else: - tensor_input[ - _batch_idx, unit_locations, start_index:end_index - ] = replacing_tensor_input + tensor_input[_batch_idx[:, None], unit_locations, start_index:end_index] = ( + replacing_tensor_input[_batch_idx, :, start_index:end_index] + ) return tensor_input - elif unit in {"h", "h.pos"}: + elif unit in {"h"}: # head-based scattering is only special for transformer-based model # replacing_tensor_input: b_s, num_h, s, h_dim -> b_s, s, num_h*h_dim old_shape = tensor_input.size() # b_s, s, -1*num_h*d @@ -391,31 +450,16 @@ def scatter_neurons( _slice_idx = 0 tensor_permute = tensor_input.view(new_shape) # b_s, s, -1, num_h, d tensor_permute = tensor_permute.permute(0, 3, 2, 1, 4) # b_s, num_h, -1, s, d - if "." in unit: - # cannot advance indexing on two columns, thus a single for loop is unavoidable. - for i in range(unit_locations[0].shape[-1]): - tensor_permute[ - _batch_idx, unit_locations[0][:, [i]], _slice_idx, unit_locations[1] - ] = replacing_tensor_input[:, i] - else: - tensor_permute[ - _batch_idx, unit_locations, _slice_idx - ] = replacing_tensor_input + tensor_permute[_batch_idx[:, None], unit_locations, _slice_idx] = ( + replacing_tensor_input[_batch_idx] + ) # permute back and reshape tensor_output = tensor_permute.permute(0, 3, 2, 1, 4) # b_s, s, -1, num_h, d tensor_output = tensor_output.view(old_shape) # b_s, s, -1*num_h*d return tensor_output else: - if "." in unit: - # cannot advance indexing on two columns, thus a single for loop is unavoidable. - for i in range(unit_locations[0].shape[-1]): - tensor_input[ - _batch_idx, unit_locations[0][:, [i]], unit_locations[1] - ] = replacing_tensor_input[:, i] - else: - tensor_input[_batch_idx, unit_locations] = replacing_tensor_input + tensor_input[_batch_idx, unit_locations] = replacing_tensor_input[_batch_idx] return tensor_input - assert False def do_intervention( @@ -424,18 +468,17 @@ def do_intervention( """Do the actual intervention.""" if isinstance(intervention, types.FunctionType): - if subspaces is None: - return intervention(base_representation, source_representation) - else: - return intervention(base_representation, source_representation, subspaces) + return intervention(base_representation, source_representation) num_unit = base_representation.shape[1] # flatten original_base_shape = base_representation.shape - if len(original_base_shape) == 2 or ( - isinstance(intervention, LocalistRepresentationIntervention) - ) or intervention.keep_last_dim: + if ( + len(original_base_shape) == 2 + or (isinstance(intervention, LocalistRepresentationIntervention)) + or intervention.keep_last_dim + ): # no pos dimension, e.g., gru, or opt-out concate last two dims base_representation_f = base_representation source_representation_f = source_representation @@ -449,7 +492,7 @@ def do_intervention( source_representation_f = bhsd_to_bs_hd(source_representation) else: assert False # what's going on? - + intervened_representation = intervention( base_representation_f, source_representation_f, subspaces ) @@ -457,9 +500,11 @@ def do_intervention( post_d = intervened_representation.shape[-1] # unflatten - if len(original_base_shape) == 2 or isinstance( - intervention, LocalistRepresentationIntervention - ) or intervention.keep_last_dim: + if ( + len(original_base_shape) == 2 + or isinstance(intervention, LocalistRepresentationIntervention) + or intervention.keep_last_dim + ): # no pos dimension, e.g., gru or opt-out concate last two dims pass elif len(original_base_shape) == 3: diff --git a/pyvene_101.ipynb b/pyvene_101.ipynb index 67ccd07f..893cf0c8 100644 --- a/pyvene_101.ipynb +++ b/pyvene_101.ipynb @@ -126,10 +126,27 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "17c7f2f6-b0d3-4fe2-8e4f-c044b93f3ef0", - "metadata": {}, - "outputs": [], + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fce745d6f2ca453b98f7b10868b1ab7d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "generation_config.json: 0%| | 0.00/124 [00:00. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "517de63768da4f7f8f58e5018c6f75f6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "model.safetensors: 0%| | 0.00/308M [00:00base": (0, [0, 1, 2])}, max_length=32 + ) + print(tokenizer.decode(intervened_story[0], skip_special_tokens=True)) + + def test_generation_with_source_intervened_prompt(self): + torch.manual_seed(0) + + pv_model = pv.IntervenableModel( + [ + { + "layer": l, + "component": "mlp_output", + "intervention": lambda b, s: b + s * 0.5, + } + for l in range(self.config.num_layers) + ], + model=self.tinystory, + ) + + prompt = self.tokenizer("Once upon a time there was", return_tensors="pt").to( + self.device + ) + orig, intervened = pv_model.generate( + prompt, + max_length=32, + sources=self.tokenizer("Happy love", return_tensors="pt").to(self.device), + intervene_on_prompt=True, + unit_locations={"sources->base": 0}, + output_original_output=True, + ) + orig_text, intervened_text = ( + self.tokenizer.decode(orig[0], skip_special_tokens=True), + self.tokenizer.decode(intervened[0], skip_special_tokens=True), + ) + + print(orig_text) + print(intervened_text) + assert ( + orig_text != intervened_text + ), "Aggressive intervention did not change the output. Probably something wrong." + + def test_dynamic_static_generation_intervention_parity(self): + torch.manual_seed(1) + + pv_model = pv.IntervenableModel( + [ + { + "layer": l, + "component": "mlp_output", + "intervention": lambda b, s: torch.ones_like(b), + } + for l in range(self.config.num_layers) + ], + model=self.tinystory, + ) + + prompt = self.tokenizer("Once upon a time there was", return_tensors="pt").to( + self.device + ) + INTERVENTION_DELAY = 5 + + orig, intervened = pv_model.generate( + prompt, + max_length=prompt.input_ids.shape[1] + INTERVENTION_DELAY + 2, + timestep_selector=[lambda idx, o: idx == INTERVENTION_DELAY] + * self.config.num_layers, + output_original_output=True, + ) + orig_text, intervened_text = ( + self.tokenizer.decode(orig[0], skip_special_tokens=True), + self.tokenizer.decode(intervened[0], skip_special_tokens=True), + ) + + print(orig_text) + print(intervened_text) + assert ( + orig_text != intervened_text + ), "Aggressive intervention did not change the output. Probably something wrong." + + def test_generation_noops(self): + torch.manual_seed(0) + + # No-op intervention + pv_model = pv.IntervenableModel( + [ + { + "layer": l, + "component": "mlp_output", + "intervention": lambda b, s: b, + } + for l in range(self.config.num_layers) + ], + model=self.tinystory, + ) + + prompt = self.tokenizer("Once upon a time there was", return_tensors="pt").to( + self.device + ) + sources = self.tokenizer(" love", return_tensors="pt").to(self.device) + + orig, intervened = pv_model.generate( + prompt, + max_length=20, + sources=sources, + intervene_on_prompt=True, + unit_locations={"sources->base": (0, [0, 1, 2])}, + output_original_output=True, + ) + orig_text, intervened_text = ( + self.tokenizer.decode(orig[0], skip_special_tokens=True), + self.tokenizer.decode(intervened[0], skip_special_tokens=True), + ) + + print(intervened_text) + assert ( + orig_text == intervened_text + ), "No-op intervention changed the output. Probably something wrong." + + # Aggressive intervention with intervene_on_prompt=False + aggressive_model = pv.IntervenableModel( + [ + { + "layer": l, + "component": "mlp_output", + "intervention": lambda b, s: s * 1000, + } + for l in range(self.config.num_layers) + ], + model=self.tinystory, + ) + + orig, intervened = aggressive_model.generate( + prompt, + max_length=20, + sources=sources, + intervene_on_prompt=False, + output_original_output=True, + ) + + orig_text, intervened_text = ( + self.tokenizer.decode(orig[0], skip_special_tokens=True), + self.tokenizer.decode(intervened[0], skip_special_tokens=True), + ) + print(orig_text) + print(intervened_text) + assert ( + orig_text == intervened_text + ), "Aggressive intervention changed the output. Probably something wrong." + + # Aggressive intervention with no prompt intervention, disabled selectors + orig, intervened = aggressive_model.generate( + prompt, + max_length=20, + sources=sources, + intervene_on_prompt=False, + output_original_output=True, + timestep_selector=[lambda idx, o: False] * self.config.num_layers, + ) + orig_text, intervened_text = ( + self.tokenizer.decode(orig[0], skip_special_tokens=True), + self.tokenizer.decode(intervened[0], skip_special_tokens=True), + ) + assert ( + orig_text == intervened_text + ), "Aggressive intervention changed the output. Probably something wrong." + +if __name__ == "__main__": + unittest.main() diff --git a/tests/integration_tests/IntervenableBasicTestCase.py b/tests/integration_tests/IntervenableBasicTestCase.py index 546c496a..23a751a0 100644 --- a/tests/integration_tests/IntervenableBasicTestCase.py +++ b/tests/integration_tests/IntervenableBasicTestCase.py @@ -1,147 +1,113 @@ import unittest + +from pyvene.models.constants import CONST_OUTPUT_HOOK from ..utils import * import copy import torch import pyvene as pv +import pprint + class IntervenableBasicTestCase(unittest.TestCase): """These are API level positive cases.""" + @classmethod - def setUpClass(self): + def setUpClass(cls): _uuid = str(uuid.uuid4())[:6] - self._test_dir = os.path.join(f"./test_output_dir_prefix-{_uuid}") + cls._test_dir = os.path.join(f"./test_output_dir_prefix-{_uuid}") + cls.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def test_lazy_demo(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - pv_gpt2 = pv.IntervenableModel({ - "layer": 0, - "component": "mlp_output", - "source_representation": torch.zeros( - gpt2.config.n_embd) - }, model=gpt2) - - intervened_outputs = pv_gpt2( - base = tokenizer( - "The capital of Spain is", - return_tensors="pt" - ), - unit_locations={"base": 3} - ) - - def test_less_lazy_demo(self): - - _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - - config = pv.IntervenableConfig([ + pv_gpt2 = pv.IntervenableModel( { - "layer": _, + "layer": 0, "component": "mlp_output", - "source_representation": torch.zeros( - gpt2.config.n_embd) - } for _ in range(4)], - mode="parallel" + "source_representation": torch.zeros(gpt2.config.n_embd), + }, + model=gpt2, ) - print(config) - pv_gpt2 = pv.IntervenableModel(config, model=gpt2) intervened_outputs = pv_gpt2( - base = tokenizer( - "The capital of Spain is", - return_tensors="pt" - ), - unit_locations={"base": 3} + base=tokenizer("The capital of Spain is", return_tensors="pt"), + unit_locations={"base": 3}, ) def test_less_lazy_demo(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - config = pv.IntervenableConfig([ - { - "layer": _, - "component": "mlp_output", - "source_representation": torch.zeros( - gpt2.config.n_embd) - } for _ in range(4)], - mode="parallel" + config = pv.IntervenableConfig( + [ + { + "layer": _, + "component": "mlp_output", + "source_representation": torch.zeros(gpt2.config.n_embd), + } + for _ in range(4) + ], + mode="parallel", ) print(config) pv_gpt2 = pv.IntervenableModel(config, model=gpt2) intervened_outputs = pv_gpt2( - base = tokenizer( - "The capital of Spain is", - return_tensors="pt" - ), - unit_locations={"base": 3} + base=tokenizer("The capital of Spain is", return_tensors="pt"), + unit_locations={"base": 3}, ) def test_source_reprs_pass_in_unit_loc_broadcast_demo(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - pv_gpt2 = pv.IntervenableModel({ - "layer": 0, - "component": "mlp_output", - }, model=gpt2) + pv_gpt2 = pv.IntervenableModel( + { + "layer": 0, + "component": "mlp_output", + }, + model=gpt2, + ) intervened_outputs = pv_gpt2( - base = tokenizer( - "The capital of Spain is", - return_tensors="pt" - ), - source_representations = torch.zeros(gpt2.config.n_embd), - unit_locations={"base": 3} + base=tokenizer("The capital of Spain is", return_tensors="pt"), + source_representations=torch.zeros(gpt2.config.n_embd), + unit_locations={"base": 3}, ) def test_input_corrupt_multi_token(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - config = pv.IntervenableConfig({ - "layer": 0, - "component": "mlp_input"}, - pv.AdditionIntervention + config = pv.IntervenableConfig( + {"layer": 0, "component": "mlp_input"}, pv.AdditionIntervention ) pv_gpt2 = pv.IntervenableModel(config, model=gpt2) intervened_outputs = pv_gpt2( - base = tokenizer( - "The Space Needle is in downtown", - return_tensors="pt" - ), + base=tokenizer("The Space Needle is in downtown", return_tensors="pt"), unit_locations={"base": [[[0, 1, 2, 3]]]}, - source_representations = torch.rand(gpt2.config.n_embd) + source_representations=torch.rand(gpt2.config.n_embd), ) def test_trainable_backward(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - config = pv.IntervenableConfig({ - "layer": 8, - "component": "block_output", - "low_rank_dimension": 1}, - pv.LowRankRotatedSpaceIntervention + config = pv.IntervenableConfig( + {"layer": 8, "component": "block_output", "low_rank_dimension": 1}, + pv.LowRankRotatedSpaceIntervention, ) - pv_gpt2 = pv.IntervenableModel( - config, model=gpt2) + pv_gpt2 = pv.IntervenableModel(config, model=gpt2) last_hidden_state = pv_gpt2( - base = tokenizer( - "The capital of Spain is", - return_tensors="pt" - ), - sources = tokenizer( - "The capital of Italy is", - return_tensors="pt" - ), - unit_locations={"sources->base": 3} + base=tokenizer("The capital of Spain is", return_tensors="pt"), + sources=tokenizer("The capital of Italy is", return_tensors="pt"), + unit_locations={"sources->base": 3}, )[-1].last_hidden_state loss = last_hidden_state.sum() @@ -151,165 +117,112 @@ def test_reprs_collection(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - config = pv.IntervenableConfig({ - "layer": 10, - "component": "block_output", - "intervention_type": pv.CollectIntervention} + config = pv.IntervenableConfig( + { + "layer": 10, + "component": "block_output", + "intervention_type": pv.CollectIntervention, + } ) - pv_gpt2 = pv.IntervenableModel( - config, model=gpt2) + pv_gpt2 = pv.IntervenableModel(config, model=gpt2) collected_activations = pv_gpt2( - base = tokenizer( - "The capital of Spain is", - return_tensors="pt" - ), - unit_locations={"sources->base": 3} + base=tokenizer("The capital of Spain is", return_tensors="pt"), + unit_locations={"sources->base": 3}, )[0][-1] def test_reprs_collection_after_intervention(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - config = pv.IntervenableConfig({ - "layer": 8, - "component": "block_output", - "intervention_type": pv.VanillaIntervention} + config = pv.IntervenableConfig( + { + "layer": 8, + "component": "block_output", + "intervention_type": pv.VanillaIntervention, + } ) - config.add_intervention({ - "layer": 10, - "component": "block_output", - "intervention_type": pv.CollectIntervention}) + config.add_intervention( + { + "layer": 10, + "component": "block_output", + "intervention_type": pv.CollectIntervention, + } + ) - pv_gpt2 = pv.IntervenableModel( - config, model=gpt2) + pv_gpt2 = pv.IntervenableModel(config, model=gpt2) collected_activations = pv_gpt2( - base = tokenizer( - "The capital of Spain is", - return_tensors="pt" - ), - sources = [tokenizer( - "The capital of Italy is", - return_tensors="pt" - ), None], - unit_locations={"sources->base": 3} + base=tokenizer("The capital of Spain is", return_tensors="pt"), + sources=[tokenizer("The capital of Italy is", return_tensors="pt"), None], + unit_locations={"sources->base": 3}, )[0][-1] def test_reprs_collection_on_one_neuron(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - config = pv.IntervenableConfig({ - "layer": 8, - "component": "head_attention_value_output", - "unit": "h.pos", - "intervention_type": pv.CollectIntervention} + config = pv.IntervenableConfig( + { + "layer": 8, + "component": "head_attention_value_output", + "unit": "h.pos", + "intervention_type": pv.CollectIntervention, + } ) - pv_gpt2 = pv.IntervenableModel( - config, model=gpt2) + pv_gpt2 = pv.IntervenableModel(config, model=gpt2) collected_activations = pv_gpt2( - base = tokenizer( - "The capital of Spain is", - return_tensors="pt" - ), - unit_locations={ - "base": pv.GET_LOC((3,3)) - }, - subspaces=[0] + base=tokenizer("The capital of Spain is", return_tensors="pt"), + unit_locations={"base": pv.GET_LOC((3, 3))}, + subspaces=[0], )[0][-1] def test_new_intervention_type(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - class MultiplierIntervention( - pv.ConstantSourceIntervention): + class MultiplierIntervention(pv.ConstantSourceIntervention): def __init__(self, embed_dim, **kwargs): super().__init__() - def forward( - self, base, source=None, subspaces=None): + + def forward(self, base, source=None, subspaces=None): return base * 99.0 + # run with new intervention type - pv_gpt2 = pv.IntervenableModel({ - "intervention_type": MultiplierIntervention}, - model=gpt2) + pv_gpt2 = pv.IntervenableModel( + {"intervention_type": MultiplierIntervention}, model=gpt2 + ) intervened_outputs = pv_gpt2( - base = tokenizer("The capital of Spain is", - return_tensors="pt"), - unit_locations={"base": 3}) - - def test_recurrent_nn(self): - - _, _, gru = pv.create_gru_classifier( - pv.GRUConfig(h_dim=32)) - - pv_gru = pv.IntervenableModel({ - "component": "cell_output", - "unit": "t", - "intervention_type": pv.ZeroIntervention}, - model=gru) - - rand_t = torch.rand(1,10, gru.config.h_dim) - - intervened_outputs = pv_gru( - base = {"inputs_embeds": rand_t}, - unit_locations={"base": 3}) - - def test_lm_generation(self): - - # built-in helper to get tinystore - _, tokenizer, tinystory = pv.create_gpt_neo() - emb_happy = tinystory.transformer.wte( - torch.tensor(14628)) * 0.3 - - pv_tinystory = pv.IntervenableModel([{ - "layer": _, - "component": "mlp_output", - "intervention_type": pv.AdditionIntervention - } for _ in range( - tinystory.config.num_layers)], - model=tinystory) - - prompt = tokenizer( - "Once upon a time there was", - return_tensors="pt") - _, intervened_story = pv_tinystory.generate( - prompt, - source_representations=emb_happy, - max_length=32 - ) - print(tokenizer.decode( - intervened_story[0], - skip_special_tokens=True - )) + base=tokenizer("The capital of Spain is", return_tensors="pt"), + unit_locations={"base": 3}, + ) def test_save_and_load(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) # run with new intervention type - pv_gpt2 = pv.IntervenableModel({ - "intervention_type": pv.ZeroIntervention}, - model=gpt2) + pv_gpt2 = pv.IntervenableModel( + {"intervention_type": pv.ZeroIntervention}, model=gpt2 + ) pv_gpt2.save(self._test_dir) - pv_gpt2_load = pv.IntervenableModel.load( - self._test_dir, - model=gpt2) - + pv_gpt2_load = pv.IntervenableModel.load(self._test_dir, model=gpt2) + def test_intervention_grouping(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - config = pv.IntervenableConfig([ - {"layer": 0, "component": "block_output", "group_key": 0}, - {"layer": 2, "component": "block_output", "group_key": 0}], + config = pv.IntervenableConfig( + [ + {"layer": 0, "component": "block_output", "group_key": 0}, + {"layer": 2, "component": "block_output", "group_key": 0}, + ], intervention_types=pv.VanillaIntervention, ) @@ -317,25 +230,30 @@ def test_intervention_grouping(self): base = tokenizer("The capital of Spain is", return_tensors="pt") sources = [tokenizer("The capital of Italy is", return_tensors="pt")] - intervened_outputs = pv_gpt2( - base, sources, - {"sources->base": ([ - [[3]], [[4]] # these two are for two interventions - ], [ # source position 3 into base position 4 - [[3]], [[4]] - ])} - ) - + _, intervened_outputs = pv_gpt2( + base, + sources, + { + "sources->base": ( + [[[3]], [[4]]], # these two are for two interventions + [[[3]], [[4]]], # source position 3 into base position 4 + ) + }, + ) + pprint.pprint(get_topk(gpt2, tokenizer, intervened_outputs)) + def test_intervention_skipping(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - config = pv.IntervenableConfig([ - # these are equivalent interventions - # we create them on purpose - {"layer": 0, "component": "block_output"}, - {"layer": 0, "component": "block_output"}, - {"layer": 0, "component": "block_output"}], + config = pv.IntervenableConfig( + [ + # these are equivalent interventions + # we create them on purpose + {"layer": 0, "component": "block_output"}, + {"layer": 0, "component": "block_output"}, + {"layer": 0, "component": "block_output"}, + ], intervention_types=pv.VanillaIntervention, ) pv_gpt2 = pv.IntervenableModel(config, model=gpt2) @@ -343,27 +261,45 @@ def test_intervention_skipping(self): base = tokenizer("The capital of Spain is", return_tensors="pt") source = tokenizer("The capital of Italy is", return_tensors="pt") # skipping 1, 2 and 3 - _, pv_out1 = pv_gpt2(base, [None, None, source], - {"sources->base": ([None, None, [[4]]], [None, None, [[4]]])}) - _, pv_out2 = pv_gpt2(base, [None, source, None], - {"sources->base": ([None, [[4]], None], [None, [[4]], None])}) - _, pv_out3 = pv_gpt2(base, [source, None, None], - {"sources->base": ([[[4]], None, None], [[[4]], None, None])}) + _, pv_out1 = pv_gpt2( + base, + [None, None, source], + {"sources->base": ([None, None, [[4]]], [None, None, [[4]]])}, + ) + _, pv_out2 = pv_gpt2( + base, + [None, source, None], + {"sources->base": ([None, [[4]], None], [None, [[4]], None])}, + ) + _, pv_out3 = pv_gpt2( + base, + [source, None, None], + {"sources->base": ([[[4]], None, None], [[[4]], None, None])}, + ) # should have the same results - self.assertTrue(torch.equal(pv_out1.last_hidden_state, pv_out2.last_hidden_state)) - self.assertTrue(torch.equal(pv_out2.last_hidden_state, pv_out3.last_hidden_state)) - + self.assertTrue( + torch.equal(pv_out1.last_hidden_state, pv_out2.last_hidden_state) + ) + self.assertTrue( + torch.equal(pv_out2.last_hidden_state, pv_out3.last_hidden_state) + ) + def test_subspace_intervention(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - config = pv.IntervenableConfig([ - # they are linked to manipulate the same representation - # but in different subspaces - {"layer": 0, "component": "block_output", - # subspaces can be partitioned into continuous chunks - # [i, j] are the boundary indices - "subspace_partition": [[0, 128], [128, 256]]}], + config = pv.IntervenableConfig( + [ + # they are linked to manipulate the same representation + # but in different subspaces + { + "layer": 0, + "component": "block_output", + # subspaces can be partitioned into continuous chunks + # [i, j] are the boundary indices + "subspace_partition": [[0, 128], [128, 256]], + } + ], intervention_types=pv.VanillaIntervention, ) pv_gpt2 = pv.IntervenableModel(config, model=gpt2) @@ -373,23 +309,34 @@ def test_subspace_intervention(self): # using intervention skipping for subspace intervened_outputs = pv_gpt2( - base, [source], + base, + [source], {"sources->base": 4}, # intervene only only dimensions from 128 to 256 subspaces=1, ) - + def test_linked_intervention_and_weights_sharing(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - config = pv.IntervenableConfig([ - # they are linked to manipulate the same representation - # but in different subspaces - {"layer": 0, "component": "block_output", - "subspace_partition": [[0, 128], [128, 256]], "intervention_link_key": 0}, - {"layer": 0, "component": "block_output", - "subspace_partition": [[0, 128], [128, 256]], "intervention_link_key": 0}], + config = pv.IntervenableConfig( + [ + # they are linked to manipulate the same representation + # but in different subspaces + { + "layer": 0, + "component": "block_output", + "subspace_partition": [[0, 128], [128, 256]], + "intervention_link_key": 0, + }, + { + "layer": 0, + "component": "block_output", + "subspace_partition": [[0, 128], [128, 256]], + "intervention_link_key": 0, + }, + ], intervention_types=pv.VanillaIntervention, ) pv_gpt2 = pv.IntervenableModel(config, model=gpt2) @@ -399,7 +346,8 @@ def test_linked_intervention_and_weights_sharing(self): # using intervention skipping for subspace _, pv_out1 = pv_gpt2( - base, [None, source], + base, + [None, source], # 4 means token position 4 {"sources->base": ([None, [[4]]], [None, [[4]]])}, # 1 means the second partition in the config @@ -411,7 +359,9 @@ def test_linked_intervention_and_weights_sharing(self): {"sources->base": ([[[4]], None], [[[4]], None])}, subspaces=[[[1]], None], ) - self.assertTrue(torch.equal(pv_out1.last_hidden_state, pv_out2.last_hidden_state)) + self.assertTrue( + torch.equal(pv_out1.last_hidden_state, pv_out2.last_hidden_state) + ) # subspaces provide a list of index and they can be in any order _, pv_out3 = pv_gpt2( @@ -426,8 +376,10 @@ def test_linked_intervention_and_weights_sharing(self): {"sources->base": ([[[4]], [[4]]], [[[4]], [[4]]])}, subspaces=[[[1]], [[0]]], ) - self.assertTrue(torch.equal(pv_out3.last_hidden_state, pv_out4.last_hidden_state)) - + self.assertTrue( + torch.equal(pv_out3.last_hidden_state, pv_out4.last_hidden_state) + ) + def test_new_model_type(self): try: import sentencepiece @@ -436,6 +388,7 @@ def test_new_model_type(self): return # get a flan-t5 from HuggingFace from transformers import T5ForConditionalGeneration, T5Tokenizer, T5Config + config = T5Config.from_pretrained("google/flan-t5-small") tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small") t5 = T5ForConditionalGeneration.from_pretrained( @@ -445,10 +398,8 @@ def test_new_model_type(self): # config the intervention mapping with pv global vars """Only define for the block output here for simplicity""" pv.type_to_module_mapping[type(t5)] = { - "mlp_output": ("encoder.block[%s].layer[1]", - pv.models.constants.CONST_OUTPUT_HOOK), - "attention_input": ("encoder.block[%s].layer[0]", - pv.models.constants.CONST_OUTPUT_HOOK), + "mlp_output": ("encoder.block[%s].layer[1]", CONST_OUTPUT_HOOK), + "attention_input": ("encoder.block[%s].layer[0]", CONST_OUTPUT_HOOK), } pv.type_to_dimension_mapping[type(t5)] = { "mlp_output": ("d_model",), @@ -458,214 +409,294 @@ def test_new_model_type(self): } # wrap as gpt2 - pv_t5 = pv.IntervenableModel({ - "layer": 0, - "component": "mlp_output", - "source_representation": torch.zeros( - t5.config.d_model) - }, model=t5) + pv_t5 = pv.IntervenableModel( + { + "layer": 0, + "component": "mlp_output", + "source_representation": torch.zeros(t5.config.d_model), + }, + model=t5, + ) # then intervene! - base = tokenizer("The capital of Spain is", - return_tensors="pt") - decoder_input_ids = tokenizer( - "", return_tensors="pt").input_ids + base = tokenizer("The capital of Spain is", return_tensors="pt") + decoder_input_ids = tokenizer("", return_tensors="pt").input_ids base["decoder_input_ids"] = decoder_input_ids - intervened_outputs = pv_t5( - base, - unit_locations={"base": 3} - ) + intervened_outputs = pv_t5(base, unit_locations={"base": 3}) def test_path_patching(self): def path_patching_config( - layer, last_layer, - component="head_attention_value_output", unit="h.pos" + layer, last_layer, component="head_attention_value_output", unit="h.pos" ): intervening_component = [ - {"layer": layer, "component": component, "unit": unit, "group_key": 0}] + {"layer": layer, "component": component, "unit": unit, "group_key": 0} + ] restoring_components = [] if not component.startswith("mlp_"): restoring_components += [ - {"layer": layer, "component": "mlp_output", "group_key": 1}] - for i in range(layer+1, last_layer): + {"layer": layer, "component": "mlp_output", "group_key": 1} + ] + for i in range(layer + 1, last_layer): restoring_components += [ {"layer": i, "component": "attention_output", "group_key": 1}, - {"layer": i, "component": "mlp_output", "group_key": 1} + {"layer": i, "component": "mlp_output", "group_key": 1}, ] intervenable_config = pv.IntervenableConfig( - intervening_component + restoring_components) + intervening_component + restoring_components + ) return intervenable_config _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) pv_gpt2 = pv.IntervenableModel( - path_patching_config(4, gpt2.config.n_layer), - model=gpt2 + path_patching_config(4, gpt2.config.n_layer), model=gpt2 ) - pv_gpt2.save( - save_directory="./tmp/" - ) - - pv_gpt2 = pv.IntervenableModel.load( - "./tmp/", - model=gpt2) - + pv_gpt2.save(save_directory="./tmp/") + + pv_gpt2 = pv.IntervenableModel.load("./tmp/", model=gpt2) + def test_multisource_parallel(self): _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - config = pv.IntervenableConfig([ - {"layer": 0, "component": "mlp_output"}, - {"layer": 2, "component": "mlp_output"}], - mode="parallel" + config = pv.IntervenableConfig( + [ + {"layer": 0, "component": "mlp_output"}, + {"layer": 2, "component": "mlp_output"}, + ], + mode="parallel", ) pv_gpt2 = pv.IntervenableModel(config, model=gpt2) base = tokenizer("The capital of Spain is", return_tensors="pt") - sources = [tokenizer("The capital of Italy is", return_tensors="pt"), - tokenizer("The capital of China is", return_tensors="pt")] + sources = [ + tokenizer("The capital of Italy is", return_tensors="pt"), + tokenizer("The capital of China is", return_tensors="pt"), + ] intervened_outputs = pv_gpt2( - base, sources, + base, + sources, # on same position {"sources->base": 4}, ) - + _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - config = pv.IntervenableConfig([ - {"layer": 0, "component": "block_output", - "subspace_partition": - [[0, 128], [128, 256]]}]*2, + config = pv.IntervenableConfig( + [ + { + "layer": 0, + "component": "block_output", + "subspace_partition": [[0, 128], [128, 256]], + } + ] + * 2, intervention_types=pv.VanillaIntervention, # act in parallel - mode="parallel" + mode="parallel", ) pv_gpt2 = pv.IntervenableModel(config, model=gpt2) base = tokenizer("The capital of Spain is", return_tensors="pt") - sources = [tokenizer("The capital of Italy is", return_tensors="pt"), - tokenizer("The capital of China is", return_tensors="pt")] + sources = [ + tokenizer("The capital of Italy is", return_tensors="pt"), + tokenizer("The capital of China is", return_tensors="pt"), + ] intervened_outputs = pv_gpt2( - base, sources, + base, + sources, # on same position {"sources->base": 4}, # on different subspaces subspaces=[[[0]], [[1]]], ) - + def test_multisource_serial(self): - + _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - config = pv.IntervenableConfig([ - {"layer": 0, "component": "mlp_output"}, - {"layer": 2, "component": "mlp_output"}], - mode="serial" + config = pv.IntervenableConfig( + [ + {"layer": 0, "component": "mlp_output"}, + {"layer": 2, "component": "mlp_output"}, + ], + mode="serial", ) pv_gpt2 = pv.IntervenableModel(config, model=gpt2) base = tokenizer("The capital of Spain is", return_tensors="pt") - sources = [tokenizer("The capital of Italy is", return_tensors="pt"), - tokenizer("The capital of China is", return_tensors="pt")] + sources = [ + tokenizer("The capital of Italy is", return_tensors="pt"), + tokenizer("The capital of China is", return_tensors="pt"), + ] intervened_outputs = pv_gpt2( - base, sources, + base, + sources, # serialized intervention # order is based on sources list {"source_0->source_1": 3, "source_1->base": 4}, ) - + _, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir) - config = pv.IntervenableConfig([ - {"layer": 0, "component": "block_output", - "subspace_partition": [[0, 128], [128, 256]]}, - {"layer": 2, "component": "block_output", - "subspace_partition": [[0, 128], [128, 256]]}], + config = pv.IntervenableConfig( + [ + { + "layer": 0, + "component": "block_output", + "subspace_partition": [[0, 128], [128, 256]], + }, + { + "layer": 2, + "component": "block_output", + "subspace_partition": [[0, 128], [128, 256]], + }, + ], intervention_types=pv.VanillaIntervention, - # act in parallel - mode="serial" + mode="serial", ) pv_gpt2 = pv.IntervenableModel(config, model=gpt2) base = tokenizer("The capital of Spain is", return_tensors="pt") - sources = [tokenizer("The capital of Italy is", return_tensors="pt"), - tokenizer("The capital of China is", return_tensors="pt")] + sources = [ + tokenizer("The capital of Italy is", return_tensors="pt"), + tokenizer("The capital of China is", return_tensors="pt"), + ] intervened_outputs = pv_gpt2( - base, sources, + base, + sources, # serialized intervention # order is based on sources list {"source_0->source_1": 3, "source_1->base": 4}, # on different subspaces subspaces=[[[0]], [[1]]], ) - + def test_customized_intervention_function_get(self): _, tokenizer, gpt2 = pv.create_gpt2() - - pv_gpt2 = pv.IntervenableModel({ - "layer": 10, - "component": "attention_weight", - "intervention_type": pv.CollectIntervention}, model=gpt2) + pv_gpt2 = pv.IntervenableModel( + { + "layer": 10, + "component": "attention_weight", + "intervention_type": pv.CollectIntervention, + }, + model=gpt2, + ) base = "When John and Mary went to the shops, Mary gave the bag to" - collected_attn_w = pv_gpt2( - base = tokenizer(base, return_tensors="pt" - ), unit_locations={"base": [h for h in range(12)]} - )[0][-1][0] + (_, collected_attn_w), _ = pv_gpt2( + base=tokenizer(base, return_tensors="pt"), + unit_locations={"base": [h for h in range(12)]}, + ) cached_w = {} - def pv_patcher(b, s): cached_w["attn_w"] = copy.deepcopy(b.data) - pv_gpt2 = pv.IntervenableModel({ - "component": "h[10].attn.attn_dropout.input", - "intervention": pv_patcher}, model=gpt2) + def pv_patcher(b, s): + cached_w["attn_w"] = copy.deepcopy(b.data) + + pv_gpt2 = pv.IntervenableModel( + {"component": "h[10].attn.attn_dropout.input", "intervention": pv_patcher}, + model=gpt2, + ) base = "When John and Mary went to the shops, Mary gave the bag to" _ = pv_gpt2(tokenizer(base, return_tensors="pt")) - torch.allclose(collected_attn_w, cached_w["attn_w"].unsqueeze(dim=0)) - + torch.allclose(list(collected_attn_w.values())[0], cached_w["attn_w"].unsqueeze(dim=0)) + def test_customized_intervention_function_zeroout(self): - + _, tokenizer, gpt2 = pv.create_gpt2() # define the component to zero-out - pv_gpt2 = pv.IntervenableModel({ - "layer": 0, "component": "mlp_output", - "source_representation": torch.zeros(gpt2.config.n_embd) - }, model=gpt2) + pv_gpt2 = pv.IntervenableModel( + { + "layer": 0, + "component": "mlp_output", + "source_representation": torch.zeros(gpt2.config.n_embd), + }, + model=gpt2, + ) # run the intervened forward pass intervened_outputs = pv_gpt2( - base = tokenizer("The capital of Spain is", return_tensors="pt"), + base=tokenizer("The capital of Spain is", return_tensors="pt"), # we define the intervening token dynamically - unit_locations={"base": 3} + unit_locations={"base": 3}, ) - + # indices are specified in the intervention mask = torch.ones(1, 5, 768) - mask[:,3,:] = 0. + mask[:, 3, :] = 0.0 # define the component to zero-out - pv_gpt2 = pv.IntervenableModel({ - "component": "h[0].mlp.output", - "intervention": lambda b, s: b*mask - }, model=gpt2) + pv_gpt2 = pv.IntervenableModel( + {"component": "h[0].mlp.output", "intervention": lambda b, s: b * mask}, + model=gpt2, + ) # run the intervened forward pass intervened_outputs_fn = pv_gpt2( - base = tokenizer("The capital of Spain is", return_tensors="pt") + base=tokenizer("The capital of Spain is", return_tensors="pt") ) torch.allclose( - intervened_outputs[1].last_hidden_state, - intervened_outputs_fn[1].last_hidden_state + intervened_outputs[1].last_hidden_state, + intervened_outputs_fn[1].last_hidden_state, ) - + + def test_nulling_intervention(self): + + _, tokenizer, gpt2 = pv.create_gpt2() + gpt2.to(self.DEVICE) + base = tokenizer( + ["The capital of Spain is" for i in range(3)], return_tensors="pt" + ).to(self.DEVICE) + + base_output = gpt2(**base) + base_logits = pv.embed_to_distrib( + gpt2, base_output.last_hidden_state, logits=True + )[0] + print(base_logits.shape) + + pv_gpt2 = pv.IntervenableModel( + { + "layer": 0, + "component": "mlp_output", + "intervention": lambda b, s: b * 0.5 + s * 0.5, + }, + model=gpt2, + ) + pv_gpt2.set_device(self.DEVICE) + + _, intervened_outputs = pv_gpt2( + # the base input + base=base, + # the source input + sources=tokenizer(["Egypt" for i in range(3)], return_tensors="pt").to( + self.DEVICE + ), + # the location to intervene at (3rd token) + unit_locations={"sources->base": (0, [[[3], None, [3]]])}, + ) + + intervened_logits = pv.embed_to_distrib( + gpt2, intervened_outputs.last_hidden_state, logits=True + ) + assert not torch.allclose( + base_logits, intervened_logits[0] + ), "Intervention had no effect on example 0!" + assert torch.allclose( + base_logits, intervened_logits[1] + ), "Intervention was not nulled on example 1!" + assert not torch.allclose( + base_logits, intervened_logits[2] + ), "Intervention had no effect on example 2!" + @classmethod - def tearDownClass(self): - print(f"Removing testing dir {self._test_dir}") - if os.path.exists(self._test_dir) and os.path.isdir(self._test_dir): - shutil.rmtree(self._test_dir) \ No newline at end of file + def tearDownClass(cls): + print(f"Removing testing dir {cls._test_dir}") + if os.path.exists(cls._test_dir) and os.path.isdir(cls._test_dir): + shutil.rmtree(cls._test_dir) diff --git a/tests/integration_tests/InterventionWithGPT2TestCase.py b/tests/integration_tests/InterventionWithGPT2TestCase.py index 723c8286..d9b4e8f5 100644 --- a/tests/integration_tests/InterventionWithGPT2TestCase.py +++ b/tests/integration_tests/InterventionWithGPT2TestCase.py @@ -1,12 +1,13 @@ import unittest from ..utils import * +from transformers import GPT2Config class InterventionWithGPT2TestCase(unittest.TestCase): @classmethod - def setUpClass(self): + def setUpClass(cls): print("=== Test Suite: InterventionWithGPT2TestCase ===") - self.config, self.tokenizer, self.gpt2 = create_gpt2_lm( + cls.config, cls.tokenizer, cls.gpt2 = create_gpt2_lm( config=GPT2Config( n_embd=24, attn_pdrop=0.0, @@ -20,8 +21,8 @@ def setUpClass(self): vocab_size=10, ) ) - self.vanilla_block_output_config = IntervenableConfig( - model_type=type(self.gpt2), + cls.vanilla_block_output_config = IntervenableConfig( + model_type=type(cls.gpt2), representations=[ RepresentationConfig( 0, @@ -32,10 +33,10 @@ def setUpClass(self): ], intervention_types=VanillaIntervention, ) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.gpt2 = self.gpt2.to(self.device) + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + cls.gpt2 = cls.gpt2.to(cls.device) - self.nonhead_streams = [ + cls.nonhead_streams = [ "block_output", "block_input", "mlp_activation", @@ -49,7 +50,7 @@ def setUpClass(self): "value_output", ] - self.head_streams = [ + cls.head_streams = [ "head_attention_value_output", "head_query_output", "head_key_output", @@ -141,7 +142,7 @@ def test_with_multiple_heads_positions_vanilla_intervention_positive(self): Multiple head and position with vanilla intervention. """ for stream in self.head_streams: - print(f"testing stream: {stream} with multiple heads positions") + print(f"testing stream: {stream} with multiple heads and positions") self._test_with_head_position_intervention( intervention_layer=random.randint(0, 3), intervention_stream=stream, diff --git a/tests/integration_tests/InterventionWithMLPTestCase.py b/tests/integration_tests/InterventionWithMLPTestCase.py index 2a515c14..ecbb272c 100644 --- a/tests/integration_tests/InterventionWithMLPTestCase.py +++ b/tests/integration_tests/InterventionWithMLPTestCase.py @@ -4,17 +4,17 @@ class InterventionWithMLPTestCase(unittest.TestCase): @classmethod - def setUpClass(self): + def setUpClass(cls): print("=== Test Suite: InterventionWithMLPTestCase ===") - self.config, self.tokenizer, self.mlp = create_mlp_classifier( + cls.config, cls.tokenizer, cls.mlp = create_mlp_classifier( MLPConfig( h_dim=3, n_layer=1, pdrop=0.0, num_classes=5, include_bias=False, squeeze_output=False ) ) - self.test_subspace_intervention_link_config = IntervenableConfig( - model_type=type(self.mlp), + cls.test_subspace_intervention_link_config = IntervenableConfig( + model_type=type(cls.mlp), representations=[ RepresentationConfig( 0, @@ -42,9 +42,38 @@ def setUpClass(self): intervention_types=VanillaIntervention, ) - self.test_subspace_no_intervention_link_config = ( + cls.test_negative_subspace_config = IntervenableConfig( + model_type=type(cls.mlp), + representations=[ + RepresentationConfig( + 0, + "mlp_activation", + "pos", # mlp layer creates a single token reprs + 1, + subspace_partition=[ + [1, 4], + [0, 1], + ], # partition into two sets of subspaces + intervention_link_key=0, # linked ones target the same subspace + ), + RepresentationConfig( + 0, + "mlp_activation", + "pos", # mlp layer creates a single token reprs + 1, + subspace_partition=[ + [1, 4], + [0, 1], + ], # partition into two sets of subspaces + intervention_link_key=0, # linked ones target the same subspace + ), + ], + intervention_types=VanillaIntervention, + ) + + cls.test_subspace_no_intervention_link_config = ( IntervenableConfig( - model_type=type(self.mlp), + model_type=type(cls.mlp), representations=[ RepresentationConfig( 0, @@ -71,9 +100,9 @@ def setUpClass(self): ) ) - self.test_subspace_no_intervention_link_trainable_config = ( + cls.test_subspace_no_intervention_link_trainable_config = ( IntervenableConfig( - model_type=type(self.mlp), + model_type=type(cls.mlp), representations=[ RepresentationConfig( 0, @@ -149,13 +178,12 @@ def test_with_subspace_negative(self): Negative test case to check input length. """ intervenable = IntervenableModel( - self.test_subspace_intervention_link_config, self.mlp + self.test_negative_subspace_config, self.mlp ) # golden label b_s = 10 base = {"inputs_embeds": torch.rand(b_s, 1, 3)} source_1 = {"inputs_embeds": torch.rand(b_s, 1, 3)} - source_2 = {"inputs_embeds": torch.rand(b_s, 1, 3)} try: intervenable( diff --git a/tests/unit_tests/IntervenableConfigUnitTestCase.py b/tests/unit_tests/IntervenableConfigUnitTestCase.py index 9f3ded9d..9a9f70b4 100644 --- a/tests/unit_tests/IntervenableConfigUnitTestCase.py +++ b/tests/unit_tests/IntervenableConfigUnitTestCase.py @@ -1,5 +1,6 @@ import unittest from ..utils import * +from transformers import GPT2Config class IntervenableConfigUnitTestCase(unittest.TestCase): diff --git a/tests/unit_tests/InterventionUtilsTestCase.py b/tests/unit_tests/InterventionUtilsTestCase.py index ed1af69e..3a5cf6db 100644 --- a/tests/unit_tests/InterventionUtilsTestCase.py +++ b/tests/unit_tests/InterventionUtilsTestCase.py @@ -4,14 +4,15 @@ from pyvene.models.intervention_utils import _do_intervention_by_swap from pyvene.models.interventions import VanillaIntervention from pyvene.models.interventions import CollectIntervention +from transformers import GPT2Config class InterventionUtilsTestCase(unittest.TestCase): - + @classmethod - def setUpClass(self): + def setUpClass(cls): print("=== Test Suite: InterventionUtilsTestCase ===") - self.config, self.tokenizer, self.gpt2 = create_gpt2_lm( + cls.config, cls.tokenizer, cls.gpt2 = create_gpt2_lm( config=GPT2Config( n_embd=24, attn_pdrop=0.0, @@ -25,9 +26,9 @@ def setUpClass(self): vocab_size=10, ) ) - self.test_output_dir_prefix = "test_tmp_output" - self.test_output_dir_pool = [] - + cls.test_output_dir_prefix = "test_tmp_output" + cls.test_output_dir_pool = [] + def test_initialization_positive(self): config = IntervenableConfig( model_type=type(self.gpt2), @@ -51,7 +52,6 @@ def test_initialization_positive(self): intervenable = IntervenableModel(config, self.gpt2) assert intervenable.mode == "parallel" - self.assertTrue(intervenable.is_model_stateless) assert intervenable.use_fast == False assert len(intervenable.interventions) == 2 @@ -144,17 +144,13 @@ def test_local_trainable_save_positive(self): ) def _test_local_trainable_load_positive(self, intervention_types): - b_s = 10 + b_s = 12 config = IntervenableConfig( model_type=type(self.gpt2), representations=[ - RepresentationConfig( - 0, "block_output", "pos", 1, low_rank_dimension=4 - ), - RepresentationConfig( - 1, "block_output", "pos", 1, low_rank_dimension=4 - ), + RepresentationConfig(0, "block_output", "pos", 1, low_rank_dimension=4), + RepresentationConfig(1, "block_output", "pos", 1, low_rank_dimension=4), ], intervention_types=intervention_types, ) @@ -190,7 +186,7 @@ def test_local_load_positive(self): self._test_local_trainable_load_positive(VanillaIntervention) self._test_local_trainable_load_positive(RotatedSpaceIntervention) self._test_local_trainable_load_positive(LowRankRotatedSpaceIntervention) - + def test_vanilla_intervention_positive(self): intervention = VanillaIntervention(embed_dim=2) base = torch.arange(36).view(2, 3, 6) @@ -356,11 +352,16 @@ def test_brs_intervention_positive(self): base = torch.arange(12).view(2, 6) source = torch.arange(12, 24).view(2, 6) output = intervention(base, source) - golden = torch.tensor([[3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14],]) + golden = torch.tensor( + [ + [3, 4, 5, 6, 7, 8], + [9, 10, 11, 12, 13, 14], + ] + ) self.assertTrue(torch.allclose(golden, output)) def test_brs_gradient_positive(self): - + _retry = 10 while _retry > 0: try: @@ -374,7 +375,12 @@ def test_brs_gradient_positive(self): optimizer_params += [{"params": intervention.intervention_boundaries}] optimizer = torch.optim.Adam(optimizer_params, lr=1e-1) - golden = torch.tensor([[5, 6, 7, 8, 9, 10], [11, 12, 13, 14, 15, 16],]).float() + golden = torch.tensor( + [ + [5, 6, 7, 8, 9, 10], + [11, 12, 13, 14, 15, 16], + ] + ).float() for _ in range(1000): optimizer.zero_grad() @@ -391,13 +397,10 @@ def test_brs_gradient_positive(self): if _retry > 0: pass # succeed else: - raise AssertionError( - "test_brs_gradient_positive with retries" - ) - + raise AssertionError("test_brs_gradient_positive with retries") def test_sigmoid_mask_gradient_positive(self): - + _retry = 10 while _retry > 0: try: @@ -410,7 +413,12 @@ def test_sigmoid_mask_gradient_positive(self): optimizer_params += [{"params": intervention.temperature}] optimizer = torch.optim.Adam(optimizer_params, lr=1e-1) - golden = torch.tensor([[0, 1, 14, 15, 16, 17], [6, 7, 20, 21, 22, 23],]).float() + golden = torch.tensor( + [ + [0, 1, 14, 15, 16, 17], + [6, 7, 20, 21, 22, 23], + ] + ).float() for _ in range(2000): optimizer.zero_grad() @@ -430,13 +438,10 @@ def test_sigmoid_mask_gradient_positive(self): if _retry > 0: pass # succeed else: - raise AssertionError( - "test_sigmoid_mask_gradient_positive with retries" - ) - + raise AssertionError("test_sigmoid_mask_gradient_positive with retries") def test_low_rank_gradient_positive(self): - + _retry = 10 while _retry > 0: try: @@ -450,7 +455,12 @@ def test_low_rank_gradient_positive(self): optimizer_params += [{"params": intervention.rotate_layer.parameters()}] optimizer = torch.optim.Adam(optimizer_params, lr=1e-1) - golden = torch.tensor([[0, 1, 14, 15, 16, 17], [6, 7, 20, 21, 22, 23],]).float() + golden = torch.tensor( + [ + [0, 1, 14, 15, 16, 17], + [6, 7, 20, 21, 22, 23], + ] + ).float() for _ in range(2000): optimizer.zero_grad() @@ -458,7 +468,7 @@ def test_low_rank_gradient_positive(self): loss = F.mse_loss(output, golden) loss.backward() optimizer.step() - print(output) + self.assertTrue(torch.allclose(golden, output, rtol=1e-02, atol=1e-02)) except: pass # retry @@ -468,13 +478,11 @@ def test_low_rank_gradient_positive(self): if _retry > 0: pass # succeed else: - raise AssertionError( - "test_sigmoid_mask_gradient_positive with retries" - ) + raise AssertionError("test_sigmoid_mask_gradient_positive with retries") @classmethod - def tearDownClass(self): - for current_dir in self.test_output_dir_pool: + def tearDownClass(cls): + for current_dir in cls.test_output_dir_pool: print(f"Removing testing dir {current_dir}") if os.path.exists(current_dir) and os.path.isdir(current_dir): shutil.rmtree(current_dir) diff --git a/tests/unit_tests/ModelUtilsTestCase.py b/tests/unit_tests/ModelUtilsTestCase.py index d4649559..4aa04b73 100644 --- a/tests/unit_tests/ModelUtilsTestCase.py +++ b/tests/unit_tests/ModelUtilsTestCase.py @@ -1,12 +1,14 @@ import unittest from ..utils import * from pyvene.models.modeling_utils import * +from pprint import pprint, pformat +from transformers import GPT2Config class ModelUtilsTestCase(unittest.TestCase): @classmethod - def setUpClass(self): - self.gpt2_config = GPT2Config( + def setUpClass(cls): + cls.gpt2_config = GPT2Config( n_embd=6, n_head=3, attn_pdrop=0.0, @@ -19,7 +21,7 @@ def setUpClass(self): n_positions=20, vocab_size=10, ) - self.gpt2_model = hf_models.gpt2.modeling_gpt2.GPT2LMHeadModel + cls.gpt2_model = hf_models.gpt2.modeling_gpt2.GPT2LMHeadModel def test_gather_neurons_positive(self): tensor_input = torch.rand((5, 3, 2)) # batch_size, seq_len, emb_dim @@ -42,7 +44,10 @@ def test_output_to_subcomponent_gpt2_no_head_positive(self): golden_output = tensor_input.clone() tensor_output = output_to_subcomponent( - tensor_input, "attention_input", self.gpt2_model, self.gpt2_config, + tensor_input, + "attention_input", + self.gpt2_model, + self.gpt2_config, ) self.assertTrue(torch.allclose(tensor_output, golden_output)) @@ -102,7 +107,7 @@ def test_scatter_neurons_gpt2_batch_diff_fast_no_head_positive(self): golden_output = tensor_input.clone() golden_output[0, 1:3, :] = replacing_tensor_input[0, 0:2, :] # Fast path's behavior is different - golden_output[1, 1:3, :] = replacing_tensor_input[1, 0:2, :] + golden_output[1, 0:2, :] = replacing_tensor_input[1, 0:2, :] tensor_output = scatter_neurons( tensor_input, @@ -110,7 +115,7 @@ def test_scatter_neurons_gpt2_batch_diff_fast_no_head_positive(self): "attention_input", "pos", # each batch is different - ([[1, 2], [0, 1]]), + [[1, 2], [0, 1]], self.gpt2_model, self.gpt2_config, True, @@ -274,7 +279,7 @@ def test_scatter_neurons_gpt2_attn_with_head_positive(self): replacing_tensor_input = torch.arange(60, 84).view(2, 3, 2, 2) # ? - # Replace the heads 1, 2 at positions 0, 1 with the first + # Replace the heads 1, 2 at positions 0, 1 with the first two heads of the replacement golden_output = tensor_input.clone().view(2, 5, 3, 2) golden_output[:, 0:2, 1:3, :] = replacing_tensor_input[:, 0:2, :, :].permute( 0, 2, 1, 3 diff --git a/tests/utils.py b/tests/utils.py index d005a01d..168b5636 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,30 +6,35 @@ import os, shutil, torch, random, uuid import pandas as pd import numpy as np -from transformers import GPT2Config - - +from transformers.utils import ModelOutput import subprocess + def is_package_installed(package_name): try: # Execute 'pip list' command and capture the output - result = subprocess.run(['pip', 'list'], stdout=subprocess.PIPE, text=True) + pkg_all = subprocess.run( + ["pip", "list"], stdout=subprocess.PIPE, text=True + ).stdout + pkg_e = subprocess.run( + ["pip", "list", "-e"], stdout=subprocess.PIPE, text=True + ).stdout # Check if package_name is in the result - return package_name in result.stdout + return package_name in pkg_all and package_name not in pkg_e except Exception as e: print(f"An error occurred: {e}") return False + # Replace 'pyvene' with the name of the package you want to check -package_name = 'pyvene' +package_name = "pyvene" if is_package_installed(package_name): raise RuntimeError( - f"Remove your pip installed {package_name} before running tests.") + f"Remove your pip installed {package_name} before running tests." + ) else: - print(f"'{package_name}' is not installed.") - print("PASS: pyvene is not installed. Testing local dev code.") + print(f"PASS: {package_name} is not installed. Testing local dev code.") from pyvene.models.basic_utils import embed_to_distrib, top_vals, format_token from pyvene.models.configuration_intervenable_model import ( @@ -41,6 +46,15 @@ def is_package_installed(package_name): from pyvene.models.mlp.modelings_mlp import MLPConfig from pyvene.models.mlp.modelings_intervenable_mlp import create_mlp_classifier from pyvene.models.gpt2.modelings_intervenable_gpt2 import create_gpt2_lm +from pyvene import embed_to_distrib + + +def get_topk(model, tokenizer, outputs: ModelOutput, k=20): + dist = embed_to_distrib(model, outputs.last_hidden_state, logits=False) + + if dist is not None: + _, ind = torch.topk(dist[:, -1], 20, dim=-1) + return tokenizer.batch_decode(ind) ################## diff --git a/tutorials/basic_tutorials/Generation_Intervention.ipynb b/tutorials/basic_tutorials/Generation_Intervention.ipynb new file mode 100644 index 00000000..e884605c --- /dev/null +++ b/tutorials/basic_tutorials/Generation_Intervention.ipynb @@ -0,0 +1,241 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loaded GPTNeo model roneneldan/TinyStories-33M\n", + "GPTNeoForCausalLM(\n", + " (transformer): GPTNeoModel(\n", + " (wte): Embedding(50257, 768)\n", + " (wpe): Embedding(2048, 768)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " (h): ModuleList(\n", + " (0-3): 4 x GPTNeoBlock(\n", + " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (attn): GPTNeoAttention(\n", + " (attention): GPTNeoSelfAttention(\n", + " (attn_dropout): Dropout(p=0.0, inplace=False)\n", + " (resid_dropout): Dropout(p=0.0, inplace=False)\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " )\n", + " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (mlp): GPTNeoMLP(\n", + " (c_fc): Linear(in_features=768, out_features=3072, bias=True)\n", + " (c_proj): Linear(in_features=3072, out_features=768, bias=True)\n", + " (act): NewGELUActivation()\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n", + ")\n" + ] + } + ], + "source": [ + "import torch\n", + "import pyvene as pv\n", + "import pprint\n", + "\n", + "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "config, tokenizer, model = pv.create_gpt_neo()\n", + "model.to(DEVICE)\n", + "pprint.pprint(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([2, 2, 384])\n", + "16\n" + ] + } + ], + "source": [ + "from typing import Dict\n", + "\n", + "collect_model = pv.IntervenableModel(\n", + " [\n", + " {\n", + " \"layer\": l,\n", + " \"component\": \"block_output\",\n", + " \"intervention_type\": pv.CollectIntervention,\n", + " }\n", + " for l in range(1, config.num_layers)\n", + " ],\n", + " model=model,\n", + ")\n", + "\n", + "p_plus = \" love\"\n", + "p_minus = \" hate\"\n", + "\n", + "res = collect_model(\n", + " base=tokenizer([p_plus, p_minus], return_tensors=\"pt\").to(DEVICE),\n", + " unit_locations={\"base\": 0},\n", + " return_dict=True,\n", + ")[\"collected_activations\"]\n", + "\n", + "print(res['layer.1.comp.block_output.unit.pos.nunit.1#0'].shape)\n", + "print(config.num_heads)\n", + "\n", + "diff: Dict[str, torch.Tensor] = {}\n", + "\n", + "for k, v in res.items():\n", + " diff[k] = torch.reshape(res[k][0] - res[k][1], (-1,))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "'Unintervened generation:'\n", + "['I hate you because I don\\'t want you to be my friend.\"\\n'\n", + " '\\n'\n", + " 'The little girl was sad and went home. She told her mom what happened and']\n", + "Intervened generation:\n", + "['I hate you because batch of cookies are not good for you. You are a bad '\n", + " \"sister. I don't want to play with you anymore. I want to\"]\n" + ] + } + ], + "source": [ + "intv_model = pv.IntervenableModel(\n", + " [\n", + " {\n", + " \"layer\": l,\n", + " \"component\": \"block_output\",\n", + " \"intervention\": lambda b, s: b + 10 * s,\n", + " }\n", + " for l in range(1, config.num_layers)\n", + " ],\n", + " model=model,\n", + ")\n", + "\n", + "# ActAdd on prompt (original setting)\n", + "orig, intervened = intv_model.generate(\n", + " base=tokenizer(\"I hate you because\", return_tensors=\"pt\").to(DEVICE),\n", + " source_representations=diff,\n", + " unit_locations={\"sources->base\": (0, 3)},\n", + " output_original_output=True,\n", + " max_length=32,\n", + ")\n", + "\n", + "pprint.pprint('Unintervened generation:')\n", + "pprint.pprint(tokenizer.batch_decode(orig))\n", + "\n", + "print('Intervened generation:')\n", + "pprint.pprint(tokenizer.batch_decode(intervened))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "'Unintervened generation:'\n", + "['I hate you because I don\\'t want you to be my friend.\"\\n'\n", + " '\\n'\n", + " 'The little girl was sad and went home. She told her mom what happened and '\n", + " 'her mom said, \"Don\\'t worry, we can make a new friend tomorrow.\" The']\n", + "Intervened generation:\n", + "['I hate you because I don\\'t want you to be my friend.\" batch of cookies and '\n", + " 'they both laughed. The end batch of cookies were so delicious that they made '\n", + " 'the batch of cookies and they both ate them together. batch of cookies were '\n", + " 'so tasty']\n" + ] + } + ], + "source": [ + "# ActAdd on decoded region\n", + "orig, intervened = intv_model.generate(\n", + " base=tokenizer(\"I hate you because\", return_tensors=\"pt\").to(DEVICE),\n", + " source_representations=diff,\n", + " unit_locations={\"sources->base\": (0, 3)},\n", + " intervene_on_prompt=False,\n", + " timestep_selector=[lambda idx, o: idx % 10 == 0 for i in range(3)],\n", + " output_original_output=True,\n", + " max_length=50\n", + ")\n", + "\n", + "pprint.pprint('Unintervened generation:')\n", + "pprint.pprint(tokenizer.batch_decode(orig))\n", + "\n", + "print('Intervened generation:')\n", + "pprint.pprint(tokenizer.batch_decode(intervened))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}