-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
115 lines (94 loc) · 4 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from torch.utils.data import Dataset, DataLoader
import lightning.pytorch as pl
from sklearn.model_selection import train_test_split
class AdataDataset(Dataset):
"""
A custom dataset class for representing gene data.
Args:
genes (numpy.ndarray): Array of gene data.
batches (numpy.ndarray): Array of batch information.
Attributes:
genes (numpy.ndarray): Array of gene data.
batches (numpy.ndarray): Array of batch information.
"""
def __init__(self, genes, batches):
self.genes = genes
self.batches = batches
def __len__(self):
return self.genes.shape[0]
def __getitem__(self, idx):
"""
Get a specific item from the dataset.
Args:
idx (int): Index of the item to retrieve.
Returns:
dict: A dictionary containing the gene data and batch information.
- "X" (numpy.ndarray): Gene data at the specified index.
- "batch" (numpy.ndarray): Batch information at the specified index.
"""
batch = {
"X": self.genes[idx],
"batch": self.batches[idx],
}
return batch
class AdataDataModule(pl.LightningDataModule):
def __init__(self, adata, class_key='scRNASeq_sample_ID', batch_size=2048, val_split=0.1, test_split=0.1):
"""
LightningDataModule for handling data loading and processing for AdataDataset.
Args:
adata (AnnData): Annotated data object containing gene expression data.
batch_size (int): Number of samples per batch (default: 2048).
val_split (float): Fraction of data to be used for validation (default: 0.1).
test_split (float): Fraction of data to be used for testing (default: 0.1).
"""
super().__init__()
self.adata = adata
self.batch_size = batch_size
self.val_split = val_split
self.test_split = test_split
self.class_key = class_key
def setup(self, stage=None):
"""
Prepare the train, validation, and test datasets.
Args:
stage (str, optional): The current stage (e.g., 'fit', 'validate', 'test'). Defaults to None.
"""
train_genes, val_test_genes, train_batches, val_test_batches = train_test_split(
self.adata.X, self.adata.obs[self.class_key],
test_size=self.val_split + self.test_split, random_state=42
)
val_genes, test_genes, val_batches, test_batches = train_test_split(
val_test_genes, val_test_batches,
test_size=self.test_split / (self.val_split + self.test_split), random_state=42
)
self.train_dataset = AdataDataset(train_genes.todense(), train_batches)
self.val_dataset = AdataDataset(val_genes.todense(), val_batches)
self.test_dataset = AdataDataset(test_genes.todense(), test_batches)
def train_dataloader(self):
"""
Returns a DataLoader for the training dataset.
Returns:
torch.utils.data.DataLoader: DataLoader for the training dataset.
"""
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
"""
Returns a DataLoader for the validation dataset.
Returns:
torch.utils.data.DataLoader: DataLoader for the validation dataset.
"""
return DataLoader(self.val_dataset, batch_size=self.batch_size)
def test_dataloader(self):
"""
Returns a DataLoader for the test dataset.
Returns:
torch.utils.data.DataLoader: DataLoader for the test dataset.
"""
return DataLoader(self.test_dataset, batch_size=self.batch_size)
def __iter__(self):
"""
Iterates over the train, validation, and test dataloaders.
Yields:
torch.utils.data.DataLoader: The train, validation, and test dataloaders.
"""
yield from (self.train_dataloader(), self.val_dataloader(), self.test_dataloader())