Skip to content

Commit

Permalink
Reapply "[MLIR][Python] add ctype python binding support for bf16" (#…
Browse files Browse the repository at this point in the history
…101271)

Reapply the PR which was reverted due to built-bots, and now the bots
get updated.
https://discourse.llvm.org/t/need-a-help-with-the-built-bots/79437
original PR: #92489, reverted
in #93771
  • Loading branch information
xurui1995 authored Jul 31, 2024
1 parent 8300eaa commit 5ef087b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 1 deletion.
19 changes: 19 additions & 0 deletions mlir/python/mlir/runtime/np_to_memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
import numpy as np
import ctypes

try:
import ml_dtypes
except ModuleNotFoundError:
# The third-party ml_dtypes provides some optional low precision data-types for NumPy.
ml_dtypes = None


class C128(ctypes.Structure):
"""A ctype representation for MLIR's Double Complex."""
Expand All @@ -26,6 +32,12 @@ class F16(ctypes.Structure):
_fields_ = [("f16", ctypes.c_int16)]


class BF16(ctypes.Structure):
"""A ctype representation for MLIR's BFloat16."""

_fields_ = [("bf16", ctypes.c_int16)]


# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
def as_ctype(dtp):
"""Converts dtype to ctype."""
Expand All @@ -35,6 +47,8 @@ def as_ctype(dtp):
return C64
if dtp == np.dtype(np.float16):
return F16
if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
return BF16
return np.ctypeslib.as_ctypes_type(dtp)


Expand All @@ -46,6 +60,11 @@ def to_numpy(array):
return array.view("complex64")
if array.dtype == F16:
return array.view("float16")
assert not (
array.dtype == BF16 and ml_dtypes is None
), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
if array.dtype == BF16:
return array.view("bfloat16")
return array


Expand Down
3 changes: 2 additions & 1 deletion mlir/python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy>=1.19.5, <=1.26
pybind11>=2.9.0, <=2.10.3
PyYAML>=5.3.1, <=6.0.1
PyYAML>=5.3.1, <=6.0.1
ml_dtypes # provides several NumPy dtype extensions, including the bf16
40 changes: 40 additions & 0 deletions mlir/test/python/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir.runtime import *
from ml_dtypes import bfloat16


# Log everything to stderr and flush so that we have a unified stream to match
Expand Down Expand Up @@ -521,6 +522,45 @@ def testComplexUnrankedMemrefAdd():
run(testComplexUnrankedMemrefAdd)


# Test bf16 memrefs
# CHECK-LABEL: TEST: testBF16Memref
def testBF16Memref():
with Context():
module = Module.parse(
"""
module {
func.func @main(%arg0: memref<1xbf16>,
%arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
%0 = arith.constant 0 : index
%1 = memref.load %arg0[%0] : memref<1xbf16>
memref.store %1, %arg1[%0] : memref<1xbf16>
return
}
} """
)

arg1 = np.array([0.5]).astype(bfloat16)
arg2 = np.array([0.0]).astype(bfloat16)

arg1_memref_ptr = ctypes.pointer(
ctypes.pointer(get_ranked_memref_descriptor(arg1))
)
arg2_memref_ptr = ctypes.pointer(
ctypes.pointer(get_ranked_memref_descriptor(arg2))
)

execution_engine = ExecutionEngine(lowerToLLVM(module))
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)

# test to-numpy utility
# CHECK: [0.5]
npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
log(npout)


run(testBF16Memref)


# Test addition of two 2d_memref
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
def testDynamicMemrefAdd2D():
Expand Down

0 comments on commit 5ef087b

Please sign in to comment.