Skip to content

Commit

Permalink
Fix Kotlin Native memory management for tensor data
Browse files Browse the repository at this point in the history
  • Loading branch information
erksch committed Jun 24, 2022
1 parent 5e387fa commit 413dcbd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 25 deletions.
10 changes: 5 additions & 5 deletions src/iosMain/kotlin/de/voize/pytorch_lite_multiplatform/Tensor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ import kotlinx.cinterop.*
import cocoapods.LibTorchWrapper.Tensor as LibTorchWrapperTensor

actual abstract class Tensor {
abstract fun getTensor(): LibTorchWrapperTensor
abstract fun getTensor(nativePlacement: NativePlacement): LibTorchWrapperTensor
}

actual class LongTensor actual constructor(
private val data: LongArray,
private val shape: LongArray
) : Tensor() {
override fun getTensor(): LibTorchWrapperTensor {
return memScoped {
override fun getTensor(nativePlacement: NativePlacement): LibTorchWrapperTensor {
return with(nativePlacement) {
val cData = allocArray<LongVar>(data.size)
val cShape = allocArray<LongVar>(shape.size)
data.forEachIndexed { index, value -> cData[index] = value }
Expand All @@ -26,8 +26,8 @@ actual class FloatTensor actual constructor(
private val data: FloatArray,
private val shape: LongArray
) : Tensor() {
override fun getTensor(): LibTorchWrapperTensor {
return memScoped {
override fun getTensor(nativePlacement: NativePlacement): LibTorchWrapperTensor {
return with(nativePlacement) {
val cData = allocArray<FloatVar>(data.size)
val cShape = allocArray<LongVar>(shape.size)
data.forEachIndexed { index, value -> cData[index] = value }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,25 @@ actual class TorchModule actual constructor(path: String) {
private val module = LibTorchWrapperTorchModule(fileAtPath = path)

actual fun runMethod(methodName: String, inputs: List<Tensor>): ModelOutput {
return memScoped {
val output = module.runMethod(methodName, inputs.map { it.getTensor() })

output?.let {
ModelOutput(
(it.data as List<Float>).toFloatArray(),
(it.shape as List<Long>).toLongArray(),
)
} ?: throw IllegalArgumentException("Model output can not be null")
}
val output = memScoped {
module.runMethod(methodName, inputs.map { it.getTensor(this) })
} ?: throw IllegalArgumentException("Model output can not be null")

return ModelOutput(
(output.data as List<Float>).toFloatArray(),
(output.shape as List<Long>).toLongArray(),
)
}

actual fun runMethod(methodName: String, inputs: Map<String, Tensor>): ModelOutput {
return memScoped {
val output = module.runMethodMap(methodName, inputs.mapValues { it.value.getTensor() })

output?.let {
ModelOutput(
(it.data as List<Float>).toFloatArray(),
(it.shape as List<Long>).toLongArray(),
)
} ?: throw IllegalArgumentException("Model output can not be null")
}
val output = memScoped {
module.runMethodMap(methodName, inputs.mapValues { it.value.getTensor(this) })
} ?: throw IllegalArgumentException("Model output can not be null")

return ModelOutput(
(output.data as List<Float>).toFloatArray(),
(output.shape as List<Long>).toLongArray(),
)
}

actual fun forward(inputs: List<Tensor>): ModelOutput {
Expand Down

0 comments on commit 413dcbd

Please sign in to comment.