Skip to content

Commit

Permalink
Add support for GPT-4 Vision (#17)
Browse files Browse the repository at this point in the history
* Add Content

* Make ChatContent public

* Rename ChatContent -> MessageContent

* Update ChatMessage

* Fix init

* Add init for converting image data to base64

* Return emtpy string in case of no text
  • Loading branch information
ronaldmannak authored Apr 13, 2024
1 parent 61812ff commit cf2f607
Show file tree
Hide file tree
Showing 7 changed files with 355 additions and 7 deletions.
75 changes: 70 additions & 5 deletions Sources/CleverBird/chat/ChatMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public struct ChatMessage: Codable, Identifiable {
public let role: Role

/// The contents of the message. `content` is required for all messages except assistant messages with function calls.
public let content: String?
public let content: Content?

/// The name and arguments of a function that should be called, as generated by the model.
public let functionCall: FunctionCall?
Expand All @@ -36,14 +36,21 @@ public struct ChatMessage: Codable, Identifiable {
content: String? = nil,
id: String? = nil,
functionCall: FunctionCall? = nil) throws {
try self.init(role: role, media: content != nil ? .text(content!) : nil, id: id, functionCall: functionCall)
}

public init(role: Role,
media: ChatMessage.Content?,
id: String? = nil,
functionCall: FunctionCall? = nil) throws {

// Validation: Content is required for all messages except assistant messages with function calls.
if content == nil && !(role == .assistant && functionCall != nil) {
if media == nil && !(role == .assistant && functionCall != nil) {
throw CleverBirdError.invalidMessageContent
}

self.role = role
self.content = content
self.content = media
self.name = functionCall?.name
if role == .function {
// If the role is "function" I need to set functionCall to nil, otherwise this will
Expand All @@ -58,7 +65,9 @@ public struct ChatMessage: Codable, Identifiable {
} else {
var hasher = Hasher()
hasher.combine(self.role)
hasher.combine(self.content ?? "")
if let content {
hasher.combine(content)
}
let hashValue = abs(hasher.finalize())
let timestamp = Int(Date.now.timeIntervalSince1970*10000)

Expand All @@ -69,7 +78,7 @@ public struct ChatMessage: Codable, Identifiable {
public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.role = try container.decode(Role.self, forKey: .role)
self.content = try container.decodeIfPresent(String.self, forKey: .content)
self.content = try container.decodeIfPresent(Content.self, forKey: .content)
self.functionCall = try container.decodeIfPresent(FunctionCall.self, forKey: .functionCall)
self.name = try container.decodeIfPresent(String.self, forKey: .name)
self.id = "pending"
Expand All @@ -92,3 +101,59 @@ extension ChatMessage: Equatable {
&& lhs.content == rhs.content
}
}

extension ChatMessage {

public enum Content: Codable, Equatable, CustomStringConvertible, Hashable {

case text(String)
case media([MessageContent])

public init(from decoder: Decoder) throws {
let container = try decoder.singleValueContainer()
if let textContent = try? container.decode(String.self) {
self = .text(textContent)
} else if let chatContents = try? container.decode([MessageContent].self) {
self = .media(chatContents)
} else {
throw DecodingError.typeMismatch(MessageContent.self, DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Unsupported type for Content"))
}
}

public func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
switch self {
case .text(let text):
try container.encode(text)
case .media(let contents):
try container.encode(contents)
}
}

public static func == (lhs: Content, rhs: Content) -> Bool {
switch (lhs, rhs) {
case (.text(let leftText), .text(let rightText)):
return leftText == rightText
case (.media(let leftMedia), .media(let rightMedia)):
return leftMedia == rightMedia
default:
return false
}
}

public var description: String {
switch self {
case .media(let messageContents):
for messageContent in messageContents {
if case .text(let textValue) = messageContent {
return textValue
}
}
return ""
case .text(let text):
return text
}
}
}
}

27 changes: 26 additions & 1 deletion Sources/CleverBird/chat/ChatThread+tokenCount.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,32 @@ extension ChatThread {
let roleTokens = try tokenEncoder.encode(text: message.role.rawValue).count
let contentTokens: Int
if let content = message.content {
contentTokens = try tokenEncoder.encode(text: content).count
switch content {
case .text(let text):
contentTokens = try tokenEncoder.encode(text: text).count
case .media(let media):
var count = 0
for medium in media {
switch medium {
case .text(let text):
count += try tokenEncoder.encode(text: text).count
case .imageUrl(let url):
// See https://platform.openai.com/docs/guides/vision/calculating-costs
switch url.detail {
// TODO: calculate real values for auto and high
case .auto:
count += 1105
case .high:
count += 1105
case .low:
count += 85
case .none:
count += 1105
}
}
}
contentTokens = count
}
} else if let functionCall = message.functionCall {
let jsonEncoder = JSONEncoder()
let jsonData = try jsonEncoder.encode(functionCall)
Expand Down
10 changes: 10 additions & 0 deletions Sources/CleverBird/chat/ChatThread.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ public class ChatThread: Codable {
}
return self
}

@discardableResult
public func addUserMessage(_ media: [MessageContent]) -> Self {
do {
try addMessage(ChatMessage(role: .user, media: .media(media)))
} catch {
print(error.localizedDescription)
}
return self
}

@discardableResult
public func addAssistantMessage(_ content: String) -> Self {
Expand Down
90 changes: 90 additions & 0 deletions Sources/CleverBird/chat/MessageContent.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
//
// ChatContent.swift
//
//
// Created by Ronald Mannak on 4/12/24.
//

import Foundation

public enum MessageContent: Hashable {
case text(String)
case imageUrl(URLDetail)
}

extension MessageContent {
public enum ContentType: String, Codable, Hashable {
case text
case imageUrl = "image_url"
}

public struct URLDetail: Codable, Equatable, Hashable {

public enum Detail: String, Codable {
case low, high, auto
}

let url: String
let detail: Detail?

public init(url: String, detail: Detail? = nil) {
self.url = url
self.detail = detail
}

public init(url: URL, detail: Detail? = nil) {
self.init(url: url.absoluteString, detail: detail)
}

public init(imageData: Data, detail: Detail? = nil) {
let base64 = imageData.base64EncodedString()
self.init(url: "data:image/jpeg;base64,\(base64)", detail: detail)
}
}
}

extension MessageContent: Codable {

private enum CodingKeys: String, CodingKey {
case type, text, imageUrl
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let type = try container.decode(ContentType.self, forKey: .type)

switch type {
case .text:
let text = try container.decode(String.self, forKey: .text)
self = .text(text)
case .imageUrl:
let imageUrl = try container.decode(URLDetail.self, forKey: .imageUrl)
self = .imageUrl(imageUrl)
}
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
switch self {
case .text(let text):
try container.encode(ContentType.text.rawValue, forKey: .type)
try container.encode(text, forKey: .text)
case .imageUrl(let urlDetail):
try container.encode(ContentType.imageUrl.rawValue, forKey: .type)
try container.encode(urlDetail, forKey: .imageUrl)
}
}
}

extension MessageContent: Equatable {
public static func == (lhs: MessageContent, rhs: MessageContent) -> Bool {
switch (lhs, rhs) {
case (.text(let lhsText), .text(let rhsText)):
return lhsText == rhsText
case (.imageUrl(let lhsUrlDetail), .imageUrl(let rhsUrlDetail)):
return lhsUrlDetail == rhsUrlDetail
default:
return false
}
}
}
24 changes: 24 additions & 0 deletions Tests/CleverBirdTests/MessageContentTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//
// MessageContentTests.swift
//
//
// Created by Ronald Mannak on 4/12/24.
//

import Foundation
import XCTest
@testable import CleverBird

class MessageContentTests: XCTestCase {

func testURL() {
let content = MessageContent.URLDetail(url: URL(string: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg")!)
XCTAssertEqual(content.url, "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg")
}

func testBase64() {
let data = "Hello, world".data(using: .utf8)!
let content = MessageContent.URLDetail(imageData: data)
XCTAssertEqual(content.url, "")
}
}
Loading

0 comments on commit cf2f607

Please sign in to comment.