-
Notifications
You must be signed in to change notification settings - Fork 10
/
memcpy_metal_objc.mm
168 lines (128 loc) · 4.66 KB
/
memcpy_metal_objc.mm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#import <Metal/Metal.h>
#import "memcpy_metal_objc.h"
struct memcpy_constants
{
uint num_elements;
};
@implementation MemcpyMetalObjC
{
id<MTLComputePipelineState> _mPSO;
id<MTLBuffer> _mIn;
id<MTLBuffer> _mOut;
id<MTLBuffer> _mConst;
uint _mNumElementsInt;
uint _mNumGroupsPerGrid;
uint _mNumThreadsPerGroup;
bool _mUseManagedBuffer;
}
- (instancetype) initWithNumBytes:(size_t) num_bytes UseManagedBuffer:(bool) useManagedBuffer
{
self = [super init];
if (self)
{
_mUseManagedBuffer = useManagedBuffer;
if (num_bytes < 1024*sizeof(int)) {
_mNumElementsInt = (num_bytes + sizeof(int) - 1)/ sizeof(int);
_mNumThreadsPerGroup = (_mNumElementsInt + 31) / 32 * 32;
_mNumGroupsPerGrid = 1;
}
else {
_mNumElementsInt = (num_bytes + sizeof(int) - 1)/ sizeof(int);
_mNumThreadsPerGroup = 1024;
_mNumGroupsPerGrid = (_mNumElementsInt + 1023) / 1024;
}
[ self loadLibraryWithName:@"./memcpy.metallib" ];
_mPSO = [ self getPipelineStateForFunction:@"my_memcpy" ];
if ( _mUseManagedBuffer ) {
_mIn = [ self getManagedMTLBufferForBytes: _mNumElementsInt * sizeof(int) for:@"_mIn" ];
_mOut = [ self getManagedMTLBufferForBytes: _mNumElementsInt * sizeof(int) for:@"_mOut" ];
_mConst = [ self getManagedMTLBufferForBytes: sizeof(struct memcpy_constants) for:@"_mOut" ];
}
else {
_mIn = [ self getSharedMTLBufferForBytes: _mNumElementsInt * sizeof(int) for:@"_mIn" ];
_mOut = [ self getSharedMTLBufferForBytes: _mNumElementsInt * sizeof(int) for:@"_mOut" ];
_mConst = [ self getSharedMTLBufferForBytes: sizeof(struct memcpy_constants) for:@"_mOut" ];
}
struct memcpy_constants c;
memset( &c, (uint)0, sizeof(struct memcpy_constants) );
c.num_elements = _mNumElementsInt;
memcpy( _mConst.contents, &c, sizeof(struct memcpy_constants) );
}
return self;
}
- (uint) numGroupsPerGrid
{
return _mNumGroupsPerGrid;
}
- (uint) numThreadsPerGroup
{
return _mNumThreadsPerGroup;
}
- (uint) numBytes
{
return _mNumElementsInt * sizeof(int);
}
-(void*) getRawPointerIn
{
return (void*)_mIn.contents;
}
-(void*) getRawPointerOut
{
return (void*)_mOut.contents;
}
-(void) performComputationKernel
{
#if TARGET_OS_OSX
if ( _mUseManagedBuffer ) {
[_mIn didModifyRange: NSMakeRange(0, _mNumElementsInt * sizeof(int) ) ];
[_mConst didModifyRange: NSMakeRange(0, sizeof( struct memcpy_constants) ) ];
}
#endif
id<MTLCommandBuffer> commandBuffer = [ self.commandQueue commandBuffer ];
assert( commandBuffer != nil );
id<MTLComputeCommandEncoder> computeEncoder = [ commandBuffer computeCommandEncoder ];
assert( computeEncoder != nil );
[ computeEncoder setComputePipelineState: _mPSO ];
[ computeEncoder setBuffer:_mIn offset:0 atIndex:0 ];
[ computeEncoder setBuffer:_mOut offset:0 atIndex:1 ];
[ computeEncoder setBuffer:_mConst offset:0 atIndex:2 ];
[ computeEncoder dispatchThreadgroups:MTLSizeMake( _mNumGroupsPerGrid, 1, 1)
threadsPerThreadgroup:MTLSizeMake( _mNumThreadsPerGroup, 1, 1) ];
[ computeEncoder endEncoding];
if ( _mUseManagedBuffer ) {
id<MTLBlitCommandEncoder> blitEncoder = [ commandBuffer blitCommandEncoder ];
assert( blitEncoder != nil );
#if TARGET_OS_OSX
[ blitEncoder synchronizeResource:_mOut ];
#endif
[ blitEncoder endEncoding ];
}
[commandBuffer commit];
[commandBuffer waitUntilCompleted];
}
-(void) performComputationBlit
{
#if TARGET_OS_OSX
if ( _mUseManagedBuffer ) {
[_mIn didModifyRange: NSMakeRange(0, _mNumElementsInt * sizeof(int) ) ];
}
#endif
id<MTLCommandBuffer> commandBuffer = [ self.commandQueue commandBuffer ];
assert( commandBuffer != nil );
id<MTLBlitCommandEncoder> blitEncoder = [ commandBuffer blitCommandEncoder ];
assert( blitEncoder != nil );
[ blitEncoder copyFromBuffer: _mIn
sourceOffset: 0
toBuffer: _mOut
destinationOffset: 0
size: sizeof(int)*_mNumElementsInt ];
#if TARGET_OS_OSX
if ( _mUseManagedBuffer ) {
[ blitEncoder synchronizeResource:_mOut ];
}
#endif
[blitEncoder endEncoding];
[commandBuffer commit];
[commandBuffer waitUntilCompleted];
}
@end