diff --git a/official/modeling/fast_training/experimental/tf2_utils_2x_wide.py b/official/modeling/fast_training/experimental/tf2_utils_2x_wide.py index 94529392204..a998980ccde 100644 --- a/official/modeling/fast_training/experimental/tf2_utils_2x_wide.py +++ b/official/modeling/fast_training/experimental/tf2_utils_2x_wide.py @@ -36,13 +36,14 @@ def expand_vector(v: np.ndarray) -> np.ndarray: def expand_1_axis(w: np.ndarray, epsilon: float, axis: int) -> np.ndarray: - """Expands either the first dimension or the last dimension of w. + """Expands either the first or last dimension of w. - If `axis = 0`, the following constraint will be satisfied: + If `axis = 0`, the following expression will be satisfied: matmul(x, w) == - matmul(expand_vector(x), expand_1_axis(w, epsilon=0.1, axis=0)) + matmul(expand_vector(x), expand_1_axis(w, axis=0)) - If `axis = -1`, the following constraint will be satisfied if `epsilon = 0.0`: + If `axis = -1` and `epsilon = 0.0`, the following constraint will be + satisfied: expand_vector(matmul(x, w)) == 2 * matmul(x, expand_1_axis(w, epsilon=0.0, axis=-1)) @@ -54,9 +55,12 @@ def expand_1_axis(w: np.ndarray, Returns: Expanded numpy array. """ - assert axis in (0, -1), ( - "Only support expanding the first or the last dimension. " - "Got: {}".format(axis)) + + if axis not in (0, -1): + raise ValueError( + "Only support expanding the first or the last dimension. " + "Got: {}".format(axis) + ) rank = len(w.shape) @@ -65,7 +69,7 @@ def expand_1_axis(w: np.ndarray, sign_flip = np.array([1, -1]) for _ in range(rank - 1): - sign_flip = np.expand_dims(sign_flip, axis=-1 if axis == 0 else 0) + sign_flip = np.expand_dims(sign_flip, axis=axis - 1) sign_flip = np.tile(sign_flip, [w.shape[0]] + [1] * (rank - 2) + [w.shape[-1]]) @@ -76,9 +80,9 @@ def expand_1_axis(w: np.ndarray, def expand_2_axes(w: np.ndarray, epsilon: float) -> np.ndarray: - """Expands the first dimension and the last dimension of w. + """Expands the first and last dimension of w. - The following constraint will be satisfied: + This operation satisfies the following expression: expand_vector(matmul(x, w)) == matmul(expand_vector(x), expand_2_axes(w)) Args: @@ -109,8 +113,8 @@ def var_to_var(var_from: tf.Variable, epsilon: float): """Expands a variable to another variable. - Assume the shape of `var_from` is (a, b, ..., y, z), the shape of `var_to` - can be (a, ..., z * 2), (a * 2, ..., z * 2), (a * 2, ..., z) + Assuming the shape of `var_from` is (a, b, ..., y, z), then shape of `var_to` + must be one of (a, ..., z * 2), (a * 2, ..., z * 2), or (a * 2, ..., z). If the shape of `var_to` is (a, ..., 2 * z): For any x, tf.matmul(x, var_to) ~= expand_vector(tf.matmul(x, var_from)) / 2 @@ -131,21 +135,30 @@ def var_to_var(var_from: tf.Variable, if shape_from == shape_to: var_to.assign(var_from) + return + + var_from_np = var_from.numpy() + + if len(shape_from) == len(shape_to) == 1: + var_to.assign(expand_vector(var_from_np)) + return - elif len(shape_from) == 1 and len(shape_to) == 1: - var_to.assign(expand_vector(var_from.numpy())) + a_from, z_from = shape_from[0], shape_from[-1] + a_to, z_to = shape_to[0], shape_to[-1] - elif shape_from[0] * 2 == shape_to[0] and shape_from[-1] == shape_to[-1]: - var_to.assign(expand_1_axis(var_from.numpy(), epsilon=epsilon, axis=0)) + if a_to == 2 * a_from and z_to == z_from: + var_to.assign(expand_1_axis(var_from_np, epsilon=epsilon, axis=0)) + return - elif shape_from[0] == shape_to[0] and shape_from[-1] * 2 == shape_to[-1]: - var_to.assign(expand_1_axis(var_from.numpy(), epsilon=epsilon, axis=-1)) + if a_to == a_from and z_to == 2 * z_from: + var_to.assign(expand_1_axis(var_from_np, epsilon=epsilon, axis=-1)) + return - elif shape_from[0] * 2 == shape_to[0] and shape_from[-1] * 2 == shape_to[-1]: - var_to.assign(expand_2_axes(var_from.numpy(), epsilon=epsilon)) + if a_to == 2 * a_from and z_to == 2 * z_from: + var_to.assign(expand_2_axes(var_from_np, epsilon=epsilon)) + return - else: - raise ValueError("Shape not supported, {}, {}".format(shape_from, shape_to)) + raise ValueError("Shape not supported, {}, {}".format(shape_from, shape_to)) def model_to_model_2x_wide(model_from: tf.Module, @@ -170,8 +183,9 @@ def model_to_model_2x_wide(model_from: tf.Module, assert model_narrow([[1, 2, 3]]) == model_wide([[1, 1, 2, 2, 3, 3]]) ``` - We assume that `model_from` and `model_to` has the same architecture and only - widths of them differ. + We assume that `model_from` and `model_to` have the same architecture and + differ + only in widths. Args: model_from: input model to expand. diff --git a/official/modeling/fast_training/experimental/tf2_utils_2x_wide_test.py b/official/modeling/fast_training/experimental/tf2_utils_2x_wide_test.py index 6fc5f300f0e..ecc1c4ef452 100644 --- a/official/modeling/fast_training/experimental/tf2_utils_2x_wide_test.py +++ b/official/modeling/fast_training/experimental/tf2_utils_2x_wide_test.py @@ -71,6 +71,31 @@ def test_expand_3d_tensor_axis_2(self): o1 = np.matmul(x, w1) self.assertAllClose(o0, np.sum(o1.reshape(2, 2), axis=-1)) + def test_relations(self): + x = np.array([10, 11]) + w = np.random.rand(2, 2) + # matmul(x, w) == matmul(expand_vector(x), expand_1_axis(w, axis=0)) + lhs = np.matmul(x, w) + rhs = np.matmul( + tf2_utils_2x_wide.expand_vector(x), + tf2_utils_2x_wide.expand_1_axis(w, epsilon=0.1, axis=0), + ) + self.assertAllClose(lhs, rhs) + # expand_vector(matmul(x, w)) == + # 2 * matmul(x, expand_1_axis(w, epsilon=0.0, axis=-1)) + lhs = tf2_utils_2x_wide.expand_vector(np.matmul(x, w)) + rhs = 2 * np.matmul( + x, tf2_utils_2x_wide.expand_1_axis(w, epsilon=0.0, axis=-1) + ) + self.assertAllClose(lhs, rhs) + # expand_vector(matmul(x, w)) == matmul(expand_vector(x), expand_2_axes(w)) + lhs = tf2_utils_2x_wide.expand_vector(np.matmul(x, w)) + rhs = np.matmul( + tf2_utils_2x_wide.expand_vector(x), + tf2_utils_2x_wide.expand_2_axes(w, epsilon=0.1), + ) + self.assertAllClose(lhs, rhs) + def test_end_to_end(self): """Covers expand_vector, expand_2_axes, and expand_1_axis.""" model_narrow = tf_keras.Sequential()