Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St][NO.2] pir dy2st unittest fix test_bert - Part 1 #60164

Merged
merged 1 commit into from
Dec 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 78 additions & 68 deletions test/dygraph_to_static/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,21 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_legacy_and_pt_and_pir,
test_sot_only,
)
from predictor_utils import PredictorTools

import paddle
from paddle import base
from paddle.base import core
from paddle.base.framework import unique_name
from paddle.framework import use_pir_api
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX

place = base.CUDAPlace(0) if base.is_compiled_with_cuda() else base.CPUPlace()
place = (
paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
)
SEED = 2020
STEP_NUM = 10
PRINT_STEP = 2
Expand Down Expand Up @@ -95,7 +100,7 @@ def tearDown(self):
self.temp_dir.cleanup()

def train(self, bert_config, data_reader, to_static):
with base.dygraph.guard(place):
with unique_name.guard():
base.default_main_program().random_seed = SEED
base.default_startup_program().random_seed = SEED

Expand Down Expand Up @@ -158,7 +163,9 @@ def train(self, bert_config, data_reader, to_static):
step_idx += 1
if step_idx == STEP_NUM:
if to_static:
paddle.jit.save(bert, self.model_save_prefix)
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(bert, self.model_save_prefix)
else:
paddle.save(
bert.state_dict(),
Expand All @@ -172,8 +179,7 @@ def train_dygraph(self, bert_config, data_reader):
return self.train(bert_config, data_reader, False)

def train_static(self, bert_config, data_reader):
with enable_to_static_guard(True):
return self.train(bert_config, data_reader, True)
return self.train(bert_config, data_reader, True)

def predict_static(self, data):
paddle.enable_static()
Expand All @@ -195,11 +201,12 @@ def predict_static(self, data):
fetch_list=fetch_targets,
)

paddle.disable_static()
return pred_res

def predict_dygraph(self, bert_config, data):
with enable_to_static_guard(False):
with base.dygraph.guard(place):
with unique_name.guard():
bert = PretrainModelLayer(
config=bert_config, weight_sharing=False, use_fp16=False
)
Expand All @@ -210,7 +217,7 @@ def predict_dygraph(self, bert_config, data):
bert.set_dict(model_dict)
bert.eval()

input_vars = [base.dygraph.to_variable(x) for x in data]
input_vars = [paddle.to_tensor(x) for x in data]
(
src_ids,
pos_ids,
Expand All @@ -234,31 +241,30 @@ def predict_dygraph(self, bert_config, data):
return pred_res

def predict_dygraph_jit(self, data):
with base.dygraph.guard(place):
bert = paddle.jit.load(self.model_save_prefix)
bert.eval()

(
src_ids,
pos_ids,
sent_ids,
input_mask,
mask_label,
mask_pos,
labels,
) = data
pred_res = bert(
src_ids,
pos_ids,
sent_ids,
input_mask,
mask_label,
mask_pos,
labels,
)
pred_res = [var.numpy() for var in pred_res]
bert = paddle.jit.load(self.model_save_prefix)
bert.eval()

(
src_ids,
pos_ids,
sent_ids,
input_mask,
mask_label,
mask_pos,
labels,
) = data
pred_res = bert(
src_ids,
pos_ids,
sent_ids,
input_mask,
mask_label,
mask_pos,
labels,
)
pred_res = [var.numpy() for var in pred_res]

return pred_res
return pred_res

def predict_analysis_inference(self, data):
output = PredictorTools(
Expand All @@ -267,6 +273,7 @@ def predict_analysis_inference(self, data):
out = output()
return out

@test_legacy_and_pt_and_pir
def test_train(self):
static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader
Expand All @@ -280,6 +287,7 @@ def test_train(self):
self.verify_predict()

@test_sot_only
@test_legacy_and_pt_and_pir
def test_train_composite(self):
core._set_prim_backward_enabled(True)
# core._add_skip_comp_ops("layer_norm")
Expand All @@ -297,43 +305,45 @@ def test_train_composite(self):
def verify_predict(self):
for data in self.data_reader.data_generator()():
dygraph_pred_res = self.predict_dygraph(self.bert_config, data)
static_pred_res = self.predict_static(data)
dygraph_jit_pred_res = self.predict_dygraph_jit(data)
predictor_pred_res = self.predict_analysis_inference(data)

for dy_res, st_res, dy_jit_res, predictor_res in zip(
dygraph_pred_res,
static_pred_res,
dygraph_jit_pred_res,
predictor_pred_res,
):
np.testing.assert_allclose(
st_res,
dy_res,
rtol=1e-05,
err_msg='dygraph_res: {},\n static_res: {}'.format(
dy_res[~np.isclose(st_res, dy_res)],
st_res[~np.isclose(st_res, dy_res)],
),
)
np.testing.assert_allclose(
st_res,
dy_jit_res,
rtol=1e-05,
err_msg='dygraph_jit_res: {},\n static_res: {}'.format(
dy_jit_res[~np.isclose(st_res, dy_jit_res)],
st_res[~np.isclose(st_res, dy_jit_res)],
),
)
np.testing.assert_allclose(
st_res,
predictor_res,
rtol=1e-05,
err_msg='dygraph_jit_res: {},\n static_res: {}'.format(
predictor_res[~np.isclose(st_res, predictor_res)],
st_res[~np.isclose(st_res, predictor_res)],
),
)
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
static_pred_res = self.predict_static(data)
dygraph_jit_pred_res = self.predict_dygraph_jit(data)
predictor_pred_res = self.predict_analysis_inference(data)

for dy_res, st_res, dy_jit_res, predictor_res in zip(
dygraph_pred_res,
static_pred_res,
dygraph_jit_pred_res,
predictor_pred_res,
):
np.testing.assert_allclose(
st_res,
dy_res,
rtol=1e-05,
err_msg='dygraph_res: {},\n static_res: {}'.format(
dy_res[~np.isclose(st_res, dy_res)],
st_res[~np.isclose(st_res, dy_res)],
),
)
np.testing.assert_allclose(
st_res,
dy_jit_res,
rtol=1e-05,
err_msg='dygraph_jit_res: {},\n static_res: {}'.format(
dy_jit_res[~np.isclose(st_res, dy_jit_res)],
st_res[~np.isclose(st_res, dy_jit_res)],
),
)
np.testing.assert_allclose(
st_res,
predictor_res,
rtol=1e-05,
err_msg='dygraph_jit_res: {},\n static_res: {}'.format(
predictor_res[~np.isclose(st_res, predictor_res)],
st_res[~np.isclose(st_res, predictor_res)],
),
)
break


Expand Down