Skip to content

Commit

Permalink
Fix loading checkpoint in hubert preprocessing (pytorch#2310)
Browse files Browse the repository at this point in the history
Summary:
When checkpoint is on GPU device and preprocessing is on CPU, the script will throw an exception error. Fix it to load the model state dictionary into CPU by default.

Pull Request resolved: pytorch#2310

Reviewed By: mthrok

Differential Revision: D35316903

Pulled By: nateanl

fbshipit-source-id: d3e7183400ba133240aa6d205f5c671a421a9fed
  • Loading branch information
nateanl authored and xiaohui-zhang committed May 4, 2022
1 parent a50fa7e commit 3cf05e4
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions examples/hubert/utils/feature_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .common_utils import _get_feat_lens_paths

_LG = logging.getLogger(__name__)
_DEFAULT_DEVICE = torch.device("cpu")


def get_shard_range(num_lines: int, num_rank: int, rank: int) -> Tuple[int, int]:
Expand Down Expand Up @@ -105,16 +106,17 @@ def extract_feature_hubert(
return feat


def _load_state(model: Module, checkpoint_path: Path) -> Module:
def _load_state(model: Module, checkpoint_path: Path, device=_DEFAULT_DEVICE) -> Module:
"""Load weights from HuBERTPretrainModel checkpoint into hubert_pretrain_base model.
Args:
model (Module): The hubert_pretrain_base model.
checkpoint_path (Path): The model checkpoint.
device (torch.device, optional): The device of the model. (Default: ``torch.device("cpu")``)
Returns:
(Module): The pretrained model.
"""
state_dict = torch.load(checkpoint_path)
state_dict = torch.load(checkpoint_path, map_location=device)
state_dict = {k.replace("model.", ""): v for k, v in state_dict["state_dict"].items()}
model.load_state_dict(state_dict)
return model
Expand Down Expand Up @@ -169,8 +171,8 @@ def dump_features(
from torchaudio.models import hubert_pretrain_base

model = hubert_pretrain_base()
model = _load_state(model, checkpoint_path)
model.to(device)
model = _load_state(model, checkpoint_path, device)

with open(tsv_file, "r") as f:
root = f.readline().rstrip()
Expand Down

0 comments on commit 3cf05e4

Please sign in to comment.