Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: memory of position_encoding_table is not malloced correctly. #790

Open
johnson-magic opened this issue Mar 27, 2024 · 0 comments · May be fixed by #791
Open

bug: memory of position_encoding_table is not malloced correctly. #790

johnson-magic opened this issue Mar 27, 2024 · 0 comments · May be fixed by #791
Labels
bug Something isn't working

Comments

@johnson-magic
Copy link

johnson-magic commented Mar 27, 2024

Branch/Tag/Commit

main

Docker Image Version

nvcr.io/nvidia/pytorch:22.12-py3

GPU name

A10

CUDA Driver

535.54.03

Reproduced Steps

1. docker run -ti --gpus all --rm nvcr.io/nvidia/pytorch:22.12-py3 bash
2. git clone --recursive https://github.com/NVIDIA/FasterTransformer.git
3. cd FasterTransformer
4. mkdir build
5. cd build
6. cmake -DSM=86 -DCMAKE_BUILD_TYPE=Release ..
7. make -j14
8. CUDA_VISIBLE_DEVICES=0 ./satrn 1 1 8 64 2048 4022 3 100 576 512 0 0.0 0

Abnormal Phenomena:
in

val = val + position_encoding[step_offset + col_index];
, step_offset is calculated with intervals of hidden_units,

So I think

cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], max_seq_len_ * vocab_size_);
should be
cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], max_seq_len_ * hidden_units_);
instead of
cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], max_seq_len_ * vocab_size_);

There are two similar situations

cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], max_seq_len_ * vocab_size_);

deviceMalloc(&weights_ptr[0], max_seq_len_ * vocab_size_);

I have pull a pr to try to fix it. @byshiue

@johnson-magic johnson-magic added the bug Something isn't working label Mar 27, 2024
@johnson-magic johnson-magic linked a pull request Mar 27, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant