Skip to content

Commit

Permalink
Address pull request review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Sep 18, 2023
1 parent dc78a7b commit 58ff40b
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,12 @@ public void AddProperty(string propertyName, string propertyValue)
var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName);
var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue);

IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length);
try
{
Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length);
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer));
}
finally
unsafe
{
Marshal.FreeHGlobal(unmanagedPointer);
fixed (byte* p = propertyValueUtf8)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p));
}
}
}

Expand All @@ -173,23 +170,32 @@ public object GetProperty(string propertyName)

if (propertyType == PropertyType.Int)
{
var longPropertyValue = Marshal.ReadInt64(propertyValue);
allocator.FreeMemory(propertyValue);
return longPropertyValue;
Int64 value;
unsafe
{
value = *(Int64*)propertyValue;
}
return value;
}
else if (propertyType == PropertyType.Float)
{
float[] value = new float[1];
Marshal.Copy(propertyValue, value, 0, 1);
allocator.FreeMemory(propertyValue);
return value[0];
float value;
unsafe
{
value = *(float*)propertyValue;
}
return value;
}
else if (propertyType == PropertyType.String)
{
return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator);
}

throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
try {
throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString());
} finally {
allocator.FreeMemory(propertyValue);
}
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,12 @@ public void EvalStep(
{
if (_evalOutputCount != (ulong)outputValues.Count())
{
throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount}).");
throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of eval model ({_evalOutputCount}).");
}
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
const bool isInput = true;
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, isInput);

IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */
IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, !isInput); /* pointers to Pre-allocated OrtValue instances */
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray));
}
Expand Down Expand Up @@ -509,7 +510,7 @@ public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollec
/// Returns a contiguous buffer that holds a copy of all training state parameters
/// </summary>
/// <param name="onlyTrainable">Whether to only copy trainable parameters or to copy all parameters.</param>
public FixedBufferOnnxValue ToBuffer(bool onlyTrainable)
public OrtValue ToBuffer(bool onlyTrainable)
{
UIntPtr bufferSize = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, onlyTrainable));
Expand All @@ -518,9 +519,9 @@ public FixedBufferOnnxValue ToBuffer(bool onlyTrainable)

var memInfo = OrtMemoryInfo.DefaultInstance; // CPU
var shape = new long[] { (long)bufferSize.ToUInt64() };
var buffer = FixedBufferOnnxValue.CreateFromMemory<float>(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float));
var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape);

NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, onlyTrainable));
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable));

return buffer;
}
Expand All @@ -529,15 +530,15 @@ public FixedBufferOnnxValue ToBuffer(bool onlyTrainable)
/// Loads the training session model parameters from a contiguous buffer
/// </summary>
/// <param name="buffer">Contiguous buffer to load the parameters from.</param>
public void FromBuffer(FixedBufferOnnxValue buffer)
public void FromBuffer(OrtValue buffer)
{
if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
if (buffer.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR)
{
throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer.");
}

IntPtr typeAndShapeInfo = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo));
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Handle, out typeAndShapeInfo));
UIntPtr numDimensions = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions));
if (numDimensions.ToUInt64() != 1)
Expand All @@ -551,22 +552,23 @@ public void FromBuffer(FixedBufferOnnxValue buffer)

// OrtGetParametersSize returns the total number of elements in the model's parameters.
UIntPtr numElementsTrainingOnly = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true));
const bool onlyTrainable = true;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, onlyTrainable));
if ((ulong)bufferSize == (ulong)numElementsTrainingOnly)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true));
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Handle, onlyTrainable));
return;
}

UIntPtr numElements = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false));
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, !onlyTrainable));
if ((ulong)bufferSize != (ulong)numElements)
{
string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString();
throw new ArgumentException(errorMessage);
}

NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false));
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Handle, !onlyTrainable));
}

/// <summary>
Expand Down
162 changes: 80 additions & 82 deletions csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -484,43 +484,49 @@ public void TestEvalModelOutputNames()
public void TestToBuffer()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");

using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");

var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
cleanUp.Add(trainingSession);

var buffer = trainingSession.ToBuffer(true);
cleanUp.Add(buffer);
using (var buffer = trainingSession.ToBuffer(true))
{
Assert.NotNull(buffer);
var typeShape = buffer.GetTensorTypeAndShape();
Assert.Equal(1, typeShape.DimensionsCount);
var fetchedShape = typeShape.Shape;
Assert.Equal(397510, fetchedShape[0]);
}
}
}

