Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement getitem backward #2883

Merged
merged 204 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
204 commits
Select commit Hold shift + click to select a range
188c339
init driver and gtest
seungmanhan Mar 28, 2024
e0ee983
add getitem driver and gtest, init host api and kernel
seungmanhan Apr 6, 2024
d68d1c3
add host API and kernel, fix build error
seungmanhan Apr 7, 2024
b48c73d
fix driver build error
seungmanhan Apr 8, 2024
6a219fa
fix kernel build error
seungmanhan Apr 8, 2024
7c48ef5
fix driver error
seungmanhan Apr 8, 2024
170059a
fix error, add atomic add for half and bfloat16
seungmanhan Apr 8, 2024
b1e2173
change tref to float
seungmanhan Apr 8, 2024
222e267
clang format
seungmanhan Apr 8, 2024
391ce83
remove unused value
seungmanhan Apr 8, 2024
214f1cb
fix gtest error
seungmanhan Apr 8, 2024
186230c
add applicable function, remove unused function, 2023->2024
seungmanhan Apr 8, 2024
cfbbb84
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Apr 11, 2024
ae00e7b
add getitem driver
seungmanhan Apr 11, 2024
bf4f195
add doc
seungmanhan Apr 15, 2024
6fbcfbf
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Apr 16, 2024
349fc17
fix namespace typo
seungmanhan Apr 22, 2024
5369205
Merge branch 'impl_getitem_bwd' of https://github.com/ROCm/MIOpen int…
seungmanhan Apr 22, 2024
d377def
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Apr 22, 2024
d97ee71
remove const, remove push_back in for loop, remove pop_back
seungmanhan Apr 22, 2024
ea798c4
apply make analyze
seungmanhan Apr 22, 2024
bd25a19
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Apr 22, 2024
ab15284
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Apr 23, 2024
b0f3379
add tensor view include, add driver input check, remove unused value,…
seungmanhan Apr 26, 2024
5b17930
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Apr 29, 2024
243c1fa
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Apr 30, 2024
abc86ab
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 2, 2024
a7f67e0
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 7, 2024
68a7da6
fix error
seungmanhan May 7, 2024
aab0e30
clang format
seungmanhan May 7, 2024
017ff48
Merge branch 'impl_getitem_bwd' of https://github.com/ROCm/MIOpen int…
seungmanhan May 7, 2024
434026e
add comment and remove unused macro
seungmanhan May 7, 2024
5d38709
fix build error
seungmanhan May 7, 2024
a95218a
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 8, 2024
db9d629
change macro to constexpr
seungmanhan May 10, 2024
6c438c2
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 10, 2024
226265e
fix build error, add comment
seungmanhan May 10, 2024
595d3ca
clang format
seungmanhan May 10, 2024
c29fb0a
remove duplicate code, add newtwork config
seungmanhan May 10, 2024
c800521
add comment
seungmanhan May 10, 2024
4d8360b
remove unused function, modify comment
seungmanhan May 10, 2024
d552950
add comment
seungmanhan May 10, 2024
4f0e849
change c style to C++, remove unnecessary code and add atomic add for…
seungmanhan May 10, 2024
9066d11
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 11, 2024
1cb3612
fix build error
seungmanhan May 11, 2024
08b8bb6
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 14, 2024
ff17728
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 16, 2024
637cf3b
add uint64_t function i InputFlags, remove unnecessary code
seungmanhan May 17, 2024
d16f029
Merge branch 'impl_getitem_bwd' of https://github.com/ROCm/MIOpen int…
seungmanhan May 17, 2024
cacb376
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 17, 2024
523d952
fix build error
seungmanhan May 17, 2024
754eb50
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 20, 2024
f45cc61
layerout -> layout
seungmanhan May 20, 2024
2d8868f
Merge branch 'impl_getitem_bwd' of https://github.com/ROCm/MIOpen int…
seungmanhan May 20, 2024
5f40374
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 22, 2024
54983ac
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 22, 2024
69d3446
remove unnecessary workspace error logic
seungmanhan May 23, 2024
6be79f0
add standalone run gtest
seungmanhan May 23, 2024
ebed155
fix build error in gtest
seungmanhan May 23, 2024
a37b79d
remove GetitemBackward::GetWorkspaceSize
seungmanhan May 23, 2024
7ec4b31
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 24, 2024
74e16c6
remove unused value
seungmanhan May 24, 2024
2cd6e37
remove printf
seungmanhan May 24, 2024
c8c6024
fix sum gtest error
seungmanhan May 24, 2024
de9276d
fix HIP tidy issue
seungmanhan May 24, 2024
020a1bc
fix warning
seungmanhan May 24, 2024
6810b81
Merge branch 'develop' into impl_getitem_bwd
junliume May 25, 2024
d8223bd
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 27, 2024
7dea455
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 27, 2024
c48ec89
revert ab_case
seungmanhan May 27, 2024
de7c9d2
fix tensor view error
seungmanhan May 27, 2024
b063b7c
revert gtest except getitem
seungmanhan May 27, 2024
4bab1e5
revert getitem gtest
seungmanhan May 28, 2024
eca01db
revert get item workspcae
seungmanhan May 28, 2024
4f5f447
fix build error
seungmanhan May 28, 2024
fcff9c3
Change GetWorkspaceSizes logic
seungmanhan May 28, 2024
3c42e24
revert gtest change
seungmanhan May 28, 2024
d80bbea
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 28, 2024
40440f5
remove unused variable
seungmanhan May 28, 2024
88455c8
fix get inner expanded tv error
seungmanhan May 28, 2024
3b41ae9
change file name item to getitem
seungmanhan May 28, 2024
46d608d
Change GetWorkspaceSizes logic in t5layernorm
seungmanhan May 28, 2024
d144993
change file name in cmake list
seungmanhan May 28, 2024
bf2a313
item to getitem
seungmanhan May 28, 2024
d07c883
clang format
seungmanhan May 28, 2024
f9eeae2
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 29, 2024
5ed364f
make tensor view uilts header file
seungmanhan May 29, 2024
040be9d
Merge branch 'impl_getitem_bwd' of https://github.com/ROCm/MIOpen int…
seungmanhan May 29, 2024
46aaf9e
cuto to onst auto&
seungmanhan May 29, 2024
5b104de
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 29, 2024
0440875
modify problem_description
seungmanhan May 29, 2024
1f8298a
add MIOPEN_TEST_ALL check in getitem gtest
seungmanhan May 29, 2024
363dbe2
Merge branch 'impl_getitem_bwd' of https://github.com/ROCm/MIOpen int…
seungmanhan May 29, 2024
f882c13
revert test all check
seungmanhan May 30, 2024
05e1775
int32_t -> uint32_t
seungmanhan May 30, 2024
a1eb5cc
modify error code
seungmanhan May 30, 2024
4d06fcc
add indexDescs check, modify problem desc
seungmanhan May 30, 2024
5f46dc3
add nullptr check
seungmanhan May 30, 2024
32fad05
fix warning
seungmanhan May 30, 2024
de65023
clang format
seungmanhan May 30, 2024
dc5fed2
fix build error
seungmanhan May 30, 2024
0977e22
move valid functions to ctor
seungmanhan May 30, 2024
8509e39
fix typo error
seungmanhan May 30, 2024
548bd9a
revert MIOPEN_TEST_ALL
seungmanhan May 30, 2024
f27cce6
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 30, 2024
dc42916
clang format
seungmanhan May 30, 2024
0a39c92
Merge branch 'impl_getitem_bwd' of https://github.com/ROCm/MIOpen int…
seungmanhan May 30, 2024
7bf43df
Merge branch 'develop' into impl_getitem_bwd
seungmanhan May 31, 2024
2da1447
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 1, 2024
3374ea2
add MIOPEN_TEST_ALL check
seungmanhan Jun 3, 2024
4f07357
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 5, 2024
9fb7ef8
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 5, 2024
494a84c
revert MIOPEN_TEST_ALL check
seungmanhan Jun 5, 2024
808b4ac
Merge branch 'impl_getitem_bwd' of https://github.com/ROCm/MIOpen int…
seungmanhan Jun 5, 2024
80be44b
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 5, 2024
3ac4f16
Merge branch 'develop' into impl_getitem_bwd
junliume Jun 5, 2024
51e2eca
Merge branch 'develop' into impl_getitem_bwd
junliume Jun 5, 2024
0256ef5
fix build error
seungmanhan Jun 6, 2024
c87aaec
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 7, 2024
9cab437
size_t->uint64, fix type error
seungmanhan Jun 7, 2024
504f8b1
Merge branch 'impl_getitem_bwd' of https://github.com/ROCm/MIOpen int…
seungmanhan Jun 7, 2024
c628f4c
fix profile error
seungmanhan Jun 7, 2024
caaaff1
add bool check
seungmanhan Jun 7, 2024
064ae03
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 7, 2024
f089e75
fix build error
seungmanhan Jun 7, 2024
8cd13b1
Merge branch 'impl_getitem_bwd' of https://github.com/ROCm/MIOpen int…
seungmanhan Jun 7, 2024
f7d923f
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 10, 2024
6255e87
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 11, 2024
deebd72
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 11, 2024
93da019
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 12, 2024
c3f6ab4
remove unused varialbe
seungmanhan Jun 12, 2024
8a21de5
remove unused variable
seungmanhan Jun 12, 2024
bd05a6e
\n->std::endl, modify comment, adjust tolerance
seungmanhan Jun 12, 2024
4237d24
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 12, 2024
98fe8a1
debug getitem gtest
seungmanhan Jun 13, 2024
9e69f7f
Merge branch 'impl_getitem_bwd' of https://github.com/ROCm/MIOpen int…
seungmanhan Jun 13, 2024
35ac2c0
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 13, 2024
2f76d96
miopen::IsEnabled(ENV) to env::enabled
seungmanhan Jun 13, 2024
6e71d03
miopen::GetStringEnv(ENV) to env::value
seungmanhan Jun 13, 2024
b141946
add MIOPEN_TEST_ALL check
seungmanhan Jun 13, 2024
133d416
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 14, 2024
8f7bfbe
revert other op change
seungmanhan Jun 14, 2024
26627fe
revert other op change2
seungmanhan Jun 14, 2024
e58ec3d
github action debug
seungmanhan Jun 15, 2024
44a9b6b
fix t5layernorm driver default
seungmanhan Jun 15, 2024
f1e6912
modify threshild
seungmanhan Jun 15, 2024
3745d94
clang format
seungmanhan Jun 15, 2024
4561d66
error debug
seungmanhan Jun 16, 2024
989cf69
fix warning
seungmanhan Jun 16, 2024
e7a24a9
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 16, 2024
7ecec02
fix warning
seungmanhan Jun 16, 2024
e063776
adjust threshold
seungmanhan Jun 16, 2024
151a987
adjust threshold in driver
seungmanhan Jun 16, 2024
098421b
remove getitem gtest for debug
seungmanhan Jun 17, 2024
11fdae9
clang format
seungmanhan Jun 17, 2024
b0de59c
revert debug
seungmanhan Jun 17, 2024
27c00b6
clang format
seungmanhan Jun 17, 2024
46a94d4
fix doxygen error
seungmanhan Jun 17, 2024
e3a0d72
fix build error
seungmanhan Jun 18, 2024
d619bc2
add comment
seungmanhan Jun 18, 2024
34b5ae0
modify initilization
seungmanhan Jun 18, 2024
eda199d
change order
seungmanhan Jun 18, 2024
44c4da4
remove half, bfloat16 test for debug
seungmanhan Jun 18, 2024
6c7105f
revert debut, fix typo error
seungmanhan Jun 18, 2024
1349d01
revert debut
seungmanhan Jun 18, 2024
c66c4e1
remove unused if
seungmanhan Jun 18, 2024
0c91a2e
modify threshold
seungmanhan Jun 18, 2024
35eef25
fix build error
seungmanhan Jun 18, 2024
341b331
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 24, 2024
3f4ea40
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jun 26, 2024
79eac38
Merge branch 'develop' into impl_getitem_bwd
junliume Jun 27, 2024
01ee523
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jul 2, 2024
1830461
fix type error
seungmanhan Jul 2, 2024
f474a65
modify tolerance
seungmanhan Jul 2, 2024
86696c4
modify t5layernorm driver defalut
seungmanhan Jul 2, 2024
3c93547
change layernorm mode type bool to int
seungmanhan Jul 2, 2024
de5e413
add MIOPEN_TEST_ALL in layernorms
seungmanhan Jul 2, 2024
e11f9e3
Modify cat driver defalut
seungmanhan Jul 3, 2024
019ab9f
add device kernel in groupnorm, change mean and rstd type, update tol…
seungmanhan Jul 3, 2024
e3e37ba
revert layernorm tolerance calculation
seungmanhan Jul 3, 2024
a8c1df1
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jul 3, 2024
17364fc
remove failed driver for debug
seungmanhan Jul 3, 2024
332e9a9
remove failed driver test
seungmanhan Jul 3, 2024
57d08ec
Merge branch 'impl_getitem_bwd' of https://github.com/ROCm/MIOpen int…
seungmanhan Jul 3, 2024
825d22a
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jul 5, 2024
cefa17d
remove CBAInfer test for debug
seungmanhan Jul 8, 2024
18b55d0
Merge branch 'impl_getitem_bwd' of https://github.com/ROCm/MIOpen int…
seungmanhan Jul 8, 2024
297f46c
fix comment
seungmanhan Jul 8, 2024
8033e39
fix comment
seungmanhan Jul 8, 2024
8084379
fix MIOPEN_BETA_API
seungmanhan Jul 8, 2024
a0fb548
add MIOPEN_TEST_ALL check
seungmanhan Jul 9, 2024
0f61a62
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jul 9, 2024
a4c741f
Merge branch 'develop' into impl_getitem_bwd
junliume Jul 9, 2024
66d0e87
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jul 10, 2024
55e215b
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jul 10, 2024
fe8676c
Merge branch 'develop' into impl_getitem_bwd
junliume Jul 12, 2024
1127557
add MIOPEN_USE
seungmanhan Jul 17, 2024
aea7ce5
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jul 17, 2024
d224fc7
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jul 17, 2024
c86f2e9
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jul 18, 2024
943941f
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jul 22, 2024
105bc59
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jul 22, 2024
2e966b1
add MIOPEN_INTERNALS_EXPORT
seungmanhan Jul 24, 2024
eb34256
Merge branch 'develop' into impl_getitem_bwd
seungmanhan Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ The MIOpen API library is structured as follows:
* :doc:`Cat <../doxygen/html/group__cat>` (experimental)
* :doc:`SGD <../doxygen/html/group___s_g_d>` (experimental)
* :doc:`ReduceExtreme <../doxygen/html/group__ReduceExtreme>` (experimental)
* :doc:`Getitem <../doxygen/html/group__getitem>` (experimental)
1 change: 1 addition & 0 deletions driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ add_executable(MIOpenDriver
dm_dropout.cpp
dm_fusion.cpp
dm_gemm.cpp
dm_getitem.cpp
dm_groupnorm.cpp
dm_layernorm.cpp
dm_lrn.cpp
Expand Down
159 changes: 159 additions & 0 deletions driver/InputFlags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,165 @@ TensorParameters InputFlags::GetValueTensor(const std::string& long_name) const

MIOPEN_THROW("Too many tensor descriptor parameters.");
}

