Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SWDEV-373814 - Updated to enqueue graph launches and then syncronize #3307

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions tests/src/runtimeApi/graph/hipSimpleGraphWithKernel.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#include <stdio.h>
#include <iostream>
#include "hip/hip_runtime.h"
#include <test_common.h>
#include <chrono>
#include <unistd.h>

#include <chrono>
#include <iostream>

#include "hip/hip_runtime.h"
/* HIT_START
* BUILD: %t %s ../../test_common.cpp
* TEST: %t
Expand All @@ -19,7 +21,7 @@ __global__ void simpleKernel(float* out_d, float* in_d) {
if (idx < N) out_d[idx] = CONSTANT * in_d[idx];
}

bool hipTestWithGraph() {
bool hipTestWithGraph(int nstep, int nkernel) {
int deviceId;
HIPCHECK(hipGetDevice(&deviceId));
hipDeviceProp_t props;
Expand Down Expand Up @@ -47,17 +49,18 @@ bool hipTestWithGraph() {
hipGraphExec_t instance;

hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal);
for (int ikrnl = 0; ikrnl < NKERNEL; ikrnl++) {
for (int ikrnl = 0; ikrnl < nkernel; ikrnl++) {
simpleKernel<<<dim3(N / 512, 1, 1), dim3(512, 1, 1), 0, stream>>>(out_d, in_d);
}
hipStreamEndCapture(stream, &graph);
hipGraphInstantiate(&instance, graph, NULL, NULL, 0);

auto start1 = std::chrono::high_resolution_clock::now();
for (int istep = 0; istep < NSTEP; istep++) {
for (int istep = 0; istep < nstep; istep++) {
hipGraphLaunch(instance, stream);
hipStreamSynchronize(stream);
}
hipStreamSynchronize(stream);

auto stop = std::chrono::high_resolution_clock::now();
auto resultWithInit = std::chrono::duration<double, std::milli>(stop - start);
auto resultWithoutInit = std::chrono::duration<double, std::milli>(stop - start1);
Expand All @@ -80,12 +83,13 @@ bool hipTestWithGraph() {
return true;
}

bool hipTestWithoutGraph() {
bool hipTestWithoutGraph(int nstep, int nkernel) {
int deviceId;
HIPCHECK(hipGetDevice(&deviceId));
hipDeviceProp_t props;
HIPCHECK(hipGetDeviceProperties(&props, deviceId));
printf("info: running on device #%d %s\n", deviceId, props.name);
printf("info: running on device #%d %s with graph size & launches:%d %d \n", deviceId, props.name,
nkernel, nstep);

hipStream_t stream;
HIPCHECK(hipStreamCreate(&stream));
Expand All @@ -104,12 +108,12 @@ bool hipTestWithoutGraph() {

// start CPU wallclock timer
auto start = std::chrono::high_resolution_clock::now();
for (int istep = 0; istep < NSTEP; istep++) {
for (int ikrnl = 0; ikrnl < NKERNEL; ikrnl++) {
for (int istep = 0; istep < nstep; istep++) {
for (int ikrnl = 0; ikrnl < nkernel; ikrnl++) {
simpleKernel<<<dim3(N / 512, 1, 1), dim3(512, 1, 1), 0, stream>>>(out_d, in_d);
}
HIPCHECK(hipStreamSynchronize(stream));
}
HIPCHECK(hipStreamSynchronize(stream));
auto stop = std::chrono::high_resolution_clock::now();
auto result = std::chrono::duration<double, std::milli>(stop - start);
std::cout << "Time taken for test without graph: "
Expand All @@ -130,13 +134,18 @@ bool hipTestWithoutGraph() {

int main(int argc, char* argv[]) {
bool status1, status2;
status1 = hipTestWithoutGraph();
status2 = hipTestWithGraph();
if (argc == 3) {
status1 = hipTestWithoutGraph(atoi(argv[1]), atoi(argv[2]));
status2 = hipTestWithGraph(atoi(argv[1]), atoi(argv[2]));
} else {
status1 = hipTestWithoutGraph(NSTEP, NKERNEL);
status2 = hipTestWithGraph(NSTEP, NKERNEL);
}
if (!status1) {
failed("Failed during test with hip graph\n");
failed("Failed during test without hip graph\n");
}
if (!status2) {
failed("Failed during test without graph\n");
failed("Failed during test with graph\n");
}
passed();
}