Skip to content

Commit

Permalink
Refactor generate.py; fixes #118, fixes #119
Browse files Browse the repository at this point in the history
  • Loading branch information
Yannic Bracke committed Feb 10, 2025
1 parent 5102508 commit 8d5d92a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 26 deletions.
79 changes: 57 additions & 22 deletions src/transnormer/models/generate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
from typing import Dict, List, Optional, Union
import os

from functools import partial
from typing import Dict, List, Optional, Union

import tomli
import random
Expand All @@ -10,7 +12,7 @@
import transformers

from transnormer.preprocess import translit
from transnormer.data.process import sort_dataset_by_length
from transnormer.data.process import sort_dataset_by_length, filter_dataset_by_length


def set_seeds(n: int = 42):
Expand Down Expand Up @@ -43,6 +45,31 @@ def generate_normalization(
return batch


def load_data(
path, n_examples: Optional[int] = None, split: Optional[str] = None
) -> datasets.Dataset:
"""
Load the dataset.
Dataset can be either a JSON file, a directory of JSON files or
huggingface name of a dataset.
"""

# Which split to use?
s = split if split is not None else "train"
if os.path.isfile(path):
ds = datasets.load_dataset("json", data_files=path, split=s)
elif os.path.isdir(path):
ds = datasets.load_dataset("json", data_dir=path, split=s)
else:
try:
ds = datasets.load_dataset(path, split=s)
except datasets.exceptions.DatasetNotFoundError as e:
raise e(f"Path '{path}' is no existing file or directory.")

return ds


def parse_and_check_arguments(
arguments: Optional[List[str]] = None,
) -> argparse.Namespace:
Expand Down Expand Up @@ -75,7 +102,7 @@ def main(arguments: Optional[List[str]] = None) -> None:
CONFIGS = tomli.load(fp)

# (2) Preparations
# (2.1) Fix random, numpy and torch seeds for reproducibilty
# (2.1) Set fixed seed (random, numpy, torch) for reproducibilty
set_seeds(CONFIGS["random_seed"])

# (2.2) GPU set-up
Expand All @@ -85,12 +112,9 @@ def main(arguments: Optional[List[str]] = None) -> None:
)

# (3) Data
ds = datasets.load_dataset("json", data_files=CONFIGS["data"]["path_test"])

# Optional: Use a fixed number of test set examples
n = CONFIGS["data"].get("n_examples_test")
if n:
ds["train"] = ds["train"].shuffle().select(range(n))
data_path = CONFIGS["data"]["path_test"]
split = CONFIGS["data"]["split"]
ds = load_data(data_path, split)

# (4) Tokenizers and transliterator
# Load tokenizer
Expand All @@ -111,21 +135,32 @@ def main(arguments: Optional[List[str]] = None) -> None:
checkpoint = CONFIGS["model"]["checkpoint"]
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device)

# (6) Generation
# Parameters for model output
gen_cfg = transformers.GenerationConfig(**CONFIGS["generation_config"])

# (6) Data preparation
# Sort by length
dataset = ds["train"]
index_column = "#"
dataset = sort_dataset_by_length(
dataset,
len_column = "len"
ds = sort_dataset_by_length(
ds,
"orig",
descending=True,
keep_length_column=False,
name_index_column=index_column,
name_length_column=len_column,
use_bytelength=True,
)
ds["train"] = dataset

# Optional: Filter out samples that exceed given length
k = CONFIGS["data"].get("max_length")
if k:
ds = filter_dataset_by_length(ds, max_length=k, name_length_column=len_column)

# Optional: Clip dataset to fixed number of samples
n = CONFIGS["data"].get("n_examples_test")
if n:
ds = ds.shuffle().select(range(n))

# (7) Generation
# Parameters for model output
gen_cfg = transformers.GenerationConfig(**CONFIGS["generation_config"])

# Prepare generation function as a partial function (only batch missing)
normalize = partial(
Expand All @@ -146,11 +181,11 @@ def main(arguments: Optional[List[str]] = None) -> None:
)

# Sort in original order
ds["train"] = ds["train"].sort(index_column)
ds["train"] = ds["train"].remove_columns(index_column)
ds = ds.sort(index_column)
ds = ds.remove_columns([index_column, len_column])

# (7) Save outputs
ds["train"].to_json(args.out, force_ascii=False)
# (8) Save outputs
ds.to_json(args.out, force_ascii=False)


if __name__ == "__main__":
Expand Down
9 changes: 5 additions & 4 deletions test_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@ gpu = "cuda:0" # when you did `export CUDA_VISIBLE_DEVICES=1` first
random_seed = 42

[data]
path_test = "/home/bracke/code/transnormer/data/raw/dta/jsonl/v09-lm/not-dtaec/1700-1799/dtak-test.jsonl"
# n_examples_test = 24
path_test = "data/dtak-transnormer-basic-v1/data/test/1700-1799"
split = "test"
# n_examples_test = 100
max_bytelength = 510

[tokenizer]
checkpoint_in = "google/byt5-small"

[tokenizer_configs]
padding = "longest" # pad to the longest sequence in batch
truncation = false
max_length = 1024

[model]
checkpoint = "/home/bracke/code/transnormer/models/models_2024-10-15"
checkpoint = "models/models_2024-12-20"

[generation]
batch_size = 32
Expand Down

0 comments on commit 8d5d92a

Please sign in to comment.