Skip to content

Commit

Permalink
__getitem__ works
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Dec 17, 2024
1 parent 78733cb commit 4e956b3
Show file tree
Hide file tree
Showing 8 changed files with 385 additions and 204 deletions.
90 changes: 72 additions & 18 deletions luisa_lang/_builtin_decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ class _ObjKind(Enum):
KERNEL = auto()


def _make_func_template(f: Callable[..., Any], func_name: str, func_sig: Optional[MethodType], func_globals: Dict[str, Any], foreign_type_var_ns: Dict[TypeVar, hir.Type | hir.ComptimeValue], props: hir.FuncProperties, self_type: Optional[hir.Type] = None):
def _make_func_template(f: Callable[..., Any], func_name: str, func_sig: Optional[MethodType],
func_globals: Dict[str, Any], foreign_type_var_ns: Dict[TypeVar, hir.Type | hir.ComptimeValue],
props: hir.FuncProperties, self_type: Optional[hir.Type] = None):
# parsing_ctx = _parse.ParsingContext(func_name, func_globals)
# func_sig_parser = _parse.FuncParser(func_name, f, parsing_ctx, self_type)
# func_sig = func_sig_parser.parsed_func
Expand All @@ -91,7 +93,8 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
mapped_implicit_type_params: Dict[str,
hir.Type] = dict()
assert func_sig is not None
type_parser = parse.TypeParser(func_name, func_globals, type_var_ns, self_type, 'instantiate')
type_parser = parse.TypeParser(
func_name, func_globals, type_var_ns, self_type, 'instantiate')
for (tv, t) in func_sig.env.items():
type_var_ns[tv] = unwrap(type_parser.parse_type(t))
if is_generic:
Expand All @@ -115,7 +118,7 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
mapped_type = mapping[gp]
assert isinstance(mapped_type, hir.Type)
mapped_implicit_type_params[name] = mapped_type

