Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
ming1753 committed Sep 24, 2023
1 parent df1d708 commit e143621
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions test/cpp/inference/api/analysis_predictor_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,61 @@ TEST(Predictor, Streams) {
CHECK_NE(stream, stream2);
}
}

TEST(Tensor, RunWithExternalStream) {
Config config;
config.SetModel(FLAGS_dirname);
config.EnableUseGpu(100, 0);
cudaStream_t stream;
cudaStreamCreate(&stream);
config.SetExecStream(stream);
auto predictor = CreatePredictor(config);

auto w0 = predictor->GetInputHandle("firstw");
auto w1 = predictor->GetInputHandle("secondw");
auto w2 = predictor->GetInputHandle("thirdw");
auto w3 = predictor->GetInputHandle("forthw");

std::vector<std::vector<int64_t>> input_data(4, {0, 1, 2, 3});
std::vector<int64_t*> input_gpu(4, nullptr);

for (size_t i = 0; i < 4; ++i) {
cudaMalloc(reinterpret_cast<void**>(&input_gpu[i]), 4 * sizeof(int64_t));
cudaMemcpy(input_gpu[i],
input_data[i].data(),
4 * sizeof(int64_t),
cudaMemcpyHostToDevice);
}

w0->ShareExternalData<int64_t>(input_gpu[0], {4, 1}, PlaceType::kGPU);
w1->ShareExternalData<int64_t>(input_gpu[1], {4, 1}, PlaceType::kGPU);
w2->ShareExternalData<int64_t>(input_gpu[2], {4, 1}, PlaceType::kGPU);
w3->ShareExternalData<int64_t>(input_gpu[3], {4, 1}, PlaceType::kGPU);

auto out = predictor->GetOutputHandle("fc_1.tmp_2");
auto out_shape = out->shape();
float* out_data = nullptr;
auto out_size =
std::accumulate(
out_shape.begin(), out_shape.end(), 1, std::multiplies<int>()) *
sizeof(float);
cudaMalloc(reinterpret_cast<void**>(out_data), out_size * sizeof(float));
out->ShareExternalData<float>(out_data, out_shape, PlaceType::kGPU);

cudaStream_t external_stream;
cudaStreamCreate(&external_stream);
Config tmp_config(config);
tmp_config.SetExecStream(external_stream);
predictor->Run();
paddle_infer::experimental::InternalUtils::RunWithExternalStream(
predictor.get(), external_stream);

PlaceType place;
int size = 0;
out->data<float>(&place, &size);
LOG(INFO) << "output size: " << size / sizeof(float);
predictor->TryShrinkMemory();
}
#endif

TEST(AnalysisPredictor, OutputTensorHookFunc) {
Expand Down

0 comments on commit e143621

Please sign in to comment.