Skip to content

Commit

Permalink
Fix naming in mobilenet_v3 and densenet (#759)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Aug 21, 2023
1 parent 523c235 commit d7d93f8
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 25 deletions.
6 changes: 3 additions & 3 deletions keras_core/applications/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ def DenseNet(
bn_axis = 3 if backend.image_data_format() == "channels_last" else 1

x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)
x = layers.Conv2D(64, 7, strides=2, use_bias=False, name="conv1/conv")(x)
x = layers.Conv2D(64, 7, strides=2, use_bias=False, name="conv1_conv")(x)
x = layers.BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, name="conv1/bn"
axis=bn_axis, epsilon=1.001e-5, name="conv1_bn"
)(x)
x = layers.Activation("relu", name="conv1/relu")(x)
x = layers.Activation("relu", name="conv1_relu")(x)
x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(x)
x = layers.MaxPooling2D(3, strides=2, name="pool1")(x)

Expand Down
38 changes: 19 additions & 19 deletions keras_core/applications/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ def MobileNetV3(
strides=(2, 2),
padding="same",
use_bias=False,
name="Conv",
name="conv",
)(x)
x = layers.BatchNormalization(
axis=channel_axis, epsilon=1e-3, momentum=0.999, name="Conv/BatchNorm"
axis=channel_axis, epsilon=1e-3, momentum=0.999, name="conv_bn"
)(x)
x = activation(x)

Expand All @@ -330,10 +330,10 @@ def MobileNetV3(
kernel_size=1,
padding="same",
use_bias=False,
name="Conv_1",
name="conv_1",
)(x)
x = layers.BatchNormalization(
axis=channel_axis, epsilon=1e-3, momentum=0.999, name="Conv_1/BatchNorm"
axis=channel_axis, epsilon=1e-3, momentum=0.999, name="conv_1_bn"
)(x)
x = activation(x)
if include_top:
Expand All @@ -343,19 +343,19 @@ def MobileNetV3(
kernel_size=1,
padding="same",
use_bias=True,
name="Conv_2",
name="conv_2",
)(x)
x = activation(x)

if dropout_rate > 0:
x = layers.Dropout(dropout_rate)(x)
x = layers.Conv2D(
classes, kernel_size=1, padding="same", name="Logits"
classes, kernel_size=1, padding="same", name="logits"
)(x)
x = layers.Flatten()(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = layers.Activation(
activation=classifier_activation, name="Predictions"
activation=classifier_activation, name="predictions"
)(x)
else:
if pooling == "avg":
Expand Down Expand Up @@ -559,23 +559,23 @@ def _depth(v, divisor=8, min_value=None):

def _se_block(inputs, filters, se_ratio, prefix):
x = layers.GlobalAveragePooling2D(
keepdims=True, name=prefix + "squeeze_excite/AvgPool"
keepdims=True, name=prefix + "squeeze_excite_avg_pool"
)(inputs)
x = layers.Conv2D(
_depth(filters * se_ratio),
kernel_size=1,
padding="same",
name=prefix + "squeeze_excite/Conv",
name=prefix + "squeeze_excite_conv",
)(x)
x = layers.ReLU(name=prefix + "squeeze_excite/Relu")(x)
x = layers.ReLU(name=prefix + "squeeze_excite_relu")(x)
x = layers.Conv2D(
filters,
kernel_size=1,
padding="same",
name=prefix + "squeeze_excite/Conv_1",
name=prefix + "squeeze_excite_conv_1",
)(x)
x = hard_sigmoid(x)
x = layers.Multiply(name=prefix + "squeeze_excite/Mul")([inputs, x])
x = layers.Multiply(name=prefix + "squeeze_excite_mul")([inputs, x])
return x


Expand All @@ -584,11 +584,11 @@ def _inverted_res_block(
):
channel_axis = 1 if backend.image_data_format() == "channels_first" else -1
shortcut = x
prefix = "expanded_conv/"
prefix = "expanded_conv_"
infilters = x.shape[channel_axis]
if block_id:
# Expand
prefix = f"expanded_conv_{block_id}/"
prefix = f"expanded_conv_{block_id}_"
x = layers.Conv2D(
_depth(infilters * expansion),
kernel_size=1,
Expand All @@ -600,14 +600,14 @@ def _inverted_res_block(
axis=channel_axis,
epsilon=1e-3,
momentum=0.999,
name=prefix + "expand/BatchNorm",
name=prefix + "expand_bn",
)(x)
x = activation(x)

if stride == 2:
x = layers.ZeroPadding2D(
padding=imagenet_utils.correct_pad(x, kernel_size),
name=prefix + "depthwise/pad",
name=prefix + "depthwise_pad",
)(x)
x = layers.DepthwiseConv2D(
kernel_size,
Expand All @@ -620,7 +620,7 @@ def _inverted_res_block(
axis=channel_axis,
epsilon=1e-3,
momentum=0.999,
name=prefix + "depthwise/BatchNorm",
name=prefix + "depthwise_bn",
)(x)
x = activation(x)

Expand All @@ -638,11 +638,11 @@ def _inverted_res_block(
axis=channel_axis,
epsilon=1e-3,
momentum=0.999,
name=prefix + "project/BatchNorm",
name=prefix + "project_bn",
)(x)

if stride == 1 and infilters == filters:
x = layers.Add(name=prefix + "Add")([shortcut, x])
x = layers.Add(name=prefix + "add")([shortcut, x])
return x


Expand Down
7 changes: 4 additions & 3 deletions keras_core/ops/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ class Operation:
def __init__(self, name=None):
if name is None:
name = auto_name(self.__class__.__name__)
if not isinstance(name, str):
if not isinstance(name, str) or "/" in name:
raise ValueError(
"Argument `name` should be a string. "
f"Received instead: name={name} (of type {type(name)})"
"Argument `name` must be a string and "
"cannot contain character `/`. "
f"Received: name={name} (of type {type(name)})"
)
self.name = name
self._inbound_nodes = []
Expand Down
8 changes: 8 additions & 0 deletions keras_core/ops/operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,11 @@ def test_input_conversion(self):
out = op(x, y, z)
self.assertTrue(backend.is_tensor(out))
self.assertAllClose(out, 6 * np.ones((2,)))

def test_valid_naming(self):
OpWithMultipleOutputs(name="test_op")

with self.assertRaisesRegex(
ValueError, "must be a string and cannot contain character `/`."
):
OpWithMultipleOutputs(name="test/op")

0 comments on commit d7d93f8

Please sign in to comment.