Skip to content

Commit

Permalink
add rt-detr-ssod
Browse files Browse the repository at this point in the history
  • Loading branch information
wjm202 committed Jun 8, 2023
1 parent eb66b90 commit da3d600
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ppdet.core.workspace import load_config, merge_config

from ppdet.engine import Trainer, TrainerCot, init_parallel_env, set_random_seed, init_fleet_env
from ppdet.engine.trainer_ssod import Trainer_DenseTeacher, Trainer_ARSL
from ppdet.engine.trainer_ssod import Trainer_DenseTeacher, Trainer_ARSL, Trainer_Semi_detr

from ppdet.slim import build_slim_model

Expand Down Expand Up @@ -134,10 +134,11 @@ def run(FLAGS, cfg):
trainer = Trainer_DenseTeacher(cfg, mode='train')
elif ssod_method == 'ARSL':
trainer = Trainer_ARSL(cfg, mode='train')
elif ssod_method == 'Trainer_Semi_detr':
trainer = Trainer_Semi_detr(cfg, mode='train')
else:
raise ValueError(
"Semi-Supervised Object Detection only support DenseTeacher and ARSL now."
)
"Semi-Supervised Object Detection only no support this method.")
elif cfg.get('use_cot', False):
trainer = TrainerCot(cfg, mode='train')
else:
Expand All @@ -146,6 +147,10 @@ def run(FLAGS, cfg):
# load weights
if FLAGS.resume is not None:
trainer.resume_weights(FLAGS.resume)
elif 'pretrain_student_weights' in cfg and 'pretrain_teacher_weights' in cfg \
and cfg.pretrain_teacher_weights and cfg.pretrain_student_weights:
trainer.load_semi_weights(cfg.pretrain_teacher_weights,
cfg.pretrain_student_weights)
elif 'pretrain_weights' in cfg and cfg.pretrain_weights:
trainer.load_weights(cfg.pretrain_weights)

Expand Down

0 comments on commit da3d600

Please sign in to comment.