Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesLim-sy committed Aug 19, 2022
1 parent 61b76c5 commit b6fe413
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions examples/machine_translation/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def do_train(args):
(args.trg_vocab_size - 1) + 1e-20))

step_idx = 0
tokens_sum = 0

# For benchmark
reader_cost_avg = AverageStatistical()
Expand All @@ -217,6 +218,7 @@ def do_train(args):
train_reader_cost = time.time() - batch_start
(src_word, trg_word, lbl_word) = input_data

token_num = 0
if args.use_amp:
with paddle.amp.auto_cast(custom_black_list={
'scale', 'reduce_sum', 'elementwise_div'
Expand All @@ -225,7 +227,6 @@ def do_train(args):
logits = transformer(src_word=src_word, trg_word=trg_word)
sum_cost, avg_cost, token_num = criterion(logits, lbl_word)

tokens_per_cards = token_num.numpy()
scaled = scaler.scale(avg_cost) # scale the loss
scaled.backward() # do backward

Expand All @@ -238,7 +239,6 @@ def do_train(args):
else:
logits = transformer(src_word=src_word, trg_word=trg_word)
sum_cost, avg_cost, token_num = criterion(logits, lbl_word)
tokens_per_cards = token_num.numpy()

avg_cost.backward()

Expand All @@ -248,7 +248,9 @@ def do_train(args):
train_batch_cost = time.time() - batch_start
reader_cost_avg.record(train_reader_cost)
batch_cost_avg.record(train_batch_cost)
batch_ips_avg.record(train_batch_cost, tokens_per_cards)
batch_ips_avg.record(train_batch_cost, 0)

tokens_sum += token_num

# Profile for model benchmark
if args.profiler_options is not None:
Expand All @@ -258,6 +260,9 @@ def do_train(args):
if step_idx % args.print_step == 0 and (args.benchmark
or rank == 0):
total_avg_cost = avg_cost.numpy()
tokens_sum_val = tokens_sum.numpy()
batch_ips_avg.record(0, tokens_sum_val)
tokens_sum = 0

if step_idx == 0:
logger.info(
Expand Down

0 comments on commit b6fe413

Please sign in to comment.