[Fact(DisplayName = "TestFromBuffer")]
public void TestFromBuffer()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");

using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");

var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
cleanUp.Add(trainingSession);

var buffer = trainingSession.ToBuffer(true);
cleanUp.Add(buffer);
using (var buffer = trainingSession.ToBuffer(true))
{
Assert.NotNull(buffer);
var typeShape = buffer.GetTensorTypeAndShape();
Assert.Equal(1, typeShape.DimensionsCount);
var fetchedShape = typeShape.Shape;
Assert.Equal(397510, fetchedShape[0]);

trainingSession.FromBuffer(buffer);
trainingSession.FromBuffer(buffer);
}
}
}

Expand All @@ -534,24 +540,18 @@ public void TestSetSeed()
public void TestGetParameter()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");

using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
using (var parameter = state.GetParameter("fc1.weight"))
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");

var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
cleanUp.Add(trainingSession);

var parameter = state.GetParameter("fc1.weight");
cleanUp.Add(parameter);

Assert.NotNull(parameter);
var typeShape = parameter.GetTensorTypeAndShape();

var typeShape = parameter.GetTensorTypeAndShape();
Assert.Equal(2, typeShape.DimensionsCount);
var fetchedShape = typeShape.Shape;
Assert.Equal(500, fetchedShape[0]);
Expand All @@ -563,54 +563,52 @@ public void TestGetParameter()
public void TestUpdateParameter()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");

using (var state = CheckpointState.LoadCheckpoint(checkpointPath))
using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath))
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");

var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
cleanUp.Add(trainingSession);

var parameter = state.GetParameter("fc1.weight");
cleanUp.Add(parameter);

Assert.NotNull(parameter);
var typeShape = parameter.GetTensorTypeAndShape();

Assert.Equal(2, typeShape.DimensionsCount);
var fetchedShape = typeShape.Shape;
Assert.Equal(500, fetchedShape[0]);
Assert.Equal(784, fetchedShape[1]);

float maxVal = 20;
Random randNum = new Random();
float[] updated_parameter_buffer = Enumerable
.Repeat(0, 500 * 784)
.Select(i => maxVal * (float)randNum.NextDouble())
.ToArray();

var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape);
cleanUp.Add(updated_parameter);

state.UpdateParameter("fc1.weight", updated_parameter);
var current_parameter = state.GetParameter("fc1.weight");
cleanUp.Add(current_parameter);

var current_parameter_tensor = current_parameter.GetTensorDataAsSpan<float>().ToArray();
Assert.Equal(updated_parameter_buffer, current_parameter_tensor);
Assert.NotEqual(parameter.GetTensorDataAsSpan<float>().ToArray(), current_parameter_tensor);

state.UpdateParameter("fc1.weight", parameter);
current_parameter = state.GetParameter("fc1.weight");
cleanUp.Add(current_parameter);

current_parameter_tensor = current_parameter.GetTensorDataAsSpan<float>().ToArray();
Assert.Equal(parameter.GetTensorDataAsSpan<float>().ToArray(), current_parameter_tensor);
Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor);
using (var parameter = state.GetParameter("fc1.weight"))
{
Assert.NotNull(parameter);
var typeShape = parameter.GetTensorTypeAndShape();

Assert.Equal(2, typeShape.DimensionsCount);
var fetchedShape = typeShape.Shape;
Assert.Equal(500, fetchedShape[0]);
Assert.Equal(784, fetchedShape[1]);

float maxVal = 20;
Random randNum = new Random();
float[] updated_parameter_buffer = Enumerable
.Repeat(0, 500 * 784)
.Select(i => maxVal * (float)randNum.NextDouble())
.ToArray();

using (var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape))
{
state.UpdateParameter("fc1.weight", updated_parameter);
using (var current_parameter = state.GetParameter("fc1.weight"))
{
var current_parameter_tensor = current_parameter.GetTensorDataAsSpan<float>().ToArray();
Assert.Equal(updated_parameter_buffer, current_parameter_tensor);
Assert.NotEqual(parameter.GetTensorDataAsSpan<float>().ToArray(), current_parameter_tensor);
}

state.UpdateParameter("fc1.weight", parameter);

using (var current_parameter = state.GetParameter("fc1.weight"))
{
var current_parameter_tensor = current_parameter.GetTensorDataAsSpan<float>().ToArray();
Assert.Equal(parameter.GetTensorDataAsSpan<float>().ToArray(), current_parameter_tensor);
Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor);
}
}
}
}
}

Expand Down

0 comments on commit 58ff40b

Please sign in to comment.