Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code2class #3

Merged
merged 19 commits into from
Mar 16, 2021
74 changes: 74 additions & 0 deletions configs/code2class-poj104.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
hydra:
run:
dir: .
output_subdir: null
job_logging: null
hydra_logging: null

name: code2class

seed: 9
num_workers: 2
log_offline: false

num_classes: 5

# data keys
data_folder: data
vocabulary_name: vocabulary.pkl
train_holdout: train
val_holdout: val
test_holdout: test

save_every_epoch: 1
val_every_epoch: 1
log_every_epoch: 10
progress_bar_refresh_rate: 1

hyper_parameters:
n_epochs: 3000
patience: 10
batch_size: 16
test_batch_size: 512
clip_norm: 5
max_context: 200
random_context: true
shuffle_data: true

optimizer: "Momentum"
nesterov: true
learning_rate: 0.01
weight_decay: 0
decay_gamma: 0.95

dataset:
name: poj_104
target:
max_parts: 1
is_wrapped: false
is_splitted: false
vocabulary_size: 27000
token:
max_parts: 5
is_wrapped: false
is_splitted: true
vocabulary_size: 190000
path:
max_parts: 9
is_wrapped: false
is_splitted: true
vocabulary_size: null

encoder:
embedding_size: 16
rnn_size: 16
use_bi_rnn: true
embedding_dropout: 0.25
rnn_num_layers: 1
rnn_dropout: 0.5

classifier:
n_hidden_layers: 2
hidden_size: 16
classifier_input_size: 16
activation: relu
15 changes: 12 additions & 3 deletions dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from .base_data_module import BaseDataModule
from dataset.classification_datasets.text_dataset import TextDataset
from dataset.data_modules import TextDataModule, PathDataModule
from .base_data_module import BaseContrastiveDataModule
from .contrastive_dataset import ContrastiveDataset
from .classification_datasets.text_dataset import TextDataset

__all__ = [
"TextDataset",
"TextDataModule",
"PathDataModule",
"ContrastiveDataset",
"BaseDataModule",
"BaseContrastiveDataModule",
"data_modules"
]

data_modules = {
"LSTM": TextDataModule,
"Code2Class": PathDataModule
}
53 changes: 16 additions & 37 deletions dataset/base_data_module.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,42 @@
from os import walk
from abc import abstractmethod
from os.path import exists
from os.path import join
from typing import Any, Callable

import torch
from pytorch_lightning import LightningDataModule
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

from .contrastive_dataset import ContrastiveDataset
from .download import load_dataset
from .classification_datasets.text_dataset import get_text_dataset

encoder2datasets = {
"LSTM": get_text_dataset,
}

SEED = 9


class BaseDataModule(LightningDataModule):
class BaseContrastiveDataModule(LightningDataModule):
def __init__(
self,
encoder_name: str,
dataset_name: str,
num_classes: int,
batch_size: int,
is_test: bool = False,
transform: Callable = None
):
super().__init__()
if encoder_name in encoder2datasets:
self.get_dataset = encoder2datasets[encoder_name]
else:
raise NotImplementedError(f"Dataset for {encoder_name} is currently not available")

self.dataset_name = dataset_name
self.dataset_path = join("data", dataset_name)
self.batch_size = batch_size
_, base_dirs, _ = next(walk(join(self.dataset_path, "train")))
self.num_classes = len(base_dirs)
self.num_classes = num_classes
self.transform = transform
self.is_test = is_test

self.clf_dataset = {}
self.contrastive_dataset = {}

@abstractmethod
def create_dataset(self, dataset_path: str, stage: str) -> Any:
pass

def prepare_data(self):
if not exists(self.dataset_path):
load_dataset(self.dataset_name)
Expand All @@ -57,11 +49,7 @@ def setup(self, stage: str = None):
stages += ["test"]

for stage in stages:
self.clf_dataset[stage] = self.get_dataset(
dataset_path=self.dataset_path,
stage=stage,
is_test=self.is_test
)
self.clf_dataset[stage] = self.create_dataset(dataset_path=self.dataset_path, stage=stage)
self.contrastive_dataset[stage] = ContrastiveDataset(clf_dataset=self.clf_dataset[stage])

def train_dataloader(self):
Expand Down Expand Up @@ -91,22 +79,13 @@ def test_dataloader(self):
drop_last=True
)

def _collate(self, batch):
# batch contains a list of tuples of structure (sequence, target)
a = pad_sequence([item["a_encoding"].squeeze() for item in batch])
b = pad_sequence([item["b_encoding"].squeeze() for item in batch])
label = torch.LongTensor([item["label"] for item in batch])
data = (a, b), label
@abstractmethod
def collate_fn(self, batch: Any) -> Any:
pass

def _collate(self, batch: Any) -> Any:
batch = self.collate_fn(batch)
if self.transform is not None:
return self.transform(data)
return self.transform(batch)
else:
return data

def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
(a, b), label = batch
a = a.to(device)
b = b.to(device)
if isinstance(label, torch.Tensor):
label = label.to(device)
return (a, b), label
return batch
7 changes: 7 additions & 0 deletions dataset/classification_datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .path_dataset import PathDataset
from .text_dataset import TextDataset

__all__ = [
"TextDataset",
"PathDataset"
]
14 changes: 14 additions & 0 deletions dataset/classification_datasets/path_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Optional, Tuple, Any

from code2seq.dataset import PathContextDataset, PathContextSample
from code2seq.utils.vocabulary import Vocabulary
from omegaconf import DictConfig


