Skip to content

Commit

Permalink
tests: skip attention-related parameterize when attn_layer is 0
Browse files Browse the repository at this point in the history
The tests makes no sense in this case.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed May 15, 2024
1 parent 2bf0769 commit aeaeff2
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions source/tests/consistent/descriptor/test_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
from typing import (
Any,
Optional,
Tuple,
)

Expand Down Expand Up @@ -107,6 +108,15 @@ def data(self) -> dict:
"seed": 1145141919810,
}

def is_meaningless_zero_attention_layer_tests(
self,
attn_layer: int,
attn_dotr: bool,
normalize: bool,
temperature: Optional[float],
) -> bool:
return attn_layer == 0 and (attn_dotr or normalize or temperature is not None)

@property
def skip_pt(self) -> bool:
(
Expand All @@ -128,7 +138,12 @@ def skip_pt(self) -> bool:
concat_output_tebd,
precision,
) = self.param
return CommonTest.skip_pt
return CommonTest.skip_pt or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
attn_dotr,
normalize,
temperature,
)

@property
def skip_dp(self) -> bool:
Expand All @@ -151,7 +166,12 @@ def skip_dp(self) -> bool:
concat_output_tebd,
precision,
) = self.param
return CommonTest.skip_pt
return CommonTest.skip_pt or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
attn_dotr,
normalize,
temperature,
)

@property
def skip_tf(self) -> bool:
Expand All @@ -176,12 +196,21 @@ def skip_tf(self) -> bool:
) = self.param
# TODO (excluded_types != [] and attn_layer > 0) need fix
return (
env_protection != 0.0
or smooth_type_embedding
or not normalize
or temperature != 1.0
or (excluded_types != [] and attn_layer > 0)
or (type_one_side and tebd_input_mode == "strip") # not consistent yet
CommonTest.skip_tf
or (
env_protection != 0.0
or smooth_type_embedding
or not normalize
or temperature != 1.0
or (excluded_types != [] and attn_layer > 0)
or (type_one_side and tebd_input_mode == "strip") # not consistent yet
)
or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
attn_dotr,
normalize,
temperature,
)
)

tf_class = DescrptDPA1TF
Expand Down

0 comments on commit aeaeff2

Please sign in to comment.