diff --git a/src/fastertransformer/models/decoding/DecodingWeight.h b/src/fastertransformer/models/decoding/DecodingWeight.h index 5b58c1e2c..5a3fba7eb 100644 --- a/src/fastertransformer/models/decoding/DecodingWeight.h +++ b/src/fastertransformer/models/decoding/DecodingWeight.h @@ -74,7 +74,7 @@ struct DecodingWeight { mem_hidden_units_(other.mem_hidden_units_) { mallocWeights(); - cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], max_seq_len_ * vocab_size_); + cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], max_seq_len_ * hidden_units_); cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], vocab_size_ * hidden_units_); cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_); cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], hidden_units_); @@ -98,7 +98,7 @@ struct DecodingWeight { mem_hidden_units_ = other.mem_hidden_units_; mallocWeights(); - cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], max_seq_len_ * vocab_size_); + cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], max_seq_len_ * hidden_units_); cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], vocab_size_ * hidden_units_); cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_); cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], hidden_units_); @@ -115,7 +115,7 @@ struct DecodingWeight { void mallocWeights() { - deviceMalloc(&weights_ptr[0], max_seq_len_ * vocab_size_); + deviceMalloc(&weights_ptr[0], max_seq_len_ * hidden_units_); deviceMalloc(&weights_ptr[1], vocab_size_ * hidden_units_); deviceMalloc(&weights_ptr[2], hidden_units_); deviceMalloc(&weights_ptr[3], hidden_units_);