Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] Sequential uses regular list #3909

Merged
merged 2 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from .nnx.nn.attention import make_attention_mask as make_attention_mask
from .nnx.nn.attention import make_causal_mask as make_causal_mask
from .nnx.nn.linear import Conv as Conv
from .nnx.nn.linear import ConvTranspose as ConvTranspose
from .nnx.nn.linear import Embed as Embed
from .nnx.nn.linear import Linear as Linear
from .nnx.nn.linear import LinearGeneral as LinearGeneral
Expand Down
21 changes: 10 additions & 11 deletions flax/experimental/nnx/nnx/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,12 @@

class Dict(Module, tp.Mapping[str, A]):
@tp.overload
def __init__(self, iterable: tp.Iterable[tp.Tuple[str, A]], /):
...
def __init__(self, iterable: tp.Iterable[tp.Tuple[str, A]], /): ...

@tp.overload
def __init__(
self, mapping: tp.Optional[tp.Mapping[str, A]] = None, /, **kwargs: A
):
...
): ...

def __init__(self, *args, **kwargs):
for name, value in dict(*args, **kwargs).items():
Expand Down Expand Up @@ -126,15 +124,18 @@ def _graph_node_pop_key(self, key: Key):
return super()._graph_node_pop_key(key)


class Sequential(List):
class Sequential(Module):
def __init__(self, *fns: tp.Callable[..., tp.Any]):
self.layers = list(fns)

def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) -> tp.Any:
output: tp.Any = None

for i, f in enumerate(self):
for i, f in enumerate(self.layers):
if not callable(f):
raise TypeError(f'Sequence[{i}] is not callable: {f}')
if i > 0:
if isinstance(output, tp.Tuple):
if isinstance(output, tuple):
args = output
kwargs = {}
elif isinstance(output, dict):
Expand All @@ -154,8 +155,7 @@ def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) -> tp.Any:
class ModuleDefApply(tp.Protocol, tp.Generic[M]):
def __call__(
self, state: State, *states: State
) -> ApplyCaller[tuple[State, GraphDef[M]]]:
...
) -> ApplyCaller[tuple[State, GraphDef[M]]]: ...