TensorParametersUint64 InputFlags::GetValueTensorUint64(const std::string& long_name) const
{
const auto& input = MapInputs.at(FindShortName(long_name));
const auto components = miopen::SplitDelim(input.value.c_str(), ',');

if(components.size() < 1)
return {};

auto parse = [](auto line) {
auto ret = std::vector<uint64_t>{};
const auto strs = miopen::SplitDelim(line, 'x');
seungmanhan marked this conversation as resolved.
Show resolved Hide resolved
for(auto&& str : strs)
{
auto elem = uint64_t{};
auto ss = std::istringstream{str};
ss >> elem;

if(ss.bad() || ss.fail())
MIOPEN_THROW("Invalid tensor component " + str + " in " + line + ".");

ret.push_back(elem);
}
return ret;
};

auto lens = parse(components[0]);

if(components.size() == 1)
return {lens};

auto layout = std::string{};
auto strides = std::vector<uint64_t>{};

if(std::isdigit(components[1][0]))
strides = parse(components[1]);
else
layout = components[1];

if(components.size() == 2)
return {lens, strides, layout};

MIOPEN_THROW("Too many tensor descriptor parameters.");
}

std::vector<int32_t> InputFlags::GetValueVectorInt(const std::string& long_name) const
{
const auto& input = MapInputs.at(FindShortName(long_name));

auto ret = std::vector<int32_t>{};
const auto strs = miopen::SplitDelim(input.value.c_str(), ',');

for(auto&& str : strs)
{
auto elem = int32_t{};
auto ss = std::istringstream{str};
ss >> elem;

if(ss.bad() || ss.fail())
MIOPEN_THROW("Invalid tensor component " + str + " in " + input.value.c_str() + ".");

ret.push_back(elem);
}

return ret;
}

std::vector<uint64_t> InputFlags::GetValueVectorUint64(const std::string& long_name) const
{
const auto& input = MapInputs.at(FindShortName(long_name));

auto ret = std::vector<uint64_t>{};
const auto strs = miopen::SplitDelim(input.value.c_str(), ',');

for(auto&& str : strs)
{
auto elem = uint64_t{};
auto ss = std::istringstream{str};
ss >> elem;

if(ss.bad() || ss.fail())
MIOPEN_THROW("Invalid tensor component " + str + " in " + input.value.c_str() + ".");

ret.push_back(elem);
}

return ret;
}

std::vector<std::vector<int32_t>>
InputFlags::GetValue2dVectorInt(const std::string& long_name) const
{
const auto& input = MapInputs.at(FindShortName(long_name));
const auto components = miopen::SplitDelim(input.value.c_str(), ',');
auto output = std::vector<std::vector<int32_t>>{};

if(components.size() < 1)
return {};

auto parse = [](auto line) {
auto ret = std::vector<int32_t>{};
const auto strs = miopen::SplitDelim(line, 'x');
for(auto&& str : strs)
{
auto elem = int32_t{};
auto ss = std::istringstream{str};
ss >> elem;

if(ss.bad() || ss.fail())
MIOPEN_THROW("Invalid tensor component " + str + " in " + line + ".");

ret.push_back(elem);
}
return ret;
};

for(auto&& component : components)
{
output.push_back(parse(component));
}

return output;
}

