Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoopPartition] Fix a bug of LoopPartition in single point scenarioes #16104

Merged
merged 1 commit into from
Dec 15, 2023

Conversation

lightzhan-intellif
Copy link
Contributor

This PR tries to fix a bug of the pass LoopPartiton. When there are one or more tensors containing a shape 1 in the concat dim, the pass will unroll the loops wrongly after partitioning. For example:

@T.prim_func
def concat_func_single_point(
    placeholder: T.Buffer((28, 64), "int8"),
    placeholder_1: T.Buffer((28, 1), "int8"),
    placeholder_2: T.Buffer((28, 63), "int8"),
    T_concat: T.Buffer((28, 128), "int8"),
) -> None:
    for i0 in range(28):
        for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
            if i1 > 63:
                T_concat[i0, i1] = placeholder[i0, i1 - 64]
            elif i1 == 63:
                T_concat[i0, i1] = placeholder_1[i0, i1 - 63]
            else:
                T_concat[i0, i1] = placeholder_2[i0, i1]

after LoopPartition:

@T.prim_func
def expected_partitioned_concat_single_point(
    placeholder: T.Buffer((28, 64), "int8"),
    placeholder_1: T.Buffer((28, 1), "int8"),
    placeholder_2: T.Buffer((28, 63), "int8"),
    T_concat: T.Buffer((28, 128), "int8"),
):
    for i0 in range(28):
        T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data)
        for i1 in T.unroll(63): # Note here, it is unrolled.
            placeholder_2_1 = T.Buffer((1764,), "int8", data=placeholder_2.data)
            T_concat_1[i0 * 128 + i1] = placeholder_2_1[i0 * 63 + i1]
        placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
        T_concat_1[i0 * 128 + 63] = placeholder_1_1[i0]
        for i1 in T.unroll(64): # here too.
            placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
            T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1]

cc @wrongtest-intellif @tqchen

@lightzhan-intellif lightzhan-intellif changed the title [LoopPartition] Fix a bug of LoopPartition in single point scenarioes. [LoopPartition] Fix a bug of LoopPartition in single point scenarioes Nov 10, 2023
@tqchen
Copy link
Member

tqchen commented Dec 7, 2023

Thanks @lightzhan-intellif do you mind to fix the ci?

@lightzhan-intellif
Copy link
Contributor Author

@tvm-bot rerun

@lightzhan-intellif
Copy link
Contributor Author

Thanks @lightzhan-intellif do you mind to fix the ci?

done

@tqchen tqchen merged commit 870246a into apache:main Dec 15, 2023
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants