Skip to content

Commit

Permalink
Removed the function body of __devicelib_assert_fail. (triton-lang#1808)
Browse files Browse the repository at this point in the history
This PR fixes triton-lang#1176 
IGC detects the call of `__devicelib_assert_fail` and replace it with a
'safe' implementation.
However, the SYCL library contains a 'fallback' implementation of
assertion, which does not work in our setup.
If we mark the function with `InternalLinkage`, the fallback
implementation is inlined and IGC cannot replace it with the safe
implementation.
By declaring `__devicelib_assert_fail` as an external function in SYCL
library, IGC can correctly insert its implementation.
The diff between the old and new `libsycl-spir64-unknown-unknown.ll` is
as follows:
```diff
@@ -5424,149 +5424,7 @@ declare extern_weak dso_local spir_func noundef i32 @_Z18__spirv_AtomicLoadPU3AS
 declare void @llvm.memcpy.p4.p1.i64(ptr addrspace(4) noalias nocapture writeonly, ptr addrspace(1) noalias nocapture readonly, i64, i1 immarg) triton-lang#16
 
 ; Function Attrs: convergent mustprogress norecurse nounwind
-define weak dso_local spir_func void @__devicelib_assert_fail(ptr addrspace(4) noundef %0, ptr addrspace(4) noundef %1, i32 noundef %2, ptr addrspace(4) noundef %3, i64 noundef %4, i64 noundef %5, i64 noundef %6, i64 noundef %7, i64 noundef %8, i64 noundef %9) local_unnamed_addr triton-lang#14 !srcloc !720 {
-  %11 = tail call spir_func noundef i32 @_Z29__spirv_AtomicCompareExchangePU3AS1iN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagES5_ii(ptr addrspace(1) noundef @SPIR_AssertHappenedMem, i32 noundef 1, i32 noundef 16, i32 noundef 16, i32 noundef 1, i32 noundef 0) triton-lang#54
-  %12 = icmp eq i32 %11, 0
-  br i1 %12, label %13, label %92
-
-13:                                               ; preds = %10
-  store i32 %2, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 4), align 8, !tbaa !721
-  store i64 %4, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 5), align 8, !tbaa !722
-  store i64 %5, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 6), align 8, !tbaa !723
-  store i64 %6, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 7), align 8, !tbaa !724
-  store i64 %7, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 8), align 8, !tbaa !725
-  store i64 %8, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 9), align 8, !tbaa !726
-  store i64 %9, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 10), align 8, !tbaa !727
-  %14 = icmp eq ptr addrspace(4) %0, null
-  br i1 %14, label %23, label %15
-
-15:                                               ; preds = %20, %13
-  %16 = phi i32 [ %22, %20 ], [ 0, %13 ]
-  %17 = phi ptr addrspace(4) [ %21, %20 ], [ %0, %13 ]
-  %18 = load i8, ptr addrspace(4) %17, align 1, !tbaa !718
-  %19 = icmp eq i8 %18, 0
-  br i1 %19, label %23, label %20
-
-20:                                               ; preds = %15
-  %21 = getelementptr inbounds i8, ptr addrspace(4) %17, i64 1
-  %22 = add nuw nsw i32 %16, 1
-  br label %15, !llvm.loop !728
-
-23:                                               ; preds = %15, %13
-  %24 = phi i32 [ 0, %13 ], [ %16, %15 ]
-  %25 = icmp eq ptr addrspace(4) %1, null
-  br i1 %25, label %34, label %26
-
-26:                                               ; preds = %31, %23
-  %27 = phi i32 [ %33, %31 ], [ 0, %23 ]
-  %28 = phi ptr addrspace(4) [ %32, %31 ], [ %1, %23 ]
-  %29 = load i8, ptr addrspace(4) %28, align 1, !tbaa !718
-  %30 = icmp eq i8 %29, 0
-  br i1 %30, label %34, label %31
-
-31:                                               ; preds = %26
-  %32 = getelementptr inbounds i8, ptr addrspace(4) %28, i64 1
-  %33 = add nuw nsw i32 %27, 1
-  br label %26, !llvm.loop !729
-
-34:                                               ; preds = %26, %23
-  %35 = phi i32 [ 0, %23 ], [ %27, %26 ]
-  %36 = icmp eq ptr addrspace(4) %3, null
-  br i1 %36, label %37, label %40
-
-37:                                               ; preds = %34
-  %38 = tail call i32 @llvm.umin.i32(i32 %24, i32 256)
-  %39 = tail call i32 @llvm.umin.i32(i32 %35, i32 256)
-  br label %52
-
-40:                                               ; preds = %45, %34
-  %41 = phi i32 [ %47, %45 ], [ 0, %34 ]
-  %42 = phi ptr addrspace(4) [ %46, %45 ], [ %3, %34 ]
-  %43 = load i8, ptr addrspace(4) %42, align 1, !tbaa !718
-  %44 = icmp eq i8 %43, 0
-  br i1 %44, label %48, label %45
-
-45:                                               ; preds = %40
-  %46 = getelementptr inbounds i8, ptr addrspace(4) %42, i64 1
-  %47 = add i32 %41, 1
-  br label %40, !llvm.loop !730
-
-48:                                               ; preds = %40
-  %49 = tail call i32 @llvm.umin.i32(i32 %24, i32 256)
-  %50 = tail call i32 @llvm.umin.i32(i32 %35, i32 256)
-  %51 = tail call i32 @llvm.umin.i32(i32 %41, i32 128)
-  br label %52
-
-52:                                               ; preds = %48, %37
-  %53 = phi i32 [ %39, %37 ], [ %50, %48 ]
-  %54 = phi i32 [ %38, %37 ], [ %49, %48 ]
-  %55 = phi i32 [ 0, %37 ], [ %51, %48 ]
-  br label %56
-
-56:                                               ; preds = %62, %52
-  %57 = phi i32 [ 0, %52 ], [ %67, %62 ]
-  %58 = icmp ult i32 %57, %54
-  br i1 %58, label %62, label %59
-
-59:                                               ; preds = %56
-  %60 = zext nneg i32 %54 to i64
-  %61 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 1, i64 %60
-  store i8 0, ptr addrspace(1) %61, align 1, !tbaa !718
-  br label %68
-
-62:                                               ; preds = %56
-  %63 = sext i32 %57 to i64
-  %64 = getelementptr inbounds i8, ptr addrspace(4) %0, i64 %63
-  %65 = load i8, ptr addrspace(4) %64, align 1, !tbaa !718
-  %66 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 1, i64 %63
-  store i8 %65, ptr addrspace(1) %66, align 1, !tbaa !718
-  %67 = add nuw nsw i32 %57, 1
-  br label %56, !llvm.loop !731
-
-68:                                               ; preds = %74, %59
-  %69 = phi i32 [ 0, %59 ], [ %79, %74 ]
-  %70 = icmp ult i32 %69, %53
-  br i1 %70, label %74, label %71
-
-71:                                               ; preds = %68
-  %72 = zext nneg i32 %53 to i64
-  %73 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 2, i64 %72
-  store i8 0, ptr addrspace(1) %73, align 1, !tbaa !718
-  br label %80
-
-74:                                               ; preds = %68
-  %75 = sext i32 %69 to i64
-  %76 = getelementptr inbounds i8, ptr addrspace(4) %1, i64 %75
-  %77 = load i8, ptr addrspace(4) %76, align 1, !tbaa !718
-  %78 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 2, i64 %75
-  store i8 %77, ptr addrspace(1) %78, align 1, !tbaa !718
-  %79 = add nuw nsw i32 %69, 1
-  br label %68, !llvm.loop !732
-
-80:                                               ; preds = %86, %71
-  %81 = phi i32 [ 0, %71 ], [ %91, %86 ]
-  %82 = icmp ult i32 %81, %55
-  br i1 %82, label %86, label %83
-
-83:                                               ; preds = %80
-  %84 = sext i32 %55 to i64
-  %85 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 3, i64 %84
-  store i8 0, ptr addrspace(1) %85, align 1, !tbaa !718
-  tail call spir_func void @_Z19__spirv_AtomicStorePU3AS1iN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagEi(ptr addrspace(1) noundef @SPIR_AssertHappenedMem, i32 noundef 1, i32 noundef 16, i32 noundef 2) triton-lang#54
-  br label %92
-
-86:                                               ; preds = %80
-  %87 = sext i32 %81 to i64
-  %88 = getelementptr inbounds i8, ptr addrspace(4) %3, i64 %87
-  %89 = load i8, ptr addrspace(4) %88, align 1, !tbaa !718
-  %90 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 3, i64 %87
-  store i8 %89, ptr addrspace(1) %90, align 1, !tbaa !718
-  %91 = add nuw nsw i32 %81, 1
-  br label %80, !llvm.loop !733
-
-92:                                               ; preds = %83, %10
-  ret void
-}
+declare extern_weak dso_local spir_func void @__devicelib_assert_fail(ptr addrspace(4) noundef %0, ptr addrspace(4) noundef %1, i32 noundef %2, ptr addrspace(4) noundef %3, i64 noundef %4, i64 noundef %5, i64 noundef %6, i64 noundef %7, i64 noundef %8, i64 noundef %9) local_unnamed_addr triton-lang#14
 
 ; Function Attrs: convergent nounwind
 declare extern_weak dso_local spir_func noundef i32 @_Z29__spirv_AtomicCompareExchangePU3AS1iN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagES5_ii(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr triton-lang#15

```
  • Loading branch information
hwnam831 authored Aug 15, 2024
1 parent 21232be commit ff74649
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
5 changes: 0 additions & 5 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,6 @@ void init_triton_llvm(py::module &&m) {
// linkage as a signifier of kernel functions.
for (llvm::Function &fn : dstMod->functions()) {
if (externalFns.count(fn.getName().str())) {
// FIXME: Temporary workaround to avoid __devicelib_assert_fail
// optimization with InternalLinkage, which causes
// test_subprocess.py::test_assert to fail.
if (fn.getName().str() == "__devicelib_assert_fail")
continue;
fn.setLinkage(llvm::GlobalValue::InternalLinkage);
}
}
Expand Down
Binary file modified third_party/intel/backend/lib/libsycl-spir64-unknown-unknown.bc
Binary file not shown.
10 changes: 10 additions & 0 deletions third_party/intel/lib/Target/LLVMIR/PostProcess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ void postProcessLLVMIR(llvm::Module &mod) {
}
};

// __devicelib_assert_fail must be a declaration so that
// IGC can replace it with a runtime assert function.
// If a 'fallback' implementation is defined in SYCL libarary, the
// assertion does not work correctly.
for (auto &f : mod) {
if (f.getName().str() == "__devicelib_assert_fail") {
assert(f.isDeclaration() &&
"__devicelib_assert_fail must be a declaration!");
}
}
print("PostProcessing: Before SLPVectorizer", mod);
SLPVectorizer(mod, trace);
print("PostProcessing: After SLPVectorizer", mod);
Expand Down

0 comments on commit ff74649

Please sign in to comment.