From 0fbf513d0baa70d50447b5bcc4e538849d75dad5 Mon Sep 17 00:00:00 2001 From: Steven Shimizu Date: Tue, 10 Sep 2024 22:26:31 +0000 Subject: [PATCH] Added script for testing --- ...ral_use_liger_False_patching_type_None.log | 15 ++++ ...ger_True_patching_type_post_init_class.log | 15 ++++ ..._True_patching_type_post_init_instance.log | 15 ++++ ..._use_liger_True_patching_type_pre_init.log | 15 ++++ examples/huggingface/run_llama.sh | 2 +- examples/huggingface/run_patching_methods.sh | 60 +++++++++++++ examples/huggingface/training.py | 87 ++++++++++++------- src/liger_kernel/transformers/monkey_patch.py | 42 ++++++--- src/liger_kernel/triton/monkey_patch.py | 6 +- 9 files changed, 210 insertions(+), 47 deletions(-) create mode 100644 examples/huggingface/results/mistral_use_liger_False_patching_type_None.log create mode 100644 examples/huggingface/results/mistral_use_liger_True_patching_type_post_init_class.log create mode 100644 examples/huggingface/results/mistral_use_liger_True_patching_type_post_init_instance.log create mode 100644 examples/huggingface/results/mistral_use_liger_True_patching_type_pre_init.log create mode 100755 examples/huggingface/run_patching_methods.sh diff --git a/examples/huggingface/results/mistral_use_liger_False_patching_type_None.log b/examples/huggingface/results/mistral_use_liger_False_patching_type_None.log new file mode 100644 index 000000000..7df9d2b6b --- /dev/null +++ b/examples/huggingface/results/mistral_use_liger_False_patching_type_None.log @@ -0,0 +1,15 @@ +********** No Patching *********** +********** No Patching *********** +********** No Patching *********** +********** No Patching *********** +{'loss': 1.9423, 'grad_norm': 102.73744201660156, 'learning_rate': 6e-06, 'epoch': 0.0, 'num_input_tokens_seen': 40960} +{'loss': 1.9385, 'grad_norm': 104.32758331298828, 'learning_rate': 5.819077862357725e-06, 'epoch': 0.01, 'num_input_tokens_seen': 78336, 'step': 2, 'step_time_sec': 2.96, 'avg_step_time_sec': 2.96, 'time_to_completion_sec': 23.68, 'estimated_total_time_sec': 29.6, 'step_peak_memory_allocated_MB': 34547.76, 'step_peak_memory_reserved_MB': 46892.0, 'total_peak_memory_allocated_MB': 34547.76, 'total_peak_memory_reserved_MB': 46892.0, 'step_tokens_per_second': 12625.64, 'avg_tokens_per_second': 12625.64} +{'loss': 1.1825, 'grad_norm': 52.640846252441406, 'learning_rate': 5.298133329356934e-06, 'epoch': 0.01, 'num_input_tokens_seen': 118784, 'step': 3, 'step_time_sec': 3.72, 'avg_step_time_sec': 3.34, 'time_to_completion_sec': 23.36, 'estimated_total_time_sec': 33.38, 'step_peak_memory_allocated_MB': 34547.83, 'step_peak_memory_reserved_MB': 51576.0, 'total_peak_memory_allocated_MB': 34547.83, 'total_peak_memory_reserved_MB': 51576.0, 'step_tokens_per_second': 10887.47, 'avg_tokens_per_second': 11658.29} +{'loss': 1.1446, 'grad_norm': 67.16675567626953, 'learning_rate': 4.5e-06, 'epoch': 0.01, 'num_input_tokens_seen': 163328, 'step': 4, 'step_time_sec': 4.63, 'avg_step_time_sec': 3.77, 'time_to_completion_sec': 22.6, 'estimated_total_time_sec': 37.67, 'step_peak_memory_allocated_MB': 34547.91, 'step_peak_memory_reserved_MB': 57572.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 57572.0, 'step_tokens_per_second': 9630.35, 'avg_tokens_per_second': 10828.26} +{'loss': 1.0803, 'grad_norm': 21.94203758239746, 'learning_rate': 3.5209445330007917e-06, 'epoch': 0.01, 'num_input_tokens_seen': 196608, 'step': 5, 'step_time_sec': 2.89, 'avg_step_time_sec': 3.55, 'time_to_completion_sec': 17.74, 'estimated_total_time_sec': 35.48, 'step_peak_memory_allocated_MB': 34547.75, 'step_peak_memory_reserved_MB': 57572.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 57572.0, 'step_tokens_per_second': 11518.67, 'avg_tokens_per_second': 10968.83} +{'loss': 1.098, 'grad_norm': 24.288616180419922, 'learning_rate': 2.4790554669992093e-06, 'epoch': 0.02, 'num_input_tokens_seen': 244736, 'step': 6, 'step_time_sec': 4.63, 'avg_step_time_sec': 3.76, 'time_to_completion_sec': 15.06, 'estimated_total_time_sec': 37.64, 'step_peak_memory_allocated_MB': 34547.72, 'step_peak_memory_reserved_MB': 58180.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 58180.0, 'step_tokens_per_second': 10393.91, 'avg_tokens_per_second': 10827.38} +{'loss': 1.0376, 'grad_norm': 53.7581672668457, 'learning_rate': 1.5000000000000007e-06, 'epoch': 0.02, 'num_input_tokens_seen': 283648, 'step': 7, 'step_time_sec': 3.31, 'avg_step_time_sec': 3.69, 'time_to_completion_sec': 11.06, 'estimated_total_time_sec': 36.88, 'step_peak_memory_allocated_MB': 34547.76, 'step_peak_memory_reserved_MB': 58180.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 58180.0, 'step_tokens_per_second': 11766.3, 'avg_tokens_per_second': 10967.71} +{'loss': 0.9906, 'grad_norm': 10.769705772399902, 'learning_rate': 7.018666706430663e-07, 'epoch': 0.02, 'num_input_tokens_seen': 324608, 'step': 8, 'step_time_sec': 3.88, 'avg_step_time_sec': 3.72, 'time_to_completion_sec': 7.43, 'estimated_total_time_sec': 37.16, 'step_peak_memory_allocated_MB': 34547.84, 'step_peak_memory_reserved_MB': 58180.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 58180.0, 'step_tokens_per_second': 10549.53, 'avg_tokens_per_second': 10905.29} +{'loss': 0.9546, 'grad_norm': 7.883011341094971, 'learning_rate': 1.8092213764227505e-07, 'epoch': 0.02, 'num_input_tokens_seen': 362496, 'step': 9, 'step_time_sec': 3.42, 'avg_step_time_sec': 3.68, 'time_to_completion_sec': 3.68, 'estimated_total_time_sec': 36.78, 'step_peak_memory_allocated_MB': 34547.68, 'step_peak_memory_reserved_MB': 58180.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 58180.0, 'step_tokens_per_second': 11086.94, 'avg_tokens_per_second': 10926.38} +{'loss': 0.9645, 'grad_norm': 7.525882720947266, 'learning_rate': 0.0, 'epoch': 0.03, 'num_input_tokens_seen': 396800, 'step': 10, 'step_time_sec': 2.99, 'avg_step_time_sec': 3.6, 'time_to_completion_sec': 0.0, 'estimated_total_time_sec': 36.02, 'step_peak_memory_allocated_MB': 34547.69, 'step_peak_memory_reserved_MB': 58180.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 58180.0, 'step_tokens_per_second': 11473.59, 'avg_tokens_per_second': 10976.85} +{'train_runtime': 38.5858, 'train_samples_per_second': 33.173, 'train_steps_per_second': 0.259, 'train_loss': 1.2333564937114716, 'epoch': 0.03, 'num_input_tokens_seen': 396800, 'step': 10, 'step_time_sec': 2.99, 'avg_step_time_sec': 3.6, 'time_to_completion_sec': 0.0, 'estimated_total_time_sec': 36.02, 'step_peak_memory_allocated_MB': 34547.69, 'step_peak_memory_reserved_MB': 58180.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 58180.0, 'step_tokens_per_second': 11473.59, 'avg_tokens_per_second': 10976.85} diff --git a/examples/huggingface/results/mistral_use_liger_True_patching_type_post_init_class.log b/examples/huggingface/results/mistral_use_liger_True_patching_type_post_init_class.log new file mode 100644 index 000000000..790cb390d --- /dev/null +++ b/examples/huggingface/results/mistral_use_liger_True_patching_type_post_init_class.log @@ -0,0 +1,15 @@ +********** Post-Init Class Patching *********** +********** Post-Init Class Patching *********** +********** Post-Init Class Patching *********** +********** Post-Init Class Patching *********** +{'loss': 2.0133, 'grad_norm': 111.51253509521484, 'learning_rate': 6e-06, 'epoch': 0.0, 'num_input_tokens_seen': 42496} +{'loss': 1.9944, 'grad_norm': 110.4891128540039, 'learning_rate': 5.819077862357725e-06, 'epoch': 0.01, 'num_input_tokens_seen': 73728, 'step': 2, 'step_time_sec': 2.62, 'avg_step_time_sec': 2.62, 'time_to_completion_sec': 20.95, 'estimated_total_time_sec': 26.18, 'step_peak_memory_allocated_MB': 34547.75, 'step_peak_memory_reserved_MB': 45144.0, 'total_peak_memory_allocated_MB': 34547.75, 'total_peak_memory_reserved_MB': 45144.0, 'step_tokens_per_second': 11927.48, 'avg_tokens_per_second': 11927.48} +{'loss': 1.193, 'grad_norm': 53.41102981567383, 'learning_rate': 5.298133329356934e-06, 'epoch': 0.01, 'num_input_tokens_seen': 107008, 'step': 3, 'step_time_sec': 3.14, 'avg_step_time_sec': 2.88, 'time_to_completion_sec': 20.16, 'estimated_total_time_sec': 28.79, 'step_peak_memory_allocated_MB': 34547.7, 'step_peak_memory_reserved_MB': 45860.0, 'total_peak_memory_allocated_MB': 34547.75, 'total_peak_memory_reserved_MB': 45860.0, 'step_tokens_per_second': 10598.12, 'avg_tokens_per_second': 11202.59} +{'loss': 1.1854, 'grad_norm': 66.05502319335938, 'learning_rate': 4.5e-06, 'epoch': 0.01, 'num_input_tokens_seen': 142848, 'step': 4, 'step_time_sec': 3.37, 'avg_step_time_sec': 3.04, 'time_to_completion_sec': 18.26, 'estimated_total_time_sec': 30.43, 'step_peak_memory_allocated_MB': 34547.82, 'step_peak_memory_reserved_MB': 48548.0, 'total_peak_memory_allocated_MB': 34547.82, 'total_peak_memory_reserved_MB': 48548.0, 'step_tokens_per_second': 10630.55, 'avg_tokens_per_second': 10991.35} +{'loss': 1.1145, 'grad_norm': 19.789567947387695, 'learning_rate': 3.5209445330007917e-06, 'epoch': 0.01, 'num_input_tokens_seen': 187392, 'step': 5, 'step_time_sec': 4.45, 'avg_step_time_sec': 3.39, 'time_to_completion_sec': 16.97, 'estimated_total_time_sec': 33.94, 'step_peak_memory_allocated_MB': 34547.91, 'step_peak_memory_reserved_MB': 52132.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 52132.0, 'step_tokens_per_second': 10014.59, 'avg_tokens_per_second': 10671.38} +{'loss': 1.0048, 'grad_norm': 20.529048919677734, 'learning_rate': 2.4790554669992093e-06, 'epoch': 0.02, 'num_input_tokens_seen': 224768, 'step': 6, 'step_time_sec': 3.27, 'avg_step_time_sec': 3.37, 'time_to_completion_sec': 13.48, 'estimated_total_time_sec': 33.69, 'step_peak_memory_allocated_MB': 34547.75, 'step_peak_memory_reserved_MB': 52132.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 52132.0, 'step_tokens_per_second': 11434.53, 'avg_tokens_per_second': 10819.45} +{'loss': 0.9917, 'grad_norm': 9.391414642333984, 'learning_rate': 1.5000000000000007e-06, 'epoch': 0.02, 'num_input_tokens_seen': 260096, 'step': 7, 'step_time_sec': 3.03, 'avg_step_time_sec': 3.31, 'time_to_completion_sec': 9.94, 'estimated_total_time_sec': 33.12, 'step_peak_memory_allocated_MB': 34547.74, 'step_peak_memory_reserved_MB': 52132.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 52132.0, 'step_tokens_per_second': 11671.7, 'avg_tokens_per_second': 10949.25} +{'loss': 0.9286, 'grad_norm': 7.622978687286377, 'learning_rate': 7.018666706430663e-07, 'epoch': 0.02, 'num_input_tokens_seen': 306176, 'step': 8, 'step_time_sec': 4.37, 'avg_step_time_sec': 3.46, 'time_to_completion_sec': 6.93, 'estimated_total_time_sec': 34.63, 'step_peak_memory_allocated_MB': 34547.78, 'step_peak_memory_reserved_MB': 52132.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 52132.0, 'step_tokens_per_second': 10545.82, 'avg_tokens_per_second': 10876.54} +{'loss': 0.984, 'grad_norm': 7.1107611656188965, 'learning_rate': 1.8092213764227505e-07, 'epoch': 0.02, 'num_input_tokens_seen': 348672, 'step': 9, 'step_time_sec': 3.59, 'avg_step_time_sec': 3.48, 'time_to_completion_sec': 3.48, 'estimated_total_time_sec': 34.79, 'step_peak_memory_allocated_MB': 34547.77, 'step_peak_memory_reserved_MB': 52132.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 52132.0, 'step_tokens_per_second': 11832.11, 'avg_tokens_per_second': 10999.84} +{'loss': 0.9725, 'grad_norm': 7.447627544403076, 'learning_rate': 0.0, 'epoch': 0.03, 'num_input_tokens_seen': 386560, 'step': 10, 'step_time_sec': 3.99, 'avg_step_time_sec': 3.54, 'time_to_completion_sec': 0.0, 'estimated_total_time_sec': 35.36, 'step_peak_memory_allocated_MB': 34547.87, 'step_peak_memory_reserved_MB': 52132.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 52132.0, 'step_tokens_per_second': 9499.24, 'avg_tokens_per_second': 10811.76} +{'train_runtime': 38.6812, 'train_samples_per_second': 33.091, 'train_steps_per_second': 0.259, 'train_loss': 1.2382215678691864, 'epoch': 0.03, 'num_input_tokens_seen': 386560, 'step': 10, 'step_time_sec': 3.99, 'avg_step_time_sec': 3.54, 'time_to_completion_sec': 0.0, 'estimated_total_time_sec': 35.36, 'step_peak_memory_allocated_MB': 34547.87, 'step_peak_memory_reserved_MB': 52132.0, 'total_peak_memory_allocated_MB': 34547.91, 'total_peak_memory_reserved_MB': 52132.0, 'step_tokens_per_second': 9499.24, 'avg_tokens_per_second': 10811.76} diff --git a/examples/huggingface/results/mistral_use_liger_True_patching_type_post_init_instance.log b/examples/huggingface/results/mistral_use_liger_True_patching_type_post_init_instance.log new file mode 100644 index 000000000..75eb5b5b0 --- /dev/null +++ b/examples/huggingface/results/mistral_use_liger_True_patching_type_post_init_instance.log @@ -0,0 +1,15 @@ +********** Post-Init Instance Patching *********** +********** Post-Init Instance Patching *********** +********** Post-Init Instance Patching *********** +********** Post-Init Instance Patching *********** +{'loss': 10.3753, 'grad_norm': 80.14120483398438, 'learning_rate': 6e-06, 'epoch': 0.0, 'num_input_tokens_seen': 47104} +{'loss': 10.374, 'grad_norm': 80.64556121826172, 'learning_rate': 5.819077862357725e-06, 'epoch': 0.01, 'num_input_tokens_seen': 89600, 'step': 2, 'step_time_sec': 2.98, 'avg_step_time_sec': 2.98, 'time_to_completion_sec': 23.82, 'estimated_total_time_sec': 29.78, 'step_peak_memory_allocated_MB': 34547.82, 'step_peak_memory_reserved_MB': 45690.0, 'total_peak_memory_allocated_MB': 34547.82, 'total_peak_memory_reserved_MB': 45690.0, 'step_tokens_per_second': 14269.83, 'avg_tokens_per_second': 14269.83} +{'loss': 9.5078, 'grad_norm': 30.093812942504883, 'learning_rate': 5.298133329356934e-06, 'epoch': 0.01, 'num_input_tokens_seen': 131072, 'step': 3, 'step_time_sec': 3.66, 'avg_step_time_sec': 3.32, 'time_to_completion_sec': 23.22, 'estimated_total_time_sec': 33.17, 'step_peak_memory_allocated_MB': 34547.88, 'step_peak_memory_reserved_MB': 48532.0, 'total_peak_memory_allocated_MB': 34547.88, 'total_peak_memory_reserved_MB': 48532.0, 'step_tokens_per_second': 11344.44, 'avg_tokens_per_second': 12657.71} +{'loss': 8.7012, 'grad_norm': 33.24259948730469, 'learning_rate': 4.5e-06, 'epoch': 0.01, 'num_input_tokens_seen': 176640, 'step': 4, 'step_time_sec': 3.94, 'avg_step_time_sec': 3.53, 'time_to_completion_sec': 21.15, 'estimated_total_time_sec': 35.25, 'step_peak_memory_allocated_MB': 34547.71, 'step_peak_memory_reserved_MB': 49246.0, 'total_peak_memory_allocated_MB': 34547.88, 'total_peak_memory_reserved_MB': 49246.0, 'step_tokens_per_second': 11559.63, 'avg_tokens_per_second': 12248.41} +{'loss': 7.7453, 'grad_norm': 46.81055450439453, 'learning_rate': 3.5209445330007917e-06, 'epoch': 0.01, 'num_input_tokens_seen': 222208, 'step': 5, 'step_time_sec': 3.93, 'avg_step_time_sec': 3.63, 'time_to_completion_sec': 18.13, 'estimated_total_time_sec': 36.26, 'step_peak_memory_allocated_MB': 34547.76, 'step_peak_memory_reserved_MB': 49246.0, 'total_peak_memory_allocated_MB': 34547.88, 'total_peak_memory_reserved_MB': 49246.0, 'step_tokens_per_second': 11605.14, 'avg_tokens_per_second': 12074.24} +{'loss': 7.1292, 'grad_norm': 25.51975440979004, 'learning_rate': 2.4790554669992093e-06, 'epoch': 0.02, 'num_input_tokens_seen': 260096, 'step': 6, 'step_time_sec': 2.79, 'avg_step_time_sec': 3.46, 'time_to_completion_sec': 13.84, 'estimated_total_time_sec': 34.59, 'step_peak_memory_allocated_MB': 34547.74, 'step_peak_memory_reserved_MB': 49246.0, 'total_peak_memory_allocated_MB': 34547.88, 'total_peak_memory_reserved_MB': 49246.0, 'step_tokens_per_second': 13556.53, 'avg_tokens_per_second': 12313.75} +{'loss': 6.747, 'grad_norm': 17.432945251464844, 'learning_rate': 1.5000000000000007e-06, 'epoch': 0.02, 'num_input_tokens_seen': 303104, 'step': 7, 'step_time_sec': 3.66, 'avg_step_time_sec': 3.49, 'time_to_completion_sec': 10.48, 'estimated_total_time_sec': 34.92, 'step_peak_memory_allocated_MB': 34547.74, 'step_peak_memory_reserved_MB': 49246.0, 'total_peak_memory_allocated_MB': 34547.88, 'total_peak_memory_reserved_MB': 49246.0, 'step_tokens_per_second': 11759.21, 'avg_tokens_per_second': 12216.96} +{'loss': 6.4769, 'grad_norm': 16.094770431518555, 'learning_rate': 7.018666706430663e-07, 'epoch': 0.02, 'num_input_tokens_seen': 345600, 'step': 8, 'step_time_sec': 3.94, 'avg_step_time_sec': 3.56, 'time_to_completion_sec': 7.11, 'estimated_total_time_sec': 35.56, 'step_peak_memory_allocated_MB': 34547.75, 'step_peak_memory_reserved_MB': 49246.0, 'total_peak_memory_allocated_MB': 34547.88, 'total_peak_memory_reserved_MB': 49246.0, 'step_tokens_per_second': 10792.2, 'avg_tokens_per_second': 11991.58} +{'loss': 6.3711, 'grad_norm': 14.258646011352539, 'learning_rate': 1.8092213764227505e-07, 'epoch': 0.02, 'num_input_tokens_seen': 393216, 'step': 9, 'step_time_sec': 3.95, 'avg_step_time_sec': 3.61, 'time_to_completion_sec': 3.61, 'estimated_total_time_sec': 36.05, 'step_peak_memory_allocated_MB': 34547.79, 'step_peak_memory_reserved_MB': 49246.0, 'total_peak_memory_allocated_MB': 34547.88, 'total_peak_memory_reserved_MB': 49246.0, 'step_tokens_per_second': 12050.23, 'avg_tokens_per_second': 11999.61} +{'loss': 6.236, 'grad_norm': 13.197342872619629, 'learning_rate': 0.0, 'epoch': 0.03, 'num_input_tokens_seen': 432128, 'step': 10, 'step_time_sec': 3.01, 'avg_step_time_sec': 3.54, 'time_to_completion_sec': 0.0, 'estimated_total_time_sec': 35.4, 'step_peak_memory_allocated_MB': 34547.7, 'step_peak_memory_reserved_MB': 49246.0, 'total_peak_memory_allocated_MB': 34547.88, 'total_peak_memory_reserved_MB': 49246.0, 'step_tokens_per_second': 12913.91, 'avg_tokens_per_second': 12086.09} +{'train_runtime': 38.3024, 'train_samples_per_second': 33.418, 'train_steps_per_second': 0.261, 'train_loss': 7.966402006149292, 'epoch': 0.03, 'num_input_tokens_seen': 432128, 'step': 10, 'step_time_sec': 3.01, 'avg_step_time_sec': 3.54, 'time_to_completion_sec': 0.0, 'estimated_total_time_sec': 35.4, 'step_peak_memory_allocated_MB': 34547.7, 'step_peak_memory_reserved_MB': 49246.0, 'total_peak_memory_allocated_MB': 34547.88, 'total_peak_memory_reserved_MB': 49246.0, 'step_tokens_per_second': 12913.91, 'avg_tokens_per_second': 12086.09} diff --git a/examples/huggingface/results/mistral_use_liger_True_patching_type_pre_init.log b/examples/huggingface/results/mistral_use_liger_True_patching_type_pre_init.log new file mode 100644 index 000000000..f58a2c422 --- /dev/null +++ b/examples/huggingface/results/mistral_use_liger_True_patching_type_pre_init.log @@ -0,0 +1,15 @@ +********** Pre-Init Patching *********** +********** Pre-Init Patching *********** +********** Pre-Init Patching *********** +********** Pre-Init Patching *********** +{'loss': 2.0207, 'grad_norm': 108.31810760498047, 'learning_rate': 6e-06, 'epoch': 0.0, 'num_input_tokens_seen': 37888} +{'loss': 1.9878, 'grad_norm': 93.73506164550781, 'learning_rate': 5.819077862357725e-06, 'epoch': 0.01, 'num_input_tokens_seen': 77824, 'step': 2, 'step_time_sec': 3.92, 'avg_step_time_sec': 3.92, 'time_to_completion_sec': 31.4, 'estimated_total_time_sec': 39.24, 'step_peak_memory_allocated_MB': 34547.7, 'step_peak_memory_reserved_MB': 42254.0, 'total_peak_memory_allocated_MB': 34547.7, 'total_peak_memory_reserved_MB': 42254.0, 'step_tokens_per_second': 10176.29, 'avg_tokens_per_second': 10176.29} +{'loss': 1.1647, 'grad_norm': 52.44873809814453, 'learning_rate': 5.298133329356934e-06, 'epoch': 0.01, 'num_input_tokens_seen': 110592, 'step': 3, 'step_time_sec': 2.33, 'avg_step_time_sec': 3.13, 'time_to_completion_sec': 21.88, 'estimated_total_time_sec': 31.26, 'step_peak_memory_allocated_MB': 34547.75, 'step_peak_memory_reserved_MB': 45388.0, 'total_peak_memory_allocated_MB': 34547.75, 'total_peak_memory_reserved_MB': 45388.0, 'step_tokens_per_second': 14081.43, 'avg_tokens_per_second': 11629.94} +{'loss': 1.1242, 'grad_norm': 66.05574035644531, 'learning_rate': 4.5e-06, 'epoch': 0.01, 'num_input_tokens_seen': 148480, 'step': 4, 'step_time_sec': 3.0, 'avg_step_time_sec': 3.08, 'time_to_completion_sec': 18.5, 'estimated_total_time_sec': 30.84, 'step_peak_memory_allocated_MB': 34547.82, 'step_peak_memory_reserved_MB': 46732.0, 'total_peak_memory_allocated_MB': 34547.82, 'total_peak_memory_reserved_MB': 46732.0, 'step_tokens_per_second': 12631.7, 'avg_tokens_per_second': 11954.74} +{'loss': 1.0549, 'grad_norm': 22.685998916625977, 'learning_rate': 3.5209445330007917e-06, 'epoch': 0.01, 'num_input_tokens_seen': 178688, 'step': 5, 'step_time_sec': 2.24, 'avg_step_time_sec': 2.87, 'time_to_completion_sec': 14.37, 'estimated_total_time_sec': 28.74, 'step_peak_memory_allocated_MB': 34547.72, 'step_peak_memory_reserved_MB': 46732.0, 'total_peak_memory_allocated_MB': 34547.82, 'total_peak_memory_reserved_MB': 46732.0, 'step_tokens_per_second': 13466.62, 'avg_tokens_per_second': 12249.8} +{'loss': 1.0396, 'grad_norm': 11.832498550415039, 'learning_rate': 2.4790554669992093e-06, 'epoch': 0.02, 'num_input_tokens_seen': 227840, 'step': 6, 'step_time_sec': 3.95, 'avg_step_time_sec': 3.09, 'time_to_completion_sec': 12.35, 'estimated_total_time_sec': 30.88, 'step_peak_memory_allocated_MB': 34547.79, 'step_peak_memory_reserved_MB': 46732.0, 'total_peak_memory_allocated_MB': 34547.82, 'total_peak_memory_reserved_MB': 46732.0, 'step_tokens_per_second': 12455.91, 'avg_tokens_per_second': 12302.48} +{'loss': 0.9977, 'grad_norm': 8.230679512023926, 'learning_rate': 1.5000000000000007e-06, 'epoch': 0.02, 'num_input_tokens_seen': 265728, 'step': 7, 'step_time_sec': 3.18, 'avg_step_time_sec': 3.1, 'time_to_completion_sec': 9.31, 'estimated_total_time_sec': 31.03, 'step_peak_memory_allocated_MB': 34547.77, 'step_peak_memory_reserved_MB': 46732.0, 'total_peak_memory_allocated_MB': 34547.82, 'total_peak_memory_reserved_MB': 46732.0, 'step_tokens_per_second': 11930.92, 'avg_tokens_per_second': 12239.09} +{'loss': 0.9193, 'grad_norm': 7.567809581756592, 'learning_rate': 7.018666706430663e-07, 'epoch': 0.02, 'num_input_tokens_seen': 299008, 'step': 8, 'step_time_sec': 2.89, 'avg_step_time_sec': 3.07, 'time_to_completion_sec': 6.14, 'estimated_total_time_sec': 30.72, 'step_peak_memory_allocated_MB': 34547.74, 'step_peak_memory_reserved_MB': 46732.0, 'total_peak_memory_allocated_MB': 34547.82, 'total_peak_memory_reserved_MB': 46732.0, 'step_tokens_per_second': 11534.49, 'avg_tokens_per_second': 12144.54} +{'loss': 0.9728, 'grad_norm': 7.609869003295898, 'learning_rate': 1.8092213764227505e-07, 'epoch': 0.02, 'num_input_tokens_seen': 335360, 'step': 9, 'step_time_sec': 2.9, 'avg_step_time_sec': 3.05, 'time_to_completion_sec': 3.05, 'estimated_total_time_sec': 30.5, 'step_peak_memory_allocated_MB': 34547.71, 'step_peak_memory_reserved_MB': 46732.0, 'total_peak_memory_allocated_MB': 34547.82, 'total_peak_memory_reserved_MB': 46732.0, 'step_tokens_per_second': 12553.79, 'avg_tokens_per_second': 12193.12} +{'loss': 0.9544, 'grad_norm': 7.200411319732666, 'learning_rate': 0.0, 'epoch': 0.03, 'num_input_tokens_seen': 373248, 'step': 10, 'step_time_sec': 2.9, 'avg_step_time_sec': 3.03, 'time_to_completion_sec': 0.0, 'estimated_total_time_sec': 30.33, 'step_peak_memory_allocated_MB': 34547.71, 'step_peak_memory_reserved_MB': 46732.0, 'total_peak_memory_allocated_MB': 34547.82, 'total_peak_memory_reserved_MB': 46732.0, 'step_tokens_per_second': 13064.17, 'avg_tokens_per_second': 12285.66} +{'train_runtime': 33.6106, 'train_samples_per_second': 38.083, 'train_steps_per_second': 0.298, 'train_loss': 1.2235911667346955, 'epoch': 0.03, 'num_input_tokens_seen': 373248, 'step': 10, 'step_time_sec': 2.9, 'avg_step_time_sec': 3.03, 'time_to_completion_sec': 0.0, 'estimated_total_time_sec': 30.33, 'step_peak_memory_allocated_MB': 34547.71, 'step_peak_memory_reserved_MB': 46732.0, 'total_peak_memory_allocated_MB': 34547.82, 'total_peak_memory_reserved_MB': 46732.0, 'step_tokens_per_second': 13064.17, 'avg_tokens_per_second': 12285.66} diff --git a/examples/huggingface/run_llama.sh b/examples/huggingface/run_llama.sh index 22b10ff97..11eae9280 100644 --- a/examples/huggingface/run_llama.sh +++ b/examples/huggingface/run_llama.sh @@ -2,7 +2,7 @@ torchrun --nnodes=1 --nproc-per-node=4 training.py \ --bf16 \ --model_name "/shared/public/models/Meta-Llama-3-8B" \ --dataset "/shared/public/data/tatsu-lab" \ - --max_steps 20 \ + --max_steps 5 \ --num_train_epochs 1 \ --per_device_train_batch_size 48 \ --per_device_eval_batch_size 64 \ diff --git a/examples/huggingface/run_patching_methods.sh b/examples/huggingface/run_patching_methods.sh new file mode 100755 index 000000000..29e6ec455 --- /dev/null +++ b/examples/huggingface/run_patching_methods.sh @@ -0,0 +1,60 @@ +## Benchmarking Script +## Runs the training script with different configurations and logs the results + +# MODEL_TYPE="llama" +# MODEL_PATH="/shared/public/models/Meta-Llama-3-8B" +MODEL_TYPE="mistral" +MODEL_PATH="/shared/public/models/mistralai/Mistral-7B-v0.1" +# USE_LIGER_VALUES=("True" "False") +# PATCHING_TYPE_VALUES=("pre_init" "post_init_class" "post_init_instance") +USE_LIGER_VALUES=("False") +PATCHING_TYPE_VALUES=("post_init_instance") +MAX_STEPS=10 +BATCH_SIZE=32 +DATASET_PATH="/shared/public/data/tatsu-lab" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +mkdir -p "${SCRIPT_DIR}/results" + +for USE_LIGER in "${USE_LIGER_VALUES[@]}"; do + + # Only run patching types if USE_LIGER is "True" + if [ "$USE_LIGER" == "True" ]; then + PATCHING_TYPES=("${PATCHING_TYPE_VALUES[@]}") + else + PATCHING_TYPES=("None") + fi + + for PATCHING_TYPE in "${PATCHING_TYPES[@]}"; do + echo "Running with use_liger=$USE_LIGER and patching_type=$PATCHING_TYPE" + + LOG_FILE="${SCRIPT_DIR}/results/${MODEL_TYPE}_use_liger_${USE_LIGER}_patching_type_${PATCHING_TYPE}.log" + + torchrun --nnodes=1 --nproc-per-node=4 training.py \ + --bf16 \ + --num_train_epochs 1 \ + --max_steps $MAX_STEPS \ + --model_name $MODEL_PATH \ + --dataset $DATASET_PATH \ + --per_device_train_batch_size $BATCH_SIZE \ + --per_device_eval_batch_size 16 \ + --eval_strategy "no" \ + --save_strategy "no" \ + --learning_rate 6e-6 \ + --weight_decay 0.05 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --include_num_input_tokens_seen \ + --report_to none \ + --fsdp "full_shard auto_wrap" \ + --fsdp_config config/fsdp_config.json \ + --seed 42 \ + --use_liger $USE_LIGER \ + --patching_type $PATCHING_TYPE \ + --output_dir model_output_dir \ + > $LOG_FILE + + sleep 5 + done +done \ No newline at end of file diff --git a/examples/huggingface/training.py b/examples/huggingface/training.py index 1f5d9be0f..2dab51163 100644 --- a/examples/huggingface/training.py +++ b/examples/huggingface/training.py @@ -6,8 +6,11 @@ from callback import EfficiencyCallback from trl import DataCollatorForCompletionOnlyLM, SFTTrainer -from liger_kernel.transformers import AutoLigerKernelForCausalLM, apply_liger_kernel_to_llama -from liger_kernel.transformers.monkey_patch import _apply_liger_kernel +from liger_kernel.transformers import ( + AutoLigerKernelForCausalLM, + apply_liger_kernel_to_llama, +) +from liger_kernel.transformers.monkey_patch import _apply_liger_kernel, _apply_liger_kernel_to_instance @dataclass @@ -16,10 +19,12 @@ class CustomArguments: dataset: str = "tatsu-lab/alpaca" max_seq_length: int = 512 use_liger: bool = False + patching_type: str = "pre_init" # pre_init, post_init_class, post_init_instance +bos_token = '' def formatting_prompts_func(example): - return example["text"] + return [text.replace("### Response:", bos_token) for text in example["text"]] def train(): @@ -39,31 +44,54 @@ def train(): ) train_dataset = dataset["train"] eval_dataset = dataset["test"] - response_prompt = tokenizer.encode("### Response:\n", add_special_tokens=False) + response_prompt = bos_token collator = DataCollatorForCompletionOnlyLM( tokenizer=tokenizer, response_template=response_prompt, pad_to_multiple_of=16, ) - # if custom_args.use_liger: - # model = AutoLigerKernelForCausalLM.from_pretrained( - # custom_args.model_name, - # trust_remote_code=True, - # use_cache=False, - # torch_dtype=torch.bfloat16, - # # These args will get passed to the appropriate apply_liger_kernel_to_* function - # # to override the default settings - # # cross_entropy=True, - # # fused_linear_cross_entropy=False, - # ) - # else: - # model = transformers.AutoModelForCausalLM.from_pretrained( - # custom_args.model_name, - # trust_remote_code=True, - # use_cache=False, - # torch_dtype=torch.bfloat16, - # ) + + if custom_args.use_liger: + if custom_args.patching_type == "pre_init": + print("********** Pre-Init Patching ***********") + model = AutoLigerKernelForCausalLM.from_pretrained( + custom_args.model_name, + trust_remote_code=True, + use_cache=False, + torch_dtype=torch.bfloat16, + ) + elif custom_args.patching_type == "post_init_class": + print("********** Post-Init Class Patching ***********") + model = transformers.AutoModelForCausalLM.from_pretrained( + custom_args.model_name, + trust_remote_code=True, + use_cache=False, + torch_dtype=torch.bfloat16, + ) + model_type = getattr(model, "config", None) and getattr( + model.config, "model_type", None + ) + _apply_liger_kernel(model_type=model_type) + elif custom_args.patching_type == "post_init_instance": + print("********** Post-Init Instance Patching ***********") + model = transformers.AutoModelForCausalLM.from_pretrained( + custom_args.model_name, + trust_remote_code=True, + use_cache=False, + torch_dtype=torch.bfloat16, + ) + _apply_liger_kernel_to_instance(model) + else: + raise ValueError(f"Invalid patching type: {custom_args.patching_type}") + else: + print("********** No Patching ***********") + model = transformers.AutoModelForCausalLM.from_pretrained( + custom_args.model_name, + trust_remote_code=True, + use_cache=False, + torch_dtype=torch.bfloat16, + ) ## 1. Pre-init patching # _apply_liger_kernel(model_type="llama") @@ -84,14 +112,13 @@ def train(): # _apply_liger_kernel(model_type="llama") ## 3. Post-init instance patching - model = transformers.AutoModelForCausalLM.from_pretrained( - custom_args.model_name, - trust_remote_code=True, - use_cache=False, - torch_dtype=torch.bfloat16, - ) - _apply_liger_kernel(model=model) - + # model = transformers.AutoModelForCausalLM.from_pretrained( + # custom_args.model_name, + # trust_remote_code=True, + # use_cache=False, + # torch_dtype=torch.bfloat16, + # ) + # _apply_liger_kernel(model=model) trainer = SFTTrainer( model=model, diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index a6d5ce52c..c9403c4a8 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -2,6 +2,9 @@ import logging from functools import partial +from transformers import PretrainedConfig, PreTrainedModel +from transformers.models.llama.modeling_llama import LlamaModel + from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.transformers.layer_norm import LigerLayerNorm @@ -18,8 +21,6 @@ LigerPhi3SwiGLUMLP, LigerSwiGLUMLP, ) -from transformers import PreTrainedModel, PretrainedConfig -from transformers.models.llama.modeling_llama import LlamaModel logger = logging.getLogger(__name__) @@ -79,17 +80,23 @@ def apply_liger_kernel_to_llama( else: # Direct LlamaModel base_model = model - + torch_dtype = config.torch_dtype if rms_norm: - base_model.norm = LigerRMSNorm(config.hidden_size, eps=config.rms_norm_eps).to(torch_dtype) + base_model.norm = LigerRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ).to(torch_dtype) for decoder_layer in base_model.layers: if swiglu: decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype) if rms_norm: - decoder_layer.input_layernorm = LigerRMSNorm(config.hidden_size, eps=config.rms_norm_eps).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNorm(config.hidden_size, eps=config.rms_norm_eps).to(torch_dtype) + decoder_layer.input_layernorm = LigerRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ).to(torch_dtype) + decoder_layer.post_attention_layernorm = LigerRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ).to(torch_dtype) def apply_liger_kernel_to_mistral( @@ -142,18 +149,23 @@ def apply_liger_kernel_to_mistral( else: # Direct MistralModel base_model = model - + torch_dtype = config.torch_dtype if rms_norm: - base_model.norm = LigerRMSNorm(config.hidden_size, eps=config.rms_norm_eps).to(torch_dtype) + base_model.norm = LigerRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ).to(torch_dtype) for decoder_layer in base_model.layers: if swiglu: decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype) if rms_norm: - decoder_layer.input_layernorm = LigerRMSNorm(config.hidden_size, eps=config.rms_norm_eps).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNorm(config.hidden_size, eps=config.rms_norm_eps).to(torch_dtype) - + decoder_layer.input_layernorm = LigerRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ).to(torch_dtype) + decoder_layer.post_attention_layernorm = LigerRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ).to(torch_dtype) def apply_liger_kernel_to_mixtral( @@ -453,10 +465,14 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None: Args: - model: the model instance to apply Liger kernels to """ - model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None) + model_type = getattr(model, "config", None) and getattr( + model.config, "model_type", None + ) if not model_type: - logger.info("Model type could not be determined from model config. No Liger kernels will be applied.") + logger.info( + "Model type could not be determined from model config. No Liger kernels will be applied." + ) return if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys(): diff --git a/src/liger_kernel/triton/monkey_patch.py b/src/liger_kernel/triton/monkey_patch.py index 590842a83..70863f4e3 100644 --- a/src/liger_kernel/triton/monkey_patch.py +++ b/src/liger_kernel/triton/monkey_patch.py @@ -37,6 +37,6 @@ def apply_liger_triton_cache_manager(): Experimental feature to get around transient FileNotFoundError in triton compilation. For more details please see https://github.com/triton-lang/triton/pull/4295 """ - os.environ[ - "TRITON_CACHE_MANAGER" - ] = "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager" + os.environ["TRITON_CACHE_MANAGER"] = ( + "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager" + )