Skip to content

Commit

Permalink
Dix abstract methods
Browse files Browse the repository at this point in the history
  • Loading branch information
maximzubkov committed Mar 16, 2021
1 parent 0863b80 commit 79c1a62
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 6 deletions.
3 changes: 1 addition & 2 deletions dataset/data_modules/path_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def __init__(
)

def create_dataset(self, dataset_path: str, stage: str) -> Any:
dataset = PathDataset(self.stage2path[stage], self._config, self._vocabulary, False)
return dataset
return PathDataset(self.stage2path[stage], self._config, self._vocabulary, False)

def collate_fn(self, batch: Any) -> Any:
a_pc = [sample["a_encoding"] for sample in batch]
Expand Down
4 changes: 0 additions & 4 deletions dataset/data_modules/text_data_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from abc import abstractmethod
from typing import Any, Callable

import torch
Expand All @@ -25,19 +24,16 @@ def __init__(
num_classes=num_classes
)

@abstractmethod
def create_dataset(self, dataset_path: str, stage: str) -> Any:
return TextDataset(dataset_path=dataset_path, stage=stage, is_test=self.is_test)

@abstractmethod
def collate_fn(self, batch: Any) -> Any:
# batch contains a list of tuples of structure (sequence, target)
a = pad_sequence([item["a_encoding"].squeeze() for item in batch])
b = pad_sequence([item["b_encoding"].squeeze() for item in batch])
label = torch.LongTensor([item["label"] for item in batch])
return (a, b), label

@abstractmethod
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
(a, b), label = batch
a = a.to(device)
Expand Down

0 comments on commit 79c1a62

Please sign in to comment.