diff --git a/docs/core_docs/docs/integrations/chat/azure_ml.mdx b/docs/core_docs/docs/integrations/chat/azure_ml.mdx index a82d6f20406a..14563f403793 100644 --- a/docs/core_docs/docs/integrations/chat/azure_ml.mdx +++ b/docs/core_docs/docs/integrations/chat/azure_ml.mdx @@ -1,19 +1,22 @@ # Azure Machine Learning Chat -You can deploy models on Azure with the endpointUrl, apiKey, and deploymentName +You must deploy models on Azure with the endpointUrl, apiKey, and deploymentName when creating the AzureMLChatParams to call upon later. Must import a ContentFormatter or create your own using the ChatContentFormatter interface. ```typescript -import { AzureMLChatParams, LlamaContentFormatter } from "langchain/chat_models/azure_ml"; +import { + AzureMLChatParams, + LlamaContentFormatter, +} from "langchain/chat_models/azure_ml"; -const model = new AzureMLModel({ - endpointUrl: "YOUR_ENDPOINT_URL", - endpointApiKey: "YOUR_ENDPOINT_API_KEY", - deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME", - contentFormatter: new LlamaContentFormatter() +const model = new AzureMLOnlineEndpoint({ + endpointUrl: "YOUR_ENDPOINT_URL", + endpointApiKey: "YOUR_ENDPOINT_API_KEY", + deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME", + contentFormatter: new LlamaContentFormatter(), }); -const res = model.call(["Foo"]); +const res = model.invoke(["Foo"]); console.log({ res }); -``` \ No newline at end of file +``` diff --git a/docs/core_docs/docs/integrations/llms/azure_ml.mdx b/docs/core_docs/docs/integrations/llms/azure_ml.mdx index b8b26b88b6a6..fef9a3cd183b 100644 --- a/docs/core_docs/docs/integrations/llms/azure_ml.mdx +++ b/docs/core_docs/docs/integrations/llms/azure_ml.mdx @@ -1,19 +1,22 @@ # Azure Machine Learning -You can deploy models on Azure with the endpointUrl, apiKey, and deploymentName -when creating the AzureMLModel to call upon later. Must import a ContentFormatter +You must deploy models on Azure with the endpointUrl, apiKey, and deploymentName +when creating the AzureMLOnlineEndpoint to call upon later. Must import a ContentFormatter or create your own using the ContentFormatter interface. ```typescript -import { AzureMLModel, LlamaContentFormatter } from "langchain/llms/azure_ml"; +import { + AzureMLOnlineEndpoint, + LlamaContentFormatter, +} from "langchain/llms/azure_ml"; -const model = new AzureMLModel({ - endpointUrl: "YOUR_ENDPOINT_URL", - endpointApiKey: "YOUR_ENDPOINT_API_KEY", - deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME", - contentFormatter: new LlamaContentFormatter() +const model = new AzureMLOnlineEndpoint({ + endpointUrl: "YOUR_ENDPOINT_URL", + endpointApiKey: "YOUR_ENDPOINT_API_KEY", + deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME", + contentFormatter: new LlamaContentFormatter(), }); -const res = model.call("Foo"); +const res = model.invoke("Foo"); console.log({ res }); -``` \ No newline at end of file +``` diff --git a/examples/src/models/chat/chat_azure_ml.ts b/examples/src/models/chat/chat_azure_ml.ts index beefe9b8b7be..4391e267071f 100644 --- a/examples/src/models/chat/chat_azure_ml.ts +++ b/examples/src/models/chat/chat_azure_ml.ts @@ -1,12 +1,14 @@ -import { AzureMLChatModel, LlamaContentFormatter } from "langchain/chat_models/azure_ml"; +import { + AzureMLChatOnlineEndpoint, + LlamaContentFormatter, +} from "langchain/chat_models/azure_ml"; -const model = new AzureMLChatModel({ - endpointUrl: "YOUR_ENDPOINT_URL", // Or set as process.env.AZURE_ML_ENDPOINTURL - endpointApiKey: "YOUR_ENDPOINT_API_KEY", // Or set as process.env.AZURE_ML_APIKEY - deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME", // Or set as process.env.AZURE_ML_NAME - contentFormatter: new LlamaContentFormatter(), // Only LLAMA currently supported. +const model = new AzureMLChatOnlineEndpoint({ + endpointUrl: "YOUR_ENDPOINT_URL", // Or set as process.env.AZURE_ML_ENDPOINTURL + endpointApiKey: "YOUR_ENDPOINT_API_KEY", // Or set as process.env.AZURE_ML_APIKEY + contentFormatter: new LlamaContentFormatter(), // Only LLAMA currently supported. }); -const res = model.call("Foo"); +const res = model.invoke("Foo"); -console.log({ res }); \ No newline at end of file +console.log({ res }); diff --git a/examples/src/models/llm/azure_ml.ts b/examples/src/models/llm/azure_ml.ts index 9fb24900645c..eb14bbdb776d 100644 --- a/examples/src/models/llm/azure_ml.ts +++ b/examples/src/models/llm/azure_ml.ts @@ -1,12 +1,15 @@ -import { AzureMLModel, LlamaContentFormatter } from "langchain/llms/azure_ml"; +import { + AzureMLOnlineEndpoint, + LlamaContentFormatter, +} from "langchain/llms/azure_ml"; -const model = new AzureMLModel({ - endpointUrl: "YOUR_ENDPOINT_URL", // Or set as process.env.AZURE_ML_ENDPOINTURL - endpointApiKey: "YOUR_ENDPOINT_API_KEY", // Or set as process.env.AZURE_ML_APIKEY - deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME", // Or set as process.env.AZURE_ML_NAME - contentFormatter: new LlamaContentFormatter(), // Or any of the other Models: GPT2ContentFormatter, HFContentFormatter, DollyContentFormatter +const model = new AzureMLOnlineEndpoint({ + endpointUrl: "YOUR_ENDPOINT_URL", // Or set as process.env.AZURE_ML_ENDPOINTURL + endpointApiKey: "YOUR_ENDPOINT_API_KEY", // Or set as process.env.AZURE_ML_APIKEY + deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME", // Or set as process.env.AZURE_ML_NAME + contentFormatter: new LlamaContentFormatter(), // Or any of the other Models: GPT2ContentFormatter, HFContentFormatter, DollyContentFormatter }); -const res = model.call("Foo"); +const res = model.invoke("Foo"); -console.log({ res }); \ No newline at end of file +console.log({ res }); diff --git a/langchain/src/chat_models/azure_ml.ts b/langchain/src/chat_models/azure_ml.ts index 0acdf77ee3a1..b8b3254875c6 100644 --- a/langchain/src/chat_models/azure_ml.ts +++ b/langchain/src/chat_models/azure_ml.ts @@ -4,7 +4,7 @@ import { getEnvironmentVariable } from "../util/env.js"; import { BaseMessage } from "../schema/index.js"; export interface ChatContentFormatter { - /** + /** * Formats the request payload for the AzureML endpoint. It takes a * prompt and a dictionary of model arguments as input and returns a * string representing the formatted request payload. @@ -12,7 +12,10 @@ export interface ChatContentFormatter { * @param modelArgs A dictionary of model arguments. * @returns A string representing the formatted request payload. */ - formatRequestPayload:(messages:BaseMessage[], modelArgs:Record) => string; + formatRequestPayload: ( + messages: BaseMessage[], + modelArgs: Record + ) => string; /** * Formats the response payload from the AzureML endpoint. It takes a * response payload as input and returns a string representing the @@ -24,98 +27,86 @@ export interface ChatContentFormatter { } export class LlamaContentFormatter implements ChatContentFormatter { - _convertMessageToRecord(message:BaseMessage):Record { - if (message._getType() === 'human') { - return {role: "user", content: message.content} - } else if (message._getType() === 'ai') { - return {role: "assistant", content: message.content} - } else { - return {role: message._getType(), content: message.content} - } + _convertMessageToRecord(message: BaseMessage): Record { + if (message._getType() === "human") { + return { role: "user", content: message.content }; + } else if (message._getType() === "ai") { + return { role: "assistant", content: message.content }; + } else { + return { role: message._getType(), content: message.content }; } + } - formatRequestPayload( - messages: BaseMessage[], - modelArgs: Record - ): string { - let msgs = messages.map(message => { - this._convertMessageToRecord(message) - }); - return JSON.stringify( - {"input_data": { - "input_string": msgs, - "parameters": modelArgs - }} - ) - } + formatRequestPayload( + messages: BaseMessage[], + modelArgs: Record + ): string { + let msgs = messages.map((message) => { + this._convertMessageToRecord(message); + }); + return JSON.stringify({ + input_data: { + input_string: msgs, + parameters: modelArgs, + }, + }); + } - formatResponsePayload( - responsePayload: string - ) { - const response = JSON.parse(responsePayload); - return response.output - } + formatResponsePayload(responsePayload: string) { + const response = JSON.parse(responsePayload); + return response.output; + } } /** * Type definition for the input parameters of the AzureMLChatOnlineEndpoint class. */ -export interface AzureMLChatParams extends BaseChatModelParams { - endpointUrl?: string; - endpointApiKey?: string; - modelArgs?: Record; - contentFormatter?: ChatContentFormatter; - }; - +export interface AzureMLChatParams extends BaseChatModelParams { + endpointUrl?: string; + endpointApiKey?: string; + modelArgs?: Record; + contentFormatter?: ChatContentFormatter; +} /** * Class that represents the chat model. It extends the SimpleChatModel class and implements the AzureMLChatInput interface. */ -export class AzureMLChatModel extends SimpleChatModel implements AzureMLChatParams { +export class AzureMLChatOnlineEndpoint + extends SimpleChatModel + implements AzureMLChatParams +{ static lc_name() { return "AzureMLChat"; } static lc_description() { return "A class for interacting with AzureML Chat models."; } - - static lc_fields() { - return { - endpointUrl: { - lc_description: "The URL of the AzureML endpoint.", - lc_env: "AZUREML_URL", - }, - endpointApiKey: { - lc_description: "The API key for the AzureML endpoint.", - lc_env: "AZUREML_API_KEY", - }, - contentFormatter: { - lc_description: "The formatter for AzureML API", - } - }; - } endpointUrl: string; endpointApiKey: string; modelArgs?: Record; contentFormatter: ChatContentFormatter; httpClient: AzureMLHttpClient; - constructor(fields: AzureMLChatParams) { super(fields ?? {}); - if (!fields?.endpointUrl && !getEnvironmentVariable('AZUREML_URL')) { + if (!fields?.endpointUrl && !getEnvironmentVariable("AZUREML_URL")) { throw new Error("No Azure ML Url found."); } - if (!fields?.endpointApiKey && !getEnvironmentVariable('AZUREML_API_KEY')) { + if (!fields?.endpointApiKey && !getEnvironmentVariable("AZUREML_API_KEY")) { throw new Error("No Azure ML ApiKey found."); } if (!fields?.contentFormatter) { - throw new Error("No Content Formatter provided.") + throw new Error("No Content Formatter provided."); } - - this.endpointUrl = fields.endpointUrl || getEnvironmentVariable('AZUREML_URL')+''; - this.endpointApiKey = fields.endpointApiKey || getEnvironmentVariable('AZUREML_API_KEY')+''; - this.httpClient = new AzureMLHttpClient(this.endpointUrl, this.endpointApiKey); + + this.endpointUrl = + fields.endpointUrl || getEnvironmentVariable("AZUREML_URL") + ""; + this.endpointApiKey = + fields.endpointApiKey || getEnvironmentVariable("AZUREML_API_KEY") + ""; + this.httpClient = new AzureMLHttpClient( + this.endpointUrl, + this.endpointApiKey + ); this.contentFormatter = fields.contentFormatter; this.modelArgs = fields?.modelArgs; } @@ -132,7 +123,7 @@ export class AzureMLChatModel extends SimpleChatModel implements AzureMLChatPara } _combineLLMOutput(): Record | undefined { - return [] + return []; } async _call( @@ -140,12 +131,12 @@ export class AzureMLChatModel extends SimpleChatModel implements AzureMLChatPara modelArgs: Record ): Promise { const requestPayload = this.contentFormatter.formatRequestPayload( - messages, - modelArgs + messages, + modelArgs ); const responsePayload = await this.httpClient.call(requestPayload); - const generatedText = this.contentFormatter.formatResponsePayload(responsePayload); + const generatedText = + this.contentFormatter.formatResponsePayload(responsePayload); return generatedText; } } - diff --git a/langchain/src/chat_models/tests/chatazure_ml.int.test.ts b/langchain/src/chat_models/tests/chatazure_ml.int.test.ts index b58112386744..a1f261249184 100644 --- a/langchain/src/chat_models/tests/chatazure_ml.int.test.ts +++ b/langchain/src/chat_models/tests/chatazure_ml.int.test.ts @@ -1,14 +1,17 @@ import { test, expect } from "@jest/globals"; -import { AzureMLChatModel, LlamaContentFormatter } from "../azure_ml.js"; +import { + AzureMLChatOnlineEndpoint, + LlamaContentFormatter, +} from "../azure_ml.js"; test("Test AzureML LLama Call", async () => { - const prompt = "Hi Llama!"; - const chat = new AzureMLChatModel({ - contentFormatter: new LlamaContentFormatter() - }); - - const res = await chat.call([prompt]); - expect(typeof res).toBe("string"); - - console.log(res); -}); \ No newline at end of file + const prompt = "Hi Llama!"; + const chat = new AzureMLChatOnlineEndpoint({ + contentFormatter: new LlamaContentFormatter(), + }); + + const res = await chat.call([prompt]); + expect(typeof res).toBe("string"); + + console.log(res); +}); diff --git a/langchain/src/llms/azure_ml.ts b/langchain/src/llms/azure_ml.ts index f186ec1f23c2..bcfe1cf40410 100644 --- a/langchain/src/llms/azure_ml.ts +++ b/langchain/src/llms/azure_ml.ts @@ -6,7 +6,11 @@ export class AzureMLHttpClient { endpointApiKey: string; deploymentName?: string; - constructor(endpointUrl: string, endpointApiKey: string, deploymentName?: string) { + constructor( + endpointUrl: string, + endpointApiKey: string, + deploymentName?: string + ) { this.deploymentName = deploymentName; this.endpointApiKey = endpointApiKey; this.endpointUrl = endpointUrl; @@ -18,29 +22,29 @@ export class AzureMLHttpClient { * @param requestPayload The request payload for the AzureML endpoint. * @returns A Promise that resolves to the response payload. */ - async call(requestPayload: string): Promise { - const response = await fetch(this.endpointUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${this.endpointApiKey}`, - }, - body: requestPayload, - }); - if (!response.ok) { - const error = new Error( - `Azure ML LLM call failed with status code ${response.status}` - ); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error as any).response = response; - throw error; - } - return response.text(); + async call(requestPayload: string): Promise { + const response = await fetch(this.endpointUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${this.endpointApiKey}`, + }, + body: requestPayload, + }); + if (!response.ok) { + const error = new Error( + `Azure ML LLM call failed with status code ${response.status}` + ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any).response = response; + throw error; } + return response.text(); + } } export interface ContentFormatter { - /** + /** * Formats the request payload for the AzureML endpoint. It takes a * prompt and a dictionary of model arguments as input and returns a * string representing the formatted request payload. @@ -48,7 +52,10 @@ export interface ContentFormatter { * @param modelArgs A dictionary of model arguments. * @returns A string representing the formatted request payload. */ - formatRequestPayload:(prompt:string, modelArgs:Record) => string; + formatRequestPayload: ( + prompt: string, + modelArgs: Record + ) => string; /** * Formats the response payload from the AzureML endpoint. It takes a * response payload as input and returns a string representing the @@ -60,57 +67,69 @@ export interface ContentFormatter { } export class GPT2ContentFormatter implements ContentFormatter { - formatRequestPayload(prompt: string, modelArgs: Record): string { + formatRequestPayload( + prompt: string, + modelArgs: Record + ): string { return JSON.stringify({ inputs: { input_string: [prompt], }, parameters: modelArgs, }); - }; + } formatResponsePayload(output: string): string { - return JSON.parse(output)[0]["0"] - }; + return JSON.parse(output)[0]["0"]; + } } export class HFContentFormatter implements ContentFormatter { - formatRequestPayload(prompt: string, modelArgs: Record): string { + formatRequestPayload( + prompt: string, + modelArgs: Record + ): string { return JSON.stringify({ inputs: [prompt], parameters: modelArgs, }); - }; + } formatResponsePayload(output: string): string { - return JSON.parse(output)[0]["generated_text"] - }; + return JSON.parse(output)[0]["generated_text"]; + } } export class DollyContentFormatter implements ContentFormatter { - formatRequestPayload(prompt: string, modelArgs: Record): string { + formatRequestPayload( + prompt: string, + modelArgs: Record + ): string { return JSON.stringify({ input_data: { input_string: [prompt], }, parameters: modelArgs, }); - }; + } formatResponsePayload(output: string): string { - return JSON.parse(output)[0] - }; + return JSON.parse(output)[0]; + } } export class LlamaContentFormatter implements ContentFormatter { - formatRequestPayload(prompt: string, modelArgs: Record): string { + formatRequestPayload( + prompt: string, + modelArgs: Record + ): string { return JSON.stringify({ input_data: { input_string: [prompt], }, parameters: modelArgs, }); - }; + } formatResponsePayload(output: string): string { - return JSON.parse(output)[0]["0"] - }; + return JSON.parse(output)[0]["0"]; + } } export interface AzureMLParams extends BaseLLMParams { @@ -126,12 +145,12 @@ export interface AzureMLParams extends BaseLLMParams { * and provides methods for calling the AzureML endpoint and formatting * the request and response payloads. */ -export class AzureMLModel extends LLM implements AzureMLParams { +export class AzureMLOnlineEndpoint extends LLM implements AzureMLParams { _llmType() { return "azure_ml"; } static lc_name() { - return "AzureMLModel"; + return "AzureMLOnlineEndpoint"; } static lc_description() { return "A class for interacting with AzureML models."; @@ -151,7 +170,7 @@ export class AzureMLModel extends LLM implements AzureMLParams { }, contentFormatter: { lc_description: "The formatter for AzureML API", - } + }, }; } @@ -164,20 +183,26 @@ export class AzureMLModel extends LLM implements AzureMLParams { constructor(fields: AzureMLParams) { super(fields ?? {}); - if (!fields?.endpointUrl && !getEnvironmentVariable('AZUREML_URL')) { + if (!fields?.endpointUrl && !getEnvironmentVariable("AZUREML_URL")) { throw new Error("No Azure ML Url found."); } - if (!fields?.endpointApiKey && !getEnvironmentVariable('AZUREML_API_KEY')) { + if (!fields?.endpointApiKey && !getEnvironmentVariable("AZUREML_API_KEY")) { throw new Error("No Azure ML ApiKey found."); } if (!fields?.contentFormatter) { - throw new Error("No Content Formatter provided.") + throw new Error("No Content Formatter provided."); } - this.endpointUrl = fields.endpointUrl || getEnvironmentVariable('AZUREML_URL')+''; - this.endpointApiKey = fields.endpointApiKey || getEnvironmentVariable('AZUREML_API_KEY')+''; + this.endpointUrl = + fields.endpointUrl || getEnvironmentVariable("AZUREML_URL") + ""; + this.endpointApiKey = + fields.endpointApiKey || getEnvironmentVariable("AZUREML_API_KEY") + ""; this.deploymentName = fields.deploymentName; - this.httpClient = new AzureMLHttpClient(this.endpointUrl, this.endpointApiKey, this.deploymentName); + this.httpClient = new AzureMLHttpClient( + this.endpointUrl, + this.endpointApiKey, + this.deploymentName + ); this.contentFormatter = fields.contentFormatter; this.modelArgs = fields.modelArgs; } @@ -193,7 +218,10 @@ export class AzureMLModel extends LLM implements AzureMLParams { prompt: string, modelArgs: Record ): Promise { - const requestPayload = this.contentFormatter.formatRequestPayload(prompt, modelArgs); + const requestPayload = this.contentFormatter.formatRequestPayload( + prompt, + modelArgs + ); const responsePayload = await this.httpClient.call(requestPayload); return this.contentFormatter.formatResponsePayload(responsePayload); } diff --git a/langchain/src/llms/tests/azure_ml.int.test.ts b/langchain/src/llms/tests/azure_ml.int.test.ts index ff08c15726fd..abc516ee2e72 100644 --- a/langchain/src/llms/tests/azure_ml.int.test.ts +++ b/langchain/src/llms/tests/azure_ml.int.test.ts @@ -1,50 +1,64 @@ import { test, expect } from "@jest/globals"; -import { AzureMLModel, DollyContentFormatter, GPT2ContentFormatter, HFContentFormatter, LlamaContentFormatter } from "../azure_ml.js"; +import { + AzureMLOnlineEndpoint, + DollyContentFormatter, + GPT2ContentFormatter, + HFContentFormatter, + LlamaContentFormatter, +} from "../azure_ml.js"; +/* LLama Test test("Test AzureML LLama Call", async () => { - const prompt = "What is the meaning of Foo?"; - const model = new AzureMLModel({ - contentFormatter: new LlamaContentFormatter() - }); - - const res = await model.call(prompt); - expect(typeof res).toBe("string"); - - console.log(res); + const prompt = "What is the meaning of Foo?"; + const model = new AzureMLOnlineEndpoint({ + contentFormatter: new LlamaContentFormatter(), + }); + + const res = await model.call(prompt); + expect(typeof res).toBe("string"); + + console.log(res); }); +*/ +/* GPT2 Test test("Test AzureML GPT2 Call", async () => { - const prompt = "What is the meaning of Foo?"; - const model = new AzureMLModel({ - contentFormatter: new GPT2ContentFormatter() - }); - - const res = await model.call(prompt); - expect(typeof res).toBe("string"); - - console.log(res); + const prompt = "What is the meaning of Foo?"; + const model = new AzureMLOnlineEndpoint({ + contentFormatter: new GPT2ContentFormatter(), + }); + + const res = await model.call(prompt); + expect(typeof res).toBe("string"); + + console.log(res); }); +*/ +/* HF Test test("Test AzureML HF Call", async () => { - const prompt = "What is the meaning of Foo?"; - const model = new AzureMLModel({ - contentFormatter: new HFContentFormatter() - }); - - const res = await model.call(prompt); - expect(typeof res).toBe("string"); - - console.log(res); + const prompt = "What is the meaning of Foo?"; + const model = new AzureMLOnlineEndpoint({ + contentFormatter: new HFContentFormatter(), + }); + + const res = await model.call(prompt); + expect(typeof res).toBe("string"); + + console.log(res); }); +*/ +/* Dolly Test test("Test AzureML Dolly Call", async () => { - const prompt = "What is the meaning of Foo?"; - const model = new AzureMLModel({ - contentFormatter: new DollyContentFormatter() - }); - - const res = await model.call(prompt); - expect(typeof res).toBe("string"); - - console.log(res); + const prompt = "What is the meaning of Foo?"; + const model = new AzureMLOnlineEndpoint({ + contentFormatter: new DollyContentFormatter(), + }); + + const res = await model.call(prompt); + expect(typeof res).toBe("string"); + + console.log(res); }); +*/