Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ models are not respecting context_length or max_seq_len settings #191

Open
chrsbats opened this issue Dec 6, 2023 · 0 comments
Open

Comments

@chrsbats
Copy link

chrsbats commented Dec 6, 2023

No matter what I try I can't set the context_length of a GPTQ model. It's overridden by ExLLAMA, which then sets the cache size and context_length whatever it set as default (in this case 2048).

First problem is that its actually using max_seq_len to set the context_length and the Config dataclass doesn't include that field. Even if I monkey patch the config dataclass and set the Config

        model = "TheBloke/NeuralHermes-2.5-Mistral-7B-GPTQ"
        config = AutoConfig.from_pretrained(model)
        config.max_seq_len = 8000
        config.context_length = 8000
        config.config.max_seq_len = 8000
        config.config.context_length = 8000
        self.config = config
        self.llm = AutoModelForCausalLM.from_pretrained(model,config=self.config,local_files_only=True,max_seq_len=8000)
        self.llm.config.context_length = 8000
        self.llm.config.max_seq_len = 8000

None of these will change the context_length used by the GPTQ model because it uses the ExLLAMA config instead.

If I reach in and modify the ExLLAMA config after loading the model via

        self.llm._model.config.max_seq_len = 8000
        self.llm._model.config.max_input_len = 8000

It correctly sets the context_length that but its already allocated a cache size at 2048 and promptly crashes whenever you ask for a long response.

File ~/anaconda3/envs/orac/lib/python3.11/site-packages/exllama/model.py:369, in ExLlamaAttention.fused(self, hidden_states, cache, buffer, input_layernorm, lora)
    365 query_states = query_states.view(bsz, q_len, self.config.num_attention_heads, self.config.head_dim)
    367 # Get k, v with past
--> 369 key_states = cache.key_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz)
    370 value_states = cache.value_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz)
    372 # Repeat K/V heads if num_key_value_headsn_kv_heads < n_heads

RuntimeError: start (0) + length (2049) exceeds dimension size (2048).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant