From ef84a114fd4aca87116adba891cd22991c1e7909 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Tue, 7 Nov 2023 14:54:06 +0100 Subject: [PATCH] kernel fix: hidden_state batch bigger than input batch --- src/ocl/rnnocl.cpp | 20 ++++++++++++++++++++ test/rnn_seq_api.hpp | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/ocl/rnnocl.cpp b/src/ocl/rnnocl.cpp index 5a5ca9a8d9..164cc94b2b 100644 --- a/src/ocl/rnnocl.cpp +++ b/src/ocl/rnnocl.cpp @@ -558,6 +558,26 @@ void RNNDescriptor::RNNForwardTraining_MS(Handle& handle, const std::vector hcy_dst_stride{ static_cast(hidden_size * max_batch), static_cast(hidden_size), 1}; + if(in_n.at(0) < max_batch) + { + float beta = 0.; + const std::vector zero_set_size{1, + static_cast(max_batch - in_n.at(0)), + static_cast(hidden_size)}; + auto set_batch_offset = in_n.at(0) * hidden_size; + + auto set_desc = + miopen::TensorDescriptor(wDesc.GetType(), zero_set_size, hcy_dst_stride); + if(hy != nullptr) + { + SetTensor(handle, set_desc, hy, &beta, hcy_layer_offset + set_batch_offset); + } + if(cy != nullptr) + { + SetTensor(handle, set_desc, cy, &beta, hcy_layer_offset + set_batch_offset); + } + } + for(int time_i = seq_len - 1; time_i >= 0; time_i--) { auto copy_batch = (time_i == seq_len - 1) ? in_n.at(time_i) diff --git a/test/rnn_seq_api.hpp b/test/rnn_seq_api.hpp index c9ec52065e..53c40709cc 100644 --- a/test/rnn_seq_api.hpp +++ b/test/rnn_seq_api.hpp @@ -1546,7 +1546,7 @@ struct rnn_seq_api_test_driver : test_driver batchSize, padding_val); - std::vector new_seqLenArray(seqLength); + std::vector new_seqLenArray(batchSize); std::copy_n(seqLenArray.begin(), seqLenArray.size(), new_seqLenArray.begin()); std::fill_n(new_seqLenArray.begin() + seqLenArray.size(),