forked from gorgonia/gorgonia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
device_cuda.go
51 lines (39 loc) · 1.34 KB
/
device_cuda.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
// +build cuda
package gorgonia
import "github.com/chewxy/cu"
// Device represents the device where the code will be executed on. It can either be a GPU or CPU
type Device cu.Device
// CPU is the default the graph will be executed on.
const CPU = Device(cu.CPU)
// String implements fmt.Stringer and runtime.Stringer
func (d Device) String() string { return cu.Device(d).String() }
// Alloc allocates memory on the device. If the device is CPU, the allocations is a NO-OP because Go handles all the allocations in the CPU
func (d Device) Alloc(extern External, size int64) (Memory, error) {
if d == CPU {
cudaLogf("device is CPU")
return nil, nil // well there should be an error because this wouldn't be called
}
machine := extern.(CUDAMachine)
ctxes := machine.Contexts()
if len(ctxes) == 0 {
cudaLogf("allocate nothing")
return nil, nil
}
ctx := ctxes[int(d)]
cudaLogf("calling ctx.MemAlloc(%d)", size)
return ctx.MemAlloc(size)
}
func (d Device) Free(extern External, mem Memory, size int64) (err error) {
var devptr cu.DevicePtr
var ok bool
if devptr, ok = mem.(cu.DevicePtr); !ok {
return nil
}
machine := extern.(CUDAMachine)
machine.Put(d, devptr, size)
// FUTURE: actually free memory if there ain't enough to go round
// ctx := machine.Contexts()[int(d)]
// cudaLogf("MemFree %v", devptr)
// ctx.MemFree(devptr)
return nil
}