diff --git a/configs/code2class-poj104.yaml b/configs/code2class-poj104.yaml new file mode 100644 index 0000000..7354a59 --- /dev/null +++ b/configs/code2class-poj104.yaml @@ -0,0 +1,74 @@ +hydra: + run: + dir: . + output_subdir: null + job_logging: null + hydra_logging: null + +name: code2class + +seed: 9 +num_workers: 2 +log_offline: false + +num_classes: 5 + +# data keys +data_folder: data +vocabulary_name: vocabulary.pkl +train_holdout: train +val_holdout: val +test_holdout: test + +save_every_epoch: 1 +val_every_epoch: 1 +log_every_epoch: 10 +progress_bar_refresh_rate: 1 + +hyper_parameters: + n_epochs: 3000 + patience: 10 + batch_size: 16 + test_batch_size: 512 + clip_norm: 5 + max_context: 200 + random_context: true + shuffle_data: true + + optimizer: "Momentum" + nesterov: true + learning_rate: 0.01 + weight_decay: 0 + decay_gamma: 0.95 + +dataset: + name: poj_104 + target: + max_parts: 1 + is_wrapped: false + is_splitted: false + vocabulary_size: 27000 + token: + max_parts: 5 + is_wrapped: false + is_splitted: true + vocabulary_size: 190000 + path: + max_parts: 9 + is_wrapped: false + is_splitted: true + vocabulary_size: null + +encoder: + embedding_size: 16 + rnn_size: 16 + use_bi_rnn: true + embedding_dropout: 0.25 + rnn_num_layers: 1 + rnn_dropout: 0.5 + +classifier: + n_hidden_layers: 2 + hidden_size: 16 + classifier_input_size: 16 + activation: relu \ No newline at end of file diff --git a/dataset/__init__.py b/dataset/__init__.py index 4248f23..ce46c62 100644 --- a/dataset/__init__.py +++ b/dataset/__init__.py @@ -1,9 +1,18 @@ -from .base_data_module import BaseDataModule +from dataset.classification_datasets.text_dataset import TextDataset +from dataset.data_modules import TextDataModule, PathDataModule +from .base_data_module import BaseContrastiveDataModule from .contrastive_dataset import ContrastiveDataset -from .classification_datasets.text_dataset import TextDataset __all__ = [ "TextDataset", + "TextDataModule", + "PathDataModule", "ContrastiveDataset", - "BaseDataModule", + "BaseContrastiveDataModule", + "data_modules" ] + +data_modules = { + "LSTM": TextDataModule, + "Code2Class": PathDataModule +} diff --git a/dataset/base_data_module.py b/dataset/base_data_module.py index 96f788a..cb921df 100644 --- a/dataset/base_data_module.py +++ b/dataset/base_data_module.py @@ -1,50 +1,42 @@ -from os import walk +from abc import abstractmethod from os.path import exists from os.path import join from typing import Any, Callable -import torch from pytorch_lightning import LightningDataModule -from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from .contrastive_dataset import ContrastiveDataset from .download import load_dataset -from .classification_datasets.text_dataset import get_text_dataset - -encoder2datasets = { - "LSTM": get_text_dataset, -} SEED = 9 -class BaseDataModule(LightningDataModule): +class BaseContrastiveDataModule(LightningDataModule): def __init__( self, - encoder_name: str, dataset_name: str, + num_classes: int, batch_size: int, is_test: bool = False, transform: Callable = None ): super().__init__() - if encoder_name in encoder2datasets: - self.get_dataset = encoder2datasets[encoder_name] - else: - raise NotImplementedError(f"Dataset for {encoder_name} is currently not available") self.dataset_name = dataset_name self.dataset_path = join("data", dataset_name) self.batch_size = batch_size - _, base_dirs, _ = next(walk(join(self.dataset_path, "train"))) - self.num_classes = len(base_dirs) + self.num_classes = num_classes self.transform = transform self.is_test = is_test self.clf_dataset = {} self.contrastive_dataset = {} + @abstractmethod + def create_dataset(self, dataset_path: str, stage: str) -> Any: + pass + def prepare_data(self): if not exists(self.dataset_path): load_dataset(self.dataset_name) @@ -57,11 +49,7 @@ def setup(self, stage: str = None): stages += ["test"] for stage in stages: - self.clf_dataset[stage] = self.get_dataset( - dataset_path=self.dataset_path, - stage=stage, - is_test=self.is_test - ) + self.clf_dataset[stage] = self.create_dataset(dataset_path=self.dataset_path, stage=stage) self.contrastive_dataset[stage] = ContrastiveDataset(clf_dataset=self.clf_dataset[stage]) def train_dataloader(self): @@ -91,22 +79,13 @@ def test_dataloader(self): drop_last=True ) - def _collate(self, batch): - # batch contains a list of tuples of structure (sequence, target) - a = pad_sequence([item["a_encoding"].squeeze() for item in batch]) - b = pad_sequence([item["b_encoding"].squeeze() for item in batch]) - label = torch.LongTensor([item["label"] for item in batch]) - data = (a, b), label + @abstractmethod + def collate_fn(self, batch: Any) -> Any: + pass + def _collate(self, batch: Any) -> Any: + batch = self.collate_fn(batch) if self.transform is not None: - return self.transform(data) + return self.transform(batch) else: - return data - - def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: - (a, b), label = batch - a = a.to(device) - b = b.to(device) - if isinstance(label, torch.Tensor): - label = label.to(device) - return (a, b), label + return batch diff --git a/dataset/classification_datasets/__init__.py b/dataset/classification_datasets/__init__.py new file mode 100644 index 0000000..75b1e1f --- /dev/null +++ b/dataset/classification_datasets/__init__.py @@ -0,0 +1,7 @@ +from .path_dataset import PathDataset +from .text_dataset import TextDataset + +__all__ = [ + "TextDataset", + "PathDataset" +] diff --git a/dataset/classification_datasets/path_dataset.py b/dataset/classification_datasets/path_dataset.py new file mode 100644 index 0000000..fcfbcb2 --- /dev/null +++ b/dataset/classification_datasets/path_dataset.py @@ -0,0 +1,14 @@ +from typing import Optional, Tuple, Any + +from code2seq.dataset import PathContextDataset, PathContextSample +from code2seq.utils.vocabulary import Vocabulary +from omegaconf import DictConfig + + +class PathDataset(PathContextDataset): + def __init__(self, data_file_path: str, config: DictConfig, vocabulary: Vocabulary, random_context: bool): + super().__init__(data_file_path, config, vocabulary, random_context) + + def __getitem__(self, index) -> Optional[Tuple[PathContextSample, Any]]: + pcs = super().__getitem__(index) + return pcs, pcs.label[0][0] diff --git a/dataset/classification_datasets/text_dataset.py b/dataset/classification_datasets/text_dataset.py index 85df4b8..9a1fc4e 100644 --- a/dataset/classification_datasets/text_dataset.py +++ b/dataset/classification_datasets/text_dataset.py @@ -10,10 +10,6 @@ from preprocess import tokenize -def get_text_dataset(dataset_path: str, stage: str, is_test: bool = False): - return TextDataset(dataset_path=dataset_path, stage=stage, is_test=is_test) - - class TextDataset(Dataset): def __init__(self, dataset_path: str, stage: str, is_test: bool = False): super().__init__() diff --git a/dataset/data_modules/__init__.py b/dataset/data_modules/__init__.py new file mode 100644 index 0000000..4838d7a --- /dev/null +++ b/dataset/data_modules/__init__.py @@ -0,0 +1,7 @@ +from .path_data_module import PathDataModule +from .text_data_module import TextDataModule + +__all__ = [ + "TextDataModule", + "PathDataModule" +] diff --git a/dataset/data_modules/path_data_module.py b/dataset/data_modules/path_data_module.py new file mode 100644 index 0000000..4151ab2 --- /dev/null +++ b/dataset/data_modules/path_data_module.py @@ -0,0 +1,71 @@ +from os.path import join +from typing import Callable, Any, Optional, Tuple + +import torch +from code2seq.dataset import PathContextBatch +from code2seq.utils.vocabulary import Vocabulary +from omegaconf import DictConfig, OmegaConf + +from dataset.base_data_module import BaseContrastiveDataModule +from dataset.classification_datasets import PathDataset + + +def get_config() -> DictConfig: + return OmegaConf.load("configs/code2class-poj104.yaml") + + +class PathDataModule(BaseContrastiveDataModule): + def __init__( + self, + dataset_name: str, + batch_size: int, + num_classes: int, + is_test: bool = False, + transform: Callable = None, + ): + + config = get_config() + self._config = config + self._vocabulary = Vocabulary.load_vocabulary( + join(config.data_folder, config.dataset.name, config.vocabulary_name) + ) + + self._dataset_dir = join(config.data_folder, config.dataset.name) + self._train_data_file = join(self._dataset_dir, f"{config.dataset.name}.{config.train_holdout}.c2s") + self._val_data_file = join(self._dataset_dir, f"{config.dataset.name}.{config.val_holdout}.c2s") + self._test_data_file = join(self._dataset_dir, f"{config.dataset.name}.{config.test_holdout}.c2s") + + self.stage2path = { + "train": self._train_data_file, + "test": self._test_data_file, + "val": self._val_data_file + } + + BaseContrastiveDataModule.__init__( + self, + dataset_name=dataset_name, + batch_size=config.hyper_parameters.batch_size, + is_test=is_test, + transform=transform, + num_classes=num_classes + ) + + def create_dataset(self, dataset_path: str, stage: str) -> Any: + return PathDataset(self.stage2path[stage], self._config, self._vocabulary, False) + + def collate_fn(self, batch: Any) -> Any: + a_pc = [sample["a_encoding"] for sample in batch] + b_pc = [sample["b_encoding"] for sample in batch] + labels = [sample["label"] for sample in batch] + a_pc = PathContextBatch(a_pc) + b_pc = PathContextBatch(b_pc) + return (a_pc, b_pc), torch.LongTensor(labels) + + def transfer_batch_to_device( + self, batch: Tuple[PathContextBatch, torch.Tensor], device: Optional[torch.device] = None + ) -> Tuple[PathContextBatch, torch.Tensor]: + pc, labels = batch + if device is not None: + pc.move_to_device(device) + labels.to(device) + return pc, labels diff --git a/dataset/data_modules/text_data_module.py b/dataset/data_modules/text_data_module.py new file mode 100644 index 0000000..c46e86f --- /dev/null +++ b/dataset/data_modules/text_data_module.py @@ -0,0 +1,43 @@ +from typing import Any, Callable + +import torch +from torch.nn.utils.rnn import pad_sequence + +from dataset.base_data_module import BaseContrastiveDataModule +from dataset.classification_datasets.text_dataset import TextDataset + + +class TextDataModule(BaseContrastiveDataModule): + def __init__( + self, + dataset_name: str, + batch_size: int, + num_classes: int, + is_test: bool = False, + transform: Callable = None + ): + super().__init__( + dataset_name=dataset_name, + batch_size=batch_size, + is_test=is_test, + transform=transform, + num_classes=num_classes + ) + + def create_dataset(self, dataset_path: str, stage: str) -> Any: + return TextDataset(dataset_path=dataset_path, stage=stage, is_test=self.is_test) + + def collate_fn(self, batch: Any) -> Any: + # batch contains a list of tuples of structure (sequence, target) + a = pad_sequence([item["a_encoding"].squeeze() for item in batch]) + b = pad_sequence([item["b_encoding"].squeeze() for item in batch]) + label = torch.LongTensor([item["label"] for item in batch]) + return (a, b), label + + def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: + (a, b), label = batch + a = a.to(device) + b = b.to(device) + if isinstance(label, torch.Tensor): + label = label.to(device) + return (a, b), label diff --git a/models/encoders/__init__.py b/models/encoders/__init__.py index 44de1a1..76f3f73 100644 --- a/models/encoders/__init__.py +++ b/models/encoders/__init__.py @@ -1,10 +1,13 @@ +from .code2class import Code2ClassModel from .lstm import LSTMModel __all__ = [ "LSTMModel", + "Code2ClassModel", "encoder_models" ] encoder_models = { - "LSTM": LSTMModel + "LSTM": LSTMModel, + "Code2Class": Code2ClassModel } diff --git a/models/encoders/code2class.py b/models/encoders/code2class.py new file mode 100644 index 0000000..69f6b80 --- /dev/null +++ b/models/encoders/code2class.py @@ -0,0 +1,15 @@ +from code2seq.dataset import PathContextBatch +from code2seq.model import Code2Class +from code2seq.utils.vocabulary import Vocabulary +from omegaconf import DictConfig +from torch import nn + + +class Code2ClassModel(nn.Module): + def __init__(self, config: DictConfig, vocabulary: Vocabulary): + super().__init__() + self.num_classes = config.num_classes + self.code2class = Code2Class(config, vocabulary) + + def forward(self, batch: PathContextBatch): + return self.code2class(batch.contexts, batch.contexts_per_label) diff --git a/models/self_supervised/byol.py b/models/self_supervised/byol.py index 7495b0a..ddfa4cb 100644 --- a/models/self_supervised/byol.py +++ b/models/self_supervised/byol.py @@ -2,7 +2,8 @@ from dataclasses import dataclass, asdict import torch.nn as nn -from dataset import BaseDataModule +from code2seq.utils.vocabulary import Vocabulary +from omegaconf import OmegaConf from pl_bolts.models.self_supervised import BYOL from pl_bolts.models.self_supervised.byol.models import MLP @@ -14,7 +15,6 @@ def __init__( self, base_encoder: str, encoder_config: dataclass, - datamodule: BaseDataModule, learning_rate: float = 0.2, weight_decay: float = 1.5e-6, input_height: int = 32, @@ -26,10 +26,18 @@ def __init__( ): self.hparams = asdict(encoder_config) self.encoder_config = encoder_config - self.datamodule = datamodule + + if base_encoder == "LSTM": + encoder = encoder_models[base_encoder](self.encoder_config) + self.num_classes = self.encoder_config.num_classes + else: + _config = OmegaConf.load("configs/code2class-poj104.yaml") + _vocabulary = Vocabulary.load_vocabulary("data/poj_104/vocabulary.pkl") + encoder = encoder_models[base_encoder](config=_config, vocabulary=_vocabulary) + self.num_classes = _config.num_classes super().__init__( - num_classes=datamodule.num_classes, + num_classes=self.num_classes, learning_rate=learning_rate, weight_decay=weight_decay, input_height=input_height, @@ -40,7 +48,7 @@ def __init__( **kwargs ) - self.online_network = SiameseArm(encoder_models[base_encoder](self.encoder_config)) + self.online_network = SiameseArm(encoder=encoder) self.target_network = deepcopy(self.online_network) diff --git a/models/self_supervised/moco.py b/models/self_supervised/moco.py index 783d5ac..d8fe494 100644 --- a/models/self_supervised/moco.py +++ b/models/self_supervised/moco.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, asdict -from dataset import BaseDataModule +from code2seq.utils.vocabulary import Vocabulary +from omegaconf import OmegaConf from pl_bolts.models.self_supervised import MocoV2 from models.encoders import encoder_models @@ -11,7 +12,6 @@ def __init__( self, base_encoder: str, encoder_config: dataclass, - datamodule: BaseDataModule, num_negatives: int = 65536, encoder_momentum: float = 0.999, softmax_temperature: float = 0.07, @@ -25,7 +25,6 @@ def __init__( ): self.hparams = asdict(encoder_config) self.encoder_config = encoder_config - self.datamodule = datamodule super().__init__( base_encoder=base_encoder, @@ -43,6 +42,12 @@ def __init__( ) def init_encoders(self, base_encoder: str): - encoder_q = encoder_models[base_encoder](self.encoder_config) - encoder_k = encoder_models[base_encoder](self.encoder_config) + if base_encoder == "LSTM": + encoder_q = encoder_models[base_encoder](self.encoder_config) + encoder_k = encoder_models[base_encoder](self.encoder_config) + else: + _config = OmegaConf.load("configs/code2class-poj104.yaml") + _vocabulary = Vocabulary.load_vocabulary("data/poj_104/vocabulary.pkl") + encoder_q = encoder_models[base_encoder](config=_config, vocabulary=_vocabulary) + encoder_k = encoder_models[base_encoder](config=_config, vocabulary=_vocabulary) return encoder_q, encoder_k diff --git a/requirements.txt b/requirements.txt index 81db0bf..b4bef1a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ flake8==3.8.4 -mypy==0.800 -numpy==1.19.5 -pytorch-lightning==1.1.6 +pytorch-lightning==1.1.7 pytorch-lightning-bolts==0.3.0 torch==1.7.1 -tqdm==4.56.0 -wandb==0.10.15 +tqdm~=4.58.0 +wandb==0.10.20 pytest==6.2.2 torchvision~=0.8.2 wget~=3.2 -split-folders==0.4.3 \ No newline at end of file +split-folders==0.4.3 +omegaconf~=2.0.6 +code2seq~=0.0.2 \ No newline at end of file diff --git a/train.py b/train.py index 2fb6481..c840075 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,7 @@ from pytorch_lightning.loggers import WandbLogger from configs import default_config, test_config, default_hyperparametrs -from dataset import BaseDataModule +from dataset import data_modules from models import encoder_models, ssl_models, ssl_models_transforms SEED = 9 @@ -25,10 +25,10 @@ def train(model: str, encoder: str, dataset: str, is_test: bool, log_offline: bo hyperparams = default_hyperparametrs transform = ssl_models_transforms[model]() if model in ssl_models_transforms else None - dm = BaseDataModule( - encoder_name=encoder, + dm = data_modules[encoder]( dataset_name=dataset, is_test=is_test, + num_classes=config.num_classes, batch_size=hyperparams.batch_size, transform=transform ) @@ -36,7 +36,6 @@ def train(model: str, encoder: str, dataset: str, is_test: bool, log_offline: bo model_ = ssl_models[model]( base_encoder=encoder, encoder_config=config, - datamodule=dm, batch_size=hyperparams.batch_size, max_epochs=hyperparams.n_epochs, )