class TrainState(tp.Generic[M], struct.PyTreeNode):
Expand Down Expand Up @@ -186,8 +186,7 @@ def create(

if tp.TYPE_CHECKING:

def __getattr__(self, key: str) -> tp.Any:
...
def __getattr__(self, key: str) -> tp.Any: ...

def apply(
self, state: tp.Union[State, str], *states: tp.Union[State, str]
Expand Down
219 changes: 212 additions & 7 deletions flax/experimental/nnx/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from flax.experimental.nnx.nnx.module import Module, first_from
from flax.experimental.nnx.nnx.nn import dtypes, initializers
from flax.typing import (
Array,
Dtype,
Shape,
Initializer,
Expand All @@ -52,11 +51,13 @@
LaxPadding,
)

Array = jax.Array
Axis = int
Size = int


default_kernel_init = initializers.lecun_normal()
default_bias_init = initializers.zeros_init()


def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
Expand Down Expand Up @@ -158,7 +159,7 @@ def __init__(
dtype: Dtype | None = None,
param_dtype: Dtype = jnp.float32,
kernel_init: Initializer = default_kernel_init,
bias_init: Initializer = initializers.zeros_init(),
bias_init: Initializer = default_bias_init,
precision: PrecisionLike = None,
# Deprecated. Will be removed.
dot_general: DotGeneralT | None = None,
Expand Down Expand Up @@ -316,7 +317,7 @@ def __init__(
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
kernel_init: Initializer = default_kernel_init,
bias_init: Initializer = initializers.zeros_init(),
bias_init: Initializer = default_bias_init,
dot_general: DotGeneralT = lax.dot_general,
rngs: rnglib.Rngs,
):
Expand Down Expand Up @@ -410,7 +411,7 @@ def __init__(
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
kernel_init: Initializer = default_kernel_init,
bias_init: Initializer = initializers.zeros_init(),
bias_init: Initializer = default_bias_init,
rngs: rnglib.Rngs,
):
einsum_str = einsum_str.replace(' ', '')
Expand Down Expand Up @@ -576,7 +577,7 @@ def __init__(
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
kernel_init: Initializer = default_kernel_init,
bias_init: Initializer = initializers.zeros_init(),
bias_init: Initializer = default_bias_init,
conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated,
rngs: rnglib.Rngs,
):
Expand Down Expand Up @@ -641,7 +642,7 @@ def __call__(self, inputs: Array) -> Array:

def maybe_broadcast(
x: tp.Optional[tp.Union[int, tp.Sequence[int]]],
) -> tp.Tuple[int, ...]:
) -> tuple[int, ...]:
if x is None:
# backward compatibility with using None as sentinel for
# broadcast 1
Expand Down Expand Up @@ -670,7 +671,7 @@ def maybe_broadcast(
kernel_size_dilated = [
(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)
]
zero_pad: tp.List[tp.Tuple[int, int]] = [(0, 0)]
zero_pad: tp.List[tuple[int, int]] = [(0, 0)]
pads = (
zero_pad
+ [((k - 1) // 2, k // 2) for k in kernel_size_dilated]
Expand Down Expand Up @@ -725,6 +726,210 @@ def maybe_broadcast(
y = jnp.reshape(y, output_shape)
return y

class ConvTranspose(Module):
# features: int
# kernel_size: Union[int, Sequence[int]]
# strides: Optional[Sequence[int]] = None
# padding: PaddingLike = 'SAME'
# kernel_dilation: Optional[Sequence[int]] = None
# use_bias: bool = True
# mask: Optional[Array] = None
# dtype: Optional[Dtype] = None
# param_dtype: Dtype = jnp.float32
# precision: PrecisionLike = None
# kernel_init: Initializer = default_kernel_init
# bias_init: Initializer = initializers.zeros_init()
# transpose_kernel: bool = False

def __init__(
self,
in_features: int,
out_features: int,
kernel_size: int | tp.Sequence[int],
strides: int | tp.Sequence[int] | None = None,
*,
padding: PaddingLike = 'SAME',
kernel_dilation: int | tp.Sequence[int] | None = None,
use_bias: bool = True,
mask: Array | None = None,
dtype: Dtype | None = None,
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike | None = None,
kernel_init: Initializer = default_kernel_init,
bias_init: Initializer = default_bias_init,
transpose_kernel: bool = False,
rngs: rnglib.Rngs,
):
if isinstance(kernel_size, int):
kernel_size = (kernel_size,)
else:
kernel_size = tuple(kernel_size)

self.kernel_size = kernel_size
self.in_features = in_features
self.out_features = out_features
self.strides = strides
self.padding = padding
self.kernel_dilation = kernel_dilation
self.use_bias = use_bias
self.mask = mask
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.kernel_init = kernel_init
self.bias_init = bias_init
self.transpose_kernel = transpose_kernel

if self.transpose_kernel:
kernel_shape = kernel_size + (self.out_features, in_features)
else:
kernel_shape = kernel_size + (in_features, self.out_features)

self.kernel_shape = kernel_shape
self.kernel = nnx.Param(
self.kernel_init(rngs.params(), kernel_shape, self.param_dtype)
)

if self.use_bias:
self.bias = nnx.Param(
self.bias_init(rngs.params(), (self.out_features,), self.param_dtype)
)
else:
self.bias = None

def __call__(self, inputs: Array) -> Array:
"""Applies a transposed convolution to the inputs.

Behaviour mirrors of ``jax.lax.conv_transpose``.

Args:
inputs: input data with dimensions (*batch_dims, spatial_dims...,
features). This is the channels-last convention, i.e. NHWC for a 2d
convolution and NDHWC for a 3D convolution. Note: this is different from
the input convention used by ``lax.conv_general_dilated``, which puts the
spatial dimensions last.
Note: If the input has more than 1 batch dimension, all batch dimensions
are flattened into a single dimension for the convolution and restored
before returning. In some cases directly vmap'ing the layer may yield
better performance than this default flattening approach. If the input
lacks a batch dimension it will be added for the convolution and removed
n return, an allowance made to enable writing single-example code.

Returns:
The convolved data.
"""
kernel_size = self.kernel_size

def maybe_broadcast(
x: tp.Optional[tp.Union[int, tp.Sequence[int]]],
) -> tuple[int, ...]:
if x is None:
# backward compatibility with using None as sentinel for
# broadcast 1
x = 1
if isinstance(x, int):
return (x,) * len(kernel_size)
return tuple(x)

# Combine all input batch dimensions into a single leading batch axis.
num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1)
if num_batch_dimensions != 1:
input_batch_shape = inputs.shape[:num_batch_dimensions]
total_batch_size = int(np.prod(input_batch_shape))
flat_input_shape = (total_batch_size,) + inputs.shape[
num_batch_dimensions:
]
inputs = jnp.reshape(inputs, flat_input_shape)

strides = maybe_broadcast(self.strides)
kernel_dilation = maybe_broadcast(self.kernel_dilation)

kernel_shape = self.kernel_shape

if self.mask is not None and self.mask.shape != kernel_shape:
raise ValueError(
'Mask needs to have the same shape as weights. '
f'Shapes are: {self.mask.shape}, {kernel_shape}'
)

kernel = self.kernel.value

if self.mask is not None:
kernel *= self.mask

padding_lax = canonicalize_padding(self.padding, len(kernel_size))
if padding_lax == 'CIRCULAR':
padding_lax = 'VALID'

bias = self.bias.value if self.bias is not None else None

inputs, kernel, bias = dtypes.promote_dtype(
inputs, kernel, bias, dtype=self.dtype
)

y = lax.conv_transpose(
inputs,
kernel,
strides,
padding_lax,
rhs_dilation=kernel_dilation,
transpose_kernel=self.transpose_kernel,
precision=self.precision,
)

if self.padding == 'CIRCULAR':
# For circular padding, we need to identify the size of the final output
# ("period") along each spatial dimension, pad each dimension to an
# integer number of periods, and wrap the array periodically around each
# dimension. Padding should be done in such a way that the start of the
# original input data inside the padded array is located at integer
# number of periods - otherwise the result would be circularly shifted.

# Compute period along each spatial dimension - it's input size scaled
# by the stride.
scaled_x_dims = [
x_dim * stride
for x_dim, stride in zip(jnp.shape(inputs)[1:-1], strides)
]
# Compute difference between the current size of y and the final output
# size, and complement this difference to 2 * period - that gives how
# much we need to pad.
size_diffs = [
-(y_dim - x_dim) % (2 * x_dim)
for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims)
]
if self.transpose_kernel:
# If the kernel is transposed, the "+1" is put on the right to
# mirror the regular convolution. If the same kernel parameters are used
# as for Conv, this layer then computes the proper transpose convolution.
total_pad = [
(size_diff // 2, (size_diff + 1) // 2) for size_diff in size_diffs
]
else:
# Divide the padding equally between left and right. The choice to put
# "+1" on the left (and not on the right) represents a convention for
# aligning even-sized kernels.
total_pad = [
((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs
]
y = jnp.pad(y, [(0, 0)] + total_pad + [(0, 0)])
# Wrap the result periodically around each spatial dimension,
# one by one.
for i in range(1, y.ndim - 1):
y = y.reshape(
y.shape[:i] + (-1, scaled_x_dims[i - 1]) + y.shape[i + 1 :]
)
y = y.sum(axis=i)

if self.use_bias:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) # type: ignore

if num_batch_dimensions != 1:
output_shape = input_batch_shape + y.shape[1:]
y = jnp.reshape(y, output_shape)

return y


default_embed_init = initializers.variance_scaling(
1.0, 'fan_in', 'normal', out_axis=0
Expand Down
Loading