Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable tuning in Batch norm CK solver #3326

Merged
merged 13 commits into from
Oct 24, 2024
6 changes: 6 additions & 0 deletions src/batchnorm/problem_description.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "c" << c;
}
else
Expand All @@ -154,6 +155,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "single" << static_cast<int>(single);
ss << "n" << n;
ss << "c" << c;
Expand All @@ -172,6 +174,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "gx" << xgridsize;
ss << "gy" << ygridsize;
ss << "lx" << xlocalsize;
Expand Down Expand Up @@ -201,6 +204,7 @@ NetworkConfig ProblemDescription::MakeForwardInferenceNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "mode" << bn_mode;
ss << "HWdims" << in_cstride;
ss << "C" << c;
Expand Down Expand Up @@ -308,6 +312,7 @@ NetworkConfig ProblemDescription::MakeBackwardNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "single" << static_cast<int>(single);
ss << "gcn" << ldsgcn;
}
Expand All @@ -330,6 +335,7 @@ NetworkConfig ProblemDescription::MakeBackwardNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "nhw" << in_nhw;
}
ss << "layout" << in_layout;
Expand Down
4 changes: 2 additions & 2 deletions src/include/miopen/batchnorm/invoke_params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
namespace miopen {
namespace batchnorm {

struct InvokeParams : public miopen::InvokeParams
struct FwdTrainInvokeParams : public miopen::InvokeParams
{
InvokeParams() = default;
FwdTrainInvokeParams() = default;

ConstData_t x = nullptr;
Data_t y = nullptr;
Expand Down
72 changes: 72 additions & 0 deletions src/include/miopen/batchnorm/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,49 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase, Prob
bool IsFp64() const { return xDesc.GetType() == miopenDouble; }
bool IsFp32() const { return xDesc.GetType() == miopenFloat; }
bool IsFp16() const { return xDesc.GetType() == miopenHalf; }
bool IsMix() const
{
return xDesc.GetType() == miopenHalf && sMeanDesc.GetType() == miopenFloat;
}
bool IsBfp16() const { return xDesc.GetType() == miopenBFloat16; }

void Serialize(std::ostream& stream) const { stream << MakeNetworkConfig().ToString(); }

NetworkConfig MakeNetworkConfig() const override;

template <class Self>
static void Visit(Self&& self, std::function<void(int64_t, std::string)> f)
{
// The column names match the driver command line argument names
f(self.spatial_dim, "spatial_dim");
f(self.GetBatchSize(), "batchsize");
f(self.GetChannel(), "in_channels");
f(self.GetHeight(), "in_h");
f(self.GetWidth(), "in_w");
f(self.GetDepth(), "in_d");

f(self.resultsave, "resultsave");
f(self.resultrunning, "resultrunning");
f(self.useSaved, "useSaved");
}

template <class Self>
static void Visit(Self&& self, std::function<void(std::string, std::string)> f)
{
f(self.ComputeInLayout(), "layout");
f(self.GetDirectionStr(), "direction");
f(GetDataTypeName(self.xDesc.GetType()), "data_type");
f(self.GetModeStr(), "mode");
}

template <class Self, class Visitor>
static void VisitAll(Self&& self, const Visitor& f)
{
Visit(std::forward<Self>(self), [&](int64_t value, std::string name) { f(value, name); });
Visit(std::forward<Self>(self),
[&](std::string value, std::string name) { f(value, name); });
}

// This declaration marks batchnorm as a primitive with tuning enabled.
// Any tunable solver would be able pick it and fetch a db instance in ExecutePrimitive.
// It has to be discoverable via ADL from problem description.
Expand Down Expand Up @@ -267,6 +306,39 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase, Prob
std::string ComputeInLayout() const { return ComputeLayout(xDesc); }
std::string ComputeOutLayout() const { return ComputeLayout(yOrDyDesc); }
std::string ComputeDinLayout() const { return ComputeLayout(dxDesc); }

size_t GetSpatialDims() const { return spatial_dim; }

std::size_t GetBatchSize() const { return GetN5(GetSpatialDims(), xDesc.GetLengths()); }
std::size_t GetChannel() const { return GetC5(GetSpatialDims(), xDesc.GetLengths()); }
std::size_t GetHeight() const { return GetH5(GetSpatialDims(), xDesc.GetLengths()); }
std::size_t GetWidth() const { return GetW5(GetSpatialDims(), xDesc.GetLengths()); }
std::size_t GetDepth() const { return GetD5(GetSpatialDims(), xDesc.GetLengths()); }

std::string GetDirectionStr() const
{
std::string s;

switch(direction)
{
case Direction::ForwardInference: s = "Inf"; break;
case Direction::ForwardTraining: s = "Trn"; break;
case Direction::Backward: s = "Bwd"; break;
default: MIOPEN_THROW(miopenStatusInvalidValue, "Wrong Batchnorm Direction provided");
}

return s;
}

std::string GetModeStr() const
{
switch(bn_mode)
{
case miopenBNPerActivation: return "0";
case miopenBNSpatial: return "1";
default: MIOPEN_THROW(miopenStatusInvalidValue, "Wrong Batchnorm Direction provided");
}
}
bghimireamd marked this conversation as resolved.
Show resolved Hide resolved
};

} // namespace batchnorm
Expand Down
211 changes: 196 additions & 15 deletions src/include/miopen/batchnorm/solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ namespace batchnorm {
using BatchnormSolver =
NonTunableSolverBase<ExecutionContext, miopen::batchnorm::ProblemDescription>;

template <class PerformanceConfig>
using BatchNormTunableSolver =
TunableSolverMixin<ExecutionContext, miopen::batchnorm::ProblemDescription, PerformanceConfig>;
;

struct BnFwdTrainingSpatialSingle final : BatchnormSolver
{
const std::string& SolverDbId() const override
Expand Down Expand Up @@ -132,34 +137,210 @@ struct BnFwdInference final : BatchnormSolver
const miopen::batchnorm::ProblemDescription& problem) const override;
};

struct BnCKFwdInference final : BatchnormSolver
struct PerformanceConfigBnCKFwdInference : PerfConfigBase<PerformanceConfigBnCKFwdInference>
{
int index;
std::string kernel_id;
std::vector<std::string> valid_kernels;
PerformanceConfigBnCKFwdInference(int idx, std::string kernl_id)
: index(idx), kernel_id(kernl_id)
{
}
PerformanceConfigBnCKFwdInference() : PerformanceConfigBnCKFwdInference(0, "") {}
PerformanceConfigBnCKFwdInference(bool) : PerformanceConfigBnCKFwdInference(0, "") {}
MIOPEN_INTERNALS_EXPORT void
HeuristicInit(const miopen::batchnorm::ProblemDescription& problem_desc);
MIOPEN_INTERNALS_EXPORT bool
SetNextValue(const miopen::batchnorm::ProblemDescription& problem_desc);
MIOPEN_INTERNALS_EXPORT bool IsValidValue() const;
bghimireamd marked this conversation as resolved.
Show resolved Hide resolved
MIOPEN_INTERNALS_EXPORT bool
IsValid(const ExecutionContext&,
const miopen::batchnorm::ProblemDescription& problem_desc) const;

template <typename Self, typename F>
static void Visit(Self&& s, F f)
{
f(s.kernel_id, "kernel_id");
}
MIOPEN_INTERNALS_EXPORT bool operator==(const PerformanceConfigBnCKFwdInference& other) const;

private:
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
void Init(const miopen::batchnorm::ProblemDescription&);
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
bool CheckIsSupportCKArgs(const miopen::batchnorm::ProblemDescription&) const;
};

struct BnCKFwdInference final : BatchNormTunableSolver<PerformanceConfigBnCKFwdInference>
{
const std::string& SolverDbId() const override { return GetSolverDbId<BnCKFwdInference>(); }

bool IsApplicable(const ExecutionContext& context,
const miopen::batchnorm::ProblemDescription& problem) const override;
ConvSolution GetSolution(const ExecutionContext& context,
const miopen::batchnorm::ProblemDescription& problem) const override;
MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKFwdInference GetDefaultPerformanceConfig(
const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc) const override;
MIOPEN_INTERNALS_EXPORT bool
IsValidPerformanceConfig(const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc,
const PerformanceConfigBnCKFwdInference& config) const override;
MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKFwdInference
Search(const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc,
const AnyInvokeParams& invoke_ctx) const override;
MIOPEN_INTERNALS_EXPORT bool
IsApplicable(const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc) const override;
MIOPEN_INTERNALS_EXPORT ConvSolution
GetSolution(const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc,
const PerformanceConfigBnCKFwdInference& config) const override;
};

struct PerformanceConfigBnCKBwdBackward : PerfConfigBase<PerformanceConfigBnCKBwdBackward>
{
int index;
std::string kernel_id;
std::vector<std::string> valid_kernels;
PerformanceConfigBnCKBwdBackward(int idx, std::string kernl_id)
: index(idx), kernel_id(kernl_id)
{
}
PerformanceConfigBnCKBwdBackward() : PerformanceConfigBnCKBwdBackward(0, "") {}
PerformanceConfigBnCKBwdBackward(bool) : PerformanceConfigBnCKBwdBackward(0, "") {}
MIOPEN_INTERNALS_EXPORT void
HeuristicInit(const miopen::batchnorm::ProblemDescription& problem_desc);
MIOPEN_INTERNALS_EXPORT bool
SetNextValue(const miopen::batchnorm::ProblemDescription& problem_desc);
MIOPEN_INTERNALS_EXPORT bool IsValidValue() const;
MIOPEN_INTERNALS_EXPORT bool
IsValid(const ExecutionContext&,
const miopen::batchnorm::ProblemDescription& problem_desc) const;

template <typename Self, typename F>
static void Visit(Self&& s, F f)
{
f(s.kernel_id, "kernel_id");
}
MIOPEN_INTERNALS_EXPORT bool operator==(const PerformanceConfigBnCKBwdBackward& other) const;

private:
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType>
void Init(const miopen::batchnorm::ProblemDescription&);
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType>
bool CheckIsSupportCKArgs(const miopen::batchnorm::ProblemDescription&) const;
};

struct BnCKBwdBackward final : BatchnormSolver
struct BnCKBwdBackward final : BatchNormTunableSolver<PerformanceConfigBnCKBwdBackward>
{
const std::string& SolverDbId() const override { return GetSolverDbId<BnCKBwdBackward>(); }

bool IsApplicable(const ExecutionContext& context,
const miopen::batchnorm::ProblemDescription& problem) const override;
ConvSolution GetSolution(const ExecutionContext& context,
const miopen::batchnorm::ProblemDescription& problem) const override;
MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKBwdBackward GetDefaultPerformanceConfig(
const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc) const override;
MIOPEN_INTERNALS_EXPORT bool
IsValidPerformanceConfig(const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc,
const PerformanceConfigBnCKBwdBackward& config) const override;
MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKBwdBackward
Search(const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc,
const AnyInvokeParams& invoke_ctx) const override;
MIOPEN_INTERNALS_EXPORT bool
IsApplicable(const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc) const override;
MIOPEN_INTERNALS_EXPORT ConvSolution
GetSolution(const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc,
const PerformanceConfigBnCKBwdBackward& config) const override;
};

struct PerformanceConfigBnCKFwdTraining : PerfConfigBase<PerformanceConfigBnCKFwdTraining>
{
int index;
std::string kernel_id;
std::vector<std::string> valid_kernels;
PerformanceConfigBnCKFwdTraining(int idx, std::string kernl_id)
: index(idx), kernel_id(kernl_id)
{
}
PerformanceConfigBnCKFwdTraining() : PerformanceConfigBnCKFwdTraining(0, "") {}
PerformanceConfigBnCKFwdTraining(bool) : PerformanceConfigBnCKFwdTraining(0, "") {}
MIOPEN_INTERNALS_EXPORT void
HeuristicInit(const miopen::batchnorm::ProblemDescription& problem_desc);
MIOPEN_INTERNALS_EXPORT bool
SetNextValue(const miopen::batchnorm::ProblemDescription& problem_desc);
MIOPEN_INTERNALS_EXPORT bool IsValidValue() const;
MIOPEN_INTERNALS_EXPORT bool
IsValid(const ExecutionContext&,
const miopen::batchnorm::ProblemDescription& problem_desc) const;

template <typename Self, typename F>
static void Visit(Self&& s, F f)
{
f(s.kernel_id, "kernel_id");
}
MIOPEN_INTERNALS_EXPORT bool operator==(const PerformanceConfigBnCKFwdTraining& other) const;

private:
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
void Init(const miopen::batchnorm::ProblemDescription&);
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
bool CheckIsSupportCKArgs(const miopen::batchnorm::ProblemDescription&) const;
};

struct BnCKFwdTraining final : BatchnormSolver
struct BnCKFwdTraining final : BatchNormTunableSolver<PerformanceConfigBnCKFwdTraining>
{
const std::string& SolverDbId() const override { return GetSolverDbId<BnCKFwdTraining>(); }

bool IsApplicable(const ExecutionContext& context,
const miopen::batchnorm::ProblemDescription& problem) const override;
ConvSolution GetSolution(const ExecutionContext& context,
const miopen::batchnorm::ProblemDescription& problem) const override;
MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKFwdTraining GetDefaultPerformanceConfig(
const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc) const override;
MIOPEN_INTERNALS_EXPORT bool
IsValidPerformanceConfig(const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc,
const PerformanceConfigBnCKFwdTraining& config) const override;
MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKFwdTraining
Search(const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc,
const AnyInvokeParams& invoke_ctx) const override;
MIOPEN_INTERNALS_EXPORT bool
IsApplicable(const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc) const override;
MIOPEN_INTERNALS_EXPORT ConvSolution
GetSolution(const ExecutionContext& ctx,
const miopen::batchnorm::ProblemDescription& problem_desc,
const PerformanceConfigBnCKFwdTraining& config) const override;
};

} // namespace batchnorm
Expand Down
Loading
Loading