Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed vectorize function for tile (with a nelts list) in batch_matmul #77

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 55 additions & 38 deletions llama2.mojo
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from algorithm import sum
from algorithm import vectorize, parallelize, unroll
from algorithm import vectorize, parallelize, unroll, tile
from builtin import string
from math import round
from memory import memset_zero, memcpy
from math import round, log2
from memory import memset_zero, memcpy, stack_allocation
from memory.buffer import Buffer
from memory.unsafe import DTypePointer
from random import rand
Expand All @@ -19,7 +19,7 @@ import time

var workers = 0

alias nelts = (4*simdwidthof[DType.float32]())
alias nelts = (4 * simdwidthof[DType.float32]())

alias PointerString = Pointer[UInt8]
alias BufferPtrType = DTypePointer[DType.uint8]
Expand Down Expand Up @@ -371,7 +371,6 @@ struct RunState:
var key_cache: TensorF32 # (layer, seq_len, dim)
var value_cache: TensorF32 # (layer, seq_len, dim)


fn __init__(inout self, config: Config) raises:
self.x = TensorF32(config.dim)
self.xb = TensorF32(config.dim)
Expand Down Expand Up @@ -449,10 +448,10 @@ fn read_file(file_name: String, inout buf: FileBuf) raises:
let cp_buf: BufferPtrType = BufferPtrType.alloc(cp_size)

let data_ptr = data._as_ptr().bitcast[DType.uint8]()

for i in range(cp_size):
cp_buf.store(i,data_ptr.load(i))
cp_buf.store(i, data_ptr.load(i))

# don't free data
_ = data

Expand Down Expand Up @@ -571,41 +570,47 @@ fn batch_matmul[
rows: Int,
cols: Int,
):
alias nelts_list = VariadicList(128, 64, 32, 16, 8, 4, 2, 1)

@parameter
fn compute_row(i: Int):
var tmp = StaticTuple[n, SIMD[DType.float32, nelts]]()
var tmp = StaticTuple[n, BufferPtrFloat32]()

@parameter
fn init[k: Int]():
tmp[k] = SIMD[DType.float32, nelts](0)
tmp[k] = stack_allocation[nelts, DType.float32]()
memset_zero(tmp[k], nelts)

unroll[n, init]()
let row_offset = i * cols

@parameter
fn dot[_nelts: Int](j: Int):
if _nelts < nelts: # take care of tail array elements with length < nelts
let a = A.simd_load[_nelts](j)
var j = 0

@parameter
fn _multiply_tail[k: Int]():
tmp[k][0] += (
a * B[k].simd_load[_nelts](row_offset + j)
).reduce_add()

unroll[n, _multiply_tail]()
else:
let a = A.simd_load[nelts](j)

@parameter
fn _multiply[k: Int]():
tmp[k] += a * B[k].simd_load[nelts](row_offset + j)

unroll[n, _multiply]()

vectorize[nelts, dot](cols)
@parameter
fn dot[z: Int]():
# we want the list to only contain nelts value that are 4 * simdwidth or less, if we use bigger values we get undefined behavior
@parameter
if nelts_list[z] <= nelts:
let range = cols - cols % nelts_list[z]
while j < range:
let a = A.simd_load[nelts_list[z]](j)

@parameter
fn _multiply_tail[k: Int]():
tmp[k].simd_store[nelts_list[z]](
0,
tmp[k].simd_load[nelts_list[z]](0)
+ a * B[k].simd_load[nelts_list[z]](row_offset + j),
)

unroll[n, _multiply_tail]()
j += nelts_list[z]

unroll[len(nelts_list), dot]()

@parameter
fn _reduce[k: Int]():
C[k].store(i, tmp[k].reduce_add())
C[k].store(i, tmp[k].simd_load[nelts](0).reduce_add())

unroll[n, _reduce]()

Expand Down Expand Up @@ -643,7 +648,9 @@ fn matmul(C: TensorSlice, A: TensorF32, B: TensorSlice) raises:
# B (d,n) @ A (n,) -> C (d,)
matmul_dimension_checks(A.shape(), B.shape())
batch_matmul[1](
StaticTuple[1, BufferPtrFloat32](C.data(),),
StaticTuple[1, BufferPtrFloat32](
C.data(),
),
A.data(),
StaticTuple[1, BufferPtrFloat32](B.data()),
B.dim(0),
Expand Down Expand Up @@ -671,8 +678,9 @@ fn rope_rotation_llama(
) -> None:
# stories model, llama2
let head_size = config.head_size

@parameter
fn head_loop(i:Int):
fn head_loop(i: Int):
# Simple vectorization with (head_size // 2) steps gave junk transformer output.
# Maybe because the nelt ranges end up overlapping between the steps.
for j in range(0, config.head_size, 2):
Expand All @@ -687,8 +695,8 @@ fn rope_rotation_llama(
let k1 = state.k[i * head_size + j + 1]
state.k[i * head_size + j] = k0 * fcr - k1 * fci
state.k[i * head_size + j + 1] = k0 * fci + k1 * fcr
parallelize[head_loop](config.n_heads, workers)

parallelize[head_loop](config.n_heads, workers)


@always_inline
Expand Down Expand Up @@ -755,7 +763,7 @@ fn transformer(

# Multihead attention. Iterate over all heads in parallel.
@parameter
fn loop_over_heads(h:Int):
fn loop_over_heads(h: Int):
# Get the query vector for this head
let q_offset = h * head_size

Expand Down Expand Up @@ -1020,8 +1028,17 @@ fn main() raises:
var tok = Tokenizer(config.vocab_size, tbuf)

# print the layers number and vocab size
print("checkpoint size: ", fbuf.size, "[", fbuf.size // 1024 // 1024, "MB ]",
"| n layers:", config.n_layers, "| vocab size:", tok.vocab_size)
print(
"checkpoint size: ",
fbuf.size,
"[",
fbuf.size // 1024 // 1024,
"MB ]",
"| n layers:",
config.n_layers,
"| vocab size:",
tok.vocab_size,
)

# Create and initialize the application RunState
var state = RunState(config)
Expand Down