Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added AzureML LLM #1

Merged
merged 11 commits into from
Nov 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api_refs/typedoc.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"./langchain/src/llms/portkey.ts",
"./langchain/src/llms/yandex.ts",
"./langchain/src/llms/fake.ts",
"./langchain/src/llms/azure_ml.ts",
"./langchain/src/prompts/index.ts",
"./langchain/src/prompts/load.ts",
"./langchain/src/vectorstores/analyticdb.ts",
Expand Down
17 changes: 17 additions & 0 deletions docs/core_docs/docs/integrations/llms/azure_ml.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Azure Machine Learning

You can deploy models on Azure and add the endpointUrl, apiKey and deploymentName
when creating the AzureMLModel to call upon later.

```typescript
import { AzureMLModel } from "langchain/llms/azure_ml";

const model = new AzureMLModel({
endpointUrl: "YOUR_ENDPOINT_URL",
endpointApiKey: "YOUR_ENDPOINT_API_KEY",
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME",
});

const res = model.call("Foo");
console.log({ res });
```
11 changes: 11 additions & 0 deletions examples/src/models/llm/azure_ml.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { AzureMLModel } from "langchain/llms/azure_ml";

const model = new AzureMLModel({
endpointUrl: "YOUR_ENDPOINT_URL", // Or set as process.env.AZURE_ML_ENDPOINTURL
endpointApiKey: "YOUR_ENDPOINT_API_KEY", // Or set as process.env.AZURE_ML_APIKEY
deploymentName: "YOUR_MODEL_DEPLOYMENT_NAME", // Or set as process.env.AZURE_ML_NAME
});

const res = model.call("Foo");

