Extends pytorch_lightning.core.datamodule.LightningDataModule and wraps QuackIterableDataset for use by pytorch_lightning.trainer.trainer.Trainer
Receives a list of Tensors with B elements. Calculates the widest tensor, which is length T. Pads all narrower tensors to T with zeros. Returns a (B x T) shaped tensor.
-
Parameters
batch (List*[pt.Tensor]*) – A list of tensors in the batch.
-
Return type
pt.Tensor
Receives a list of TokenizedQuackData with B elements. Calculates the widest tensor, which is length T. Pads all narrower tensors to T with zeros. Returns a (B x T) shaped tensor.
-
Parameters
batch (List*[TokenizedQuackData]*) – A list of TokenizedQuackData (TypedDict) in the batch.
-
Returns
A tuple of a list of metadata and a batch tensor.
-
Return type
Tuple[List[dict], pt.Tensor]
Concatenates the static and text data into a single numpy array.
-
Parameters
item (dict) – A TypedDict cp_flatten.TokenizedQuackData
-
Returns
The concatenated data.
-
Return type
np.ndarray
Bases: pytorch_lightning.core.datamodule.LightningDataModule
_init_(data_dir: str, batch_size: int = 64, workers: int = 0, train_transforms=None, val_transforms=None, test_transforms=None, dims=None)
Constructs QuackTokenizedDataModule.
-
Parameters
-
data_dir (str) – The path to top dir of the QuackIterableDataset.
-
batch_size (int) – The batch size to pass to the torch.utils.data.dataloader.DataLoader
-
workers (int) – The number of workers to pass to the torch.utils.data.dataloader.DataLoader
-
train_transforms – deprecated: DataModule property train_transforms was deprecated in pytorch_lightning.core.datamodule.LightningDataModule v1.5 and will be removed in v1.7.
-
val_transforms – deprecated: DataModule property val_transforms was deprecated in pytorch_lightning.core.datamodule.LightningDataModule v1.5 and will be removed in v1.7.
-
test_transforms – deprecated: DataModule property test_transforms was deprecated in pytorch_lightning.core.datamodule.LightningDataModule v1.5 and will be removed in v1.7.
-
dims – deprecated: DataModule property dims was deprecated in pytorch_lightning.core.datamodule.LightningDataModule v1.5 and will be removed in v1.7.
-
Constructs and returns the training dataloader using collate function pad_right.
-
Return type
torch.utils.data.dataloader.DataLoader
Constructs and returns the testing dataloader using collate function pad_right.
-
Return type
torch.utils.data.dataloader.DataLoader
Constructs and returns the validation dataloader using collate function pad_right.
-
Return type
torch.utils.data.dataloader.DataLoader
Constructs and returns the inference dataloader using collate function pad_right_with_meta.
-
Return type
torch.utils.data.dataloader.DataLoader
Returns data_width() from the cp_dataset.QuackIterableDataset loaded in this data module.
-
Return type
int