Skip to content

Commit

Permalink
Update seq2len plugin to support dynamic shapes
Browse files Browse the repository at this point in the history
Signed-off-by: Rajeev Rao <rajeevrao@nvidia.com>
  • Loading branch information
jdemouth-nvidia authored and rajeevsrao committed Dec 8, 2022
1 parent e82aa6a commit 85f59aa
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 79 deletions.
7 changes: 3 additions & 4 deletions demo/Diffusion/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,9 @@ def insert_seq2spatial_plugin(self):
inputTensor = biasNode.inputs[inputIndex]
residualInput = residualNode.inputs[1]
outputTensor = transposeNode.outputs[0]
outputShapeTensor = transposeNode.i().i().i(1).i(1).i(1).i().inputs[0]
seqLen2SpatialNode = gs.Node("SeqLen2Spatial", "AddAddSeqLen2Spatial-" + str(nSeqLen2SpatialPlugin),
attrs=OrderedDict([("height", outputTensor.shape[2]), ("width", outputTensor.shape[3])]),
inputs=[inputTensor, biasInput, residualInput], outputs=[outputTensor])
inputs=[inputTensor, biasInput, residualInput, outputShapeTensor], outputs=[outputTensor])
self.graph.nodes.append(seqLen2SpatialNode)
biasNode.inputs.clear()
transposeNode.outputs.clear()
Expand Down Expand Up @@ -841,8 +841,7 @@ def optimize(self, onnx_graph, minimal_optimization=False):
# Insert Split+GeLU Plugin
bSplitGeLUPlugin = True
# Replace BiasAdd+ResidualAdd+SeqLen2Spatial with plugin
# TODO - support dynamic shapes in plugin
bSeqLen2SpatialPlugin = False
bSeqLen2SpatialPlugin = True

opt = Optimizer(onnx_graph, verbose=self.verbose)
opt.info('UNet: original')
Expand Down
82 changes: 10 additions & 72 deletions plugin/seqLen2SpatialPlugin/seqLen2SpatialPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,17 @@ namespace
{
static std::string const kSEQLEN2SPATIAL_PLUGIN_NAME{"SeqLen2Spatial"};
static std::string const kSEQLEN2SPATIAL_PLUGIN_VERSION{"1"};
size_t constexpr kSERIALIZATION_SIZE{2 * sizeof(int32_t)};
size_t constexpr kSERIALIZATION_SIZE{0};
} // namespace

SeqLen2SpatialPlugin::SeqLen2SpatialPlugin(std::string const& name, int32_t height, int32_t width)
SeqLen2SpatialPlugin::SeqLen2SpatialPlugin(std::string const& name)
: mName(name)
, mHeight(height)
, mWidth(width)
{
}

SeqLen2SpatialPlugin::SeqLen2SpatialPlugin(std::string const& name, void const* buffer, size_t length)
: mName(name)
{
PLUGIN_VALIDATE(buffer != nullptr);
PLUGIN_VALIDATE(length == kSERIALIZATION_SIZE);

auto const* d = static_cast<char const*>(buffer);
auto const* a = d;

mHeight = read<int32_t>(d);
mWidth = read<int32_t>(d);

PLUGIN_VALIDATE(d == a + length);
}

