diff --git a/recbole/model/general_recommender/neumf.py b/recbole/model/general_recommender/neumf.py index 991fff8bc..29bfbdc76 100644 --- a/recbole/model/general_recommender/neumf.py +++ b/recbole/model/general_recommender/neumf.py @@ -80,22 +80,35 @@ def __init__(self, config, dataset): def load_pretrain(self): r"""A simple implementation of loading pretrained parameters.""" - mf = torch.load(self.mf_pretrain_path) - mlp = torch.load(self.mlp_pretrain_path) - self.user_mf_embedding.weight.data.copy_(mf.user_mf_embedding.weight) - self.item_mf_embedding.weight.data.copy_(mf.item_mf_embedding.weight) - self.user_mlp_embedding.weight.data.copy_(mlp.user_mlp_embedding.weight) - self.item_mlp_embedding.weight.data.copy_(mlp.item_mlp_embedding.weight) - - for (m1, m2) in zip(self.mlp_layers.mlp_layers, mlp.mlp_layers.mlp_layers): - if isinstance(m1, nn.Linear) and isinstance(m2, nn.Linear): - m1.weight.data.copy_(m2.weight) - m1.bias.data.copy_(m2.bias) + mf = torch.load(self.mf_pretrain_path, map_location="cpu") + mlp = torch.load(self.mlp_pretrain_path, map_location="cpu") + mf = mf if "state_dict" not in mf else mf["state_dict"] + mlp = mlp if "state_dict" not in mlp else mlp["state_dict"] + self.user_mf_embedding.weight.data.copy_(mf["user_mf_embedding.weight"]) + self.item_mf_embedding.weight.data.copy_(mf["item_mf_embedding.weight"]) + self.user_mlp_embedding.weight.data.copy_(mlp["user_mlp_embedding.weight"]) + self.item_mlp_embedding.weight.data.copy_(mlp["item_mlp_embedding.weight"]) + + mlp_layers = list(self.mlp_layers.state_dict().keys()) + index = 0 + for layer in self.mlp_layers.mlp_layers: + if isinstance(layer, nn.Linear): + weight_key = "mlp_layers." + mlp_layers[index] + bias_key = "mlp_layers." + mlp_layers[index + 1] + assert ( + layer.weight.shape == mlp[weight_key].shape + ), f"mlp layer parameter shape mismatch" + assert ( + layer.bias.shape == mlp[bias_key].shape + ), f"mlp layer parameter shape mismatch" + layer.weight.data.copy_(mlp[weight_key]) + layer.bias.data.copy_(mlp[bias_key]) + index += 2 predict_weight = torch.cat( - [mf.predict_layer.weight, mlp.predict_layer.weight], dim=1 + [mf["predict_layer.weight"], mlp["predict_layer.weight"]], dim=1 ) - predict_bias = mf.predict_layer.bias + mlp.predict_layer.bias + predict_bias = mf["predict_layer.bias"] + mlp["predict_layer.bias"] self.predict_layer.weight.data.copy_(predict_weight) self.predict_layer.bias.data.copy_(0.5 * predict_bias)