Skip to content

Commit

Permalink
Address review points
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaolu committed Nov 24, 2021
1 parent bed7933 commit 23ba4cb
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 54 deletions.
3 changes: 1 addition & 2 deletions tools/clang/include/clang/SPIRV/SpirvContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,7 @@ class SpirvContext {

const SpirvIntrinsicType *
getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode,
llvm::ArrayRef<SpirvConstant *> constants,
SpirvIntrinsicType *elementTy);
llvm::ArrayRef<SpvIntrinsicTypeOperand> operands);

SpirvIntrinsicType *getCreatedSpirvIntrinsicType(unsigned typeId);

Expand Down
27 changes: 20 additions & 7 deletions tools/clang/include/clang/SPIRV/SpirvType.h
Original file line number Diff line number Diff line change
Expand Up @@ -413,22 +413,35 @@ class RayQueryTypeKHR : public SpirvType {
}
};

class SpirvConstant;
class SpirvInstruction;
struct SpvIntrinsicTypeOperand {
SpvIntrinsicTypeOperand(SpirvType *type_operand)
: operand_as_type(type_operand), isTypeOperand(true) {}
SpvIntrinsicTypeOperand(SpirvInstruction *inst_operand)
: operand_as_inst(inst_operand), isTypeOperand(false) {}
union {
SpirvType *operand_as_type;
SpirvInstruction *operand_as_inst;
};
bool isTypeOperand;
};

class SpirvIntrinsicType : public SpirvType {
public:
SpirvIntrinsicType(unsigned typeOp, llvm::ArrayRef<SpirvConstant *> constants,
SpirvIntrinsicType *elementTy);
SpirvIntrinsicType(unsigned typeOp,
llvm::ArrayRef<SpvIntrinsicTypeOperand> inOps);

static bool classof(const SpirvType *t) {
return t->getKind() == TK_SpirvIntrinsicType;
}
unsigned getOpCode() const { return typeOpCode; }
llvm::ArrayRef<SpirvConstant *> getLiterals() const { return literals; }
SpirvIntrinsicType *getElemType() const { return elementType; }
llvm::ArrayRef<SpvIntrinsicTypeOperand> getOperands() const {
return operands;
}

private:
unsigned typeOpCode;
llvm::SmallVector<SpirvConstant *, 3> literals;
SpirvIntrinsicType *elementType;
llvm::SmallVector<SpvIntrinsicTypeOperand, 3> operands;
};

class HybridType : public SpirvType {
Expand Down
78 changes: 48 additions & 30 deletions tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2450,15 +2450,22 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
initTypeInstruction(spv::Op::OpTypeRayQueryKHR);
curTypeInst.push_back(id);
finalizeTypeInstruction();
}
else if (const auto *spvIntrinsicType = dyn_cast<SpirvIntrinsicType>(type)) {
} else if (const auto *spvIntrinsicType =
dyn_cast<SpirvIntrinsicType>(type)) {
initTypeInstruction(static_cast<spv::Op>(spvIntrinsicType->getOpCode()));
curTypeInst.push_back(id);
if (spvIntrinsicType->getElemType()) {
curTypeInst.push_back(emitType(spvIntrinsicType->getElemType()));
}
for (auto& literal : spvIntrinsicType->getLiterals()) {
emitLiteral(literal, curTypeInst);
for (const SpvIntrinsicTypeOperand &operand :
spvIntrinsicType->getOperands()) {
if (operand.isTypeOperand) {
curTypeInst.push_back(emitType(operand.operand_as_type));
} else {
auto *literal = dyn_cast<SpirvConstant>(operand.operand_as_inst);
if (literal && literal->isLiteral()) {
emitLiteral(literal, curTypeInst);
} else {
curTypeInst.push_back(getOrAssignResultId(operand.operand_as_inst));
}
}
}
finalizeTypeInstruction();
}
Expand All @@ -2477,36 +2484,47 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
return id;
}

