Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] Adding interchange intervention for SAEs #187

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyvene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@
from .models.interventions import NoiseIntervention
from .models.interventions import SigmoidMaskIntervention
from .models.interventions import AutoencoderIntervention
from .models.interventions import JumpReLUAutoencoderIntervention
from .models.interventions import InterventionOutput


# Utils
from .models.basic_utils import *
from .models.intervention_utils import _do_intervention_by_swap
from .models.intervenable_modelcard import type_to_module_mapping, type_to_dimension_mapping
from .models.gpt2.modelings_intervenable_gpt2 import create_gpt2
from .models.gpt2.modelings_intervenable_gpt2 import create_gpt2_lm
Expand Down
44 changes: 44 additions & 0 deletions pyvene/models/interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,47 @@ def forward(self, base, source, subspaces=None):

def __str__(self):
return f"AutoencoderIntervention()"


class JumpReLUAutoencoderIntervention(TrainableIntervention):
"""Interchange intervention on JumpReLU SAE's latent subspaces"""
def __init__(self, **kwargs):
# Note that we initialise these to zeros because we're loading in pre-trained weights.
# If you want to train your own SAEs then we recommend using blah
super().__init__(**kwargs, keep_last_dim=True)
self.W_enc = torch.nn.Parameter(torch.zeros(self.embed_dim, kwargs["low_rank_dimension"]))
self.W_dec = torch.nn.Parameter(torch.zeros(kwargs["low_rank_dimension"], self.embed_dim))
self.threshold = torch.nn.Parameter(torch.zeros(kwargs["low_rank_dimension"]))
self.b_enc = torch.nn.Parameter(torch.zeros(kwargs["low_rank_dimension"]))
self.b_dec = torch.nn.Parameter(torch.zeros(self.embed_dim))

def encode(self, input_acts):
pre_acts = input_acts @ self.W_enc + self.b_enc
mask = (pre_acts > self.threshold)
acts = mask * torch.nn.functional.relu(pre_acts)
return acts

def decode(self, acts):
return acts @ self.W_dec + self.b_dec

def forward(self, base, source=None, subspaces=None):
# generate latents for base and source runs.
base_latent = self.encode(base)
source_latent = self.encode(source)
# intervention.
intervened_latent = _do_intervention_by_swap(
base_latent,
source_latent,
"interchange",
self.interchange_dim,
subspaces,
subspace_partition=self.subspace_partition,
use_fast=self.use_fast,
)
# decode intervened latent.
recon = self.decode(intervened_latent)
return recon

def __str__(self):
return f"JumpReLUAutoencoderIntervention()"

