Skip to content

Commit

Permalink
Added requested changes (#4)
Browse files Browse the repository at this point in the history
* Made requesed changes

* Formatted

---------

Co-authored-by: Vis <vishakanshanthakumar@gmail.com>
  • Loading branch information
ericckzhou and univish authored Dec 17, 2023
1 parent 3af7179 commit 4207132
Show file tree
Hide file tree
Showing 8 changed files with 244 additions and 197 deletions.
21 changes: 12 additions & 9 deletions docs/core_docs/docs/integrations/chat/azure_ml.mdx
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# Azure Machine Learning Chat

You can deploy models on Azure with the endpointUrl, apiKey, and deploymentName
You must deploy models on Azure with the endpointUrl, apiKey, and deploymentName
when creating the AzureMLChatParams to call upon later. Must import a ContentFormatter
or create your own using the ChatContentFormatter interface.

```typescript
import { AzureMLChatParams, LlamaContentFormatter } from "langchain/chat_models/azure_ml";
import {
AzureMLChatParams,
LlamaContentFormatter,
} from "langchain/chat_models/azure_ml";

const model = new AzureMLModel({
endpointUrl: "YOUR_ENDPOINT_URL",
endpointApiKey: "YOUR_ENDPOINT_API_KEY",
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME",
contentFormatter: new LlamaContentFormatter()
const model = new AzureMLOnlineEndpoint({
endpointUrl: "YOUR_ENDPOINT_URL",
endpointApiKey: "YOUR_ENDPOINT_API_KEY",
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME",
contentFormatter: new LlamaContentFormatter(),
});

const res = model.call(["Foo"]);
const res = model.invoke(["Foo"]);
console.log({ res });
```
```
23 changes: 13 additions & 10 deletions docs/core_docs/docs/integrations/llms/azure_ml.mdx
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# Azure Machine Learning

You can deploy models on Azure with the endpointUrl, apiKey, and deploymentName
when creating the AzureMLModel to call upon later. Must import a ContentFormatter
You must deploy models on Azure with the endpointUrl, apiKey, and deploymentName
when creating the AzureMLOnlineEndpoint to call upon later. Must import a ContentFormatter
or create your own using the ContentFormatter interface.

```typescript
import { AzureMLModel, LlamaContentFormatter } from "langchain/llms/azure_ml";
import {
AzureMLOnlineEndpoint,
LlamaContentFormatter,
} from "langchain/llms/azure_ml";

const model = new AzureMLModel({
endpointUrl: "YOUR_ENDPOINT_URL",
endpointApiKey: "YOUR_ENDPOINT_API_KEY",
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME",
contentFormatter: new LlamaContentFormatter()
const model = new AzureMLOnlineEndpoint({
endpointUrl: "YOUR_ENDPOINT_URL",
endpointApiKey: "YOUR_ENDPOINT_API_KEY",
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME",
contentFormatter: new LlamaContentFormatter(),
});

const res = model.call("Foo");
const res = model.invoke("Foo");
console.log({ res });
```
```
18 changes: 10 additions & 8 deletions examples/src/models/chat/chat_azure_ml.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import { AzureMLChatModel, LlamaContentFormatter } from "langchain/chat_models/azure_ml";
import {
AzureMLChatOnlineEndpoint,
LlamaContentFormatter,
} from "langchain/chat_models/azure_ml";

const model = new AzureMLChatModel({
endpointUrl: "YOUR_ENDPOINT_URL", // Or set as process.env.AZURE_ML_ENDPOINTURL
endpointApiKey: "YOUR_ENDPOINT_API_KEY", // Or set as process.env.AZURE_ML_APIKEY
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME", // Or set as process.env.AZURE_ML_NAME
contentFormatter: new LlamaContentFormatter(), // Only LLAMA currently supported.
const model = new AzureMLChatOnlineEndpoint({
endpointUrl: "YOUR_ENDPOINT_URL", // Or set as process.env.AZURE_ML_ENDPOINTURL
endpointApiKey: "YOUR_ENDPOINT_API_KEY", // Or set as process.env.AZURE_ML_APIKEY
contentFormatter: new LlamaContentFormatter(), // Only LLAMA currently supported.
});

const res = model.call("Foo");
const res = model.invoke("Foo");

console.log({ res });
console.log({ res });
19 changes: 11 additions & 8 deletions examples/src/models/llm/azure_ml.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import { AzureMLModel, LlamaContentFormatter } from "langchain/llms/azure_ml";
import {
AzureMLOnlineEndpoint,
LlamaContentFormatter,
} from "langchain/llms/azure_ml";

const model = new AzureMLModel({
endpointUrl: "YOUR_ENDPOINT_URL", // Or set as process.env.AZURE_ML_ENDPOINTURL
endpointApiKey: "YOUR_ENDPOINT_API_KEY", // Or set as process.env.AZURE_ML_APIKEY
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME", // Or set as process.env.AZURE_ML_NAME
contentFormatter: new LlamaContentFormatter(), // Or any of the other Models: GPT2ContentFormatter, HFContentFormatter, DollyContentFormatter
const model = new AzureMLOnlineEndpoint({
endpointUrl: "YOUR_ENDPOINT_URL", // Or set as process.env.AZURE_ML_ENDPOINTURL
endpointApiKey: "YOUR_ENDPOINT_API_KEY", // Or set as process.env.AZURE_ML_APIKEY
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME", // Or set as process.env.AZURE_ML_NAME
contentFormatter: new LlamaContentFormatter(), // Or any of the other Models: GPT2ContentFormatter, HFContentFormatter, DollyContentFormatter
});

const res = model.call("Foo");
const res = model.invoke("Foo");

console.log({ res });
console.log({ res });
125 changes: 58 additions & 67 deletions langchain/src/chat_models/azure_ml.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@ import { getEnvironmentVariable } from "../util/env.js";
import { BaseMessage } from "../schema/index.js";

export interface ChatContentFormatter {
/**
/**
* Formats the request payload for the AzureML endpoint. It takes a
* prompt and a dictionary of model arguments as input and returns a
* string representing the formatted request payload.
* @param messages A list of messages for the chat so far.
* @param modelArgs A dictionary of model arguments.
* @returns A string representing the formatted request payload.
*/
formatRequestPayload:(messages:BaseMessage[], modelArgs:Record<string, unknown>) => string;
formatRequestPayload: (
messages: BaseMessage[],
modelArgs: Record<string, unknown>
) => string;
/**
* Formats the response payload from the AzureML endpoint. It takes a
* response payload as input and returns a string representing the
Expand All @@ -24,98 +27,86 @@ export interface ChatContentFormatter {
}

export class LlamaContentFormatter implements ChatContentFormatter {
_convertMessageToRecord(message:BaseMessage):Record<string, unknown> {
if (message._getType() === 'human') {
return {role: "user", content: message.content}
} else if (message._getType() === 'ai') {
return {role: "assistant", content: message.content}
} else {
return {role: message._getType(), content: message.content}
}
_convertMessageToRecord(message: BaseMessage): Record<string, unknown> {
if (message._getType() === "human") {
return { role: "user", content: message.content };
} else if (message._getType() === "ai") {
return { role: "assistant", content: message.content };
} else {
return { role: message._getType(), content: message.content };
}
}

formatRequestPayload(
messages: BaseMessage[],
modelArgs: Record<string, unknown>
): string {
let msgs = messages.map(message => {
this._convertMessageToRecord(message)
});
return JSON.stringify(
{"input_data": {
"input_string": msgs,
"parameters": modelArgs
}}
)
}
formatRequestPayload(
messages: BaseMessage[],
modelArgs: Record<string, unknown>
): string {
let msgs = messages.map((message) => {
this._convertMessageToRecord(message);
});
return JSON.stringify({
input_data: {
input_string: msgs,
parameters: modelArgs,
},
});
}

formatResponsePayload(
responsePayload: string
) {
const response = JSON.parse(responsePayload);
return response.output
}
formatResponsePayload(responsePayload: string) {
const response = JSON.parse(responsePayload);
return response.output;
}
}

/**
* Type definition for the input parameters of the AzureMLChatOnlineEndpoint class.
*/
export interface AzureMLChatParams extends BaseChatModelParams {
endpointUrl?: string;
endpointApiKey?: string;
modelArgs?: Record<string, unknown>;
contentFormatter?: ChatContentFormatter;
};

export interface AzureMLChatParams extends BaseChatModelParams {
endpointUrl?: string;
endpointApiKey?: string;
modelArgs?: Record<string, unknown>;
contentFormatter?: ChatContentFormatter;
}

/**
* Class that represents the chat model. It extends the SimpleChatModel class and implements the AzureMLChatInput interface.
*/
export class AzureMLChatModel extends SimpleChatModel implements AzureMLChatParams {
export class AzureMLChatOnlineEndpoint
extends SimpleChatModel
implements AzureMLChatParams
{
static lc_name() {
return "AzureMLChat";
}
static lc_description() {
return "A class for interacting with AzureML Chat models.";
}

static lc_fields() {
return {
endpointUrl: {
lc_description: "The URL of the AzureML endpoint.",
lc_env: "AZUREML_URL",
},
endpointApiKey: {
lc_description: "The API key for the AzureML endpoint.",
lc_env: "AZUREML_API_KEY",
},
contentFormatter: {
lc_description: "The formatter for AzureML API",
}
};
}
endpointUrl: string;
endpointApiKey: string;
modelArgs?: Record<string, unknown>;
contentFormatter: ChatContentFormatter;
httpClient: AzureMLHttpClient;


constructor(fields: AzureMLChatParams) {
super(fields ?? {});
if (!fields?.endpointUrl && !getEnvironmentVariable('AZUREML_URL')) {
if (!fields?.endpointUrl && !getEnvironmentVariable("AZUREML_URL")) {
throw new Error("No Azure ML Url found.");
}
if (!fields?.endpointApiKey && !getEnvironmentVariable('AZUREML_API_KEY')) {
if (!fields?.endpointApiKey && !getEnvironmentVariable("AZUREML_API_KEY")) {
throw new Error("No Azure ML ApiKey found.");
}
if (!fields?.contentFormatter) {
throw new Error("No Content Formatter provided.")
throw new Error("No Content Formatter provided.");
}

this.endpointUrl = fields.endpointUrl || getEnvironmentVariable('AZUREML_URL')+'';
this.endpointApiKey = fields.endpointApiKey || getEnvironmentVariable('AZUREML_API_KEY')+'';
this.httpClient = new AzureMLHttpClient(this.endpointUrl, this.endpointApiKey);

this.endpointUrl =
fields.endpointUrl || getEnvironmentVariable("AZUREML_URL") + "";
this.endpointApiKey =
fields.endpointApiKey || getEnvironmentVariable("AZUREML_API_KEY") + "";
this.httpClient = new AzureMLHttpClient(
this.endpointUrl,
this.endpointApiKey
);
this.contentFormatter = fields.contentFormatter;
this.modelArgs = fields?.modelArgs;
}
Expand All @@ -132,20 +123,20 @@ export class AzureMLChatModel extends SimpleChatModel implements AzureMLChatPara
}

_combineLLMOutput(): Record<string, any> | undefined {
return []
return [];
}

async _call(
messages: BaseMessage[],
modelArgs: Record<string, unknown>
): Promise<string> {
const requestPayload = this.contentFormatter.formatRequestPayload(
messages,
modelArgs
messages,
modelArgs
);
const responsePayload = await this.httpClient.call(requestPayload);
const generatedText = this.contentFormatter.formatResponsePayload(responsePayload);
const generatedText =
this.contentFormatter.formatResponsePayload(responsePayload);
return generatedText;
}
}

25 changes: 14 additions & 11 deletions langchain/src/chat_models/tests/chatazure_ml.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import { test, expect } from "@jest/globals";
import { AzureMLChatModel, LlamaContentFormatter } from "../azure_ml.js";
import {
AzureMLChatOnlineEndpoint,
LlamaContentFormatter,
} from "../azure_ml.js";

test("Test AzureML LLama Call", async () => {
const prompt = "Hi Llama!";
const chat = new AzureMLChatModel({
contentFormatter: new LlamaContentFormatter()
});
const res = await chat.call([prompt]);
expect(typeof res).toBe("string");
console.log(res);
});
const prompt = "Hi Llama!";
const chat = new AzureMLChatOnlineEndpoint({
contentFormatter: new LlamaContentFormatter(),
});

const res = await chat.call([prompt]);
expect(typeof res).toBe("string");

console.log(res);
});
Loading

0 comments on commit 4207132

Please sign in to comment.