Skip to content

Commit

Permalink
Remove dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Jun 10, 2024
1 parent 82f90d4 commit e05bdb0
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
6 changes: 3 additions & 3 deletions keras_nlp/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def assertAllEqual(self, x1, x2, msg=None):
super().assertAllEqual(x1, x2, msg=msg)

def assertDTypeEqual(self, x, expected_dtype, msg=None):
input_dtype = keras.utils.standardize_dtype(x.dtype)
input_dtype = keras.backend.standardize_dtype(x.dtype)
super().assertEqual(input_dtype, expected_dtype, msg=msg)

def run_layer_test(
Expand Down Expand Up @@ -127,8 +127,8 @@ def run_output_asserts(layer, output, eager=False):
)
output_dtype = tree.flatten(output)[0].dtype
self.assertEqual(
keras.utils.standardize_dtype(layer.dtype),
keras.utils.standardize_dtype(output_dtype),
keras.backend.standardize_dtype(layer.dtype),
keras.backend.standardize_dtype(output_dtype),
msg="Unexpected output dtype",
)
if eager and expected_output_data is not None:
Expand Down
8 changes: 4 additions & 4 deletions keras_nlp/src/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def convert_to_backend_tensor_or_python_list(x):
if isinstance(x, tf.RaggedTensor) or getattr(x, "dtype", None) == tf.string:
return tensor_to_list(x)
dtype = getattr(x, "dtype", "float32")
dtype = keras.utils.standardize_dtype(dtype)
dtype = keras.backend.standardize_dtype(dtype)
return ops.convert_to_tensor(x, dtype=dtype)


Expand Down Expand Up @@ -160,15 +160,15 @@ def is_tensor_type(x):


def is_float_dtype(dtype):
return "float" in keras.utils.standardize_dtype(dtype)
return "float" in keras.backend.standardize_dtype(dtype)


def is_int_dtype(dtype):
return "int" in keras.utils.standardize_dtype(dtype)
return "int" in keras.backend.standardize_dtype(dtype)


def is_string_dtype(dtype):
return "string" in keras.utils.standardize_dtype(dtype)
return "string" in keras.backend.standardize_dtype(dtype)


def any_equal(inputs, values, padding_mask):
Expand Down
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,11 @@ def get_version(rel_path):
author_email="keras-nlp@google.com",
license="Apache License 2.0",
install_requires=[
"keras-core",
"absl-py",
"numpy",
"packaging",
"regex",
"rich",
"dm-tree",
"kagglehub",
# Don't require tensorflow-text on MacOS, there are no binaries for ARM.
# Also, we rely on tensorflow *transitively* through tensorflow-text.
Expand Down

0 comments on commit e05bdb0

Please sign in to comment.