Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
Add Float16 support (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tanyo Kwok committed Jul 20, 2022
1 parent ad6c81c commit e0497e2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
return torch_upstream::ScalarType::Bool;
if (type.isBF16())
return torch_upstream::ScalarType::BFloat16;
if (type.isF16())
return torch_upstream::ScalarType::Half;
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
}

Expand All @@ -74,6 +76,8 @@ Type Torch::getTypeForScalarType(
return IntegerType::get(context, 1);
case torch_upstream::ScalarType::BFloat16:
return mlir::FloatType::getBF16(context);
case torch_upstream::ScalarType::Half:
return mlir::FloatType::getF16(context);
default:
return Type();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,10 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
case ScalarType::BFloat16:
return mlirDenseElementsAttrBFloat16Get(
shapedType, numElements, static_cast<const uint16_t *>(tensorData));
case ScalarType::Half:
return mlirDenseElementsAttrFloat16Get(
shapedType, numElements, static_cast<const uint16_t *>(tensorData));

default:
throwUnsupportedTensorError();
}
Expand Down

0 comments on commit e0497e2

Please sign in to comment.