Skip to content

Commit

Permalink
fixed #40 (#41)
Browse files Browse the repository at this point in the history
* save

* save

* fixed #39

---------

Co-authored-by: jjjkkkjjj-mizuno <jkado@mizuno.co.jp>
  • Loading branch information
2 people authored and jjjkkkjjj committed Feb 6, 2023
1 parent 39cdead commit ba20aaa
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 13 deletions.
6 changes: 5 additions & 1 deletion Sources/Matft/core/general/print.swift
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ extension MfData: CustomStringConvertible{

ret += "\n"

ret += "isView\t: \(self._isView)\n"
ret += "isView\t: \(self._isView)"
if self._isView{
ret += ", source\t: \(String(describing: self._base))"
}
ret += "\n"
ret += "offset\t: \(self.offset)\n"

return ret
Expand Down
23 changes: 20 additions & 3 deletions Sources/Matft/core/object/mfarray.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import Foundation
import Accelerate
import CoreML

open class MfArray: MfArrayProtocol{
public typealias MFDATA = MfData
Expand All @@ -16,6 +17,7 @@ open class MfArray: MfArrayProtocol{

public internal(set) var base: MfArray?


/// Create a mfarray from Swift Array
/// - Parameters:
/// - array: A Swift Array
Expand Down Expand Up @@ -108,6 +110,21 @@ open class MfArray: MfArrayProtocol{
self.mfdata = MfData(refdata: base.mfdata, offset: offset)
self.mfstructure = mfstructure//mfstructure will be copied because mfstructure is struct
}

/// Create a VIEW or Copy mfarray from MLShapedArray
/// - Parameters:
/// - base: A base MLShapedArray
/// - share: Whether to share memories or not, by default to true
@available(macOS 12.0, *)
public init (base: inout MLMultiArray, share: Bool = true){
precondition([MLMultiArrayDataType.float, MLMultiArrayDataType.double].contains(base.dataType), "Must be float or double in share mode")
// note that base is not assigned here!
let mftype = MfType.mftype(value: base.dataType)
let mfdata = MfData(source: share ? base : nil, data_real_ptr: base.dataPointer, storedSize: base.count, mftype: mftype, offset: 0)

self.mfdata = mfdata
self.mfstructure = MfStructure(shape: base.shape.map{ Int(truncating: $0) }, strides: base.strides.map{ Int(truncating: $0) })
}

deinit {
self.base = nil
Expand Down Expand Up @@ -152,7 +169,7 @@ extension MfArray{
}
}

internal var storedData: [Any]{
public var storedData: [Any]{
if let base = self.base{
return base.storedData
}
Expand All @@ -175,7 +192,7 @@ extension MfArray{
return self
}
else{
let mfdata = MfData(data_real_ptr: self.mfdata.data_real, storedSize: self.mfdata.storedSize, mftype: self.mfdata.mftype, offset: self.mfdata.offset)
let mfdata = MfData(source: self.mfdata, data_real_ptr: self.mfdata.data_real, storedSize: self.mfdata.storedSize, mftype: self.mfdata.mftype, offset: self.mfdata.offset)
return MfArray(mfdata: mfdata, mfstructure: self.mfstructure)
}
}
Expand All @@ -184,7 +201,7 @@ extension MfArray{
return nil
}
else{
let mfdata = MfData(data_real_ptr: self.mfdata.data_imag!, storedSize: self.mfdata.storedSize, mftype: self.mfdata.mftype, offset: self.mfdata.offset)
let mfdata = MfData(source: self.mfdata, data_real_ptr: self.mfdata.data_imag!, storedSize: self.mfdata.storedSize, mftype: self.mfdata.mftype, offset: self.mfdata.offset)
return MfArray(mfdata: mfdata, mfstructure: self.mfstructure)
}
}
Expand Down
55 changes: 46 additions & 9 deletions Sources/Matft/core/object/mfdata.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@
import Foundation
import Accelerate

internal enum MfDataSource{
case mfdata
case mlshapedarray
}

public class MfData: MfDataProtocol{
private var _base: MfData? // must be referenced because refdata could be freed automatically?
private var _isOwner: Bool = true
internal var _base: MfDataBasable? // must be referenced because refdata could be freed automatically?
private var _fromOtherDataSource: Bool = false
internal var data_real: UnsafeMutableRawPointer
internal var data_imag: UnsafeMutableRawPointer?

Expand All @@ -22,7 +27,7 @@ public class MfData: MfDataProtocol{

/// Whether to be VIEW or not
internal var _isView: Bool{
return !(self._base == nil && self._isOwner)
return (self._base != nil || self._fromOtherDataSource)
}

/// Whether to be real or not
Expand Down Expand Up @@ -76,20 +81,52 @@ public class MfData: MfDataProtocol{

/// Pass a pointer directly.
/// - Parameters:
/// - source: A source data. If the source is nil, COPY it, otherwise, SHARE it.
/// - data_real_ptr: A real data pointer
/// - data_imag_ptr: A imag data pointer
/// - storedSize: A size
/// - mftype: Type
/// - Important: The given dataptr will NOT be freed. So don't forget to free manually.
internal init(data_real_ptr: UnsafeMutableRawPointer, data_imag_ptr: UnsafeMutableRawPointer? = nil, storedSize: Int, mftype: MfType, offset: Int){
self._isOwner = false
self.data_real = data_real_ptr
self.data_imag = data_imag_ptr
/// - Important: The given dataptr will NOT be freed in SHARE mode. So don't forget to free manually.
internal init(source: MfDataBasable?, data_real_ptr: UnsafeMutableRawPointer, data_imag_ptr: UnsafeMutableRawPointer? = nil, storedSize: Int, mftype: MfType, offset: Int){
self._base = source
self._fromOtherDataSource = source != nil
self.storedSize = storedSize
self.mftype = mftype
self.offset = offset

if self._fromOtherDataSource{
self.data_real = data_real_ptr
self.data_imag = data_imag_ptr
}
else{
switch MfType.storedType(mftype) {
case .Float:
self.data_real = allocate_unsafeMRPtr(type: Float.self, count: storedSize)
memcpy(self.data_real, data_real_ptr, self.storedByteSize)

if let data_imag_ptr = data_imag_ptr{
self.data_imag = allocate_unsafeMRPtr(type: Float.self, count: storedSize)
memcpy(self.data_imag, data_imag_ptr, self.storedByteSize)
}
else{
self.data_imag = nil
}
case .Double:
self.data_real = allocate_unsafeMRPtr(type: Double.self, count: storedSize)
memcpy(self.data_real, data_real_ptr, self.storedByteSize)

if let data_imag_ptr = data_imag_ptr{
self.data_imag = allocate_unsafeMRPtr(type: Double.self, count: storedSize)
memcpy(self.data_imag, data_imag_ptr, self.storedByteSize)
}
else{
self.data_imag = nil
}
}

}

}


/// Create a zero padded MfData
/// - Parameters:
Expand Down
13 changes: 13 additions & 0 deletions Sources/Matft/core/object/mftype.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import Foundation
import Accelerate
import CoreML

public enum MfType: Int{
case None
Expand Down Expand Up @@ -64,6 +65,18 @@ public enum MfType: Int{
return MfType.mftype(value: value as Any)
}

@available(macOS 10.13, *)
static internal func mftype(value: MLMultiArrayDataType) -> MfType{
switch value {
case .double:
return .Double
case .float:
return .Float
default:
return .Object // Not supported
}
}

static public func priority(_ a: MfType, _ b: MfType) -> MfType{
if a.rawValue < b.rawValue{
return b
Expand Down
7 changes: 7 additions & 0 deletions Sources/Matft/core/protocol/mfarrayProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import Foundation

public protocol HasMfDataProtocol{
/// The data object
var mfdata: MfData { get }

/// The offset index
Expand All @@ -21,6 +22,8 @@ public protocol HasMfDataProtocol{
var storedSize: Int { get }
/// The size of the stored data (byte)
var storedByteSize: Int { get }
/// Whether to share the memory or not
var isView: Bool { get }
}

extension HasMfDataProtocol{
Expand All @@ -39,9 +42,13 @@ extension HasMfDataProtocol{
public var storedByteSize: Int{
return self.mfdata.storedByteSize
}
public var isView: Bool{
return self.mfdata._isView
}
}

public protocol HasMfStructurProtocol{
/// The structure object
var mfstructure: MfStructure { get }

/// The shape of ndarray
Expand Down
16 changes: 16 additions & 0 deletions Sources/Matft/core/protocol/mfdataProtocol.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//
// File.swift
//
//
// Created by AM19A0 on 2023/02/06.
//

import Foundation
import CoreML

public protocol MfDataBasable {}

extension MfData: MfDataBasable{}

@available(macOS 10.13, *)
extension MLMultiArray: MfDataBasable{}
46 changes: 46 additions & 0 deletions Tests/MatftTests/CreationTest.swift
Original file line number Diff line number Diff line change
@@ -1,9 +1,55 @@
import XCTest
//@testable import Matft
import Matft
import CoreML

final class CreationTests: XCTestCase {

@available(macOS 12.0, *)
func testFromMLMultiArray() {
do {
let scalars = Array<Float>(stride(from: 0, to: 28, by: 2))
var mlmarr = try! MLMultiArray(shape: [2,7], dataType: .float)
for (i, s) in scalars.enumerated(){
mlmarr[[i/7,i%7] as [NSNumber]] = s as NSNumber
}
let mfarray = MfArray(base: &mlmarr)
XCTAssertEqual(mfarray, Matft.arange(start: 0, to: 28, by: 2, shape: [2, 7], mftype: .Float, mforder: .Row))
XCTAssertTrue(mfarray.isView)
}

do {
let scalars = Array<Float>(stride(from: 0, to: 28, by: 2))
var mlmarr = try! MLMultiArray(shape: [2,7], dataType: .float)
for (i, s) in scalars.enumerated(){
mlmarr[[i/7,i%7] as [NSNumber]] = s as NSNumber
}

let mfarray = MfArray(base: &mlmarr, share: false)
XCTAssertEqual(mfarray, Matft.arange(start: 0, to: 28, by: 2, shape: [2, 7], mftype: .Float, mforder: .Row))
XCTAssertFalse(mfarray.isView)
}

}

@available(macOS 12.0, *)
func testFromMLMultiArrayShare() {
do {
let scalars = Array<Float>(stride(from: 0, to: 28, by: 2))
var mlmarr = try! MLMultiArray(shape: [2,7], dataType: .float)
for (i, s) in scalars.enumerated(){
mlmarr[[i/7,i%7] as [NSNumber]] = s as NSNumber
}
let mfarray = MfArray(base: &mlmarr)
mfarray[0] = MfArray([1])
for i in 0..<scalars.count / 2{
let val = mlmarr[[0,i%7] as [NSNumber]] as! Float
XCTAssertEqual(mfarray[0, i].scalar as! Float, val)
}
XCTAssertTrue(mfarray.isView)
}
}

func testAppend() {
do {
let x = MfArray([1,2,3])
Expand Down

0 comments on commit ba20aaa

Please sign in to comment.