-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Training an FPN model using grad_req="add" causes rapid divergence, while manually implemented gradient accumulation works fine #16708
Comments
Thanks for reporting this! It looks that some operators haven't implemented |
There is also some bugs in grad accumulation as well |
I've confirmed that this issue do exist and I'm able to reproduce that. |
@samskalicky Actually, @zhreshold is investigating this issue. |
@zhreshold any update on this issue? Can you assign this issue to yourself if you're looking into it? |
At first glance it's related to implementation in contrib operator, however, when I dig into it, it's more complicated than I though, I am still investigating. |
Interestingly disable |
As long as Minimum reproducable code without GluonCV, run with
Note that I am using empty feature extraction layer import argparse
from enum import Enum
from typing import Tuple, Iterable
import mxnet as mx
import mxnet.gluon.nn as nn
from mxnet import autograd
from mxnet.gluon import HybridBlock, Trainer
from mxnet.gluon.loss import SoftmaxCrossEntropyLoss
from mxnet.initializer import Xavier
from mxnet.test_utils import assert_almost_equal
class SimpleModel(HybridBlock):
def __init__(self, context):
super().__init__()
with self.name_scope():
self.glob_avg_pool = nn.GlobalAvgPool2D()
self.linear = nn.Dense(2)
self.loss = SoftmaxCrossEntropyLoss()
def hybrid_forward(self, F, x, label, *args, **kwargs):
res = self.linear(self.glob_avg_pool(x))
return self.loss(res, label)
class FPNModel(HybridBlock):
def __init__(self, context):
super().__init__()
with self.name_scope():
base_network = resnet18_v1b(pretrained=False, dilated=False, use_global_stats=True, ctx=context)
features = FPNFeatureExpander(
network=base_network,
outputs=['layers1_relu3_fwd', 'layers2_relu3_fwd', 'layers3_relu3_fwd',
'layers4_relu3_fwd'], num_filters=[256, 256, 256, 256], use_1x1=True,
use_upsample=True, use_elewadd=True, use_p6=False, no_bias=False, pretrained=False, ctx=context)
self.features = features
self.loss = SoftmaxCrossEntropyLoss()
self.conv2d = nn.Conv2D(256, kernel_size=(3, 3), weight_initializer=Xavier(magnitude=2))
self.glob_avg_pool = nn.GlobalAvgPool2D()
self.linear = nn.Dense(2)
def hybrid_forward(self, F, x, label, *args, **kwargs):
feat = self.features(x)
res = [self.linear(self.glob_avg_pool(F.relu(self.conv2d(y)))) for y in feat]
res = F.ElementWiseSum(*res) / len(res)
return self.loss(res, label)
class DebugModel(HybridBlock):
def __init__(self, context):
super().__init__()
self.repeat = 4
with self.name_scope():
self.features = nn.HybridSequential()
self.glob_avg_pool = nn.GlobalAvgPool2D()
self.linear = nn.Dense(2)
self.loss = SoftmaxCrossEntropyLoss()
self.conv2d = nn.Conv2D(256, kernel_size=(3, 3), weight_initializer=Xavier(magnitude=2))
def hybrid_forward(self, F, x, label, *args, **kwargs):
feat = self.features(x)
feat = [feat] * self.repeat
res = [self.linear(self.glob_avg_pool(F.relu(self.conv2d(y)))) for y in feat]
res = F.ElementWiseSum(*res) / len(res)
return self.loss(res, label)
steps = 256
accumulate_over = 2
batch_size = 4
def create_trainer(model):
pattern = '|'.join(['.*dense', 'P', '.*down(2|3|4)_conv', '.*layers(2|3|4)_conv'])
trainer = Trainer(model.collect_params(pattern), "sgd", {"learning_rate": 0.01, "wd": 1e-2})
params = [x for x in model.collect_params(pattern).values() if x.grad_req != "null"]
return params, trainer
class TrainUsingGradAdd:
def __init__(self, context):
self.context = context
def train(self, model: FPNModel, data_iter: Iterable[Tuple[mx.nd.NDArray, mx.nd.NDArray]]):
params, trainer = create_trainer(model)
for p in params:
if p.grad_req != "null":
p.grad_req = "add"
p.zero_grad()
for i, (data, label) in enumerate(data_iter):
data = mx.nd.array(data.asnumpy(), ctx=self.context)
label = mx.nd.array(label.asnumpy(), ctx=self.context)
with autograd.record():
loss = model(data, label)
loss_mean = loss.mean().asscalar()
yield None, loss_mean
loss.backward()
trainer.allreduce_grads()
if i % accumulate_over == accumulate_over - 1:
yield params, loss
trainer.update(batch_size)
for p in params:
p.zero_grad()
class TrainManually:
def __init__(self, context):
self.context = context
def train(self, model: FPNModel, data_iter: Iterable[Tuple[mx.nd.NDArray, mx.nd.NDArray]]):
params, trainer = create_trainer(model)
# Create an external memory for accumulated gradients
params_memory = []
for p in params:
if p.grad_req != "null":
p.grad_req = "write"
params_memory.append(mx.nd.zeros_like(p.grad()))
for i, (data, label) in enumerate(data_iter):
# Create a copy of the input array in order to prevent any chance of sharing arrays
data = mx.nd.array(data.asnumpy(), ctx=self.context)
label = mx.nd.array(label.asnumpy(), ctx=self.context)
with autograd.record():
loss = model(data, label)
loss_mean = loss.mean().asscalar()
yield None, loss_mean
loss.backward()
trainer.allreduce_grads()
# Commit the gradients to our external memory
for j, p in enumerate(params):
params_memory[j] = params_memory[j] + p.grad()
p.zero_grad()
if i % accumulate_over == accumulate_over - 1:
# If it's time to accumulate, copy the external memory into the model's grads and update.
for j, p in enumerate(params):
for g in p._grad:
params_memory[j].copyto(g)
yield params, loss_mean
trainer.update(batch_size)
# Zero out the external memory.
for j, p in enumerate(params):
params_memory[j] = mx.nd.zeros_like(p.grad())
p.zero_grad()
class ModelType(Enum):
SIMPLE = "simple"
FPN = "fpn"
DEBUG = "debug"
def __str__(self):
return self.value
class ContextType(Enum):
CPU = "cpu"
GPU = "gpu"
def __str__(self):
return self.value
def main():
parser = argparse.ArgumentParser(description="Demonstrate MXNet gradient accumulation divergence.")
parser.add_argument("--hybridize", dest="hybridize", action="store_true", default=False)
parser.add_argument("--check-grads", dest="check_grads", action="store_true", default=False)
parser.add_argument("--set-repeat", dest="set_repeat", default=4, type=int)
parser.add_argument("--model", dest="model", type=ModelType, choices=list(ModelType), required=True)
parser.add_argument("--ctx", dest="ctx", type=ContextType, choices=list(ContextType), default=ContextType.GPU)
args = parser.parse_args()
model_class = {ModelType.SIMPLE: SimpleModel, ModelType.FPN: FPNModel, ModelType.DEBUG: DebugModel}[args.model]
context = {ContextType.CPU: mx.cpu(0), ContextType.GPU: mx.gpu(0)}[args.ctx]
# Create a prototype model
model_proto = model_class(mx.cpu(0))
model_proto.initialize()
dummy = mx.nd.zeros((batch_size, 3, 224, 224))
model_proto(dummy, mx.nd.zeros((batch_size, 1)))
# Save the prototype model and create two independent copies.
model_proto.save_parameters("tmp.dat")
model1 = model_class(context)
model2 = model_class(context)
if args.model == ModelType.DEBUG:
model1.repeat = args.set_repeat
model2.repeat = args.set_repeat
model1.load_parameters("tmp.dat", ctx=[context])
model2.load_parameters("tmp.dat", ctx=[context])
#s = model1(mx.sym.var('data', mx.sym.var('label')))
#s.to_json()
if args.hybridize:
model1.hybridize(static_shape=True, static_alloc=True)
model2.hybridize(static_shape=True, static_alloc=True)
# Create a synthetic, meaningless dataset.
data = [
(mx.nd.random_uniform(-0.1, 0.1, (batch_size, 3, 224, 224)), mx.nd.random_randint(0, 2, (batch_size, 1)))
for i in range(steps)
]
trainer1 = TrainUsingGradAdd(context)
trainer2 = TrainManually(context)
results = zip(trainer1.train(model1, data), trainer2.train(model2, data))
i = 0
for (params1, loss1), (params2, loss2) in results:
if params1 is None:
# Received a message containing the current step's losses.
print(f"[Step {i}] Loss using grad_req=add: {loss1:15.4f}, Loss using manual accumulation: {loss2:15.4f}")
i += 1
continue
elif args.check_grads:
# Received a message indicating that the trainers are about to perform an optimizer update.
print(f"[Step {i}] Checking grads...")
for j, (p1, p2) in enumerate(zip(params1, params2)):
assert_almost_equal(p1.grad(), p2.grad(), atol=1e-3, rtol=1e-3)
print(f"[Step {i}] Param {j} passed check.")
print(f"[Step {i}] Grads passed the check.")
if __name__ == '__main__':
main() |
@zhreshold I don't think that def elementwise_sum(F, *args):
return reduce(lambda x, y: F.broadcast_plus(x, y), args) |
After digging a while, I found several confusing facts about this bug.
|
@zhreshold any luck? |
Traversing autograd module, no luck yet. |
Fixed by #17995 |
Description
While working with FPNs and gradient accumulation, I've discovered that using
grad_req=add
with certain models utilizing FPNs causes almost immediate divergence. This issue is similar to #16686, but the results are much more extreme: the toy model I have provided in this issue diverges in just a couple of steps. A proprietary model (which I cannot share) diverges immediately.To Reproduce
Steps to reproduce
(Paste the commands you ran that produced the error.)
--model fpn
--model simple --check-grads
will demonstrate that this bug doesn't exhibit itself on simple models. It seems that it's pretty difficult to satisfy the conditions necessary for the model to diverge, since I've tried multiple toy models before finding one that can demonstrate the issue.--ctx cpu
) does not change the results.What have you tried to solve it?
ElementWiseSum
with a manual reduction, since the test forElementWiseSum
is flaky. This seems to have no effect at all.MXNET_ENGINE_TYPE=NaiveEngine
andMXNET_EXEC_ENABLE_INPLACE=false
, both on the CPU and GPU.Environment
We recommend using our script for collecting the diagnostic information. Run the following command and paste the outputs below:
The text was updated successfully, but these errors were encountered: