From ad1a36a0b6e60709ba851008200f490d8c8bca33 Mon Sep 17 00:00:00 2001 From: zhanglab-wbgcas Date: Wed, 13 Dec 2023 21:01:08 +0800 Subject: [PATCH] Add files via upload --- inter_AnnData.py | 311 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 inter_AnnData.py diff --git a/inter_AnnData.py b/inter_AnnData.py new file mode 100644 index 0000000..a2b0302 --- /dev/null +++ b/inter_AnnData.py @@ -0,0 +1,311 @@ + +import torch +import torch.nn as nn +import warnings +warnings.filterwarnings('ignore') +import math +import pandas as pd +import scanpy as sc +import os +from tqdm.auto import tqdm +from sklearn.metrics import accuracy_score ,f1_score ,recall_score ,precision_score +from torch.utils.data import (DataLoader ,Dataset) +torch.set_default_tensor_type(torch.DoubleTensor) +import numpy as np +import random +from sklearn import preprocessing + +def same_seeds(seed): + random.seed(seed) + # Numpy + np.random.seed(seed) + # Torch + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True +same_seeds(2021) + + + +def getXY(gap, adata, y_trains): + X = adata.X + if not isinstance(X, np.ndarray): + X = X.todense() + X = np.asarray(X) + + single_cell_list = [] + for single_cell in X: + feature = [] + length = len(single_cell) + for k in range(0, length, gap): + if (k + gap <= length): + a = single_cell[k:k + gap] + else: + a = single_cell[length - gap:length] + + a = preprocessing.scale(a) + feature.append(a) + feature = np.asarray(feature) + single_cell_list.append(feature) + + single_cell_list = np.asarray(single_cell_list) + + cell_types = [] + + if(len(y_trains) > 0): + for i in y_trains: + i = str(i).upper() + if (not cell_types.__contains__(i)): + cell_types.append(i) + + return single_cell_list, y_trains, cell_types + else: + return single_cell_list + +def getNewData(cells, cell_types): + labels = [] + for i in range(len(cells)): + cell = cells[i] + cell = str(cell).upper() + + if (cell_types.__contains__(cell)): + indexs = cell_types.index(cell) + labels.append(indexs + 1) + else: + labels.append(0) # 0 denotes the unknowns cell types + + return np.asarray(labels) + + +class TrainDataSet(Dataset): + def __init__(self, data, label): + self.data = data + self.label = label + self.length = len(data) + + def __len__(self): + return self.length + + def __getitem__(self, index): + data = torch.from_numpy(self.data) + label = torch.from_numpy(self.label) + + return data[index], label[index] + +class TestDataSet(Dataset): + def __init__(self, data): + self.data = data + + self.length = len(data) + + def __len__(self): + return self.length + + def __getitem__(self, index): + data = torch.from_numpy(self.data) + return data[index] + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:x.size(0), :] + return self.dropout(x) + +class CIForm(nn.Module): + def __init__(self, input_dim, nhead=2, d_model=80, num_classes=2, dropout=0.1): + super().__init__() + self.encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, dim_feedforward=1024, nhead=nhead, dropout=dropout + ) + self.positionalEncoding = PositionalEncoding(d_model=d_model, dropout=dropout) + self.pred_layer = nn.Sequential( + nn.Linear(d_model, d_model), + nn.ReLU(), + nn.Linear(d_model, num_classes) + ) + + def forward(self, mels): + out = mels.permute(1, 0, 2) + out = self.positionalEncoding(out) + out = self.encoder_layer(out) + out = out.transpose(0, 1) + out = out.mean(dim=1) + out = self.pred_layer(out) + return out + +def ciForm(s ,Train_adata ,train_labels ,Test_adata,y_test,n_epochs=20): + gap = s + d_models = s + heads = 64 + + lr = 0.0001 + dp = 0.1 + batch_sizes = 256 + n_epochs = n_epochs + + train_data, train_cells, train_cellTypes = getXY(gap,Train_adata,train_labels) + print("train_cellTypes",train_cellTypes) + + cell_types = train_cellTypes + Train_labels = getNewData(train_cells, cell_types) + labels = Train_labels + cell_types = np.unique(train_labels) + num_classes = len(cell_types) + 1 + + query_data, test_cells, test_cellTypes = getXY(gap, Test_adata, y_test) + test_labels = getNewData(test_cells, cell_types) + + + model = CIForm(input_dim=d_models, nhead=heads, d_model=d_models, + num_classes=num_classes ,dropout=dp) + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5) + + + train_dataset = TrainDataSet(data=train_data, label=labels) + train_loader = DataLoader(train_dataset, batch_size=batch_sizes, shuffle=True, + pin_memory=True) + test_dataset = TrainDataSet(data=query_data, label=test_labels) + test_loader = DataLoader(test_dataset, batch_size=batch_sizes, shuffle=False, + pin_memory=True) + new_cellTypes = [] + new_cellTypes.append("unassigned") + new_cellTypes.extend(cell_types) + + + model.train() + for epoch in range(n_epochs): + # model.train() + # These are used to record information in training. + train_loss = [] + train_accs = [] + train_f1s = [] + for batch in tqdm(train_loader): + # A batch consists of image data and corresponding labels. + data, labels = batch + logits = model(data) + labels = torch.tensor(labels, dtype=torch.long) + loss = criterion(logits, labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + preds = logits.argmax(1) + preds = preds.cpu().numpy() + labels = labels.cpu().numpy() + + acc = accuracy_score(labels, preds) + f1 = f1_score(labels ,preds ,average='macro') + train_loss.append(loss.item()) + train_accs.append(acc) + train_f1s.append(f1) + train_loss = sum(train_loss) / len(train_loss) + train_acc = sum(train_accs) / len(train_accs) + train_f1 = sum(train_f1s) / len(train_f1s) + + print \ + (f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}, f1 = {train_f1:.5f}") + + model.eval() + test_accs = [] + test_f1s = [] + y_predict = [] + labelss = [] + for batch in tqdm(test_loader): + # A batch consists of image data and corresponding labels. + data, labels = batch + with torch.no_grad(): + logits = model(data) + preds = logits.argmax(1) + preds = preds.cpu().numpy().tolist() + labels = labels.cpu().numpy().tolist() + acc = accuracy_score(labels, preds) + f1 = f1_score(labels, preds, average='macro') + test_f1s.append(f1) + test_accs.append(acc) + + y_predict.extend(preds) + labelss.extend(labels) + test_acc = sum(test_accs) / len(test_accs) + test_f1 = sum(test_f1s) / len(test_f1s) + print("---------------------------------------------end test---------------------------------------------") + print("len(y_predict)" ,len(y_predict)) + all_acc = accuracy_score(labelss, y_predict) + all_f1 = f1_score(labelss, y_predict, average='macro') + print("all_acc:", all_acc ,"all_f1:", all_f1) + + labelsss = [] + y_predicts = [] + for i in labelss: + labelsss.append(new_cellTypes[i]) + for i in y_predict: + y_predicts.append(new_cellTypes[i]) + + log_dir = "log/" + log_txt = "log/" + + if (not os.path.isdir(log_dir)): + os.makedirs(log_dir) + + last_path = log_dir + str(n_epochs) + "/" + if (not os.path.isdir(last_path)): + os.makedirs(last_path) + with open(log_txt + "end_norm.txt", "a") as f: + f.writelines("log_dir:" + last_path + "\n") + f.writelines("acc:" + str(all_acc) + "\n") + f.writelines('f1:' + str(all_f1) + "\n") + + +s = 1024 + + +import os +import time as tm +from sklearn.metrics import * +from sklearn.model_selection import train_test_split + + +tissues = "Root" +path = "../../Datasets/Arabidopsis thaliana/" + tissues + "/" +data_name = "03SRP330542" +test_paths = [path] +test_names = [data_name + ".h5ad"] +topgenes = 2000 + +adata_mod1 = sc.read_h5ad(test_paths[0] + test_names[0]) +adata_mod1.var_names_make_unique() +adata_mod1.obs['domain_id'] = 0 +sc.pp.highly_variable_genes(adata_mod1, n_top_genes=topgenes) +adata_mod1 = adata_mod1[:, adata_mod1.var['highly_variable']] + +X = adata_mod1.X # obsm['X_pca']#.todense() +if not isinstance(X, np.ndarray): + X = X.todense() + +X = np.asarray(X) +print("X.shape", X.shape) +Y = adata_mod1.obs['Celltype'].to_numpy() +print("Y.shape", Y.shape) + +test_size = 0.2 +ref_adata, query_adata, y_train, y_test = train_test_split(adata_mod1, Y, test_size=test_size, random_state=2024) +print("ref_adata",ref_adata) +print("query_adata",query_adata) +start = tm.time() + +ciForm(s ,ref_adata ,y_train ,query_adata,y_test,n_epochs=20) +