-
Notifications
You must be signed in to change notification settings - Fork 375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/2658 add argilla training module for openai with several bug fixes #2691
Feat/2658 add argilla training module for openai with several bug fixes #2691
Conversation
…for opanai support
…ture and hold init args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of small improvements, I like it. I found a few nits that should be pretty easy to resolve.
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Can we change this function name? argilla/tests/training/helpers.py Lines 22 to 30 in 09ce4a2
Perhaps into |
Hello! ## Pull Request overview * Add [SpanMarker](https://github.com/tomaarsen/SpanMarkerNER) Argilla Trainer for Named Entity Recognition. ## Details The SpanMarker Argilla trainer is based on the Transformers Trainer, as SpanMarker is tightly implemented on top of transformers. However, we don't need to do the tokenization, data collation or evaluation on the Argilla side, unlike with Transformers. This makes the SpanMarker Argilla trainer relatively small. ## Usage First, we need an annotated dataset: ```python import argilla as rg from datasets import load_dataset dataset = "conll2003" dataset_ds = load_dataset( "conll2003", split="train[:1000]", ) dataset_ds = dataset_ds.rename_column("ner_tags", "tags") dataset_rb = rg.read_datasets(dataset_ds, task="TokenClassification") rg.delete(dataset) rg.log(name=dataset, records=dataset_rb) ``` And then we can use the new Trainer to train with this dataset: ```python import argilla as rg from argilla.training.base import ArgillaTrainer dataset = "conll2003" trainer = ArgillaTrainer(name=dataset, framework="span_marker", train_size=0.8) trainer.update_config( num_train_epochs=10, bf16=True, per_device_train_batch_size=16, per_device_eval_batch_size=16, marker_max_length=128, entity_max_length=8, ) trainer.train(output_dir="tmp_span_marker_train") ``` (You can use lower batch sizes or `model_max_length=256` if you have memory issues. You can also use `fp16` instead of `bf16` if you get an error.) This produces the following logs: <details><summary>Click to see the logs</summary> ``` [04/13/23 16:25:25] WARNING WARNING:argilla.client.datasets:No label schema provided. Using all_labels: TokenClassificationSettings(['LOC', 'MISC', 'ORG', 'PER']). We recommend datasets.py:1222 providing a `TokenClassificationSettings()` or setting it `rg.configure_dataset_settings()`/`rg.load_dataset_settings()` to ensure reproducibility. [04/13/23 16:25:30] WARNING WARNING:ArgillaTrainer:ArgillaBaseTrainer info: base.py:175 _________________________________________________________________ These baseline params are fixed: dataset: conll2003 task: DatasetForTokenClassification multi_label: False train_size: 0.8 seed: None <class 'argilla.training.span_marker.ArgillaSpanMarkerTrainer'> info: _________________________________________________________________ The parameters are configurable via `trainer.update_config()`: 'SpanMarkerModel' pretrained_model_name_or_path: bert-base-cased labels: ['O', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER'] 'Trainer' overwrite_output_dir: False do_train: False do_eval: False do_predict: False evaluation_strategy: epoch prediction_loss_only: False per_device_train_batch_size: 8 per_device_eval_batch_size: 8 per_gpu_train_batch_size: None per_gpu_eval_batch_size: None gradient_accumulation_steps: 1 eval_accumulation_steps: None eval_delay: 0 learning_rate: 5e-05 weight_decay: 0.01 adam_beta1: 0.9 adam_beta2: 0.999 adam_epsilon: 1e-08 max_grad_norm: 1.0 num_train_epochs: 3.0 max_steps: -1 lr_scheduler_type: linear warmup_ratio: 0.0 warmup_steps: 0 log_level: passive log_level_replica: warning log_on_each_node: True logging_dir: None logging_strategy: steps logging_first_step: False logging_steps: 30 logging_nan_inf_filter: True save_strategy: steps save_steps: 500 save_total_limit: None save_on_each_node: False no_cuda: False use_mps_device: False seed: 42 data_seed: None jit_mode_eval: False use_ipex: False bf16: False fp16: False fp16_opt_level: O1 half_precision_backend: auto bf16_full_eval: False fp16_full_eval: False tf32: None local_rank: -1 xpu_backend: None tpu_num_cores: None tpu_metrics_debug: False debug: dataloader_drop_last: False eval_steps: None dataloader_num_workers: 0 past_index: -1 run_name: None disable_tqdm: None remove_unused_columns: True label_names: None load_best_model_at_end: False metric_for_best_model: None greater_is_better: None ignore_data_skip: False sharded_ddp: fsdp: fsdp_min_num_params: 0 fsdp_config: None fsdp_transformer_layer_cls_to_wrap: None deepspeed: None label_smoothing_factor: 0.0 optim: adamw_hf optim_args: None adafactor: False group_by_length: False length_column_name: length report_to: None ddp_find_unused_parameters: None ddp_bucket_cap_mb: None dataloader_pin_memory: True skip_memory_metrics: True use_legacy_prediction_loop: False push_to_hub: False resume_from_checkpoint: None hub_model_id: None hub_strategy: every_save hub_token: None hub_private_repo: False gradient_checkpointing: False include_inputs_for_metrics: False fp16_backend: auto push_to_hub_model_id: None push_to_hub_organization: None push_to_hub_token: None mp_parameters: auto_find_batch_size: False full_determinism: False torchdynamo: None ray_scope: last ddp_timeout: 1800 torch_compile: False torch_compile_backend: None torch_compile_mode: None output_dir: None Using the trainer: _________________________________________________________________ `trainer.train(output_dir)` to train to start training. `output_dir` is the directory to save the model automatically. `trainer.predict(text, as_argilla_records=True)` to make predictions. `trainer.save(output_dir)` to save the model manually. Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight'] - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). C:\Users\tom\.conda\envs\argilla\lib\site-packages\transformers\optimization.py:395: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning FutureWarning, {'loss': 0.2601, 'learning_rate': 4.754098360655738e-05, 'epoch': 0.49} {'loss': 0.066, 'learning_rate': 4.508196721311476e-05, 'epoch': 0.98} {'eval_loss': 0.03822080418467522, 'eval_overall_precision': 0.8810289389067524, 'eval_overall_recall': 0.6372093023255814, 'eval_overall_f1': 0.7395411605937922, 'eval_overall_accuracy': 0.9204204204204204, 'eval_runtime': 4.5471, 'eval_samples_per_second': 57.619, 'eval_steps_per_second': 3.739, 'epoch': 1.0} {'loss': 0.0378, 'learning_rate': 4.262295081967213e-05, 'epoch': 1.48} {'loss': 0.029, 'learning_rate': 4.016393442622951e-05, 'epoch': 1.97} {'eval_loss': 0.021275097504258156, 'eval_overall_precision': 0.8808933002481389, 'eval_overall_recall': 0.8255813953488372, 'eval_overall_f1': 0.8523409363745499, 'eval_overall_accuracy': 0.9602102102102102, 'eval_runtime': 4.5328, 'eval_samples_per_second': 57.8, 'eval_steps_per_second': 3.75, 'epoch': 2.0} {'loss': 0.0169, 'learning_rate': 3.7704918032786885e-05, 'epoch': 2.46} {'loss': 0.015, 'learning_rate': 3.524590163934427e-05, 'epoch': 2.95} {'eval_loss': 0.013853945769369602, 'eval_overall_precision': 0.9527363184079602, 'eval_overall_recall': 0.8906976744186047, 'eval_overall_f1': 0.920673076923077, 'eval_overall_accuracy': 0.9774774774774775, 'eval_runtime': 4.4899, 'eval_samples_per_second': 58.353, 'eval_steps_per_second': 3.786, 'epoch': 3.0} {'loss': 0.01, 'learning_rate': 3.2786885245901635e-05, 'epoch': 3.44} {'loss': 0.0089, 'learning_rate': 3.0327868852459017e-05, 'epoch': 3.93} {'eval_loss': 0.013151775114238262, 'eval_overall_precision': 0.948780487804878, 'eval_overall_recall': 0.9046511627906977, 'eval_overall_f1': 0.9261904761904762, 'eval_overall_accuracy': 0.9786036036036037, 'eval_runtime': 4.5472, 'eval_samples_per_second': 57.618, 'eval_steps_per_second': 3.739, 'epoch': 4.0} {'loss': 0.0056, 'learning_rate': 2.7868852459016392e-05, 'epoch': 4.43} {'loss': 0.0059, 'learning_rate': 2.540983606557377e-05, 'epoch': 4.92} {'eval_loss': 0.012591547332704067, 'eval_overall_precision': 0.9587378640776699, 'eval_overall_recall': 0.9186046511627907, 'eval_overall_f1': 0.9382422802850356, 'eval_overall_accuracy': 0.9812312312312312, 'eval_runtime': 4.5184, 'eval_samples_per_second': 57.986, 'eval_steps_per_second': 3.762, 'epoch': 5.0} {'loss': 0.0036, 'learning_rate': 2.295081967213115e-05, 'epoch': 5.41} {'loss': 0.0044, 'learning_rate': 2.0491803278688525e-05, 'epoch': 5.9} {'eval_loss': 0.012911035679280758, 'eval_overall_precision': 0.9539951573849879, 'eval_overall_recall': 0.9162790697674419, 'eval_overall_f1': 0.9347568208778173, 'eval_overall_accuracy': 0.9808558558558559, 'eval_runtime': 4.3026, 'eval_samples_per_second': 60.893, 'eval_steps_per_second': 3.951, 'epoch': 6.0} {'loss': 0.0031, 'learning_rate': 1.8032786885245903e-05, 'epoch': 6.39} {'loss': 0.0024, 'learning_rate': 1.557377049180328e-05, 'epoch': 6.89} {'eval_loss': 0.010593990795314312, 'eval_overall_precision': 0.9567307692307693, 'eval_overall_recall': 0.9255813953488372, 'eval_overall_f1': 0.9408983451536642, 'eval_overall_accuracy': 0.9838588588588588, 'eval_runtime': 4.5172, 'eval_samples_per_second': 58.001, 'eval_steps_per_second': 3.763, 'epoch': 7.0} {'loss': 0.0018, 'learning_rate': 1.3114754098360657e-05, 'epoch': 7.38} {'loss': 0.0025, 'learning_rate': 1.0655737704918032e-05, 'epoch': 7.87} {'eval_loss': 0.010297469794750214, 'eval_overall_precision': 0.9478672985781991, 'eval_overall_recall': 0.9302325581395349, 'eval_overall_f1': 0.9389671361502349, 'eval_overall_accuracy': 0.9846096096096096, 'eval_runtime': 4.4597, 'eval_samples_per_second': 58.748, 'eval_steps_per_second': 3.812, 'epoch': 8.0} {'loss': 0.0011, 'learning_rate': 8.196721311475409e-06, 'epoch': 8.36} {'loss': 0.0009, 'learning_rate': 5.737704918032787e-06, 'epoch': 8.85} {'eval_loss': 0.009832620620727539, 'eval_overall_precision': 0.9478672985781991, 'eval_overall_recall': 0.9302325581395349, 'eval_overall_f1': 0.9389671361502349, 'eval_overall_accuracy': 0.9846096096096096, 'eval_runtime': 4.6756, 'eval_samples_per_second': 56.036, 'eval_steps_per_second': 3.636, 'epoch': 9.0} {'loss': 0.0015, 'learning_rate': 3.278688524590164e-06, 'epoch': 9.34} {'loss': 0.0013, 'learning_rate': 8.19672131147541e-07, 'epoch': 9.84} {'eval_loss': 0.010347267612814903, 'eval_overall_precision': 0.9457547169811321, 'eval_overall_recall': 0.9325581395348838, 'eval_overall_f1': 0.9391100702576113, 'eval_overall_accuracy': 0.9846096096096096, 'eval_runtime': 4.7101, 'eval_samples_per_second': 55.625, 'eval_steps_per_second': 3.609, 'epoch': 10.0} {'train_runtime': 297.0633, 'train_samples_per_second': 32.855, 'train_steps_per_second': 2.053, 'train_loss': 0.023517039891515597, 'epoch': 10.0} 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 610/610 [04:57<00:00, 2.05it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:04<00:00, 3.84it/s] [04/13/23 16:30:34] INFO INFO:ArgillaSpanMarkerTrainer:{'eval_loss': 0.010347267612814903, 'eval_overall_precision': 0.9457547169811321, 'eval_overall_recall': span_marker.py:133 0.9325581395348838, 'eval_overall_f1': 0.9391100702576113, 'eval_overall_accuracy': 0.9846096096096096, 'eval_runtime': 4.6647, 'eval_samples_per_second': 56.166, 'eval_steps_per_second': 3.644, 'epoch': 10.0} ``` </details> In short, I trained to 0.939 eval F1 on CoNLL03 in 5 minutes. ### Type of change - [x] New feature (non-breaking change which adds functionality) ### How Has This Been Tested Tests still need to be written. I'll be working on this - but I'll publish this as a draft already so it's available for reviews already. <!-- **Checklist** - [ ] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [ ] follows the style guidelines of this project - [ ] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --> - Tom Aarsen --------- Co-authored-by: David Berenstein <david.m.berenstein@gmail.com>
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some minor comments cc. @davidberenstein1957
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
openai
isn't in the CI dependencies yet.
chore: switched `click` to `typer`
Co-authored-by: Alvaro Bartolome <alvarobartt@yahoo.com>
chore: removed while-loop openai trainer
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
fix: replaced pass for docstrings abc chore: added tests spacy wo training chore: added tests transformers wo training chore: added specific setfit versioning
Description
Updated the argilla.training integration
Closes #2658
Closes #2665
Closes #2659
Type of change
(Please delete options that are not relevant. Remember to title the PR according to the type of change)
How Has This Been Tested
(Please describe the tests that you ran to verify your changes. And ideally, reference
tests
)argilla/tests/training/*
Checklist