Skip to content

Commit

Permalink
Merge pull request #3 from JetBrains-Research/code2seq
Browse files Browse the repository at this point in the history
Code2class
  • Loading branch information
maximzubkov committed Mar 16, 2021
2 parents 98cd7ae + 79c1a62 commit 1c1b03b
Show file tree
Hide file tree
Showing 15 changed files with 295 additions and 65 deletions.
74 changes: 74 additions & 0 deletions configs/code2class-poj104.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
hydra:
run:
dir: .
output_subdir: null
job_logging: null
hydra_logging: null

name: code2class

seed: 9
num_workers: 2
log_offline: false

num_classes: 5

# data keys
data_folder: data
vocabulary_name: vocabulary.pkl
train_holdout: train
val_holdout: val
test_holdout: test

save_every_epoch: 1
val_every_epoch: 1
log_every_epoch: 10
progress_bar_refresh_rate: 1

hyper_parameters:
n_epochs: 3000
patience: 10
batch_size: 16
test_batch_size: 512
clip_norm: 5
max_context: 200
random_context: true
shuffle_data: true

optimizer: "Momentum"
nesterov: true
learning_rate: 0.01
weight_decay: 0
decay_gamma: 0.95

dataset:
name: poj_104
target:
max_parts: 1
is_wrapped: false
is_splitted: false
vocabulary_size: 27000
token:
max_parts: 5
is_wrapped: false
is_splitted: true
vocabulary_size: 190000
path:
max_parts: 9
is_wrapped: false
is_splitted: true
vocabulary_size: null

encoder:
embedding_size: 16
rnn_size: 16
use_bi_rnn: true
embedding_dropout: 0.25
rnn_num_layers: 1
rnn_dropout: 0.5

classifier:
n_hidden_layers: 2
hidden_size: 16
classifier_input_size: 16
activation: relu
15 changes: 12 additions & 3 deletions dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from .base_data_module import BaseDataModule
from dataset.classification_datasets.text_dataset import TextDataset
from dataset.data_modules import TextDataModule, PathDataModule
from .base_data_module import BaseContrastiveDataModule
from .contrastive_dataset import ContrastiveDataset
from .classification_datasets.text_dataset import TextDataset

__all__ = [
"TextDataset",
"TextDataModule",
"PathDataModule",
"ContrastiveDataset",
"BaseDataModule",
"BaseContrastiveDataModule",
"data_modules"
]

data_modules = {
"LSTM": TextDataModule,
"Code2Class": PathDataModule
}
53 changes: 16 additions & 37 deletions dataset/base_data_module.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,42 @@
from os import walk
from abc import abstractmethod
from os.path import exists
from os.path import join
from typing import Any, Callable

import torch
from pytorch_lightning import LightningDataModule
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

from .contrastive_dataset import ContrastiveDataset
from .download import load_dataset
from .classification_datasets.text_dataset import get_text_dataset

encoder2datasets = {
"LSTM": get_text_dataset,
}

SEED = 9


class BaseDataModule(LightningDataModule):
class BaseContrastiveDataModule(LightningDataModule):
def __init__(
self,
encoder_name: str,
dataset_name: str,
num_classes: int,
batch_size: int,
is_test: bool = False,
transform: Callable = None
):
super().__init__()
if encoder_name in encoder2datasets:
self.get_dataset = encoder2datasets[encoder_name]
else:
raise NotImplementedError(f"Dataset for {encoder_name} is currently not available")

self.dataset_name = dataset_name
self.dataset_path = join("data", dataset_name)
self.batch_size = batch_size
_, base_dirs, _ = next(walk(join(self.dataset_path, "train")))
self.num_classes = len(base_dirs)
self.num_classes = num_classes
self.transform = transform
self.is_test = is_test

self.clf_dataset = {}
self.contrastive_dataset = {}

@abstractmethod
def create_dataset(self, dataset_path: str, stage: str) -> Any:
pass

def prepare_data(self):
if not exists(self.dataset_path):
load_dataset(self.dataset_name)
Expand All @@ -57,11 +49,7 @@ def setup(self, stage: str = None):
stages += ["test"]

for stage in stages:
self.clf_dataset[stage] = self.get_dataset(
dataset_path=self.dataset_path,
stage=stage,
is_test=self.is_test
)
self.clf_dataset[stage] = self.create_dataset(dataset_path=self.dataset_path, stage=stage)
self.contrastive_dataset[stage] = ContrastiveDataset(clf_dataset=self.clf_dataset[stage])

def train_dataloader(self):
Expand Down Expand Up @@ -91,22 +79,13 @@ def test_dataloader(self):
drop_last=True
)

def _collate(self, batch):
# batch contains a list of tuples of structure (sequence, target)
a = pad_sequence([item["a_encoding"].squeeze() for item in batch])
b = pad_sequence([item["b_encoding"].squeeze() for item in batch])
label = torch.LongTensor([item["label"] for item in batch])
data = (a, b), label
@abstractmethod
def collate_fn(self, batch: Any) -> Any:
pass

def _collate(self, batch: Any) -> Any:
batch = self.collate_fn(batch)
if self.transform is not None:
return self.transform(data)
return self.transform(batch)
else:
return data

def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
(a, b), label = batch
a = a.to(device)
b = b.to(device)
if isinstance(label, torch.Tensor):
label = label.to(device)
return (a, b), label
return batch
7 changes: 7 additions & 0 deletions dataset/classification_datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .path_dataset import PathDataset
from .text_dataset import TextDataset

__all__ = [
"TextDataset",
"PathDataset"
]
14 changes: 14 additions & 0 deletions dataset/classification_datasets/path_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Optional, Tuple, Any

from code2seq.dataset import PathContextDataset, PathContextSample
from code2seq.utils.vocabulary import Vocabulary
from omegaconf import DictConfig


class PathDataset(PathContextDataset):
def __init__(self, data_file_path: str, config: DictConfig, vocabulary: Vocabulary, random_context: bool):
super().__init__(data_file_path, config, vocabulary, random_context)

def __getitem__(self, index) -> Optional[Tuple[PathContextSample, Any]]:
pcs = super().__getitem__(index)
return pcs, pcs.label[0][0]
4 changes: 0 additions & 4 deletions dataset/classification_datasets/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
from preprocess import tokenize


def get_text_dataset(dataset_path: str, stage: str, is_test: bool = False):
return TextDataset(dataset_path=dataset_path, stage=stage, is_test=is_test)


class TextDataset(Dataset):
def __init__(self, dataset_path: str, stage: str, is_test: bool = False):
super().__init__()
Expand Down
7 changes: 7 additions & 0 deletions dataset/data_modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .path_data_module import PathDataModule
from .text_data_module import TextDataModule

__all__ = [
"TextDataModule",
"PathDataModule"
]
71 changes: 71 additions & 0 deletions dataset/data_modules/path_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from os.path import join
from typing import Callable, Any, Optional, Tuple

import torch
from code2seq.dataset import PathContextBatch
from code2seq.utils.vocabulary import Vocabulary
from omegaconf import DictConfig, OmegaConf

from dataset.base_data_module import BaseContrastiveDataModule
from dataset.classification_datasets import PathDataset


def get_config() -> DictConfig:
return OmegaConf.load("configs/code2class-poj104.yaml")


class PathDataModule(BaseContrastiveDataModule):
def __init__(
self,
dataset_name: str,
batch_size: int,
num_classes: int,
is_test: bool = False,
transform: Callable = None,
):

config = get_config()
self._config = config
self._vocabulary = Vocabulary.load_vocabulary(
join(config.data_folder, config.dataset.name, config.vocabulary_name)
)

self._dataset_dir = join(config.data_folder, config.dataset.name)
self._train_data_file = join(self._dataset_dir, f"{config.dataset.name}.{config.train_holdout}.c2s")
self._val_data_file = join(self._dataset_dir, f"{config.dataset.name}.{config.val_holdout}.c2s")
self._test_data_file = join(self._dataset_dir, f"{config.dataset.name}.{config.test_holdout}.c2s")

self.stage2path = {
"train": self._train_data_file,
"test": self._test_data_file,
"val": self._val_data_file
}

BaseContrastiveDataModule.__init__(
self,
dataset_name=dataset_name,
batch_size=config.hyper_parameters.batch_size,
is_test=is_test,
transform=transform,
num_classes=num_classes
)

def create_dataset(self, dataset_path: str, stage: str) -> Any:
return PathDataset(self.stage2path[stage], self._config, self._vocabulary, False)

def collate_fn(self, batch: Any) -> Any:
a_pc = [sample["a_encoding"] for sample in batch]
b_pc = [sample["b_encoding"] for sample in batch]
labels = [sample["label"] for sample in batch]
a_pc = PathContextBatch(a_pc)
b_pc = PathContextBatch(b_pc)
return (a_pc, b_pc), torch.LongTensor(labels)

def transfer_batch_to_device(
self, batch: Tuple[PathContextBatch, torch.Tensor], device: Optional[torch.device] = None
) -> Tuple[PathContextBatch, torch.Tensor]:
pc, labels = batch
if device is not None:
pc.move_to_device(device)
labels.to(device)
return pc, labels
43 changes: 43 additions & 0 deletions dataset/data_modules/text_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Any, Callable

import torch
from torch.nn.utils.rnn import pad_sequence

from dataset.base_data_module import BaseContrastiveDataModule
from dataset.classification_datasets.text_dataset import TextDataset


class TextDataModule(BaseContrastiveDataModule):
def __init__(
self,
dataset_name: str,
batch_size: int,
num_classes: int,
is_test: bool = False,
transform: Callable = None
):
super().__init__(
dataset_name=dataset_name,
batch_size=batch_size,
is_test=is_test,
transform=transform,
num_classes=num_classes
)

def create_dataset(self, dataset_path: str, stage: str) -> Any:
return TextDataset(dataset_path=dataset_path, stage=stage, is_test=self.is_test)

def collate_fn(self, batch: Any) -> Any:
# batch contains a list of tuples of structure (sequence, target)
a = pad_sequence([item["a_encoding"].squeeze() for item in batch])
b = pad_sequence([item["b_encoding"].squeeze() for item in batch])
label = torch.LongTensor([item["label"] for item in batch])
return (a, b), label

def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
(a, b), label = batch
a = a.to(device)
b = b.to(device)
if isinstance(label, torch.Tensor):
label = label.to(device)
return (a, b), label
5 changes: 4 additions & 1 deletion models/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from .code2class import Code2ClassModel
from .lstm import LSTMModel

__all__ = [
"LSTMModel",
"Code2ClassModel",
"encoder_models"
]

encoder_models = {
"LSTM": LSTMModel
"LSTM": LSTMModel,
"Code2Class": Code2ClassModel
}
15 changes: 15 additions & 0 deletions models/encoders/code2class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from code2seq.dataset import PathContextBatch
from code2seq.model import Code2Class
from code2seq.utils.vocabulary import Vocabulary
from omegaconf import DictConfig
from torch import nn


class Code2ClassModel(nn.Module):
def __init__(self, config: DictConfig, vocabulary: Vocabulary):
super().__init__()
self.num_classes = config.num_classes
self.code2class = Code2Class(config, vocabulary)

def forward(self, batch: PathContextBatch):
return self.code2class(batch.contexts, batch.contexts_per_label)
Loading

0 comments on commit 1c1b03b

Please sign in to comment.