Skip to content

Commit

Permalink
[SPIR-V] Cast ptr kernel args to i8* when used as Store's value opera…
Browse files Browse the repository at this point in the history
…nd (#78603)

Handle a special case when StoreInst's value operand is a kernel
argument of a pointer type. Since these arguments could have either a
basic element type (e.g. float*) or OpenCL builtin type (sampler_t),
bitcast the StoreInst's value operand to default pointer element type
(i8).

This pull request addresses the issue
#72864
  • Loading branch information
michalpaszkowski authored Jan 29, 2024
1 parent 07dfa61 commit 0fbaf03
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 64 deletions.
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ add_llvm_target(SPIRVCodeGen
SPIRVISelLowering.cpp
SPIRVLegalizerInfo.cpp
SPIRVMCInstLower.cpp
SPIRVMetadata.cpp
SPIRVModuleAnalysis.cpp
SPIRVPreLegalizer.cpp
SPIRVPrepareFunctions.cpp
Expand Down
63 changes: 5 additions & 58 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "SPIRVBuiltins.h"
#include "SPIRVGlobalRegistry.h"
#include "SPIRVISelLowering.h"
#include "SPIRVMetadata.h"
#include "SPIRVRegisterInfo.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
Expand Down Expand Up @@ -117,64 +118,12 @@ static FunctionType *getOriginalFunctionType(const Function &F) {
return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
}

static MDString *getKernelArgAttribute(const Function &KernelFunction,
unsigned ArgIdx,
const StringRef AttributeName) {
assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL &&
"Kernel attributes are attached/belong only to kernel functions");

// Lookup the argument attribute in metadata attached to the kernel function.
MDNode *Node = KernelFunction.getMetadata(AttributeName);
if (Node && ArgIdx < Node->getNumOperands())
return cast<MDString>(Node->getOperand(ArgIdx));

// Sometimes metadata containing kernel attributes is not attached to the
// function, but can be found in the named module-level metadata instead.
// For example:
// !opencl.kernels = !{!0}
// !0 = !{void ()* @someKernelFunction, !1, ...}
// !1 = !{!"kernel_arg_addr_space", ...}
// In this case the actual index of searched argument attribute is ArgIdx + 1,
// since the first metadata node operand is occupied by attribute name
// ("kernel_arg_addr_space" in the example above).
unsigned MDArgIdx = ArgIdx + 1;
NamedMDNode *OpenCLKernelsMD =
KernelFunction.getParent()->getNamedMetadata("opencl.kernels");
if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
return nullptr;

// KernelToMDNodeList contains kernel function declarations followed by
// corresponding MDNodes for each attribute. Search only MDNodes "belonging"
// to the currently lowered kernel function.
MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
bool FoundLoweredKernelFunction = false;
for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() ==
KernelFunction.getName()) {
FoundLoweredKernelFunction = true;
continue;
}
if (MaybeValue && FoundLoweredKernelFunction)
return nullptr;

MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
if (FoundLoweredKernelFunction && MaybeNode &&
cast<MDString>(MaybeNode->getOperand(0))->getString() ==
AttributeName &&
MDArgIdx < MaybeNode->getNumOperands())
return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
}
return nullptr;
}

static SPIRV::AccessQualifier::AccessQualifier
getArgAccessQual(const Function &F, unsigned ArgIdx) {
if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
return SPIRV::AccessQualifier::ReadWrite;

MDString *ArgAttribute =
getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx);
if (!ArgAttribute)
return SPIRV::AccessQualifier::ReadWrite;

Expand All @@ -186,9 +135,8 @@ getArgAccessQual(const Function &F, unsigned ArgIdx) {
}

static std::vector<SPIRV::Decoration::Decoration>
getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
MDString *ArgAttribute =
getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual");
getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx);
if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0)
return {SPIRV::Decoration::Volatile};
return {};
Expand All @@ -209,8 +157,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
isSpecialOpaqueType(OriginalArgType))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);

MDString *MDKernelArgType =
getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx);
if (!MDKernelArgType || (!MDKernelArgType->getString().ends_with("*") &&
!MDKernelArgType->getString().ends_with("_t")))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
Expand Down
34 changes: 28 additions & 6 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//