template <typename vecType>
void EmitTypeHandler::emitIntLiteral(const SpirvConstantInteger *intLiteral,
vecType &outInst) {
const auto &literalVal = intLiteral->getValue();
bool positive = !literalVal.isNegative();
if (literalVal.getBitWidth() <= 32) {
outInst.push_back(positive ? literalVal.getZExtValue()
: literalVal.getSExtValue());
} else {
assert(literalVal.getBitWidth() == 64);
uint64_t val =
positive ? literalVal.getZExtValue() : literalVal.getSExtValue();
outInst.push_back(static_cast<unsigned>(val));
outInst.push_back(static_cast<unsigned>(val >> 32));
}
}

template <typename vecType>
void EmitTypeHandler::emitFloatLiteral(const SpirvConstantFloat *fLiteral,
vecType &outInst) {
const auto &literalVal = fLiteral->getValue();
const auto bitwidth = llvm::APFloat::getSizeInBits(literalVal.getSemantics());
if (bitwidth <= 32) {
outInst.push_back(literalVal.bitcastToAPInt().getZExtValue());
} else {
assert(bitwidth == 64);
uint64_t val = literalVal.bitcastToAPInt().getZExtValue();
outInst.push_back(static_cast<unsigned>(val));
outInst.push_back(static_cast<unsigned>(val >> 32));
}
}

template <typename VecType>
void EmitTypeHandler::emitLiteral(const SpirvConstant *literal,
VecType &outInst) {
if (auto boolLiteral = dyn_cast<SpirvConstantBoolean>(literal)) {
outInst.push_back(static_cast<unsigned>(boolLiteral->getValue()));
} else if (auto intLiteral = dyn_cast<SpirvConstantInteger>(literal)) {
const auto &literalVal = intLiteral->getValue();
bool positive = !literalVal.isNegative();
if (literalVal.getBitWidth() <= 32) {
outInst.push_back(positive ? literalVal.getZExtValue()
: literalVal.getSExtValue());
} else {
assert(literalVal.getBitWidth() == 64);
uint64_t val =
positive ? literalVal.getZExtValue() : literalVal.getSExtValue();
outInst.push_back(static_cast<unsigned>(val));
outInst.push_back(static_cast<unsigned>(val >> 32));
}
emitIntLiteral(intLiteral, outInst);
} else if (auto fLiteral = dyn_cast<SpirvConstantFloat>(literal)) {
const auto &literalVal = fLiteral->getValue();
const auto bitwidth =
llvm::APFloat::getSizeInBits(literalVal.getSemantics());
if (bitwidth <= 32) {
outInst.push_back(literalVal.bitcastToAPInt().getZExtValue());
} else {
assert(bitwidth == 64);
uint64_t val = literalVal.bitcastToAPInt().getZExtValue();
outInst.push_back(static_cast<unsigned>(val));
outInst.push_back(static_cast<unsigned>(val >> 32));
}
emitFloatLiteral(fLiteral, outInst);
}
}

