Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Centered rib #257

Merged
merged 75 commits into from
Jan 17, 2024
Merged

Centered rib #257

merged 75 commits into from
Jan 17, 2024

Conversation

nix-apollo
Copy link
Collaborator

@nix-apollo nix-apollo commented Dec 11, 2023

Centred rib

Description

  • Incorporates a centring matrix Y (aka Gamma) into the calculation of the rib rotation
  • Isolates the bias direction when center = true, returning it in the 0th position.
  • Reorganizes the order of the residual stream so the last position is always the bias position.

Related Issue

Closes #248

Motivation and Context

This is a first version of centered rib. It seems likely that we might want to handle lambda and/or edge calculation differently in the future to make the calculation more principled. For instance, the baseline for IG is still the 0 point, instead of the mean activation.

How Has This Been Tested?

I have a test that checks invariants for the output of the centered rib build. Including that there is a single constant direction pointing in the direction we expect, and that activations in all other rib directions are centered.

This code was also used for various analysis in the OP report, where it seemed to do reasonable things.

Does this PR introduce a breaking change?

Residual stream reorder may break some analysis code. No interface changes.

rib/interaction_algos.py Outdated Show resolved Hide resolved
rib/linalg.py Outdated Show resolved Hide resolved
@nix-apollo nix-apollo changed the title [wip] centred rib Centered rib Jan 11, 2024
Copy link
Contributor

@danbraunai-apollo danbraunai-apollo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't looked deeply yet, but flagging that the test_centered_rib_modadd() is failing heavily on this line:

        # Check 2: no other rib dir has non-zero component at bias positions
        assert_is_zeros(C_inv[1:][:, bias_positions], atol=atol, m_name=m_name)

where the tensor has many values that are on the order 1e-2.

Copy link
Contributor

@danbraunai-apollo danbraunai-apollo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I accidentally pressed submit review too early, so combine this review with the subsequent one. Sorry.

Misc:

  • EDIT: Oh, I see you have a todo for this as a comment, all good. Maybe put todos in the PR description so it's easier to find. Still an issue with the centered_rib_test causing tests to fail.
  • I don't like or understand why all of the test__build_graph went from an atol=0 to a bigger atol (and mnist went from 1e-5 to 1e-4), when none of those use centering. Hopefully a solution to the above issue will fix this one. Any ideas what this could be, or where the computation is different?
  • I think shift_matrix could do with unit tests. And maybe even elaborate on the example in the docstring to give an example x where it's used to shift the activations by the mean. I just find that function very confusing even after previously spending time and understanding it.
  • I wouldn't mind a find + replace on centered -> centred. My bad for not being consistent here.

rib/analysis_utils.py Show resolved Hide resolved
rib/interaction_algos.py Outdated Show resolved Hide resolved
rib/interaction_algos.py Outdated Show resolved Hide resolved
rib/linalg.py Outdated Show resolved Hide resolved
rib/linalg.py Outdated Show resolved Hide resolved
rib/utils.py Outdated Show resolved Hide resolved
Copy link
Contributor

@danbraunai-apollo danbraunai-apollo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Some requested changes/confusions.

tests/test_build_graph.py Outdated Show resolved Hide resolved
tests/test_build_graph.py Outdated Show resolved Hide resolved
@nix-apollo
Copy link
Collaborator Author

We only ever need a bias position for both:

  • making a shift matrix
  • identifying the constant rib dir (to move it to the 0th in ordering)
    both of these only need a single constant pos. A good simplification would be to have get_dataset_means return a single int for the bias pos instead of a tensor. Even better this could be a separate helper function that only needs the means as input.

Possibly I'll postpone this and not improve the current situation in this PR.

@nix-apollo
Copy link
Collaborator Author

Re: test tolerance. This was because I had accidentally made a test stricter by going from rtol=1e-5 (pytorch's default) to rtol=0 (pytest's default). Not because computation got less precise. I've reverted the change for consistency.

Copy link
Contributor

@danbraunai-apollo danbraunai-apollo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved. Looks great, and much cleaner after swapping residual stream to be at the bottom. I really like your updates to calculate_interaction_rotations. That function is actually (on the way to becoming) understandable now.

I made a few comments. Have a quick look, and update yourself or shout out if you see an issue with them. Then you can merge.

rib/data_accumulator.py Outdated Show resolved Hide resolved
tests/test_build_graph.py Show resolved Hide resolved
@nix-apollo nix-apollo merged commit aaa53ea into main Jan 17, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support rib centering
3 participants