From c87d481be1a654776cfd1897fefee6b3d36efd0c Mon Sep 17 00:00:00 2001 From: zbw <18735382001@163.com> Date: Thu, 20 Oct 2022 23:50:09 +0800 Subject: [PATCH 1/2] FIX: fix load_pretrain function in NeuMF --- recbole/model/general_recommender/neumf.py | 35 ++++++++++++++-------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/recbole/model/general_recommender/neumf.py b/recbole/model/general_recommender/neumf.py index 991fff8bc..6c1b6aad9 100644 --- a/recbole/model/general_recommender/neumf.py +++ b/recbole/model/general_recommender/neumf.py @@ -80,22 +80,31 @@ def __init__(self, config, dataset): def load_pretrain(self): r"""A simple implementation of loading pretrained parameters.""" - mf = torch.load(self.mf_pretrain_path) - mlp = torch.load(self.mlp_pretrain_path) - self.user_mf_embedding.weight.data.copy_(mf.user_mf_embedding.weight) - self.item_mf_embedding.weight.data.copy_(mf.item_mf_embedding.weight) - self.user_mlp_embedding.weight.data.copy_(mlp.user_mlp_embedding.weight) - self.item_mlp_embedding.weight.data.copy_(mlp.item_mlp_embedding.weight) - - for (m1, m2) in zip(self.mlp_layers.mlp_layers, mlp.mlp_layers.mlp_layers): - if isinstance(m1, nn.Linear) and isinstance(m2, nn.Linear): - m1.weight.data.copy_(m2.weight) - m1.bias.data.copy_(m2.bias) + mf = torch.load(self.mf_pretrain_path,map_location="cpu") + mlp = torch.load(self.mlp_pretrain_path,map_location="cpu") + mf = mf if "state_dict" not in mf else mf["state_dict"] + mlp = mlp if "state_dict" not in mlp else mlp["state_dict"] + self.user_mf_embedding.weight.data.copy_(mf["user_mf_embedding.weight"]) + self.item_mf_embedding.weight.data.copy_(mf["item_mf_embedding.weight"]) + self.user_mlp_embedding.weight.data.copy_(mlp["user_mlp_embedding.weight"]) + self.item_mlp_embedding.weight.data.copy_(mlp["item_mlp_embedding.weight"]) + + mlp_layers = list(self.mlp_layers.state_dict().keys()) + index = 0 + for layer in self.mlp_layers.mlp_layers: + if isinstance(layer, nn.Linear): + weight_key ="mlp_layers."+ mlp_layers[index] + bias_key = "mlp_layers."+ mlp_layers[index+1] + assert layer.weight.shape == mlp[weight_key].shape, f'mlp layer parameter shape mismatch' + assert layer.bias.shape == mlp[bias_key].shape, f'mlp layer parameter shape mismatch' + layer.weight.data.copy_(mlp[weight_key]) + layer.bias.data.copy_(mlp[bias_key]) + index += 2 predict_weight = torch.cat( - [mf.predict_layer.weight, mlp.predict_layer.weight], dim=1 + [mf["predict_layer.weight"], mlp["predict_layer.weight"]], dim=1 ) - predict_bias = mf.predict_layer.bias + mlp.predict_layer.bias + predict_bias = mf["predict_layer.bias"] + mlp["predict_layer.bias"] self.predict_layer.weight.data.copy_(predict_weight) self.predict_layer.bias.data.copy_(0.5 * predict_bias) From 4e5630cfa86cab575596ee4a23cd2d06fc4829d2 Mon Sep 17 00:00:00 2001 From: zhengbw0324 Date: Thu, 20 Oct 2022 15:50:43 +0000 Subject: [PATCH 2/2] Format Python code according to PEP8 --- recbole/model/general_recommender/neumf.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/recbole/model/general_recommender/neumf.py b/recbole/model/general_recommender/neumf.py index 6c1b6aad9..29bfbdc76 100644 --- a/recbole/model/general_recommender/neumf.py +++ b/recbole/model/general_recommender/neumf.py @@ -80,8 +80,8 @@ def __init__(self, config, dataset): def load_pretrain(self): r"""A simple implementation of loading pretrained parameters.""" - mf = torch.load(self.mf_pretrain_path,map_location="cpu") - mlp = torch.load(self.mlp_pretrain_path,map_location="cpu") + mf = torch.load(self.mf_pretrain_path, map_location="cpu") + mlp = torch.load(self.mlp_pretrain_path, map_location="cpu") mf = mf if "state_dict" not in mf else mf["state_dict"] mlp = mlp if "state_dict" not in mlp else mlp["state_dict"] self.user_mf_embedding.weight.data.copy_(mf["user_mf_embedding.weight"]) @@ -93,10 +93,14 @@ def load_pretrain(self): index = 0 for layer in self.mlp_layers.mlp_layers: if isinstance(layer, nn.Linear): - weight_key ="mlp_layers."+ mlp_layers[index] - bias_key = "mlp_layers."+ mlp_layers[index+1] - assert layer.weight.shape == mlp[weight_key].shape, f'mlp layer parameter shape mismatch' - assert layer.bias.shape == mlp[bias_key].shape, f'mlp layer parameter shape mismatch' + weight_key = "mlp_layers." + mlp_layers[index] + bias_key = "mlp_layers." + mlp_layers[index + 1] + assert ( + layer.weight.shape == mlp[weight_key].shape + ), f"mlp layer parameter shape mismatch" + assert ( + layer.bias.shape == mlp[bias_key].shape + ), f"mlp layer parameter shape mismatch" layer.weight.data.copy_(mlp[weight_key]) layer.bias.data.copy_(mlp[bias_key]) index += 2