Skip to content

Commit

Permalink
Fix some testing on the latest version of keras (#1663)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Jun 7, 2024
1 parent 9c5f267 commit 30b34d3
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 61 deletions.
10 changes: 0 additions & 10 deletions keras_nlp/src/models/whisper/whisper_audio_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,6 @@ def __init__(
max_audio_length=30,
**kwargs,
):
# Check dtype and provide a default.
if "dtype" not in kwargs or kwargs["dtype"] is None:
kwargs["dtype"] = "float32"
else:
dtype = tf.dtypes.as_dtype(kwargs["dtype"])
if not dtype.is_floating:
raise ValueError(
f"dtype must be a floating type. Received: dtype={dtype}"
)

super().__init__(**kwargs)

self._convert_input_args = False
Expand Down
17 changes: 6 additions & 11 deletions keras_nlp/src/tokenizers/byte_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def test_load_model_with_config(self):
)

def test_config(self):
input_data = ["hello", "fun", "▀▁▂▃", "haha"]
tokenizer = ByteTokenizer(
name="byte_tokenizer_config_test",
lowercase=False,
Expand All @@ -216,14 +217,8 @@ def test_config(self):
errors="ignore",
replacement_char=0,
)
exp_config = {
"dtype": "int32",
"errors": "ignore",
"lowercase": False,
"name": "byte_tokenizer_config_test",
"normalization_form": "NFC",
"replacement_char": 0,
"sequence_length": 8,
"trainable": True,
}
self.assertEqual(tokenizer.get_config(), exp_config)
cloned_tokenizer = ByteTokenizer.from_config(tokenizer.get_config())
self.assertAllEqual(
tokenizer(input_data),
cloned_tokenizer(input_data),
)
46 changes: 6 additions & 40 deletions keras_nlp/src/tokenizers/unicode_codepoint_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def test_load_model_with_config(self):
)

def test_config(self):
input_data = ["ninja", "samurai", "▀▁▂▃"]
tokenizer = UnicodeCodepointTokenizer(
name="unicode_character_tokenizer_config_gen",
lowercase=False,
Expand All @@ -272,45 +273,10 @@ def test_config(self):
replacement_char=0,
vocabulary_size=100,
)
exp_config = {
"dtype": "int32",
"errors": "ignore",
"lowercase": False,
"name": "unicode_character_tokenizer_config_gen",
"normalization_form": "NFC",
"replacement_char": 0,
"sequence_length": 8,
"input_encoding": "UTF-8",
"output_encoding": "UTF-8",
"trainable": True,
"vocabulary_size": 100,
}
self.assertEqual(tokenizer.get_config(), exp_config)

tokenize_different_encoding = UnicodeCodepointTokenizer(
name="unicode_character_tokenizer_config_gen",
lowercase=False,
sequence_length=8,
errors="ignore",
replacement_char=0,
input_encoding="UTF-16",
output_encoding="UTF-16",
vocabulary_size=None,
cloned_tokenizer = UnicodeCodepointTokenizer.from_config(
tokenizer.get_config()
)
exp_config_different_encoding = {
"dtype": "int32",
"errors": "ignore",
"lowercase": False,
"name": "unicode_character_tokenizer_config_gen",
"normalization_form": None,
"replacement_char": 0,
"sequence_length": 8,
"input_encoding": "UTF-16",
"output_encoding": "UTF-16",
"trainable": True,
"vocabulary_size": None,
}
self.assertEqual(
tokenize_different_encoding.get_config(),
exp_config_different_encoding,
self.assertAllEqual(
tokenizer(input_data),
cloned_tokenizer(input_data),
)

0 comments on commit 30b34d3

Please sign in to comment.