Skip to content

Commit

Permalink
progress
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Oct 13, 2024
1 parent d416c34 commit bb335a3
Show file tree
Hide file tree
Showing 14 changed files with 546 additions and 100 deletions.
59 changes: 27 additions & 32 deletions luisa_lang/_builtin_decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,6 @@
_F = TypeVar("_F", bound=Callable[..., Any])


@functools.lru_cache(maxsize=None)
def _retrieve_generic_params(cls: type) -> Set[TypeVar]:
if hasattr(cls, "__orig_bases__"):
orig_bases = cls.__orig_bases__
for base in orig_bases:
print(base, typing.get_args(base))
_retrieve_generic_params(base)
return set()


def _builtin_type(ty: hir.Type, *args, **kwargs) -> Callable[[_T], _T]:
def decorator(cls: _T) -> _T:
cls_name = _get_full_name(cls)
Expand All @@ -42,48 +32,51 @@ def make_type_rule(
parameters = signature.parameters
return_type = method.return_type
if not isinstance(return_type, type):
raise hir.TypeInferenceError(
f"Valid return type annotation required for {cls_name}.{name}"
)
raise hir.TypeInferenceError(None,
f"Valid return type annotation required for {cls_name}.{name}"
)

def type_rule(args: List[hir.Type]) -> hir.Type:
if len(args) > len(parameters):
raise hir.TypeInferenceError(
f"Too many arguments for {cls_name}.{name} expected at most {len(parameters)} but got {len(args)}"
)

parameters_list = list(parameters.values())
if name == '__init__':
parameters_list = parameters_list[1:]
if len(args) > len(parameters_list):
raise hir.TypeInferenceError(None,
f"Too many arguments for {cls_name}.{name} expected at most {len(parameters_list)} but got {len(args)}"
)
for i, arg in enumerate(args):
param = parameters_list[i]
param_ty = type_hints.get(param.name)
if param.name == "self":
if arg != ty:
if i != 0:
raise hir.TypeInferenceError(
f"Expected {cls_name}.{name} to be called with an instance of {cls_name} as the first argument but got {arg}"
)
raise hir.TypeInferenceError(
f"Expected {cls_name}.{name} to be called with an instance of {cls_name} but got {arg}"
)
raise hir.TypeInferenceError(None,
f"Expected {cls_name}.{name} to be called with an instance of {cls_name} as the first argument but got {arg}"
)
raise hir.TypeInferenceError(None,
f"Expected {cls_name}.{name} to be called with an instance of {cls_name} but got {arg}"
)
continue
if param_ty is None:
raise hir.TypeInferenceError(
f"Parameter type annotation required for {cls_name}.{name}"
)
raise hir.TypeInferenceError(None,
f"Parameter type annotation required for {cls_name}.{name}"
)

def check(anno_tys: List[type | Any]):
possible_failed_reasons: List[str] = []
for anno_ty in anno_tys:
if anno_ty == float:
# match all hir.FloatType
if isinstance(arg, hir.FloatType):
if isinstance(arg, hir.FloatType) or isinstance(arg, hir.GenericFloatType):
return
else:
possible_failed_reasons.append(
f"Expected {cls_name}.{name} to be called with {anno_ty} but got {arg}"
)
continue
if anno_ty == int:
if isinstance(arg, hir.IntType):
if isinstance(arg, hir.IntType) or isinstance(arg, hir.GenericIntType):
return
else:
possible_failed_reasons.append(
Expand All @@ -110,14 +103,16 @@ def check(anno_tys: List[type | Any]):
possible_failed_reasons.append(
f"Expected {cls_name}.{name} to be called with {anno_ty} but got {arg}"
)
raise hir.TypeInferenceError(
f"Expected {cls_name}.{name} to be called with one of {possible_failed_reasons}"
)
raise hir.TypeInferenceError(None,
f"Expected {cls_name}.{name} to be called with one of {possible_failed_reasons}"
)

union_args = get_union_args(param_ty)
if union_args == []:
union_args = [param_ty]
check(union_args)
if name == '__init__':
return ty
if return_type:
return ctx.types[return_type]
else:
Expand All @@ -140,7 +135,7 @@ def make_builtin():
return decorator


def _builtin(func: _F, *args, **kwargs) -> _F:
def _builtin(func: _F) -> _F:
return func


Expand Down
78 changes: 59 additions & 19 deletions luisa_lang/codegen/cpp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from functools import cache
from luisa_lang import hir
from luisa_lang._utils import unwrap
from luisa_lang.codegen import CodeGen, ScratchBuffer
from typing import Any, Callable, Dict, Tuple, Union
from typing import Any, Callable, Dict, Set, Tuple, Union

from luisa_lang.hir.defs import GlobalContext
from luisa_lang.hir import get_dsl_func
Expand All @@ -25,32 +26,66 @@ def gen_impl(self, ty: hir.Type) -> str:
match ty:
case hir.IntType(bits=bits, signed=signed):
if signed:
return f"int{bits}_t"
return f"i{bits}"
else:
return f"uint{bits}_t"
return f"u{bits}"
case hir.FloatType(bits=bits):
return f"float{bits}_t"
return f"f{bits}"
case hir.BoolType():
return "bool"
case hir.VectorType(element=element, count=count):
return f"vec<{self.gen(element)}, {count}>"
case _:
raise NotImplementedError(f"unsupported type: {ty}")


_OPERATORS: Set[str] = set([
'__add__',
'__sub__',
'__mul__',
'__truediv__',
'__floordiv__',
'__mod__',
'__pow__',
'__and__',
'__or__',
'__xor__',
'__lshift__',
'__rshift__',
'__eq__',
'__ne__',
'__lt__',
'__le__',
'__gt__',
'__ge__',
])
@cache
def map_builtin_to_cpp_func(name: str) -> str:
comps = name.split(".")
if comps[0] == "luisa_lang" and comps[1] == "math_types":
if comps[3] in _OPERATORS:
return f'{comps[3]}<{comps[2]}>'
return f'{comps[2]}_{comps[3]}'

else:
raise NotImplementedError(f"unsupported builtin function: {name}")


class Mangling:
cache: Dict[hir.Type | hir.Function, str]
cache: Dict[hir.Type | hir.FunctionLike, str]

def __init__(self) -> None:
self.cache = {}

def mangle(self, obj: Union[hir.Type, hir.Function]) -> str:
def mangle(self, obj: Union[hir.Type, hir.FunctionLike]) -> str:
if obj in self.cache:
return self.cache[obj]
else:
res = self.mangle_impl(obj)
self.cache[obj] = res
return res

def mangle_impl(self, obj: Union[hir.Type, hir.Function]) -> str:
def mangle_impl(self, obj: Union[hir.Type, hir.FunctionLike]) -> str:
def mangle_name(name: str) -> str:
comps = name.split(".")
mangled = "N"
Expand All @@ -75,9 +110,14 @@ def mangle_name(name: str) -> str:
return f"P{self.mangle(element)}"
case hir.ArrayType(element=element, count=count):
return f"A{count}{self.mangle(element)}"
case hir.VectorType(element=element, count=count):
return f"V{count}{self.mangle(element)}"
case hir.Function(name=name, params=params, return_type=ret):
name = mangle_name(name)
return f"F{name}_{self.mangle(ret)}{''.join(self.mangle(unwrap(p.type)) for p in params)}"
case hir.BuiltinFunction(name=name):
name = map_builtin_to_cpp_func(name)
return f"__builtin_{name}"
case _:
raise NotImplementedError(f"unsupported object: {obj}")

Expand Down Expand Up @@ -140,7 +180,10 @@ def gen_ref(self, ref: hir.Ref) -> str:
case hir.Var() as var:
return var.name
case hir.Member() as member:
base = self.gen_ref(member.base)
if isinstance(member.base, hir.Ref):
base = self.gen_ref(member.base)
else:
base = self.gen_expr(member.base)
return f"{base}.{member.field}"
case hir.ValueRef() as value_ref:
return self.gen_expr(value_ref.value)
Expand All @@ -154,16 +197,10 @@ def gen_expr(self, expr: hir.Value) -> str:
case hir.Call() as call:
assert call.resolved, f"unresolved call: {call}"
kind = call.kind
match kind:
case hir.CallOpKind.BINARY_OP:
return f"{self.gen_expr(call.args[0])} {call.op} {self.gen_expr(call.args[1])}"
case hir.CallOpKind.UNARY_OP:
return f"{call.op}{self.gen_expr(call.args[0])}"
case hir.CallOpKind.FUNC:
# TODO: fix this
assert not isinstance(call.op, str)
op = self.gen_expr(call.op)
return f"{op}({','.join(self.gen_expr(arg) for arg in call.args)})"
assert kind == hir.CallOpKind.FUNC and isinstance(
call.op, hir.Value)
op = self.gen_expr(call.op)
return f"{op}({','.join(self.gen_expr(arg) for arg in call.args)})"
case hir.Constant() as constant:
value = constant.value
if isinstance(value, int):
Expand All @@ -176,8 +213,11 @@ def gen_expr(self, expr: hir.Value) -> str:
return f'"{value}"'
elif isinstance(value, hir.Function):
return self.base.gen_function(value)
elif isinstance(value, hir.BuiltinFunction):
return self.base.mangling.mangle(value)
else:
raise NotImplementedError(f"unsupported constant: {constant}")
raise NotImplementedError(
f"unsupported constant: {constant}")
case _:
raise NotImplementedError(f"unsupported expression: {expr}")

Expand Down
1 change: 1 addition & 0 deletions luisa_lang/codegen/cpp_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CPP_LIB_COMPRESSED = '''QlpoOTFBWSZTWX4ZiUsAAGvfgEQSSX79TxplxS6/598qMAF5QAaqbUzRGT0EaYjI0PQJghkG0mg5gAAAAAAAAAASJEhkyARNpqe0oNDTTyj002pPKemkmCmkl2Hz6xh92PxRRFW93kNhepBBIFznr48Hbut1j2TlKV6z5ajo3MyWLRi9Jm2XIZewKmV+4Q4giRrdmTC5NRho0pkAYHpIqRg+AVobwGdzQznOeRVHHNKKr+EJEsZ+V6ELYvIhHhszWztY2R5yJ/plepNCoDeXHA80cUSJF4gaLoRrLkXFwSOF97VEA5HrmT2MmjYVmWFvPUpK0aBHcG7kWLcHPIva0JCMpn8IKKAgKuY4rjo6eoQFao8sDVi7MK+JXLpUJQGc6lC2vUjFZQHIYijQq9DbbcjIxHW+oC3LmTm8h4gxGkKgYQJYD4JZwgxNU9q7zPSvdrvZkMBA8aZpxqZVM0hMQwh9rWOFBSxsHnhuEcWp4V9MN/uCoK1SIjXiibCTJ/8XckU4UJB+GYlL'''
51 changes: 45 additions & 6 deletions luisa_lang/hir/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,13 @@ def __repr__(self) -> str:
def __hash__(self) -> int:
return hash((VectorType, self.element, self.count))

@override
def member(self, field: Any) -> Optional['Type']:
comps = 'xyzw'[:self.count]
if isinstance(field, str) and field in comps:
return self.element
return Type.member(self, field)


class ArrayType(Type):
element: Type
Expand Down Expand Up @@ -303,7 +310,7 @@ def member(self, field: Any) -> Optional['Type']:
if isinstance(field, str):
if field in self._field_dict:
return self._field_dict[field]
return None
return Type.member(self, field)


class SymbolicType(Type):
Expand Down Expand Up @@ -570,10 +577,10 @@ def __init__(


class Member(Ref):
base: Ref
base: Ref | Value
field: str

def __init__(self, base: Ref, field: str, span: Optional[Span]) -> None:
def __init__(self, base: Ref | Value, field: str, span: Optional[Span]) -> None:
super().__init__(None, span)
self.base = base
self.field = field
Expand All @@ -590,10 +597,10 @@ def children(self) -> List[Node]:


class Index(Ref):
base: Ref
base: Ref | Value
index: Value

def __init__(self, base: Ref, index: Value, span: Optional[Span]) -> None:
def __init__(self, base: Ref | Value, index: Value, span: Optional[Span]) -> None:
super().__init__(None, span)
self.base = base
self.index = index
Expand Down Expand Up @@ -643,9 +650,17 @@ def __init__(self, value: Any, span: Optional[Span] = None) -> None:
super().__init__(None, span)
self.value = value

def __eq__(self, value: object) -> bool:
return isinstance(value, Constant) and value.value == self.value

def __hash__(self) -> int:
return hash(self.value)


class Call(Value):
op: Value | str
"""After type inference, op should be a Value."""

args: List[Value]
kind: CallOpKind
resolved: bool
Expand Down Expand Up @@ -684,7 +699,17 @@ def children(self) -> List[Node]:


class TypeInferenceError(Exception):
pass
node: Node | None
message: str

def __init__(self, node: Node | None, message: str) -> None:
self.node = node
self.message = message

def __str__(self) -> str:
if self.node is None:
return f"Type inference error: {self.message}"
return f"Type inference error at {self.node.span}: {self.message}"


class TypeRule(ABC):
Expand Down Expand Up @@ -912,3 +937,17 @@ def get_dsl_func(func: Callable[..., Any]) -> Optional[Function]:
return None
assert func_
return func_


def get_dsl_type(cls: type) -> Optional[Type]:
return GlobalContext.get().types.get(cls)


def is_type_compatible_to(ty:Type, target:Type)->bool:
if ty == target:
return True
if isinstance(target, FloatType):
return isinstance(ty, GenericFloatType)
if isinstance(target, IntType):
return isinstance(ty, GenericIntType)
return False
Loading

0 comments on commit bb335a3

Please sign in to comment.