Skip to content

Commit

Permalink
import ConformerEncoder _self_att_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 8, 2023
1 parent ad54797 commit b2f19b0
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions users/zeyer/experiments/exp2023_04_25_rf/_chunked_aed_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
_returnn_tf_ckpt_filename = "i6_core/returnn/training/AverageTFCheckpointsJob.5r6TB06ypiVq/output/model/average.index"
_load_existing_ckpt_in_test = True

_ParamMapping = {} # type: Dict[str,str]
_ParamMapping = {} # type: Dict[str,str] # used by map_param_func_v2


def _get_tf_checkpoint_path() -> tk.Path:
Expand Down Expand Up @@ -53,7 +53,8 @@ def _get_pt_checkpoint_path(*, run_if_not_exists: bool) -> tk.Path:
return converter.out_checkpoint


def _add_params():
def _add_param_mappings():
# used by map_param_func_v2. this is for simple cases, more complex cases in map_param_func_v2.
# frontend
for layer_idx in [0, 1, 2]:
orig_name = "conv0" if layer_idx == 0 else f"subsample_conv{layer_idx - 1}"
Expand Down Expand Up @@ -127,9 +128,6 @@ def _add_params():
f"encoder.layers.{layer_idx}.conv_layer_norm.bias"
] = f"conformer_block_{layer_idx + 1:02d}_conv_mod_ln/bias"
# self-att
_ParamMapping[
f"encoder.layers.{layer_idx}.self_att.qkv.weight"
] = f"conformer_block_{layer_idx + 1:02d}_self_att/QKV"
_ParamMapping[
f"encoder.layers.{layer_idx}.self_att.proj.weight"
] = f"conformer_block_{layer_idx + 1:02d}_self_att_linear/W"
Expand All @@ -151,7 +149,7 @@ def _add_params():
] = f"conformer_block_{layer_idx + 1:02d}_ln/bias"


_add_params()
_add_param_mappings()


def map_param_func_v2(reader, name: str, var: rf.Parameter) -> numpy.ndarray:
Expand Down Expand Up @@ -215,6 +213,31 @@ def map_param_func_v2(reader, name: str, var: rf.Parameter) -> numpy.ndarray:
assert value.dtype.name == var.dtype, name
return value

if name.endswith(".self_att.qkv.weight"):
# We used ConformerEncoder._self_att_v2 here, where Q,K,V was split into 3 linear layers.
assert name.startswith("encoder.layers.")
layer_idx = int(name.split(".")[2])
tf_prefix = f"conformer_block_{layer_idx + 1:02d}_self_att"
# Rel pos enc matrix has shape (..., enc_dim_per_head_dim), thus we can use it to calc num_heads.
tf_enc_matrix_var_name = f"conformer_block_{layer_idx + 1:02d}_self_att_ln_rel_pos_enc/encoding_matrix"
key_dim_per_head_dim = reader.get_tensor(tf_enc_matrix_var_name).shape[1]
# Note, also see config for reverse, func load_qkv_mats
# Want: (in_dim, num_heads * 3 * att_dim_per_head)
# We get for each: (in_dim, num_heads * att_dim_per_head)
tf_var_names = [f"{tf_prefix}_ln_{name}/W" for name in ["Q", "K", "V"]]
values = [reader.get_tensor(name) for name in tf_var_names]
key_dim_total = values[0].shape[1]
assert key_dim_total % key_dim_per_head_dim == 0
num_heads = key_dim_total // key_dim_per_head_dim
concat = []
for v in values:
assert isinstance(v, numpy.ndarray)
assert v.shape[0] == var.dims[0].dimension
concat.append(v.reshape((v.shape[0], num_heads, -1)))
out = numpy.concatenate(concat, axis=2)
out = out.reshape(var.batch_shape)
return out

raise NotImplementedError(f"cannot map {name!r} {var}")


Expand Down

0 comments on commit b2f19b0

Please sign in to comment.