Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable tuning in Batch norm CK solver #3326

Merged
merged 13 commits into from
Oct 24, 2024
6 changes: 6 additions & 0 deletions src/batchnorm/problem_description.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "c" << c;
}
else
Expand All @@ -154,6 +155,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "single" << static_cast<int>(single);
ss << "n" << n;
ss << "c" << c;
Expand All @@ -172,6 +174,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "gx" << xgridsize;
ss << "gy" << ygridsize;
ss << "lx" << xlocalsize;
Expand Down Expand Up @@ -201,6 +204,7 @@ NetworkConfig ProblemDescription::MakeForwardInferenceNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "mode" << bn_mode;
ss << "HWdims" << in_cstride;
ss << "C" << c;
Expand Down Expand Up @@ -308,6 +312,7 @@ NetworkConfig ProblemDescription::MakeBackwardNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "single" << static_cast<int>(single);
ss << "gcn" << ldsgcn;
}
Expand All @@ -330,6 +335,7 @@ NetworkConfig ProblemDescription::MakeBackwardNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "nhw" << in_nhw;
}
ss << "layout" << in_layout;
Expand Down
4 changes: 2 additions & 2 deletions src/include/miopen/batchnorm/invoke_params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
namespace miopen {
namespace batchnorm {

struct InvokeParams : public miopen::InvokeParams
struct FwdTrainInvokeParams : public miopen::InvokeParams
{
InvokeParams() = default;
FwdTrainInvokeParams() = default;

ConstData_t x = nullptr;
Data_t y = nullptr;
Expand Down
79 changes: 78 additions & 1 deletion src/include/miopen/batchnorm/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ struct ProblemDescriptionTag
{
};

struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase, ProblemDescriptionTag
struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase,
ProblemDescriptionTag
#if MIOPEN_ENABLE_SQLITE
,
SQLiteSerializable<ProblemDescription>
#endif
{
// Forward Training
ProblemDescription(miopenBatchNormMode_t bn_mode_,
Expand Down Expand Up @@ -218,10 +223,49 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase, Prob
bool IsFp64() const { return xDesc.GetType() == miopenDouble; }
bool IsFp32() const { return xDesc.GetType() == miopenFloat; }
bool IsFp16() const { return xDesc.GetType() == miopenHalf; }
bool IsMix() const
{
return xDesc.GetType() == miopenHalf && sMeanDesc.GetType() == miopenFloat;
}
bool IsBfp16() const { return xDesc.GetType() == miopenBFloat16; }

void Serialize(std::ostream& stream) const { stream << MakeNetworkConfig().ToString(); }

NetworkConfig MakeNetworkConfig() const override;

template <class Self>
static void Visit(Self&& self, std::function<void(int64_t, std::string)> f)
{
// The column names match the driver command line argument names
f(self.spatial_dim, "spatial_dim");
f(self.GetBatchSize(), "batchsize");
f(self.GetChannel(), "in_channels");
f(self.GetHeight(), "in_h");
f(self.GetWidth(), "in_w");
f(self.GetDepth(), "in_d");

f(self.resultsave, "resultsave");
f(self.resultrunning, "resultrunning");
f(self.useSaved, "useSaved");
}

template <class Self>
static void Visit(Self&& self, std::function<void(std::string, std::string)> f)
{
f(self.ComputeInLayout(), "layout");
f(self.GetDirectionStr(), "direction");
f(GetDataTypeName(self.xDesc.GetType()), "data_type");
f(self.GetModeStr(), "mode");
}

template <class Self, class Visitor>
static void VisitAll(Self&& self, const Visitor& f)
{
Visit(std::forward<Self>(self), [&](int64_t value, std::string name) { f(value, name); });
Visit(std::forward<Self>(self),
[&](std::string value, std::string name) { f(value, name); });
}

// This declaration marks batchnorm as a primitive with tuning enabled.
// Any tunable solver would be able pick it and fetch a db instance in ExecutePrimitive.
// It has to be discoverable via ADL from problem description.
Expand Down Expand Up @@ -267,6 +311,39 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase, Prob
std::string ComputeInLayout() const { return ComputeLayout(xDesc); }
std::string ComputeOutLayout() const { return ComputeLayout(yOrDyDesc); }
std::string ComputeDinLayout() const { return ComputeLayout(dxDesc); }

size_t GetSpatialDims() const { return spatial_dim; }

std::size_t GetBatchSize() const { return GetN5(GetSpatialDims(), xDesc.GetLengths()); }
std::size_t GetChannel() const { return GetC5(GetSpatialDims(), xDesc.GetLengths()); }
std::size_t GetHeight() const { return GetH5(GetSpatialDims(), xDesc.GetLengths()); }
std::size_t GetWidth() const { return GetW5(GetSpatialDims(), xDesc.GetLengths()); }
std::size_t GetDepth() const { return GetD5(GetSpatialDims(), xDesc.GetLengths()); }

std::string GetDirectionStr() const
{
std::string s;

switch(direction)
{
case Direction::ForwardInference: return "Inf"; ;
case Direction::ForwardTraining: return "Trn";
case Direction::Backward: return "Bwd";
default: MIOPEN_THROW(miopenStatusInvalidValue, "Wrong Batchnorm Direction provided");
}

return s;
}

std::string GetModeStr() const
{
switch(bn_mode)
{
case miopenBNPerActivation: return "0";
case miopenBNSpatial: return "1";
default: MIOPEN_THROW(miopenStatusInvalidValue, "Wrong Batchnorm Mode provided");
}
}
};

} // namespace batchnorm
Expand Down
Loading
Loading