Skip to content

Commit

Permalink
✨ support claude 3 params and stream handling
Browse files Browse the repository at this point in the history
  • Loading branch information
danielglh committed Mar 12, 2024
1 parent 50461af commit 30cd968
Show file tree
Hide file tree
Showing 5 changed files with 379 additions and 67 deletions.
67 changes: 12 additions & 55 deletions src/libs/agent-runtime/anthropic/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@ import { AgentRuntimeErrorType } from '../error';
import {
ChatCompetitionOptions,
ChatStreamPayload,
ModelProvider,
OpenAIChatMessage,
UserMessageContentPart,
ModelProvider
} from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { parseDataUri } from '../utils/uriParser';
import { buildAnthropicMessages } from '../utils/anthropicHelpers';

export class LobeAnthropicAI implements LobeRuntimeAI {
private client: Anthropic;
Expand All @@ -26,40 +24,22 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
this.client = new Anthropic({ apiKey });
}

private buildAnthropicMessages = (
messages: OpenAIChatMessage[],
): Anthropic.Messages.MessageParam[] =>
messages.map((message) => this.convertToAnthropicMessage(message));

private convertToAnthropicMessage = (
message: OpenAIChatMessage,
): Anthropic.Messages.MessageParam => {
const content = message.content as string | UserMessageContentPart[];

return {
content:
typeof content === 'string' ? content : content.map((c) => this.convertToAnthropicBlock(c)),
role: message.role === 'function' || message.role === 'system' ? 'assistant' : message.role,
};
};

async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) {
const { messages, model, max_tokens, temperature, top_p } = payload;
const system_message = messages.find((m) => m.role === 'system');
const user_messages = messages.filter((m) => m.role !== 'system');

const requestParams: Anthropic.MessageCreateParams = {
max_tokens: max_tokens || 4096,
messages: this.buildAnthropicMessages(user_messages),
model: model,
stream: true,
system: system_message?.content as string,
temperature: temperature,
top_p: top_p,
};

try {
const response = await this.client.messages.create(requestParams);
const response = await this.client.messages.create({
max_tokens: max_tokens || 4096,
messages: buildAnthropicMessages(user_messages),
model: model,
stream: true,
system: system_message?.content as string,
temperature: temperature,
top_p: top_p,
});

const [prod, debug] = response.tee();

if (process.env.DEBUG_ANTHROPIC_CHAT_COMPLETION === '1') {
Expand Down Expand Up @@ -91,29 +71,6 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
});
}
}

private convertToAnthropicBlock(
content: UserMessageContentPart,
): Anthropic.ContentBlock | Anthropic.ImageBlockParam {
switch (content.type) {
case 'text': {
return content;
}

case 'image_url': {
const { mimeType, base64 } = parseDataUri(content.image_url.url);

return {
source: {
data: base64 as string,
media_type: mimeType as Anthropic.ImageBlockParam.Source['media_type'],
type: 'base64',
},
type: 'image',
};
}
}
}
}

export default LobeAnthropicAI;
225 changes: 225 additions & 0 deletions src/libs/agent-runtime/bedrock/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
// @vitest-environment node
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import {
InvokeModelWithResponseStreamCommand,
} from '@aws-sdk/client-bedrock-runtime';
import * as debugStreamModule from '../utils/debugStream';
import { LobeBedrockAI } from './index';

const provider = 'bedrock';

// Mock the console.error to avoid polluting test output
vi.spyOn(console, 'error').mockImplementation(() => {});

vi.mock("@aws-sdk/client-bedrock-runtime", async (importOriginal) => {
const module = await importOriginal();
return {
...(module as any),
InvokeModelWithResponseStreamCommand: vi.fn()
}
})

let instance: LobeBedrockAI;

beforeEach(() => {
instance = new LobeBedrockAI({
region: 'us-west-2',
accessKeyId: 'test-access-key-id',
accessKeySecret: 'test-access-key-secret',
});

vi.spyOn(instance['client'], 'send').mockReturnValue(new ReadableStream() as any);
});

afterEach(() => {
vi.clearAllMocks();
});

describe('LobeBedrockAI', () => {
describe('init', () => {
it('should correctly initialize with AWS credentials', async () => {
const instance = new LobeBedrockAI({
region: 'us-west-2',
accessKeyId: 'test-access-key-id',
accessKeySecret: 'test-access-key-secret',
});
expect(instance).toBeInstanceOf(LobeBedrockAI);
});
});

describe('chat', () => {

describe('Claude model', () => {

it('should return a Response on successful API call', async () => {
const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'anthropic.claude-v2:1',
temperature: 0,
});

// Assert
expect(result).toBeInstanceOf(Response);
});

it('should handle text messages correctly', async () => {
// Arrange
const mockStream = new ReadableStream({
start(controller) {
controller.enqueue('Hello, world!');
controller.close();
},
});
const mockResponse = Promise.resolve(mockStream);
(instance['client'].send as Mock).mockResolvedValue(mockResponse);

// Act
const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'anthropic.claude-v2:1',
temperature: 0,
top_p: 1,
});

// Assert
expect(instance['client'].send).toHaveBeenCalledWith(
new InvokeModelWithResponseStreamCommand({
accept: 'application/json',
body: JSON.stringify({
anthropic_version: "bedrock-2023-05-31",
max_tokens: 4096,
messages: [{ content: 'Hello', role: 'user' }],
temperature: 0,
top_p: 1,
}),
contentType: 'application/json',
modelId: 'anthropic.claude-v2:1',
})
);
expect(result).toBeInstanceOf(Response);
});

it('should handle system prompt correctly', async () => {
// Arrange
const mockStream = new ReadableStream({
start(controller) {
controller.enqueue('Hello, world!');
controller.close();
},
});
const mockResponse = Promise.resolve(mockStream);
(instance['client'].send as Mock).mockResolvedValue(mockResponse);

// Act
const result = await instance.chat({
messages: [
{ content: 'You are an awesome greeter', role: 'system' },
{ content: 'Hello', role: 'user' },
],
model: 'anthropic.claude-v2:1',
temperature: 0,
top_p: 1,
});

// Assert
expect(instance['client'].send).toHaveBeenCalledWith(
new InvokeModelWithResponseStreamCommand({
accept: 'application/json',
body: JSON.stringify({
anthropic_version: "bedrock-2023-05-31",
max_tokens: 4096,
messages: [{ content: 'Hello', role: 'user' }],
system: 'You are an awesome greeter',
temperature: 0,
top_p: 1,
}),
contentType: 'application/json',
modelId: 'anthropic.claude-v2:1',
})
);
expect(result).toBeInstanceOf(Response);
});

it('should call Anthropic model with supported opions', async () => {
// Arrange
const mockStream = new ReadableStream({
start(controller) {
controller.enqueue('Hello, world!');
controller.close();
},
});
const mockResponse = Promise.resolve(mockStream);
(instance['client'].send as Mock).mockResolvedValue(mockResponse);

// Act
const result = await instance.chat({
max_tokens: 2048,
messages: [{ content: 'Hello', role: 'user' }],
model: 'anthropic.claude-v2:1',
temperature: 0.5,
top_p: 1,
});

// Assert
expect(instance['client'].send).toHaveBeenCalledWith(
new InvokeModelWithResponseStreamCommand({
accept: 'application/json',
body: JSON.stringify({
anthropic_version: "bedrock-2023-05-31",
max_tokens: 2048,
messages: [{ content: 'Hello', role: 'user' }],
temperature: 0.5,
top_p: 1,
}),
contentType: 'application/json',
modelId: 'anthropic.claude-v2:1',
})
);
expect(result).toBeInstanceOf(Response);
});

it('should call Anthropic model without unsupported opions', async () => {
// Arrange
const mockStream = new ReadableStream({
start(controller) {
controller.enqueue('Hello, world!');
controller.close();
},
});
const mockResponse = Promise.resolve(mockStream);
(instance['client'].send as Mock).mockResolvedValue(mockResponse);

// Act
const result = await instance.chat({
frequency_penalty: 0.5, // Unsupported option
max_tokens: 2048,
messages: [{ content: 'Hello', role: 'user' }],
model: 'anthropic.claude-v2:1',
presence_penalty: 0.5,
temperature: 0.5,
top_p: 1,
});

// Assert
expect(instance['client'].send).toHaveBeenCalledWith(
new InvokeModelWithResponseStreamCommand({
accept: 'application/json',
body: JSON.stringify({
anthropic_version: "bedrock-2023-05-31",
max_tokens: 2048,
messages: [{ content: 'Hello', role: 'user' }],
temperature: 0.5,
top_p: 1,
}),
contentType: 'application/json',
modelId: 'anthropic.claude-v2:1',
})
);
expect(result).toBeInstanceOf(Response);
});

});

});
});
Loading

0 comments on commit 30cd968

Please sign in to comment.