Skip to content

Commit

Permalink
Fix category mapper for case when input value yields default string (l…
Browse files Browse the repository at this point in the history
…lvm#1271)

When an input tensor contains a int64_t value that is not in the CategoryMapper's category attribute, the default string value needs to be returned. Currently there is a bug in the implementation (segfaults), this patch fixes the problem and adds and end-to-end test.

Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>
  • Loading branch information
Ettore Tiotto authored Mar 30, 2022
1 parent 5ecd858 commit f454310
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,8 @@ void populateKrnlToLLVMConversion(LLVMTypeConverter &typeConverter,
krnl::populateLoweringKrnlGetRefOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlInstrumentOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlMemcpyOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlPrintOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlPrintTensorOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlVectorTypeCastOpPattern(
typeConverter, patterns, ctx);
krnl::populateLoweringKrnlRandomNormalOpPattern(typeConverter, patterns, ctx);
Expand Down
9 changes: 1 addition & 8 deletions src/Conversion/KrnlToLLVM/KrnlGlobal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
krnlGlobalOp.value().getValue());
}

// LLVM_DEBUG(llvm::dbgs() << "global: " << global << "\n";);
LLVM_DEBUG(llvm::dbgs() << "global: " << global << "\n";);
return global;
}

Expand Down Expand Up @@ -223,13 +223,6 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
Type i8Type = IntegerType::get(builder.getContext(), 8);
Type i8PtrType = LLVM::LLVMPointerType::get(i8Type);

int64_t numStrings = denseAttr.getValues<StringRef>().size();
if (numStrings == 1) {
StringRef str = *denseAttr.getValues<StringRef>().begin();
return krnl::getOrCreateGlobalString(
str, loc, builder, module, getTypeConverter());
}

// Generate LLVM GlobalOps for each string in the KrnlGlobalOp dense
// attribute.
SmallVector<LLVM::GlobalOp> globalOps;
Expand Down
4 changes: 2 additions & 2 deletions test/backend-cpp/TestCategoryMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ static bool testInt64ToStr() {

const CategoryMapperBuilder::CMAttributes attributes = {{1, 2, 3, 4, 5},
{"cat", "dog", "human", "tiger", "beaver"}, -1, "unknown"};
const ArrayRef<int64_t> input = {1, 2, 3, 4, 5};
const ArrayRef<int64_t> input = {1, 2, 3, 6, 4, 5};
const ArrayRef<const char *> expResult = {
"cat", "dog", "human", "tiger", "beaver"};
"cat", "dog", "human", "unknown", "tiger", "beaver"};

CategoryMapperBuilder categoryMapper(
SharedLibBaseName, attributes, input, expResult);
Expand Down

0 comments on commit f454310

Please sign in to comment.