From 57814989dc951f510840acf87dc973d025f002cb Mon Sep 17 00:00:00 2001 From: wangcheng Date: Fri, 28 Jun 2019 17:13:03 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dbug(=E5=B0=86reshape=E4=B8=8E?= =?UTF-8?q?transpose=E8=AF=AF=E7=94=A8)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/TaxiBJ/TaxiBJ.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/data/TaxiBJ/TaxiBJ.py b/data/TaxiBJ/TaxiBJ.py index 19055ab..0ccf8b3 100644 --- a/data/TaxiBJ/TaxiBJ.py +++ b/data/TaxiBJ/TaxiBJ.py @@ -274,10 +274,12 @@ def load_dataset(T=48, nb_flow=2, len_closeness=None, len_period=None, len_trend XP = np.vstack(XP) # shape = [15072,2,32,32] XT = np.vstack(XT) # shape = [15072,2,32,32] Y = np.vstack(Y) # shape = [15072,2,32,32] - XC = XC.reshape(XC.shape[0], 32, 32, -1) - XP = XP.reshape(XP.shape[0], 32, 32, -1) - XT = XT.reshape(XT.shape[0], 32, 32, -1) - Y = Y.reshape(Y.shape[0], 32, 32, -1) + + XC=np.transpose(XC,[0,2,3,1]) + XP=np.transpose(XP,[0,2,3,1]) + XT=np.transpose(XT,[0,2,3,1]) + Y=np.transpose(Y,[0,2,3,1]) + print("XC shape: ", XC.shape, "XP shape: ", XP.shape, "XT shape: ", XT.shape, "Y shape:", Y.shape) XC_train, XP_train, XT_train, Y_train = XC[:-len_test], XP[:-len_test], XT[:-len_test], Y[:-len_test]