console.log({ res });
3 changes: 3 additions & 0 deletions langchain/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ llms/yandex.d.ts
llms/fake.cjs
llms/fake.js
llms/fake.d.ts
llms/azure_ml.cjs
llms/azure_ml.js
llms/azure_ml.d.ts
prompts.cjs
prompts.js
prompts.d.ts
Expand Down
8 changes: 8 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@
"llms/fake.cjs",
"llms/fake.js",
"llms/fake.d.ts",
"llms/azure_ml.cjs",
"llms/azure_ml.js",
"llms/azure_ml.d.ts",
"prompts.cjs",
"prompts.js",
"prompts.d.ts",
Expand Down Expand Up @@ -1763,6 +1766,11 @@
"import": "./llms/fake.js",
"require": "./llms/fake.cjs"
},
"./llms/azure_ml": {
"types": "./llms/azure_ml.d.ts",
"import": "./llms/azure_ml.js",
"require": "./llms/azure_ml.cjs"
},
"./prompts": {
"types": "./prompts.d.ts",
"import": "./prompts.js",
Expand Down
1 change: 1 addition & 0 deletions langchain/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ const requiresOptionalDependency = [
"llms/llama_cpp",
"llms/writer",
"llms/portkey",
"llms/azure_ml",
"prompts/load",
"vectorstores/analyticdb",
"vectorstores/cassandra",
Expand Down
158 changes: 158 additions & 0 deletions langchain/src/llms/azure_ml.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import { BaseLLMParams, LLM } from "./base.js";

/**
* Interface for the AzureML API response.
*/
interface AzureMLResponse {
id: string;
version: string;
created: string;
inputs: {
input_string: string[];
};
parameters: {
[key: string]: string;
};
global_parameters: {
[key: string]: string;
};
output: string;
}

export interface AzureMLInput {
endpointUrl?: string;
endpointApiKey?: string;
deploymentName?: string;
}

/**
* Class that represents an AzureML model. It extends the LLM base class
* and provides methods for calling the AzureML endpoint and formatting
* the request and response payloads.
*/
export class AzureMLModel extends LLM implements AzureMLInput {
_llmType() {
return "azure_ml";
}

static lc_name() {
return "AzureMLModel";
}

static lc_description() {
return "A class for interacting with AzureML models.";
}

static lc_fields() {
return {
endpointUrl: {
lc_description: "The URL of the AzureML endpoint.",
lc_env: "AZUREML_ENDPOINT_URL",
},
endpointApiKey: {
lc_description: "The API key for the AzureML endpoint.",
lc_env: "AZUREML_ENDPOINT_API_KEY",
},
deploymentName: {
lc_description: "The name of the AzureML deployment.",
lc_env: "AZUREML_DEPLOYMENT_NAME",
},
};
}

endpointUrl: string;
endpointApiKey: string;
deploymentName: string;

constructor(fields: AzureMLInput & BaseLLMParams) {
super(fields ?? {});

if (fields?.endpointUrl === undefined) {
throw new Error("No Azure ML endpointUrl found.");
}

if (fields?.endpointApiKey === undefined) {
throw new Error("No Azure ML endpointApiKey found.");
}

if (fields?.deploymentName === undefined) {
throw new Error("No Azure ML deploymentName found.");
}

this.endpointUrl = fields.endpointUrl;
this.endpointApiKey = fields.endpointApiKey;
this.deploymentName = fields.deploymentName;
}

/**
* Formats the request payload for the AzureML endpoint. It takes a
* prompt and a dictionary of model arguments as input and returns a
* string representing the formatted request payload.
* @param prompt The prompt for the AzureML model.
* @param modelArgs A dictionary of model arguments.
* @returns A string representing the formatted request payload.
*/
formatRequestPayload(prompt: string, modelArgs: Record<string, unknown>) {
return JSON.stringify({
inputs: {
input_string: [prompt],
},
parameters: modelArgs,
});
}

/**
* Formats the response payload from the AzureML endpoint. It takes a
* response payload as input and returns a string representing the
* formatted response.
* @param responsePayload The response payload from the AzureML endpoint.
* @returns A string representing the formatted response.
*/
formatResponsePayload(responsePayload: string) {
const response = JSON.parse(responsePayload) as AzureMLResponse;
return response.output;
}

/**
* Calls the AzureML endpoint. It takes a request payload as input and
* returns a Promise that resolves to the response payload.
* @param requestPayload The request payload for the AzureML endpoint.
* @returns A Promise that resolves to the response payload.
*/
async call(requestPayload: string): Promise<string> {
const response = await fetch(this.endpointUrl, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.endpointApiKey}`,
},
body: requestPayload,
});
if (!response.ok) {
const error = new Error(
`Azure ML LLM call failed with status code ${response.status}`
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(error as any).response = response;
throw error;
}
return response.text();
}

/**
* Calls the AzureML endpoint with the provided prompt and model arguments.
* It returns a Promise that resolves to the generated text.
* @param prompt The prompt for the AzureML model.
* @param modelArgs A dictionary of model arguments.
* @param runManager An optional CallbackManagerForLLMRun instance.
* @returns A Promise that resolves to the generated text.
*/
async _call(
prompt: string,
modelArgs: Record<string, unknown>
): Promise<string> {
const requestPayload = this.formatRequestPayload(prompt, modelArgs);
const responsePayload = await this.call(requestPayload);
return this.formatResponsePayload(responsePayload);
}
}
16 changes: 16 additions & 0 deletions langchain/src/llms/tests/azure_ml.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { test, expect } from "@jest/globals";
import { AzureMLModel } from "../azure_ml.js";

test("Test AzureML Model", async () => {
const prompt = "Foo";
const model = new AzureMLModel({
endpointUrl: process.env.AZURE_ML_ENDPOINTURL,
endpointApiKey: process.env.AZURE_ML_APIKEY,
deploymentName: process.env.AZURE_ML_NAME
});

const res = await model.call(prompt);
expect(typeof res).toBe("string");

console.log(res);
});
1 change: 1 addition & 0 deletions langchain/src/load/import_constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export const optionalImportEntrypoints = [
"langchain/llms/llama_cpp",
"langchain/llms/writer",
"langchain/llms/portkey",
"langchain/llms/azure_ml",
"langchain/prompts/load",
"langchain/vectorstores/analyticdb",
"langchain/vectorstores/cassandra",
Expand Down
3 changes: 3 additions & 0 deletions langchain/src/load/import_type.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ export interface OptionalImportMap {
"langchain/llms/portkey"?:
| typeof import("../llms/portkey.js")
| Promise<typeof import("../llms/portkey.js")>;
"langchain/llms/azure_ml"?:
| typeof import("../llms/azure_ml.js")
| Promise<typeof import("../llms/azure_ml.js")>;
"langchain/prompts/load"?:
| typeof import("../prompts/load.js")
| Promise<typeof import("../prompts/load.js")>;
Expand Down