Skip to content

Commit

Permalink
Fix QAT 2.0 TypeError for transformer network (#1877)
Browse files Browse the repository at this point in the history
Signed-off-by: Sagar Hathwar <quic_shathwar@quicinc.com>
  • Loading branch information
quic-shathwar authored Jan 19, 2023
1 parent 840c25b commit eda99b2
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _compute_dloss_by_dmin_dmax_and_dx(inputs: tf.Tensor, encoding_min: tf.Varia
dloss_by_dmin = tf.cast(_compute_dloss_by_dmin_using_dmax(dloss_by_dmax), tf.float64)

# Pass through gradient for skipped ops
dloss_by_dx = tf.cond(tf.equal(op_mode, 3), lambda: grad, lambda: dloss_by_dx)
dloss_by_dx = tf.cond(tf.equal(op_mode, 3), lambda: tf.convert_to_tensor(grad), lambda: dloss_by_dx)

return dloss_by_dmin, dloss_by_dmax, dloss_by_dx

Expand Down

0 comments on commit eda99b2

Please sign in to comment.