From 956319786a5a2df6040ef1643f9652344300e2d0 Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Thu, 12 Sep 2024 06:37:10 +0000 Subject: [PATCH 01/10] feat: add syncing models utility to ivy --- ivy/functional/backends/tensorflow/module.py | 408 ++++++++++++++++++- 1 file changed, 406 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/tensorflow/module.py b/ivy/functional/backends/tensorflow/module.py index faf06fb618bc..13abde709a30 100644 --- a/ivy/functional/backends/tensorflow/module.py +++ b/ivy/functional/backends/tensorflow/module.py @@ -3,13 +3,417 @@ import re import os import tensorflow as tf +import keras +import numpy as np import functools from tensorflow.python.util import nest -from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, Union +from typing import ( + NamedTuple, + Callable, + Any, + Tuple, + List, + Dict, + Type, + Union, + TYPE_CHECKING, +) import inspect from collections import OrderedDict from packaging.version import parse -import keras + +if TYPE_CHECKING: + import torch.nn as nn + + +if parse(keras.__version__).major > 2: + KerasVariable = keras.src.backend.Variable +else: + KerasVariable = tf.Variable + + +def _compute_module_dict_tf(model, prefix=""): + _module_dict = dict() + for key, value in model.__dict__.items(): + if isinstance(value, (tf.keras.Model, tf.keras.layers.Layer)): + if not hasattr(value, "named_parameters"): + _module_dict.update(_compute_module_dict_tf(value, prefix=f"{key}.")) + else: + _module_dict[prefix + key] = value + return _module_dict + + +def _compute_module_dict_pt(model, keychains): + _module_dict = dict() + for keychain in keychains: + keys = keychain.split(".") + value = model + for key in keys: + value = getattr(value, key) + _module_dict[keychain] = value + return _module_dict + + +def _retrive_layer(model, key): + if len(key.split(".")) == 1: + return model, key + + module_path, weight_name = key.rsplit(".", 1) + + # Retrieve the layer using the module path + layer = model + for attr in module_path.split("."): + layer = getattr(layer, attr) + + return layer, weight_name + + +def _maybe_update_keras_layer_weights(layer, weight_name, new_weight): + # Update the weight in the retrieved layer + if hasattr(layer, weight_name): + weight_var = getattr(layer, weight_name) + if isinstance(weight_var, tf.Variable): + weight_var.assign(tf.Variable(new_weight, dtype=weight_var.dtype)) + elif isinstance(weight_var, KerasVariable): + weight_var.assign( + KerasVariable(new_weight, dtype=weight_var.dtype, name=weight_var.name) + ) + else: + setattr( + layer, + weight_name, + tf.convert_to_tensor(new_weight, dtype=weight_var.dtype), + ) + else: + raise AttributeError( + f"Layer '{layer}' does not have a weight named '{weight_name}'" + ) + + +def _sync_models_torch_and_tf(model1: "nn.Module", model2: Any[Model, Layer]): + """ + Synchronizes the parameters and buffers between two models: a PyTorch + model and a TensorFlow model that is an instance of `Model` or `Layer`. + + Args: + model1 (torch.nn.Module): The original PyTorch model. + model2 (keras.Model | keras.Layer): The custom TensorFlow model, + which must inherit from both + `keras.Model`/`keras.Layer` + and expose a `torch.nn.Module`-like + interface (with `named_parameters()` + and `named_buffers()` methods). + + Returns: + None + + Example: + ```python + import torch.nn as nn + import keras + + # `CustomKerasLinear` is a subclass of keras.layers.Layer and exposes a similar + # interface to torch.nn.Module (with named_parameters and named_buffers). + class CustomKerasLinear(Layer): + def __init__(self, in_features, out_features): + super(CustomKerasLinear, self).__init__() + self.weight = tf.Variable(tf.random.normal([out_features, in_features])) + self.bias = tf.Variable(tf.random.normal([out_features])) + + def call(self, x): + return tf.matmul(x, self.weight) + self.bias + + def named_parameters(self): + return [("weight", self.weight), ("bias", self.bias)] + + def named_buffers(self): + return [] + + def eval(self): + return False + + # `CustomKerasModel` is a subclass of keras.Model and exposes a similar + # interface to torch.nn.Module (with named_parameters and named_buffers). + class CustomKerasModel(Model): + def __init__(self): + super(CustomKerasModel, self).__init__() + self.linear = CustomKerasLinear(10, 5) + + def call(self, x): + return self.linear(x) + + def named_parameters(self): + return [("linear.weight", self.linear.weight), ("linear.bias", self.linear.bias)] + + def named_buffers(self): + return [] + + def eval(self): + return False + + class PyTorchModel(nn.Module): + def __init__(self): + super(PyTorchModel, self).__init__() + self.linear = nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + # Instantiate both models + model_pt = PyTorchModel() # PyTorch model + model_tf = CustomKerasModel() # Custom Keras model + + # Sync all submodules between the PyTorch and Keras models + _sync_models_torch_and_tf(model_pt, model_tf) + ``` + """ + import torch + + has_keras_layers = os.environ.get("USE_NATIVE_KERAS_LAYERS", None) == "true" + transpose_weights = ( + has_keras_layers + or os.environ.get("APPLY_TRANSPOSE_OPTIMIZATION", None) == "true" + ) + + params1 = dict(model1.named_parameters()) + params2 = dict(model2.named_parameters()) + buffers1 = dict(model1.named_buffers()) + buffers2 = dict(model2.named_buffers()) + # TODO: remove this once the stateful attribute name-conflict has been resolved. + key_mapping = {} + for k in params2.keys(): + key_mapping[k.replace("pt_", "")] = k + + for k in buffers2.keys(): + key_mapping[k.replace("pt_", "")] = k + + params2 = {k.replace("pt_", ""): v for k, v in params2.items()} + buffers2 = {k.replace("pt_", ""): v for k, v in buffers2.items()} + + # Check if both models have the same parameters and buffers + assert params1.keys() == params2.keys() + assert buffers1.keys() == buffers2.keys() + + # Set the parameters and buffers of the second model to be the same as the first model + with torch.no_grad(): + for name in params1: + layer, weight_name = _retrive_layer(model2, key_mapping[name]) + + params1_np = params1[name].cpu().detach().numpy() + # Transpose the parameters to match the TensorFlow format + if ( + transpose_weights + and "DepthwiseConv" in layer.__class__.__name__ + and len(params1_np.shape) == 4 + ): # DepthConvolutional layer + params1_np = np.transpose(params1_np, (2, 3, 0, 1)) + elif ( + transpose_weights + and "Conv" in layer.__class__.__name__ + and len(params1_np.shape) == 4 + ): # Convolutional layer + params1_np = np.transpose(params1_np, (2, 3, 1, 0)) + elif ( + "Dense" in layer.__class__.__name__ + and len(params1_np.shape) == 2 + and layer.built + ): # Dense layer + params1_np = np.transpose(params1_np, (1, 0)) + + # inplace update the native keras layer. This is done as the parameters in + # self.v are a different copy than the parameters in self.weights. Hence, we + # need to explicitly update self.weights, otherwise the changes won't reflect. + if layer.__class__.__name__.startswith("Keras"): + _maybe_update_keras_layer_weights( + layer=layer, weight_name=weight_name, new_weight=params1_np + ) + params2[name] = getattr(layer, weight_name) + continue + + params2[name].assign(tf.Variable(params1_np, dtype=params2[name].dtype)) + + for name in buffers1: + layer, weight_name = _retrive_layer(model2, key_mapping[name]) + + buffers1_np = buffers1[name].cpu().detach().numpy() + if ( + transpose_weights + and "DepthwiseConv" in layer.__class__.__name__ + and len(params1_np.shape) == 4 + ): # DepthConvolutional layer + params1_np = np.transpose(params1_np, (2, 3, 0, 1)) + elif ( + transpose_weights + and "Conv" in layer.__class__.__name__ + and len(params1_np.shape) == 4 + ): # Convolutional layer + buffers1_np = np.transpose(buffers1_np, (2, 3, 1, 0)) + elif ( + "Dense" in layer.__class__.__name__ + and len(params1_np.shape) == 2 + and layer.built + ): # Dense layer + buffers1_np = np.transpose(buffers1_np, (1, 0)) + + # inplace update the native keras layer. This is done as the parameters in + # self.v are a different copy than the parameters in self.weights. Hence, we + # need to explicitly update self.weights, otherwise the changes won't reflect. + if layer.__class__.__name__.startswith("Keras"): + _maybe_update_keras_layer_weights( + layer=layer, weight_name=weight_name, new_weight=buffers1_np + ) + buffers2[name] = getattr(layer, weight_name) + continue + + if isinstance(buffers2[name], tf.Variable): + buffers2[name].assign( + tf.Variable(buffers1_np, dtype=buffers2[name].dtype) + ) + else: + buffers2[name] = tf.convert_to_tensor( + buffers1_np, dtype=buffers2[name].dtype + ) + + # Check if the parameters and buffers are the same + for name in params1: + layer, weight_name = _retrive_layer(model2, key_mapping[name]) + + params1_np = params1[name].cpu().detach().numpy() + params2_np = params2[name].numpy() + # Transpose the parameters back to the PyTorch format for comparison + if ( + transpose_weights + and "DepthwiseConv" in layer.__class__.__name__ + and len(params2_np.shape) == 4 + ): # Convolutional layer + params2_np = np.transpose(params2_np, (2, 3, 0, 1)) + elif ( + transpose_weights + and "Conv" in layer.__class__.__name__ + and len(params2_np.shape) == 4 + ): # Convolutional layer + params2_np = np.transpose(params2_np, (3, 2, 0, 1)) + elif ( + "Dense" in layer.__class__.__name__ + and len(params1_np.shape) == 2 + and layer.built + ): # Dense layer + params2_np = np.transpose(params2_np, (1, 0)) + + assert np.allclose( + params1_np, params2_np + ), f"Mismatch found in parameters: {name}" + + for name in buffers1: + layer, weight_name = _retrive_layer(model2, key_mapping[name]) + + buffers1_np = buffers1[name].cpu().detach().numpy() + buffers2_np = buffers2[name].numpy() + + # Transpose the parameters back to the PyTorch format for comparison + if ( + transpose_weights + and "DepthwiseConv" in layer.__class__.__name__ + and len(params2_np.shape) == 4 + ): # Convolutional layer + params2_np = np.transpose(params2_np, (2, 3, 0, 1)) + elif ( + transpose_weights + and "Conv" in layer.__class__.__name__ + and len(params2_np.shape) == 4 + ): # Convolutional layer + buffers2_np = np.transpose(buffers2_np, (3, 2, 0, 1)) + elif ( + "Dense" in layer.__class__.__name__ + and len(params1_np.shape) == 2 + and layer.built + ): # Dense layer + buffers2_np = np.transpose(buffers2_np, (1, 0)) + + assert np.allclose( + buffers1_np, buffers2_np + ), f"Mismatch found in buffers: {name}" + + +def sync_models_torch_and_tf(model_pt: "nn.Module", model_tf: keras.Model): + """ + Synchronizes the weights and buffers between a PyTorch model (`torch.nn.Module`) + and a TensorFlow model (`keras.Model`) that uses custom submodules. + + This function ensures that the PyTorch model and the TensorFlow model + have identical parameters and buffers by iterating through their submodules + and synchronizing them. The TensorFlow model's submodules must be instances + of `Model`/`Layer` and exposes an interface similar to `torch.nn.Module`, + particularly the `named_parameters()` and `named_buffers()` methods. + + Args: + model_pt (torch.nn.Module): The original PyTorch model. + model_tf (keras.Model): The TensorFlow model, which should consist of + submodules that inherit from the custom + Model/Layer class. + + Returns: + None + + Example: + ```python + import torch.nn as nn + import keras + + class CustomKerasLinear(Layer): + def __init__(self, in_features, out_features): + super(CustomKerasLinear, self).__init__() + self.weight = tf.Variable(tf.random.normal([out_features, in_features])) + self.bias = tf.Variable(tf.random.normal([out_features])) + + def call(self, x): + return tf.matmul(x, self.weight) + self.bias + + def named_parameters(self): + return [("weight", self.weight), ("bias", self.bias)] + + def named_buffers(self): + return [] + + def eval(self): + return False + + #`NativeKerasModel` is a subclass of keras.Model and does NOT exposes a similar + # interface to torch.nn.Module (with named_parameters and named_buffers). + class NativeKerasModel(keras.Model): + def __init__(self): + super(NativeKerasModel, self).__init__() + self.linear = CustomKerasLinear(10, 5) + + def call(self, x): + return self.linear(x) + + class PyTorchModel(nn.Module): + def __init__(self): + super(PyTorchModel, self).__init__() + self.linear = nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + # Instantiate both models + model_pt = PyTorchModel() # PyTorch model + model_tf = NativeKerasModel() # Native Keras model inheriting from keras.Model + + # Sync all submodules between the PyTorch and Keras models + sync_models_torch_and_tf(model_pt, model_tf) + ``` + """ + + all_submods_tf = _compute_module_dict_tf(model_tf) + all_submods_pt = _compute_module_dict_pt( + model_pt, keychains=list(all_submods_tf.keys()) + ) + + for pt_model, tf_model in zip(all_submods_pt.values(), all_submods_tf.values()): + pt_model.eval() + tf_model.eval() + _sync_models_torch_and_tf(pt_model, tf_model) def get_assignment_dict(): From f3ad98dc5b1d756ac5d5285a8586b0816e59dfb1 Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Wed, 18 Sep 2024 16:11:52 +0000 Subject: [PATCH 02/10] chore: removing the sync models logic from the stateful module.py --- ivy/functional/backends/tensorflow/module.py | 382 ------------------- 1 file changed, 382 deletions(-) diff --git a/ivy/functional/backends/tensorflow/module.py b/ivy/functional/backends/tensorflow/module.py index 13abde709a30..dce43ed7671c 100644 --- a/ivy/functional/backends/tensorflow/module.py +++ b/ivy/functional/backends/tensorflow/module.py @@ -32,388 +32,6 @@ KerasVariable = tf.Variable -def _compute_module_dict_tf(model, prefix=""): - _module_dict = dict() - for key, value in model.__dict__.items(): - if isinstance(value, (tf.keras.Model, tf.keras.layers.Layer)): - if not hasattr(value, "named_parameters"): - _module_dict.update(_compute_module_dict_tf(value, prefix=f"{key}.")) - else: - _module_dict[prefix + key] = value - return _module_dict - - -def _compute_module_dict_pt(model, keychains): - _module_dict = dict() - for keychain in keychains: - keys = keychain.split(".") - value = model - for key in keys: - value = getattr(value, key) - _module_dict[keychain] = value - return _module_dict - - -def _retrive_layer(model, key): - if len(key.split(".")) == 1: - return model, key - - module_path, weight_name = key.rsplit(".", 1) - - # Retrieve the layer using the module path - layer = model - for attr in module_path.split("."): - layer = getattr(layer, attr) - - return layer, weight_name - - -def _maybe_update_keras_layer_weights(layer, weight_name, new_weight): - # Update the weight in the retrieved layer - if hasattr(layer, weight_name): - weight_var = getattr(layer, weight_name) - if isinstance(weight_var, tf.Variable): - weight_var.assign(tf.Variable(new_weight, dtype=weight_var.dtype)) - elif isinstance(weight_var, KerasVariable): - weight_var.assign( - KerasVariable(new_weight, dtype=weight_var.dtype, name=weight_var.name) - ) - else: - setattr( - layer, - weight_name, - tf.convert_to_tensor(new_weight, dtype=weight_var.dtype), - ) - else: - raise AttributeError( - f"Layer '{layer}' does not have a weight named '{weight_name}'" - ) - - -def _sync_models_torch_and_tf(model1: "nn.Module", model2: Any[Model, Layer]): - """ - Synchronizes the parameters and buffers between two models: a PyTorch - model and a TensorFlow model that is an instance of `Model` or `Layer`. - - Args: - model1 (torch.nn.Module): The original PyTorch model. - model2 (keras.Model | keras.Layer): The custom TensorFlow model, - which must inherit from both - `keras.Model`/`keras.Layer` - and expose a `torch.nn.Module`-like - interface (with `named_parameters()` - and `named_buffers()` methods). - - Returns: - None - - Example: - ```python - import torch.nn as nn - import keras - - # `CustomKerasLinear` is a subclass of keras.layers.Layer and exposes a similar - # interface to torch.nn.Module (with named_parameters and named_buffers). - class CustomKerasLinear(Layer): - def __init__(self, in_features, out_features): - super(CustomKerasLinear, self).__init__() - self.weight = tf.Variable(tf.random.normal([out_features, in_features])) - self.bias = tf.Variable(tf.random.normal([out_features])) - - def call(self, x): - return tf.matmul(x, self.weight) + self.bias - - def named_parameters(self): - return [("weight", self.weight), ("bias", self.bias)] - - def named_buffers(self): - return [] - - def eval(self): - return False - - # `CustomKerasModel` is a subclass of keras.Model and exposes a similar - # interface to torch.nn.Module (with named_parameters and named_buffers). - class CustomKerasModel(Model): - def __init__(self): - super(CustomKerasModel, self).__init__() - self.linear = CustomKerasLinear(10, 5) - - def call(self, x): - return self.linear(x) - - def named_parameters(self): - return [("linear.weight", self.linear.weight), ("linear.bias", self.linear.bias)] - - def named_buffers(self): - return [] - - def eval(self): - return False - - class PyTorchModel(nn.Module): - def __init__(self): - super(PyTorchModel, self).__init__() - self.linear = nn.Linear(10, 5) - - def forward(self, x): - return self.linear(x) - - # Instantiate both models - model_pt = PyTorchModel() # PyTorch model - model_tf = CustomKerasModel() # Custom Keras model - - # Sync all submodules between the PyTorch and Keras models - _sync_models_torch_and_tf(model_pt, model_tf) - ``` - """ - import torch - - has_keras_layers = os.environ.get("USE_NATIVE_KERAS_LAYERS", None) == "true" - transpose_weights = ( - has_keras_layers - or os.environ.get("APPLY_TRANSPOSE_OPTIMIZATION", None) == "true" - ) - - params1 = dict(model1.named_parameters()) - params2 = dict(model2.named_parameters()) - buffers1 = dict(model1.named_buffers()) - buffers2 = dict(model2.named_buffers()) - # TODO: remove this once the stateful attribute name-conflict has been resolved. - key_mapping = {} - for k in params2.keys(): - key_mapping[k.replace("pt_", "")] = k - - for k in buffers2.keys(): - key_mapping[k.replace("pt_", "")] = k - - params2 = {k.replace("pt_", ""): v for k, v in params2.items()} - buffers2 = {k.replace("pt_", ""): v for k, v in buffers2.items()} - - # Check if both models have the same parameters and buffers - assert params1.keys() == params2.keys() - assert buffers1.keys() == buffers2.keys() - - # Set the parameters and buffers of the second model to be the same as the first model - with torch.no_grad(): - for name in params1: - layer, weight_name = _retrive_layer(model2, key_mapping[name]) - - params1_np = params1[name].cpu().detach().numpy() - # Transpose the parameters to match the TensorFlow format - if ( - transpose_weights - and "DepthwiseConv" in layer.__class__.__name__ - and len(params1_np.shape) == 4 - ): # DepthConvolutional layer - params1_np = np.transpose(params1_np, (2, 3, 0, 1)) - elif ( - transpose_weights - and "Conv" in layer.__class__.__name__ - and len(params1_np.shape) == 4 - ): # Convolutional layer - params1_np = np.transpose(params1_np, (2, 3, 1, 0)) - elif ( - "Dense" in layer.__class__.__name__ - and len(params1_np.shape) == 2 - and layer.built - ): # Dense layer - params1_np = np.transpose(params1_np, (1, 0)) - - # inplace update the native keras layer. This is done as the parameters in - # self.v are a different copy than the parameters in self.weights. Hence, we - # need to explicitly update self.weights, otherwise the changes won't reflect. - if layer.__class__.__name__.startswith("Keras"): - _maybe_update_keras_layer_weights( - layer=layer, weight_name=weight_name, new_weight=params1_np - ) - params2[name] = getattr(layer, weight_name) - continue - - params2[name].assign(tf.Variable(params1_np, dtype=params2[name].dtype)) - - for name in buffers1: - layer, weight_name = _retrive_layer(model2, key_mapping[name]) - - buffers1_np = buffers1[name].cpu().detach().numpy() - if ( - transpose_weights - and "DepthwiseConv" in layer.__class__.__name__ - and len(params1_np.shape) == 4 - ): # DepthConvolutional layer - params1_np = np.transpose(params1_np, (2, 3, 0, 1)) - elif ( - transpose_weights - and "Conv" in layer.__class__.__name__ - and len(params1_np.shape) == 4 - ): # Convolutional layer - buffers1_np = np.transpose(buffers1_np, (2, 3, 1, 0)) - elif ( - "Dense" in layer.__class__.__name__ - and len(params1_np.shape) == 2 - and layer.built - ): # Dense layer - buffers1_np = np.transpose(buffers1_np, (1, 0)) - - # inplace update the native keras layer. This is done as the parameters in - # self.v are a different copy than the parameters in self.weights. Hence, we - # need to explicitly update self.weights, otherwise the changes won't reflect. - if layer.__class__.__name__.startswith("Keras"): - _maybe_update_keras_layer_weights( - layer=layer, weight_name=weight_name, new_weight=buffers1_np - ) - buffers2[name] = getattr(layer, weight_name) - continue - - if isinstance(buffers2[name], tf.Variable): - buffers2[name].assign( - tf.Variable(buffers1_np, dtype=buffers2[name].dtype) - ) - else: - buffers2[name] = tf.convert_to_tensor( - buffers1_np, dtype=buffers2[name].dtype - ) - - # Check if the parameters and buffers are the same - for name in params1: - layer, weight_name = _retrive_layer(model2, key_mapping[name]) - - params1_np = params1[name].cpu().detach().numpy() - params2_np = params2[name].numpy() - # Transpose the parameters back to the PyTorch format for comparison - if ( - transpose_weights - and "DepthwiseConv" in layer.__class__.__name__ - and len(params2_np.shape) == 4 - ): # Convolutional layer - params2_np = np.transpose(params2_np, (2, 3, 0, 1)) - elif ( - transpose_weights - and "Conv" in layer.__class__.__name__ - and len(params2_np.shape) == 4 - ): # Convolutional layer - params2_np = np.transpose(params2_np, (3, 2, 0, 1)) - elif ( - "Dense" in layer.__class__.__name__ - and len(params1_np.shape) == 2 - and layer.built - ): # Dense layer - params2_np = np.transpose(params2_np, (1, 0)) - - assert np.allclose( - params1_np, params2_np - ), f"Mismatch found in parameters: {name}" - - for name in buffers1: - layer, weight_name = _retrive_layer(model2, key_mapping[name]) - - buffers1_np = buffers1[name].cpu().detach().numpy() - buffers2_np = buffers2[name].numpy() - - # Transpose the parameters back to the PyTorch format for comparison - if ( - transpose_weights - and "DepthwiseConv" in layer.__class__.__name__ - and len(params2_np.shape) == 4 - ): # Convolutional layer - params2_np = np.transpose(params2_np, (2, 3, 0, 1)) - elif ( - transpose_weights - and "Conv" in layer.__class__.__name__ - and len(params2_np.shape) == 4 - ): # Convolutional layer - buffers2_np = np.transpose(buffers2_np, (3, 2, 0, 1)) - elif ( - "Dense" in layer.__class__.__name__ - and len(params1_np.shape) == 2 - and layer.built - ): # Dense layer - buffers2_np = np.transpose(buffers2_np, (1, 0)) - - assert np.allclose( - buffers1_np, buffers2_np - ), f"Mismatch found in buffers: {name}" - - -def sync_models_torch_and_tf(model_pt: "nn.Module", model_tf: keras.Model): - """ - Synchronizes the weights and buffers between a PyTorch model (`torch.nn.Module`) - and a TensorFlow model (`keras.Model`) that uses custom submodules. - - This function ensures that the PyTorch model and the TensorFlow model - have identical parameters and buffers by iterating through their submodules - and synchronizing them. The TensorFlow model's submodules must be instances - of `Model`/`Layer` and exposes an interface similar to `torch.nn.Module`, - particularly the `named_parameters()` and `named_buffers()` methods. - - Args: - model_pt (torch.nn.Module): The original PyTorch model. - model_tf (keras.Model): The TensorFlow model, which should consist of - submodules that inherit from the custom - Model/Layer class. - - Returns: - None - - Example: - ```python - import torch.nn as nn - import keras - - class CustomKerasLinear(Layer): - def __init__(self, in_features, out_features): - super(CustomKerasLinear, self).__init__() - self.weight = tf.Variable(tf.random.normal([out_features, in_features])) - self.bias = tf.Variable(tf.random.normal([out_features])) - - def call(self, x): - return tf.matmul(x, self.weight) + self.bias - - def named_parameters(self): - return [("weight", self.weight), ("bias", self.bias)] - - def named_buffers(self): - return [] - - def eval(self): - return False - - #`NativeKerasModel` is a subclass of keras.Model and does NOT exposes a similar - # interface to torch.nn.Module (with named_parameters and named_buffers). - class NativeKerasModel(keras.Model): - def __init__(self): - super(NativeKerasModel, self).__init__() - self.linear = CustomKerasLinear(10, 5) - - def call(self, x): - return self.linear(x) - - class PyTorchModel(nn.Module): - def __init__(self): - super(PyTorchModel, self).__init__() - self.linear = nn.Linear(10, 5) - - def forward(self, x): - return self.linear(x) - - # Instantiate both models - model_pt = PyTorchModel() # PyTorch model - model_tf = NativeKerasModel() # Native Keras model inheriting from keras.Model - - # Sync all submodules between the PyTorch and Keras models - sync_models_torch_and_tf(model_pt, model_tf) - ``` - """ - - all_submods_tf = _compute_module_dict_tf(model_tf) - all_submods_pt = _compute_module_dict_pt( - model_pt, keychains=list(all_submods_tf.keys()) - ) - - for pt_model, tf_model in zip(all_submods_pt.values(), all_submods_tf.values()): - pt_model.eval() - tf_model.eval() - _sync_models_torch_and_tf(pt_model, tf_model) def get_assignment_dict(): From 47ca552e9e134dad05268939873d06bedd3e9db4 Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Wed, 18 Sep 2024 16:14:36 +0000 Subject: [PATCH 03/10] feat: adding model syncing helpers to ivy.stateful.utilites.py --- ivy/stateful/__init__.py | 2 + ivy/stateful/utilities.py | 741 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 743 insertions(+) create mode 100644 ivy/stateful/utilities.py diff --git a/ivy/stateful/__init__.py b/ivy/stateful/__init__.py index a04dd5d292fb..dcc7827bd341 100644 --- a/ivy/stateful/__init__.py +++ b/ivy/stateful/__init__.py @@ -16,3 +16,5 @@ from .optimizers import * from . import sequential from .sequential import * +from . import utilities +from .utilities import sync_models_torch diff --git a/ivy/stateful/utilities.py b/ivy/stateful/utilities.py new file mode 100644 index 000000000000..a36414f43e95 --- /dev/null +++ b/ivy/stateful/utilities.py @@ -0,0 +1,741 @@ +import os +from typing import Union, TYPE_CHECKING +from packaging.version import parse +import numpy as np +import ivy + +if TYPE_CHECKING: + from ivy.functional.backends.tensorflow import Model as KerasModel + from ivy.functional.backends.jax import Model as FlaxModel + import torch.nn as nn + import keras + import flax.nnx as nnx + + +def _compute_module_dict_pt(model, keychains): + _module_dict = dict() + for keychain in keychains: + keys = keychain.split(".") + value = model + for key in keys: + value = getattr(value, key) + _module_dict[keychain] = value + return _module_dict + + +def _retrive_layer(model, key): + if len(key.split(".")) == 1: + return model, key + + module_path, weight_name = key.rsplit(".", 1) + + # Retrieve the layer using the module path + layer = model + for attr in module_path.split("."): + layer = getattr(layer, attr) + + return layer, weight_name + + +def _sync_models_torch_and_jax(model1: "nn.Module", model2: "FlaxModel"): + """ + Synchronizes the parameters and buffers of the original and the translated model. + + Args: + model1 (torch.nn.Module): The original PyTorch model. + model2 (ivy.Module converted Flax.nnx.Module)): The converted ivy.Module converted Flax.nnx.Module. + + Returns: + None + """ + + def _pt_name_to_flax_name(layer, weight_name): + if layer.__class__.__name__ in ("FlaxConv", "FlaxLinear"): + param_and_buff_map = { + "weight": "kernel", + "bias": "bias", + } + elif layer.__class__.__name__ == "FlaxBatchNorm": + param_and_buff_map = { + "weight": "scale", + "bias": "bias", + "running_mean": "mean", + "running_var": "var", + "num_batches_tracked": "num_batches_tracked", + } + else: + raise ValueError(f"Layer '{layer}' is not supported.") + + return param_and_buff_map[weight_name] + + def _maybe_update_flax_layer_weights(layer, weight_name, new_weight): + # Update the weight in the retrieved layer + if hasattr(layer, weight_name): + weight_var = getattr(layer, weight_name) + if isinstance(weight_var, nnx.Variable): + weight_var.value = jnp.asarray(new_weight, dtype=weight_var.value.dtype) + else: + setattr( + layer, + weight_name, + jnp.asarray(new_weight, dtype=weight_var.dtype), + ) + else: + raise AttributeError( + f"Layer '{layer}' does not have a weight named '{weight_name}'" + ) + + import torch + import flax.nnx as nnx + import jax.numpy as jnp + + has_keras_layers = os.environ.get("USE_NATIVE_KERAS_LAYERS", None) == "true" + transpose_weights = ( + has_keras_layers + or os.environ.get("APPLY_TRANSPOSE_OPTIMIZATION", None) == "true" + ) + + params1 = dict(model1.named_parameters()) + params2 = dict(model2.named_parameters()) + buffers1 = dict(model1.named_buffers()) + buffers2 = dict(model2.named_buffers()) + # TODO: remove this once the stateful attribute name-conflict has been resolved. + key_mapping = {} + for k in params2.keys(): + key_mapping[k.replace("pt_", "")] = k + + for k in buffers2.keys(): + key_mapping[k.replace("pt_", "")] = k + + params2 = {k.replace("pt_", ""): v for k, v in params2.items()} + buffers2 = {k.replace("pt_", ""): v for k, v in buffers2.items()} + + # Check if both models have the same parameters and buffers + missing_in_params2 = params1.keys() - params2.keys() + if missing_in_params2: + raise AssertionError( + f"Mismatch in param keys:\n" + f"Missing params Flax model: {missing_in_params2}\n" + ) + missing_in_buffers2 = buffers1.keys() - buffers2.keys() + if missing_in_buffers2: + raise AssertionError( + f"Mismatch in buffers keys:\n" + f"Missing buffers in Flax model: {missing_in_buffers2}\n" + ) + + # Set the parameters and buffers of the second model to be the same as the first model + with torch.no_grad(): + for name in params1: + layer, weight_name = _retrive_layer(model2, key_mapping[name]) + + params1_np = params1[name].cpu().detach().numpy() + # Transpose the parameters to match the TensorFlow format + if ( + transpose_weights + and "Conv" in layer.__class__.__name__ + and len(params1_np.shape) == 4 + ): # Convolutional layer + params1_np = np.transpose(params1_np, (2, 3, 1, 0)) + elif ( + "Linear" in layer.__class__.__name__ and len(params1_np.shape) == 2 + ): # Dense layer + params1_np = np.transpose(params1_np, (1, 0)) + + # inplace update the native Flax layer. This is done as the parameters in + # self.v are a different copy than the layer's original parameters. Hence, we + # need to explicitly update the layer's original parameters, otherwise the changes won't reflect. + if layer.__class__.__name__.startswith("Flax"): + flax_name = _pt_name_to_flax_name(layer, weight_name) + _maybe_update_flax_layer_weights( + layer=layer, weight_name=flax_name, new_weight=params1_np + ) + params2[name] = getattr(layer, flax_name) + continue + + params2[name].value = jnp.asarray( + params1_np, dtype=params2[name].value.dtype + ) + + for name in buffers1: + layer, weight_name = _retrive_layer(model2, key_mapping[name]) + + buffers1_np = buffers1[name].cpu().detach().numpy() + if ( + transpose_weights + and "Conv" in layer.__class__.__name__ + and len(params1_np.shape) == 4 + ): # Convolutional layer + buffers1_np = np.transpose(buffers1_np, (2, 3, 1, 0)) + elif ( + "Linear" in layer.__class__.__name__ + and len(params1_np.shape) == 2 + and layer.built + ): # Dense layer + buffers1_np = np.transpose(buffers1_np, (1, 0)) + + # inplace update the native Flax layer. This is done as the buffers in + # self.buffers are a different copy than the layer's original buffers. Hence, we + # need to explicitly update the layer's original buffers, otherwise the changes won't reflect. + if layer.__class__.__name__.startswith("Flax"): + flax_name = _pt_name_to_flax_name(layer, weight_name) + _maybe_update_flax_layer_weights( + layer=layer, weight_name=flax_name, new_weight=buffers1_np + ) + buffers2[name] = getattr(layer, flax_name) + continue + + if isinstance(buffers2[name], nnx.Variable): + buffers2[name].value = jnp.asarray( + buffers1_np, dtype=buffers2[name].value.dtype + ) + + else: + buffers2[name] = jnp.asarray(buffers1_np, dtype=buffers2[name].dtype) + + # Check if the parameters and buffers are the same + for name in params1: + layer, weight_name = _retrive_layer(model2, key_mapping[name]) + + params1_np = params1[name].cpu().detach().numpy() + params2_np = params2[name].value._value + # Transpose the parameters back to the PyTorch format for comparison + if ( + transpose_weights + and "Conv" in layer.__class__.__name__ + and len(params2_np.shape) == 4 + ): # Convolutional layer + params2_np = np.transpose(params2_np, (3, 2, 0, 1)) + elif ( + "Linear" in layer.__class__.__name__ and len(params1_np.shape) == 2 + ): # Dense layer + params2_np = np.transpose(params2_np, (1, 0)) + + assert np.allclose( + params1_np, params2_np + ), f"Mismatch found in parameters: {name}" + + for name in buffers1: + layer, weight_name = _retrive_layer(model2, key_mapping[name]) + + buffers1_np = buffers1[name].cpu().detach().numpy() + buffers2_np = ( + buffers2[name].value._value + if isinstance(buffers2[name], nnx.Variable) + else buffers2[name]._value + ) + + # Transpose the parameters back to the PyTorch format for comparison + if ( + transpose_weights + and "Conv" in layer.__class__.__name__ + and len(params2_np.shape) == 4 + ): # Convolutional layer + buffers2_np = np.transpose(buffers2_np, (3, 2, 0, 1)) + elif ( + "Linear" in layer.__class__.__name__ and len(params1_np.shape) == 2 + ): # Dense layer + buffers2_np = np.transpose(buffers2_np, (1, 0)) + + assert np.allclose( + buffers1_np, buffers2_np + ), f"Mismatch found in buffers: {name}" + + +def _sync_models_torch_and_tf(model1: "nn.Module", model2: "KerasModel"): + """ + Synchronizes the parameters and buffers of the original and the translated model. + + Args: + model1 (torch.nn.Module): The original PyTorch model. + model2 (ivy.Module converted keras.Model)): The converted ivy.Module converted keras.Model. + + Returns: + None + """ + + def _maybe_update_keras_layer_weights(layer, weight_name, new_weight): + # Update the weight in the retrieved layer + if hasattr(layer, weight_name): + weight_var = getattr(layer, weight_name) + if isinstance(weight_var, tf.Variable): + weight_var.assign(tf.Variable(new_weight, dtype=weight_var.dtype)) + elif isinstance(weight_var, KerasVariable): + weight_var.assign( + KerasVariable( + new_weight, dtype=weight_var.dtype, name=weight_var.name + ) + ) + else: + setattr( + layer, + weight_name, + tf.convert_to_tensor(new_weight, dtype=weight_var.dtype), + ) + else: + raise AttributeError( + f"Layer '{layer}' does not have a weight named '{weight_name}'" + ) + + import torch + import tensorflow as tf + import keras + + if parse(keras.__version__).major > 2: + KerasVariable = keras.src.backend.Variable + else: + KerasVariable = tf.Variable + + has_keras_layers = os.environ.get("USE_NATIVE_KERAS_LAYERS", None) == "true" + transpose_weights = ( + has_keras_layers + or os.environ.get("APPLY_TRANSPOSE_OPTIMIZATION", None) == "true" + ) + + params1 = dict(model1.named_parameters()) + params2 = dict(model2.named_parameters()) + buffers1 = dict(model1.named_buffers()) + buffers2 = dict(model2.named_buffers()) + # TODO: remove this once the stateful attribute name-conflict has been resolved. + key_mapping = {} + for k in params2.keys(): + key_mapping[k.replace("pt_", "")] = k + + for k in buffers2.keys(): + key_mapping[k.replace("pt_", "")] = k + + params2 = {k.replace("pt_", ""): v for k, v in params2.items()} + buffers2 = {k.replace("pt_", ""): v for k, v in buffers2.items()} + + # Check if both models have the same parameters and buffers + assert params1.keys() == params2.keys() + assert buffers1.keys() == buffers2.keys() + + # Set the parameters and buffers of the second model to be the same as the first model + with torch.no_grad(): + for name in params1: + layer, weight_name = _retrive_layer(model2, key_mapping[name]) + + params1_np = params1[name].cpu().detach().numpy() + # Transpose the parameters to match the TensorFlow format + if ( + transpose_weights + and "DepthwiseConv" in layer.__class__.__name__ + and len(params1_np.shape) == 4 + ): # DepthConvolutional layer + params1_np = np.transpose(params1_np, (2, 3, 0, 1)) + elif ( + transpose_weights + and "Conv" in layer.__class__.__name__ + and len(params1_np.shape) == 4 + ): # Convolutional layer + params1_np = np.transpose(params1_np, (2, 3, 1, 0)) + elif ( + "Dense" in layer.__class__.__name__ + and len(params1_np.shape) == 2 + and layer.built + ): # Dense layer + params1_np = np.transpose(params1_np, (1, 0)) + + # inplace update the native keras layer. This is done as the parameters in + # self.v are a different copy than the parameters in self.weights. Hence, we + # need to explicitly update self.weights, otherwise the changes won't reflect. + if layer.__class__.__name__.startswith("Keras"): + _maybe_update_keras_layer_weights( + layer=layer, weight_name=weight_name, new_weight=params1_np + ) + params2[name] = getattr(layer, weight_name) + continue + + params2[name].assign(tf.Variable(params1_np, dtype=params2[name].dtype)) + + for name in buffers1: + layer, weight_name = _retrive_layer(model2, key_mapping[name]) + + buffers1_np = buffers1[name].cpu().detach().numpy() + if ( + transpose_weights + and "DepthwiseConv" in layer.__class__.__name__ + and len(params1_np.shape) == 4 + ): # DepthConvolutional layer + params1_np = np.transpose(params1_np, (2, 3, 0, 1)) + elif ( + transpose_weights + and "Conv" in layer.__class__.__name__ + and len(params1_np.shape) == 4 + ): # Convolutional layer + buffers1_np = np.transpose(buffers1_np, (2, 3, 1, 0)) + elif ( + "Dense" in layer.__class__.__name__ + and len(params1_np.shape) == 2 + and layer.built + ): # Dense layer + buffers1_np = np.transpose(buffers1_np, (1, 0)) + + # inplace update the native keras layer. This is done as the parameters in + # self.v are a different copy than the parameters in self.weights. Hence, we + # need to explicitly update self.weights, otherwise the changes won't reflect. + if layer.__class__.__name__.startswith("Keras"): + _maybe_update_keras_layer_weights( + layer=layer, weight_name=weight_name, new_weight=buffers1_np + ) + buffers2[name] = getattr(layer, weight_name) + continue + + if isinstance(buffers2[name], tf.Variable): + buffers2[name].assign( + tf.Variable(buffers1_np, dtype=buffers2[name].dtype) + ) + else: + buffers2[name] = tf.convert_to_tensor( + buffers1_np, dtype=buffers2[name].dtype + ) + + # Check if the parameters and buffers are the same + for name in params1: + layer, weight_name = _retrive_layer(model2, key_mapping[name]) + + params1_np = params1[name].cpu().detach().numpy() + params2_np = params2[name].numpy() + # Transpose the parameters back to the PyTorch format for comparison + if ( + transpose_weights + and "DepthwiseConv" in layer.__class__.__name__ + and len(params2_np.shape) == 4 + ): # Convolutional layer + params2_np = np.transpose(params2_np, (2, 3, 0, 1)) + elif ( + transpose_weights + and "Conv" in layer.__class__.__name__ + and len(params2_np.shape) == 4 + ): # Convolutional layer + params2_np = np.transpose(params2_np, (3, 2, 0, 1)) + elif ( + "Dense" in layer.__class__.__name__ + and len(params1_np.shape) == 2 + and layer.built + ): # Dense layer + params2_np = np.transpose(params2_np, (1, 0)) + + assert np.allclose( + params1_np, params2_np + ), f"Mismatch found in parameters: {name}" + + for name in buffers1: + layer, weight_name = _retrive_layer(model2, key_mapping[name]) + + buffers1_np = buffers1[name].cpu().detach().numpy() + buffers2_np = buffers2[name].numpy() + + # Transpose the parameters back to the PyTorch format for comparison + if ( + transpose_weights + and "DepthwiseConv" in layer.__class__.__name__ + and len(params2_np.shape) == 4 + ): # Convolutional layer + params2_np = np.transpose(params2_np, (2, 3, 0, 1)) + elif ( + transpose_weights + and "Conv" in layer.__class__.__name__ + and len(params2_np.shape) == 4 + ): # Convolutional layer + buffers2_np = np.transpose(buffers2_np, (3, 2, 0, 1)) + elif ( + "Dense" in layer.__class__.__name__ + and len(params1_np.shape) == 2 + and layer.built + ): # Dense layer + buffers2_np = np.transpose(buffers2_np, (1, 0)) + + assert np.allclose( + buffers1_np, buffers2_np + ), f"Mismatch found in buffers: {name}" + + +def sync_models_torch_and_tf( + model_pt: "nn.Module", model_tf: Union["keras.Model", "KerasModel"] +): + """ + Synchronizes the weights and buffers between a PyTorch model (`torch.nn.Module`) + and a TensorFlow model (`keras.Model`). + + This function ensures that both models have identical parameters and buffers by + iterating through their submodules and synchronizing them. The TensorFlow model + must either be an instance of `KerasModel` or have submodules that inherit from the + translated `KerasModel`/`KerasLayer`, and expose interfaces similar to `torch.nn.Module`, + including `named_parameters()` and `named_buffers()`. + + Args: + model_pt (torch.nn.Module): The PyTorch model to synchronize from. + model_tf (keras.Model): The TensorFlow model to synchronize to, with submodules + inheriting from the custom `KerasModel`/`KerasLayer` class. + + Returns: + None + + + Example: + ```python + import torch.nn as nn + import keras + + #`CustomKerasLinear` is a subclass of `Layer` that exposes a similar + # interface to torch.nn.Module (with named_parameters and named_buffers). + class CustomKerasLinear(Layer): + def __init__(self, in_features, out_features): + super(CustomKerasLinear, self).__init__() + self.weight = tf.Variable(tf.random.normal([out_features, in_features])) + self.bias = tf.Variable(tf.random.normal([out_features])) + + def call(self, x): + return tf.matmul(x, self.weight) + self.bias + + def named_parameters(self): + return [("weight", self.weight), ("bias", self.bias)] + + def named_buffers(self): + return [] + + def eval(self): + return False + + #`NativeKerasModel` is a subclass of keras.Model and does NOT exposes a similar + # interface to torch.nn.Module (with named_parameters and named_buffers). + class NativeKerasModel(keras.Model): + def __init__(self): + super(NativeKerasModel, self).__init__() + self.linear = CustomKerasLinear(10, 5) + + def call(self, x): + return self.linear(x) + + class PyTorchModel(nn.Module): + def __init__(self): + super(PyTorchModel, self).__init__() + self.linear = nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + # Instantiate both models + model_pt = PyTorchModel() # PyTorch model + model_tf = NativeKerasModel() # Native Keras model inheriting from keras.Model + + # Sync all submodules between the PyTorch and Keras models + sync_models_torch_and_tf(model_pt, model_tf) + ``` + """ + + def _compute_module_dict_tf(model, prefix=""): + _module_dict = dict() + for key, value in model.__dict__.items(): + if isinstance(value, (tf.keras.Model, tf.keras.layers.Layer)): + if not hasattr(value, "named_parameters"): + _module_dict.update( + _compute_module_dict_tf(value, prefix=f"{key}.") + ) + else: + _module_dict[prefix + key] = value + return _module_dict + + try: + import tensorflow as tf + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "`tensorflow` was not found installed on your system. Please proceed " + "to install it and restart your interpreter to see the changes." + ) from exc + + try: + assert isinstance( + model_tf, tf.keras.Model + ), "The second model must be an instance of `tf.keras.Model` (TensorFlow)." + except AssertionError as e: + raise TypeError("The second model must be a TensorFlow model.") from e + + if hasattr(model_tf, "named_parameters"): + _sync_models_torch_and_tf(model_pt, model_tf) + else: + all_submods_tf = _compute_module_dict_tf(model_tf) + all_submods_pt = _compute_module_dict_pt( + model_pt, keychains=list(all_submods_tf.keys()) + ) + + for pt_model, tf_model in zip(all_submods_pt.values(), all_submods_tf.values()): + pt_model.eval() + tf_model.eval() + _sync_models_torch_and_tf(pt_model, tf_model) + + +def sync_models_torch_and_jax( + model_pt: "nn.Module", model_jax: Union["nnx.Module", "FlaxModel"] +): + """ + Synchronizes the weights and buffers between a PyTorch model (`torch.nn.Module`) + and a Flax model (`flax.nnx.Module`). + + This function ensures both models have identical parameters and buffers by + iterating through their submodules and synchronizing them. The Flax model must + either be an instance of `FlaxModel` or have submodules that inherit from the + translated `FlaxModel`, and expose interfaces similar to `torch.nn.Module`, + including `named_parameters()` and `named_buffers()`. + + Args: + model_pt (torch.nn.Module): The PyTorch model to synchronize from. + model_flax (flax.nnx.Module): The Flax model to synchronize to, with submodules + inheriting from the custom `FlaxModel` class. + Returns: + None + + Example: + ```python + import torch.nn as nn + import jax.numpy as jnp + import flax.nnx as nnx + + #`CustomFlaxLinear` is a subclass of `FlaxModel` that exposes a similar + # interface to torch.nn.Module (with named_parameters and named_buffers). + class CustomFlaxLinear(FlaxModel): + def __init__(self, in_features, out_features): + super(CustomFlaxLinear, self).__init__() + self.weight = nnx.Param(jax.random.normal(jax.random.key(0), [out_features,in_features])) + self.bias = nnx.Param(jax.random.normal(jax.random.key(0),[out_features])) + + def call(self, x): + return x @ self.weight + bias + + def named_parameters(self): + return [("weight", self.weight), ("bias", self.bias)] + + def named_buffers(self): + return [] + + def eval(self): + return False + + #`NativeFlaxModel` is a subclass of nnx.Module and does NOT exposes a similar + # interface to torch.nn.Module (with named_parameters and named_buffers). + class NativeFlaxModel(nnx.Module): + def __init__(self): + super(NativeFlaxModel, self).__init__() + self.linear = CustomFlaxLinear(10, 5) + + def call(self, x): + return self.linear(x) + + class PyTorchModel(nn.Module): + def __init__(self): + super(PyTorchModel, self).__init__() + self.linear = nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + # Instantiate both models + model_pt = PyTorchModel() # PyTorch model + model_flax = NativeFlaxModel() # Native Flax model inheriting from nnx.Module + + # Sync all submodules between the PyTorch and Keras models + sync_models_torch_and_jax(model_pt, model_flax) + ``` + """ + + def _compute_module_dict_jax(model, prefix=""): + _module_dict = dict() + for key, value in model.__dict__.items(): + if isinstance(value, nnx.Module): + if not hasattr(value, "named_parameters"): + _module_dict.update( + _compute_module_dict_jax(value, prefix=f"{key}.") + ) + else: + _module_dict[prefix + key] = value + return _module_dict + + try: + import flax # noqa + + version = parse(flax.__version__) + if version < parse("0.8.0"): + raise ImportError( + "Flax version 0.8.0 or higher is required. Please update your Flax installation." + ) + import flax.nnx as nnx # noqa + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "`flax` was not found installed on your system. Please proceed " + "to install it and restart your interpreter to see the changes." + ) from exc + + try: + import jax # noqa + import jax.numpy as jnp # noqa + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "`jax` was not found installed on your system. Please proceed " + "to install it and restart your interpreter to see the changes." + ) from exc + + try: + assert isinstance( + model_jax, nnx.Module + ), "The second model must be an instance of `nnx.Module`." + except AssertionError as e: + raise TypeError("The second model must be a Flax model.") from e + + if hasattr(model_jax, "named_parameters"): + _sync_models_torch_and_jax(model_pt, model_jax) + + else: + all_submods_jax = _compute_module_dict_jax(model_jax) + all_submods_pt = _compute_module_dict_pt( + model_pt, keychains=list(all_submods_jax.keys()) + ) + + for pt_model, jax_model in zip( + all_submods_pt.values(), all_submods_jax.values() + ): + pt_model.eval() + jax_model.eval() + _sync_models_torch_and_jax(pt_model, jax_model) + + +def sync_models_torch( + original_model: "nn.Module", translated_model: Union["keras.Model", "KerasModel", "nnx.Module", "FlaxModel"], *, target: str +): + """ + Synchronizes the weights and buffers between a native PyTorch model (`torch.nn.Module`) + and it's translated version in TensorFlow or Flax. + + Args: + original_model (torch.nn.Module): The PyTorch model to synchronize from. + translated_model (tf.keras.Model or nnx.Module): The target model to synchronize to, + either a TensorFlow or Flax model. + target (str): The framework of the translated model, either 'tensorflow' or 'jax'. + """ + try: + import torch # noqa + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "`torch` was not found installed on your system. Please proceed " + "to install it and restart your interpreter to see the changes." + ) from exc + + try: + assert isinstance( + original_model, torch.nn.Module + ), "The first model must be an instance of `torch.nn.Module` (PyTorch)." + except AssertionError as e: + raise TypeError("PyTorch model is required as the first argument.") from e + + if target == "tensorflow": + sync_models_torch_and_tf(original_model, translated_model) + + elif target == "jax": + sync_models_torch_and_jax(original_model, translated_model) + else: + raise ivy.utils.exceptions.IvyException( + "target must be either 'tensorflow' or 'jax'." + ) + + print("All parameters and buffers are now synced!") From 27dc9263015bd073a6109709fc64720add378be2 Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Wed, 18 Sep 2024 16:23:04 +0000 Subject: [PATCH 04/10] fix: renaming the `USE_NATIVE_KERAS_LAYERS` env variable --- ivy/stateful/utilities.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ivy/stateful/utilities.py b/ivy/stateful/utilities.py index a36414f43e95..78f91c889b13 100644 --- a/ivy/stateful/utilities.py +++ b/ivy/stateful/utilities.py @@ -89,9 +89,9 @@ def _maybe_update_flax_layer_weights(layer, weight_name, new_weight): import flax.nnx as nnx import jax.numpy as jnp - has_keras_layers = os.environ.get("USE_NATIVE_KERAS_LAYERS", None) == "true" + has_flax_layers = os.environ.get("USE_NATIVE_FW_LAYERS", None) == "true" transpose_weights = ( - has_keras_layers + has_flax_layers or os.environ.get("APPLY_TRANSPOSE_OPTIMIZATION", None) == "true" ) @@ -286,7 +286,7 @@ def _maybe_update_keras_layer_weights(layer, weight_name, new_weight): else: KerasVariable = tf.Variable - has_keras_layers = os.environ.get("USE_NATIVE_KERAS_LAYERS", None) == "true" + has_keras_layers = os.environ.get("USE_NATIVE_FW_LAYERS", None) == "true" transpose_weights = ( has_keras_layers or os.environ.get("APPLY_TRANSPOSE_OPTIMIZATION", None) == "true" From ee6694dc73e5c551991e9b4f349a629fd679c078 Mon Sep 17 00:00:00 2001 From: Yusha Arif <101613943+YushaArif99@users.noreply.github.com> Date: Thu, 19 Sep 2024 09:55:48 +0500 Subject: [PATCH 05/10] Update ivy/functional/backends/tensorflow/module.py Co-authored-by: Sam Armstrong <88863522+Sam-Armstrong@users.noreply.github.com> --- ivy/functional/backends/tensorflow/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/backends/tensorflow/module.py b/ivy/functional/backends/tensorflow/module.py index dce43ed7671c..60c2e10b5451 100644 --- a/ivy/functional/backends/tensorflow/module.py +++ b/ivy/functional/backends/tensorflow/module.py @@ -26,7 +26,7 @@ import torch.nn as nn -if parse(keras.__version__).major > 2: +if keras.__version__ >= "3.0.0": KerasVariable = keras.src.backend.Variable else: KerasVariable = tf.Variable From 5f2feb49322e4e59e955c32d71468025ff1dcf69 Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Thu, 19 Sep 2024 05:08:08 +0000 Subject: [PATCH 06/10] fix (stateful)(utilities): adding try-except blocks for `torch` imports --- ivy/stateful/utilities.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/ivy/stateful/utilities.py b/ivy/stateful/utilities.py index 78f91c889b13..06da4fef6e20 100644 --- a/ivy/stateful/utilities.py +++ b/ivy/stateful/utilities.py @@ -538,6 +538,14 @@ def _compute_module_dict_tf(model, prefix=""): _module_dict[prefix + key] = value return _module_dict + try: + import torch + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "`torch` was not found installed on your system. Please proceed " + "to install it and restart your interpreter to see the changes." + ) from exc + try: import tensorflow as tf except ModuleNotFoundError as exc: @@ -546,6 +554,13 @@ def _compute_module_dict_tf(model, prefix=""): "to install it and restart your interpreter to see the changes." ) from exc + try: + assert isinstance( + model_pt, torch.nn.Module + ), "The original model must be an instance of `torch.nn.Module` (PyTorch)." + except AssertionError as e: + raise TypeError("PyTorch model is required as the first argument.") from e + try: assert isinstance( model_tf, tf.keras.Model @@ -652,6 +667,14 @@ def _compute_module_dict_jax(model, prefix=""): _module_dict[prefix + key] = value return _module_dict + try: + import torch # noqa + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "`torch` was not found installed on your system. Please proceed " + "to install it and restart your interpreter to see the changes." + ) from exc + try: import flax # noqa @@ -676,6 +699,13 @@ def _compute_module_dict_jax(model, prefix=""): "to install it and restart your interpreter to see the changes." ) from exc + try: + assert isinstance( + model_pt, torch.nn.Module + ), "The original model must be an instance of `torch.nn.Module` (PyTorch)." + except AssertionError as e: + raise TypeError("PyTorch model is required as the first argument.") from e + try: assert isinstance( model_jax, nnx.Module From 6c69f3aca0797b1f931d70bfc6696b19df25bcf6 Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Thu, 19 Sep 2024 05:09:46 +0000 Subject: [PATCH 07/10] feat (stateful)(utilities): renaming `sync_models_torch` with `sync_models` and adding a `source` kwarg to route to the appropriate helper. --- ivy/stateful/utilities.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/ivy/stateful/utilities.py b/ivy/stateful/utilities.py index 06da4fef6e20..28cbd8798d33 100644 --- a/ivy/stateful/utilities.py +++ b/ivy/stateful/utilities.py @@ -730,8 +730,12 @@ def _compute_module_dict_jax(model, prefix=""): _sync_models_torch_and_jax(pt_model, jax_model) -def sync_models_torch( - original_model: "nn.Module", translated_model: Union["keras.Model", "KerasModel", "nnx.Module", "FlaxModel"], *, target: str +def sync_models( + original_model: "nn.Module", + translated_model: Union["keras.Model", "KerasModel", "nnx.Module", "FlaxModel"], + *, + source: str = "torch", + target: str = "tensorflow", ): """ Synchronizes the weights and buffers between a native PyTorch model (`torch.nn.Module`) @@ -741,31 +745,26 @@ def sync_models_torch( original_model (torch.nn.Module): The PyTorch model to synchronize from. translated_model (tf.keras.Model or nnx.Module): The target model to synchronize to, either a TensorFlow or Flax model. - target (str): The framework of the translated model, either 'tensorflow' or 'jax'. + source (str): The framework of the original model, Defaults to 'torch'. + target (str): The framework of the translated model. Defaults to 'tensorflow'. """ - try: - import torch # noqa - except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "`torch` was not found installed on your system. Please proceed " - "to install it and restart your interpreter to see the changes." - ) from exc - - try: - assert isinstance( - original_model, torch.nn.Module - ), "The first model must be an instance of `torch.nn.Module` (PyTorch)." - except AssertionError as e: - raise TypeError("PyTorch model is required as the first argument.") from e + if source != "torch": + raise ivy.utils.exceptions.IvyNotImplementedException( + "sync_models is not implemented for source other than 'torch'. got {}".format( + source + ) + ) if target == "tensorflow": sync_models_torch_and_tf(original_model, translated_model) elif target == "jax": sync_models_torch_and_jax(original_model, translated_model) else: - raise ivy.utils.exceptions.IvyException( - "target must be either 'tensorflow' or 'jax'." + raise ivy.utils.exceptions.IvyNotImplementedException( + "sync_models is not implemented for target other than 'tensorflow' or 'jax'. got {}".format( + source + ) ) print("All parameters and buffers are now synced!") From 10153ef36f77ca89cc71de51e6cc548ccd16dc59 Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Thu, 19 Sep 2024 05:18:50 +0000 Subject: [PATCH 08/10] fix invalid import --- ivy/stateful/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/stateful/__init__.py b/ivy/stateful/__init__.py index dcc7827bd341..d02cb39fe860 100644 --- a/ivy/stateful/__init__.py +++ b/ivy/stateful/__init__.py @@ -17,4 +17,4 @@ from . import sequential from .sequential import * from . import utilities -from .utilities import sync_models_torch +from .utilities import sync_models From bd576902b2e5f0e811da050e39f6b6267a675e37 Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Thu, 19 Sep 2024 16:36:21 +0000 Subject: [PATCH 09/10] feat (stateful)(utilities): adding a helper function to check for instances of native modules by traversing through the `mro` chain. --- ivy/stateful/utilities.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/ivy/stateful/utilities.py b/ivy/stateful/utilities.py index 28cbd8798d33..f5e810397c88 100644 --- a/ivy/stateful/utilities.py +++ b/ivy/stateful/utilities.py @@ -12,6 +12,21 @@ import flax.nnx as nnx +def _is_submodule(obj, kw): + cls_str = { + "torch": ("torch.nn.modules.module.Module",), + "keras": ("keras.engine.training.Model", "keras.src.models.model.Model"), + "flax": ("flax.nnx.nnx.module.Module",), + }[kw] + try: + for bc in type(obj).mro(): + if any(cls in str(bc) for cls in cls_str): + return True + except TypeError: + pass + return False + + def _compute_module_dict_pt(model, keychains): _module_dict = dict() for keychain in keychains: From 27b91f5c9519594160cc6a7d309fa0914420d288 Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Thu, 19 Sep 2024 16:37:32 +0000 Subject: [PATCH 10/10] feat(stateful)(utilities): removing the `source` and `target` kwargs and instead using `_is_submodule` --- ivy/stateful/utilities.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/ivy/stateful/utilities.py b/ivy/stateful/utilities.py index f5e810397c88..50245ec1812f 100644 --- a/ivy/stateful/utilities.py +++ b/ivy/stateful/utilities.py @@ -748,9 +748,6 @@ def _compute_module_dict_jax(model, prefix=""): def sync_models( original_model: "nn.Module", translated_model: Union["keras.Model", "KerasModel", "nnx.Module", "FlaxModel"], - *, - source: str = "torch", - target: str = "tensorflow", ): """ Synchronizes the weights and buffers between a native PyTorch model (`torch.nn.Module`) @@ -760,25 +757,23 @@ def sync_models( original_model (torch.nn.Module): The PyTorch model to synchronize from. translated_model (tf.keras.Model or nnx.Module): The target model to synchronize to, either a TensorFlow or Flax model. - source (str): The framework of the original model, Defaults to 'torch'. - target (str): The framework of the translated model. Defaults to 'tensorflow'. """ - if source != "torch": - raise ivy.utils.exceptions.IvyNotImplementedException( - "sync_models is not implemented for source other than 'torch'. got {}".format( - source + if not _is_submodule(original_model, "torch"): + raise ivy.utils.exceptions.IvyException( + "sync_models expected an instance of `nn.Module` as the first argument. got {}".format( + original_model ) ) - if target == "tensorflow": + if _is_submodule(original_model, "keras"): sync_models_torch_and_tf(original_model, translated_model) - elif target == "jax": + elif _is_submodule(original_model, "flax"): sync_models_torch_and_jax(original_model, translated_model) else: raise ivy.utils.exceptions.IvyNotImplementedException( - "sync_models is not implemented for target other than 'tensorflow' or 'jax'. got {}".format( - source + "sync_models expected an instance of a `keras.Model` or `nnx.Module` as the second argument. got {}".format( + translated_model ) )