Skip to content

Commit

Permalink
Adding: Numpy Backend (keras-team#483)
Browse files Browse the repository at this point in the history
* chore: adding numpy backend

* creview comments

* review comments

* chore: adding math

* chore: adding random module

* chore: adding ranndom in init

* review comments

* chore: adding numpy and nn for numpy backend

* chore: adding generic pool, max, and average pool

* chore: adding the conv ops

* chore: reformat code and using jax for conv and pool

* chore:  added self value

* chore: activation tests pass

* chore: adding post build method

* chore: adding necessaity methods to the numpy trainer

* chore: fixing utils test

* chore: fixing losses test suite

* chore: fix backend tests

* chore: fixing initializers test

* chore: fixing accuracy metrics test

* chore: fixing ops test

* chore: review comments

* chore: init with image and fixing random tests

* chore: skipping random seed set for numpy backend

* chore: adding single resize image method

* chore: skipping tests for applications and layers

* chore: skipping tests for models

* chore: skipping testsor saving

* chore: skipping tests for trainers

* chore:ixing one hot

* chore: fixing vmap in numpy and metrics test

* chore: adding a wrapper to numpy sum, started fixing layer tests

* fix: is_tensor now accepts numpy scalars

* chore: adding draw seed

* fix: warn message for numpy masking

* fix: checking whether kernel are tensors

* chore: adding rnn

* chore: adding dynamic backend for numpy

* fix: axis cannot be None for normalize

* chore: adding jax resize for numpy image

* chore: adding rnn implementation in numpy

* chore: using pytest fixtures

* change: numpy import string

* chore: review comments

* chore: adding numpy to backend list of github actions

* chore: remove debug print statements
  • Loading branch information
ariG23498 authored and adi-kmt committed Jul 21, 2023
1 parent 20d088c commit 66976ae
Show file tree
Hide file tree
Showing 105 changed files with 2,088 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
backend: [tensorflow, jax, torch]
backend: [tensorflow, jax, torch, numpy]
name: Run tests
runs-on: ubuntu-latest
env:
Expand Down
21 changes: 21 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,24 @@
import torch # noqa: F401
except ImportError:
pass

import pytest

from keras_core.backend import backend


def pytest_configure(config):
config.addinivalue_line(
"markers",
"requires_trainable_backend: mark test for trainable backend only",
)


def pytest_collection_modifyitems(config, items):
requires_trainable_backend = pytest.mark.skipif(
backend() == "numpy",
reason="Trainer not implemented for NumPy backend.",
)
for item in items:
if "requires_trainable_backend" in item.keywords:
item.add_marker(requires_trainable_backend)
1 change: 1 addition & 0 deletions keras_core/applications/applications_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _get_elephant(target_size):
os.environ.get("SKIP_APPLICATIONS_TESTS"),
reason="Env variable set to skip.",
)
@pytest.mark.requires_trainable_backend
class ApplicationsTest(testing.TestCase, parameterized.TestCase):
@parameterized.named_parameters(MODEL_LIST)
def test_application_notop_variable_input_channels(self, app, last_dim, _):
Expand Down
2 changes: 2 additions & 0 deletions keras_core/applications/imagenet_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest
from absl.testing import parameterized

