Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 6, 2023
1 parent 492d8c3 commit 6166f59
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from returnn.frontend.tensor_array import TensorArray
from returnn.frontend.encoder.conformer import ConformerEncoder, ConformerConvSubsample

from i6_experiments.users.zeyer.utils.dict_update import dict_update_deep, dict_update_delete_deep
from i6_experiments.users.zeyer.utils.dict_update import dict_update_deep
from i6_experiments.users.zeyer.lr_schedules.lin_warmup_invsqrt_decay import dyn_lr_lin_warmup_invsqrt_decay
from i6_experiments.users.zeyer.lr_schedules.combine_eval import dyn_lr_combine_eval
from i6_experiments.users.zeyer.lr_schedules.piecewise_linear import dyn_lr_piecewise_linear
Expand Down Expand Up @@ -417,8 +417,7 @@ def train_exp(
prefix = _sis_prefix + "/" + name
task = _get_ls_task()
config = config.copy()
config = dict_update_deep(config, config_updates)
config = dict_update_delete_deep(config, config_deletes)
config = dict_update_deep(config, config_updates, config_deletes)
if "__num_epochs" in config:
num_epochs = config.pop("__num_epochs")
if "__gpu_mem" in config:
Expand Down Expand Up @@ -594,17 +593,14 @@ def _get_ls_task():
],
"rf_att_dropout_broadcast": False, # attdropfixbc
},
)
config_24gb_v5 = dict_update_delete_deep(
config_24gb_v5,
[
# specaugorig
"specaugment_num_spatial_mask_factor",
"specaugment_max_consecutive_feature_dims",
],
)

config_24gb_v6 = dict_update_delete_deep(config_24gb_v5, ["pretrain_opts"])
config_24gb_v6 = dict_update_deep(config_24gb_v5, None, ["pretrain_opts"])

_cfg_lrlin1e_5_295k = { # for bs15k, mgpu4
"learning_rate": 1.0,
Expand All @@ -622,9 +618,6 @@ def _get_ls_task():
"torch_distributed": {}, # multi-GPU
"__gpu_mem": 11,
},
)
config_11gb_v6_f32_bs15k_accgrad4_mgpu = dict_update_delete_deep(
config_11gb_v6_f32_bs15k_accgrad4_mgpu,
[
"torch_amp", # f32
],
Expand Down
6 changes: 3 additions & 3 deletions users/zeyer/experiments/exp2023_04_25_rf/rz.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
config_24gb_v6,
_batch_size_factor,
)
from i6_experiments.users.zeyer.utils.dict_update import dict_update_deep, dict_update_delete_deep
from i6_experiments.users.zeyer.utils.dict_update import dict_update_deep

if TYPE_CHECKING:
from i6_experiments.users.zeyer.model_with_checkpoints import ModelWithCheckpoints
Expand Down Expand Up @@ -92,14 +92,14 @@ def py():
)


config_v4_f32 = dict_update_delete_deep(config_24gb_v4, ["torch_amp"])
config_v4_f32 = dict_update_deep(config_24gb_v4, None, ["torch_amp"])
config_v4_f32_bs20k = dict_update_deep(
config_v4_f32,
{
"batch_size": 20_000 * _batch_size_factor, # 30k gives OOM on the 16GB GPU"
},
)
config_v6_f32 = dict_update_delete_deep(config_24gb_v6, ["torch_amp"])
config_v6_f32 = dict_update_deep(config_24gb_v6, None, ["torch_amp"])
config_v6_f32_bs20k = dict_update_deep(
config_v6_f32,
{
Expand Down
10 changes: 7 additions & 3 deletions users/zeyer/utils/dict_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
from typing import Optional, Any, Dict, Sequence


def dict_update_deep(d: Dict[str, Any], deep_updates: Optional[Dict[str, Any]]) -> Dict[str, Any]:
def dict_update_deep(
d: Dict[str, Any], deep_updates: Optional[Dict[str, Any]], deep_deletes: Optional[Sequence[str]] = None
) -> Dict[str, Any]:
"""
:param d: dict to update
:param deep_updates: might also contain "." in the key, for nested dicts
:param deep_deletes: might also contain "." in the key, for nested dicts
:return: updated dict
"""
d = _dict_update_delete_deep(d, deep_deletes)
if not deep_updates:
return d
d = d.copy()
Expand All @@ -25,7 +29,7 @@ def dict_update_deep(d: Dict[str, Any], deep_updates: Optional[Dict[str, Any]])
return d


def dict_update_delete_deep(d: Dict[str, Any], deep_deletes: Optional[Sequence[str]]) -> Dict[str, Any]:
def _dict_update_delete_deep(d: Dict[str, Any], deep_deletes: Optional[Sequence[str]]) -> Dict[str, Any]:
"""
:param d: dict to update (to delete from)
:param deep_deletes: might also contain "." in the key, for nested dicts
Expand All @@ -38,7 +42,7 @@ def dict_update_delete_deep(d: Dict[str, Any], deep_deletes: Optional[Sequence[s
assert isinstance(k, str)
if "." in k:
k1, k2 = k.split(".", 1)
d[k1] = dict_update_delete_deep(d[k1], [k2])
d[k1] = _dict_update_delete_deep(d[k1], [k2])
else:
del d[k]
return d

0 comments on commit 6166f59

Please sign in to comment.