-
Notifications
You must be signed in to change notification settings - Fork 7
/
train.py
646 lines (542 loc) · 26.9 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
# Copyright 2023 - Valeo Comfort and Driving Assistance
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import yaml
import os
import time
import datetime
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from option import Option
import dataset
import utils
import utils.tools as tools
from utils.metrics.eval_results import eval_results
from utils.metrics.tensorboard_logger import tensorboard_logger
from utils.inference.inference_utils import inference
from utils.tools import Recorder
class Trainer(object):
def __init__(self, settings: Option, model: nn.Module, recorder=None):
# Init params
self.settings = settings
self.recorder = recorder
self.model = model.cuda()
self.remain_time = tools.RemainTime(self.settings.n_epochs)
# Init data loader
self.train_loader, self.val_loader, self.train_sampler, self.val_sampler = self._initDataloader()
# Init criterion
self.criterion = self._initCriterion()
# Init optimizer
self.optimizer = self._initOptimizer()
if tools.is_dist_avail_and_initialized():
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model).cuda()
self.model = nn.parallel.DistributedDataParallel(
self.model, device_ids=[self.settings.gpu],
find_unused_parameters=True)
# Get metrics
self.metrics = utils.metrics.IOUEval(
n_classes=self.settings.n_classes, device=torch.device('cpu'),
ignore=self.ignore_class, is_distributed=self.settings.distributed)
self.metrics.reset()
# Define scheduler
self.scheduler = utils.optim.WarmupCosineLR(
optimizer=self.optimizer,
lr=self.settings.lr,
warmup_steps=self.settings.warmup_epochs * len(self.train_loader),
momentum=0.9,
max_steps=len(self.train_loader) * (self.settings.n_epochs - self.settings.warmup_epochs))
# For mixed precision training
self.fp16_scaler = None
if self.settings.use_fp16:
self.fp16_scaler = torch.cuda.amp.GradScaler()
def _initOptimizer(self):
params = self.model.parameters()
adamw_optimizer = torch.optim.AdamW(params=params,
lr=self.settings.lr,
weight_decay=0.01)
return adamw_optimizer
def _initDataloader(self):
# NuScenes dataset
if self.settings.dataset == 'nuScenes':
print('----Using nuScenes dataset----')
version = 'v1.0-mini' if self.settings.use_mini_version else 'v1.0-trainval'
assert self.settings.use_trainval is False
trainset = dataset.nuScenes.Nuscenes(
dataroot=self.settings.data_root, version=version, split='train')
valset = dataset.nuScenes.Nuscenes(
dataroot=self.settings.data_root, version=version, split='val')
self.mapped_cls_name = trainset.mapped_cls_name
self.ignore_class = [0]
self.cls_weight = np.ones((self.settings.n_classes))
self.cls_weight[0] = 0
assert self.settings.test_split is False
self.data_split = 'test' if self.settings.test_split else 'val'
# SemanticKitti dataset
elif self.settings.dataset == 'SemanticKitti':
data_config_path = 'dataset/semantic_kitti/semantic-kitti.yaml'
data_config = yaml.safe_load(open(data_config_path, 'r'))
if self.settings.use_mini_version:
train_sequences = [0]
elif self.settings.use_trainval:
print('Train with the train+val set.')
train_sequences = data_config['split']['train'] + data_config['split']['valid']
else:
train_sequences = data_config['split']['train']
trainset = dataset.semantic_kitti.SemanticKitti(
root=self.settings.data_root,
sequences=train_sequences,
config_path=data_config_path)
self.cls_weight = 1 / (trainset.cls_freq + 1e-3)
self.ignore_class = []
for cl, _ in enumerate(self.cls_weight):
if trainset.data_config['learning_ignore'][cl]:
self.cls_weight[cl] = 0
if self.cls_weight[cl] < 1e-10:
self.ignore_class.append(cl)
if self.recorder is not None:
self.recorder.logger.info('weight: {}'.format(self.cls_weight))
self.mapped_cls_name = trainset.mapped_cls_name
test_sequences = (
data_config['split']['test'] if self.settings.test_split else
data_config['split']['valid'])
valset = dataset.semantic_kitti.SemanticKitti(
root=self.settings.data_root,
sequences=test_sequences,
config_path=data_config_path,
has_label=(self.settings.test_split is False),
)
else:
raise ValueError(
'invalid dataset: {}'.format(self.settings.dataset))
self.train_range_loader = dataset.RangeViewLoader(
dataset=trainset,
config=self.settings.config,
use_kpconv=self.settings.use_kpconv)
self.val_range_loader = dataset.RangeViewLoader(
dataset=valset,
config=self.settings.config,
is_train=False,
use_kpconv=self.settings.use_kpconv)
collate_fn = dataset.custom_collate_kpconv_fn if self.settings.use_kpconv else None
if tools.is_dist_avail_and_initialized():
train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
val_sampler = torch.utils.data.distributed.DistributedSampler(valset, shuffle=False)
train_loader = torch.utils.data.DataLoader(
self.train_range_loader,
batch_size=self.settings.batch_size,
num_workers=self.settings.num_workers,
drop_last=True,
sampler=train_sampler,
collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(
self.val_range_loader,
batch_size=self.settings.batch_size_val,
num_workers=self.settings.num_workers,
drop_last=False,
sampler=val_sampler,
collate_fn=collate_fn)
return train_loader, val_loader, train_sampler, val_sampler
else:
train_loader = torch.utils.data.DataLoader(
self.train_range_loader,
batch_size=self.settings.batch_size,
num_workers=self.settings.num_workers,
shuffle=True,
drop_last=True,
collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(
self.val_range_loader,
batch_size=self.settings.batch_size_val,
num_workers=self.settings.num_workers,
shuffle=False,
drop_last=False,
collate_fn=collate_fn)
return train_loader, val_loader, None, None
def _initCriterion(self):
criterion = {}
criterion['lovasz'] = utils.optim.Lovasz_softmax(ignore=0)
if self.settings.dataset == 'SemanticKitti':
alpha = np.log(1+self.cls_weight)
alpha = alpha / alpha.max()
elif self.settings.dataset == 'nuScenes':
alpha = np.ones((self.settings.n_classes))
alpha[0] = 0
if self.recorder is not None:
self.recorder.logger.info('focal_loss alpha: {}'.format(alpha))
criterion['focal_loss'] = utils.optim.FocalSoftmaxLoss(
self.settings.n_classes, gamma=2, alpha=alpha, softmax=False)
# Set device
for _, v in criterion.items():
v.cuda()
return criterion
def compute_losses(self, output, output_softmax, label, mask):
loss_lovasz = self.criterion['lovasz'](output_softmax, label)
loss_focal = self.criterion['focal_loss'](output_softmax, label, mask=mask)
total_loss = loss_focal + loss_lovasz
return total_loss, loss_lovasz, loss_focal
def run(self, epoch, mode='Train', print_results=False, save_results_path=None):
if self.settings.use_kpconv:
# Training and validation when using the KPConv layer
return self.run_with_kpconv(
epoch=epoch, mode=mode,
print_results=print_results,
save_results_path=save_results_path)
else:
# Training and validation when not using the KPConv layer
return self.run_without_kpconv(
epoch=epoch, mode=mode,
print_results=print_results,
save_results_path=save_results_path)
# Method for training when using the KPConv layer
def run_without_kpconv(self, epoch, mode='Train', print_results=False, save_results_path=None):
if mode == 'Train':
dataloader = self.train_loader
self.model.train()
if self.train_sampler is not None:
self.train_sampler.set_epoch(epoch)
elif mode == 'Validation':
dataloader = self.val_loader
self.model.eval()
else:
raise ValueError('invalid mode: {}'.format(mode))
model_without_ddp = self.model
if hasattr(self.model, 'module'):
model_without_ddp = self.model.module
# Init metrics
loss_meter = tools.AverageMeter()
self.metrics.reset()
total_iter = len(dataloader)
t_start = time.time()
for i, (input_feature, input_label, input_mask) in enumerate(dataloader):
t_process_start = time.time()
# Feature: range, x, y, z, intensity
input_feature = input_feature.cuda() # shape: B x 5 x H x W
input_label = input_label.cuda().long()
input_label = input_label * input_label.ge(1).long()
input_mask = input_mask.cuda() * input_label.ge(1).float()
# Forward propagation
if mode == 'Train':
with torch.cuda.amp.autocast(self.fp16_scaler is not None):
output = self.model(input_feature)
output_softmax = F.softmax(output, dim=1)
# Loss calculation
total_loss, loss_lovasz, loss_focal = self.compute_losses(
output, output_softmax, input_label, input_mask)
# Backward
self.optimizer.zero_grad()
if self.fp16_scaler is None:
total_loss.backward()
self.optimizer.step()
else:
self.fp16_scaler.scale(total_loss).backward()
self.fp16_scaler.step(self.optimizer)
self.fp16_scaler.update()
# Update lr after backward (required by pytorch)
self.scheduler.step()
else:
with torch.no_grad():
assert input_feature.shape[0] == 1 # validation batch size has to be 1
# Validation
im_meta = dict(flip=False)
with torch.cuda.amp.autocast(self.fp16_scaler is not None):
lidar_pred = inference(
model_without_ddp.rangevit,
[input_feature],
[im_meta],
ori_shape=input_feature.shape[2:4],
window_size=self.settings.window_size,
window_stride=self.settings.window_stride,
batch_size=input_feature.shape[0],
use_kpconv=False)
output = lidar_pred.unsqueeze(0) # [C, H, W] ==> [1, C, H, W]
output_softmax = F.softmax(output, dim=1)
# Loss calculation
total_loss, loss_lovasz, loss_focal = self.compute_losses(
output, output_softmax, input_label, input_mask)
# Measure IoU and record loss
loss = total_loss.mean()
with torch.no_grad():
argmax = output.argmax(dim=1)
self.metrics.addBatch(argmax, input_label) # 2D predictions
loss_meter.update(loss.item(), input_feature.size(0))
# Timer logger
t_process_end = time.time()
data_cost_time = t_process_start - t_start
process_cost_time = t_process_end - t_process_start
self.remain_time.update(cost_time=(time.time() - t_start), mode=mode)
remain_time = datetime.timedelta(
seconds=self.remain_time.getRemainTime(
epoch=epoch, iters=i, total_iter=total_iter, mode=mode
))
t_start = time.time()
# Logging
if (i % self.settings.log_frequency == 0) or (i == total_iter-1):
with torch.no_grad():
mean_iou, _, mean_acc, _ = self.metrics.getIoUnAcc()
if self.recorder is not None:
for g in self.optimizer.param_groups:
lr = g['lr']
break
log_str = '>>> {} E[{:03d}|{:03d}] I[{:04d}|{:04d}] DT[{:.3f}] PT[{:.3f}] '.format(
mode, self.settings.n_epochs, epoch+1, total_iter, i+1, data_cost_time, process_cost_time)
log_str += 'LR {} Loss {:0.4f} Acc {:0.4f} IOU {:0.4F} '.format(
lr, loss.item(), mean_acc.item(), mean_iou.item())
log_str += 'RT {}'.format(remain_time)
self.recorder.logger.info(log_str)
with torch.no_grad():
mean_acc, class_acc = self.metrics.getAcc()
mean_recall, class_recall = self.metrics.getRecall()
mean_iou, class_iou = self.metrics.getIoU()
metrics_dict = {
'mean_acc': mean_acc,
'class_acc': class_acc,
'mean_recall': mean_recall,
'class_recall': class_recall,
'mean_iou': mean_iou,
'class_iou': class_iou,
'conf_matrix': self.metrics.conf_matrix.clone().cpu(),
}
loss_dict = {
'loss_meter_avg': loss_meter.avg,
'loss_focal': loss_focal,
'loss_lovasz': loss_lovasz,
}
# Print results
if self.recorder is not None:
# Print train pixel-wise evaluation results
if mode == 'Train':
if (epoch % self.settings.train_result_frequency == 0) or (epoch == self.settings.n_epochs-1):
eval_results(pixel_or_point='Pixel',
settings=self.settings,
recorder=self.recorder,
metrics_dict=metrics_dict,
dataloader=self.train_range_loader,
print_data_distribution=True)
# Print validation pixel-wise evaluation results
if mode == 'Validation' and (print_results or epoch == self.settings.n_epochs-1):
eval_results(pixel_or_point='Pixel',
settings=self.settings,
recorder=self.recorder,
metrics_dict=metrics_dict,
dataloader=self.val_range_loader,
print_data_distribution=True)
# Tensorboard logger
tensorboard_logger(epoch=epoch,
mode=mode,
recorder=self.recorder,
metrics_dict=metrics_dict,
loss_dict=loss_dict,
lr=lr,
mapped_cls_name=self.mapped_cls_name)
# Results at the end of the epoch
log_str = '>>> {} Loss {:0.4f} Acc {:0.4f} IOU {:0.4F} Recall {:0.4f}'.format(
mode, loss_meter.avg, mean_acc.item(), mean_iou.item(), mean_recall.item())
self.recorder.logger.info(log_str)
result_metrics = {
'Acc': mean_acc.item(),
'IOU': mean_iou.item(),
'Recall': mean_recall.item()
}
return result_metrics
# Method for training and validation when using the KPConv layer
def run_with_kpconv(self, epoch, mode='Train', print_results=False, save_results_path=None):
if mode == 'Train':
dataloader = self.train_loader
self.model.train()
if self.train_sampler is not None:
self.train_sampler.set_epoch(epoch)
elif mode == 'Validation':
dataloader = self.val_loader
self.model.eval()
else:
raise ValueError('invalid mode: {}'.format(mode))
track_remain_time_1epoch = tools.RemainTime(1)
model_without_ddp = self.model
if hasattr(self.model, 'module'):
model_without_ddp = self.model.module
# Init metrics
loss_meter = tools.AverageMeter()
self.metrics.reset()
total_iter = len(dataloader)
t_start = time.time()
for i, batch_dict in enumerate(dataloader):
t_process_start = time.time()
# 2D inputs
input_feature = batch_dict['input2d'].cuda(non_blocking=True)
assert self.settings.in_channels == 5
# 3D inputs
py = batch_dict['py'].cuda(non_blocking=True)
px = batch_dict['px'].cuda(non_blocking=True)
pxyz = batch_dict['points_xyz'].cuda(non_blocking=True)
knns = batch_dict['knns'].cuda(non_blocking=True)
labels3d = batch_dict['labels'].cuda(non_blocking=True).unsqueeze(1).unsqueeze(2)
labels3d = labels3d * labels3d.ge(1).long()
mask_3d = labels3d.ge(1).float()
num_points = batch_dict['num_points']
# Forward propagation
if mode == 'Train':
with torch.cuda.amp.autocast(self.fp16_scaler is not None):
output3d = self.model(input_feature, px, py, pxyz, knns, num_points)
output3d_softmax = F.softmax(output3d, dim=1)
# Loss calculation
total_loss, loss_lovasz, loss_focal = self.compute_losses(
output3d, output3d_softmax, labels3d, mask_3d)
# Backward
self.optimizer.zero_grad()
if self.fp16_scaler is None:
total_loss.backward()
self.optimizer.step()
else:
self.fp16_scaler.scale(total_loss).backward()
self.fp16_scaler.step(self.optimizer)
self.fp16_scaler.update()
# Update lr after backward (required by pytorch)
self.scheduler.step()
else:
with torch.no_grad():
assert input_feature.shape[0] == 1 # validation batch size has to be 1
# Validation
im_meta = dict(flip=False)
with torch.cuda.amp.autocast(self.fp16_scaler is not None):
output_features2d = inference(
model_without_ddp.rangevit,
[input_feature],
[im_meta],
ori_shape=input_feature.shape[2:4],
window_size=self.settings.window_size,
window_stride=self.settings.window_stride,
batch_size=input_feature.shape[0],
use_kpconv=True)
output_features2d = output_features2d.unsqueeze(0) # [C, H, W] ==> [1, C, H, W]
# Apply KPConv layer
output3d = model_without_ddp.rangevit.kpclassifier(
output_features2d, px, py, pxyz, knns, num_points)
output3d_softmax = F.softmax(output3d, dim=1)
# Loss calculation
total_loss, loss_lovasz, loss_focal = self.compute_losses(
output3d, output3d_softmax, labels3d, mask_3d)
# Measure IoU and record loss
loss = total_loss.mean()
with torch.no_grad():
argmax3d = output3d.argmax(dim=1)
self.metrics.addBatch(argmax3d, labels3d) # 3D predictions
loss_meter.update(loss.item(), input_feature.size(0))
# Save the predictions
if (mode == 'Validation' and save_results_path is not None):
pred_np = argmax3d.cpu().numpy()
pred_np = pred_np.reshape((-1)).astype(np.int32)
index = batch_dict['index']
assert index.shape[0] == 1
index = index.item()
if self.settings.dataset == 'nuScenes':
pred_path = os.path.join(save_results_path, 'lidarseg', self.data_split)
nu_dataset = self.val_loader.dataset.dataset
lidar_token = nu_dataset.token_list[index]
if not os.path.isdir(pred_path):
os.makedirs(pred_path)
pred_result_path = os.path.join(pred_path, '{}_lidarseg.bin'.format(lidar_token))
pred_np.tofile(pred_result_path)
elif self.settings.dataset == 'SemanticKitti':
sk_dataset = self.val_loader.dataset.dataset
pred_np_origin = sk_dataset.class_map_lut_inv[pred_np]
seq_id, frame_id = sk_dataset.parsePathInfoByIndex(index)
pred_path = os.path.join(save_results_path, 'sequences', seq_id, 'predictions')
if not os.path.isdir(pred_path):
os.makedirs(pred_path)
pred_result_path = os.path.join(pred_path, '{}.label'.format(frame_id))
pred_np_origin.tofile(pred_result_path)
# Timer logger
t_process_end = time.time()
data_cost_time = t_process_start - t_start
process_cost_time = t_process_end - t_process_start
self.remain_time.update(cost_time=(time.time() - t_start), mode=mode)
remain_time = datetime.timedelta(
seconds=self.remain_time.getRemainTime(
epoch=epoch, iters=i, total_iter=total_iter, mode=mode))
track_remain_time_1epoch.update(cost_time=(time.time() - t_start), mode=mode)
remain_time_1epoch = datetime.timedelta(
seconds=track_remain_time_1epoch.getRemainTime(
epoch=0, iters=i, total_iter=total_iter, mode=mode))
t_start = time.time()
# Logging
if (i % self.settings.log_frequency == 0) or (i == total_iter-1):
with torch.no_grad():
mean_iou, _, mean_acc, _ = self.metrics.getIoUnAcc()
if self.recorder is not None:
for g in self.optimizer.param_groups:
lr = g['lr']
break
log_str = '>>> {} E[{:03d}|{:03d}] I[{:04d}|{:04d}] DT[{:.3f}] PT[{:.3f}] '.format(
mode, self.settings.n_epochs, epoch+1, total_iter, i+1, data_cost_time, process_cost_time)
log_str += 'LR {} Loss {:0.4f} Acc {:0.4f} IOU {:0.4F} '.format(
lr, loss.item(), mean_acc.item(), mean_iou.item())
log_str += 'RT {} '.format(remain_time)
log_str += 'RT PER EPOCH {}'.format(remain_time_1epoch)
self.recorder.logger.info(log_str)
with torch.no_grad():
mean_acc, class_acc = self.metrics.getAcc()
mean_recall, class_recall = self.metrics.getRecall()
mean_iou, class_iou = self.metrics.getIoU()
metrics_dict = {
'mean_acc': mean_acc,
'class_acc': class_acc,
'mean_recall': mean_recall,
'class_recall': class_recall,
'mean_iou': mean_iou,
'class_iou': class_iou,
'conf_matrix': self.metrics.conf_matrix.clone().cpu(),
}
loss_dict = {
'loss_meter_avg': loss_meter.avg,
'loss_focal': loss_focal,
'loss_lovasz': loss_lovasz,
}
# Print results
if self.recorder is not None:
# Print train point-wise results
if mode == 'Train':
if (epoch % self.settings.train_result_frequency == 0) or (epoch == self.settings.n_epochs-1):
eval_results(pixel_or_point='Point',
settings=self.settings,
recorder=self.recorder,
metrics_dict=metrics_dict,
dataloader=self.train_range_loader,
print_data_distribution=True)
# Print validation point-wise results
if mode == 'Validation' and (print_results or epoch == self.settings.n_epochs-1):
eval_results(pixel_or_point='Point',
settings=self.settings,
recorder=self.recorder,
metrics_dict=metrics_dict,
dataloader=self.val_range_loader,
print_data_distribution=True)
# Tensorboard logger
tensorboard_logger(epoch=epoch,
mode=mode,
recorder=self.recorder,
metrics_dict=metrics_dict,
loss_dict=loss_dict,
lr=lr,
mapped_cls_name=self.mapped_cls_name)
# Results at the end of the epoch
log_str = '>>> {} Loss {:0.4f} Acc {:0.4f} IOU {:0.4F} Recall {:0.4f}'.format(
mode, loss_meter.avg, mean_acc.item(), mean_iou.item(), mean_recall.item())
self.recorder.logger.info(log_str)
result_metrics = {
'Acc': mean_acc.item(),
'IOU': mean_iou.item(),
'Recall': mean_recall.item()
}
return result_metrics