#include "SPIRV.h"
#include "SPIRVMetadata.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/IR/IRBuilder.h"
Expand Down Expand Up @@ -282,7 +283,26 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
Value *Pointer;
Type *ExpectedElementType;
unsigned OperandToReplace;
if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
bool AllowCastingToChar = false;

StoreInst *SI = dyn_cast<StoreInst>(I);
if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
SI->getValueOperand()->getType()->isPointerTy() &&
isa<Argument>(SI->getValueOperand())) {
Argument *Arg = cast<Argument>(SI->getValueOperand());
MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
if (!ArgType || ArgType->getString().starts_with("uchar*"))
return;

// Handle special case when StoreInst's value operand is a kernel argument
// of a pointer type. Since these arguments could have either a basic
// element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast
// the StoreInst's value operand to default pointer element type (i8).
Pointer = Arg;
ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
OperandToReplace = 0;
AllowCastingToChar = true;
} else if (SI) {
Pointer = SI->getPointerOperand();
ExpectedElementType = SI->getValueOperand()->getType();
OperandToReplace = 1;
Expand Down Expand Up @@ -364,13 +384,15 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {

// Do not emit spv_ptrcast if it would cast to the default pointer element
// type (i8) of the same address space.
if (ExpectedElementType->isIntegerTy(8))
if (ExpectedElementType->isIntegerTy(8) && !AllowCastingToChar)
return;

// If this would be the first spv_ptrcast and there is no spv_assign_ptr_type
// for this pointer before, do not emit spv_ptrcast but emit
// spv_assign_ptr_type instead.
if (FirstPtrCastOrAssignPtrType && isa<Instruction>(Pointer)) {
// If this would be the first spv_ptrcast, the pointer's defining instruction
// requires spv_assign_ptr_type and does not already have one, do not emit
// spv_ptrcast and emit spv_assign_ptr_type instead.
Instruction *PointerDefInst = dyn_cast<Instruction>(Pointer);
if (FirstPtrCastOrAssignPtrType && PointerDefInst &&
requireAssignPtrType(PointerDefInst)) {
buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
ExpectedElementTypeConst, Pointer,
{IRB->getInt32(AddressSpace)});
Expand Down
92 changes: 92 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
//===--- SPIRVMetadata.cpp ---- IR Metadata Parsing Funcs -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains functions needed for parsing LLVM IR metadata relevant
// to the SPIR-V target.
//
//===----------------------------------------------------------------------===//

#include "SPIRVMetadata.h"

using namespace llvm;

static MDString *getOCLKernelArgAttribute(const Function &F, unsigned ArgIdx,
const StringRef AttributeName) {
assert(
F.getCallingConv() == CallingConv::SPIR_KERNEL &&
"Kernel attributes are attached/belong only to OpenCL kernel functions");

// Lookup the argument attribute in metadata attached to the kernel function.
MDNode *Node = F.getMetadata(AttributeName);
if (Node && ArgIdx < Node->getNumOperands())
return cast<MDString>(Node->getOperand(ArgIdx));

// Sometimes metadata containing kernel attributes is not attached to the
// function, but can be found in the named module-level metadata instead.
// For example:
// !opencl.kernels = !{!0}
// !0 = !{void ()* @someKernelFunction, !1, ...}
// !1 = !{!"kernel_arg_addr_space", ...}
// In this case the actual index of searched argument attribute is ArgIdx + 1,
// since the first metadata node operand is occupied by attribute name
// ("kernel_arg_addr_space" in the example above).
unsigned MDArgIdx = ArgIdx + 1;
NamedMDNode *OpenCLKernelsMD =
F.getParent()->getNamedMetadata("opencl.kernels");
if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
return nullptr;

// KernelToMDNodeList contains kernel function declarations followed by
// corresponding MDNodes for each attribute. Search only MDNodes "belonging"
// to the currently lowered kernel function.
MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
bool FoundLoweredKernelFunction = false;
for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
if (MaybeValue &&
dyn_cast<Function>(MaybeValue->getValue())->getName() == F.getName()) {
FoundLoweredKernelFunction = true;
continue;
}
if (MaybeValue && FoundLoweredKernelFunction)
return nullptr;

MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
if (FoundLoweredKernelFunction && MaybeNode &&
cast<MDString>(MaybeNode->getOperand(0))->getString() ==
AttributeName &&
MDArgIdx < MaybeNode->getNumOperands())
return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
}
return nullptr;
}

namespace llvm {

MDString *getOCLKernelArgAccessQual(const Function &F, unsigned ArgIdx) {
assert(
F.getCallingConv() == CallingConv::SPIR_KERNEL &&
"Kernel attributes are attached/belong only to OpenCL kernel functions");
return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
}

MDString *getOCLKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
assert(
F.getCallingConv() == CallingConv::SPIR_KERNEL &&
"Kernel attributes are attached/belong only to OpenCL kernel functions");
return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_type_qual");
}

MDString *getOCLKernelArgType(const Function &F, unsigned ArgIdx) {
assert(
F.getCallingConv() == CallingConv::SPIR_KERNEL &&
"Kernel attributes are attached/belong only to OpenCL kernel functions");
return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
}

} // namespace llvm
31 changes: 31 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVMetadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===--- SPIRVMetadata.h ---- IR Metadata Parsing Funcs ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains functions needed for parsing LLVM IR metadata relevant
// to the SPIR-V target.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVMETADATA_H
#define LLVM_LIB_TARGET_SPIRV_SPIRVMETADATA_H

#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"

namespace llvm {

//===----------------------------------------------------------------------===//
// OpenCL Metadata
//

MDString *getOCLKernelArgAccessQual(const Function &F, unsigned ArgIdx);
MDString *getOCLKernelArgTypeQual(const Function &F, unsigned ArgIdx);
MDString *getOCLKernelArgType(const Function &F, unsigned ArgIdx);

} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_METADATA_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#CHAR:]] = OpTypeInt 8
; CHECK-DAG: %[[#GLOBAL_PTR_CHAR:]] = OpTypePointer CrossWorkgroup %[[#CHAR]]

define spir_kernel void @foo(ptr addrspace(1) %arg) {
ret void
}

; CHECK: %[[#]] = OpFunctionParameter %[[#GLOBAL_PTR_CHAR]]
14 changes: 14 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/kernel-argument-ptr-no-bitcast.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#CHAR:]] = OpTypeInt 8
; CHECK-DAG: %[[#GLOBAL_PTR_CHAR:]] = OpTypePointer CrossWorkgroup %[[#CHAR]]

define spir_kernel void @foo(i8 %a, ptr addrspace(1) %p) {
store i8 %a, ptr addrspace(1) %p
ret void
}

; CHECK: %[[#A:]] = OpFunctionParameter %[[#CHAR]]
; CHECK: %[[#P:]] = OpFunctionParameter %[[#GLOBAL_PTR_CHAR]]
; CHECK: OpStore %[[#P]] %[[#A]]
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#CHAR:]] = OpTypeInt 8
; CHECK-DAG: %[[#GLOBAL_PTR_CHAR:]] = OpTypePointer CrossWorkgroup %[[#CHAR]]

define spir_kernel void @foo(ptr addrspace(1) %arg) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_base_type !3 !kernel_arg_type_qual !4 {
%var = alloca ptr addrspace(1), align 8
; CHECK: %[[#]] = OpFunctionParameter %[[#GLOBAL_PTR_CHAR]]
; CHECK-NOT: %[[#]] = OpBitcast %[[#]] %[[#]]
store ptr addrspace(1) %arg, ptr %var, align 8
ret void
}

!1 = !{i32 1}
!2 = !{!"none"}
!3 = !{!"char*"}
!4 = !{!""}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

define spir_kernel void @foo(ptr addrspace(1) %arg) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_base_type !3 !kernel_arg_type_qual !4 {
%var = alloca ptr addrspace(1), align 8
; CHECK: %[[#VAR:]] = OpVariable %[[#]] Function
store ptr addrspace(1) %arg, ptr %var, align 8
; The test itends to verify that OpStore uses OpVariable result directly (without a bitcast).
; Other type checking is done by spirv-val.
; CHECK: OpStore %[[#VAR]] %[[#]] Aligned 8
%lod = load ptr addrspace(1), ptr %var, align 8
%idx = getelementptr inbounds i64, ptr addrspace(1) %lod, i64 0
ret void
}

!1 = !{i32 1}
!2 = !{!"none"}
!3 = !{!"ulong*"}
!4 = !{!""}

0 comments on commit 0fbaf03

Please sign in to comment.