diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 73d6a7613fda2b..864e34b016781f 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -21,7 +21,8 @@ import numpy -from .utils import ExplicitEnum, logging +from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze +from .utils import transpose as transpose_func logger = logging.get_logger(__name__) @@ -66,10 +67,12 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", if len(tf_name) > 1: tf_name = tf_name[1:] # Remove level zero + tf_weight_shape = list(tf_weight_shape) + # When should we transpose the weights - if tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 4: + if tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 4: transpose = TransposeType.CONV2D - elif tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 3: + elif tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 3: transpose = TransposeType.CONV1D elif bool( tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"] @@ -98,6 +101,43 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", return tf_name, transpose +def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf=True): + """ + Apply a transpose to some weight then tries to reshape the weight to the same shape as a given shape, all in a + framework agnostic way. + """ + if transpose is TransposeType.CONV2D: + # Conv2D weight: + # PT: (num_out_channel, num_in_channel, kernel[0], kernel[1]) + # -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel) + axes = (2, 3, 1, 0) if pt_to_tf else (3, 2, 0, 1) + weight = transpose_func(weight, axes=axes) + elif transpose is TransposeType.CONV1D: + # Conv1D weight: + # PT: (num_out_channel, num_in_channel, kernel) + # -> TF: (kernel, num_in_channel, num_out_channel) + weight = transpose_func(weight, axes=(2, 1, 0)) + elif transpose is TransposeType.SIMPLE: + weight = transpose_func(weight) + + if match_shape is None: + return weight + + if len(match_shape) < len(weight.shape): + weight = squeeze(weight) + elif len(match_shape) > len(weight.shape): + weight = expand_dims(weight, axis=0) + + if list(match_shape) != list(weight.shape): + try: + weight = reshape(weight, match_shape) + except AssertionError as e: + e.args += (match_shape, match_shape) + raise e + + return weight + + ##################### # PyTorch => TF 2.0 # ##################### @@ -155,7 +195,6 @@ def load_pytorch_weights_in_tf2_model( try: import tensorflow as tf # noqa: F401 import torch # noqa: F401 - from tensorflow.python.keras import backend as K except ImportError: logger.error( "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " @@ -163,6 +202,22 @@ def load_pytorch_weights_in_tf2_model( ) raise + pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + return load_pytorch_state_dict_in_tf2_model( + tf_model, + pt_state_dict, + tf_inputs=tf_inputs, + allow_missing_keys=allow_missing_keys, + output_loading_info=output_loading_info, + ) + + +def load_pytorch_state_dict_in_tf2_model( + tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False +): + """Load a pytorch state_dict in a TF 2.0 model.""" + from tensorflow.python.keras import backend as K + if tf_inputs is None: tf_inputs = tf_model.dummy_inputs @@ -216,41 +271,9 @@ def load_pytorch_weights_in_tf2_model( continue raise AttributeError(f"{name} not found in PyTorch model") - array = pt_state_dict[name].numpy() - - if transpose is TransposeType.CONV2D: - # Conv2D weight: - # PT: (num_out_channel, num_in_channel, kernel[0], kernel[1]) - # -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel) - array = numpy.transpose(array, axes=(2, 3, 1, 0)) - elif transpose is TransposeType.CONV1D: - # Conv1D weight: - # PT: (num_out_channel, num_in_channel, kernel) - # -> TF: (kernel, num_in_channel, num_out_channel) - array = numpy.transpose(array, axes=(2, 1, 0)) - elif transpose is TransposeType.SIMPLE: - array = numpy.transpose(array) - - if len(symbolic_weight.shape) < len(array.shape): - array = numpy.squeeze(array) - elif len(symbolic_weight.shape) > len(array.shape): - array = numpy.expand_dims(array, axis=0) - - if list(symbolic_weight.shape) != list(array.shape): - try: - array = numpy.reshape(array, symbolic_weight.shape) - except AssertionError as e: - e.args += (symbolic_weight.shape, array.shape) - raise e - - try: - assert list(symbolic_weight.shape) == list(array.shape) - except AssertionError as e: - e.args += (symbolic_weight.shape, array.shape) - raise e + array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape) tf_loaded_numel += array.size - # logger.warning(f"Initialize TF weight {symbolic_weight.name}") weight_value_tuples.append((symbolic_weight, array)) all_pytorch_weights.discard(name) @@ -370,6 +393,15 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F ) raise + tf_state_dict = {tf_weight.name: tf_weight.numpy() for tf_weight in tf_weights} + return load_tf2_state_dict_in_pytorch_model( + pt_model, tf_state_dict, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info + ) + + +def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_keys=False, output_loading_info=False): + import torch + new_pt_params_dict = {} current_pt_params_dict = dict(pt_model.named_parameters()) @@ -381,11 +413,11 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F # Build a map from potential PyTorch weight names to TF 2.0 Variables tf_weights_map = {} - for tf_weight in tf_weights: + for name, tf_weight in tf_state_dict.items(): pt_name, transpose = convert_tf_weight_name_to_pt_weight_name( - tf_weight.name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape + name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape ) - tf_weights_map[pt_name] = (tf_weight.numpy(), transpose) + tf_weights_map[pt_name] = (tf_weight, transpose) all_tf_weights = set(list(tf_weights_map.keys())) loaded_pt_weights_data_ptr = {} @@ -406,43 +438,18 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F array, transpose = tf_weights_map[pt_weight_name] - if transpose is TransposeType.CONV2D: - # Conv2D weight: - # TF: (kernel[0], kernel[1], num_in_channel, num_out_channel) - # -> PT: (num_out_channel, num_in_channel, kernel[0], kernel[1]) - array = numpy.transpose(array, axes=(3, 2, 0, 1)) - elif transpose is TransposeType.CONV1D: - # Conv1D weight: - # TF: (kernel, num_in_channel, num_out_channel) - # -> PT: (num_out_channel, num_in_channel, kernel) - array = numpy.transpose(array, axes=(2, 1, 0)) - elif transpose is TransposeType.SIMPLE: - array = numpy.transpose(array) - - if len(pt_weight.shape) < len(array.shape): - array = numpy.squeeze(array) - elif len(pt_weight.shape) > len(array.shape): - array = numpy.expand_dims(array, axis=0) - - if list(pt_weight.shape) != list(array.shape): - try: - array = numpy.reshape(array, pt_weight.shape) - except AssertionError as e: - e.args += (pt_weight.shape, array.shape) - raise e - - try: - assert list(pt_weight.shape) == list(array.shape) - except AssertionError as e: - e.args += (pt_weight.shape, array.shape) - raise e + array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False) - # logger.warning(f"Initialize PyTorch weight {pt_weight_name}") - # Make sure we have a proper numpy array if numpy.isscalar(array): array = numpy.array(array) - new_pt_params_dict[pt_weight_name] = torch.from_numpy(array) - loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = torch.from_numpy(array) + if not is_torch_tensor(array) and not is_numpy_array(array): + array = array.numpy() + if is_numpy_array(array): + # Convert to torch tensor + array = torch.from_numpy(array) + + new_pt_params_dict[pt_weight_name] = array + loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = array all_tf_weights.discard(pt_weight_name) missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 7857339379324c..03aa17bc8332d0 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -38,6 +38,7 @@ PaddingStrategy, TensorType, cached_property, + expand_dims, find_labels, flatten_dict, is_jax_tensor, @@ -46,8 +47,11 @@ is_tf_tensor, is_torch_device, is_torch_tensor, + reshape, + squeeze, to_numpy, to_py_obj, + transpose, working_or_temp_dir, ) from .hub import ( diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index a53f769f05c2cb..334141bd55c3e2 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -29,6 +29,13 @@ from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy +if is_tf_available(): + import tensorflow as tf + +if is_flax_available(): + import jax.numpy as jnp + + class cached_property(property): """ Descriptor that mimics @property but caches output in member variable. @@ -370,3 +377,71 @@ def working_or_temp_dir(working_dir, use_temp_dir: bool = False): yield tmp_dir else: yield working_dir + + +def transpose(array, axes=None): + """ + Framework-agnostic version of `numpy.transpose` that will work on torch/TensorFlow/Jax tensors as well as NumPy + arrays. + """ + if is_numpy_array(array): + return np.transpose(array, axes=axes) + elif is_torch_tensor(array): + return array.T if axes is None else array.permute(*axes) + elif is_tf_tensor(array): + return tf.transpose(array, perm=axes) + elif is_jax_tensor(array): + return jnp.transpose(array, axes=axes) + else: + raise ValueError(f"Type not supported for transpose: {type(array)}.") + + +def reshape(array, newshape): + """ + Framework-agnostic version of `numpy.reshape` that will work on torch/TensorFlow/Jax tensors as well as NumPy + arrays. + """ + if is_numpy_array(array): + return np.reshape(array, newshape) + elif is_torch_tensor(array): + return array.reshape(*newshape) + elif is_tf_tensor(array): + return tf.reshape(array, newshape) + elif is_jax_tensor(array): + return jnp.reshape(array, newshape) + else: + raise ValueError(f"Type not supported for reshape: {type(array)}.") + + +def squeeze(array, axis=None): + """ + Framework-agnostic version of `numpy.squeeze` that will work on torch/TensorFlow/Jax tensors as well as NumPy + arrays. + """ + if is_numpy_array(array): + return np.squeeze(array, axis=axis) + elif is_torch_tensor(array): + return array.squeeze() if axis is None else array.squeeze(dim=axis) + elif is_tf_tensor(array): + return tf.squeeze(array, axis=axis) + elif is_jax_tensor(array): + return jnp.squeeze(array, axis=axis) + else: + raise ValueError(f"Type not supported for squeeze: {type(array)}.") + + +def expand_dims(array, axis): + """ + Framework-agnostic version of `numpy.expand_dims` that will work on torch/TensorFlow/Jax tensors as well as NumPy + arrays. + """ + if is_numpy_array(array): + return np.expand_dims(array, axis) + elif is_torch_tensor(array): + return array.unsqueeze(dim=axis) + elif is_tf_tensor(array): + return tf.expand_dims(array, axis=axis) + elif is_jax_tensor(array): + return jnp.expand_dims(array, axis=axis) + else: + raise ValueError(f"Type not supported for expand_dims: {type(array)}.") diff --git a/tests/utils/test_generic.py b/tests/utils/test_generic.py index 6fbdbee4036070..3d864648120b56 100644 --- a/tests/utils/test_generic.py +++ b/tests/utils/test_generic.py @@ -15,7 +15,29 @@ import unittest -from transformers.utils import flatten_dict +import numpy as np + +from transformers.testing_utils import require_flax, require_tf, require_torch +from transformers.utils import ( + expand_dims, + flatten_dict, + is_flax_available, + is_tf_available, + is_torch_available, + reshape, + squeeze, + transpose, +) + + +if is_flax_available(): + import jax.numpy as jnp + +if is_tf_available(): + import tensorflow as tf + +if is_torch_available(): + import torch class GenericTester(unittest.TestCase): @@ -43,3 +65,136 @@ def test_flatten_dict(self): } self.assertEqual(flatten_dict(input_dict), expected_dict) + + def test_transpose_numpy(self): + x = np.random.randn(3, 4) + self.assertTrue(np.allclose(transpose(x), x.transpose())) + + x = np.random.randn(3, 4, 5) + self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), x.transpose((1, 2, 0)))) + + @require_torch + def test_transpose_torch(self): + x = np.random.randn(3, 4) + t = torch.tensor(x) + self.assertTrue(np.allclose(transpose(x), transpose(t).numpy())) + + x = np.random.randn(3, 4, 5) + t = torch.tensor(x) + self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), transpose(t, axes=(1, 2, 0)).numpy())) + + @require_tf + def test_transpose_tf(self): + x = np.random.randn(3, 4) + t = tf.constant(x) + self.assertTrue(np.allclose(transpose(x), transpose(t).numpy())) + + x = np.random.randn(3, 4, 5) + t = tf.constant(x) + self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), transpose(t, axes=(1, 2, 0)).numpy())) + + @require_flax + def test_transpose_flax(self): + x = np.random.randn(3, 4) + t = jnp.array(x) + self.assertTrue(np.allclose(transpose(x), np.asarray(transpose(t)))) + + x = np.random.randn(3, 4, 5) + t = jnp.array(x) + self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), np.asarray(transpose(t, axes=(1, 2, 0))))) + + def test_reshape_numpy(self): + x = np.random.randn(3, 4) + self.assertTrue(np.allclose(reshape(x, (4, 3)), np.reshape(x, (4, 3)))) + + x = np.random.randn(3, 4, 5) + self.assertTrue(np.allclose(reshape(x, (12, 5)), np.reshape(x, (12, 5)))) + + @require_torch + def test_reshape_torch(self): + x = np.random.randn(3, 4) + t = torch.tensor(x) + self.assertTrue(np.allclose(reshape(x, (4, 3)), reshape(t, (4, 3)).numpy())) + + x = np.random.randn(3, 4, 5) + t = torch.tensor(x) + self.assertTrue(np.allclose(reshape(x, (12, 5)), reshape(t, (12, 5)).numpy())) + + @require_tf + def test_reshape_tf(self): + x = np.random.randn(3, 4) + t = tf.constant(x) + self.assertTrue(np.allclose(reshape(x, (4, 3)), reshape(t, (4, 3)).numpy())) + + x = np.random.randn(3, 4, 5) + t = tf.constant(x) + self.assertTrue(np.allclose(reshape(x, (12, 5)), reshape(t, (12, 5)).numpy())) + + @require_flax + def test_reshape_flax(self): + x = np.random.randn(3, 4) + t = jnp.array(x) + self.assertTrue(np.allclose(reshape(x, (4, 3)), np.asarray(reshape(t, (4, 3))))) + + x = np.random.randn(3, 4, 5) + t = jnp.array(x) + self.assertTrue(np.allclose(reshape(x, (12, 5)), np.asarray(reshape(t, (12, 5))))) + + def test_squeeze_numpy(self): + x = np.random.randn(1, 3, 4) + self.assertTrue(np.allclose(squeeze(x), np.squeeze(x))) + + x = np.random.randn(1, 4, 1, 5) + self.assertTrue(np.allclose(squeeze(x, axis=2), np.squeeze(x, axis=2))) + + @require_torch + def test_squeeze_torch(self): + x = np.random.randn(1, 3, 4) + t = torch.tensor(x) + self.assertTrue(np.allclose(squeeze(x), squeeze(t).numpy())) + + x = np.random.randn(1, 4, 1, 5) + t = torch.tensor(x) + self.assertTrue(np.allclose(squeeze(x, axis=2), squeeze(t, axis=2).numpy())) + + @require_tf + def test_squeeze_tf(self): + x = np.random.randn(1, 3, 4) + t = tf.constant(x) + self.assertTrue(np.allclose(squeeze(x), squeeze(t).numpy())) + + x = np.random.randn(1, 4, 1, 5) + t = tf.constant(x) + self.assertTrue(np.allclose(squeeze(x, axis=2), squeeze(t, axis=2).numpy())) + + @require_flax + def test_squeeze_flax(self): + x = np.random.randn(1, 3, 4) + t = jnp.array(x) + self.assertTrue(np.allclose(squeeze(x), np.asarray(squeeze(t)))) + + x = np.random.randn(1, 4, 1, 5) + t = jnp.array(x) + self.assertTrue(np.allclose(squeeze(x, axis=2), np.asarray(squeeze(t, axis=2)))) + + def test_expand_dims_numpy(self): + x = np.random.randn(3, 4) + self.assertTrue(np.allclose(expand_dims(x, axis=1), np.expand_dims(x, axis=1))) + + @require_torch + def test_expand_dims_torch(self): + x = np.random.randn(3, 4) + t = torch.tensor(x) + self.assertTrue(np.allclose(expand_dims(x, axis=1), expand_dims(t, axis=1).numpy())) + + @require_tf + def test_expand_dims_tf(self): + x = np.random.randn(3, 4) + t = tf.constant(x) + self.assertTrue(np.allclose(expand_dims(x, axis=1), expand_dims(t, axis=1).numpy())) + + @require_flax + def test_expand_dims_flax(self): + x = np.random.randn(3, 4) + t = jnp.array(x) + self.assertTrue(np.allclose(expand_dims(x, axis=1), np.asarray(expand_dims(t, axis=1))))