-
Notifications
You must be signed in to change notification settings - Fork 162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Crash on concatenate
after latest update
#114
Comments
Which model were you running? |
So far it has happened on Phi 3.5 mini 4-bit and Llama 3.1 9B 4-bit. I haven't tested other models yet. |
Can you provide more details on what exactly you are running? Most likely the line it's breaking at is in the new KV cache. It looks like one of the inputs to that function has the wrong order. Are you using custom model code or the same models as in the example? |
I'm using the models from mlx-libraries. Generally this happens on the second or third prompt in a conversation. I'm still trying to investigate this on my end but wanted to open this issue in case others are having similar problems. |
I was using mlx-community/Phi-3-mini-4k-instruct-4bit as the primary use case so I know that one works generally. Is there something I can do to reproduce the issue? I am happy to debug it. |
How do you do the conversation? How is the state carried from one call to generate to the next? |
I use the prompt templates that are commonly used for each model to represent the conversation, adding to them for each new prompt and response, and passing the updated template to |
I've been able to reproduce this with LLMEval after updating the swift-transformers dependency to use the latest commit on the main branch. Short prompts work as expected, but after submitting a long prompt (about 5600 characters) with Llama 3 9B 4-bit, I get this crash:
And with Phi 3.5 mini:
I noticed that the active memory indicator in the app grows to a very large number when this happens. |
ok perfect, I have a repro using the output of the previous runs: |
actually a different error: <|assistant|>MLX error: [scaled_dot_product_attention] mismatching batch dimension for input with shape (512,32,182,96). at /Users/dkoski/Library/Developer/Xcode/DerivedData/mlx-swift-examples-eimbjcofifunwybkcvhnzjbqwyri/SourcePackages/checkouts/mlx-swift/Source/Cmlx/include/mlx/c/fast.cpp:61 hopefully that is related though. I also see the big memory spike (actually this matches your other error) |
So something is off in the KVCache. The shape of the keys after the prompt on the python side: (1, 32, 512, 96) swift:
The 256 vs 512 is because I messed with the prefill step size. Anyway the 0 dimension is not right. |
This comes from different shapes as input to Attention: python: (1, 512, 3072) aha: model(y[:prefill_step_size][None], cache=cache) does not translate to: _ = model(y[..<parameters.prefillStepSize, .newAxis], cache: cache) the order of the trailing y[.newAxis, ..<parameters.prefillStepSize] |
It is amazing that this gross mismatch of shapes ... mostly works. It sure would be nice to have some typing on shapes. I suppose we could use |
@DePasqualeOrg thank you so much for reporting this! Your info got a quick repro and I was able to track down the issue. You can try |
Fantastic, thank you! I tested this with Phi 3.5 mini and Llama 3.1 9B, and it mostly seems to work, but on longer, multi-turn prompts I got garbled output from Phi 3.5 mini and special tokens like I'm also curious how you would recommend estimating the required memory for a given prompt with this new approach. |
The handling of the RoPE positional encodings is not quite right for both Llama 3.1 and Phi 3.5. So if you're prompt + generation is very long (like 4k tokens or more) that might explain it. The new KV Cache shouldn't change the results at all.. if you are finding that it does, then that is a bug. We'll want to update to the latest MLX to fix this.
Since the attention steps are fixed at 512 the maximum size of the attention scores is now
|
Whenever you're able to update to the latest MLX, I'll test this again and see if that solves the problem. |
I haven't encountered this lately, so I believe the issue is resolved. |
After ab94ffc, I'm getting a crash the second time I try to generate text with my app, which uses mlx-libraries. I can't reproduce this with the LLMEval example app at the moment, but I'll try to find the cause.
MLX error: [concatenate] All the input array dimensions must match exactly except for the concatenation axis. However, the provided shapes are (512,8,1,128), (1,8,512,128), and the concatenation axis is 2. at /Users/<user>/Library/Developer/Xcode/DerivedData/<app name>-ejjtjaklhfhyarhbwjdbxiatlsar/SourcePackages/checkouts/mlx-swift/Source/Cmlx/include/mlx/c/ops.cpp:217
The text was updated successfully, but these errors were encountered: