Skip to content

Powerful Python tool for visualizing and interacting with pre-trained Masked Language Models (MLMs) like BERT. Features include self-attention visualization, masked token prediction, model fine-tuning, embedding analysis with PCA/t-SNE, and SHAP-based model interpretability.

Notifications You must be signed in to change notification settings

AKKI0511/Masked-Language-Model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Self-Attention Visualization and Fine-Tuning with Masked Language Models (MLM)

This repository contains a powerful Python tool for exploring and interacting with pre-trained Masked Language Models (MLMs), such as BERT, CamemBERT, BERT-based Multilingual models, and more. The tool allows users to:

  • Visualize self-attention mechanisms in action,
  • Predict missing tokens (masked tokens) in sentences,
  • Fine-tune the model on custom datasets,
  • Analyze contextual embeddings using dimensionality reduction techniques (e.g., PCA, t-SNE),
  • Explain model predictions using SHAP (SHapley Additive exPlanations) values for interpretability.

This project is perfect for anyone interested in natural language processing (NLP), masked language modeling, or understanding how transformers like BERT process and represent language.

Key Features

  • Self-Attention Visualization: Generate attention heatmaps for tokens across different layers and heads of a transformer-based model.
  • Masked Token Prediction: Use a pre-trained model to predict the masked token in a sentence.
  • Model Fine-Tuning: Easily fine-tune the model with your own dataset and task.
  • Embedding Analysis: Visualize the contextual embeddings of tokens in 2D using PCA or t-SNE.
  • SHAP-based Explainability: Generate SHAP explanations to understand how each word in a sentence influences the model's predictions.

Requirements

To use the tool, you need to have the following installed:

  • Python 3 (recommended version 3.7+)
  • TensorFlow
  • Hugging Face Transformers library
  • PIL (Python Imaging Library)
  • scikit-learn for dimensionality reduction (PCA and t-SNE)
  • SHAP for explainability

You can install all the dependencies using the following command:

pip install -r requirements.txt

Getting Started

Follow these simple steps to get started:

  1. Clone the repository:

    git clone https://github.com/AKKI0511/Masked-Language-Model.git
    cd Masked-Language-Model
  2. Run the script:

    python mask.py
  3. Choose your language: Upon running the script, you'll be prompted to select a language model. You can choose between English, French, German, Chinese, or Japanese.

  4. Interact with the model: You will be presented with several options to interact with the model:

    • Predict masked token: Enter a sentence with a [MASK] token and see the model's top predictions.
    • Fine-tune model: Fine-tune the model on your custom text dataset.
    • Analyze contextual embeddings: Visualize token embeddings in a 2D space using PCA or t-SNE.
    • Explain prediction: Generate SHAP explanations to visualize how each word contributes to the model's masked token prediction.
    • Exit: Terminate the program.

Features in Detail

1. Masked Token Prediction

You can input a sentence with a missing word (e.g., I love [MASK] in the morning.), and the model will predict the missing token based on the context. Additionally, you can visualize the self-attention weights to understand which words in the sentence the model focuses on when making its prediction.

2. Fine-Tuning the Model

You can fine-tune the pre-trained model on your dataset by providing a text file. The model will train for a number of epochs that you specify, allowing you to adapt the model to specific tasks or domains.

3. Contextual Embeddings Analysis

This feature allows you to visualize how the model represents each word in a sentence as embeddings. By using dimensionality reduction techniques like PCA or t-SNE, you can see how semantically similar words are grouped in the embedding space.

4. SHAP Explainability

Using SHAP, the tool provides interpretability by showing which tokens in a sentence contribute most to the model's prediction for the masked token. This is incredibly useful for understanding model behavior and for debugging.

5. Self-Attention Visualization

The tool generates attention diagrams that show how the model attends to different words in a sentence. The diagrams are generated for each layer and attention head in the model, and saved as PNG images. The attention heatmaps provide insights into the inner workings of the transformer model, and show how words influence each other.

Example Visualizations

Here are some examples of attention diagrams generated by the tool:

image Attention heatmap from Layer 1, Head 2.

image Attention heatmap from Layer 2, Head 5.

How to Fine-Tune the Model

  1. Prepare your dataset as a plain text file.
  2. Choose the fine-tune option from the menu when you run the script.
  3. Provide the path to your dataset and specify the number of training epochs.
  4. The model will be trained on your dataset, and you will see the results after each epoch.

Customization Options

The script allows you to customize several key parameters:

  • MODEL: Select from different language models (English, French, German, Chinese, Japanese).
  • K: Number of predictions to generate for the masked token.
  • FONT: Path to the font file used for rendering text on attention diagrams.
  • GRID_SIZE: Size of each grid cell in the attention diagrams.
  • PIXELS_PER_WORD: Adjusts the size of each word in pixels for the attention diagrams.

These constants can be adjusted directly within the script according to your needs.

Running the Script on a Different Model

If you'd like to use a different pre-trained transformer model (for example, a domain-specific model), you can easily modify the MODEL constant or use any Hugging Face model supported by the AutoTokenizer and TFAutoModelForMaskedLM classes.

Acknowledgements

This project uses the powerful Hugging Face Transformers library to handle tokenization, model loading, and masked language modeling. The visualizations for attention scores and embeddings are created using PIL and matplotlib. The project also leverages SHAP for interpretability, allowing for detailed explanations of model predictions.

About

Powerful Python tool for visualizing and interacting with pre-trained Masked Language Models (MLMs) like BERT. Features include self-attention visualization, masked token prediction, model fine-tuning, embedding analysis with PCA/t-SNE, and SHAP-based model interpretability.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages