Skip to content

Commit

Permalink
merge develop to resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
iq136boy committed Oct 11, 2023
2 parents d0f57da + b45e54d commit 936282c
Show file tree
Hide file tree
Showing 62 changed files with 265 additions and 387 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ include(ROCMCreatePackage)
include(CheckCXXCompilerFlag)
include(ROCMHeaderWrapper)

# Build library with Beta APIs
add_definitions("-DMIOPEN_BETA_API=1")

set(MIOPEN_ENABLE_AI_IMMED_MODE_FALLBACK On CACHE BOOL "Enable AI-based fallback for Immediate Mode")
set(MIOPEN_ENABLE_AI_KERNEL_TUNING On CACHE BOOL "Enable AI heuristic for kernel tuning")
set(MIOPEN_ENABLE_SQLITE On CACHE BOOL "")
Expand Down
2 changes: 1 addition & 1 deletion docs/.sphinx/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fastjsonschema==2.16.3
# via rocm-docs-core
gitdb==4.0.10
# via gitpython
gitpython==3.1.35
gitpython==3.1.37
# via rocm-docs-core
idna==3.4
# via requests
Expand Down
3 changes: 1 addition & 2 deletions docs/datatypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ typedef enum {
miopenFloat = 1,
miopenInt32 = 2,
miopenInt8 = 3,
miopenInt8x4 = 4,
/* Value 4 is reserved. */
miopenBFloat16 = 5,
} miopenDataType_t;
```
Expand All @@ -22,7 +22,6 @@ Type descriptions:
* `miopenFloat` - 32-bit floating point
* `miopenInt32` - 32-bit integer, used primarily for `int8` convolution outputs
* `miopenInt8` - 8-bit integer, currently only supported by `int8` convolution forward path, tensor set, tensor copy, tensor cast, tensor transform, tensor transpose, and im2col.
* `miopenInt8x4` - 8-bit 4 element vector type used primarily with `int8` convolutions forward path.
* `miopenBFloat16` - brain float fp-16 (8-bit exponent, 7-bit fraction), currently only supported by convolutions, tensor set, and tensor copy.


Expand Down
6 changes: 3 additions & 3 deletions driver/layernorm_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,19 +255,19 @@ int LayerNormDriver<Tgpu, Tref>::AllocateBuffersAndCopy()

for(int i = 0; i < in_sz; i++)
{
in[i] = RAN_GEN<Tgpu>(static_cast<Tgpu>(0.0), static_cast<Tgpu>(1.0));
in[i] = prng::gen_A_to_B<Tgpu>(static_cast<Tgpu>(0.0), static_cast<Tgpu>(1.0));
}
status = in_dev->ToGPU(q, in.data());

for(int i = 0; i < weight_sz; i++)
{
weight[i] = RAN_GEN<Tgpu>(static_cast<Tgpu>(0.0), static_cast<Tgpu>(1.0));
weight[i] = prng::gen_A_to_B<Tgpu>(static_cast<Tgpu>(0.0), static_cast<Tgpu>(1.0));
}
status = weight_dev->ToGPU(q, weight.data());

for(int i = 0; i < bias_sz; i++)
{
bias[i] = RAN_GEN<Tgpu>(static_cast<Tgpu>(0.0), static_cast<Tgpu>(1.0));
bias[i] = prng::gen_A_to_B<Tgpu>(static_cast<Tgpu>(0.0), static_cast<Tgpu>(1.0));
}
status = bias_dev->ToGPU(q, bias.data());

Expand Down
2 changes: 1 addition & 1 deletion fin
Submodule fin updated from b2f3f4 to 26b5c3
28 changes: 20 additions & 8 deletions include/miopen/miopen.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ typedef enum
miopenStatusVersionMismatch = 10, /*!< Version mismatch of the supplied binary data argment. */
} miopenStatus_t;

#ifdef MIOPEN_BETA_API
typedef enum
{
miopenF8RoundingModeStandard = 0,
miopenF8RoundingModeStochastic = 1,
} miopenF8RoundingMode_t;
#endif

/*! @brief Get character string for an error code.
*
Expand Down Expand Up @@ -346,17 +348,21 @@ MIOPEN_DECLARE_OBJECT(miopenReduceTensorDescriptor);
*/
typedef enum
{
miopenHalf = 0, /*!< 16-bit floating point (Fully supported) */
miopenFloat = 1, /*!< 32-bit floating point (Fully supported) */
miopenInt32 = 2, /*!< 32-bit int point (Partially supported) */
miopenInt8 = 3, /*!< 8-bit int point (Partially supported) */
miopenInt8x4 =
4, /*!< Pack of four 8-bit int points in NCHW_VECT_C format (Partially supported) */
miopenHalf = 0, /*!< 16-bit floating point (Fully supported) */
miopenFloat = 1, /*!< 32-bit floating point (Fully supported) */
miopenInt32 = 2, /*!< 32-bit int point (Partially supported) */
miopenInt8 = 3, /*!< 8-bit int point (Partially supported) */
miopenInt8x4 = 4, /*!< Pack of four Int8 in NCHW_VECT_C format (Support discontinued) */
miopenBFloat16 = 5, /*!< 16-bit binary floating point (8-bit exponent, 7-bit fraction)
(Partially supported) */
miopenDouble = 6, /*!< 64-bit floating point (Partially supported) */
miopenDouble = 6, /*!< 64-bit floating point (Partially supported) */
#ifdef MIOPEN_BETA_API
miopenFloat8 = 7,
miopenBFloat8 = 8
miopenBFloat8 = 8,
#else
// miopenReserved1 = 7,
// miopenReserved2 = 8,
#endif
} miopenDataType_t;

/*! @ingroup tensor
Expand Down Expand Up @@ -601,11 +607,15 @@ typedef enum
MIOPEN_CONVOLUTION_ATTRIB_DETERMINISTIC =
1, /*!< Restrict MIOpen convolutions to kernels which produce numerically deterministic
results. 0 - disabled (default), 1 - enabled >*/
#ifdef MIOPEN_BETA_API
MIOPEN_CONVOLUTION_ATTRIB_FP8_ROUNDING_MODE =
2, /*!<Specifies the rounding mode for the 8-bit floating data types. Currently, two
rounding modes are supported miopenF8RoundingModeStandard and
miopenF8RoundingModeStochastic. These are listed as part of the miopenF8RoundingMode_t
enum.>*/
#else
// miopenReserved1 = 2,
#endif
} miopenConvolutionAttrib_t;

/** @addtogroup tensor
Expand Down Expand Up @@ -723,6 +733,7 @@ MIOPEN_EXPORT miopenStatus_t miopenSetTensorDescriptor(miopenTensorDescriptor_t
const int* dimsA,
const int* stridesA);

#ifdef MIOPEN_BETA_API
/*! @brief Set the tensor cast type
*
* For tensors where the cast_type attribute is set, the tensor elements would be converted to the
Expand All @@ -734,6 +745,7 @@ MIOPEN_EXPORT miopenStatus_t miopenSetTensorDescriptor(miopenTensorDescriptor_t
*/
MIOPEN_EXPORT miopenStatus_t miopenSetTensorCastType(miopenTensorDescriptor_t tensorDesc,
miopenDataType_t cast_type);
#endif

/*! @brief Set shape of N-dimensional tensor
*
Expand Down
2 changes: 2 additions & 0 deletions src/activ/problem_description.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ NetworkConfig ProblemDescription::MakeNetworkConfig() const

ss << ((packed) ? "11" : "10"); // + lite bit
ss << xDesc.GetType();
if(const auto ct = xDesc.GetCastType())
ss << GetDataTypeName(*ct);
ss << activDesc.GetMode();
ss << read_unit;
ss << MAP_RD;
Expand Down
2 changes: 1 addition & 1 deletion src/check_numerics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ std::string GetKernelName(miopenDataType_t data_type)
case miopenBFloat8: return {"check_numerics_bf8"};
case miopenInt32:
case miopenInt8:
case miopenInt8x4:
case miopenInt8x4: // Support discontinued.
case miopenDouble:
default: return {""};
}
Expand Down
24 changes: 22 additions & 2 deletions src/conv/problem_description.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,19 @@ void ProblemDescription::BuildConfKey(std::string& conf_key) const
ss << 'x' << GetOutLayout();
}
ss << 'x' << EncodeDataTypesForKey(GetInDataType(), GetWeightsDataType(), GetOutDataType());

std::ostringstream optional;
if(const auto ct = GetInCastType())
optional << "ci" << GetDataTypeName(*ct);
if(const auto ct = GetWeightsCastType())
optional << "cw" << GetDataTypeName(*ct);
if(const auto ct = GetOutCastType())
optional << "co" << GetDataTypeName(*ct);
if(!optional.str().empty())
{
ss << 'x' << optional.str();
}

ss << 'x' << PrintDHW('x', GetSpatialDims(), GetPadD(), GetPadH(), GetPadW());
ss << 'x'
<< PrintDHW(
Expand Down Expand Up @@ -175,11 +188,18 @@ void ProblemDescription::Serialize(std::ostream& stream) const
{
// Group count > 1 identifies Group/Depthwise modes.
if(GetGroupCount() != 1)
optional << 'g' << GetGroupCount();
optional << "_g" << GetGroupCount();

if(const auto ct = GetInCastType())
optional << "_ci" << GetDataTypeName(*ct);
if(const auto ct = GetWeightsCastType())
optional << "_cw" << GetDataTypeName(*ct);
if(const auto ct = GetOutCastType())
optional << "_co" << GetDataTypeName(*ct);
}
if(!optional.str().empty())
{
stream << '_' << optional.str();
stream << optional.str();
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ ConvolutionDescriptor::GetForwardOutputTensorWithLayout(const TensorDescriptor&
std::vector<std::size_t> out_strides;
tensor_layout_to_strides(
out_lens, default_layout, yLayout, xDesc.GetVectorLength(), out_strides);
return {(xDesc.GetType() == miopenInt8 || xDesc.GetType() == miopenInt8x4
return {(xDesc.GetType() == miopenInt8
? (yType)
: xDesc.GetType()), // TODO: This function overrides the output type with
// essentially the input which is incorrect.
Expand Down
55 changes: 18 additions & 37 deletions src/gemm_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@
/// "disabled expansion of recursive macro" injected by rocblas headers.
#define AVOID_ROCBLAS_WRAPPERS_204 (MIOPEN_ROCBLAS_VERSION_FLAT >= 2004000)

/// Maintain API compatibility with various rocBLAS version
#define USE_GEMM_FLAGS_PACK_INT8X4 \
((MIOPEN_ROCBLAS_VERSION_FLAT >= 2038000) && (MIOPEN_ROCBLAS_VERSION_FLAT < 4000000))

/// Maintain API compatibility for versions not supporting FP16 alternate implementations
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (MIOPEN_ROCBLAS_VERSION_FLAT >= 2043000)
/// Some 2.42 versions have rocblas_gemm_flags_fp16_alt_impl, but
Expand Down Expand Up @@ -110,7 +106,7 @@ static inline rocblas_datatype rocBlasComputeType(const miopen::GemmDescriptor&
{
// Complex compute types are only supported in newer version of the API
assert(desc.dataType == desc.a_cast_type && desc.dataType == desc.b_cast_type);
if(desc.dataType == miopenInt8 || desc.dataType == miopenInt8x4)
if(desc.dataType == miopenInt8)
return rocblas_datatype::rocblas_datatype_i32_r;
else
return rocblas_datatype::rocblas_datatype_f32_r;
Expand Down Expand Up @@ -413,6 +409,7 @@ miopenStatus_t CallGemm(const Handle& handle,
gemm_desc.isColMajor = !gemm_desc.isColMajor;
std::swap(A, B);
std::swap(a_offset, b_offset);
std::swap(gemm_desc.a_cast_type, gemm_desc.b_cast_type);
std::swap(gemm_desc.transA, gemm_desc.transB);
std::swap(gemm_desc.m, gemm_desc.n);
std::swap(gemm_desc.lda, gemm_desc.ldb);
Expand Down Expand Up @@ -440,7 +437,6 @@ miopenStatus_t CallGemm(const Handle& handle,

switch(gemm_desc.dataType)
{
case miopenInt8x4:
case miopenInt8: {
assert(gemm_desc.k % 4 == 0);

Expand Down Expand Up @@ -472,12 +468,7 @@ miopenStatus_t CallGemm(const Handle& handle,
rocBlasComputeType(gemm_desc), // rocblas_datatype::rocblas_datatype_i32_r,
rocblas_gemm_algo::rocblas_gemm_algo_standard,
0,
#if USE_GEMM_FLAGS_PACK_INT8X4
rocblas_gemm_flags_pack_int8x4
#else
0
#endif
);
0);
}
break;
case miopenInt32: break;
Expand Down Expand Up @@ -621,9 +612,9 @@ miopenStatus_t CallGemm(const Handle& handle,
};
break;

case miopenInt8x4:
case miopenDouble: {
MIOPEN_THROW(miopenStatusBadParm,
"miopenDouble data type not supported by MIOpenGEMM.");
MIOPEN_THROW(miopenStatusBadParm, "Unknown or unsupported data type.");
};
break;
}
Expand Down Expand Up @@ -665,6 +656,7 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle,
gemm_desc.isColMajor = !gemm_desc.isColMajor;
std::swap(A, B);
std::swap(a_offset, b_offset);
std::swap(gemm_desc.a_cast_type, gemm_desc.b_cast_type);
std::swap(gemm_desc.transA, gemm_desc.transB);
std::swap(gemm_desc.m, gemm_desc.n);
std::swap(gemm_desc.lda, gemm_desc.ldb);
Expand Down Expand Up @@ -693,7 +685,6 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle,

switch(gemm_desc.dataType)
{
case miopenInt8x4:
case miopenInt8: {
assert(gemm_desc.k % 4 == 0);

Expand Down Expand Up @@ -729,12 +720,7 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle,
rocblas_datatype::rocblas_datatype_i32_r,
rocblas_gemm_algo::rocblas_gemm_algo_standard,
0,
#if USE_GEMM_FLAGS_PACK_INT8X4
rocblas_gemm_flags_pack_int8x4
#else
0
#endif
);
0);
}
break;
case miopenInt32: break;
Expand Down Expand Up @@ -893,10 +879,10 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle,
break;
}

case miopenInt8x4:
case miopenDouble: {
MIOPEN_THROW(miopenStatusBadParm,
"miopenDouble data type not supported by MIOpenGEMM.");
}
MIOPEN_THROW(miopenStatusBadParm, "Unknown or unsupported data type.");
};
break;
}

Expand Down Expand Up @@ -938,6 +924,7 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle,
gemm_desc.isColMajor = !gemm_desc.isColMajor;
std::swap(A, B);
std::swap(a_offset, b_offset);
std::swap(gemm_desc.a_cast_type, gemm_desc.b_cast_type);
std::swap(gemm_desc.transA, gemm_desc.transB);
std::swap(gemm_desc.m, gemm_desc.n);
std::swap(gemm_desc.lda, gemm_desc.ldb);
Expand Down Expand Up @@ -968,7 +955,6 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle,

switch(gemm_desc.dataType)
{
case miopenInt8x4:
case miopenInt8: {
assert(gemm_desc.k % 4 == 0);

Expand Down Expand Up @@ -1002,12 +988,7 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle,
rocBlasComputeType(gemm_desc), // rocblas_datatype::rocblas_datatype_i32_r,
rocblas_gemm_algo::rocblas_gemm_algo_standard,
0,
#if USE_GEMM_FLAGS_PACK_INT8X4
rocblas_gemm_flags_pack_int8x4
#else
0
#endif
);
0);
}
}
break;
Expand Down Expand Up @@ -1163,10 +1144,10 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle,
break;
}

case miopenInt8x4:
case miopenDouble: {
MIOPEN_THROW(miopenStatusBadParm,
"miopenDouble data type not supported by MIOpenGEMM.");
}
MIOPEN_THROW(miopenStatusBadParm, "Unknown or unsupported data type.");
};
break;
}

Expand Down Expand Up @@ -1196,7 +1177,7 @@ GemmDescriptor CreateGemmDescriptorConvFwd(const TensorDescriptor& wDesc,
{
#ifndef NDEBUG
assert(wDesc.GetType() == xDesc.GetType());
if(wDesc.GetType() != miopenInt8 && wDesc.GetType() != miopenInt8x4)
if(wDesc.GetType() != miopenInt8)
assert(wDesc.GetType() == yDesc.GetType());
#endif

Expand Down Expand Up @@ -1351,7 +1332,7 @@ GemmDescriptor CreateGemmDescriptorConvCNHWFwd(const TensorDescriptor& wDesc,
{
#ifndef NDEBUG
assert(wDesc.GetType() == xDesc.GetType());
if(wDesc.GetType() != miopenInt8 && wDesc.GetType() != miopenInt8x4)
if(wDesc.GetType() != miopenInt8)
assert(wDesc.GetType() == yDesc.GetType());
#endif

Expand Down Expand Up @@ -1455,7 +1436,7 @@ GemmDescriptor CreateGemmStridedBatchedDescriptorConv1x1Fwd(const TensorDescript
{
#ifndef NDEBUG
assert(wDesc.GetType() == xDesc.GetType());
if(wDesc.GetType() != miopenInt8 && wDesc.GetType() != miopenInt8x4)
if(wDesc.GetType() != miopenInt8)
assert(wDesc.GetType() == yDesc.GetType());
#else
(void)yDesc;
Expand Down
6 changes: 5 additions & 1 deletion src/hip/batched_transpose_sol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,11 @@ BatchedTransposeSolution::BatchedTransposeSolution(const ExecutionContext& ctx,
uint32_t width_)
: data_type(data_type_), batch(batch_), height(height_), width(width_)
{
if(data_type == miopenInt8x4 || data_type == miopenDouble)
if(!(data_type == miopenHalf //
|| data_type == miopenFloat //
|| data_type == miopenInt32 //
|| data_type == miopenInt8 //
|| data_type == miopenBFloat16))
MIOPEN_THROW("These data type are not supported");
num_cu = ctx.GetStream().GetMaxComputeUnits();
std::size_t data_size = miopen::GetTypeSize(data_type);
Expand Down
Loading

0 comments on commit 936282c

Please sign in to comment.