Skip to content

Commit

Permalink
Merge pull request #36 from YacobBY/patch-1
Browse files Browse the repository at this point in the history
Nvidia Apex for FP16 calculations
  • Loading branch information
ku21fan committed Jul 24, 2019
2 parents 40a4100 + 2d45ba2 commit 5d4ed38
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
from model import Model
from test import validation

try:
from apex import amp
from apex import fp16_utils
APEX_AVAILABLE = True
amp_handle = amp.init(enabled=True)
except ModuleNotFoundError:
APEX_AVAILABLE = False

def train(opt):
""" dataset preparation """
Expand All @@ -42,7 +49,7 @@ def train(opt):

if opt.rgb:
opt.input_channel = 3
model = Model(opt)
model = Model(opt).cuda()
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
opt.SequenceModeling, opt.Prediction)
Expand All @@ -62,9 +69,7 @@ def train(opt):
param.data.fill_(1)
continue

# data parallel for multi-GPU
model = torch.nn.DataParallel(model).cuda()
model.train()

if opt.continue_model != '':
print(f'loading pretrained model from {opt.continue_model}')
model.load_state_dict(torch.load(opt.continue_model))
Expand Down Expand Up @@ -118,6 +123,13 @@ def train(opt):
best_norm_ED = 1e+6
i = start_iter

if APEX_AVAILABLE:
model, optimizer = amp.initialize(model, optimizer, opt_level="O2")

# data parallel for multi-GPU
model = torch.nn.DataParallel(model).cuda()
model.train()

while(True):
# train part
for p in model.parameters():
Expand All @@ -140,8 +152,13 @@ def train(opt):
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

model.zero_grad()
cost.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default)
if APEX_AVAILABLE:
with amp.scale_loss(cost, optimizer) as scaled_loss:
scaled_loss.backward()
fp16_utils.clip_grad_norm(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default)
else:
cost.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default)
optimizer.step()

loss_avg.add(cost)
Expand Down

0 comments on commit 5d4ed38

Please sign in to comment.