-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add_dataset_dataloader #1
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Этот файл не нужно было в коммит добавлять)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Этот файл не нужно было в коммит добавлять)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Исправь мелкие замечания
- перепиши в
dataset.py
функцию_create_dataset
. Нужна другая стратегия токенизации для столбцов.
@@ -0,0 +1,29 @@ | |||
{ | |||
"num_labels": 170, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Пока не нужен параметр
"num_labels": 170, | ||
"num_gpu": 4, | ||
"save_period_in_epochs": 10, | ||
"metrics": ["f1_micro", "f1_macro", "f1_weighted"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Не нужен
"save_period_in_epochs": 10, | ||
"metrics": ["f1_micro", "f1_macro", "f1_weighted"], | ||
"pretrained_model_name": "bert-base-multilingual-uncased", | ||
"table_serialization_type": "column_wise", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Не нужен
"start_from_checkpoint": false, | ||
"checkpoint_dir": "checkpoints/", | ||
"checkpoint_name": "model_best_f1_weighted.pt", | ||
"inference_model_name": "model_table_wise.pt", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Не нужен
"checkpoint_dir": "checkpoints/", | ||
"checkpoint_name": "model_best_f1_weighted.pt", | ||
"inference_model_name": "model_table_wise.pt", | ||
"inference_dir": "data/inference/", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Не нужен
train_df = dataset[~dataset["table_id"].isin(valid_mask)] | ||
train_ids = train_df.index.to_numpy() | ||
|
||
# valid_ids = dataset_ids[0:len_valid] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Комменты можно удалить
pass | ||
|
||
|
||
# from config import Config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Код для теста обертки лучше оставить в этом блоке (раскомментить)
def __getitem__(self, idx): | ||
return { | ||
"data": self.df.iloc[idx]["data"], | ||
"labels": self.df.iloc[idx]["labels"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
У нас вместо лейблов будут заголовки столбцов, поэтому тут заменить на headers
def _create_dataset(self, df: pd.DataFrame, tokenizer: PreTrainedTokenizerBase) -> pd.DataFrame: | ||
"""Tokenize columns data. | ||
|
||
Groups columns by table_id's and tokenizes columns data. | ||
|
||
Tokenized columns are flatten into sequence, like so: | ||
|
||
[CLS] token_11 token_12 ... [SEP] [CLS] token_21 ... [SEP] | ||
|
||
Args: | ||
df: Entire dataset as dataframe object. | ||
tokenizer: Pretrained BERT tokenizer. | ||
|
||
Returns: | ||
pd.Dataframe: Dataset, grouped by tables and tokenized. | ||
""" | ||
|
||
data_list = [] | ||
for table_id, table in tqdm(df.groupby("table_id")): | ||
num_cols = len(table) | ||
|
||
# Tokenize table columns. | ||
tokenized_table_columns = table["column_data"].apply( | ||
lambda x: tokenizer.encode( | ||
# max_length for SINGLE COLUMN. Not for table as sequence. | ||
# BERT maximum input length = 512. So, max_length = (512 // num_cols). | ||
x, add_special_tokens=True, max_length=(512 // num_cols), truncation=True | ||
) | ||
).tolist() | ||
|
||
# Concat table columns into one sequence. | ||
concat_tok_table_columns = list(chain.from_iterable(tokenized_table_columns)) | ||
tokenized_columns_seq = torch.LongTensor(concat_tok_table_columns) | ||
|
||
# Use Long, because CrossEntropyLoss works with Long tensors. | ||
labels = torch.LongTensor(table["label_id"].values) | ||
|
||
data_list.append( | ||
[table_id, num_cols, tokenized_columns_seq, labels] | ||
) | ||
|
||
return pd.DataFrame( | ||
data_list, | ||
columns=["table_id", "n_cols", "data", "labels"] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
В данном случае нужна другая стратегия токенизации, более простая.
В RuTaBERT мы токенизировали все столбцы таблицы в одну последовательность [CLS] token_11 token_12 ... [SEP] [CLS] token_21 ... [SEP]
.
Сейчас нам нужно просто каждый столбец таблицы токенизировать + вставить в начало столбца специальный токен [CLS]
и [SEP]
в конец столбца.
Вернуть в том же формате. Вместо столбца labels
в результирующем датафреме у нас будет заголовок, поэтому замени его название на header
.
|
||
|
||
if __name__ == "__main__": | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Добавить пример использования датасета, и ты забыл пустую строку добавить в конец файла.
No description provided.