-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[On-Device Training] Expose Parameters through the Training API #17364
Conversation
orttraining/orttraining/python/training/api/checkpoint_state.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Take a quick look, I have a few comments.
csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs
Outdated
Show resolved
Hide resolved
Recommendation: instead of true, false you can do and pass it as an argument. This would improve readability a lot. In reply to: 1709295692 In reply to: 1709295692 In reply to: 1709295692 Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs:365 in df21a2e. [](commit_id = df21a2e, deletion_comment = False) |
Also, I am suggesting to introduce a OrtValue based API some time in the near future. In reply to: 1709295692 Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs:365 in df21a2e. [](commit_id = df21a2e, deletion_comment = False) |
Suggestion: float type is blittable, so you can simply copy it from native memory in the unsafe block and no array is needed, bc it introduces garbage. Marshal.Copy is slow. Same is for Int Something like this: Or one can do unsafe block: float val;
unsafe
{
val = \*(float\*)propertyValue.ToPointer();
}
---
In reply to: [1709298903](https://github.com/microsoft/onnxruntime/pull/17364#issuecomment-1709298903) [](http://example.com/codeflow?ancestors=1709298903)
---
In reply to: [1709298903](https://github.com/microsoft/onnxruntime/pull/17364#issuecomment-1709298903) [](http://example.com/codeflow?ancestors=1709298903)
---
In reply to: [1709298903](https://github.com/microsoft/onnxruntime/pull/17364#issuecomment-1709298903) [](http://example.com/codeflow?ancestors=1709298903)
---
Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs:183 in df21a2e. [](commit_id = df21a2e20e52436297950511058c3626203d2e92, deletion_comment = False) |
See inference code for a function that converts and copies the string to native memory at the same time, and avoids intermediate array allocation. In reply to: 1709299587 In reply to: 1709299587 In reply to: 1709299587 Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs:152 in df21a2e. [](commit_id = df21a2e, deletion_comment = False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🕐
…baijumeswani/update-checkpoint-params
nit: Pin() is very expensive. In reply to: 1724556619 Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs:49 in 58ff40b. [](commit_id = 58ff40b, deletion_comment = False) |
csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
Outdated
Show resolved
Hide resolved
csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
Outdated
Show resolved
Hide resolved
orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h
Show resolved
Hide resolved
csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
Outdated
Show resolved
Hide resolved
csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
Outdated
Show resolved
Hide resolved
csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
Outdated
Show resolved
Hide resolved
csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
Outdated
Show resolved
Hide resolved
csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
Outdated
Show resolved
Hide resolved
csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
Outdated
Show resolved
Hide resolved
…baijumeswani/update-checkpoint-params
csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs
Outdated
Show resolved
Hide resolved
…baijumeswani/update-checkpoint-params
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thank you for the review @yuslepukhin @AdamLouly @pengwa :) |
Cherry-pick PRs: #18026 #17912 #17901 “2 lines added whitespace errors when cherry-picking" #17293 #17364 #17505 #17885 This PR contains all the cherry-picks for the patch release except: 1. The PRs marked with sdxl_llama 2. #17772 which has a merge conflict. --------- Co-authored-by: Chi Lo <Chi.Lo@microsoft.com> Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Co-authored-by: Scott McKay <Scott.McKay@microsoft.com> Co-authored-by: Baiju Meswani <bmeswani@microsoft.com> Co-authored-by: Kaz Nishimura <kazssym@linuxfront.com> Co-authored-by: Scott McKay <skottmckay@gmail.com>
This pull request exposes the checkpoint parameters to users in C, C++, C# and Python.
Users will be able to query the current value of the parameters and update the parameters after this pull request.