IPluginV2DynamicExt* SeqLen2SpatialPlugin::clone() const noexcept
Expand Down Expand Up @@ -93,25 +81,7 @@ DataType SeqLen2SpatialPlugin::getOutputDataType(
DimsExprs SeqLen2SpatialPlugin::getOutputDimensions(
int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept
{
DimsExprs ret{};
try
{
PLUGIN_VALIDATE(inputs != nullptr);
PLUGIN_VALIDATE(nbInputs > 0);
DimsExprs inputDims = inputs[0];
DimsExprs outputDims;
outputDims.nbDims = 4;
outputDims.d[0] = inputDims.d[0];
outputDims.d[1] = inputDims.d[2];
outputDims.d[2] = exprBuilder.constant(mHeight);
outputDims.d[3] = exprBuilder.constant(mWidth);
ret = outputDims;
}
catch (std::exception const& e)
{
caughtError(e);
}
return ret;
return inputs[3];
}

bool SeqLen2SpatialPlugin::supportsFormatCombination(
Expand All @@ -122,7 +92,7 @@ bool SeqLen2SpatialPlugin::supportsFormatCombination(
PLUGIN_VALIDATE(inOut != nullptr);
PLUGIN_VALIDATE(nbInputs + nbOutputs > 0);
PLUGIN_VALIDATE(pos < nbInputs + nbOutputs);
PLUGIN_VALIDATE(pos >= 0 && pos <= 3);
PLUGIN_VALIDATE(pos >= 0 && pos <= 4);

if (pos == 0)
{
Expand All @@ -136,6 +106,11 @@ bool SeqLen2SpatialPlugin::supportsFormatCombination(
}

if (pos == 3)
{
return inOut[pos].type == DataType::kFLOAT && inOut[pos].format == TensorFormat::kLINEAR;
}

if (pos == 4)
{
return inOut[pos].type == inOut[0].type
&& ((inOut[pos].type == DataType::kFLOAT && inOut[pos].format == TensorFormat::kHWC)
Expand Down Expand Up @@ -204,19 +179,6 @@ size_t SeqLen2SpatialPlugin::getSerializationSize() const noexcept

void SeqLen2SpatialPlugin::serialize(void* buffer) const noexcept
{
try
{
PLUGIN_VALIDATE(buffer != nullptr);
auto* d = static_cast<char*>(buffer);
auto* a = d;
write(d, mHeight); // int32_t
write(d, mWidth); // int32_t
PLUGIN_VALIDATE(d == a + getSerializationSize());
}
catch (std::exception const& e)
{
caughtError(e);
}
}

void SeqLen2SpatialPlugin::setPluginNamespace(char const* pluginNamespace) noexcept
Expand Down Expand Up @@ -245,8 +207,6 @@ std::vector<PluginField> SeqLen2SpatialPluginCreator::mPluginAttributes;
SeqLen2SpatialPluginCreator::SeqLen2SpatialPluginCreator()
{
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("height", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("width", nullptr, PluginFieldType::kINT32, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
Expand All @@ -257,28 +217,7 @@ IPluginV2* SeqLen2SpatialPluginCreator::createPlugin(char const* name, PluginFie
{
try
{
PLUGIN_VALIDATE(fc != nullptr);
PluginField const* fields = fc->fields;

// default values
int32_t mHeight{0};
int32_t mWidth{0};

for (int32_t i = 0; i < fc->nbFields; ++i)
{
char const* attrName = fields[i].name;
if (!strcmp(attrName, "height"))
{
PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32);
mHeight = static_cast<int32_t>(*(static_cast<int32_t const*>(fields[i].data)));
}
else if (!strcmp(attrName, "width"))
{
PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32);
mWidth = static_cast<int32_t>(*(static_cast<int32_t const*>(fields[i].data)));
}
}
return new SeqLen2SpatialPlugin(name, mHeight, mWidth);
return new SeqLen2SpatialPlugin(name);
}
catch (std::exception const& e)
{
Expand All @@ -292,7 +231,6 @@ IPluginV2* SeqLen2SpatialPluginCreator::deserializePlugin(
{
try
{
PLUGIN_VALIDATE(serialData != nullptr);
return new SeqLen2SpatialPlugin(name, serialData, serialLength);
}
catch (std::exception const& e)
Expand Down
4 changes: 1 addition & 3 deletions plugin/seqLen2SpatialPlugin/seqLen2SpatialPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class SeqLen2SpatialPlugin : public IPluginV2DynamicExt
{
public:
SeqLen2SpatialPlugin() = delete;
SeqLen2SpatialPlugin(std::string const& name, int32_t height, int32_t width);
SeqLen2SpatialPlugin(std::string const& name);
SeqLen2SpatialPlugin(std::string const& name, void const* buffer, size_t length);
~SeqLen2SpatialPlugin() override = default;

Expand Down Expand Up @@ -74,8 +74,6 @@ class SeqLen2SpatialPlugin : public IPluginV2DynamicExt
private:
std::string mName;
std::string mNameSpace;
int32_t mHeight{};
int32_t mWidth{};
};

class SeqLen2SpatialPluginCreator : public nvinfer1::pluginInternal::BaseCreator
Expand Down

0 comments on commit 85f59aa

Please sign in to comment.