std::vector<std::vector<uint64_t>>
InputFlags::GetValue2dVectorUint64(const std::string& long_name) const
{
const auto& input = MapInputs.at(FindShortName(long_name));
const auto components = miopen::SplitDelim(input.value.c_str(), ',');
auto output = std::vector<std::vector<uint64_t>>{};

if(components.size() < 1)
return {};

auto parse = [](auto line) {
auto ret = std::vector<uint64_t>{};
const auto strs = miopen::SplitDelim(line, 'x');
for(auto&& str : strs)
{
auto elem = uint64_t{};
auto ss = std::istringstream{str};
ss >> elem;

if(ss.bad() || ss.fail())
MIOPEN_THROW("Invalid tensor component " + str + " in " + line + ".");

ret.push_back(elem);
}
return ret;
};

for(auto&& component : components)
{
output.push_back(parse(component));
}

return output;
}

void InputFlags::SetValue(const std::string& long_name, const std::string& new_value)
{
char short_name = FindShortName(long_name);
Expand Down
24 changes: 24 additions & 0 deletions driver/InputFlags.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,25 @@ struct TensorParameters
void CalculateStrides();
};

struct TensorParametersUint64
{
std::vector<uint64_t> lengths = {};
std::vector<uint64_t> strides = {};
std::string layout = "";

TensorParametersUint64 FillMissing(const TensorParametersUint64& other) const
{
return {
(lengths.empty() ? other.lengths : lengths),
(strides.empty() ? other.strides : strides),
(layout.empty() ? other.layout : layout),
};
}

uint64_t SetTensordDescriptor(miopenTensorDescriptor_t result, miopenDataType_t data_type);
void CalculateStrides();
};

