Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1404] Implement storage tagging, the first half of the memory profiler #17656

Merged
merged 1 commit into from
Mar 10, 2020

Conversation

ArmageddonKnight
Copy link
Contributor

@ArmageddonKnight ArmageddonKnight commented Feb 21, 2020

Description

implement storage tagging, the first half of the memory profiler

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • This PR is the first half of the GPU memory profiler. It implements storage tagging which adds profiler scope, name, and data structure information to each allocated storage handle.

Example

    profiler.set_state('run')

    model = nn.HybridSequential(prefix='net_')
    with model.name_scope():
        model.add(nn.Dense(128, activation='tanh'))
        model.add(nn.Dropout(0.5))
        model.add(nn.Dense(64, activation='tanh'),
                  nn.Dense(32, in_units=64))
        model.add(nn.Activation('relu'))
    model.initialize(ctx=mx.gpu())
    model.hybridize()

    inputs = mx.sym.var('data')

    with mx.autograd.record():
        out = model(mx.nd.zeros((16, 10), ctx=mx.gpu()))
    out.backward()
    mx.nd.waitall()
    profiler.set_state('stop')
    profiler.dump(True)

The code snippet above will generate the following gpu_memory_profile.csv:

Attribute Name Requested Size Device Actual Size Reuse?
<unk>:in_arg:data 640 0 4096 0
net:arg_grad:net_dense0_bias 512 0 4096 0
net:arg_grad:net_dense0_weight 5120 0 8192 0
net:arg_grad:net_dense1_bias 256 0 4096 0
net:arg_grad:net_dense1_weight 32768 0 32768 0
net:arg_grad:net_dense2_bias 128 0 4096 0
net:arg_grad:net_dense2_weight 8192 0 8192 0
net:dense0:net_dense0_fwd 8192 0 8192 0
net:dense0:tanh:net_dense0_tanh_fwd 8192 0 8192 0
net:dense1:net_dense1_fwd 4096 0 4096 0
net:dense1:tanh:net_dense1_tanh_fwd 4096 0 4096 0
net:dense2:net_dense2_fwd 2048 0 4096 0
net:dense2:net_dense2_fwd_backward 4096 0 4096 0
net:dropout0:net_dropout0_fwd 8192 0 8192 0
net:dropout0:net_dropout0_fwd 8192 0 8192 0
net:in_arg:net_dense0_bias 512 0 4096 0
net:in_arg:net_dense0_weight 5120 0 8192 0
net:in_arg:net_dense1_bias 256 0 4096 0
net:in_arg:net_dense1_weight 32768 0 32768 0
net:in_arg:net_dense2_bias 128 0 4096 0
net:in_arg:net_dense2_weight 8192 0 8192 0
net:relu0:net_relu0_fwd 2048 0 4096 0
net:relu0:net_relu0_fwd_backward 8192 0 8192 0
net:relu0:net_relu0_fwd_head_grad 2048 0 4096 0
resource:cudnn_dropout_state (dropout-inl.h +258) 1671168 0 1671168 0
resource:temp_space (fully_connected-inl.h +316) 34816 0 36864 0

@szha szha merged commit 4dddb08 into apache:master Mar 10, 2020
MoisesHer pushed a commit to MoisesHer/incubator-mxnet that referenced this pull request Apr 10, 2020
@ArmageddonKnight ArmageddonKnight deleted the bojian/Storage_Tagging_ii branch May 6, 2020 08:05
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants