Skip to content

Commit

Permalink
Expose model parameters and their gradients in Python
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Aug 31, 2023
1 parent 8f1d194 commit df21a2e
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 25 deletions.
62 changes: 51 additions & 11 deletions orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1065,17 +1065,42 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc");
checkpoint_state
.def(py::init())
.def("add_property", [](onnxruntime::training::api::CheckpointState* state,
const std::string& property_name,
const std::variant<int64_t, float, std::string>& property_value) {
state->property_bag.AddProperty(property_name, property_value);
})
.def("get_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
return state->property_bag.GetProperty<onnxruntime::training::api::PropertyDataType>(property_name);
})
.def("has_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
return state->property_bag.HasProperty(property_name);
});
.def("add_property",
[](onnxruntime::training::api::CheckpointState* state,
const std::string& property_name,
const std::variant<int64_t, float, std::string>& property_value) {
state->property_bag.AddProperty(property_name, property_value);
})
.def("get_property",
[](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
return state->property_bag.GetProperty<onnxruntime::training::api::PropertyDataType>(property_name);
})
.def("has_property",
[](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) {
return state->property_bag.HasProperty(property_name);
})
.def("copy_parameter_from",
[](onnxruntime::training::api::CheckpointState* state,
const std::string& parameter_name, OrtValue& value) -> void {
auto it = state->module_checkpoint_state.named_parameters.find(parameter_name);
if (it == state->module_checkpoint_state.named_parameters.end()) {
ORT_THROW("Parameter with name ", parameter_name, " does not exist.");
}
ORT_THROW_IF_ERROR(it->second->CopyFrom(
value, state->module_checkpoint_state.train_session_data_transfer_mgr));
})
.def("get_parameter",
[](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) {
auto it = state->module_checkpoint_state.named_parameters.find(parameter_name);
if (it == state->module_checkpoint_state.named_parameters.end()) {
ORT_THROW("Parameter with name ", parameter_name, " does not exist.");
}
return it->second;
})
.def("has_parameter",
[](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) {
return state->module_checkpoint_state.named_parameters.count(parameter_name);
});

py::class_<PyOptimizer>
training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc");
Expand Down Expand Up @@ -1111,6 +1136,21 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
ORT_THROW_IF_ERROR(scheduler->Step());
});

py::class_<onnxruntime::training::api::Parameter,
std::unique_ptr<onnxruntime::training::api::Parameter, py::nodelete>>
parameter(m, "Parameter");
parameter
.def_property_readonly("name", &onnxruntime::training::api::Parameter::Name)
.def_property_readonly("data", &onnxruntime::training::api::Parameter::Data)
.def_property_readonly("grad", &onnxruntime::training::api::Parameter::Gradient)
.def_property_readonly("requires_grad", &onnxruntime::training::api::Parameter::RequiresGrad)
.def("copy_from",
[](onnxruntime::training::api::Parameter* parameter,
onnxruntime::training::api::CheckpointState* state,
OrtValue& value) -> void {
ORT_THROW_IF_ERROR(parameter->CopyFrom(value, state->module_checkpoint_state.train_session_data_transfer_mgr));

Check warning on line 1151 in orttraining/orttraining/python/orttraining_pybind_state.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/python/orttraining_pybind_state.cc#L1151

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
orttraining/orttraining/python/orttraining_pybind_state.cc:1151:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
});

