Skip to content

Commit

Permalink
basic template function works
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Oct 23, 2024
1 parent a0b6d2f commit 62eb535
Show file tree
Hide file tree
Showing 11 changed files with 594 additions and 95 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ version = "0.1.0"
edition = "2021"

[dependencies]
pyo3 = { version = "0.22.0", optional = true }
serde = { version = "1.0.203", features = ["derive"] }
serde_json = "1.0.118"
pyo3 = { version = "0.22", optional = true }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

[features]
jit = ["pyo3"]
3 changes: 2 additions & 1 deletion luisa_lang/_builtin_decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,6 @@ def _builtin(func: _F) -> _F:

def _intrinsic_impl(*args, **kwargs) -> Any:
raise NotImplementedError(
"intrinsic functions should not be called in normal Python code"
"intrinsic functions should not be called in host-side Python code. "
"Did you mistakenly called a DSL function?"
)
18 changes: 17 additions & 1 deletion luisa_lang/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,26 @@ def get_union_args(union: Any) -> List[type]:
return list(union.__args__)
return []


def get_typevar_constrains_and_bounds(t: TypeVar) -> Tuple[List[Any], Optional[Any]]:
"""
Find the constraints and bounds of a TypeVar.
Only one of the two can be present.
"""
constraints = []
bound = None
if hasattr(t, "__constraints__"):
constraints = list(t.__constraints__)
if hasattr(t, "__bound__"):
bound = t.__bound__
return constraints, bound


def checked_cast(t: type[T], obj: Any) -> T:
if not isinstance(obj, t):
raise TypeError(f"expected {t}, got {type(obj)}")
return obj


def unique_hash(s: str) -> str:
return sha256(s.encode()).hexdigest().upper()[:8]
return sha256(s.encode()).hexdigest().upper()[:8]
9 changes: 6 additions & 3 deletions luisa_lang/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from luisa_lang.hir.defs import GlobalContext
from luisa_lang.hir import get_dsl_func
from luisa_lang.hir.infer import run_inference_on_function


class TypeCodeGenCache:
Expand Down Expand Up @@ -146,11 +147,13 @@ def gen_function(self, func: hir.Function | Callable[..., Any]) -> str:
if callable(func):
dsl_func = get_dsl_func(func)
assert dsl_func is not None
func = dsl_func
assert not dsl_func.is_generic, f"Generic functions should be resolved before codegen: {func}"
func_tmp = dsl_func.resolve([])
assert isinstance(func_tmp, hir.Function), f"Expected function, got {func_tmp}"
func = func_tmp
if id(func) in self.func_cache:
return self.func_cache[id(func)][1]
inferencer = hir.FuncTypeInferencer(func)
inferencer.infer()
run_inference_on_function(func)
func_code_gen = FuncCodeGen(self, func)
name = func_code_gen.name
self.func_cache[id(func)] = (func, name)
Expand Down
Loading

0 comments on commit 62eb535

Please sign in to comment.