Skip to content

Commit

Permalink
kernel fix: hidden_state batch bigger than input batch
Browse files Browse the repository at this point in the history
  • Loading branch information
shurale-nkn committed Nov 7, 2023
1 parent 34c1094 commit ef84a11
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
20 changes: 20 additions & 0 deletions src/ocl/rnnocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,26 @@ void RNNDescriptor::RNNForwardTraining_MS(Handle& handle,
const std::vector<size_t> hcy_dst_stride{
static_cast<size_t>(hidden_size * max_batch), static_cast<size_t>(hidden_size), 1};

if(in_n.at(0) < max_batch)
{
float beta = 0.;
const std::vector<size_t> zero_set_size{1,
static_cast<size_t>(max_batch - in_n.at(0)),
static_cast<size_t>(hidden_size)};
auto set_batch_offset = in_n.at(0) * hidden_size;

auto set_desc =
miopen::TensorDescriptor(wDesc.GetType(), zero_set_size, hcy_dst_stride);
if(hy != nullptr)
{
SetTensor(handle, set_desc, hy, &beta, hcy_layer_offset + set_batch_offset);
}
if(cy != nullptr)
{
SetTensor(handle, set_desc, cy, &beta, hcy_layer_offset + set_batch_offset);
}
}

for(int time_i = seq_len - 1; time_i >= 0; time_i--)
{
auto copy_batch = (time_i == seq_len - 1) ? in_n.at(time_i)
Expand Down
2 changes: 1 addition & 1 deletion test/rnn_seq_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1546,7 +1546,7 @@ struct rnn_seq_api_test_driver : test_driver
batchSize,
padding_val);

std::vector<int> new_seqLenArray(seqLength);
std::vector<int> new_seqLenArray(batchSize);

std::copy_n(seqLenArray.begin(), seqLenArray.size(), new_seqLenArray.begin());
std::fill_n(new_seqLenArray.begin() + seqLenArray.size(),
Expand Down

0 comments on commit ef84a11

Please sign in to comment.