func_sig_instantiated, _p = parse.convert_func_signature(
func_sig, func_name, func_globals, type_var_ns, mapped_implicit_type_params, self_type, mode='instantiate')
# print(func_name, func_sig)
Expand All @@ -124,10 +127,10 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
assert not isinstance(
func_sig_instantiated.return_type, hir.SymbolicType)
func_parser = parse.FuncParser(
func_name, f, func_sig_instantiated, func_globals, type_var_ns, self_type)
func_name, f, func_sig_instantiated, func_globals, type_var_ns, self_type, props.returning_ref)
ret = func_parser.parse_body()
ret.inline_hint = props.inline
ret.export = props.export
ret.export = props.export
return ret
params = [v[0] for v in func_sig.args]
is_generic = len(func_sig_converted.generic_params) > 0
Expand Down Expand Up @@ -162,17 +165,23 @@ def _dsl_func_impl(f: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
# return cast(_T, f)


def _dsl_struct_impl(cls: type[_TT], attrs: Dict[str, Any], ir_ty_override: hir.Type | None = None) -> type[_TT]:
ctx = hir.GlobalContext.get()
_MakeTemplateFn = Callable[[List[hir.GenericParameter]], hir.Type]
_InstantiateFn = Callable[[List[Any]], hir.Type]


def _dsl_struct_impl(cls: type[_TT], attrs: Dict[str, Any], ir_ty_override: hir.Type | Tuple[_MakeTemplateFn, _InstantiateFn] | None = None, opqaue_override: str | None = None) -> type[_TT]:
ctx = hir.GlobalContext.get()
register_class(cls)
assert not (ir_ty_override is not None and opqaue_override is not None)
cls_info = class_typeinfo(cls)
globalns = _get_cls_globalns(cls)
globalns[cls.__name__] = cls
type_var_to_generic_param: Dict[TypeVar, hir.GenericParameter] = {}
for type_var in cls_info.type_vars:
type_var_to_generic_param[type_var] = hir.GenericParameter(
type_var.__name__, cls.__qualname__)
generic_params = [type_var_to_generic_param[tv]
for tv in cls_info.type_vars]

def parse_fields(tp: parse.TypeParser, self_ty: hir.Type):
fields: List[Tuple[str, hir.Type]] = []
Expand All @@ -182,13 +191,14 @@ def parse_fields(tp: parse.TypeParser, self_ty: hir.Type):
raise hir.TypeInferenceError(
None, f"Cannot infer type for field {name} of {cls.__name__}")
fields.append((name, field_ty))
if isinstance(self_ty, hir.StructType):
self_ty.fields = fields
elif isinstance(self_ty, hir.BoundType):
assert isinstance(self_ty.instantiated, hir.StructType)
self_ty.instantiated.fields = fields
else:
raise NotImplementedError()
if len(fields) > 0:
if isinstance(self_ty, hir.StructType):
self_ty.fields = fields
elif isinstance(self_ty, hir.BoundType):
assert isinstance(self_ty.instantiated, hir.StructType)
self_ty.instantiated.fields = fields
else:
raise NotImplementedError()

def parse_methods(type_var_ns: Dict[TypeVar, hir.Type | Any], self_ty: hir.Type,):
for name in cls_info.methods:
Expand All @@ -198,16 +208,24 @@ def parse_methods(type_var_ns: Dict[TypeVar, hir.Type | Any], self_ty: hir.Type,
props = getattr(method_object, '__luisa_func_props__')
else:
props = hir.FuncProperties()
if name == '__getitem__':
props.returning_ref = True
template = _make_func_template(
method_object, get_full_name(method_object), cls_info.methods[name], globalns, type_var_ns, props, self_type=self_ty)
if isinstance(self_ty, hir.BoundType):
assert isinstance(self_ty.instantiated, hir.StructType)
assert isinstance(self_ty.instantiated,
(hir.StructType, hir.OpaqueType))
self_ty.instantiated.methods[name] = template
else:
self_ty.methods[name] = template
ir_ty: hir.Type
if ir_ty_override is not None:
ir_ty = ir_ty_override
if isinstance(ir_ty_override, hir.Type):
ir_ty = ir_ty_override
else:
ir_ty = ir_ty_override[0](generic_params)
elif opqaue_override is not None:
ir_ty = hir.OpaqueType(opqaue_override)
else:
ir_ty = hir.StructType(
f'{cls.__name__}_{unique_hash(cls.__qualname__)}', cls.__qualname__, [])
Expand All @@ -226,8 +244,15 @@ def monomorphization_func(args: List[hir.Type | Any]) -> hir.Type:
for i, arg in enumerate(args):
type_var_ns[cls_info.type_vars[i]] = arg
hash_s = unique_hash(f'{cls.__qualname__}_{args}')
inner_ty = hir.StructType(
f'{cls.__name__}_{hash_s}M', f'{cls.__qualname__}[{",".join([str(a) for a in args])}]', [])
inner_ty: hir.Type
if ir_ty_override is not None:
assert isinstance(ir_ty_override, tuple)
inner_ty = ir_ty_override[1](args)
elif opqaue_override:
inner_ty = hir.OpaqueType(opqaue_override, args[:])
else:
inner_ty = hir.StructType(
f'{cls.__name__}_{hash_s}M', f'{cls.__qualname__}[{",".join([str(a) for a in args])}]', [])
mono_self_ty = hir.BoundType(ir_ty, args, inner_ty)
mono_type_parser = parse.TypeParser(
cls.__qualname__, globalns, type_var_ns, mono_self_ty, 'instantiate')
Expand All @@ -253,6 +278,22 @@ def _dsl_decorator_impl(obj: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
raise NotImplementedError()


def opaque(name: str) -> Callable[[type[_TT]], type[_TT]]:
"""
Mark a class as a DSL opaque type.
Example:
```python
@luisa.opaque("Buffer")
class Buffer(Generic[T]):
pass
```
"""
def wrapper(cls: type[_TT]) -> type[_TT]:
return _dsl_struct_impl(cls, {}, opqaue_override=name)
return wrapper


def struct(cls: type[_TT]) -> type[_TT]:
"""
Mark a class as a DSL struct.
Expand All @@ -277,6 +318,12 @@ def decorator(cls: type[_TT]) -> type[_TT]:
return decorator


def builtin_generic_type(make_template: _MakeTemplateFn, instantiate: _InstantiateFn) -> Callable[[type[_TT]], type[_TT]]:
def decorator(cls: type[_TT]) -> type[_TT]:
return typing.cast(type[_TT], _dsl_struct_impl(cls, {}, ir_ty_override=(make_template, instantiate)))
return decorator


_KernelType = TypeVar("_KernelType", bound=Callable[..., None])


Expand Down Expand Up @@ -310,6 +357,13 @@ def __init__(self, value: str):
def _parse_func_kwargs(kwargs: Dict[str, Any]) -> hir.FuncProperties:
props = hir.FuncProperties()
props.byref = set()
return_ = kwargs.get("return", None)
if return_ is not None:
if return_ == 'ref':
props.returning_ref = True
else:
raise ValueError(
f"invalid value for return: {return_}, expected 'ref'")
inline = kwargs.get("inline", False)
if isinstance(inline, bool):
props.inline = inline
Expand Down
31 changes: 29 additions & 2 deletions luisa_lang/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,16 @@ def gen(self, ty: hir.Type) -> str:
def gen_impl(self, ty: hir.Type) -> str:
match ty:
case hir.IntType(bits=bits, signed=signed):
int_names = {
'8':'byte',
'16':'short',
'32':'int',
'64':'long',
}
if signed:
return f"i{bits}"
return f"lc_{int_names[str(bits)]}"
else:
return f"u{bits}"
return f"lc_u{int_names[str(bits)]}"
case hir.FloatType(bits=bits):
match bits:
case 16:
Expand Down Expand Up @@ -77,6 +83,15 @@ def do():
return ''
case hir.TypeConstructorType():
return ''
case hir.OpaqueType():
def do():
match ty.name:
case 'Buffer':
elem_ty = self.gen(ty.extra_args[0])
return f'__builtin__Buffer<{elem_ty}>'
case _:
raise NotImplementedError(f"unsupported opaque type: {ty.name}")
return do()
case _:
raise NotImplementedError(f"unsupported type: {ty}")

Expand Down Expand Up @@ -167,6 +182,8 @@ def mangle_impl(self, obj: Union[hir.Type, hir.Function]) -> str:
case hir.BoundType():
assert obj.instantiated
return self.mangle(obj.instantiated)
case hir.OpaqueType():
return obj.name
case _:
raise NotImplementedError(f"unsupported object: {obj}")

Expand Down Expand Up @@ -263,6 +280,16 @@ def gen_ref(self, ref: hir.Ref) -> str:
base = self.gen_ref(index.base)
idx = self.gen_expr(index.index)
return f"{base}[{idx}]"
case hir.IntrinsicRef() as intrin:
def do():
intrin_name = intrin.name
gened_args = [self.gen_value_or_ref(
arg) for arg in intrin.args]
if intrin_name == 'buffer_ref':
return f"{gened_args[0]}[{gened_args[1]}]"
else:
raise RuntimeError(f"unsupported intrinsic reference: {intrin_name}")
return do()
case _:
raise NotImplementedError(f"unsupported reference: {ref}")

Expand Down
2 changes: 1 addition & 1 deletion luisa_lang/codegen/cpp_lib.py

Large diffs are not rendered by default.

Loading

0 comments on commit 4e956b3

Please sign in to comment.