forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor dtypes and add float8_* dtypes (keras-team#19401)
* Refactor dtypes in codebase and add float8_* dtypes * Update comments Fix for JAX export on GPU. (keras-team#19404) Fix formatting in export_lib. (keras-team#19405) `ops/numpy.py`: Support `key` as `list` in `GetItem` (keras-team#19310) When loading a model that contains `GetItem` nodes with multidimensional indices/slices as `key`, the `key` argument is loaded from JSON as a `list`, not a `tuple` (because JSON does not have the distinction). So, treat the `key list` as equivalent to the `key tuple`. Copying is important: otherwise, the later `pop()` will remove the bound slice elements from the op itself. `saving/serialization_lib_test.py`: * Add `test_numpy_get_item_layer()`: test for consistent serialization/deserialization of a model which contains `ops.numpy.GetItem`; feat(losses): add Dice loss implementation (keras-team#19409) * feat(losses): add Dice loss implementation * removed smooth parameter and type casting * adjusted casting and dot operator Update casting Bump the github-actions group with 1 update (keras-team#19412) Bumps the github-actions group with 1 update: [github/codeql-action](https://github.com/github/codeql-action). Updates `github/codeql-action` from 3.24.6 to 3.24.9 - [Release notes](https://github.com/github/codeql-action/releases) - [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md) - [Commits](github/codeql-action@8a470fd...1b1aada) --- updated-dependencies: - dependency-name: github/codeql-action dependency-type: direct:production update-type: version-update:semver-patch dependency-group: github-actions ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Fix issue with shared layer deserialization Remove dead code in saving lib (keras-team#19415) Remove unused beta param for silu, use torch op directly (keras-team#19417) The beta param was only accepted on the tensorflow/torch backends and not in the `keras.ops` API, nor was it tested. I think best just to ditch, since no one could be relying on it. Fix print_fn for custom function (keras-team#19419) Add fp8 to `EinsumDense` Add test script
- Loading branch information
1 parent
9eb9629
commit 898db1d
Showing
25 changed files
with
520 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import argparse | ||
|
||
import numpy as np | ||
|
||
import keras | ||
from keras import layers | ||
from keras import models | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--fp8", action="store_true") | ||
parser.add_argument("--einsum", action="store_true") | ||
return parser.parse_args() | ||
|
||
|
||
class Classifier(models.Model): | ||
def __init__(self, use_fp8=False): | ||
super().__init__() | ||
inputs = layers.Input(shape=[28, 28, 1]) | ||
x = layers.Flatten()(inputs) | ||
x = layers.Dense( | ||
64, activation="relu", use_bias=False, use_fp8=use_fp8 | ||
)(x) | ||
x = layers.Dense( | ||
64, activation="relu", use_bias=False, use_fp8=use_fp8 | ||
)(x) | ||
outputs = layers.Dense( | ||
10, activation="softmax", use_bias=False, use_fp8=use_fp8 | ||
)(x) | ||
super().__init__(inputs, outputs) | ||
|
||
|
||
class Classifier2(models.Model): | ||
def __init__(self, use_fp8=False): | ||
super().__init__() | ||
inputs = layers.Input(shape=[28, 28, 1]) | ||
x = layers.Flatten()(inputs) | ||
x = layers.EinsumDense( | ||
"ab,bc->ac", output_shape=[64], activation="relu", use_fp8=use_fp8 | ||
)(x) | ||
x = layers.EinsumDense( | ||
"ab,bc->ac", output_shape=[64], activation="relu", use_fp8=use_fp8 | ||
)(x) | ||
outputs = layers.EinsumDense( | ||
"ab,bc->ac", | ||
output_shape=[10], | ||
activation="softmax", | ||
use_fp8=use_fp8, | ||
)(x) | ||
super().__init__(inputs, outputs) | ||
|
||
|
||
args = get_args() | ||
if args.einsum: | ||
model = Classifier2(use_fp8=args.fp8) | ||
else: | ||
model = Classifier(use_fp8=args.fp8) | ||
num_classes = 10 | ||
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() | ||
x_train = x_train.astype("float32") / 255 | ||
x_test = x_test.astype("float32") / 255 | ||
x_train = np.expand_dims(x_train, -1) | ||
x_test = np.expand_dims(x_test, -1) | ||
y_train = keras.utils.to_categorical(y_train, num_classes) | ||
y_test = keras.utils.to_categorical(y_test, num_classes) | ||
|
||
model.compile( | ||
loss="categorical_crossentropy", | ||
optimizer="adam", | ||
metrics=["accuracy"], | ||
) | ||
model.fit(x_train, y_train, batch_size=128, epochs=10, validation_split=0.1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.