diff --git a/train.py b/train.py index 9041b3c385..ebce2f29b0 100644 --- a/train.py +++ b/train.py @@ -132,7 +132,7 @@ def main(args): # Only support for the DeepLabv3+ model if args.data_format == 'NHWC': - if cfg.dic['model'] != 'DeepLabV3P': + if cfg.dic['model']['type'] != 'DeepLabV3P': raise ValueError( 'The "NHWC" data format only support the DeepLabV3P model!') cfg.dic['model']['data_format'] = args.data_format