Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Crash on concatenate after latest update #114

Closed
DePasqualeOrg opened this issue Aug 29, 2024 · 18 comments
Closed

Crash on concatenate after latest update #114

DePasqualeOrg opened this issue Aug 29, 2024 · 18 comments
Assignees

Comments

@DePasqualeOrg
Copy link
Contributor

After ab94ffc, I'm getting a crash the second time I try to generate text with my app, which uses mlx-libraries. I can't reproduce this with the LLMEval example app at the moment, but I'll try to find the cause.

MLX error: [concatenate] All the input array dimensions must match exactly except for the concatenation axis. However, the provided shapes are (512,8,1,128), (1,8,512,128), and the concatenation axis is 2. at /Users/<user>/Library/Developer/Xcode/DerivedData/<app name>-ejjtjaklhfhyarhbwjdbxiatlsar/SourcePackages/checkouts/mlx-swift/Source/Cmlx/include/mlx/c/ops.cpp:217

@awni
Copy link
Member

awni commented Aug 29, 2024

Which model were you running?

@DePasqualeOrg
Copy link
Contributor Author

So far it has happened on Phi 3.5 mini 4-bit and Llama 3.1 9B 4-bit. I haven't tested other models yet.

@awni
Copy link
Member

awni commented Aug 29, 2024

Can you provide more details on what exactly you are running? Most likely the line it's breaking at is in the new KV cache. It looks like one of the inputs to that function has the wrong order.

Are you using custom model code or the same models as in the example?

@DePasqualeOrg
Copy link
Contributor Author

I'm using the models from mlx-libraries. Generally this happens on the second or third prompt in a conversation. I'm still trying to investigate this on my end but wanted to open this issue in case others are having similar problems.

@davidkoski
Copy link
Collaborator

I was using mlx-community/Phi-3-mini-4k-instruct-4bit as the primary use case so I know that one works generally.

Is there something I can do to reproduce the issue? I am happy to debug it.

@davidkoski
Copy link
Collaborator

Generally this happens on the second or third prompt in a conversation

How do you do the conversation? How is the state carried from one call to generate to the next?

@DePasqualeOrg
Copy link
Contributor Author

I use the prompt templates that are commonly used for each model to represent the conversation, adding to them for each new prompt and response, and passing the updated template to generate when a new prompt is submitted. I've had to build this myself, since swift-transformers doesn't include it, although this may change soon. I'll post an update here when I can reproduce this with LLMEval.

@DePasqualeOrg
Copy link
Contributor Author

I've been able to reproduce this with LLMEval after updating the swift-transformers dependency to use the latest commit on the main branch. Short prompts work as expected, but after submitting a long prompt (about 5600 characters) with Llama 3 9B 4-bit, I get this crash:

MLX error: [scaled_dot_product_attention] mismatching batch dimension for input with shape (512,8,170,128). at /Users/<user>/Library/Developer/Xcode/DerivedData/mlx-swift-examples-eiqkagimbcumwufwrjncqseqpfjo/SourcePackages/checkouts/mlx-swift/Source/Cmlx/include/mlx/c/fast.cpp:61

And with Phi 3.5 mini:

MLX error: [concatenate] All the input array dimensions must match exactly except for the concatenation axis. However, the provided shapes are (512,32,2,96), (1,32,512,96), and the concatenation axis is 2. at /Users/<user>/Library/Developer/Xcode/DerivedData/mlx-swift-examples-eiqkagimbcumwufwrjncqseqpfjo/SourcePackages/checkouts/mlx-swift/Source/Cmlx/include/mlx/c/ops.cpp:217

I noticed that the active memory indicator in the app grows to a very large number when this happens.

@davidkoski
Copy link
Collaborator

ok perfect, I have a repro using the output of the previous runs:

p2.txt

@davidkoski
Copy link
Collaborator

davidkoski commented Aug 30, 2024

actually a different error:

<|assistant|>MLX error: [scaled_dot_product_attention] mismatching batch dimension for input with shape (512,32,182,96). at /Users/dkoski/Library/Developer/Xcode/DerivedData/mlx-swift-examples-eimbjcofifunwybkcvhnzjbqwyri/SourcePackages/checkouts/mlx-swift/Source/Cmlx/include/mlx/c/fast.cpp:61

hopefully that is related though. I also see the big memory spike

(actually this matches your other error)

@davidkoski
Copy link
Collaborator

So something is off in the KVCache. The shape of the keys after the prompt on the python side:

(1, 32, 512, 96)

swift:

- 0 : 256
- 1 : 32
- 2 : 256
- 3 : 96

The 256 vs 512 is because I messed with the prefill step size. Anyway the 0 dimension is not right.

@davidkoski
Copy link
Collaborator

This comes from different shapes as input to Attention:

python: (1, 512, 3072)
swift: [256, 1, 3072]

aha:

        model(y[:prefill_step_size][None], cache=cache)

does not translate to:

            _ = model(y[..<parameters.prefillStepSize, .newAxis], cache: cache)

the order of the trailing [None] is actually first:

y[.newAxis, ..<parameters.prefillStepSize]

@davidkoski
Copy link
Collaborator

It is amazing that this gross mismatch of shapes ... mostly works. It sure would be nice to have some typing on shapes. I suppose we could use precondition

@davidkoski
Copy link
Collaborator

@DePasqualeOrg thank you so much for reporting this! Your info got a quick repro and I was able to track down the issue. You can try kvcache2 from #115

@DePasqualeOrg
Copy link
Contributor Author

DePasqualeOrg commented Aug 30, 2024

Fantastic, thank you! I tested this with Phi 3.5 mini and Llama 3.1 9B, and it mostly seems to work, but on longer, multi-turn prompts I got garbled output from Phi 3.5 mini and special tokens like assistant<|end_header_id|> from Llama 3.1 9B. I guess this is due to the new KV cache?

I'm also curious how you would recommend estimating the required memory for a given prompt with this new approach.

@awni
Copy link
Member

awni commented Aug 30, 2024

I tested this with Phi 3.5 mini and Llama 3.1 9B, and it mostly seems to work, but on longer, multi-turn prompts I got garbled output from Phi 3.5 mini and special tokens like assistant<|end_header_id|> from Llama 3.1 9B

The handling of the RoPE positional encodings is not quite right for both Llama 3.1 and Phi 3.5. So if you're prompt + generation is very long (like 4k tokens or more) that might explain it. The new KV Cache shouldn't change the results at all.. if you are finding that it does, then that is a bug. We'll want to update to the latest MLX to fix this.

I'm also curious how you would recommend estimating the required memory for a given prompt with this new approach.

Since the attention steps are fixed at 512 the maximum size of the attention scores is now 512 * 512 * num_heads * 2 which is not that big. The memory bottleneck for long prompts will most likely be the memory used by the KV cache. That will scale as the product of the following factors:

  • num layers
  • 2 (keys and values)
  • length of prompt + generation
  • num_kv_heads
  • head_dim
  • 2 bytes

@DePasqualeOrg
Copy link
Contributor Author

The new KV Cache shouldn't change the results at all.. if you are finding that it does, then that is a bug. We'll want to update to the latest MLX to fix this.

Whenever you're able to update to the latest MLX, I'll test this again and see if that solves the problem.

@DePasqualeOrg
Copy link
Contributor Author

I haven't encountered this lately, so I believe the issue is resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants