Skip to content

Commit

Permalink
[ENH] Add an init method that eagerly reports errors with the tenant …
Browse files Browse the repository at this point in the history
…or DB (#2537)
  • Loading branch information
AlabasterAxe authored Jul 19, 2024
1 parent e93bacf commit 63207d5
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 133 deletions.
4 changes: 0 additions & 4 deletions clients/js/src/AdminClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@ export class AdminClient {
...this.api.options.headers,
...this.authProvider.authenticate(),
};
this.api.options.headers = {
...this.api.options.headers,
...this.authProvider.authenticate(),
};
}
}

Expand Down
85 changes: 46 additions & 39 deletions clients/js/src/ChromaClient.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Configuration, ApiApi as DefaultApi } from "./generated";
import { handleSuccess } from "./utils";
import { handleSuccess, validateTenantDatabase } from "./utils";
import { Collection } from "./Collection";
import {
ChromaClientParams,
Expand All @@ -25,10 +25,11 @@ export class ChromaClient {
* @ignore
*/
private api: DefaultApi & ConfigOptions;
private tenant: string = DEFAULT_TENANT;
private database: string = DEFAULT_DATABASE;
private _adminClient?: AdminClient;
private tenant: string;
private database: string;
private _adminClient: AdminClient;
private authProvider: ClientAuthProvider | undefined;
private _initPromise: Promise<void> | undefined;

/**
* Creates a new ChromaClient instance.
Expand All @@ -44,13 +45,12 @@ export class ChromaClient {
* ```
*/
constructor({
path,
path = "http://localhost:8000",
fetchOptions,
auth,
tenant = DEFAULT_TENANT,
database = DEFAULT_DATABASE,
}: ChromaClientParams = {}) {
if (path === undefined) path = "http://localhost:8000";
this.tenant = tenant;
this.database = database;
this.authProvider = undefined;
Expand All @@ -71,17 +71,25 @@ export class ChromaClient {
}

this._adminClient = new AdminClient({
path: path,
fetchOptions: fetchOptions,
auth: auth,
tenant: tenant,
database: database,
path,
fetchOptions,
auth,
tenant,
database,
});
}

/** @ignore */
private init(): Promise<void> {
if (!this._initPromise) {
this._initPromise = validateTenantDatabase(
this._adminClient,
this.tenant,
this.database,
);
}

// TODO: Validate tenant and database on client creation
// this got tricky because:
// - the constructor is sync but the generated api is async
// - we need to inject auth information so a simple rewrite/fetch does not work
return this._initPromise;
}

/**
Expand All @@ -96,7 +104,8 @@ export class ChromaClient {
* await client.reset();
* ```
*/
public async reset(): Promise<boolean> {
async reset(): Promise<boolean> {
await this.init();
return await this.api.reset(this.api.options);
}

Expand All @@ -110,7 +119,7 @@ export class ChromaClient {
* const version = await client.version();
* ```
*/
public async version(): Promise<string> {
async version(): Promise<string> {
const response = await this.api.version(this.api.options);
return await handleSuccess(response);
}
Expand All @@ -125,7 +134,7 @@ export class ChromaClient {
* const heartbeat = await client.heartbeat();
* ```
*/
public async heartbeat(): Promise<number> {
async heartbeat(): Promise<number> {
const response = await this.api.heartbeat(this.api.options);
let ret = await handleSuccess(response);
return ret["nanosecond heartbeat"];
Expand Down Expand Up @@ -153,15 +162,12 @@ export class ChromaClient {
* });
* ```
*/
public async createCollection({
async createCollection({
name,
metadata,
embeddingFunction,
embeddingFunction = new DefaultEmbeddingFunction(),
}: CreateCollectionParams): Promise<Collection> {
if (embeddingFunction === undefined) {
embeddingFunction = new DefaultEmbeddingFunction();
}

await this.init();
const newCollection = await this.api
.createCollection(
this.tenant,
Expand Down Expand Up @@ -211,15 +217,12 @@ export class ChromaClient {
* });
* ```
*/
public async getOrCreateCollection({
async getOrCreateCollection({
name,
metadata,
embeddingFunction,
embeddingFunction = new DefaultEmbeddingFunction(),
}: GetOrCreateCollectionParams): Promise<Collection> {
if (embeddingFunction === undefined) {
embeddingFunction = new DefaultEmbeddingFunction();
}

await this.init();
const newCollection = await this.api
.createCollection(
this.tenant,
Expand Down Expand Up @@ -259,10 +262,10 @@ export class ChromaClient {
* });
* ```
*/
public async listCollections({
limit,
offset,
}: ListCollectionsParams = {}): Promise<CollectionType[]> {
async listCollections({ limit, offset }: ListCollectionsParams = {}): Promise<
CollectionType[]
> {
await this.init();
const response = await this.api.listCollections(
limit,
offset,
Expand All @@ -284,7 +287,9 @@ export class ChromaClient {
* const collections = await client.countCollections();
* ```
*/
public async countCollections(): Promise<number> {
async countCollections(): Promise<number> {
await this.init();

const response = await this.api.countCollections(
this.tenant,
this.database,
Expand All @@ -308,10 +313,12 @@ export class ChromaClient {
* });
* ```
*/
public async getCollection({
async getCollection({
name,
embeddingFunction,
}: GetCollectionParams): Promise<Collection> {
await this.init();

const response = await this.api
.getCollection(name, this.tenant, this.database, this.api.options)
.then(handleSuccess);
Expand Down Expand Up @@ -339,9 +346,9 @@ export class ChromaClient {
* });
* ```
*/
public async deleteCollection({
name,
}: DeleteCollectionParams): Promise<void> {
async deleteCollection({ name }: DeleteCollectionParams): Promise<void> {
await this.init();

return await this.api
.deleteCollection(name, this.tenant, this.database, this.api.options)
.then(handleSuccess);
Expand Down
8 changes: 2 additions & 6 deletions clients/js/src/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ export class BasicAuthClientProvider implements ClientAuthProvider {
* @throws {Error} If neither credentials provider or text credentials are supplied.
*/
constructor(textCredentials: string | undefined) {
const envVarTextCredentials = process.env.CHROMA_CLIENT_AUTH_CREDENTIALS;

const creds = textCredentials ?? envVarTextCredentials;
const creds = textCredentials ?? process.env.CHROMA_CLIENT_AUTH_CREDENTIALS;
if (creds === undefined) {
throw new Error(
"Credentials must be supplied via environment variable (CHROMA_CLIENT_AUTH_CREDENTIALS) or passed in as configuration.",
Expand All @@ -64,9 +62,7 @@ export class TokenAuthClientProvider implements ClientAuthProvider {
textCredentials: any,
headerType: TokenHeaderType = "AUTHORIZATION",
) {
const envVarTextCredentials = process.env.CHROMA_CLIENT_AUTH_CREDENTIALS;

const creds = textCredentials ?? envVarTextCredentials;
const creds = textCredentials ?? process.env.CHROMA_CLIENT_AUTH_CREDENTIALS;
if (creds === undefined) {
throw new Error(
"Credentials must be supplied via environment variable (CHROMA_CLIENT_AUTH_CREDENTIALS) or passed in as configuration.",
Expand Down
13 changes: 11 additions & 2 deletions clients/js/src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Api } from "./generated";
import Count200Response = Api.Count200Response;
import { AdminClient } from "./AdminClient";
import { ChromaConnectionError } from "./Errors";

// a function to convert a non-Array object to an Array
export function toArray<T>(obj: T | Array<T>): Array<T> {
Expand Down Expand Up @@ -72,16 +73,24 @@ export async function validateTenantDatabase(
try {
await adminClient.getTenant({ name: tenant });
} catch (error) {
if (error instanceof ChromaConnectionError) {
throw error;
}
throw new Error(
`Error: ${error}, Could not connect to tenant ${tenant}. Are you sure it exists?`,
`Could not connect to tenant ${tenant}. Are you sure it exists? Underlying error:
${error}`,
);
}

try {
await adminClient.getDatabase({ name: database, tenantName: tenant });
} catch (error) {
if (error instanceof ChromaConnectionError) {
throw error;
}
throw new Error(
`Error: ${error}, Could not connect to database ${database} for tenant ${tenant}. Are you sure it exists?`,
`Could not connect to database ${database} for tenant ${tenant}. Are you sure it exists? Underlying error:
${error}`,
);
}
}
Expand Down
21 changes: 12 additions & 9 deletions clients/js/test/auth.basic.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { expect, test } from "@jest/globals";
import { chromaBasic } from "./initClientWithAuth";
import chromaNoAuth from "./initClient";
import { ChromaForbiddenError } from "../src/Errors";

test("it should get the version without auth needed", async () => {
const version = await chromaNoAuth.version();
Expand All @@ -15,19 +14,23 @@ test("it should get the heartbeat without auth needed", async () => {
expect(heartbeat).toBeGreaterThan(0);
});

test("it should raise error when non authenticated", async () => {
await expect(chromaNoAuth.listCollections()).rejects.toBeInstanceOf(
ChromaForbiddenError,
);
test("it should throw error when non authenticated", async () => {
try {
await chromaNoAuth.listCollections();
} catch (e) {
expect(e).toBeInstanceOf(Error);
}
});

test("it should list collections", async () => {
await chromaBasic.reset();
let collections = await chromaBasic.listCollections();
const client = chromaBasic();

await client.reset();
let collections = await client.listCollections();
expect(collections).toBeDefined();
expect(collections).toBeInstanceOf(Array);
expect(collections.length).toBe(0);
await chromaBasic.createCollection({ name: "test" });
collections = await chromaBasic.listCollections();
await client.createCollection({ name: "test" });
collections = await client.listCollections();
expect(collections.length).toBe(1);
});
60 changes: 20 additions & 40 deletions clients/js/test/auth.token.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import {
cloudClient,
} from "./initClientWithAuth";
import chromaNoAuth from "./initClient";
import { ChromaForbiddenError } from "../src/Errors";

test("it should get the version without auth needed", async () => {
const version = await chromaNoAuth.version();
expect(version).toBeDefined();
Expand All @@ -21,59 +19,41 @@ test("it should get the heartbeat without auth needed", async () => {
});

test("it should raise error when non authenticated", async () => {
await expect(chromaNoAuth.listCollections()).rejects.toBeInstanceOf(
ChromaForbiddenError,
);
await expect(chromaNoAuth.listCollections()).rejects.toBeInstanceOf(Error);
});

if (!process.env.XTOKEN_TEST) {
test("it should list collections with default token config", async () => {
await chromaTokenDefault.reset();
let collections = await chromaTokenDefault.listCollections();
if (process.env.XTOKEN_TEST) {
test.each([
["x-token", chromaTokenXToken],
["cloud client", cloudClient],
])(`it should list collections with %s`, async (_, clientBuilder) => {
const client = clientBuilder();
await client.reset();
let collections = await client.listCollections();
expect(collections).toBeDefined();
expect(collections).toBeInstanceOf(Array);
expect(collections.length).toBe(0);
const collection = await chromaTokenDefault.createCollection({
await client.createCollection({
name: "test",
});
collections = await chromaTokenDefault.listCollections();
expect(collections.length).toBe(1);
});

test("it should list collections with explicit bearer token config", async () => {
await chromaTokenBearer.reset();
let collections = await chromaTokenBearer.listCollections();
expect(collections).toBeDefined();
expect(collections).toBeInstanceOf(Array);
expect(collections.length).toBe(0);
const collection = await chromaTokenBearer.createCollection({
name: "test",
});
collections = await chromaTokenBearer.listCollections();
collections = await client.listCollections();
expect(collections.length).toBe(1);
});
} else {
test("it should list collections with explicit x-token token config", async () => {
await chromaTokenXToken.reset();
let collections = await chromaTokenXToken.listCollections();
test.each([
["default token", chromaTokenDefault],
["bearer token", chromaTokenBearer],
])(`it should list collections with %s`, async (_, clientBuilder) => {
const client = clientBuilder();
await client.reset();
let collections = await client.listCollections();
expect(collections).toBeDefined();
expect(collections).toBeInstanceOf(Array);
expect(collections.length).toBe(0);
const collection = await chromaTokenXToken.createCollection({
await client.createCollection({
name: "test",
});
collections = await chromaTokenXToken.listCollections();
expect(collections.length).toBe(1);
});

test("it should list collections with explicit x-token token config in CloudClient", async () => {
await cloudClient.reset();
let collections = await cloudClient.listCollections();
expect(collections).toBeDefined();
expect(collections).toBeInstanceOf(Array);
expect(collections.length).toBe(0);
const collection = await cloudClient.createCollection({ name: "test" });
collections = await cloudClient.listCollections();
collections = await client.listCollections();
expect(collections.length).toBe(1);
});
}
Loading

0 comments on commit 63207d5

Please sign in to comment.