Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[On-Device Training] Expose Parameters through the Training API #17364

Merged
merged 14 commits into from
Sep 26, 2023

Conversation

baijumeswani
Copy link
Contributor

This pull request exposes the checkpoint parameters to users in C, C++, C# and Python.

Users will be able to query the current value of the parameters and update the parameters after this pull request.

@baijumeswani baijumeswani added the training issues related to ONNX Runtime training; typically submitted using template label Aug 31, 2023
Copy link
Contributor

@pengwa pengwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take a quick look, I have a few comments.

orttraining/orttraining/training_api/module.cc Outdated Show resolved Hide resolved
orttraining/orttraining/python/orttraining_pybind_state.cc Outdated Show resolved Hide resolved
orttraining/orttraining/training_api/module.h Outdated Show resolved Hide resolved
@yuslepukhin
Copy link
Member

yuslepukhin commented Sep 7, 2023

        IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);

Recommendation:

instead of true, false you can do
const bool descriptive_name_true = true;

and pass it as an argument.

This would improve readability a lot.


In reply to: 1709295692


In reply to: 1709295692


In reply to: 1709295692


Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs:365 in df21a2e. [](commit_id = df21a2e, deletion_comment = False)

@yuslepukhin
Copy link
Member

        IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);

Also, I am suggesting to introduce a OrtValue based API some time in the near future.


In reply to: 1709295692


Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs:365 in df21a2e. [](commit_id = df21a2e, deletion_comment = False)

@yuslepukhin
Copy link
Member

yuslepukhin commented Sep 7, 2023

            Marshal.Copy(propertyValue, value, 0, 1);

Suggestion: float type is blittable, so you can simply copy it from native memory in the unsafe block and no array is needed, bc it introduces garbage. Marshal.Copy is slow.

Same is for Int

Something like this:
https://github.com/microsoft/onnxruntime/blob/main/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs#L219

Or one can do unsafe block:

float val;
unsafe
{
    val = \*(float\*)propertyValue.ToPointer();
}


---
In reply to: [1709298903](https://github.com/microsoft/onnxruntime/pull/17364#issuecomment-1709298903) [](http://example.com/codeflow?ancestors=1709298903)

---
In reply to: [1709298903](https://github.com/microsoft/onnxruntime/pull/17364#issuecomment-1709298903) [](http://example.com/codeflow?ancestors=1709298903)

---
In reply to: [1709298903](https://github.com/microsoft/onnxruntime/pull/17364#issuecomment-1709298903) [](http://example.com/codeflow?ancestors=1709298903)

---
Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs:183 in df21a2e. [](commit_id = df21a2e20e52436297950511058c3626203d2e92, deletion_comment = False)

@yuslepukhin
Copy link
Member

yuslepukhin commented Sep 7, 2023

        }

See inference code for a function that converts and copies the string to native memory at the same time, and avoids intermediate array allocation.

https://github.com/microsoft/onnxruntime/blob/main/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs#L47


In reply to: 1709299587


In reply to: 1709299587


In reply to: 1709299587


Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs:152 in df21a2e. [](commit_id = df21a2e, deletion_comment = False)

Copy link
Member

@yuslepukhin yuslepukhin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕐

@yuslepukhin
Copy link
Member

yuslepukhin commented Sep 18, 2023

        using (var memHandle = memory.Pin())

nit: Pin() is very expensive.
fixed() block is cheap.


In reply to: 1724556619


Refers to: csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs:49 in 58ff40b. [](commit_id = 58ff40b, deletion_comment = False)

@yuslepukhin
Copy link
Member

Make it to pass everything in a form of a string, and then convert to the type.


Refers to: orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h:628 in 58ff40b. [](commit_id = 58ff40b, deletion_comment = False)

yuslepukhin
yuslepukhin previously approved these changes Sep 20, 2023
Copy link
Contributor

@AdamLouly AdamLouly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@baijumeswani baijumeswani merged commit ccb73fd into main Sep 26, 2023
92 checks passed
@baijumeswani baijumeswani deleted the baijumeswani/update-checkpoint-params branch September 26, 2023 03:03
@baijumeswani
Copy link
Contributor Author

Thank you for the review @yuslepukhin @AdamLouly @pengwa :)

snnn added a commit that referenced this pull request Nov 2, 2023
Cherry-pick PRs: 
#18026 
#17912 
#17901 “2 lines added whitespace errors when cherry-picking"
#17293 
#17364 
#17505 
#17885

This PR contains all the cherry-picks for the patch release except:
1. The PRs marked with sdxl_llama
2. #17772 which has a merge conflict.

---------

Co-authored-by: Chi Lo <Chi.Lo@microsoft.com>
Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com>
Co-authored-by: Scott McKay <Scott.McKay@microsoft.com>
Co-authored-by: Baiju Meswani <bmeswani@microsoft.com>
Co-authored-by: Kaz Nishimura <kazssym@linuxfront.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants