Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added requested changes #4

Merged
merged 2 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions docs/core_docs/docs/integrations/chat/azure_ml.mdx
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# Azure Machine Learning Chat

You can deploy models on Azure with the endpointUrl, apiKey, and deploymentName
You must deploy models on Azure with the endpointUrl, apiKey, and deploymentName
when creating the AzureMLChatParams to call upon later. Must import a ContentFormatter
or create your own using the ChatContentFormatter interface.

```typescript
import { AzureMLChatParams, LlamaContentFormatter } from "langchain/chat_models/azure_ml";
import {
AzureMLChatParams,
LlamaContentFormatter,
} from "langchain/chat_models/azure_ml";

const model = new AzureMLModel({
endpointUrl: "YOUR_ENDPOINT_URL",
endpointApiKey: "YOUR_ENDPOINT_API_KEY",
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME",
contentFormatter: new LlamaContentFormatter()
const model = new AzureMLOnlineEndpoint({
endpointUrl: "YOUR_ENDPOINT_URL",
endpointApiKey: "YOUR_ENDPOINT_API_KEY",
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME",
contentFormatter: new LlamaContentFormatter(),
});

const res = model.call(["Foo"]);
const res = model.invoke(["Foo"]);
console.log({ res });
```
```
23 changes: 13 additions & 10 deletions docs/core_docs/docs/integrations/llms/azure_ml.mdx
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# Azure Machine Learning

You can deploy models on Azure with the endpointUrl, apiKey, and deploymentName
when creating the AzureMLModel to call upon later. Must import a ContentFormatter
You must deploy models on Azure with the endpointUrl, apiKey, and deploymentName
when creating the AzureMLOnlineEndpoint to call upon later. Must import a ContentFormatter
or create your own using the ContentFormatter interface.

```typescript
import { AzureMLModel, LlamaContentFormatter } from "langchain/llms/azure_ml";
import {
AzureMLOnlineEndpoint,
LlamaContentFormatter,
} from "langchain/llms/azure_ml";