import keras_core as keras
Expand Down Expand Up @@ -74,6 +75,7 @@ def test_preprocess_input(self):
{"testcase_name": "mode_caffe", "mode": "caffe"},
]
)
@pytest.mark.requires_trainable_backend
def test_preprocess_input_symbolic(self, mode):
# Test image batch
x = np.random.uniform(0, 255, (2, 10, 10, 3))
Expand Down
7 changes: 7 additions & 0 deletions keras_core/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,12 @@
elif backend() == "torch":
print_msg("Using PyTorch backend.")
from keras_core.backend.torch import * # noqa: F403
elif backend() == "numpy":
print_msg(
"Using NumPy backend.\nThe NumPy backend does not support "
"training. It should only be used for inference, evaluation, "
"and debugging."
)
from keras_core.backend.numpy import * # noqa: F403
else:
raise ValueError(f"Unable to import backend : {backend()}")
19 changes: 19 additions & 0 deletions keras_core/backend/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,20 @@
from keras_core.backend.numpy import core
from keras_core.backend.numpy import image
from keras_core.backend.numpy import math
from keras_core.backend.numpy import nn
from keras_core.backend.numpy import numpy
from keras_core.backend.numpy import random
from keras_core.backend.numpy.core import DYNAMIC_SHAPES_OK
from keras_core.backend.numpy.core import Variable
from keras_core.backend.numpy.core import cast
from keras_core.backend.numpy.core import compute_output_spec
from keras_core.backend.numpy.core import cond
from keras_core.backend.numpy.core import convert_to_numpy
from keras_core.backend.numpy.core import convert_to_tensor
from keras_core.backend.numpy.core import is_tensor
from keras_core.backend.numpy.core import name_scope
from keras_core.backend.numpy.core import shape
from keras_core.backend.numpy.core import vectorized_map
from keras_core.backend.numpy.rnn import gru
from keras_core.backend.numpy.rnn import lstm
from keras_core.backend.numpy.rnn import rnn
212 changes: 212 additions & 0 deletions keras_core/backend/numpy/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from contextlib import nullcontext

import numpy as np
from tensorflow import nest

from keras_core.backend.common import KerasVariable
from keras_core.backend.common import standardize_dtype
from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.backend.common.stateless_scope import StatelessScope

DYNAMIC_SHAPES_OK = True


class Variable(KerasVariable):
def _initialize(self, value):
self._value = np.array(value, dtype=self._dtype)

def _direct_assign(self, value):
self._value = np.array(value, dtype=self._dtype)

def _convert_to_tensor(self, value, dtype=None):
return convert_to_tensor(value, dtype=dtype)

# Overload native accessor.
def __array__(self):
return self.value


def convert_to_tensor(x, dtype=None):
if dtype is not None:
dtype = standardize_dtype(dtype)
if isinstance(x, Variable):
if dtype and dtype != x.dtype:
return x.value.astype(dtype)
return x.value
return np.array(x, dtype=dtype)


def convert_to_numpy(x):
return np.array(x)


def is_tensor(x):
if isinstance(x, (np.generic, np.ndarray)):
return True
return False


def shape(x):
return x.shape


def cast(x, dtype):
return convert_to_tensor(x, dtype=dtype)


def cond(pred, true_fn, false_fn):
if pred:
return true_fn()
return false_fn()


def name_scope(name):
# There is no need for a named context for NumPy.
return nullcontext()


def vectorized_map(function, elements):
if len(elements) == 1:
return function(elements)
else:
batch_size = elements[0].shape[0]
output_store = list()
for index in range(batch_size):
output_store.append(function([x[index] for x in elements]))
return np.stack(output_store)


# Shape / dtype inference util
def compute_output_spec(fn, *args, **kwargs):
with StatelessScope():

def has_none_shape(x):
if isinstance(x, KerasTensor):
return None in x.shape
return False

none_in_shape = any(map(has_none_shape, nest.flatten((args, kwargs))))

def convert_keras_tensor_to_numpy(x, fill_value=None):
if isinstance(x, KerasTensor):
shape = list(x.shape)
if fill_value:
for i, e in enumerate(shape):
if e is None:
shape[i] = fill_value
return np.empty(
shape=shape,
dtype=x.dtype,
)
return x

args_1, kwargs_1 = nest.map_structure(
lambda x: convert_keras_tensor_to_numpy(x, fill_value=83),
(args, kwargs),
)
outputs_1 = fn(*args_1, **kwargs_1)

outputs = outputs_1

if none_in_shape:
args_2, kwargs_2 = nest.map_structure(
lambda x: convert_keras_tensor_to_numpy(x, fill_value=89),
(args, kwargs),
)
outputs_2 = fn(*args_2, **kwargs_2)

flat_out_1 = nest.flatten(outputs_1)
flat_out_2 = nest.flatten(outputs_2)

flat_out = []
for x1, x2 in zip(flat_out_1, flat_out_2):
shape = list(x1.shape)
for i, e in enumerate(x2.shape):
if e != shape[i]:
shape[i] = None
flat_out.append(KerasTensor(shape, standardize_dtype(x1.dtype)))
outputs = nest.pack_sequence_as(outputs_1, flat_out)

def convert_numpy_to_keras_tensor(x):
if is_tensor(x):
return KerasTensor(x.shape, standardize_dtype(x.dtype))
return x

output_spec = nest.map_structure(convert_numpy_to_keras_tensor, outputs)
return output_spec


def scatter(indices, values, shape):
indices = convert_to_tensor(indices)
values = convert_to_tensor(values)
zeros = np.zeros(shape, dtype=values.dtype)

index_length = indices.shape[-1]
value_shape = shape[index_length:]
indices = np.reshape(indices, [-1, index_length])
values = np.reshape(values, [-1] + list(value_shape))

for i in range(indices.shape[0]):
index = indices[i]
zeros[tuple(index)] += values[i]
return zeros


def scatter_update(inputs, indices, updates):
indices = np.array(indices)
indices = np.transpose(indices)
inputs[tuple(indices)] = updates
return inputs


def slice(inputs, start_indices, lengths):
# Validate inputs
assert len(start_indices) == len(lengths)

# Generate list of indices arrays for each dimension
indices = [
np.arange(start, start + length)
for start, length in zip(start_indices, lengths)
]

# Use np.ix_ to create a multidimensional index array
mesh = np.ix_(*indices)

return inputs[mesh]


def slice_update(inputs, start_indices, updates):
# Generate list of indices arrays for each dimension
indices = [
np.arange(start, start + length)
for start, length in zip(start_indices, updates.shape)
]

# Use np.ix_ to create a multidimensional index array
mesh = np.ix_(*indices)
inputs[mesh] = updates
return inputs


def while_loop(
cond,
body,
loop_vars,
maximum_iterations=None,
):
current_iter = 0
iteration_check = (
lambda iter: maximum_iterations is None or iter < maximum_iterations
)
loop_vars = tuple([convert_to_tensor(v) for v in loop_vars])
while cond(*loop_vars) and iteration_check(current_iter):
loop_vars = body(*loop_vars)
if not isinstance(loop_vars, (list, tuple)):
loop_vars = (loop_vars,)
loop_vars = tuple(loop_vars)
current_iter += 1
return loop_vars


def stop_gradient(x):
pass
45 changes: 45 additions & 0 deletions keras_core/backend/numpy/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import jax
import numpy as np

RESIZE_METHODS = (
"bilinear",
"nearest",
"lanczos3",
"lanczos5",
"bicubic",
)


def resize(
image, size, method="bilinear", antialias=False, data_format="channels_last"
):
if method not in RESIZE_METHODS:
raise ValueError(
"Invalid value for argument `method`. Expected of one "
f"{RESIZE_METHODS}. Received: method={method}"
)
if not len(size) == 2:
raise ValueError(
"Argument `size` must be a tuple of two elements "
f"(height, width). Received: size={size}"
)
size = tuple(size)
if len(image.shape) == 4:
if data_format == "channels_last":
size = (image.shape[0],) + size + (image.shape[-1],)
else:
size = (image.shape[0], image.shape[1]) + size
elif len(image.shape) == 3:
if data_format == "channels_last":
size = size + (image.shape[-1],)
else:
size = (image.shape[0],) + size
else:
raise ValueError(
"Invalid input rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
return np.array(
jax.image.resize(image, size, method=method, antialias=antialias)
)
3 changes: 3 additions & 0 deletions keras_core/backend/numpy/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class NumpyLayer:
def _post_build(self):
pass
Loading

0 comments on commit 66976ae

Please sign in to comment.