Skip to content

theislab/inVAE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

86 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

inVAE

inVAE is a conditionally invariant variational autoencoder that identifies both spurious (distractors) and invariant features. It leverages domain variability to learn conditionally invariant representations. We show that inVAE captures biological variations in single-cell datasets obtained from diverse conditions and labs. inVAE incorporates biological covariates and mechanisms such as disease states, to learn an invariant data representation. This improves cell classification accuracy significantly.

Installation

  1. PyPI only
    pip install invae

  2. Development Version (latest version on github)
    git clone https://github.com/theislab/inVAE.git
    cd inVAE
    pip install .

Example

Integration of Human Lung Cell Atlas using both healthy and disease samples

Usage

  1. Load the data:
    adata = sc.read(path/to/data)
  2. Optional - Split the data into train, val, test (in supervised case for training classifier as well)
  3. Initialize the model, either Factorized or Non-Factorized:
from inVAE import FinVAE, NFinVAE`

inv_covar_keys = {
    'cont': [],
    'cat': ['cell_type', 'disease'] #set to the keys in the adata
}

spur_covar_keys = {
    'cont': [],
    'cat': ['batch'] #set to the keys in the adata
}

model = FinVAE(
    adata = adata_train,
    layer = 'counts', # The layer where the raw counts are stored in adata (None for adata.X: default)
    inv_covar_keys = inv_covar_keys,
    spur_covar_keys = spur_covar_keys,
    latent_dim_inv = 20, 
    latent_dim_spur = 5,
    device = 'cpu',
    decoder_dist = 'nb'
)

Set inject_covar_in_latent= True if you wish to add the spurious conditions directly to the latent (instead of learning the spurious latents). This gives you the most compatible version to SCVI.

For non-factorized model, use:

model = NFinVAE(
    adata = adata_train,
    layer = 'counts', # The layer where the raw counts are stored in adata (None for adata.X: default)
    inv_covar_keys = inv_covar_keys,
    spur_covar_keys = spur_covar_keys,
    latent_dim_inv = 20, 
    latent_dim_spur = 5,
    device = 'cpu',
    decoder_dist = 'nb'
)
  1. Train the generative model:
    model.train(n_epochs=500, lr_train=0.001, weight_decay=0.0001)
  2. Get the latent representation: In the case that covariates that were used are missing the encoder gets zeros as inputs for that sample and covariate
# This works for an arbitrary adata object not only for the training data
# Other options for the latent type are: full or spurious
latent = model.get_latent_representation(adata, latent_type='invariant')
  1. Optional - Train the classifer (for cell types): if adata_val is not given or does not have labels the classifier is just trained on the adata object the generative model was trained on (here: adata_train)
model.train_classifier(
    adata_val,
    batch_key = 'batch',
    label_key = 'cell_type',
)
  1. Optional - Predict cell types:
# Other possible dataset_types: train or val
# train corresponds to the adata_train object above
# val to the adata used in the train_classifier function
# test is for a new unseen object
pred_test = model.predict(adata_test, dataset_type='test')
  1. Optional - Infer latent representation via trained classifier:
# As key one can use 'val' or 'test' depending which key was used in the predict function above
# E.g. for invariant latent representation
# Otherwise do not subset for the full representation or subset to the last dimensions for the spurious one
latent_samples_inv = model.saved_latent['val'][:, :model.latent_dim_inv]
  1. Optional - Saving and loading model:
model.save('./checkpoints/path.pt')
model.load('./checkpoints/path.pt')

Newest version now supports loading model parameters and weights in one:

# Same syntax for saving but now saves model params too
model.save('./checkpoints/path.pt')
# New loading function (old function can be used to load older checkpoints)
FinVAE.load_model('./checkpoints/path.pt', adata_train, device)
# or for NFinVAE
NFinVAE.load_model('./checkpoints/path.pt', adata_train, device)

Dependencies

  • scanpy==1.9.3
  • torch==2.0.1
  • tensorboard==2.13.0
  • anndata==0.8.0

Citation

H. Aliee, F. Kapl, S. Hediyeh-Zadeh, F. J. Theis, Conditionally Invariant Representation Learning for Disentangling Cellular Heterogeneity, 2023