Skip to content

Commit

Permalink
Updated LowerWarp unit tests to find Allocate in PrimFunc
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Aug 15, 2023
1 parent 1c59bc9 commit 2538ac8
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
import tvm
import tvm.testing
from tvm import te
from tvm import te, tir
from tvm.contrib.nvcc import have_fp16


Expand Down Expand Up @@ -55,9 +55,13 @@ def test_lower_warp_memory_local_scope():

mod = _run_passes(mod)
fdevice = mod["f_kernel"]
allocate = fdevice.body.body

allocate = fdevice
while not isinstance(allocate, tir.Allocate):
allocate = allocate.body

assert allocate.buffer_var.type_annotation.storage_scope == "local"
assert fdevice.body.body.extents[0].value == 2
assert allocate.extents[0].value == 2


@tvm.testing.requires_cuda
Expand Down

0 comments on commit 2538ac8

Please sign in to comment.