192 changes: 111 additions & 81 deletions tutorials/basic_tutorials/Sparse_Autoencoder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
},
{
"cell_type": "code",
"execution_count": 68,
"execution_count": 1,
"id": "e5d14f0a-b02d-4d1f-863a-dbb1e475e264",
"metadata": {},
"outputs": [],
"source": [
"__author__ = \"Zhengxuan Wu\"\n",
"__version__ = \"08/07/2024\""
"__version__ = \"09/23/2024\""
]
},
{
Expand All @@ -50,7 +50,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"id": "dd197c1f-71b5-4379-a9dd-2f6ff27083f6",
"metadata": {},
"outputs": [
Expand All @@ -75,7 +75,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 4,
"id": "209bfc46-7685-4e66-975f-3280ed516b52",
"metadata": {},
"outputs": [],
Expand All @@ -85,7 +85,8 @@
" SourcelessIntervention,\n",
" TrainableIntervention,\n",
" DistributedRepresentationIntervention,\n",
" CollectIntervention\n",
" CollectIntervention,\n",
" JumpReLUAutoencoderIntervention\n",
")\n",
"\n",
"from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer\n",
Expand All @@ -108,28 +109,14 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"id": "a6e7e7fb-5e73-4711-b378-bc1b04ab1e7f",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7e4aaac37998428dbe22cc95595c3fcc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors.index.json: 0%| | 0.00/24.2k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fa51ca858082486bb23796d8146aadea",
"model_id": "192a06afdbdc4c868bc6d20677b3dd38",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -143,49 +130,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "77cf508dd4bc42e19452855fdeb8744b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00001-of-00003.safetensors: 0%| | 0.00/4.99G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2913e6370a364fe6823d35b31ccbd9e4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00002-of-00003.safetensors: 0%| | 0.00/4.98G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "18e7e07b13394f58bab2b5c4c6b32394",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00003-of-00003.safetensors: 0%| | 0.00/481M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6defff33e9d74d8b9f754d20092dd7b4",
"model_id": "3e088f2b3808489bac70b9aeb5ae73f0",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -195,20 +140,6 @@
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a5981f2c5b6f4e41acd7a7ad5a574557",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"generation_config.json: 0%| | 0.00/168 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
Expand Down Expand Up @@ -276,7 +207,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 7,
"id": "d490a50c-a1cd-4def-90c2-cd6bfe67266f",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -311,6 +242,7 @@
"class JumpReLUSAECollectIntervention(\n",
" CollectIntervention\n",
"):\n",
" \"\"\"Collect activations\"\"\"\n",
" def __init__(self, **kwargs):\n",
" # Note that we initialise these to zeros because we're loading in pre-trained weights.\n",
" # If you want to train your own SAEs then we recommend using blah\n",
Expand Down Expand Up @@ -500,7 +432,7 @@
"metadata": {},
"outputs": [],
"source": [
"class JumpReLUSAEIntervention(\n",
"class JumpReLUSAESteeringIntervention(\n",
" SourcelessIntervention,\n",
" TrainableIntervention, \n",
" DistributedRepresentationIntervention\n",
Expand Down Expand Up @@ -544,7 +476,7 @@
"metadata": {},
"outputs": [],
"source": [
"sae = JumpReLUSAEIntervention(\n",
"sae = JumpReLUSAESteeringIntervention(\n",
" embed_dim=params['W_enc'].shape[0],\n",
" low_rank_dimension=params['W_enc'].shape[1]\n",
")\n",
Expand Down Expand Up @@ -614,6 +546,104 @@
"source": [
"**Here you go: a \"Space-travel, time-travel\" Doodle!**"
]
},
{
"cell_type": "markdown",
"id": "22cabe19-2c2f-46c7-a631-d0b40fca5308",
"metadata": {},
"source": [
"### Interchange intervention with JumpReLU SAEs.\n",
"\n",
"You can also swap values between examples for a specific latent dimension. However, since SAE usually maps a concpet to 1D subspace, swapping between examples and resetting the scalar to another value are similar.\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4f23b199-ca01-4676-9a2d-61b24b96dc2f",
"metadata": {},
"outputs": [],
"source": [
"sae = JumpReLUAutoencoderIntervention(\n",
" embed_dim=params['W_enc'].shape[0],\n",
" low_rank_dimension=params['W_enc'].shape[1]\n",
")\n",
"sae.load_state_dict(pt_params, strict=False)\n",
"sae.cuda()\n",
"\n",
"# add the intervention to the model computation graph via the config\n",
"pv_model = pyvene.IntervenableModel({\n",
" \"component\": f\"model.layers[{LAYER}].output\",\n",
" \"intervention\": sae}, model=model)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9dbe3883-3588-45fe-91bf-aeb075dea642",
"metadata": {},
"outputs": [],
"source": [
"base = tokenizer(\n",
" \"Which dog breed do people think is cuter, poodle or doodle?\", \n",
" return_tensors=\"pt\").to(\"cuda\")\n",
"source = tokenizer(\n",
" \"Origin (general) Space-travel, time-travel\", \n",
" return_tensors=\"pt\").to(\"cuda\")\n",
"\n",
"# run an interchange intervention \n",
"original_outputs, intervened_outputs = pv_model(\n",
" # the base input\n",
" base=base, \n",
" # the source input\n",
" sources=source, \n",
" # the location to intervene (swap last tokens)\n",
" unit_locations={\"sources->base\": (11, 14)},\n",
" # the SAE latent dimension mapping to the time travel concept (\"10004\")\n",
" subspaces=[10004],\n",
" output_original_output=True\n",
")\n",
"logits_diff = intervened_outputs.logits[:,-1] - original_outputs.logits[:,-1]\n",
"values, indices = logits_diff.topk(k=10, sorted=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "57b8c19a-c73f-47e5-b7f3-6b9353802a96",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"** topk logits diff **\n"
]
},
{
"data": {
"text/plain": [
"['PhysRevD',\n",
" ' transporting',\n",
" ' teleport',\n",
" ' space',\n",
" ' transit',\n",
" ' transported',\n",
" ' transporter',\n",
" ' transpor',\n",
" ' multiverse',\n",
" ' universes']"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(\"** topk logits diff **\")\n",
"tokenizer.batch_decode(indices[0].unsqueeze(dim=-1))"
]
}
],
"metadata": {
Expand Down
Loading