diff --git a/demo/Diffusion/models.py b/demo/Diffusion/models.py index ed921784..8664b7b7 100644 --- a/demo/Diffusion/models.py +++ b/demo/Diffusion/models.py @@ -345,9 +345,9 @@ def insert_seq2spatial_plugin(self): inputTensor = biasNode.inputs[inputIndex] residualInput = residualNode.inputs[1] outputTensor = transposeNode.outputs[0] + outputShapeTensor = transposeNode.i().i().i(1).i(1).i(1).i().inputs[0] seqLen2SpatialNode = gs.Node("SeqLen2Spatial", "AddAddSeqLen2Spatial-" + str(nSeqLen2SpatialPlugin), - attrs=OrderedDict([("height", outputTensor.shape[2]), ("width", outputTensor.shape[3])]), - inputs=[inputTensor, biasInput, residualInput], outputs=[outputTensor]) + inputs=[inputTensor, biasInput, residualInput, outputShapeTensor], outputs=[outputTensor]) self.graph.nodes.append(seqLen2SpatialNode) biasNode.inputs.clear() transposeNode.outputs.clear() @@ -841,8 +841,7 @@ def optimize(self, onnx_graph, minimal_optimization=False): # Insert Split+GeLU Plugin bSplitGeLUPlugin = True # Replace BiasAdd+ResidualAdd+SeqLen2Spatial with plugin - # TODO - support dynamic shapes in plugin - bSeqLen2SpatialPlugin = False + bSeqLen2SpatialPlugin = True opt = Optimizer(onnx_graph, verbose=self.verbose) opt.info('UNet: original') diff --git a/plugin/seqLen2SpatialPlugin/seqLen2SpatialPlugin.cpp b/plugin/seqLen2SpatialPlugin/seqLen2SpatialPlugin.cpp index a25a8c08..23b0a384 100644 --- a/plugin/seqLen2SpatialPlugin/seqLen2SpatialPlugin.cpp +++ b/plugin/seqLen2SpatialPlugin/seqLen2SpatialPlugin.cpp @@ -28,29 +28,17 @@ namespace { static std::string const kSEQLEN2SPATIAL_PLUGIN_NAME{"SeqLen2Spatial"}; static std::string const kSEQLEN2SPATIAL_PLUGIN_VERSION{"1"}; -size_t constexpr kSERIALIZATION_SIZE{2 * sizeof(int32_t)}; +size_t constexpr kSERIALIZATION_SIZE{0}; } // namespace -SeqLen2SpatialPlugin::SeqLen2SpatialPlugin(std::string const& name, int32_t height, int32_t width) +SeqLen2SpatialPlugin::SeqLen2SpatialPlugin(std::string const& name) : mName(name) - , mHeight(height) - , mWidth(width) { } SeqLen2SpatialPlugin::SeqLen2SpatialPlugin(std::string const& name, void const* buffer, size_t length) : mName(name) { - PLUGIN_VALIDATE(buffer != nullptr); - PLUGIN_VALIDATE(length == kSERIALIZATION_SIZE); - - auto const* d = static_cast(buffer); - auto const* a = d; - - mHeight = read(d); - mWidth = read(d); - - PLUGIN_VALIDATE(d == a + length); } IPluginV2DynamicExt* SeqLen2SpatialPlugin::clone() const noexcept @@ -93,25 +81,7 @@ DataType SeqLen2SpatialPlugin::getOutputDataType( DimsExprs SeqLen2SpatialPlugin::getOutputDimensions( int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept { - DimsExprs ret{}; - try - { - PLUGIN_VALIDATE(inputs != nullptr); - PLUGIN_VALIDATE(nbInputs > 0); - DimsExprs inputDims = inputs[0]; - DimsExprs outputDims; - outputDims.nbDims = 4; - outputDims.d[0] = inputDims.d[0]; - outputDims.d[1] = inputDims.d[2]; - outputDims.d[2] = exprBuilder.constant(mHeight); - outputDims.d[3] = exprBuilder.constant(mWidth); - ret = outputDims; - } - catch (std::exception const& e) - { - caughtError(e); - } - return ret; + return inputs[3]; } bool SeqLen2SpatialPlugin::supportsFormatCombination( @@ -122,7 +92,7 @@ bool SeqLen2SpatialPlugin::supportsFormatCombination( PLUGIN_VALIDATE(inOut != nullptr); PLUGIN_VALIDATE(nbInputs + nbOutputs > 0); PLUGIN_VALIDATE(pos < nbInputs + nbOutputs); - PLUGIN_VALIDATE(pos >= 0 && pos <= 3); + PLUGIN_VALIDATE(pos >= 0 && pos <= 4); if (pos == 0) { @@ -136,6 +106,11 @@ bool SeqLen2SpatialPlugin::supportsFormatCombination( } if (pos == 3) + { + return inOut[pos].type == DataType::kFLOAT && inOut[pos].format == TensorFormat::kLINEAR; + } + + if (pos == 4) { return inOut[pos].type == inOut[0].type && ((inOut[pos].type == DataType::kFLOAT && inOut[pos].format == TensorFormat::kHWC) @@ -204,19 +179,6 @@ size_t SeqLen2SpatialPlugin::getSerializationSize() const noexcept void SeqLen2SpatialPlugin::serialize(void* buffer) const noexcept { - try - { - PLUGIN_VALIDATE(buffer != nullptr); - auto* d = static_cast(buffer); - auto* a = d; - write(d, mHeight); // int32_t - write(d, mWidth); // int32_t - PLUGIN_VALIDATE(d == a + getSerializationSize()); - } - catch (std::exception const& e) - { - caughtError(e); - } } void SeqLen2SpatialPlugin::setPluginNamespace(char const* pluginNamespace) noexcept @@ -245,8 +207,6 @@ std::vector SeqLen2SpatialPluginCreator::mPluginAttributes; SeqLen2SpatialPluginCreator::SeqLen2SpatialPluginCreator() { mPluginAttributes.clear(); - mPluginAttributes.emplace_back(PluginField("height", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("width", nullptr, PluginFieldType::kINT32, 1)); mFC.nbFields = mPluginAttributes.size(); mFC.fields = mPluginAttributes.data(); } @@ -257,28 +217,7 @@ IPluginV2* SeqLen2SpatialPluginCreator::createPlugin(char const* name, PluginFie { try { - PLUGIN_VALIDATE(fc != nullptr); - PluginField const* fields = fc->fields; - - // default values - int32_t mHeight{0}; - int32_t mWidth{0}; - - for (int32_t i = 0; i < fc->nbFields; ++i) - { - char const* attrName = fields[i].name; - if (!strcmp(attrName, "height")) - { - PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32); - mHeight = static_cast(*(static_cast(fields[i].data))); - } - else if (!strcmp(attrName, "width")) - { - PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32); - mWidth = static_cast(*(static_cast(fields[i].data))); - } - } - return new SeqLen2SpatialPlugin(name, mHeight, mWidth); + return new SeqLen2SpatialPlugin(name); } catch (std::exception const& e) { @@ -292,7 +231,6 @@ IPluginV2* SeqLen2SpatialPluginCreator::deserializePlugin( { try { - PLUGIN_VALIDATE(serialData != nullptr); return new SeqLen2SpatialPlugin(name, serialData, serialLength); } catch (std::exception const& e) diff --git a/plugin/seqLen2SpatialPlugin/seqLen2SpatialPlugin.h b/plugin/seqLen2SpatialPlugin/seqLen2SpatialPlugin.h index 955ff2fd..612d231c 100644 --- a/plugin/seqLen2SpatialPlugin/seqLen2SpatialPlugin.h +++ b/plugin/seqLen2SpatialPlugin/seqLen2SpatialPlugin.h @@ -34,7 +34,7 @@ class SeqLen2SpatialPlugin : public IPluginV2DynamicExt { public: SeqLen2SpatialPlugin() = delete; - SeqLen2SpatialPlugin(std::string const& name, int32_t height, int32_t width); + SeqLen2SpatialPlugin(std::string const& name); SeqLen2SpatialPlugin(std::string const& name, void const* buffer, size_t length); ~SeqLen2SpatialPlugin() override = default; @@ -74,8 +74,6 @@ class SeqLen2SpatialPlugin : public IPluginV2DynamicExt private: std::string mName; std::string mNameSpace; - int32_t mHeight{}; - int32_t mWidth{}; }; class SeqLen2SpatialPluginCreator : public nvinfer1::pluginInternal::BaseCreator