Skip to content

Official implementation for "An image is worth multiple words: discovering object level concepts using multi-concepts prompts learning" [ICML 2024]]

License

Notifications You must be signed in to change notification settings

AstraZeneca/MCPL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

An Image is Worth Multiple Words: Discovering Object Level Concepts using Multi-Concept Prompt Learning (ICML 2024)

Hugging Face Spaces Maturity level-0

teaser

An Image is Worth Multiple Words: Discovering Object Level Concepts using Multi-Concept Prompt Learning (ICML 2024)

Chen Jin1 Ryutaro Tanno2 Amrutha Saseendran1 Tom Diethe1 Philip Teare1

Multi-Concept Prompt Learning (MCPL) pioneers mask-free text-guided learning for multiple prompts from one scene. Our approach not only enhances current methodologies but also paves the way for novel applications, such as facilitating knowledge discovery through natural language-driven interactions between humans and machines.

Motivation

We use Textural Inversion (T.I.) to learn concepts from both masked (left-first) or cropped (left-second) images; MCPL-one, learning both concepts jointly from the full image with a single string; and MCPL-diverse accounting for per-image specific relationships

Naive learning multiple text embeddings from single image-sentence pair without imagery guidence lead to miss-alignment in per-word cross attention (top). We propose three regularisation terms to enhance the accuracy of prompt-object level correlation (bottom).

Method

Input images from our natural_2_concepts dataset.

Applications

Multiple concepts from single image

Input images from our natural_2_concepts dataset.

Per-image different multiple concepts

Input images from P2P demo images.

Out-of-Distribution concept discovery and hypothesis generation

Input images from LGE CMR and MIMIC-CXR dataset.

Dataset

We generate and collected a Multi-Concept-Dataset including a total of ~ 1400 images and masked objects/concepts as follows

/ (370 images) /natural_2_concepts
/natural_345_concepts
/real_natural_concepts

Data file name Size # of images
medical_2_concepts 2.5M 370
natural_2_concepts 36M 415
natural_345_concepts 13M 525
real_natural_concepts 5.6M 137

Setup

Our code builds on, and shares requirements with Latent Diffusion Models (LDM). To set up their environment, please run:

conda env create -f environment.yml
conda activate ldm
pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
cd ./src/taming-transformers
pip install -e .

You will also need the official LDM text-to-image checkpoint, available through the LDM project page.

Currently, the model can be downloaded by running:

mkdir -p models/ldm/text2img-large/
wget -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt

Learning

MCPL-all: a naive approach that learns em-beddings for all prompts in the string (including adjectives, prepositions and nouns. etc.)

  • specify the placeholder_string to describe your multi-concept images;
  • in presudo_words we specify to learn every word in the placeholder_string;
python main.py --base configs/latent-diffusion/txt2img-1p4B-finetune.yaml \
                -t \
                --actual_resume </path/to/pretrained/model.ckpt> \
                -n <run_name> \
                --gpus 0, \
                --data_root </path/to/directory/with/images> \
                --init_word <initialization_word> \
                --placeholder_string 'green * and orange @' \
                --presudo_words 'green,*,and,orange,@'

MCPL-one: which simplifies the objective by learning single prompt (nouns) per concept

  • in this case, in presudo_words we specify to learn only a subset of words in the placeholder_string;
python main.py --base configs/latent-diffusion/txt2img-1p4B-finetune.yaml \
                -t \
                --actual_resume </path/to/pretrained/model.ckpt> \
                -n <run_name> \
                --gpus 0, \
                --data_root </path/to/directory/with/images> \
                --init_word <initialization_word> \
                --placeholder_string 'green * and orange @' \
                --presudo_words '*,@'

MCPL-diverse: where different strings are learned per image to observe variances among examples

  • before start, name each training image using single word representing relation;
  • e.g. in the ball and box exp, we train with: <'front.jpg, next.jpg, on.jpg, under.jpg'>;
  • in placeholder_string we describe the multi-concept, and use 'RELATE' as placeholder of relationship between multi-concepts;
  • in presudo_words, we specify all presudo_words include relations to be learnt, the per-image relation will be injected via replace 'RELATE' with the relation specified by each image's name;
python main.py --base configs/latent-diffusion/txt2img-1p4B-finetune.yaml \
                -t \
                --actual_resume </path/to/pretrained/model.ckpt> \
                -n <run_name> \
                --gpus 0, \
                --data_root </path/to/directory/with/images> \
                --init_word <initialization_word> \
                --placeholder_string 'green * RELATE orange @' \
                --presudo_words '*,@,on,under,next,front'

Regularisation-1: adding PromptCL and Bind adjective (teddybear skateboard example)

python main.py --base configs/latent-diffusion/txt2img-1p4B-finetune.yaml \
                -t \
                --actual_resume </path/to/pretrained/model.ckpt> \
                -n <run_name> \
                --gpus 0, \
                --data_root </path/to/directory/with/images> \
                --init_word <initialization_word> \
                --placeholder_string 'a brown @ on a rolling * at times square' \
                --presudo_words 'a,brown,on,rolling,at,times,square,@,*' \
                --attn_words 'brown,rolling,@,*' \
                --presudo_words_softmax '@,*' \
                --presudo_words_infonce '@,*' \
                --infonce_temperature 0.2 \
                --infonce_scale 0.0005 \
                --adj_aug_infonce 'brown,rolling' \
                --attn_mask_type 'skip'

Regularisation-2: adding PromptCL, Bind adjective and Attention Mask (teddybear skateboard example)

python main.py --base configs/latent-diffusion/txt2img-1p4B-finetune.yaml \
                -t \
                --actual_resume </path/to/pretrained/model.ckpt> \
                -n <run_name> \
                --gpus 0, \
                --data_root </path/to/directory/with/images> \
                --init_word <initialization_word> \
                --placeholder_string 'a brown @ on a rolling * at times square' \
                --presudo_words 'a,brown,on,rolling,at,times,square,@,*' \
                --attn_words 'brown,rolling,@,*' \
                --presudo_words_softmax '@,*' \
                --presudo_words_infonce '@,*' \
                --infonce_temperature 0.3 \
                --infonce_scale 0.00075 \
                --adj_aug_infonce 'brown,rolling'

Generation

To generate new images of the learned concept, run:

python scripts/txt2img.py --ddim_eta 0.0 
            --n_samples 8 
            --n_iter 2 
            --scale 10.0 
            --ddim_steps 50 
            --embedding_path /path/to/logs/trained_model/checkpoints/embeddings_gs-6099.pt 
            --ckpt_path /path/to/pretrained/model.ckpt 
            --prompt "a photo of green * and orange @"

where * and @ is the placeholder string used during inversion.

Code scructure

Our code is builds on the code from the Textural Inversion (MIT licence) library as well as the Prompt-to-Prompt (Apache-2.0 licence) codebase.

The mainjority modifications are performed in the following files, where we provide docstrings for all functions:

./main.py
./src/p2p/p2p_ldm_utils.py
./src/p2p/ptp_utils.py
./ldm/modules/embedding_manager.py
./ldm/models/diffusion/ddpm.py

The rest lib files are mostly unchanged and inherent from prior work.

FAQ

bert tokenizer error Sometimes one may get the following error due to the intrinsic error of tokenizer, simply try a different word with similar meaning. For example in the error below, replace 'peachy' in your prompt with 'splendid' would resolve the issue.

File "/YOUR-HOME-PATH/MCPL/ldm/modules/embedding_manager.py", line 22, in get_bert_token_for_string
    assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
AssertionError: String 'peachy' maps to more than a single token. Please use another string

Citation

If you make use of our work, please cite our paper:

@inproceedings{
anonymous2024an,
title={An Image is Worth Multiple Words: Discovering Object Level Concepts using Multi-Concept Prompt Learning},
author={Anonymous},
booktitle={Forty-first International Conference on Machine Learning},
year={2024},
url={https://openreview.net/forum?id=F3x6uYILgL}
}

About

Official implementation for "An image is worth multiple words: discovering object level concepts using multi-concepts prompts learning" [ICML 2024]]

Resources

License

Stars

Watchers

Forks

Packages

No packages published