Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upstream static #807

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions static/csrc/model_container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,9 @@ void ModelContainer::SetConstantImpl(
". Check that the provided tensor's shape is correct.");
}
} else {
throw std::runtime_error(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change? We want to throw an error in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there are some additional variables won't be loaded to the ait model, only warning them is a better solution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any examples? This error means that a constant is neither bound nor unbound, which is unexpected. It means that there is something wrong in this constant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted

std::string("Called SetConstant on ") + name +
std::string(" but can't find in either bound or unbound constant set"));
LOG(WARNING) << "Called SetConstant on " << name
<< " but can't find in either bound or unbound constant set";
return;
}

auto* src = tensor.ptr;
Expand Down
2 changes: 1 addition & 1 deletion static/include/cuda_device_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ inline DeviceError QueryEvent(EventType event) {
return cudaEventQuery(event);
}

inline const char* GetErrorString(DeviceError err) {
inline std::string GetErrorString(DeviceError err) {
return cudaGetErrorString(err);
}

Expand Down
1 change: 0 additions & 1 deletion static/include/debug_utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "device_functions-generated.h"

Expand Down
2 changes: 1 addition & 1 deletion static/include/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#define DEVICE_CHECK(call) \
if ((call) != GetDeviceSuccess()) { \
throw std::runtime_error( \
#call " API call failed: " + GetLastErrorString() + " at " + \
#call " API call failed: " + GetErrorString(call) + " at " + \
__FILE__ + ", line" + std::to_string(__LINE__)); \
}

Expand Down
29 changes: 25 additions & 4 deletions static/include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ namespace ait {
inline void DeviceCheckLastError(const char* file, int line) {
auto device_error = GetLastError();
if (device_error != GetDeviceSuccess()) {
std::string msg = std::string("Got error: ") +
cudaGetErrorString(device_error) +
std::string msg = std::string("Got error: ") + GetErrorString(device_error) +
" enum: " + std::to_string(device_error) + " at " + file + ": " +
std::to_string(line);
LOG(ERROR) << msg;
Expand Down Expand Up @@ -217,6 +216,29 @@ class ModelBase {
}

void RunAsGraph(StreamType stream) {
#ifdef __HIP_PLATFORM_HCC__
if (graph_exec_ == nullptr) {
DEVICE_CHECK(StreamBeginCapture(graph_capture_stream_, /*global=*/false));
try {
static_cast<ModelType*>(this)->RunImpl(graph_capture_stream_);
} catch (...) {
GraphType graph;
// No need to DEVICE_CHECK here, we want to see the original exception.
EndCapture(&graph);
if (graph != nullptr && GraphDestroy(graph) != GetDeviceSuccess()) {
LOG(WARNING)
<< "Graph destruction failed while handling exception! Memory will be leaked.";
}
throw;
}
// The following function ends the capture and creates a graph
// inside a unique_ptr that cleans up it when it goes out of scope.
// Note that it throws an exception if EndCapture fails.
auto graph = RAII_EndCaptureAndCreateGraph(
[this](GraphType* graph_ptr) { return EndCapture(graph_ptr); });
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
}
#else
DEVICE_CHECK(StreamBeginCapture(graph_capture_stream_, /*global=*/false));
try {
static_cast<ModelType*>(this)->RunImpl(graph_capture_stream_);
Expand All @@ -230,13 +252,11 @@ class ModelBase {
}
throw;
}

// The following function ends the capture and creates a graph
// inside a unique_ptr that cleans up it when it goes out of scope.
// Note that it throws an exception if EndCapture fails.
auto graph = RAII_EndCaptureAndCreateGraph(
[this](GraphType* graph_ptr) { return EndCapture(graph_ptr); });

if (graph_exec_ == nullptr) {
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
} else if (
Expand All @@ -247,6 +267,7 @@ class ModelBase {
DEVICE_CHECK(GraphExecDestroy(graph_exec_));
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
}
#endif

DEVICE_CHECK(GraphExecLaunch(graph_exec_, stream));
}
Expand Down
15 changes: 8 additions & 7 deletions static/include/rocm_device_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

namespace ait {

inline thread_local bool target_has_graph_mode = false;
inline thread_local bool target_has_graph_mode = true;

using DeviceError = hipError_t;
using DevicePropertyType = hipDeviceProp_t;
Expand Down Expand Up @@ -57,7 +57,7 @@ inline std::string PrintArchFeatureFlags(const hipDeviceArch_t& arch) {
<< "\n Has 32-bit integer atomics for shared memory: "
<< (arch.hasSharedInt32Atomics ? "yes" : "no")
<< "\n Has 32-bit float atomic exch for shared memory: "
<< (arch.hasSharedFloatAtomicExch ? "yes" : "no"
<< (arch.hasSharedFloatAtomicExch ? "yes" : "no")
<< "\n Has 32-bit float atomic add in global and shared memory: "
<< (arch.hasFloatAtomicAdd ? "yes" : "no")
<< "\n Has 64-bit integer atomics for global memory: "
Expand All @@ -67,9 +67,9 @@ inline std::string PrintArchFeatureFlags(const hipDeviceArch_t& arch) {
<< "\n Has double-precision floating point: "
<< (arch.hasDoubles ? "yes" : "no")
<< "\n Has warp vote instructions (__any, __all): "
<< (arch.hasWarpVote: ? "yes" : "no")
<< (arch.hasWarpVote ? "yes" : "no")
<< "\n Has warp ballot instructions (__ballot): "
<< (arch.hasWarpBallot: ? "yes" : "no")
<< (arch.hasWarpBallot ? "yes" : "no")
<< "\n Has warp shuffle operations. (__shfl_*): "
<< (arch.hasWarpShuffle ? "yes" : "no")
<< "\n Has funnel two words into one with shift&mask caps: "
Expand Down Expand Up @@ -187,7 +187,7 @@ inline DeviceError StreamDestroy(StreamType stream) {
}

inline DeviceError StreamWaitEvent(StreamType stream, EventType event) {
return hipStreamWaitEvent(stream, event);
return hipStreamWaitEvent(stream, event, 0);
}

inline DeviceError GraphInstantiate(
Expand All @@ -202,7 +202,8 @@ inline DeviceError GraphDestroy(GraphType graph) {

inline DeviceError GraphExecUpdate(GraphExecType graph_exec, GraphType graph) {
// We don't have hipGraphExecUpdate in some versions of rocm
return hipErrorUnknown;
hipGraphExecUpdateResult update;
return hipGraphExecUpdate(graph_exec, graph, nullptr, &update);
}

inline DeviceError GraphExecDestroy(GraphExecType graph_exec) {
Expand Down Expand Up @@ -314,7 +315,7 @@ inline DeviceError QueryEvent(EventType event) {
return hipEventQuery(event);
}

inline const char* GetErrorString(DeviceError err) {
inline std::string GetErrorString(DeviceError err) {
return hipGetErrorString(err);
}

Expand Down