Expand Down
4 changes: 4 additions & 0 deletions tools/clang/lib/SPIRV/EmitVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ class EmitTypeHandler {
uint32_t getOrCreateConstantBool(SpirvConstantBoolean *);
template <typename vecType>
void emitLiteral(const SpirvConstant *, vecType &outInst);
template <typename vecType>
void emitFloatLiteral(const SpirvConstantFloat *, vecType &outInst);
template <typename vecType>
void emitIntLiteral(const SpirvConstantInteger *, vecType &outInst);

private:
void initTypeInstruction(spv::Op op);
Expand Down
9 changes: 4 additions & 5 deletions tools/clang/lib/SPIRV/SpirvContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,13 +527,12 @@ void SpirvContext::moveDebugTypesToModule(SpirvModule *module) {
typeTemplateParams.clear();
}

const SpirvIntrinsicType *
SpirvContext::getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode,
llvm::ArrayRef<SpirvConstant *> constants,
SpirvIntrinsicType *elementTy) {
const SpirvIntrinsicType *SpirvContext::getSpirvIntrinsicType(
unsigned typeId, unsigned typeOpCode,
llvm::ArrayRef<SpvIntrinsicTypeOperand> operands) {
if (spirvIntrinsicTypes[typeId] == nullptr) {
spirvIntrinsicTypes[typeId] =
new (this) SpirvIntrinsicType(typeOpCode, constants, elementTy);
new (this) SpirvIntrinsicType(typeOpCode, operands);
}
return spirvIntrinsicTypes[typeId];
}
Expand Down
20 changes: 14 additions & 6 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12609,7 +12609,6 @@ SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) {
auto funcDecl = expr->getDirectCallee();
auto typeDefAttr = funcDecl->getAttr<VKTypeDefExtAttr>();
SpirvIntrinsicType *elementType = nullptr;
SmallVector<SpirvConstant *, 3> constants;
llvm::SmallVector<uint32_t, 2> capbilities;
llvm::SmallVector<llvm::StringRef, 2> extensions;

Expand All @@ -12621,23 +12620,32 @@ SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) {
}
}

SmallVector<SpvIntrinsicTypeOperand, 3> operands;
const auto args = expr->getArgs();
for (uint32_t i = 0; i < expr->getNumArgs(); ++i) {
auto param = funcDecl->getParamDecl(i);
const Expr *arg = args[i]->IgnoreParenLValueCasts();
if (param->hasAttr<VKReferenceExtAttr>()) {
auto typeId = hlsl::GetHLSLResourceTemplateUInt(arg->getType());
elementType = spvContext.getCreatedSpirvIntrinsicType(typeId);
auto *recType = param->getType()->getAs<RecordType>();
if (recType && recType->getDecl()->getName() == "ext_type") {
auto typeId = hlsl::GetHLSLResourceTemplateUInt(arg->getType());
auto *typeArg = spvContext.getCreatedSpirvIntrinsicType(typeId);
operands.emplace_back(typeArg);
} else {
operands.emplace_back(doExpr(arg));
}
} else if (param->hasAttr<VKLiteralExtAttr>()) {
SpirvInstruction *argInst = doExpr(arg);
auto constArg = dyn_cast<SpirvConstant>(argInst);
assert(constArg != nullptr);
constArg->setLiteral();
constants.push_back(constArg);
operands.emplace_back(constArg);
} else {
operands.emplace_back(loadIfGLValue(arg));
}
}
spvContext.getSpirvIntrinsicType(
typeDefAttr->getId(), typeDefAttr->getOpcode(), constants, elementType);
spvContext.getSpirvIntrinsicType(typeDefAttr->getId(),
typeDefAttr->getOpcode(), operands);

// Emit dummy OpNop with no semantic meaning, with possible extension and
// capabilities
Expand Down
6 changes: 2 additions & 4 deletions tools/clang/lib/SPIRV/SpirvType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,9 @@ bool RuntimeArrayType::operator==(const RuntimeArrayType &that) const {
}

SpirvIntrinsicType::SpirvIntrinsicType(
unsigned typeOp, llvm::ArrayRef<SpirvConstant *> constants,
SpirvIntrinsicType *eleTy)
unsigned typeOp, llvm::ArrayRef<SpvIntrinsicTypeOperand> inOps)
: SpirvType(TK_SpirvIntrinsicType, "spirvIntrinsicType"),
typeOpCode(typeOp), literals(constants.begin(), constants.end()),
elementType(eleTy) {}
typeOpCode(typeOp), operands(inOps.begin(), inOps.end()) {}

StructType::StructType(llvm::ArrayRef<StructType::FieldInfo> fieldsVec,
llvm::StringRef name, bool isReadOnly,
Expand Down

0 comments on commit 23ba4cb

Please sign in to comment.