From c4933479540703d67b912f7ac801d11993084840 Mon Sep 17 00:00:00 2001 From: johnson-magic Date: Wed, 27 Mar 2024 18:26:25 +0800 Subject: [PATCH] fix: update DecodingWeight.h --- src/fastertransformer/models/decoding/DecodingWeight.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fastertransformer/models/decoding/DecodingWeight.h b/src/fastertransformer/models/decoding/DecodingWeight.h index 5b58c1e2c..5a3fba7eb 100644 --- a/src/fastertransformer/models/decoding/DecodingWeight.h +++ b/src/fastertransformer/models/decoding/DecodingWeight.h @@ -74,7 +74,7 @@ struct DecodingWeight { mem_hidden_units_(other.mem_hidden_units_) { mallocWeights(); - cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], max_seq_len_ * vocab_size_); + cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], max_seq_len_ * hidden_units_); cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], vocab_size_ * hidden_units_); cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_); cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], hidden_units_); @@ -98,7 +98,7 @@ struct DecodingWeight { mem_hidden_units_ = other.mem_hidden_units_; mallocWeights(); - cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], max_seq_len_ * vocab_size_); + cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], max_seq_len_ * hidden_units_); cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], vocab_size_ * hidden_units_); cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_); cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], hidden_units_); @@ -115,7 +115,7 @@ struct DecodingWeight { void mallocWeights() { - deviceMalloc(&weights_ptr[0], max_seq_len_ * vocab_size_); + deviceMalloc(&weights_ptr[0], max_seq_len_ * hidden_units_); deviceMalloc(&weights_ptr[1], vocab_size_ * hidden_units_); deviceMalloc(&weights_ptr[2], hidden_units_); deviceMalloc(&weights_ptr[3], hidden_units_);