Skip to content

Commit

Permalink
Refactor conversion function (huggingface#19799)
Browse files Browse the repository at this point in the history
* Refactor conversion function

* Remove dupe line

* Fixes

* Fixes

* Use the right variable...

* Fix last test
  • Loading branch information
sgugger authored and amyeroberts committed Nov 1, 2022
1 parent 54b04a9 commit 24f8c50
Show file tree
Hide file tree
Showing 4 changed files with 316 additions and 75 deletions.
155 changes: 81 additions & 74 deletions src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

import numpy

from .utils import ExplicitEnum, logging
from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze
from .utils import transpose as transpose_func


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -66,10 +67,12 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",
if len(tf_name) > 1:
tf_name = tf_name[1:] # Remove level zero

tf_weight_shape = list(tf_weight_shape)

# When should we transpose the weights
if tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 4:
if tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 4:
transpose = TransposeType.CONV2D
elif tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 3:
elif tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 3:
transpose = TransposeType.CONV1D
elif bool(
tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
Expand Down Expand Up @@ -98,6 +101,43 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",
return tf_name, transpose


def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf=True):
"""
Apply a transpose to some weight then tries to reshape the weight to the same shape as a given shape, all in a
framework agnostic way.
"""
if transpose is TransposeType.CONV2D:
# Conv2D weight:
# PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
# -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
axes = (2, 3, 1, 0) if pt_to_tf else (3, 2, 0, 1)
weight = transpose_func(weight, axes=axes)
elif transpose is TransposeType.CONV1D:
# Conv1D weight:
# PT: (num_out_channel, num_in_channel, kernel)
# -> TF: (kernel, num_in_channel, num_out_channel)
weight = transpose_func(weight, axes=(2, 1, 0))
elif transpose is TransposeType.SIMPLE:
weight = transpose_func(weight)

if match_shape is None:
return weight

if len(match_shape) < len(weight.shape):
weight = squeeze(weight)
elif len(match_shape) > len(weight.shape):
weight = expand_dims(weight, axis=0)

if list(match_shape) != list(weight.shape):
try:
weight = reshape(weight, match_shape)
except AssertionError as e:
e.args += (match_shape, match_shape)
raise e

return weight


#####################
# PyTorch => TF 2.0 #
#####################
Expand Down Expand Up @@ -155,14 +195,29 @@ def load_pytorch_weights_in_tf2_model(
try:
import tensorflow as tf # noqa: F401
import torch # noqa: F401
from tensorflow.python.keras import backend as K
except ImportError:
logger.error(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise

pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
return load_pytorch_state_dict_in_tf2_model(
tf_model,
pt_state_dict,
tf_inputs=tf_inputs,
allow_missing_keys=allow_missing_keys,
output_loading_info=output_loading_info,
)


def load_pytorch_state_dict_in_tf2_model(
tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
):
"""Load a pytorch state_dict in a TF 2.0 model."""
from tensorflow.python.keras import backend as K

if tf_inputs is None:
tf_inputs = tf_model.dummy_inputs

Expand Down Expand Up @@ -216,41 +271,9 @@ def load_pytorch_weights_in_tf2_model(
continue
raise AttributeError(f"{name} not found in PyTorch model")

array = pt_state_dict[name].numpy()

if transpose is TransposeType.CONV2D:
# Conv2D weight:
# PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
# -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
array = numpy.transpose(array, axes=(2, 3, 1, 0))
elif transpose is TransposeType.CONV1D:
# Conv1D weight:
# PT: (num_out_channel, num_in_channel, kernel)
# -> TF: (kernel, num_in_channel, num_out_channel)
array = numpy.transpose(array, axes=(2, 1, 0))
elif transpose is TransposeType.SIMPLE:
array = numpy.transpose(array)

if len(symbolic_weight.shape) < len(array.shape):
array = numpy.squeeze(array)
elif len(symbolic_weight.shape) > len(array.shape):
array = numpy.expand_dims(array, axis=0)

if list(symbolic_weight.shape) != list(array.shape):
try:
array = numpy.reshape(array, symbolic_weight.shape)
except AssertionError as e:
e.args += (symbolic_weight.shape, array.shape)
raise e

try:
assert list(symbolic_weight.shape) == list(array.shape)
except AssertionError as e:
e.args += (symbolic_weight.shape, array.shape)
raise e
array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape)

tf_loaded_numel += array.size
# logger.warning(f"Initialize TF weight {symbolic_weight.name}")

weight_value_tuples.append((symbolic_weight, array))
all_pytorch_weights.discard(name)
Expand Down Expand Up @@ -370,6 +393,15 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
)
raise

tf_state_dict = {tf_weight.name: tf_weight.numpy() for tf_weight in tf_weights}
return load_tf2_state_dict_in_pytorch_model(
pt_model, tf_state_dict, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
)


def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_keys=False, output_loading_info=False):
import torch

new_pt_params_dict = {}
current_pt_params_dict = dict(pt_model.named_parameters())

Expand All @@ -381,11 +413,11 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F

# Build a map from potential PyTorch weight names to TF 2.0 Variables
tf_weights_map = {}
for tf_weight in tf_weights:
for name, tf_weight in tf_state_dict.items():
pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(
tf_weight.name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
)
tf_weights_map[pt_name] = (tf_weight.numpy(), transpose)
tf_weights_map[pt_name] = (tf_weight, transpose)

all_tf_weights = set(list(tf_weights_map.keys()))
loaded_pt_weights_data_ptr = {}
Expand All @@ -406,43 +438,18 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F

array, transpose = tf_weights_map[pt_weight_name]

if transpose is TransposeType.CONV2D:
# Conv2D weight:
# TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
# -> PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
array = numpy.transpose(array, axes=(3, 2, 0, 1))
elif transpose is TransposeType.CONV1D:
# Conv1D weight:
# TF: (kernel, num_in_channel, num_out_channel)
# -> PT: (num_out_channel, num_in_channel, kernel)
array = numpy.transpose(array, axes=(2, 1, 0))
elif transpose is TransposeType.SIMPLE:
array = numpy.transpose(array)

if len(pt_weight.shape) < len(array.shape):
array = numpy.squeeze(array)
elif len(pt_weight.shape) > len(array.shape):
array = numpy.expand_dims(array, axis=0)

if list(pt_weight.shape) != list(array.shape):
try:
array = numpy.reshape(array, pt_weight.shape)
except AssertionError as e:
e.args += (pt_weight.shape, array.shape)
raise e

try:
assert list(pt_weight.shape) == list(array.shape)
except AssertionError as e:
e.args += (pt_weight.shape, array.shape)
raise e
array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False)

# logger.warning(f"Initialize PyTorch weight {pt_weight_name}")
# Make sure we have a proper numpy array
if numpy.isscalar(array):
array = numpy.array(array)
new_pt_params_dict[pt_weight_name] = torch.from_numpy(array)
loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = torch.from_numpy(array)
if not is_torch_tensor(array) and not is_numpy_array(array):
array = array.numpy()
if is_numpy_array(array):
# Convert to torch tensor
array = torch.from_numpy(array)

new_pt_params_dict[pt_weight_name] = array
loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = array
all_tf_weights.discard(pt_weight_name)

missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
PaddingStrategy,
TensorType,
cached_property,
expand_dims,
find_labels,
flatten_dict,
is_jax_tensor,
Expand All @@ -46,8 +47,11 @@
is_tf_tensor,
is_torch_device,
is_torch_tensor,
reshape,
squeeze,
to_numpy,
to_py_obj,
transpose,
working_or_temp_dir,
)
from .hub import (
Expand Down
75 changes: 75 additions & 0 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy


if is_tf_available():
import tensorflow as tf

if is_flax_available():
import jax.numpy as jnp


class cached_property(property):
"""
Descriptor that mimics @property but caches output in member variable.
Expand Down Expand Up @@ -370,3 +377,71 @@ def working_or_temp_dir(working_dir, use_temp_dir: bool = False):
yield tmp_dir
else:
yield working_dir


def transpose(array, axes=None):
"""
Framework-agnostic version of `numpy.transpose` that will work on torch/TensorFlow/Jax tensors as well as NumPy
arrays.
"""
if is_numpy_array(array):
return np.transpose(array, axes=axes)
elif is_torch_tensor(array):
return array.T if axes is None else array.permute(*axes)
elif is_tf_tensor(array):
return tf.transpose(array, perm=axes)
elif is_jax_tensor(array):
return jnp.transpose(array, axes=axes)
else:
raise ValueError(f"Type not supported for transpose: {type(array)}.")


def reshape(array, newshape):
"""
Framework-agnostic version of `numpy.reshape` that will work on torch/TensorFlow/Jax tensors as well as NumPy
arrays.
"""
if is_numpy_array(array):
return np.reshape(array, newshape)
elif is_torch_tensor(array):
return array.reshape(*newshape)
elif is_tf_tensor(array):
return tf.reshape(array, newshape)
elif is_jax_tensor(array):
return jnp.reshape(array, newshape)
else:
raise ValueError(f"Type not supported for reshape: {type(array)}.")


def squeeze(array, axis=None):
"""
Framework-agnostic version of `numpy.squeeze` that will work on torch/TensorFlow/Jax tensors as well as NumPy
arrays.
"""
if is_numpy_array(array):
return np.squeeze(array, axis=axis)
elif is_torch_tensor(array):
return array.squeeze() if axis is None else array.squeeze(dim=axis)
elif is_tf_tensor(array):
return tf.squeeze(array, axis=axis)
elif is_jax_tensor(array):
return jnp.squeeze(array, axis=axis)
else:
raise ValueError(f"Type not supported for squeeze: {type(array)}.")


def expand_dims(array, axis):
"""
Framework-agnostic version of `numpy.expand_dims` that will work on torch/TensorFlow/Jax tensors as well as NumPy
arrays.
"""
if is_numpy_array(array):
return np.expand_dims(array, axis)
elif is_torch_tensor(array):
return array.unsqueeze(dim=axis)
elif is_tf_tensor(array):
return tf.expand_dims(array, axis=axis)
elif is_jax_tensor(array):
return jnp.expand_dims(array, axis=axis)
else:
raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
Loading

0 comments on commit 24f8c50

Please sign in to comment.