class PathDataset(PathContextDataset):
def __init__(self, data_file_path: str, config: DictConfig, vocabulary: Vocabulary, random_context: bool):
super().__init__(data_file_path, config, vocabulary, random_context)

def __getitem__(self, index) -> Optional[Tuple[PathContextSample, Any]]:
pcs = super().__getitem__(index)
return pcs, pcs.label[0][0]
4 changes: 0 additions & 4 deletions dataset/classification_datasets/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
from preprocess import tokenize


def get_text_dataset(dataset_path: str, stage: str, is_test: bool = False):
return TextDataset(dataset_path=dataset_path, stage=stage, is_test=is_test)


class TextDataset(Dataset):
def __init__(self, dataset_path: str, stage: str, is_test: bool = False):
super().__init__()
Expand Down
7 changes: 7 additions & 0 deletions dataset/data_modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .path_data_module import PathDataModule
from .text_data_module import TextDataModule

__all__ = [
"TextDataModule",
"PathDataModule"
]
71 changes: 71 additions & 0 deletions dataset/data_modules/path_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from os.path import join
from typing import Callable, Any, Optional, Tuple

import torch
from code2seq.dataset import PathContextBatch
from code2seq.utils.vocabulary import Vocabulary
from omegaconf import DictConfig, OmegaConf

from dataset.base_data_module import BaseContrastiveDataModule
from dataset.classification_datasets import PathDataset


def get_config() -> DictConfig:
return OmegaConf.load("configs/code2class-poj104.yaml")


class PathDataModule(BaseContrastiveDataModule):
def __init__(
self,
dataset_name: str,
batch_size: int,
num_classes: int,
is_test: bool = False,
transform: Callable = None,
):

config = get_config()
self._config = config
self._vocabulary = Vocabulary.load_vocabulary(
join(config.data_folder, config.dataset.name, config.vocabulary_name)
)

self._dataset_dir = join(config.data_folder, config.dataset.name)
self._train_data_file = join(self._dataset_dir, f"{config.dataset.name}.{config.train_holdout}.c2s")
self._val_data_file = join(self._dataset_dir, f"{config.dataset.name}.{config.val_holdout}.c2s")
self._test_data_file = join(self._dataset_dir, f"{config.dataset.name}.{config.test_holdout}.c2s")

self.stage2path = {
"train": self._train_data_file,
"test": self._test_data_file,
"val": self._val_data_file
}

BaseContrastiveDataModule.__init__(
self,
dataset_name=dataset_name,
batch_size=config.hyper_parameters.batch_size,
is_test=is_test,
transform=transform,
num_classes=num_classes
)

def create_dataset(self, dataset_path: str, stage: str) -> Any:
return PathDataset(self.stage2path[stage], self._config, self._vocabulary, False)

def collate_fn(self, batch: Any) -> Any:
a_pc = [sample["a_encoding"] for sample in batch]
b_pc = [sample["b_encoding"] for sample in batch]
labels = [sample["label"] for sample in batch]
a_pc = PathContextBatch(a_pc)
b_pc = PathContextBatch(b_pc)
return (a_pc, b_pc), torch.LongTensor(labels)

def transfer_batch_to_device(
self, batch: Tuple[PathContextBatch, torch.Tensor], device: Optional[torch.device] = None
) -> Tuple[PathContextBatch, torch.Tensor]:
pc, labels = batch
if device is not None:
pc.move_to_device(device)
labels.to(device)
return pc, labels
43 changes: 43 additions & 0 deletions dataset/data_modules/text_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Any, Callable

import torch
from torch.nn.utils.rnn import pad_sequence

from dataset.base_data_module import BaseContrastiveDataModule
from dataset.classification_datasets.text_dataset import TextDataset


class TextDataModule(BaseContrastiveDataModule):
def __init__(
self,
dataset_name: str,
batch_size: int,
num_classes: int,
is_test: bool = False,
transform: Callable = None
):
super().__init__(
dataset_name=dataset_name,
batch_size=batch_size,
is_test=is_test,
transform=transform,
num_classes=num_classes
)

def create_dataset(self, dataset_path: str, stage: str) -> Any:
return TextDataset(dataset_path=dataset_path, stage=stage, is_test=self.is_test)

def collate_fn(self, batch: Any) -> Any:
# batch contains a list of tuples of structure (sequence, target)
a = pad_sequence([item["a_encoding"].squeeze() for item in batch])
b = pad_sequence([item["b_encoding"].squeeze() for item in batch])
label = torch.LongTensor([item["label"] for item in batch])
return (a, b), label

def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
(a, b), label = batch
a = a.to(device)
b = b.to(device)
if isinstance(label, torch.Tensor):
label = label.to(device)
return (a, b), label
5 changes: 4 additions & 1 deletion models/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from .code2class import Code2ClassModel
from .lstm import LSTMModel

__all__ = [
"LSTMModel",
"Code2ClassModel",
"encoder_models"
]

encoder_models = {
"LSTM": LSTMModel
"LSTM": LSTMModel,
"Code2Class": Code2ClassModel
}
15 changes: 15 additions & 0 deletions models/encoders/code2class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from code2seq.dataset import PathContextBatch
from code2seq.model import Code2Class
from code2seq.utils.vocabulary import Vocabulary
from omegaconf import DictConfig
from torch import nn


class Code2ClassModel(nn.Module):
def __init__(self, config: DictConfig, vocabulary: Vocabulary):
super().__init__()
self.num_classes = config.num_classes
self.code2class = Code2Class(config, vocabulary)

def forward(self, batch: PathContextBatch):
return self.code2class(batch.contexts, batch.contexts_per_label)
Loading