You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def fit(self):
cfg = self.cfg
refiner = nn.DataParallel(self.refiner, device_ids=range(cfg.num_gpu))
learning_rate = cfg.lr
while True:
for inputs in self.train_loader:
self.refiner.train()
if cfg.scale > 0:
scale = cfg.scale
hr, lr = inputs[-1][0], inputs[-1][1]
Suppose I give batch_size = 10 in the DataLoader, then in the fit() function, in each iteration of the for loop, the variable inputs while containing data of 10 image pairs. But then the code seems to be taking only 1 pair out of the entire batch of 10 image pairs. Am I missing something here?
The text was updated successfully, but these errors were encountered:
The definitions for
self.train_loader
is defined as follows in the file:self.train_data = TrainDataset(cfg.train_data_path, scale=cfg.scale, size=cfg.patch_size)
self.train_loader = DataLoader(self.train_data, batch_size=cfg.batch_size, num_workers=1, shuffle=True, drop_last=True)
Suppose I give
batch_size
= 10 in theDataLoader
, then in thefit()
function, in each iteration of the for loop, the variableinputs
while containing data of 10 image pairs. But then the code seems to be taking only 1 pair out of the entire batch of 10 image pairs. Am I missing something here?The text was updated successfully, but these errors were encountered: