Skip to content

Commit

Permalink
修复bug(将reshape与transpose误用)
Browse files Browse the repository at this point in the history
  • Loading branch information
TolicWang committed Jun 28, 2019
1 parent 4c5938d commit 5781498
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions data/TaxiBJ/TaxiBJ.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,12 @@ def load_dataset(T=48, nb_flow=2, len_closeness=None, len_period=None, len_trend
XP = np.vstack(XP) # shape = [15072,2,32,32]
XT = np.vstack(XT) # shape = [15072,2,32,32]
Y = np.vstack(Y) # shape = [15072,2,32,32]
XC = XC.reshape(XC.shape[0], 32, 32, -1)
XP = XP.reshape(XP.shape[0], 32, 32, -1)
XT = XT.reshape(XT.shape[0], 32, 32, -1)
Y = Y.reshape(Y.shape[0], 32, 32, -1)

XC=np.transpose(XC,[0,2,3,1])
XP=np.transpose(XP,[0,2,3,1])
XT=np.transpose(XT,[0,2,3,1])
Y=np.transpose(Y,[0,2,3,1])

print("XC shape: ", XC.shape, "XP shape: ", XP.shape, "XT shape: ", XT.shape, "Y shape:", Y.shape)

XC_train, XP_train, XT_train, Y_train = XC[:-len_test], XP[:-len_test], XT[:-len_test], Y[:-len_test]
Expand Down

0 comments on commit 5781498

Please sign in to comment.