m.def(
"save_checkpoint",
[](const std::vector<py::bytes>& trainable_tensor_protos_pybytes,
Expand Down
96 changes: 82 additions & 14 deletions orttraining/orttraining/python/training/api/checkpoint_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,56 @@

import os

import numpy as np

from onnxruntime.capi import _pybind_state as C
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue


class Parameter:
"""Class that represents a model parameter
This class represents a model parameter and provides access to its data,
gradient and other properties. This class is not expected to be instantiated directly.
Instead, it is returned by the `CheckpointState` object.
Args:
parameter: The C.Parameter object that holds the underlying parameter data.
state: The C.CheckpointState object that holds the underlying session state.
"""

def __init__(self, parameter: C.Parameter, state: C.CheckpointState):
self._parameter = parameter
self._state = state

@property
def name(self) -> str:
"""The name of the parameter"""
return self._parameter.name

@property
def data(self) -> np.ndarray:
"""The data of the parameter"""
return self._parameter.data.numpy()

@data.setter
def data(self, value: np.ndarray) -> None:
"""Sets the data of the parameter"""
self._parameter.copy_from(self._state, OrtValue.ortvalue_from_numpy(value)._ortvalue)

@property
def grad(self) -> np.ndarray:
"""The gradient of the parameter"""
return self._parameter.grad.numpy() if self._parameter.grad.has_value() else None

@property
def requires_grad(self) -> bool:
"""Whether or not the parameter requires its gradient to be computed"""
return self._parameter.requires_grad

def __repr__(self) -> str:
"""Returns a string representation of the parameter"""
return f"Parameter(name={self.name}, requires_grad={self.requires_grad})"


class CheckpointState:
Expand Down Expand Up @@ -52,33 +101,52 @@ def save_checkpoint(
"""
C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state)

def __getitem__(self, name: str) -> int | float | str:
"""Gets the property associated with the given name
def __getitem__(self, name: str) -> int | float | str | Parameter:
"""Gets the parameter or property associated with the given name
Searches for the name in the parameters and properties of the checkpoint state.
Args:
name: The name of the property
name: The name of the parameter or property
Returns:
The value of the property
The value of the parameter or property
"""
return self._state.get_property(name)

def __setitem__(self, name: str, value: int | float | str) -> None:
"""Sets the property value for the given name
if self._state.has_parameter(name):
return Parameter(self._state.get_parameter(name), self._state)
elif self._state.has_property(name):
return self._state.get_property(name)
else:
raise KeyError(f"Could not find {name} in the checkpoint state.")

def __setitem__(self, name: str, value: int | float | str | np.ndarray) -> None:
"""Sets the parameter or property value for the given name
Searches for the name in the parameters and properties of the checkpoint state.
If the name is found in parameters, the value is updated.
Else, the value is added or updated in the properties.
Args:
name: The name of the property
value: The value of the property
name: The name of the parameter or property
value: The value of the parameter or property
Properties only support int, float and str values.
"""
self._state.add_property(name, value)
if self._state.has_parameter(name):
self._state.copy_parameter_from(name, OrtValue.ortvalue_from_numpy(value)._ortvalue)
else:
self._state.add_property(name, value)

def __contains__(self, name: str) -> bool:
"""Checks if the property exists in the state
"""Checks if the parameter or property exists in the state
Tthe name is searched in both parameters and properties.
Args:
name: The name of the property
name: The name of the parameter or property
Returns:
True if the property exists, False otherwise
True if the name is either a parameter or a property, False otherwise
"""
return self._state.has_property(name)

return self._state.has_parameter(name) or self._state.has_property(name)
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,55 @@ def test_eval_step_with_ort_values():
fetches = model(inputs, labels)
assert isinstance(fetches, OrtValue)
assert fetches


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_get_and_set_parameter_values(device):
with tempfile.TemporaryDirectory() as temp_dir:
(
checkpoint_file_path,
training_model_file_path,
eval_model_file_path,
_,
pt_model,
) = _create_training_artifacts(
temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"]
)

state = CheckpointState.load_checkpoint(checkpoint_file_path)

model = Module(training_model_file_path, state, eval_model_file_path, device=device)

for name, pt_param in pt_model.named_parameters():
ort_param = state[name]
assert ort_param.name == name
assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data)
if name in ["fc1.weight", "fc1.bias"]:
assert ort_param.requires_grad is False
assert ort_param.grad is None
else:
assert ort_param.requires_grad is True
assert np.allclose(ort_param.grad, np.zeros_like(ort_param.data, dtype=np.float32))

original_param = state["fc1.weight"].data
state["fc1.weight"].data = np.ones_like(state["fc1.weight"].data, dtype=np.float32)
updated_param = state["fc1.weight"].data
assert np.allclose(updated_param, np.ones_like(updated_param, dtype=np.float32))

model.train()
inputs = torch.randn(64, 784).numpy()
labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy()
loss = model(inputs, labels)
assert loss is not None
for name, _ in pt_model.named_parameters():
ort_param = state[name]
assert ort_param.name == name
if name in ["fc1.weight", "fc1.bias"]:
assert ort_param.requires_grad is False
assert ort_param.grad is None
else:
assert ort_param.requires_grad is True
assert ort_param.grad.any()

state["fc1.weight"] = original_param
assert np.allclose(state["fc1.weight"].data, original_param)

0 comments on commit df21a2e

Please sign in to comment.