class InputFlags
{
std::map<char, Input> MapInputs;
Expand Down Expand Up @@ -90,6 +109,11 @@ class InputFlags
uint64_t GetValueUint64(const std::string& _long_name) const;
double GetValueDouble(const std::string& _long_name) const;
TensorParameters GetValueTensor(const std::string& long_name) const;
TensorParametersUint64 GetValueTensorUint64(const std::string& long_name) const;
std::vector<int32_t> GetValueVectorInt(const std::string& long_name) const;
std::vector<uint64_t> GetValueVectorUint64(const std::string& long_name) const;
std::vector<std::vector<int32_t>> GetValue2dVectorInt(const std::string& long_name) const;
std::vector<std::vector<uint64_t>> GetValue2dVectorUint64(const std::string& long_name) const;
void SetValue(const std::string& long_name, const std::string& new_value);
void StoreOptionalFlagValue(char short_name, const std::string& input_value);

Expand Down
14 changes: 7 additions & 7 deletions driver/addlayernorm_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ class AddLayerNormDriver : public Driver
std::vector<Tgpu> weight;
std::vector<Tgpu> bias;
std::vector<Tgpu> out;
std::vector<Tref> mean;
std::vector<Tref> rstd;
std::vector<Tgpu> mean;
std::vector<Tgpu> rstd;
std::vector<Tref> outhost;
std::vector<Tref> meanhost;
std::vector<Tref> rstdhost;
Expand Down Expand Up @@ -259,7 +259,7 @@ int AddLayerNormDriver<Tgpu, Tref>::AddCmdLineArgs()
inflags.AddInputFlag("eps", 'e', "0.00001", "Alpha (Default=0.00001)", "double");
inflags.AddInputFlag("normalized_dim", 'o', "3", "Nomalized Dim (Default=3)", "int");
inflags.AddInputFlag(
"mode", 'm', "0", "elemwise affine mode (0), weight and bias mode (1) (Default=0)", "int");
"mode", 'm', "2", "elemwise affine mode (2), weight and bias mode (3) (Default=0)", "int");

inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int");
inflags.AddInputFlag("verify", 'V', "1", "Verify Each Layer (Default=1)", "int");
Expand Down Expand Up @@ -291,16 +291,16 @@ int AddLayerNormDriver<Tgpu, Tref>::AllocateBuffersAndCopy()
weight_dev = std::unique_ptr<GPUMem>(new GPUMem(ctx, weight_sz, sizeof(Tgpu)));
bias_dev = std::unique_ptr<GPUMem>(new GPUMem(ctx, bias_sz, sizeof(Tgpu)));
out_dev = std::unique_ptr<GPUMem>(new GPUMem(ctx, out_sz, sizeof(Tgpu)));
mean_dev = std::unique_ptr<GPUMem>(new GPUMem(ctx, mean_sz, sizeof(Tref)));
rstd_dev = std::unique_ptr<GPUMem>(new GPUMem(ctx, rstd_sz, sizeof(Tref)));
mean_dev = std::unique_ptr<GPUMem>(new GPUMem(ctx, mean_sz, sizeof(Tgpu)));
rstd_dev = std::unique_ptr<GPUMem>(new GPUMem(ctx, rstd_sz, sizeof(Tgpu)));

in = std::vector<Tgpu>(in_sz, Tgpu0val);
in2 = std::vector<Tgpu>(in2_sz, Tgpu0val);
weight = std::vector<Tgpu>(weight_sz, Tgpu0val);
bias = std::vector<Tgpu>(bias_sz, Tgpu0val);
out = std::vector<Tgpu>(out_sz, Tgpu0val);
mean = std::vector<Tref>(mean_sz, Tref0val);
rstd = std::vector<Tref>(rstd_sz, Tref0val);
mean = std::vector<Tgpu>(mean_sz, Tgpu0val);
rstd = std::vector<Tgpu>(rstd_sz, Tgpu0val);
outhost = std::vector<Tref>(out_sz, Tref0val);
meanhost = std::vector<Tref>(mean_sz, Tref0val);
rstdhost = std::vector<Tref>(rstd_sz, Tref0val);
Expand Down
4 changes: 2 additions & 2 deletions driver/cat_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ template <typename Tgpu, typename Tref>
int CatDriver<Tgpu, Tref>::AddCmdLineArgs()
{
inflags.AddInputFlag("forw", 'F', "1", "Run only Forward Cat (Default=1)", "int");
inflags.AddTensorFlag("input1", '1', "", "input1 tensor descriptor");
inflags.AddTensorFlag("input2", '2', "", "input2 tensor descriptor");
inflags.AddTensorFlag("input1", '1', "2x32x128x128x128", "input1 tensor descriptor");
inflags.AddTensorFlag("input2", '2', "2x32x128x128x128", "input2 tensor descriptor");
inflags.AddTensorFlag("input3", '3', "", "input3 tensor descriptor");
inflags.AddTensorFlag("input4", '4', "", "input4 tensor descriptor");
inflags.AddTensorFlag("input5", '5', "", "input5 tensor descriptor");
Expand Down
40 changes: 40 additions & 0 deletions driver/dm_getitem.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2024 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "getitem_driver.hpp"
#include "registry_driver_maker.hpp"

static Driver* makeDriver(const std::string& base_arg)
{
if(base_arg == "getitem")
return new GetitemDriver<float, float>();
if(base_arg == "getitemfp16")
return new GetitemDriver<float16, float>();
if(base_arg == "getitembfp16")
return new GetitemDriver<bfloat16, float>();
return nullptr;
}

REGISTER_DRIVER_MAKER(makeDriver);
28 changes: 14 additions & 14 deletions driver/driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,13 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
[[noreturn]] inline void Usage()
{
printf("Usage: ./driver *base_arg* *other_args*\n");
printf("Supported Base Arguments: conv[fp16|int8|bfp16|fp8|bfp8], CBAInfer[fp16], "
"pool[fp16], lrn[fp16], "
printf("Supported Base Arguments: conv[fp16|int8|bfp16], pool[fp16], lrn[fp16], "
"activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], "
"tensorop[fp16], reduce[fp16|fp64], layernorm[bfp16|fp16], sum[bfp16|fp16], "
"tensorop, reduce[fp16|fp64], layernorm[bfp16|fp16], sum[bfp16|fp16], "
"groupnorm[bfp16|fp16], cat[bfp16|fp16], addlayernorm[bfp16|fp16], "
"t5layernorm[bfp16|fp16], adam[fp16], ampadam, reduceextreme[bfp16|fp16], "
"adamw[fp16], ampadamw, transformersadamw[fp16], transformersampadamw\n");
"adamw[fp16], ampadamw, transformersadamw[fp16], transformersampadamw, "
"getitem[bfp16|fp16]\n");
exit(0); // NOLINT (concurrency-mt-unsafe)
}

Expand All @@ -190,22 +190,22 @@ inline std::string ParseBaseArg(int argc, char* argv[])
std::string arg = argv[1];

if(arg != "conv" && arg != "convfp16" && arg != "convint8" && arg != "convbfp16" &&
arg != "convfp8" && arg != "convbfp8" && arg != "CBAInfer" && arg != "CBAInferfp16" &&
arg != "pool" && arg != "poolfp16" && arg != "lrn" && arg != "lrnfp16" && arg != "activ" &&
arg != "activfp16" && arg != "softmax" && arg != "softmaxfp16" && arg != "bnorm" &&
arg != "bnormfp16" && arg != "rnn" && arg != "rnnfp16" && arg != "rnn_seq" &&
arg != "rnn_seqfp16" && arg != "gemm" && arg != "gemmfp16" && arg != "ctc" &&
arg != "dropout" && arg != "dropoutfp16" && arg != "tensorop" && arg != "tensoropfp16" &&
arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "layernorm" &&
arg != "layernormfp16" && arg != "layernormbfp16" && arg != "sum" && arg != "sumfp16" &&
arg != "sumbfp16" && arg != "groupnorm" && arg != "groupnormfp16" &&
arg != "groupnormbfp16" && arg != "cat" && arg != "catfp16" && arg != "catbfp16" &&
arg != "addlayernorm" && arg != "addlayernormfp16" && arg != "addlayernormbfp16" &&
arg != "t5layernorm" && arg != "t5layernormfp16" && arg != "t5layernormbfp16" &&
arg != "adam" && arg != "adamfp16" && arg != "ampadam" && arg != "reduceextreme" &&
arg != "dropout" && arg != "dropoutfp16" && arg != "tensorop" && arg != "reduce" &&
arg != "reducefp16" && arg != "reducefp64" && arg != "layernorm" && arg != "layernormfp16" &&
arg != "layernormbfp16" && arg != "sum" && arg != "sumfp16" && arg != "sumbfp16" &&
arg != "groupnorm" && arg != "groupnormfp16" && arg != "groupnormbfp16" && arg != "cat" &&
arg != "catfp16" && arg != "catbfp16" && arg != "addlayernorm" &&
arg != "addlayernormfp16" && arg != "addlayernormbfp16" && arg != "t5layernorm" &&
arg != "t5layernormfp16" && arg != "t5layernormbfp16" && arg != "adam" &&
arg != "adamfp16" && arg != "ampadam" && arg != "reduceextreme" &&
arg != "reduceextremefp16" && arg != "reduceextremebfp16" && arg != "adamw" &&
arg != "adamwfp16" && arg != "ampadamw" && arg != "transformersadamw" &&
arg != "transformersadamwfp16" && arg != "transformersampadamw" && arg != "--version")
arg != "transformersadamwfp16" && arg != "transformersampadamw" && arg != "getitem" &&
arg != "getitemfp16" && arg != "getitembfp16" && arg != "--version")
{
printf("FAILED: Invalid Base Input Argument\n");
Usage();
Expand Down
Loading