Skip to content

A Python library for calculating and visualizing mutual information in neural networks. This repository includes methods to calculate mutual information using various techniques (binning, KDE, Kraskov) and tools to train neural networks and plot information plane dynamics.

License

Notifications You must be signed in to change notification settings

aidinattar/info-bottleneck

Repository files navigation

MNIST Classification with Mutual Information Analysis

This repository contains code for training simple fully connected and convolutional neural networks on the MNIST dataset. Additionally, it provides tools for calculating mutual information using different methods to analyze the layers' information retention during training.

Table of Contents

Installation

  1. Clone the repository:

    git clone https://github.com/aidinattar/info-bottleneck.git
    cd your-repo-name
  2. Install the required packages:

    pip install -r requirements.txt

Usage

Simple MLP

To train a simple fully connected network on the MNIST dataset:

  1. Navigate to the directory:

    cd info_bottleneck
  2. Run the training script:

    python main.py --model mlp

Simple CNN

To train a simple convolutional neural network on the MNIST dataset:

  1. Navigate to the directory:

    cd info_bottleneck
  2. Run the training script:

    python main.py --model cnn

Mutual Information Calculation

The mutual information is calculated using different methods (binning, kde, kraskov). The mutual information calculation is integrated within the NetworkTrainer class and can be specified during initialization.

Example Usage

import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# MNIST dataset and DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

val_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=64, shuffle=False)

# Initialize the trainer and train the model
trainer = SimpleCNNTrainer(train_loader, val_loader, activation='relu', optimizer=optim.Adam, epochs=10, device='cuda')
trainer.train()

References

This code is based on the Information Bottleneck principle for neural networks. For more information, refer to the following resources:

Future Work

The current implementation focuses on fully connected and convolutional neural networks. Future work could include:

  • Extending the mutual information analysis to other network architectures such as recurrent neural networks (RNNs) or transformers.
  • Applying the mutual information analysis to more complex datasets beyond MNIST.
  • Investigating the impact of different mutual information estimation methods on the performance and interpretability of neural networks.

License

This project is licensed under the MIT License. See the LICENSE file for details.

Acknowledgements

Special thanks to the authors of the referenced papers and the open-source community for providing valuable resources and tools that made this project possible.

About

A Python library for calculating and visualizing mutual information in neural networks. This repository includes methods to calculate mutual information using various techniques (binning, KDE, Kraskov) and tools to train neural networks and plot information plane dynamics.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages