-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'develop_summary' into main
- Loading branch information
Showing
23 changed files
with
1,117 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
path: | ||
train_path: ./datas/Training/all.csv | ||
dev_path: ./datas/Validation/all_dual_05.csv | ||
predict_path: ./datas/Test/all_dual_05.csv | ||
|
||
# TODO | ||
exp: | ||
exp_name : dialog_R3F_only_dual_new_only_R3F # 변경 해야함!! 실험 이름 | ||
project_name : sglee_sum # 변경 해야함!! 실험 공간 ex) Moongi_exp | ||
|
||
# TODO | ||
model: | ||
model_name: gogamza/kobart-summarization | ||
mode_load_path: None | ||
trainer: | ||
mode: 'blend' # base, blend | ||
|
||
train: | ||
loss_name : 'focal' | ||
gpus: 1 | ||
batch_size: 8 | ||
max_epoch: 10 | ||
learning_rate: 4e-5 | ||
logging_step: 25 | ||
save_total_limit : 5 # number of total save model. | ||
save_steps : 5000 # model saving step. | ||
warmup_steps : 30000 # number of warmup steps for learning rate scheduler | ||
weight_decay : 0.01 # strength of weight decay | ||
logging_steps : 25 # log saving step. | ||
eval_steps : 5000 # evaluation step | ||
|
||
|
||
test: | ||
# pytorch_model.bin을 불러올 경우, pytorch_model.bin을 경로에 포함할 것 | ||
# huggingface에서 KoBART 가중치를 직접 불러올 경우, 해당 경로를 입력할 것. EX) papari1123/summary_bart_dual_R3F_aihub | ||
model_path : "papari1123/summary_bart_dual_R3F_aihub" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
path: | ||
train_path: ./datas/Training/all.csv | ||
dev_path: ./datas/Validation/all_dual_05.csv | ||
predict_path: ./datas/Validation/all_dual_05.csv | ||
|
||
# TODO | ||
exp: | ||
exp_name : dialog_R3F_only_dual_new_only_R3F # 변경 해야함!! 실험 이름 | ||
project_name : sglee_sum # 변경 해야함!! 실험 공간 ex) Moongi_exp | ||
|
||
# TODO | ||
model: | ||
model_name: gogamza/kobart-summarization | ||
mode_load_path: None | ||
trainer: | ||
mode: 'blend' # base, blend | ||
kl_div_lambda: 0.1 | ||
|
||
train: | ||
loss_name : 'focal' | ||
gpus: 1 | ||
batch_size: 8 | ||
max_epoch: 10 | ||
learning_rate: 4e-5 | ||
logging_step: 25 | ||
save_total_limit : 5 # number of total save model. | ||
save_steps : 5000 # model saving step. | ||
warmup_steps : 30000 # number of warmup steps for learning rate scheduler | ||
weight_decay : 0.01 # strength of weight decay | ||
logging_steps : 25 # log saving step. | ||
eval_steps : 5000 # evaluation step | ||
|
||
|
||
test: | ||
model_path : saved/dialog_dual_only_R3F/checkpoint-135000/pytorch_model.bin | ||
prediction : None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
path: | ||
train_path: ./datas/subject_text/train_data.csv | ||
dev_path: ./datas/subject_text/val_data.csv | ||
predict_path: ./datas/Validation/val_data.csv | ||
|
||
# TODO | ||
exp: | ||
exp_name : subject_cls_binary # 변경 해야함!! 실험 이름 | ||
project_name : sglee_sub_cls # 변경 해야함!! 실험 공간 ex) Moongi_exp | ||
|
||
# TODO | ||
model: | ||
model_name: papari1123/summary_bart_dual_R3F_aihub | ||
mode_load_path: None #saved/subject_cls/checkpoint-3300/pytorch_model.bin | ||
cls: 'binary' # binary, multi | ||
|
||
trainer: | ||
mode: 'base' # only base | ||
|
||
train: | ||
gpus: 1 | ||
batch_size: 256 | ||
max_epoch: 30 | ||
learning_rate: 1e-4 | ||
logging_step: 25 | ||
save_total_limit : 3 # number of total save model. | ||
save_steps : 330 # model saving step. | ||
warmup_steps : 0 # number of warmup steps for learning rate scheduler | ||
weight_decay : 0.01 # strength of weight decay | ||
logging_steps : 25 # log saving step. | ||
eval_steps : 330 # evaluation step | ||
|
||
|
||
test: | ||
model_path : None | ||
prediction : None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
path: | ||
train_path: ./datas/subject_text/train_data.csv | ||
dev_path: ./datas/subject_text/val_data.csv | ||
predict_path: ./datas/Validation/val_data.csv | ||
|
||
# TODO | ||
exp: | ||
exp_name : subject_cls_multi_no_freeze # 변경 해야함!! 실험 이름 | ||
project_name : sglee_sub_cls # 변경 해야함!! 실험 공간 ex) Moongi_exp | ||
|
||
# TODO | ||
model: | ||
model_name: papari1123/summary_bart_dual_R3F_aihub | ||
mode_load_path: None #saved/subject_cls/checkpoint-3300/pytorch_model.bin | ||
cls: 'multi' # binary, multi | ||
|
||
trainer: | ||
mode: 'base' # only base | ||
|
||
train: | ||
gpus: 1 | ||
batch_size: 64 | ||
max_epoch: 10 | ||
learning_rate: 1e-4 | ||
logging_step: 25 | ||
save_total_limit : 4 # number of total save model. | ||
save_steps : 1000 # model saving step. | ||
warmup_steps : 0 # number of warmup steps for learning rate scheduler | ||
weight_decay : 0.01 # strength of weight decay | ||
logging_steps : 25 # log saving step. | ||
eval_steps : 1000 # evaluation step | ||
|
||
|
||
test: | ||
model_path : None | ||
prediction : None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import argparse | ||
import os | ||
import glob | ||
import torch | ||
import ast | ||
import numpy as np | ||
import pandas as pd | ||
from tqdm import tqdm, trange | ||
from torch.utils.data import Dataset, DataLoader | ||
import pytorch_lightning as pl | ||
from functools import partial | ||
|
||
class KoBARTSummaryDataset(Dataset): | ||
def __init__(self, file, tokenizer, max_len, ignore_index=-100): | ||
super().__init__() | ||
self.tokenizer = tokenizer | ||
self.max_len = max_len | ||
self.docs = pd.read_csv(file) | ||
self.len = self.docs.shape[0] | ||
|
||
self.pad_index = self.tokenizer.pad_token_id | ||
self.ignore_index = ignore_index | ||
|
||
def add_padding_data(self, inputs): | ||
if len(inputs) < self.max_len: | ||
pad = np.array([self.pad_index] *(self.max_len - len(inputs))) | ||
inputs = np.concatenate([inputs, pad]) | ||
else: | ||
inputs = inputs[:self.max_len] | ||
|
||
return inputs | ||
|
||
def add_ignored_data(self, inputs): | ||
if len(inputs) < self.max_len: | ||
pad = np.array([self.ignore_index] *(self.max_len - len(inputs))) | ||
inputs = np.concatenate([inputs, pad]) | ||
else: | ||
inputs = inputs[:self.max_len] | ||
|
||
return inputs | ||
|
||
def __getitem__(self, idx): | ||
instance = self.docs.iloc[idx] | ||
input_ids = self.tokenizer.encode(instance['context']) | ||
input_ids = self.add_padding_data(input_ids) | ||
|
||
label_ids = self.tokenizer.encode(instance['summary']) | ||
label_ids.append(self.tokenizer.eos_token_id) | ||
dec_input_ids = [self.tokenizer.eos_token_id] | ||
dec_input_ids += label_ids[:-1] | ||
dec_input_ids = self.add_padding_data(dec_input_ids) | ||
label_ids = self.add_ignored_data(label_ids) | ||
|
||
return {'input_ids': np.array(input_ids, dtype=np.int_), | ||
'decoder_input_ids': np.array(dec_input_ids, dtype=np.int_), | ||
'labels': np.array(label_ids, dtype=np.int_)} | ||
|
||
def __len__(self): | ||
return self.len | ||
|
||
class KobartSummaryModule(pl.LightningDataModule): | ||
def __init__(self, train_file, | ||
test_file, tok, | ||
max_len=512, | ||
batch_size=8, | ||
num_workers=4): | ||
super().__init__() | ||
self.batch_size = batch_size | ||
self.max_len = max_len | ||
self.train_file_path = train_file | ||
self.test_file_path = test_file | ||
self.tok = tok | ||
self.num_workers = num_workers | ||
|
||
@staticmethod | ||
def add_model_specific_args(parent_parser): | ||
parser = argparse.ArgumentParser( | ||
parents=[parent_parser], add_help=False) | ||
parser.add_argument('--num_workers', | ||
type=int, | ||
default=4, | ||
help='num of worker for dataloader') | ||
return parser | ||
|
||
# OPTIONAL, called for every GPU/machine (assigning state is OK) | ||
def setup(self, stage): | ||
# split dataset | ||
self.train = KoBARTSummaryDataset(self.train_file_path, | ||
self.tok, | ||
self.max_len) | ||
self.test = KoBARTSummaryDataset(self.test_file_path, | ||
self.tok, | ||
self.max_len) | ||
|
||
def train_dataloader(self): | ||
train = DataLoader(self.train, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, shuffle=True) | ||
return train | ||
|
||
def val_dataloader(self): | ||
val = DataLoader(self.test, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, shuffle=False) | ||
return val | ||
|
||
def test_dataloader(self): | ||
test = DataLoader(self.test, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, shuffle=False) | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import argparse | ||
from train import KoBARTConditionalGeneration | ||
from transformers.models.bart import BartForConditionalGeneration | ||
import yaml | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--hparams", default=None, type=str) | ||
parser.add_argument("--model_binary", default=None, type=str) | ||
parser.add_argument("--output_dir", default=None, type=str) | ||
args = parser.parse_args() | ||
|
||
with open(args.hparams) as f: | ||
hparams = yaml.load(f) | ||
|
||
inf = KoBARTConditionalGeneration.load_from_checkpoint(args.model_binary, hparams=hparams) | ||
|
||
inf.model.save_pretrained(args.output_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python get_model_binary.py --hparams ./{DEFAULT_ROOT_DIR}/tb_logs/default/version_0/hparams.yaml --model_binary ./{DEFAULT_ROOT_DIR}/model_chp/{CKPT_FILE} --output_dir {OUTPUT_DIR} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import pandas as pd | ||
import json | ||
import os | ||
|
||
trainTypes = ['Training', 'Validation'] | ||
domains = ['개인및관계', '미용과건강', '상거래(쇼핑)', '시사교육', '식음료', '여가생활', '일과직업', '주거와생활', '행사'] | ||
pattern = '([ㄱ-ㅎㅏ-ㅣ]+)' | ||
|
||
for trainType in trainTypes: | ||
for domain in domains: | ||
file_path = './data/Korean_speech_summarization/' + trainType + '/' + domain + '.json' | ||
with open(file_path) as f: | ||
data = json.load(f) | ||
|
||
dataLen = len(data['data']) | ||
|
||
X_data, y_data = [], [] | ||
for i in range(dataLen): | ||
dialogue = data['data'][i]['body']['dialogue'] | ||
summary = data['data'][i]['body']['summary'] | ||
string = '' | ||
for j in range(len(dialogue)): | ||
string += dialogue[j]['utterance'] + ' ' | ||
string = re.sub(pattern=pattern, repl='', string=string) | ||
X_data.append(string) | ||
y_data.append(summary) | ||
|
||
df = pd.DataFrame({'passage':X_data, 'summary':y_data}) | ||
|
||
output_dir = './data_csv/' + trainType | ||
save_path = output_dir + '/' + domain + '.csv' | ||
|
||
if not os.path.exists(output_dir): | ||
os.mkdir(output_dir) | ||
|
||
df.to_csv(save_path) |
Oops, something went wrong.