forked from google/gemma.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gemma.cc
1649 lines (1480 loc) · 66.4 KB
/
gemma.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Lightweight C++ implementation of the gemma model.
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Must come after foreach_target.h to avoid redefinition errors.
// copybara:import_next_line:gemma_cpp
#include "compression/compress-inl.h"
// copybara:import_next_line:gemma_cpp
#include "ops.h"
#include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // Path
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
// copybara:end
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
// compile pass, whereas we want this defined in the first.
#ifndef GEMMA_ONCE
#define GEMMA_ONCE
#include <math.h> // sqrtf
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <algorithm>
#include <array>
#include <cmath>
#include <filesystem> // NOLINT
#include <iostream>
#include <memory>
#include <random>
#include <regex>
#include <string>
#include <vector>
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h"
// copybara:import_next_line:gemma_cpp
#include "configs.h"
// copybara:import_next_line:gemma_cpp
#include "gemma.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
// Setting this to true disables fread() calls that read the model file.
constexpr bool kDryRunFread = false;
// Setting this to false will load and use uncompressed weights.
constexpr bool kWeightsAreCompressed = true;
// Set this to true to debug tokenizer tokens.
constexpr bool kShowTokenization = false;
namespace gcpp {
template <size_t kNumLayers>
constexpr size_t NumLayersOfTypeBefore(
const std::array<LayerAttentionType, kNumLayers>& layers,
LayerAttentionType type, size_t num) {
size_t count = 0;
for (size_t i = 0; i < num; i++) {
if (layers[i] == type) count++;
}
return count;
}
template <typename TConfig>
constexpr size_t NumGemmaLayers() {
return NumLayersOfTypeBefore(TConfig::kLayerConfig,
LayerAttentionType::kGemma, TConfig::kLayers);
}
template <typename TConfig>
constexpr size_t NumGriffinLayers() {
return NumLayersOfTypeBefore(TConfig::kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
TConfig::kLayers);
}
template <class TConfig>
struct Layer {
Layer() = default;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
static constexpr size_t kQKVEinsumWSize =
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
// 2x for (gelu gating vector, gated vector)
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
std::array<float, kAttVecEinsumWSize> attn_vec_einsum_w;
std::array<float, kQKVEinsumWSize> qkv_einsum_w;
std::array<float, kGatingEinsumWSize> gating_einsum_w;
std::array<float, kModelDim * kFFHiddenDim> linear_w;
std::array<float, kModelDim> pre_attention_norm_scale;
std::array<float, kModelDim> pre_ffw_norm_scale;
// These fields are only used by Griffin, and do not affect loading of the
// model as it is done per-member.
// TODO(veluca): pull weights that are never used at the same time into a
// union or otherwise reduce the memory usage.
std::array<float, 2 * kFFHiddenDim> ffw_gating_biases;
std::array<float, kModelDim> ffw_output_biases;
std::array<float, kModelDim> attention_output_biases;
std::array<float, kModelDim * kModelDim> griffin_linear_y_w;
std::array<float, kModelDim> griffin_linear_y_biases;
std::array<float, kModelDim * kModelDim> griffin_linear_x_w;
std::array<float, kModelDim> griffin_linear_x_biases;
std::array<float, kModelDim * kModelDim> griffin_linear_out_w;
std::array<float, kModelDim> griffin_linear_out_biases;
std::array<float, kModelDim> griffin_conv_biases;
std::array<float, kModelDim * kModelDim / TConfig::kHeads * 2> griffin_gate_w;
std::array<float, kModelDim * 2> griffin_gate_biases;
std::array<float, kModelDim> griffin_a;
std::array<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w;
};
float ScaleWeights(float* data, size_t len) {
float maxabs = 0.0;
for (size_t i = 0; i < len; ++i) {
maxabs = std::max(maxabs, std::abs(data[i]));
}
const float kMaxRange = 1.875f;
if (maxabs <= kMaxRange) {
return 1.0f;
}
const float scale = maxabs / kMaxRange;
const float inv_scale = 1.0f / scale;
for (size_t i = 0; i < len; ++i) {
data[i] *= inv_scale;
}
return scale;
}
// Array instead of single large allocation for parallel mem init. Split out of
// Weights so that only these pointers are initialized.
template <class TConfig>
struct LayerPointers {
explicit LayerPointers(hwy::ThreadPool& pool) {
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
this->layers[task] = hwy::AllocateAligned<Layer<TConfig>>(1);
});
}
using TLayer = Layer<TConfig>;
std::array<hwy::AlignedFreeUniquePtr<TLayer[]>, TConfig::kLayers> layers;
};
template <class TConfig>
struct Weights {
// No ctor/dtor, allocated via AllocateAligned.
std::array<float, TConfig::kVocabSize * TConfig::kModelDim>
embedder_input_embedding;
std::array<float, TConfig::kModelDim> final_norm_scale;
LayerPointers<TConfig> layer_ptrs;
std::array<float, TConfig::kNumTensorScales> scales;
const Layer<TConfig>* GetLayer(size_t layer) const {
return layer_ptrs.layers[layer].get();
}
Layer<TConfig>* GetLayer(size_t layer) {
return layer_ptrs.layers[layer].get();
}
};
template <typename TConfig>
hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
const Path& checkpoint, hwy::ThreadPool& pool,
bool scale_for_compression = false) {
PROFILER_ZONE("Startup.LoadWeights");
if (!std::filesystem::exists(checkpoint.path)) {
HWY_ABORT("The model weights file '%s' does not exist.",
checkpoint.path.c_str());
}
using TWeights = Weights<TConfig>;
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8 =
hwy::AllocateAligned<uint8_t>(sizeof(TWeights));
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
new (&weights->layer_ptrs) LayerPointers<TConfig>(pool);
size_t scale_pos = 0;
FILE* fptr;
if constexpr (kDryRunFread) {
fprintf(stderr, "Dry-Run, not reading model-file.\n");
} else {
fptr = fopen(checkpoint.path.c_str(), "rb");
if (fptr == nullptr) {
HWY_ABORT("Failed to open model file %s - does it exist?",
checkpoint.path.c_str());
}
}
bool ok = true;
uint64_t total_size = 0;
auto do_fread = [&](void* var, int layer, const char* name, size_t size) {
if (layer == -1) {
fprintf(stderr, "Loading Parameters (size %zu): %s\n", size, name);
} else {
fprintf(stderr, "Loading Parameters (layer=%d, size %zu): %s\n", layer,
size, name);
}
if constexpr (!kDryRunFread) {
ok &= 1 == fread(var, size, 1, fptr);
total_size += size;
}
};
do_fread(&(weights->embedder_input_embedding), -1, "embedder_input_embedding",
sizeof(weights->embedder_input_embedding));
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
sizeof(weights->final_norm_scale));
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
Layer<TConfig>* layer_view = weights->GetLayer(layer);
#define READ_WEIGHTS(name) \
do { \
do_fread(&(layer_view->name), layer, #name, sizeof(layer_view->name)); \
} while (0)
#define SCALE_WEIGHTS(name) \
do { \
if (ok && !kDryRunFread && scale_for_compression) { \
weights->scales[scale_pos++] = \
ScaleWeights(layer_view->name.data(), layer_view->name.size()); \
} \
} while (0)
// Make sure we don't have uninitialized memory.
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
if (type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attn_vec_einsum_w);
READ_WEIGHTS(qkv_einsum_w);
SCALE_WEIGHTS(attn_vec_einsum_w);
SCALE_WEIGHTS(qkv_einsum_w);
} else {
READ_WEIGHTS(griffin_linear_x_w);
READ_WEIGHTS(griffin_linear_x_biases);
READ_WEIGHTS(griffin_linear_y_w);
READ_WEIGHTS(griffin_linear_y_biases);
READ_WEIGHTS(griffin_linear_out_w);
READ_WEIGHTS(griffin_linear_out_biases);
READ_WEIGHTS(griffin_conv_w);
READ_WEIGHTS(griffin_conv_biases);
READ_WEIGHTS(griffin_gate_w);
READ_WEIGHTS(griffin_gate_biases);
READ_WEIGHTS(griffin_a);
SCALE_WEIGHTS(griffin_linear_x_w);
SCALE_WEIGHTS(griffin_linear_y_w);
SCALE_WEIGHTS(griffin_linear_out_w);
SCALE_WEIGHTS(griffin_gate_w);
}
READ_WEIGHTS(gating_einsum_w);
READ_WEIGHTS(linear_w);
SCALE_WEIGHTS(gating_einsum_w);
SCALE_WEIGHTS(linear_w);
READ_WEIGHTS(pre_attention_norm_scale);
READ_WEIGHTS(pre_ffw_norm_scale);
if (TConfig::kFFBiases) {
READ_WEIGHTS(ffw_gating_biases);
READ_WEIGHTS(ffw_output_biases);
}
if (TConfig::kSoftmaxAttnOutputBiases &&
type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attention_output_biases);
}
#undef READ_WEIGHTS
}
if (!ok) {
HWY_ABORT(
"Failed to read from %s - might be a directory, or too small? "
"expected size: %d kB",
checkpoint.path.c_str(), static_cast<uint32_t>(total_size >> 10));
}
if (!kDryRunFread) {
HWY_ASSERT(0 == fclose(fptr));
if (scale_for_compression) {
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
}
}
return weights_u8;
}
template <class TConfig>
struct CompressedLayer {
// No ctor/dtor, allocated via AllocateAligned.
using TLayer = gcpp::Layer<TConfig>;
using WeightT = typename TConfig::WeightT;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
// Compressed Parameters
// We don't yet have an RMSNorm that accepts all WeightT.
CompressedArray<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
CompressedArray<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
CompressedArray<float, 2 * kFFHiddenDim> ffw_gating_biases;
CompressedArray<float, kModelDim> ffw_output_biases;
CompressedArray<float, kModelDim> attention_output_biases;
CompressedArray<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
CompressedArray<WeightT, kModelDim * kFFHiddenDim> linear_w;
CompressedArray<WeightT, TLayer::kQKVEinsumWSize> qkv_einsum_w;
CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> attn_vec_einsum_w;
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_y_w;
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_x_w;
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_out_w;
CompressedArray<WeightT, kModelDim * kModelDim / TConfig::kHeads * 2>
griffin_gate_w;
CompressedArray<float, kModelDim> griffin_a;
CompressedArray<float, kModelDim> griffin_linear_y_biases;
CompressedArray<float, kModelDim> griffin_linear_x_biases;
CompressedArray<float, kModelDim> griffin_linear_out_biases;
CompressedArray<float, kModelDim> griffin_conv_biases;
CompressedArray<float, kModelDim * 2> griffin_gate_biases;
CompressedArray<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w;
};
// Array instead of single large allocation for parallel mem init. Split out
// of CompressedWeights so that only these pointers are initialized, not the
// CompressedArray.
template <class TConfig>
struct CompressedLayerPointers {
explicit CompressedLayerPointers(hwy::ThreadPool& pool) {
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
this->c_layers[task] = hwy::AllocateAligned<CompressedLayer<TConfig>>(1);
});
}
using CLayer = CompressedLayer<TConfig>;
std::array<hwy::AlignedFreeUniquePtr<CLayer[]>, TConfig::kLayers> c_layers;
};
template <class TConfig>
struct CompressedWeights {
// No ctor/dtor, allocated via AllocateAligned.
CompressedArray<EmbedderInputT, TConfig::kVocabSize * TConfig::kModelDim>
embedder_input_embedding;
CompressedArray<hwy::bfloat16_t, TConfig::kModelDim> final_norm_scale;
// Must be last so that the other arrays remain aligned.
CompressedLayerPointers<TConfig> c_layer_ptrs;
const CompressedLayer<TConfig>* GetLayer(size_t layer) const {
return c_layer_ptrs.c_layers[layer].get();
}
CompressedLayer<TConfig>* GetLayer(size_t layer) {
return c_layer_ptrs.c_layers[layer].get();
}
};
template <class TConfig>
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
Weights<TConfig>>;
// Aligned.
template <class TConfig, size_t TBatchSize>
struct Activations {
static constexpr size_t kBatchSize = TBatchSize;
using LayerConfig = Layer<TConfig>;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kCachePosSize =
NumGemmaLayers<TConfig>() * kKVHeads * kQKVDim;
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim;
std::array<float, kBatchSize * kModelDim> x; // input
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
std::array<float, kBatchSize * kHeads * kQKVDim> q; // query vector
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
att; // attention vector
std::array<float, kBatchSize * kHeads * kQKVDim> att_out; // attention output
std::array<float, kHeads * kBatchSize * kModelDim>
att_post1; // attention output after linear transformation, per head
std::array<float, kBatchSize * kModelDim>
att_post2; // accumulation of attention outputs over heads
std::array<hwy::bfloat16_t, kBatchSize * kModelDim> bf_pre_ffw_rms_out;
std::array<float, kBatchSize * TConfig::kFFHiddenDim * 2> ffw_hidden;
// bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved.
// std::array<hwy::bfloat16_t, kBatchSize * 2 * TConfig::kFFHiddenDim>
// bf_ffw_hidden;
std::array<float, kBatchSize * kModelDim> ffw_out;
std::array<float, kBatchSize * TConfig::kVocabSize> logits;
// Griffin layer internal activations
std::array<float, kBatchSize * kModelDim> griffin_x;
std::array<float, kBatchSize * kModelDim> griffin_y;
std::array<float, kBatchSize * kModelDim> griffin_gate_x;
std::array<float, kBatchSize * kModelDim> griffin_multiplier;
};
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we
// define an abstract base class.
struct GemmaInterface {
virtual ~GemmaInterface() = default;
virtual const GemmaTokenizer* Tokenizer() const = 0;
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) = 0;
virtual float ComputeCrossEntropy(size_t max_tokens,
const std::vector<int>& prompt,
KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool,
int verbosity) = 0;
};
template <class Config>
KVCache CreateKVCache() {
constexpr size_t kConv1dWidth = Config::kConv1dWidth;
return CreateKVCache(
NumGemmaLayers<Config>() * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen,
NumGriffinLayers<Config>() * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
Config::kModelDim,
NumGriffinLayers<Config>() * Config::kModelDim);
}
KVCache CreateKVCache(Model type) {
switch (type) {
case Model::GEMMA_2B:
return CreateKVCache<ConfigGemma2B>();
case Model::GEMMA_7B:
return CreateKVCache<ConfigGemma7B>();
case Model::GRIFFIN_2B:
return CreateKVCache<ConfigGriffin2B>();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
}
}
class GemmaTokenizerImpl : public GemmaTokenizer {
public:
GemmaTokenizerImpl(
std::unique_ptr<sentencepiece::SentencePieceProcessor>&& impl)
: impl_(std::move(impl)) {}
bool Encode(const std::string& input,
std::vector<std::string>* pieces) const override {
return impl_->Encode(input, pieces).ok();
}
bool Encode(const std::string& input,
std::vector<int>* pieces) const override {
if constexpr (kShowTokenization) {
bool is_ok = impl_->Encode(input, pieces).ok();
for (int i = 0; i < static_cast<int>(pieces->size()); i++) {
fprintf(stderr, "%3d: %d\n", i, (*pieces)[i]);
}
return is_ok;
} else {
return impl_->Encode(input, pieces).ok();
}
}
// Given a sequence of ids, decodes it into a detokenized output.
bool Decode(const std::vector<int>& ids,
std::string* detokenized) const override {
return impl_->Decode(ids, detokenized).ok();
}
private:
std::unique_ptr<sentencepiece::SentencePieceProcessor> impl_;
};
namespace {
template <class Config>
void DeleteLayersPtrs(CompressedWeights<Config>* c_weights) {
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
}
template <class Config>
void DeleteLayersPtrs(Weights<Config>* weights) {
weights->layer_ptrs.~LayerPointers<Config>();
}
} // namespace
template <class Config>
struct GemmaImpl : public GemmaInterface {
GemmaImpl(std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
hwy::AlignedFreeUniquePtr<uint8_t[]>& weights_u8,
hwy::ThreadPool& pool);
~GemmaImpl() {
WeightsT<Config>* weights =
reinterpret_cast<WeightsT<Config>*>(weights_u8.get());
DeleteLayersPtrs(weights);
}
const GemmaTokenizer* Tokenizer() const override { return &tokenizer; }
void Generate(size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937&,
int verbosity) override;
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool,
int verbosity) override;
GemmaTokenizerImpl tokenizer;
hwy::AlignedFreeUniquePtr<uint8_t[]> weights_u8;
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
};
} // namespace gcpp
#endif // GEMMA_ONCE
// SIMD code, compiled once per target.
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
template <size_t kBatchSize, typename LayerT, class TConfig>
HWY_NOINLINE void GriffinRecurrent(
size_t batch_start, size_t batch_idx, size_t layer,
Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights,
KVCache& kv_cache, hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Griffin");
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
HWY_DASSERT(batch_idx < kBatchSize);
static constexpr size_t kModelDim =
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr size_t kHeads = TConfig::kHeads;
const size_t batch_offset = batch_idx * kModelDim;
const size_t pos = batch_start + batch_idx;
// X / Y linear layers.
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
TwoMatVecAdd<true, kModelDim, kModelDim>(
layer_weights->griffin_linear_x_w, layer_weights->griffin_linear_y_w, 0,
activations.pre_att_rms_out.data() + batch_offset,
/*add0=*/layer_weights->griffin_linear_x_biases.data(),
/*add1=*/layer_weights->griffin_linear_y_biases.data(), /*out0=*/x,
/*out1=*/y, pool);
Gelu(y, kModelDim);
// Conv1D.
{
HWY_FULL(float) df;
HWY_DASSERT(kModelDim % Lanes(df) == 0);
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
// cache[i] = input at time t-i.
float* HWY_RESTRICT cache[HWY_MAX(kConv1dWidth, 1)];
cache[0] = x;
for (size_t i = 1; i < kConv1dWidth; i++) {
cache[i] =
kv_cache.conv1d_cache.get() + layer_offset +
((pos + kConv1dWidth - 1 - i) % (kConv1dWidth - 1)) * kModelDim;
}
for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
auto xv = hn::Load(df, x + i);
auto accum0 = hn::Load(df, layer_weights->griffin_conv_biases.data() + i);
auto accum1 = hn::Zero(df);
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
auto wv0 = hn::Load(df, layer_weights->griffin_conv_w.data() +
(kConv1dWidth - 1 - 2 * l) * kModelDim + i);
auto wv1 = hn::Load(df, layer_weights->griffin_conv_w.data() +
(kConv1dWidth - 2 - 2 * l) * kModelDim + i);
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
}
hn::Store(hn::Add(accum0, accum1), df, x + i);
hn::Store(xv, df, cache[HWY_MAX(kConv1dWidth, 1) - 1] + i);
}
}
// RGLRU
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.data() + batch_offset;
float* HWY_RESTRICT a = activations.griffin_multiplier.data() + batch_offset;
float* HWY_RESTRICT rnn_state =
kv_cache.rglru_cache.get() + layer * kModelDim;
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
constexpr size_t kHeadDim = kModelDim / kHeads;
constexpr size_t kMatrixSize = kHeadDim * kHeadDim;
size_t head_offset = head * kHeadDim;
TwoOfsMatVecAddLoop<true, kHeadDim, kHeadDim>(
layer_weights->griffin_gate_w, kMatrixSize * head,
kMatrixSize * (kHeads + head), x + head_offset,
/*add0=*/layer_weights->griffin_gate_biases.data() + head_offset,
/*add1=*/layer_weights->griffin_gate_biases.data() + kModelDim +
head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
Sigmoid(gate_x + head_offset, kHeadDim);
Sigmoid(a + head_offset, kHeadDim);
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin_a.data() + head_offset, fn_mul);
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul);
// RNN scan
HWY_FULL(float) df;
HWY_DASSERT(kHeadDim % Lanes(df) == 0);
for (size_t i = 0; i < kHeadDim; i += Lanes(df)) {
auto log_a = hn::Load(df, a + head_offset + i);
auto gated_x = hn::Load(df, x + head_offset + i);
auto rnn = hn::Load(df, rnn_state + head_offset + i);
auto a = hn::Exp(df, log_a);
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0)));
if (pos == 0) {
x_multiplier = hn::Set(df, 1.0);
}
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
hn::Store(new_x, df, rnn_state + head_offset + i);
// Join branches.
auto yv = hn::Load(df, y + head_offset + i);
auto pre_out = hn::Mul(yv, new_x);
hn::Store(pre_out, df, x + head_offset + i);
}
});
// Final linear layer.
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
MatVecAdd<true, kModelDim, kModelDim>(
layer_weights->griffin_linear_out_w, 0, x,
layer_weights->griffin_linear_out_biases.data(), out_ptr, pool);
}
template <size_t kBatchSize, typename LayerT, class TConfig>
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
Activations<TConfig, kBatchSize>& activations,
const LayerT* layer_weights, KVCache& kv_cache,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Attention");
const size_t pos = batch_start + batch_idx;
HWY_DASSERT(batch_idx < kBatchSize);
static constexpr size_t kQKVDim = gcpp::Activations<TConfig, 1>::kQKVDim;
static constexpr size_t kCachePosSize =
gcpp::Activations<TConfig, kBatchSize>::kCachePosSize;
static constexpr size_t kCacheLayerSize =
gcpp::Activations<TConfig, kBatchSize>::kCacheLayerSize;
static constexpr size_t kModelDim =
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
size_t cache_pos = pos;
size_t cache_num = pos + 1;
if constexpr (TConfig::kUseLocalAttention) {
cache_pos %= TConfig::kSeqLen;
cache_num = std::min(cache_num, static_cast<size_t>(TConfig::kSeqLen));
}
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR {
float* HWY_RESTRICT q =
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
MatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w,
head_offset + 0 * kQKVDim * kModelDim, x, q);
};
auto ProjKV = [&](size_t k_offset, size_t v_offset,
size_t kv_offset) HWY_ATTR {
float* HWY_RESTRICT k = kv_cache.key_cache.get() + kv_offset;
float* HWY_RESTRICT v = kv_cache.value_cache.get() + kv_offset;
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, k_offset,
v_offset, x, k, v);
Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
};
auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR {
// Calculate scores
float* HWY_RESTRICT q =
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
float* HWY_RESTRICT head_att = activations.att.data() +
head * TConfig::kSeqLen +
batch_idx * kHeads * kQKVDim;
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
MulByConst(kQueryScale, q, kQKVDim);
// Compute Q dot K scores
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
const size_t cache_offset =
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score;
}
Softmax(head_att, cache_num);
// Weighted summation
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
batch_idx * kHeads * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
const size_t cache_offset =
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
}
// linear projection from kQKVDim back to kModelDim, sum projections
// across heads
float* HWY_RESTRICT head_out =
head == 0
? activations.att_post2.data() + batch_idx * kModelDim
: activations.att_post1.data() + head * kBatchSize * kModelDim;
if (head == 0) {
MatVecAddLoop<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out,
layer_weights->attention_output_biases.data(), head_out);
} else {
MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w,
head * kModelDim * kQKVDim, att_out,
head_out);
}
};
if constexpr (kHeads == kKVHeads) {
// Multi-Head Attention
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
// linear projections to QKV
const size_t head_offset = TConfig::kInterleaveQKV
? 3 * kQKVDim * kModelDim
: kQKVDim * kModelDim;
const size_t mat_offset =
TConfig::kInterleaveQKV ? kQKVDim * kModelDim : kModelDim * kModelDim;
const size_t q_offset = head * head_offset + 0 * mat_offset;
const size_t k_offset = head * head_offset + 1 * mat_offset;
const size_t v_offset = head * head_offset + 2 * mat_offset;
ProjQ(head, q_offset);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
ProjKV(k_offset, v_offset, kv_offset);
Attn(head, head * kQKVDim);
});
} else {
// Multi-Query Attention
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize;
ProjKV(k_offset, v_offset, kv_offset);
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
ProjQ(head, head * kQKVDim * kModelDim);
Attn(head, 0);
});
}
// accumulate output across all heads into att_post2. head 0 already wrote
// directly to att_post2.
for (size_t head = 1; head < kHeads; ++head) {
AddFrom(activations.att_post1.data() + head * kBatchSize * kModelDim,
activations.att_post2.data() + batch_idx * kModelDim, kModelDim);
}
}
template <size_t kBatchSize, typename LayerT, typename TConfig>
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
size_t batch_idx, const LayerT* layer_weights,
hwy::ThreadPool& pool) {
HWY_DASSERT(batch_idx < kBatchSize);
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
{
PROFILER_ZONE("Gen.FFW.GatedGELU");
const hwy::bfloat16_t* HWY_RESTRICT vec =
activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset;
float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
// Same matrix, first and second half of rows. Could fuse into one MatVec,
// but separating them could help on NUMA e.g. multiple sockets.
MatVecAdd<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
layer_weights->ffw_gating_biases.data() + kFFHiddenDim, out_mul, pool);
// Gate, will go through the nonlinearity.
MatVecAdd<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, 0, vec,
layer_weights->ffw_gating_biases.data(), out, pool);
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
hn::Transform1(DF(), out, kFFHiddenDim, out_mul,
[](DF df, VF v, VF mul)
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
}
PROFILER_ZONE("Gen.FFW\\GatedGELU");
MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
layer_weights->ffw_output_biases.data(),
activations.ffw_out.data() + batch_idx * kModelDim, pool);
}
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
// are both constexpr
#if HWY_COMPILER_GCC_ACTUAL
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
#else
#define GEMMA_CONSTEXPR_EMBSCALING
#endif
template <typename TConfig>
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(TConfig::kModelDim))));
}
template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
const WeightArrayT& weights,
Activations<TConfig, kBatchSize>& activations,
KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool) {
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
static constexpr size_t kModelDim = TConfig::kModelDim;
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
EmbeddingScaling<TConfig>();
pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
const int token = tokens[token_idx];
Decompress(weights.embedder_input_embedding, token * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim);
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
kModelDim);
});
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
const auto* layer_weights = weights.GetLayer(layer);
size_t layer_of_type =
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNorm(activations.x.data() + token_idx * kModelDim,
layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data() + token_idx * kModelDim,
kModelDim);
if (type == LayerAttentionType::kGemma) {
Attention<kBatchSize>(pos, token_idx, layer_of_type, activations,
layer_weights, kv_cache, pool);
} else {
GriffinRecurrent<kBatchSize>(pos, token_idx, layer_of_type, activations,
layer_weights, kv_cache, pool);
}
}
// TODO: sink the loop into these functions, i.e. make them matmuls.
pool.Run(
0, num_tokens,
[&](const uint64_t token_idx, size_t thread_id) HWY_ATTR {
AddFrom(activations.att_post2.data() + token_idx * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim);
RMSNorm(activations.x.data() + token_idx * kModelDim,
layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
kModelDim);
FFW<kBatchSize>(activations, token_idx, layer_weights, inner_pool);
AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim);
});
} // foreach layer
pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
RMSNormInplace(weights.final_norm_scale.data(),
activations.x.data() + token_idx * kModelDim, kModelDim);
});
}
// n = 1 specialization
template <typename WeightArrayT, class TConfig>
void Transformer(int token, size_t pos, const WeightArrayT& weights,
Activations<TConfig, 1>& activations, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) {
static constexpr size_t kModelDim = TConfig::kModelDim;
Decompress(weights.embedder_input_embedding, token * kModelDim,
activations.x.data(), kModelDim);
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
EmbeddingScaling<TConfig>();
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
const auto* layer_weights = weights.GetLayer(layer);
size_t layer_of_type =
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
RMSNorm(activations.x.data(),
layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data(), kModelDim);
if (type == LayerAttentionType::kGemma) {
Attention<1>(pos, 0, layer_of_type, activations, layer_weights, kv_cache,
pool);
} else {
GriffinRecurrent<1>(pos, 0, layer_of_type, activations, layer_weights,
kv_cache, pool);
}
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim);
FFW<1>(activations, /* batch_idx = */ 0, layer_weights, pool);
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
}
RMSNormInplace(weights.final_norm_scale.data(), activations.x.data(),
kModelDim);
}
template <class TConfig>
void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
size_t& prompt_size) {
if (!TConfig::kUseLocalAttention) {
if (max_tokens > TConfig::kSeqLen) {
fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n",
max_tokens, TConfig::kSeqLen);
max_tokens = static_cast<size_t>(TConfig::kSeqLen);
}
}
if (max_generated_tokens > max_tokens) {
fprintf(stderr,
"WARNING: max_generated_tokens %zu > max_tokens %zu, truncating.\n",
max_generated_tokens, max_tokens);
max_generated_tokens = max_tokens - 1;
}
if (!TConfig::kUseLocalAttention) {
if (prompt_size + max_generated_tokens > max_tokens) {
fprintf(stderr,
"WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen "
"%d, truncating.\n",
prompt_size, max_generated_tokens, TConfig::kSeqLen);
prompt_size = max_tokens - max_generated_tokens;
}
}
}
template <class TConfig>
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,