Skip to content

Commit

Permalink
Merge pull request #11122 from redwrasse:redwrasse/tf2_utils_2x_wide_…
Browse files Browse the repository at this point in the history
…cleanup

PiperOrigin-RevId: 605642679
  • Loading branch information
tensorflower-gardener committed Feb 9, 2024
2 parents 468fde0 + 610f177 commit 6e3f77d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 24 deletions.
62 changes: 38 additions & 24 deletions official/modeling/fast_training/experimental/tf2_utils_2x_wide.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ def expand_vector(v: np.ndarray) -> np.ndarray:
def expand_1_axis(w: np.ndarray,
epsilon: float,
axis: int) -> np.ndarray:
"""Expands either the first dimension or the last dimension of w.
"""Expands either the first or last dimension of w.
If `axis = 0`, the following constraint will be satisfied:
If `axis = 0`, the following expression will be satisfied:
matmul(x, w) ==
matmul(expand_vector(x), expand_1_axis(w, epsilon=0.1, axis=0))
matmul(expand_vector(x), expand_1_axis(w, axis=0))
If `axis = -1`, the following constraint will be satisfied if `epsilon = 0.0`:
If `axis = -1` and `epsilon = 0.0`, the following constraint will be
satisfied:
expand_vector(matmul(x, w)) ==
2 * matmul(x, expand_1_axis(w, epsilon=0.0, axis=-1))
Expand All @@ -54,9 +55,12 @@ def expand_1_axis(w: np.ndarray,
Returns:
Expanded numpy array.
"""
assert axis in (0, -1), (
"Only support expanding the first or the last dimension. "
"Got: {}".format(axis))

if axis not in (0, -1):
raise ValueError(
"Only support expanding the first or the last dimension. "
"Got: {}".format(axis)
)

rank = len(w.shape)

Expand All @@ -65,7 +69,7 @@ def expand_1_axis(w: np.ndarray,

sign_flip = np.array([1, -1])
for _ in range(rank - 1):
sign_flip = np.expand_dims(sign_flip, axis=-1 if axis == 0 else 0)
sign_flip = np.expand_dims(sign_flip, axis=axis - 1)
sign_flip = np.tile(sign_flip,
[w.shape[0]] + [1] * (rank - 2) + [w.shape[-1]])

Expand All @@ -76,9 +80,9 @@ def expand_1_axis(w: np.ndarray,

def expand_2_axes(w: np.ndarray,
epsilon: float) -> np.ndarray:
"""Expands the first dimension and the last dimension of w.
"""Expands the first and last dimension of w.
The following constraint will be satisfied:
This operation satisfies the following expression:
expand_vector(matmul(x, w)) == matmul(expand_vector(x), expand_2_axes(w))
Args:
Expand Down Expand Up @@ -109,8 +113,8 @@ def var_to_var(var_from: tf.Variable,
epsilon: float):
"""Expands a variable to another variable.
Assume the shape of `var_from` is (a, b, ..., y, z), the shape of `var_to`
can be (a, ..., z * 2), (a * 2, ..., z * 2), (a * 2, ..., z)
Assuming the shape of `var_from` is (a, b, ..., y, z), then shape of `var_to`
must be one of (a, ..., z * 2), (a * 2, ..., z * 2), or (a * 2, ..., z).
If the shape of `var_to` is (a, ..., 2 * z):
For any x, tf.matmul(x, var_to) ~= expand_vector(tf.matmul(x, var_from)) / 2
Expand All @@ -131,21 +135,30 @@ def var_to_var(var_from: tf.Variable,

if shape_from == shape_to:
var_to.assign(var_from)
return

var_from_np = var_from.numpy()

if len(shape_from) == len(shape_to) == 1:
var_to.assign(expand_vector(var_from_np))
return

elif len(shape_from) == 1 and len(shape_to) == 1:
var_to.assign(expand_vector(var_from.numpy()))
a_from, z_from = shape_from[0], shape_from[-1]
a_to, z_to = shape_to[0], shape_to[-1]

elif shape_from[0] * 2 == shape_to[0] and shape_from[-1] == shape_to[-1]:
var_to.assign(expand_1_axis(var_from.numpy(), epsilon=epsilon, axis=0))
if a_to == 2 * a_from and z_to == z_from:
var_to.assign(expand_1_axis(var_from_np, epsilon=epsilon, axis=0))
return

elif shape_from[0] == shape_to[0] and shape_from[-1] * 2 == shape_to[-1]:
var_to.assign(expand_1_axis(var_from.numpy(), epsilon=epsilon, axis=-1))
if a_to == a_from and z_to == 2 * z_from:
var_to.assign(expand_1_axis(var_from_np, epsilon=epsilon, axis=-1))
return

elif shape_from[0] * 2 == shape_to[0] and shape_from[-1] * 2 == shape_to[-1]:
var_to.assign(expand_2_axes(var_from.numpy(), epsilon=epsilon))
if a_to == 2 * a_from and z_to == 2 * z_from:
var_to.assign(expand_2_axes(var_from_np, epsilon=epsilon))
return

else:
raise ValueError("Shape not supported, {}, {}".format(shape_from, shape_to))
raise ValueError("Shape not supported, {}, {}".format(shape_from, shape_to))


def model_to_model_2x_wide(model_from: tf.Module,
Expand All @@ -170,8 +183,9 @@ def model_to_model_2x_wide(model_from: tf.Module,
assert model_narrow([[1, 2, 3]]) == model_wide([[1, 1, 2, 2, 3, 3]])
```
We assume that `model_from` and `model_to` has the same architecture and only
widths of them differ.
We assume that `model_from` and `model_to` have the same architecture and
differ
only in widths.
Args:
model_from: input model to expand.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,31 @@ def test_expand_3d_tensor_axis_2(self):
o1 = np.matmul(x, w1)
self.assertAllClose(o0, np.sum(o1.reshape(2, 2), axis=-1))

def test_relations(self):
x = np.array([10, 11])
w = np.random.rand(2, 2)
# matmul(x, w) == matmul(expand_vector(x), expand_1_axis(w, axis=0))
lhs = np.matmul(x, w)
rhs = np.matmul(
tf2_utils_2x_wide.expand_vector(x),
tf2_utils_2x_wide.expand_1_axis(w, epsilon=0.1, axis=0),
)
self.assertAllClose(lhs, rhs)
# expand_vector(matmul(x, w)) ==
# 2 * matmul(x, expand_1_axis(w, epsilon=0.0, axis=-1))
lhs = tf2_utils_2x_wide.expand_vector(np.matmul(x, w))
rhs = 2 * np.matmul(
x, tf2_utils_2x_wide.expand_1_axis(w, epsilon=0.0, axis=-1)
)
self.assertAllClose(lhs, rhs)
# expand_vector(matmul(x, w)) == matmul(expand_vector(x), expand_2_axes(w))
lhs = tf2_utils_2x_wide.expand_vector(np.matmul(x, w))
rhs = np.matmul(
tf2_utils_2x_wide.expand_vector(x),
tf2_utils_2x_wide.expand_2_axes(w, epsilon=0.1),
)
self.assertAllClose(lhs, rhs)

def test_end_to_end(self):
"""Covers expand_vector, expand_2_axes, and expand_1_axis."""
model_narrow = tf_keras.Sequential()
Expand Down

0 comments on commit 6e3f77d

Please sign in to comment.