Skip to content

Commit

Permalink
Fix: Image preprocessing in Swift
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Apr 12, 2024
1 parent f6faf4c commit f2772d0
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 54 deletions.
15 changes: 15 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File with Arguments",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
}
]
}
7 changes: 6 additions & 1 deletion python/scripts/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
torch_available = True
except:
torch_available = False

# ONNX is not a very light dependency either
try:
import onnx
Expand All @@ -34,6 +34,7 @@
("unum-cloud/uform-vl-english-large", "gpu", "fp16"),
]


@pytest.mark.skipif(not torch_available, reason="PyTorch is not installed")
@pytest.mark.parametrize("model_name", torch_models)
def test_torch_one_embedding(model_name: str):
Expand Down Expand Up @@ -141,3 +142,7 @@ def test_onnx_many_embeddings(model_specs: Tuple[str, str, str], batch_size: int

except ExecutionProviderError as e:
pytest.skip(f"Execution provider error: {e}")


if __name__ == "__main__":
pytest.main(["-s", "-x", __file__])
3 changes: 3 additions & 0 deletions python/uform/numpy_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def _resize_crop_normalize(self, image: Image):
bottom = (height + self._image_size) / 2

image = image.convert("RGB").crop((left, top, right, bottom))
# At this point `image` is a PIL Image with RGB channels.
# If you convert it to `np.ndarray` it will have shape (H, W, C) where C is the number of channels.
image = (np.array(image).astype(np.float32) / 255.0 - self.image_mean) / self.image_std

# To make it compatible with PyTorch, we need to transpose the image to (C, H, W).
return np.transpose(image, (2, 0, 1))
150 changes: 100 additions & 50 deletions swift/Embeddings.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,69 @@ import Foundation
import Hub // `Config`
import Tokenizers // `AutoTokenizer`

public enum Embedding {
case i32s([Int32])
case f16s([Float16])
case f32s([Float32])
case f64s([Float64])

init?(from multiArray: MLMultiArray) {
switch multiArray.dataType {
case .float64:
self = .f64s(
Array(
UnsafeBufferPointer(
start: multiArray.dataPointer.assumingMemoryBound(to: Float64.self),
count: Int(truncating: multiArray.shape[1])
)
)
)
case .float32:
self = .f32s(
Array(
UnsafeBufferPointer(
start: multiArray.dataPointer.assumingMemoryBound(to: Float32.self),
count: Int(truncating: multiArray.shape[1])
)
)
)
case .float16:
self = .f16s(
Array(
UnsafeBufferPointer(
start: multiArray.dataPointer.assumingMemoryBound(to: Float16.self),
count: Int(truncating: multiArray.shape[1])
)
)
)
case .int32:
self = .i32s(
Array(
UnsafeBufferPointer(
start: multiArray.dataPointer.assumingMemoryBound(to: Int32.self),
count: Int(truncating: multiArray.shape[1])
)
)
)
@unknown default:
return nil // return nil for unsupported data types
}
}

public func asFloats() -> [Float] {
switch self {
case .f32s(let array):
return array
case .i32s(let array):
return array.map { Float($0) }
case .f16s(let array):
return array.map { Float($0) }
case .f64s(let array):
return array.map { Float($0) }
}
}
}

// MARK: - Helpers

func readConfig(fromPath path: String) throws -> [String: Any] {
Expand Down Expand Up @@ -39,18 +102,20 @@ public class TextEncoder {
self.processor = try TextProcessor(configPath: configPath, tokenizerPath: tokenizerPath, model: self.model)
}

public func forward(with text: String) throws -> [Float32] {
public func forward(with text: String) throws -> Embedding {
let inputFeatureProvider = try self.processor.preprocess(text)
let prediction = try self.model.prediction(from: inputFeatureProvider)
let predictionFeature = prediction.featureValue(for: "embeddings")
// The `predictionFeature` is an MLMultiArray, which can be converted to an array of Float32
let output = predictionFeature!.multiArrayValue!
return Array(
UnsafeBufferPointer(
start: output.dataPointer.assumingMemoryBound(to: Float32.self),
count: Int(truncating: output.shape[1])
guard let predictionFeature = prediction.featureValue(for: "embeddings"),
let output = predictionFeature.multiArrayValue,
let embedding = Embedding(from: output)
else {
throw NSError(
domain: "TextEncoder",
code: 0,
userInfo: [NSLocalizedDescriptionKey: "Failed to extract embeddings or unsupported data type."]
)
)
}
return embedding
}
}

Expand All @@ -63,20 +128,21 @@ public class ImageEncoder {
self.processor = try ImageProcessor(configPath: configPath)
}

public func forward(with image: CGImage) throws -> [Float32] {
public func forward(with image: CGImage) throws -> Embedding {
let inputFeatureProvider = try self.processor.preprocess(image)
let prediction = try self.model.prediction(from: inputFeatureProvider)
let predictionFeature = prediction.featureValue(for: "embeddings")
// The `predictionFeature` is an MLMultiArray, which can be converted to an array of Float32
let output = predictionFeature!.multiArrayValue!
return Array(
UnsafeBufferPointer(
start: output.dataPointer.assumingMemoryBound(to: Float32.self),
count: Int(truncating: output.shape[1])
guard let predictionFeature = prediction.featureValue(for: "embeddings"),
let output = predictionFeature.multiArrayValue,
let embedding = Embedding(from: output)
else {
throw NSError(
domain: "ImageEncoder",
code: 0,
userInfo: [NSLocalizedDescriptionKey: "Failed to extract embeddings or unsupported data type."]
)
)
}
return embedding
}

}

// MARK: - Processors
Expand Down Expand Up @@ -147,18 +213,17 @@ class ImageProcessor {
let originalWidth = CGFloat(image.width)
let originalHeight = CGFloat(image.height)

// Calculate new size preserving the aspect ratio
let widthRatio = CGFloat(imageSize) / originalWidth
let heightRatio = CGFloat(imageSize) / originalHeight
let scaleFactor = max(widthRatio, heightRatio)

let scaledWidth = originalWidth * scaleFactor
let scaledHeight = originalHeight * scaleFactor

// Calculate the crop rectangle
let dx = (scaledWidth - CGFloat(imageSize)) / 2.0
let dy = (scaledHeight - CGFloat(imageSize)) / 2.0
let insetRect = CGRect(x: dx, y: dy, width: CGFloat(imageSize) - dx * 2, height: CGFloat(imageSize) - dy * 2)

// Create a new context (off-screen canvas) with the desired dimensions
guard
let context = CGContext(
data: nil,
Expand All @@ -171,11 +236,9 @@ class ImageProcessor {
)
else { return nil }

// Draw the image in the context with the specified inset (cropping as necessary)
// Draw the scaled and cropped image in the context
context.interpolationQuality = .high
context.draw(image, in: insetRect, byTiling: false)

// Extract the new image from the context
context.draw(image, in: CGRect(x: -dx, y: -dy, width: scaledWidth, height: scaledHeight))
return context.makeImage()
}

Expand All @@ -193,44 +256,31 @@ class ImageProcessor {
bitsPerComponent: 8,
bytesPerRow: 4 * width,
space: colorSpace,
bitmapInfo: CGImageAlphaInfo.noneSkipLast.rawValue
bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue
)
context?.draw(image, in: CGRect(x: 0, y: 0, width: width, height: height))

// Convert pixel data to float and normalize
let totalCount = width * height * 4
var floatPixels = [Float](repeating: 0, count: totalCount)
vDSP_vfltu8(pixelData, 1, &floatPixels, 1, vDSP_Length(totalCount))

// Scale the pixel values to [0, 1]
var divisor = Float(255.0)
vDSP_vsdiv(floatPixels, 1, &divisor, &floatPixels, 1, vDSP_Length(totalCount))

// Normalize the pixel values
// Normalize the pixel data
var floatPixels = [Float](repeating: 0, count: width * height * 3)
for c in 0 ..< 3 {
var slice = [Float](repeating: 0, count: width * height)
for i in 0 ..< (width * height) {
slice[i] = (floatPixels[i * 4 + c] - mean[c]) / std[c]
floatPixels[i * 3 + c] = (Float(pixelData[i * 4 + c]) / 255.0 - mean[c]) / std[c]
}
floatPixels.replaceSubrange(c * width * height ..< (c + 1) * width * height, with: slice)
}

// Rearrange the array to C x H x W
var tensor = [Float](repeating: 0, count: width * height * 3)
for y in 0 ..< height {
for x in 0 ..< width {
for c in 0 ..< 3 {
tensor[c * width * height + y * width + x] = floatPixels[y * width * 4 + x * 4 + c]
}
// Create the tensor array
var tensor = [Float](repeating: 0, count: 3 * width * height)
for i in 0 ..< (width * height) {
for c in 0 ..< 3 {
tensor[c * width * height + i] = floatPixels[i * 3 + c]
}
}

// Reshape the tensor to 1 x 3 x H x W and pack into a rank-3 `MLFeatureValue`
let multiArray = try? MLMultiArray(
shape: [1, 3, NSNumber(value: self.imageSize), NSNumber(value: self.imageSize)],
shape: [1, 3, NSNumber(value: height), NSNumber(value: width)],
dataType: .float32
)
for i in 0 ..< (width * height * 3) {
for i in 0 ..< tensor.count {
multiArray?[i] = NSNumber(value: tensor[i])
}
return multiArray
Expand Down
6 changes: 3 additions & 3 deletions swift/EmbeddingsTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ final class TokenizerTests: XCTestCase {

var textEmbeddings: [[Float32]] = []
for text in texts {
let embedding: [Float32] = try textModel.forward(with: text)
let embedding: [Float32] = try textModel.forward(with: text).asFloats()
textEmbeddings.append(embedding)
}

Expand Down Expand Up @@ -102,9 +102,9 @@ final class TokenizerTests: XCTestCase {
)
}

let textEmbedding: [Float32] = try textModel.forward(with: text)
let textEmbedding: [Float32] = try textModel.forward(with: text).asFloats()
textEmbeddings.append(textEmbedding)
let imageEmbedding: [Float32] = try imageModel.forward(with: cgImage)
let imageEmbedding: [Float32] = try imageModel.forward(with: cgImage).asFloats()
imageEmbeddings.append(imageEmbedding)
}

Expand Down

0 comments on commit f2772d0

Please sign in to comment.