Skip to content

Commit

Permalink
Patch ISQ for Mixtral (#730)
Browse files Browse the repository at this point in the history
* Correctly recast xs to original dtype

* Reset

* See if it works now

* Apply fix to xlora model
  • Loading branch information
EricLBuehler authored Sep 1, 2024
1 parent db88c11 commit 3f5d0a9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
3 changes: 2 additions & 1 deletion mistralrs-core/src/models/mixtral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,8 @@ impl DecoderLayer {
let residual = &xs;
let xs = xs
.apply(&self.post_attention_layernorm)?
.apply(&self.block_sparse_moe)?;
.apply(&self.block_sparse_moe)?
.to_dtype(residual.dtype())?;
residual + xs
}
}
Expand Down
15 changes: 9 additions & 6 deletions mistralrs-core/src/xlora_models/mixtral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,12 +560,15 @@ impl DecoderLayer {
)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self.block_sparse_moe.forward(
&xs.apply(&self.post_attention_layernorm)?,
scalings.clone(),
global_scaling_weight,
is_scaling_pass,
)?;
let xs = self
.block_sparse_moe
.forward(
&xs.apply(&self.post_attention_layernorm)?,
scalings.clone(),
global_scaling_weight,
is_scaling_pass,
)?
.to_dtype(residual.dtype())?;
residual + xs
}
}
Expand Down

0 comments on commit 3f5d0a9

Please sign in to comment.