Skip to content

Commit

Permalink
Merge pull request #20 from pomonam/dev
Browse files Browse the repository at this point in the history
dev branch to main
  • Loading branch information
pomonam authored Jun 11, 2024
2 parents cff2a88 + 7287d30 commit 1464e19
Show file tree
Hide file tree
Showing 41 changed files with 1,430 additions and 906 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,11 @@ jobs:
run: |
pip install -e ."[dev]"
pytest -vx tests/test_analyzer.py
pytest -vx tests/test_dataset_utils.py
pytest -vx tests/test_testable_tasks.py
pytest -vx tests/factors/test_covariances.py
pytest -vx tests/factors/test_eigens.py
pytest -vx tests/modules/test_modules.py
pytest -vx tests/modules/test_per_sample_gradients.py
pytest -vx tests/scores/test_pairwise_scores.py
pytest -vx tests/scores/test_self_scores.py
93 changes: 46 additions & 47 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
# Kronfluence: Technical Documentation & FAQs

For a detailed description of the methodology, see the [**paper**](https://arxiv.org/abs/2308.03296) *Studying Large Language Model Generalization with Influence Functions*.
For a detailed description of the methodology, please refer to the [**paper**](https://arxiv.org/abs/2308.03296) *Studying Large Language Model Generalization with Influence Functions*.

## Requirements

Kronfluence has been tested on the following versions of [PyTorch](https://pytorch.org/):
- PyTorch >= 2.1;
- Python >= 3.9.
Kronfluence has been tested and is compatible with the following versions of [PyTorch](https://pytorch.org/):
- PyTorch 2.1 or higher
- Python 3.9 or higher

## Supported Modules & Strategies

Kronfluence supports:
- Computing influence functions on selected PyTorch modules. At the moment, we support `nn.Linear` and `nn.Conv2d`;
- Computing influence functions with several Hessian approximation strategies: `identity`, `diagonal`, `KFAC`, and `EKFAC`;
- Computing pairwise and self-influence scores.
Kronfluence offers support for:
- Computing influence functions on selected PyTorch modules. Currently, we support `nn.Linear` and `nn.Conv2d`.
- Computing influence functions with several Hessian approximation strategies, including `identity`, `diagonal`, `KFAC`, and `EKFAC`.
- Computing pairwise and self-influence (with and without measurement) scores.

> [!NOTE]
> We are planning to support functionalities to ensemble influence scores in next release.
> [!NOTE]
> If there are specific modules you would like to see supported, please submit an issue.
> If there are additional modules you would like to see supported, please submit an issue on our GitHub repository.
---

Expand Down Expand Up @@ -103,6 +100,7 @@ After calling `prepare_model`, you can create [DistributedDataParallel (DDP)](ht
**Set up the Analyzer and Fit Factors.**
Initialize the `Analyzer` and execute `fit_all_factors` to compute all factors that aim to approximate the Hessian
(or Gauss-Newton Hessian). The computed factors will be stored on disk.

```python
from kronfluence.analyzer import Analyzer
from kronfluence.utils.dataset import DataLoaderKwargs
Expand All @@ -121,6 +119,7 @@ analyzer.fit_all_factors(factors_name="initial_factor", dataset=train_dataset)
**Compute Influence Scores.**
Once the factors have been computed, you can compute pairwise and self-influence scores. When computing the scores,
you can specify the factor name you would like to use.

```python
...
scores = analyzer.compute_pairwise_scores(
Expand All @@ -139,7 +138,7 @@ You can organize all factors and scores for the specific model with `factors_nam

**What should I do if my model does not have any nn.Linear or nn.Conv2d modules?**
Currently, the implementation does not support influence computations for modules other than `nn.Linear` or `nn.Conv2d`.
Try rewriting the model so that it uses supported modules (as done for the `conv1d` module in [GPT-2](https://github.com/pomonam/kronfluence/tree/documentation/examples/wikitext)).
Try rewriting the model so that it uses supported modules (as done for the `conv1d` module in [GPT-2 example](https://github.com/pomonam/kronfluence/tree/documentation/examples/wikitext)).
Alternatively, you can create a subclass of `TrackedModule` to compute influence scores for your custom module.
If there are specific modules you would like to see supported, please submit an issue.

Expand All @@ -150,11 +149,11 @@ inspect `model.named_modules()` to determine what modules to use. You can specif

> [!NOTE]
> If the embedding layer for transformers are defined with `nn.Linear`, you must write
> `task.tracked_modules` to avoid influence computations embedding matrices (it is too expensive).
> `task.tracked_modules` to avoid influence computations embedding matrices.
**How should I implement Task.compute_train_loss?**
Implement the loss function used to train the model. Note that the function should return
the summed loss (over batches and tokens) and should not include regularizations.
the summed loss (over batches and tokens).

**How should I implement Task.compute_measurement?**
It depends on the analysis you would like to perform. Influence functions approximate the [effect of downweighting/upweighting
Expand All @@ -167,7 +166,8 @@ cause `TrackedModuleNotFoundError`.

**My model uses supported modules, but influence scores are not computed.**
Kronfluence uses module hooks to compute factors and influence scores. For these to be tracked and computed,
the model should directly call the module.
the model's forward pass should directly call the module.

```python
import torch
from torch import nn
Expand All @@ -180,22 +180,21 @@ def forward(x: torch.Tensor) -> torch.Tensor:
```

**I get X error when fitting factors/computing scores.**
Please feel free to contact us by [filing an issue](https://github.com/pomonam/kronfluence/issues) or [through email](mailto:jbae@cs.toronto.edu).
Please feel free to contact me by [filing an issue](https://github.com/pomonam/kronfluence/issues) or [through email](mailto:jbae@cs.toronto.edu).

---

## Configuring Factors with FactorArguments


```python
import torch
from kronfluence.arguments import FactorArguments

factor_args = FactorArguments(
strategy="ekfac", # Choose from "identity", "diagonal", "kfac", or "ekfac".
use_empirical_fisher=False,
immediate_gradient_removal=False,
ignore_bias=False,
distributed_sync_steps=1000,
amp_dtype=None,

# Settings for covariance matrix fitting.
covariance_max_examples=100_000,
Expand All @@ -221,17 +220,14 @@ analyzer.fit_all_factors(factors_name="initial_factor", dataset=train_dataset, f
```

You can change:
- `strategy`: Selects the Hessian approximation strategy (`identity`, `diagonal`, `KFAC`, or `EKFAC`).
- `strategy`: Selects the Hessian approximation strategy (`identity`, `diagonal`, `kfac`, or `ekfac`).
- `use_empirical_fisher`: Determines whether to use the [empirical Fisher](https://arxiv.org/abs/1905.12558) (using actual labels from batch)
instead of the true Fisher (using sampled labels from model's predictions). It is recommended to be `False`.
- `immediate_gradient_removal`: Specifies whether to instantly set `param.grad = None` within module hooks. Generally,
recommended to be `False`, as it requires installing additional hooks. This should not affect the fitted factors, but
can potentially reduce peak memory.
- `ignore_bias`: Specifies whether to ignore factor computations on bias.
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.

### Fitting Covariance Matrices

`KFAC` and `EKFAC` require computing the uncentered activation and pre-activation pseudo-gradient covariance matrices.
`kfac` and `ekfac` require computing the uncentered activation and pre-activation pseudo-gradient covariance matrices.
To fit covariance matrices, you can use `analyzer.fit_covariance_matrices`.
```python
# Fitting covariance matrices.
Expand Down Expand Up @@ -260,13 +256,11 @@ or `torch.float16`.
**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_batch_size` when fitting covariance matrices.
2. Try using lower precision for `activation_covariance_dtype` and `gradient_covariance_dtype`.
3. Try setting `immediate_gradient_removal=True`.
4. Try setting `covariance_module_partition_size > 1`.

3. Try setting `covariance_module_partition_size > 1`.

### Performing Eigendecomposition

After computing the covariance matrices, `KFAC` and `EKFAC` require performing eigendecomposition.
After computing the covariance matrices, `kfac` and `ekfac` require performing eigendecomposition.

```python
# Performing Eigendecomposition.
Expand All @@ -281,7 +275,7 @@ but `torch.float64` is recommended.

### Fitting Lambda Matrices

`EKFAC` and `diagonal` require computing the Lambda (eigenvalue) matrices for all modules.
`ekfac` and `diagonal` require computing the Lambda (eigenvalue) matrices for all modules.

```python
# Fitting Lambda matrices.
Expand All @@ -306,8 +300,7 @@ or `torch.float16`.
1. Try reducing the `per_device_batch_size` when fitting Lambda matrices.
2. Try setting `lambda_iterative_aggregate=True` or `cached_activation_cpu_offload=True`.
3. Try using lower precision for `lambda_dtype`.
4. Try setting `immediate_gradient_removal=True`.
5. Try using `lambda_module_partition_size > 1`.
4. Try using `lambda_module_partition_size > 1`.

### FAQs

Expand All @@ -329,36 +322,41 @@ import torch
from kronfluence.arguments import ScoreArguments

score_args = ScoreArguments(
damping=None,
immediate_gradient_removal=False,
damping=1e-08,
cached_activation_cpu_offload=False,
distributed_sync_steps=1000,
amp_dtype=None,

data_partition_size=1,
module_partition_size=1,
per_module_score=False,

# Configuration for query batching.
query_gradient_rank=None,
query_gradient_svd_dtype=torch.float64,
query_gradient_svd_dtype=torch.float32,
num_query_gradient_aggregations=1,
use_measurement_for_self_influence=False,

cached_activation_cpu_offload=False,
score_dtype=torch.float32,
per_sample_gradient_dtype=torch.float32,
precondition_dtype=torch.float32,
)
```

- `damping`: A damping factor for the damped matrix-vector product. Uses a heuristic based on mean eigenvalues
- `damping`: A damping factor for the damped inverse Hessian-vector product (iHVP). Uses a heuristic based on mean eigenvalues
(0.1 x mean eigenvalues) if None.
- `immediate_gradient_removal`: Whether to immediately remove `param.grad` within a hook.
- `cached_activation_cpu_offload`: Whether to offload cached activations to CPU.
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
- `data_partition_size`: Number of data partitions for computing influence scores.
- `module_partition_size`: Number of module partitions for computing influence scores.
- `per_module_score`: Whether to return a per-module influence scores. Instead of summing over influences across
all modules, this will keep track of intermediate module-wise scores.

- `query_gradient_rank`: The rank for the query batching (low-rank approximation to the query gradient; see Section 3.2.2). If `None`, no query batching will be used.
- `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can also use `torch.float32`.
- `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can also use `torch.float64`.
- `num_query_gradient_aggregations`: Number of query gradients to aggregate over.
- `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores.

- `cached_activation_cpu_offload`: Whether to offload cached activations to CPU.
- `score_dtype`: `dtype` for computing influence scores. You can use `torch.bfloat16` or `torch.float16`.
- `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can use `torch.bfloat16` or `torch.float16`.
- `precondition_dtype`: `dtype` for performing preconditioning. You can use `torch.bfloat16` or `torch.float16`,
Expand All @@ -367,6 +365,7 @@ but `torch.float32` is recommended.
### Computing Influence Scores

To compute pairwise influence scores (Equation 5 in the paper), you can run:

```python
# Computing pairwise influence scores.
analyzer.compute_pairwise_scores(scores_name="pairwise", factors_name="ekfac", score_args=score_args)
Expand All @@ -375,21 +374,21 @@ scores = analyzer.load_pairwise_scores(scores_name="pairwise")
```

To compute self-influence scores (see Section 5.4 from [paper](https://arxiv.org/pdf/1703.04730.pdf)), you can run:

```python
# Computing pairwise influence scores.
# Computing self-influence scores.
analyzer.compute_self_scores(scores_name="self", factors_name="ekfac", score_args=score_args)
# Loading pairwise influence scores.
# Loading self-influence scores.
scores = analyzer.load_self_scores(scores_name="self")
```

**Dealing with OOMs** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_query_batch_size` or `per_device_train_batch_size`.
2. Try setting `cached_activation_cpu_offload=True`.
3. Try using lower precision for `per_sample_gradient_dtype` and `score_dtype`.
4. Try setting `immediate_gradient_removal=True`.
5. Try setting `query_gradient_rank > 1`. The recommended values are `16`, `32`, `64`, `128`, and `256`. Note that query
batching is only supported for computing pairwise influence scores, not self-infleucen scores.
6. Try setting `module_partition_size > 1`.
4. Try setting `query_gradient_rank > 1`. The recommended values are `16`, `32`, `64`, `128`, and `256`. Note that query
batching is only supported for computing pairwise influence scores, not self-influence scores.
5. Try setting `module_partition_size > 1`.

### FAQs

Expand Down
27 changes: 12 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
<a href="#"><img width="380" img src="https://raw.githubusercontent.com/pomonam/kronfluence/main/.assets/kronfluence.svg" alt="Kronfluence"/></a>
</p>


<p align="center">
<a href="https://pypi.org/project/kronfluence">
<img alt="License" src="https://img.shields.io/pypi/v/kronfluence.svg?style=flat-square">
Expand Down Expand Up @@ -38,13 +37,13 @@ For a detailed description of the methodology, see the [**paper**](https://arxiv
> - Python: Version 3.9 or later
> - PyTorch: Version 2.1 or later
To install the latest version, use the following `pip` command:
To install the latest stable version, use the following `pip` command:

```bash
pip install kronfluence
```

Alternatively, you can install the library directly from the source:
Alternatively, you can install directly from source:

```bash
git clone https://github.com/pomonam/kronfluence.git
Expand All @@ -54,13 +53,11 @@ pip install -e .

## Getting Started

Kronfluence supports influence computations on `nn.Linear` and `nn.Conv2d` modules. See the [**Technical Documentation**](https://github.com/pomonam/kronfluence/blob/main/DOCUMENTATION.md)
page for a comprehensive guide.
Kronfluence supports influence computations on `nn.Linear` and `nn.Conv2d` modules. See the [**Technical Documentation**](https://github.com/pomonam/kronfluence/blob/main/DOCUMENTATION.md) page for a comprehensive guide.

### Learn More

The [examples](https://github.com/pomonam/kronfluence/tree/main/examples) folder contains several examples demonstrating how to use Kronfluence.
More examples will be added in the future.
The [examples](https://github.com/pomonam/kronfluence/tree/main/examples) folder contains several examples demonstrating how to use Kronfluence. More examples will be added in the future.
**TL;DR** You need to prepare a trained model and datasets, and pass them into the `Analyzer` class.

```python
Expand Down Expand Up @@ -95,7 +92,7 @@ eval_dataset = torchvision.datasets.MNIST(
train=True,
)

# Define the task.
# Define the task. See the Technical Documentation page for details.
task = MnistTask()

# Prepare the model for influence computation.
Expand All @@ -120,12 +117,12 @@ scores = analyzer.load_pairwise_scores(scores_name="my_scores")

## Contributing

Your contributions are welcome! For bug fixes, please submit a pull request without prior discussion. For proposing
new features, examples, or extensions, kindly start a discussion through an issue before proceeding.
Contributions are welcome! To get started, please review our [Code of Conduct](https://github.com/pomonam/kronfluence/blob/main/CODE_OF_CONDUCT.md). For bug fixes, please submit a pull request.
If you would like to propose new features or extensions, we kindly request that you open an issue first to discuss your ideas.

### Setting Up Development Environment

To contribute, you will need to set up a development environment on your machine.
To contribute to Kronfluence, you will need to set up a development environment on your machine.
This setup includes installing all the dependencies required for linting and testing.

```bash
Expand All @@ -134,10 +131,10 @@ cd kronfluence
pip install -e ."[dev]"
```

### Contributors
[Juhan Bae](https://github.com/pomonam/kronfluence), [Omkar Dige](https://github.com/xeon27), and [Adil Asif](https://github.com/adil-a/)
are the main contributors to this repository.
We thank Lev McKinney, Sang Keun Choe, Hwijeen Ahn, Minsoo Kang, Youngseog Chung, Kewen Zhao, and Laura Ruis for their feedback during the development process.
## Acknowledgements

[Omkar Dige](https://github.com/xeon27) contributed to the profiling, DDP, and FSDP utilities, and [Adil Asif](https://github.com/adil-a/) provided valuable insights and suggestions on structuring the DDP and FSDP implementations.
I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.

## License

Expand Down
5 changes: 3 additions & 2 deletions examples/uci/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pip install -r requirements.txt

## Training

To train a regression model on the Concrete dataset, run the following command:
To train a regression model on the `Concrete` dataset, run the following command:
```bash
python train.py --dataset_name concrete \
--dataset_dir ./data \
Expand All @@ -34,7 +34,8 @@ You can also use `identity`, `diagonal`, and `kfac`.

## Counterfactual Evaluation

You can check the notebook `tutorial.ipynb` to run the counterfactual evaluation.
You can check the notebook `tutorial.ipynb` to run the subset removal counterfactual evaluation.
(Note that `TracIn` uses the final checkpoint instead of the intermediate checkpoints throughout training.)

<p align="center">
<a href="#"><img width="380" img src="figure/counterfactual.png" alt="Counterfactual"/></a>
Expand Down
7 changes: 4 additions & 3 deletions examples/uci/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,16 @@ def compute_train_loss(
if not sample:
return F.mse_loss(outputs, targets, reduction="sum")
with torch.no_grad():
sampled_targets = torch.normal(outputs, std=math.sqrt(0.5))
return F.mse_loss(outputs, sampled_targets.detach(), reduction="sum")
sampled_targets = torch.normal(outputs.detach(), std=math.sqrt(0.5))
return F.mse_loss(outputs, sampled_targets, reduction="sum")

def compute_measurement(
self,
batch: BATCH_TYPE,
model: nn.Module,
) -> torch.Tensor:
# The measurement function is set as a training loss.
# The measurement function is set as a training loss. Alternatively, we can
# use mean absolute error, as done in https://arxiv.org/abs/2405.12186.
return self.compute_train_loss(batch, model, sample=False)


Expand Down
Loading

0 comments on commit 1464e19

Please sign in to comment.