-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Add gemma model support. #259
Conversation
130cfbe
to
7fcea4b
Compare
4aca9cf
to
8adcda6
Compare
b8ed43b
to
ad79da3
Compare
b175049
to
6ce1221
Compare
55cbf57
to
c12408c
Compare
Add support in example.cpp please. |
the Gemma support for example.cpp will be updated in a new PR, not adding to the new content of this PR. |
src/layers/attention.h
Outdated
@@ -218,6 +218,7 @@ class Attention { | |||
bool useSelfAttn, bool doLnBefore, int *positionIds = nullptr) { | |||
|
|||
auto hiddenSize = ctx->hiddenSize; | |||
auto attSize = ctx->attHeadNum * ctx->attHeadSize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if 'attSize' not used, let's remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
ctx->mmHelper->compute_silu(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc); | ||
if (ctx->actType == DecoderContext::SILU) { | ||
ctx->mmHelper->compute_silu(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc); | ||
} else if (ctx->actType == DecoderContext::SWIGLU) { // chatglm2/3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use the original path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/utils/decoder_util.h
Outdated
// compute gelu on the left half and then add it with the right half | ||
template <typename T1, typename T2> | ||
static void geluSum(hpj::Matrix<T1> &src, hpj::Matrix<T2> &dst) { | ||
__m512 c1 = _mm512_set1_ps(0.044715f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Wait for checking Silu API. |
TODO list:
gelu
ACT support.gelu
ACT support.