-
Notifications
You must be signed in to change notification settings - Fork 0
/
playground.py
170 lines (134 loc) · 6.13 KB
/
playground.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import tensorflow_datasets as tfds # TFDS for MNIST
import tensorflow as tf # TensorFlow operations
from typing import Callable, Any, Protocol, Tuple, NamedTuple
from dataclasses import dataclass
import jax
import jax.numpy as jnp # JAX NumPy
from flax import linen as nn # Linen API
from clu import metrics
from flax import struct
import optax # Common loss functions and optimizers
import chex
jax.config.update("jax_traceback_filtering", 'off')
#jax.config.update("jax_debug_nans", True)
import optimizers
import training
def get_image_datasets(name, batch_size):
"""Load MNIST train and test datasets into memory."""
(train_ds, test_ds) = tfds.load(name, split=['train', 'test'], as_supervised=True)
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., tf.cast(label, tf.int32)
train_ds = train_ds.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
train_ds = train_ds.batch(batch_size, drop_remainder=True) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
test_ds = test_ds.batch(batch_size, drop_remainder=True) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
test_ds = test_ds.prefetch(tf.data.AUTOTUNE)
return train_ds, test_ds
class CNN(nn.Module):
"""A simple CNN model."""
num_classes: int
def setup(self):
# Submodule names are derived by the attributes you assign to. In this
# case, "dense1" and "dense2". This follows the logic in PyTorch.
self.conv1 = nn.Conv(features=32, kernel_size=(3, 3), kernel_init=jax.nn.initializers.glorot_normal())
self.conv2 = nn.Conv(features=64, kernel_size=(3, 3), kernel_init=jax.nn.initializers.glorot_normal())
self.dense1 = nn.Dense(features=256, kernel_init=jax.nn.initializers.glorot_normal())
self.dense2 = nn.Dense(features=self.num_classes, kernel_init=jax.nn.initializers.glorot_normal())
def __call__(self, x):
x = self.conv1(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = self.conv2(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = self.dense1(x)
x = nn.relu(x)
x = self.dense2(x)
return x
class VGG16(nn.Module):
num_classes: int
dropout_rate: float = 0.2
output: str='linear'
dtype: str='float32'
class ConvBlock(nn.Module):
features: int
num_layers: int
dtype: str
def setup(self):
layers = []
for l in range(self.num_layers):
layers.append(nn.Conv(features=self.features, kernel_size=(3, 3), padding='same', dtype=self.dtype, kernel_init=jax.nn.initializers.glorot_normal()))
self.layers = layers
def __call__(self, x):
for l in self.layers:
x = l(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
return x
def setup(self):
self.conv1 = self.ConvBlock(features=32, num_layers=2, dtype=self.dtype)
self.conv2 = self.ConvBlock(features=64, num_layers=2, dtype=self.dtype)
self.conv3 = self.ConvBlock(features=128, num_layers=2, dtype=self.dtype)
self.dense1 = nn.Dense(features=128, kernel_init=jax.nn.initializers.glorot_normal())
self.dense2 = nn.Dense(features=self.num_classes, kernel_init=jax.nn.initializers.glorot_normal())
@nn.compact
def __call__(self, x, training=False):
if self.output not in ['softmax', 'log_softmax', 'sigmoid', 'linear', 'log_sigmoid']:
raise ValueError('Wrong argument. Possible choices for output are "softmax", "sigmoid", "log_sigmoid", "linear".')
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = x.reshape((x.shape[0], -1)) # flatten
# Fully conected
#x = jnp.mean(x, axis=(2, 3))
x = self.dense1(x)
x = nn.relu(x)
#x = nn.BatchNorm()(x, use_running_average=not training)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not training)
return self.dense2(x)
@struct.dataclass
class Metrics(metrics.Collection):
accuracy: metrics.Accuracy
loss: metrics.Average.from_output('loss')
def crossentrop_loss(y_pred, y):
return optax.softmax_cross_entropy_with_integer_labels(logits=y_pred, labels=y).sum()
def create_img_train_state(module, dims, rng, state_class, tx):
"""Creates an initial `TrainState`."""
params = module.init(rng, jnp.empty(dims))['params'] # initialize parameters by passing a template image
opt_state = tx.init(params)
return state_class(
apply_fn=module.apply, params=params, tx=tx,
opt_state=opt_state,
loss_fn=crossentrop_loss,
loss_hessian_fn=optimizers.sample_crossentropy_hessian,
rng_key=rng,
initial_metrics=Metrics)
num_epochs = 10
batch_size = 32
tf.random.set_seed(0)
#module = CNN(10)
module = VGG16(10)
rng = jax.random.PRNGKey(0)
dims = [1, 32, 32, 3]
print(module.tabulate(rng, jnp.ones(dims)))
tx = optimizers.kalman_blockwise_spectral_transformation(1.0, 1.0, 16, 48, jax.random.PRNGKey(0))
#tx = optimizers.kalman_blockwise_trace_transformation(1.0, 1.0)
#tx = optax.sgd(0.005 / batch_size, 0.9)
train_ds, test_ds = get_image_datasets('cifar10', batch_size)
state = create_img_train_state(module, dims, jax.random.PRNGKey(1), training.NaturalTrainState, tx)
#state = create_img_train_state(module, dims, jax.random.PRNGKey(1), training.TrainState, tx)
enable_tracing= False
for epoch in range(num_epochs):
if epoch > 0 and enable_tracing:
jax.profiler.start_trace("./jax-trace", create_perfetto_trace=True)
state = training.train(train_ds.as_numpy_iterator(), state, train_ds.cardinality().numpy())
training.test(test_ds.as_numpy_iterator(), state)
if enable_tracing:
jax.profiler.stop_trace()