Skip to content

Commit

Permalink
MPSMatrix improvements (#157)
Browse files Browse the repository at this point in the history
* Add batched MPSMatrix

* MPSMatrix from SubArray
  • Loading branch information
tgymnich authored Sep 5, 2023
1 parent 19f3df1 commit 69aa51e
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 1 deletion.
60 changes: 59 additions & 1 deletion lib/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@ Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)

export MPSMatrixDescriptor

@objcwrapper MPSMatrixDescriptor <: NSObject
@objcwrapper immutable=false MPSMatrixDescriptor <: NSObject

@objcproperties MPSMatrixDescriptor begin
@autoproperty rows::NSUInteger setter=setRows
@autoproperty columns::NSUInteger setter=setColumns
@autoproperty matrices::NSUInteger
@autoproperty dataType::MPSDataType setter=setDataType
@autoproperty rowBytes::NSUInteger setter=setRowBytes
@autoproperty matrixBytes::NSUInteger
end


# Mapping from Julia types to the Performance Shader bitfields
const jl_typ_to_mps = Dict{DataType,MPSDataType}(
Expand Down Expand Up @@ -49,6 +59,17 @@ function MPSMatrixDescriptor(rows, columns, rowBytes, dataType)
return obj
end

function MPSMatrixDescriptor(rows, columns, matrices, rowBytes, matrixBytes, dataType)
desc = @objc [MPSMatrixDescriptor matrixDescriptorWithRows:rows::NSUInteger
columns:columns::NSUInteger
matrices:matrices::NSUInteger
rowBytes:rowBytes::NSUInteger
matrixBytes:matrixBytes::NSUInteger
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor}
obj = MPSMatrixDescriptor(desc)
# XXX: who releases this object?
return obj
end

#
# matrix object
Expand All @@ -58,6 +79,19 @@ export MPSMatrix

@objcwrapper immutable=false MPSMatrix <: NSObject

@objcproperties MPSMatrix begin
@autoproperty device::id{MTLDevice}
@autoproperty rows::NSUInteger
@autoproperty columns::NSUInteger
@autoproperty matrices::NSUInteger
@autoproperty dataType::MPSDataType
@autoproperty rowBytes::NSUInteger
@autoproperty matrixBytes::NSUInteger
@autoproperty offset::NSUInteger
@autoproperty data::id{MTLBuffer}
end


"""
MPSMatrix(arr::MtlMatrix)
Expand All @@ -71,13 +105,37 @@ function MPSMatrix(arr::MtlMatrix{T}) where T
desc = MPSMatrixDescriptor(n_rows, n_cols, sizeof(T)*n_cols, T)
mat = @objc [MPSMatrix alloc]::id{MPSMatrix}
obj = MPSMatrix(mat)
offset = arr.offset * sizeof(T)
finalizer(release, obj)
@objc [obj::id{MPSMatrix} initWithBuffer:arr::id{MTLBuffer}
offset:offset::NSUInteger
descriptor:desc::id{MPSMatrixDescriptor}]::id{MPSMatrix}
return obj
end


"""
MPSMatrix(arr::MtlArray{T,3})
Metal batched matrix representation used in Performance Shaders.
Note that this results in a transposed view of the input,
as Metal stores matrices row-major instead of column-major.
"""
function MPSMatrix(arr::MtlArray{T,3}) where T
n_cols, n_rows, n_matrices = size(arr)
row_bytes = sizeof(T)*n_cols
desc = MPSMatrixDescriptor(n_rows, n_cols, n_matrices, row_bytes, row_bytes * n_rows, T)
mat = @objc [MPSMatrix alloc]::id{MPSMatrix}
obj = MPSMatrix(mat)
offset = arr.offset * sizeof(T)
finalizer(release, obj)
@objc [obj::id{MPSMatrix} initWithBuffer:arr::id{MTLBuffer}
offset:offset::NSUInteger
descriptor:desc::id{MPSMatrixDescriptor}]::id{MPSMatrix}
return obj
end

#
# matrix multiplication
#
Expand Down
2 changes: 2 additions & 0 deletions lib/mps/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ function MPSVector(arr::MtlVector{T}) where T
desc = MPSVectorDescriptor(len, T)
vec = @objc [MPSVector alloc]::id{MPSVector}
obj = MPSVector(vec)
offset = arr.offset * sizeof(T)
finalizer(release, obj)
@objc [obj::id{MPSVector} initWithBuffer:arr::id{MTLBuffer}
offset:offset::NSUInteger
descriptor:desc::id{MPSVectorDescriptor}]::id{MPSVector}
return obj
end
Expand Down
20 changes: 20 additions & 0 deletions test/mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ if MPS.is_supported(current_device())
end
end

@testset "test matrix vector multiplication of views" begin
N = 20
a = rand(Float32, N,N)
b = rand(Float32, N)

mtl_a = mtl(a)
mtl_b = mtl(b)

view_a = @view a[:,10:end]
view_b = @view b[10:end]

mtl_view_a = @view mtl_a[:,10:end]
mtl_view_b = @view mtl_b[10:end]

mtl_c = mtl_view_a * mtl_view_b
c = view_a * view_b

@test mtl_c == mtl(c)
end

@testset "mixed-precision matrix vector multiplication" begin
N = 10
rows = N
Expand Down

0 comments on commit 69aa51e

Please sign in to comment.