Skip to content

Commit

Permalink
streamlined adding and removing hooks for httpclient
Browse files Browse the repository at this point in the history
  • Loading branch information
TripleADC committed Nov 9, 2024
1 parent a6eb0b3 commit 2db1dd4
Show file tree
Hide file tree
Showing 3 changed files with 311 additions and 5 deletions.
103 changes: 102 additions & 1 deletion libs/langchain-mistralai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { ChatCompletionStreamRequest as MistralAIChatCompletionStreamRequest } f
import { UsageInfo as MistralAITokenUsage } from "@mistralai/mistralai/models/components/usageinfo.js";
import { CompletionEvent as MistralAIChatCompletionEvent } from "@mistralai/mistralai/models/components/completionevent.js";
import { ChatCompletionResponse as MistralAIChatCompletionResponse } from "@mistralai/mistralai/models/components/chatcompletionresponse.js";
import { HTTPClient as MistralAIHTTPClient} from "@mistralai/mistralai/lib/http.js";
import { BeforeRequestHook, HTTPClient as MistralAIHTTPClient, RequestErrorHook, ResponseHook} from "@mistralai/mistralai/lib/http.js";
import { RetryConfig as MistralAIRetryConfig } from "@mistralai/mistralai/lib/retries.js";
import {
MessageType,
Expand Down Expand Up @@ -162,6 +162,24 @@ export interface ChatMistralAIInput
* The seed to use for random sampling. If set, different calls will generate deterministic results.
*/
seed?: number;
/**
* A list of custom hooks that must follow (req: Request) => Awaitable<Request | void>
* They are automatically added when a ChatMistralAI Object is created
* @default {[]}
*/
beforeRequestHooks?: Array<BeforeRequestHook>;
/**
* A list of custom hooks that must follow (err: unknown, req: Request) => Awaitable<void>
* They are automatically added when a ChatMistralAI Object is created
* @default {[]}
*/
requestErrorHooks?: Array<RequestErrorHook>;
/**
* A list of custom hooks that must follow (res: Response, req: Request) => Awaitable<void>
* They are automatically added when a ChatMistralAI Object is created
* @default {[]}
*/
responseHooks?: Array<ResponseHook>;
/**
* Custom HTTP client to manage API requests
* Allows users to add custom fetch implementations, hooks, as well as error and response processing.
Expand Down Expand Up @@ -853,6 +871,12 @@ export class ChatMistralAI<

streamUsage = true;

beforeRequestHooks?: Array<BeforeRequestHook>;

requestErrorHooks?: Array<RequestErrorHook>;

responseHooks?: Array<ResponseHook>;

httpClient?: MistralAIHTTPClient;

backoffStrategy = "none";
Expand Down Expand Up @@ -887,13 +911,17 @@ export class ChatMistralAI<
this.modelName = fields?.model ?? fields?.modelName ?? this.model;
this.model = this.modelName;
this.streamUsage = fields?.streamUsage ?? this.streamUsage;
this.beforeRequestHooks = fields?.beforeRequestHooks ?? this.beforeRequestHooks;
this.requestErrorHooks = fields?.requestErrorHooks ?? this.requestErrorHooks;
this.responseHooks = fields?.responseHooks ?? this.responseHooks;
this.httpClient = fields?.httpClient ?? this.httpClient;
this.backoffStrategy = fields?.backoffStrategy ?? this.backoffStrategy;
this.backoffInitialInterval = fields?.backoffInitialInterval ?? this.backoffInitialInterval;
this.backoffMaxInterval = fields?.backoffMaxInterval ?? this.backoffMaxInterval;
this.backoffExponent = fields?.backoffExponent ?? this.backoffExponent;
this.backoffMaxElapsedTime = fields?.backoffMaxElapsedTime ?? this.backoffMaxElapsedTime;
this.retryConnectionErrors = fields?.retryConnectionErrors ?? this.retryConnectionErrors;
this.addAllHooksToHttpClient();
}

getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
Expand Down Expand Up @@ -1159,6 +1187,79 @@ export class ChatMistralAI<
}
}

addAllHooksToHttpClient() {
try {
// To prevent duplicate hooks
this.removeAllHooksFromHttpClient();

// If the user wants to use hooks, but hasn't created an HTTPClient yet
const hasHooks = [
this.beforeRequestHooks,
this.requestErrorHooks,
this.responseHooks
].some(hook => hook && hook.length > 0);
if(hasHooks && !this.httpClient) {
this.httpClient = new MistralAIHTTPClient();
}

if(this.beforeRequestHooks) {
for(const hook of this.beforeRequestHooks) {
this.httpClient?.addHook("beforeRequest", hook);
}
}

if(this.requestErrorHooks) {
for(const hook of this.requestErrorHooks) {
this.httpClient?.addHook("requestError", hook);
}
}

if(this.responseHooks) {
for(const hook of this.responseHooks) {
this.httpClient?.addHook("response", hook);
}
}
} catch {
throw new Error("Error in adding all hooks");
}
}

removeAllHooksFromHttpClient() {
try {
if(this.beforeRequestHooks) {
for(const hook of this.beforeRequestHooks) {
this.httpClient?.removeHook("beforeRequest", hook);
}
}

if(this.requestErrorHooks) {
for(const hook of this.requestErrorHooks) {
this.httpClient?.removeHook("requestError", hook);
}
}

if(this.responseHooks) {
for(const hook of this.responseHooks) {
this.httpClient?.removeHook("response", hook);
}
}
} catch {
throw new Error("Error in removing hooks");
}
}

removeHookFromHttpClient(
hook: BeforeRequestHook | RequestErrorHook | ResponseHook
) {
try {
this.httpClient?.removeHook("beforeRequest", hook as BeforeRequestHook);
this.httpClient?.removeHook("requestError", hook as RequestErrorHook);
this.httpClient?.removeHook("response", hook as ResponseHook);
} catch {
throw new Error("Error in removing hook");
}
}

/** @ignore */
_combineLLMOutput() {
return [];
Expand Down
106 changes: 104 additions & 2 deletions libs/langchain-mistralai/src/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "@langchain/core/utils/chunk_array";
import { EmbeddingRequest as MistralAIEmbeddingsRequest} from "@mistralai/mistralai/src/models/components/embeddingrequest.js";
import { EmbeddingResponse as MistralAIEmbeddingsResponse} from "@mistralai/mistralai/src/models/components/embeddingresponse.js";
import { HTTPClient as MistralAIHTTPClient} from "@mistralai/mistralai/lib/http.js";
import { BeforeRequestHook, HTTPClient as MistralAIHTTPClient, RequestErrorHook, ResponseHook} from "@mistralai/mistralai/lib/http.js";

/**
* Interface for MistralAIEmbeddings parameters. Extends EmbeddingsParams and
Expand Down Expand Up @@ -46,6 +46,24 @@ export interface MistralAIEmbeddingsParams extends EmbeddingsParams {
* @default {true}
*/
stripNewLines?: boolean;
/**
* A list of custom hooks that must follow (req: Request) => Awaitable<Request | void>
* They are automatically added when a ChatMistralAI Object is created
* @default {[]}
*/
beforeRequestHooks?: Array<BeforeRequestHook>;
/**
* A list of custom hooks that must follow (err: unknown, req: Request) => Awaitable<void>
* They are automatically added when a ChatMistralAI Object is created
* @default {[]}
*/
requestErrorHooks?: Array<RequestErrorHook>;
/**
* A list of custom hooks that must follow (res: Response, req: Request) => Awaitable<void>
* They are automatically added when a ChatMistralAI Object is created
* @default {[]}
*/
responseHooks?: Array<ResponseHook>;
/**
* Optional custom HTTP client to manage API requests
* Allows users to add custom fetch implementations, hooks, as well as error and response processing.
Expand Down Expand Up @@ -75,6 +93,12 @@ export class MistralAIEmbeddings

serverURL?: string;

beforeRequestHooks?: Array<BeforeRequestHook>;

requestErrorHooks?: Array<RequestErrorHook>;

responseHooks?: Array<ResponseHook>;

httpClient?: MistralAIHTTPClient;

constructor(fields?: Partial<MistralAIEmbeddingsParams>) {
Expand All @@ -90,9 +114,14 @@ export class MistralAIEmbeddings
this.encodingFormat = fields?.encodingFormat ?? this.encodingFormat;
this.batchSize = fields?.batchSize ?? this.batchSize;
this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines;
this.httpClient = fields?.httpClient ?? undefined;
this.beforeRequestHooks = fields?.beforeRequestHooks ?? this.beforeRequestHooks;
this.requestErrorHooks = fields?.requestErrorHooks ?? this.requestErrorHooks;
this.responseHooks = fields?.responseHooks ?? this.responseHooks;
this.httpClient = fields?.httpClient ?? this.httpClient;
this.addAllHooksToHttpClient();
}


/**
* Method to generate embeddings for an array of documents. Splits the
* documents into batches and makes requests to the MistralAI API to generate
Expand Down Expand Up @@ -166,6 +195,79 @@ export class MistralAIEmbeddings
});
}

addAllHooksToHttpClient() {
try {
// To prevent duplicate hooks
this.removeAllHooksFromHttpClient();

// If the user wants to use hooks, but hasn't created an HTTPClient yet
const hasHooks = [
this.beforeRequestHooks,
this.requestErrorHooks,
this.responseHooks
].some(hook => hook && hook.length > 0);
if(hasHooks && !this.httpClient) {
this.httpClient = new MistralAIHTTPClient();
}

if(this.beforeRequestHooks) {
for(const hook of this.beforeRequestHooks) {
this.httpClient?.addHook("beforeRequest", hook);
}
}

if(this.requestErrorHooks) {
for(const hook of this.requestErrorHooks) {
this.httpClient?.addHook("requestError", hook);
}
}

if(this.responseHooks) {
for(const hook of this.responseHooks) {
this.httpClient?.addHook("response", hook);
}
}
} catch {
throw new Error("Error in adding all hooks");
}
}

removeAllHooksFromHttpClient() {
try {
if(this.beforeRequestHooks) {
for(const hook of this.beforeRequestHooks) {
this.httpClient?.removeHook("beforeRequest", hook);
}
}

if(this.requestErrorHooks) {
for(const hook of this.requestErrorHooks) {
this.httpClient?.removeHook("requestError", hook);
}
}

if(this.responseHooks) {
for(const hook of this.responseHooks) {
this.httpClient?.removeHook("response", hook);
}
}
} catch {
throw new Error("Error in removing hooks");
}
}

removeHookFromHttpClient(
hook: BeforeRequestHook | RequestErrorHook | ResponseHook
) {
try {
this.httpClient?.removeHook("beforeRequest", hook as BeforeRequestHook);
this.httpClient?.removeHook("requestError", hook as RequestErrorHook);
this.httpClient?.removeHook("response", hook as ResponseHook);
} catch {
throw new Error("Error in removing hook");
}
}

/** @ignore */
private async imports() {
const { Mistral } = await import("@mistralai/mistralai");
Expand Down
Loading

0 comments on commit 2db1dd4

Please sign in to comment.