const model = new AzureMLModel({
endpointUrl: "YOUR_ENDPOINT_URL",
endpointApiKey: "YOUR_ENDPOINT_API_KEY",
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME",
contentFormatter: new LlamaContentFormatter()
const model = new AzureMLOnlineEndpoint({
endpointUrl: "YOUR_ENDPOINT_URL",
endpointApiKey: "YOUR_ENDPOINT_API_KEY",
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME",
contentFormatter: new LlamaContentFormatter(),
});

const res = model.call("Foo");
const res = model.invoke("Foo");
console.log({ res });
```
```
18 changes: 10 additions & 8 deletions examples/src/models/chat/chat_azure_ml.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import { AzureMLChatModel, LlamaContentFormatter } from "langchain/chat_models/azure_ml";
import {
AzureMLChatOnlineEndpoint,
LlamaContentFormatter,
} from "langchain/chat_models/azure_ml";

const model = new AzureMLChatModel({
endpointUrl: "YOUR_ENDPOINT_URL", // Or set as process.env.AZURE_ML_ENDPOINTURL
endpointApiKey: "YOUR_ENDPOINT_API_KEY", // Or set as process.env.AZURE_ML_APIKEY
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME", // Or set as process.env.AZURE_ML_NAME
contentFormatter: new LlamaContentFormatter(), // Only LLAMA currently supported.
const model = new AzureMLChatOnlineEndpoint({
endpointUrl: "YOUR_ENDPOINT_URL", // Or set as process.env.AZURE_ML_ENDPOINTURL
endpointApiKey: "YOUR_ENDPOINT_API_KEY", // Or set as process.env.AZURE_ML_APIKEY
contentFormatter: new LlamaContentFormatter(), // Only LLAMA currently supported.
});

const res = model.call("Foo");
const res = model.invoke("Foo");

console.log({ res });
console.log({ res });
19 changes: 11 additions & 8 deletions examples/src/models/llm/azure_ml.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import { AzureMLModel, LlamaContentFormatter } from "langchain/llms/azure_ml";
import {
AzureMLOnlineEndpoint,
LlamaContentFormatter,
} from "langchain/llms/azure_ml";

const model = new AzureMLModel({
endpointUrl: "YOUR_ENDPOINT_URL", // Or set as process.env.AZURE_ML_ENDPOINTURL
endpointApiKey: "YOUR_ENDPOINT_API_KEY", // Or set as process.env.AZURE_ML_APIKEY
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME", // Or set as process.env.AZURE_ML_NAME
contentFormatter: new LlamaContentFormatter(), // Or any of the other Models: GPT2ContentFormatter, HFContentFormatter, DollyContentFormatter
const model = new AzureMLOnlineEndpoint({
endpointUrl: "YOUR_ENDPOINT_URL", // Or set as process.env.AZURE_ML_ENDPOINTURL
endpointApiKey: "YOUR_ENDPOINT_API_KEY", // Or set as process.env.AZURE_ML_APIKEY
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME", // Or set as process.env.AZURE_ML_NAME
contentFormatter: new LlamaContentFormatter(), // Or any of the other Models: GPT2ContentFormatter, HFContentFormatter, DollyContentFormatter
});

const res = model.call("Foo");
const res = model.invoke("Foo");

console.log({ res });
console.log({ res });
125 changes: 58 additions & 67 deletions langchain/src/chat_models/azure_ml.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@ import { getEnvironmentVariable } from "../util/env.js";
import { BaseMessage } from "../schema/index.js";

export interface ChatContentFormatter {
/**
/**
* Formats the request payload for the AzureML endpoint. It takes a
* prompt and a dictionary of model arguments as input and returns a
* string representing the formatted request payload.
* @param messages A list of messages for the chat so far.
* @param modelArgs A dictionary of model arguments.
* @returns A string representing the formatted request payload.
*/
formatRequestPayload:(messages:BaseMessage[], modelArgs:Record<string, unknown>) => string;
formatRequestPayload: (
messages: BaseMessage[],
modelArgs: Record<string, unknown>
) => string;
/**
* Formats the response payload from the AzureML endpoint. It takes a
* response payload as input and returns a string representing the
Expand All @@ -24,98 +27,86 @@ export interface ChatContentFormatter {
}

export class LlamaContentFormatter implements ChatContentFormatter {
_convertMessageToRecord(message:BaseMessage):Record<string, unknown> {
if (message._getType() === 'human') {
return {role: "user", content: message.content}
} else if (message._getType() === 'ai') {
return {role: "assistant", content: message.content}
} else {
return {role: message._getType(), content: message.content}
}
_convertMessageToRecord(message: BaseMessage): Record<string, unknown> {
if (message._getType() === "human") {
return { role: "user", content: message.content };
} else if (message._getType() === "ai") {
return { role: "assistant", content: message.content };
} else {
return { role: message._getType(), content: message.content };
}
}

formatRequestPayload(
messages: BaseMessage[],
modelArgs: Record<string, unknown>
): string {
let msgs = messages.map(message => {
this._convertMessageToRecord(message)
});
return JSON.stringify(
{"input_data": {
"input_string": msgs,
"parameters": modelArgs
}}
)
}
formatRequestPayload(
messages: BaseMessage[],
modelArgs: Record<string, unknown>
): string {
let msgs = messages.map((message) => {
this._convertMessageToRecord(message);
});
return JSON.stringify({
input_data: {
input_string: msgs,
parameters: modelArgs,
},
});
}

formatResponsePayload(
responsePayload: string
) {
const response = JSON.parse(responsePayload);
return response.output
}
formatResponsePayload(responsePayload: string) {
const response = JSON.parse(responsePayload);
return response.output;
}
}

/**
* Type definition for the input parameters of the AzureMLChatOnlineEndpoint class.
*/
export interface AzureMLChatParams extends BaseChatModelParams {
endpointUrl?: string;
endpointApiKey?: string;
modelArgs?: Record<string, unknown>;
contentFormatter?: ChatContentFormatter;
};

export interface AzureMLChatParams extends BaseChatModelParams {
endpointUrl?: string;
endpointApiKey?: string;
modelArgs?: Record<string, unknown>;
contentFormatter?: ChatContentFormatter;
}

/**
* Class that represents the chat model. It extends the SimpleChatModel class and implements the AzureMLChatInput interface.
*/
export class AzureMLChatModel extends SimpleChatModel implements AzureMLChatParams {
export class AzureMLChatOnlineEndpoint
extends SimpleChatModel
implements AzureMLChatParams
{
static lc_name() {
return "AzureMLChat";
}
static lc_description() {
return "A class for interacting with AzureML Chat models.";
}

static lc_fields() {
return {
endpointUrl: {
lc_description: "The URL of the AzureML endpoint.",
lc_env: "AZUREML_URL",
},
endpointApiKey: {
lc_description: "The API key for the AzureML endpoint.",
lc_env: "AZUREML_API_KEY",
},
contentFormatter: {
lc_description: "The formatter for AzureML API",
}
};
}
endpointUrl: string;
endpointApiKey: string;
modelArgs?: Record<string, unknown>;
contentFormatter: ChatContentFormatter;
httpClient: AzureMLHttpClient;


constructor(fields: AzureMLChatParams) {
super(fields ?? {});
if (!fields?.endpointUrl && !getEnvironmentVariable('AZUREML_URL')) {
if (!fields?.endpointUrl && !getEnvironmentVariable("AZUREML_URL")) {
throw new Error("No Azure ML Url found.");
}
if (!fields?.endpointApiKey && !getEnvironmentVariable('AZUREML_API_KEY')) {
if (!fields?.endpointApiKey && !getEnvironmentVariable("AZUREML_API_KEY")) {
throw new Error("No Azure ML ApiKey found.");
}
if (!fields?.contentFormatter) {
throw new Error("No Content Formatter provided.")
throw new Error("No Content Formatter provided.");
}

this.endpointUrl = fields.endpointUrl || getEnvironmentVariable('AZUREML_URL')+'';
this.endpointApiKey = fields.endpointApiKey || getEnvironmentVariable('AZUREML_API_KEY')+'';
this.httpClient = new AzureMLHttpClient(this.endpointUrl, this.endpointApiKey);

this.endpointUrl =
fields.endpointUrl || getEnvironmentVariable("AZUREML_URL") + "";
this.endpointApiKey =
fields.endpointApiKey || getEnvironmentVariable("AZUREML_API_KEY") + "";
this.httpClient = new AzureMLHttpClient(
this.endpointUrl,
this.endpointApiKey
);
this.contentFormatter = fields.contentFormatter;
this.modelArgs = fields?.modelArgs;
}
Expand All @@ -132,20 +123,20 @@ export class AzureMLChatModel extends SimpleChatModel implements AzureMLChatPara
}

_combineLLMOutput(): Record<string, any> | undefined {
return []
return [];
}

async _call(
messages: BaseMessage[],
modelArgs: Record<string, unknown>
): Promise<string> {
const requestPayload = this.contentFormatter.formatRequestPayload(
messages,
modelArgs
messages,
modelArgs
);
const responsePayload = await this.httpClient.call(requestPayload);
const generatedText = this.contentFormatter.formatResponsePayload(responsePayload);
const generatedText =
this.contentFormatter.formatResponsePayload(responsePayload);
return generatedText;
}
}

25 changes: 14 additions & 11 deletions langchain/src/chat_models/tests/chatazure_ml.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import { test, expect } from "@jest/globals";
import { AzureMLChatModel, LlamaContentFormatter } from "../azure_ml.js";
import {
AzureMLChatOnlineEndpoint,
LlamaContentFormatter,
} from "../azure_ml.js";

test("Test AzureML LLama Call", async () => {
const prompt = "Hi Llama!";
const chat = new AzureMLChatModel({
contentFormatter: new LlamaContentFormatter()
});
const res = await chat.call([prompt]);
expect(typeof res).toBe("string");
console.log(res);
});
const prompt = "Hi Llama!";
const chat = new AzureMLChatOnlineEndpoint({
contentFormatter: new LlamaContentFormatter(),
});

const res = await chat.call([prompt]);
expect(typeof res).toBe("string");

console.log(res);
});
Loading