Skip to content

Commit

Permalink
wrangler: Add AI binding (#3992)
Browse files Browse the repository at this point in the history
* wrangler: Add AI binding

Added binding for the AI project.

* Workers AI: added example
  • Loading branch information
edevil authored Sep 21, 2023
1 parent 7ba16cd commit 3556474
Show file tree
Hide file tree
Showing 14 changed files with 220 additions and 0 deletions.
36 changes: 36 additions & 0 deletions .changeset/brave-stingrays-shout.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
---
"wrangler": patch
---

Add AI binding that will be used to interact with the AI project.

Example `wrangler.toml`

name = "ai-worker"
main = "src/index.ts"

[ai]
binding = "AI"

Example script:

import Ai from "@cloudflare/ai"

export default {
async fetch(request: Request, env: Env): Promise<Response> {
const ai = new Ai(env.AI);

const story = await ai.run({
model: 'llama-2',
input: {
prompt: 'Tell me a story about the future of the Cloudflare dev platform'
}
});

return new Response(JSON.stringify(story));
},
};

export interface Env {
AI: any;
}
59 changes: 59 additions & 0 deletions packages/wrangler/src/__tests__/configuration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ describe("normalizeAndValidateConfig()", () => {
site: undefined,
text_blobs: undefined,
browser: undefined,
ai: undefined,
triggers: {
crons: [],
},
Expand Down Expand Up @@ -1620,6 +1621,64 @@ describe("normalizeAndValidateConfig()", () => {
});
});

describe("[ai]", () => {
it("should error if ai is an array", () => {
const { diagnostics } = normalizeAndValidateConfig(
{ ai: [] } as unknown as RawConfig,
undefined,
{ env: undefined }
);

expect(diagnostics.hasWarnings()).toBe(false);
expect(diagnostics.renderErrors()).toMatchInlineSnapshot(`
"Processing wrangler configuration:
- The field \\"ai\\" should be an object but got []."
`);
});

it("should error if ai is a string", () => {
const { diagnostics } = normalizeAndValidateConfig(
{ ai: "BAD" } as unknown as RawConfig,
undefined,
{ env: undefined }
);

expect(diagnostics.hasWarnings()).toBe(false);
expect(diagnostics.renderErrors()).toMatchInlineSnapshot(`
"Processing wrangler configuration:
- The field \\"ai\\" should be an object but got \\"BAD\\"."
`);
});

it("should error if ai is a number", () => {
const { diagnostics } = normalizeAndValidateConfig(
{ ai: 999 } as unknown as RawConfig,
undefined,
{ env: undefined }
);

expect(diagnostics.hasWarnings()).toBe(false);
expect(diagnostics.renderErrors()).toMatchInlineSnapshot(`
"Processing wrangler configuration:
- The field \\"ai\\" should be an object but got 999."
`);
});

it("should error if ai is null", () => {
const { diagnostics } = normalizeAndValidateConfig(
{ ai: null } as unknown as RawConfig,
undefined,
{ env: undefined }
);

expect(diagnostics.hasWarnings()).toBe(false);
expect(diagnostics.renderErrors()).toMatchInlineSnapshot(`
"Processing wrangler configuration:
- The field \\"ai\\" should be an object but got null."
`);
});
});

describe("[kv_namespaces]", () => {
it("should error if kv_namespaces is an object", () => {
const { diagnostics } = normalizeAndValidateConfig(
Expand Down
37 changes: 37 additions & 0 deletions packages/wrangler/src/__tests__/deploy.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8367,6 +8367,43 @@ export default{
});
});

describe("ai", () => {
it("should upload ai bindings", async () => {
writeWranglerToml({
ai: { binding: "AI_BIND" },
browser: { binding: "MYBROWSER" },
});
await fs.promises.writeFile("index.js", `export default {};`);
mockSubDomainRequest();
mockUploadWorkerRequest({
expectedBindings: [
{
type: "browser",
name: "MYBROWSER",
},
{
type: "ai",
name: "AI_BIND",
},
],
});

await runWrangler("deploy index.js");
expect(std.out).toMatchInlineSnapshot(`
"Total Upload: xx KiB / gzip: xx KiB
Your worker has access to the following bindings:
- Browser:
- Name: MYBROWSER
- AI:
- Name: AI_BIND
Uploaded test-name (TIMINGS)
Published test-name (TIMINGS)
https://test-name.test-sub-domain.workers.dev
Current Deployment ID: Galaxy-Class"
`);
});
});

describe("mtls_certificates", () => {
it("should upload mtls_certificate bindings", async () => {
writeWranglerToml({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ function createWorkerBundleFormData(workerBundle: BundleResult): FormData {
wasm_modules: undefined,
text_blobs: undefined,
browser: undefined,
ai: undefined,
data_blobs: undefined,
durable_objects: undefined,
queues: undefined,
Expand Down
9 changes: 9 additions & 0 deletions packages/wrangler/src/config/environment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,15 @@ interface EnvironmentNonInheritable {
}
| undefined;

/**
* Binding to the AI project.
*/
ai:
| {
binding: string;
}
| undefined;

/**
* "Unsafe" tables for features that aren't directly supported by wrangler.
*
Expand Down
8 changes: 8 additions & 0 deletions packages/wrangler/src/config/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ export function printBindings(bindings: CfWorkerInit["bindings"]) {
analytics_engine_datasets,
text_blobs,
browser,
ai,
unsafe,
vars,
wasm_modules,
Expand Down Expand Up @@ -296,6 +297,13 @@ export function printBindings(bindings: CfWorkerInit["bindings"]) {
});
}

if (ai !== undefined) {
output.push({
type: "AI",
entries: [{ key: "Name", value: ai.binding }],
});
}

if (unsafe?.bindings !== undefined && unsafe.bindings.length > 0) {
output.push({
type: "Unsafe",
Expand Down
6 changes: 6 additions & 0 deletions packages/wrangler/src/config/validation-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,12 @@ export const getBindingNames = (value: unknown): string[] => {
} else if (isNamespaceList(value)) {
return value.map(({ binding }) => binding);
} else if (isRecord(value)) {
// browser and AI bindings are single values with a similar shape
// { binding = "name" }
if (value["binding"] !== undefined) {
return [value["binding"] as string];
}

return Object.keys(value).filter((k) => value[k] !== undefined);
} else {
return [];
Expand Down
37 changes: 37 additions & 0 deletions packages/wrangler/src/config/validation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,16 @@ function normalizeAndValidateEnvironment(
validateBrowserBinding(envName),
undefined
),
ai: notInheritable(
diagnostics,
topLevelEnv,
rawConfig,
rawEnv,
envName,
"ai",
validateAIBinding(envName),
undefined
),
zone_id: rawEnv.zone_id,
logfwdr: inheritable(
diagnostics,
Expand Down Expand Up @@ -1893,6 +1903,30 @@ const validateBrowserBinding =
return isValid;
};

const validateAIBinding =
(envName: string): ValidatorFn =>
(diagnostics, field, value, config) => {
const fieldPath =
config === undefined ? `${field}` : `env.${envName}.${field}`;

if (typeof value !== "object" || value === null || Array.isArray(value)) {
diagnostics.errors.push(
`The field "${fieldPath}" should be an object but got ${JSON.stringify(
value
)}.`
);
return false;
}

let isValid = true;
if (!isRequiredProperty(value, "binding", "string")) {
diagnostics.errors.push(`binding should have a string "binding" field.`);
isValid = false;
}

return isValid;
};

/**
* Check that the given field is a valid "unsafe" binding object.
*
Expand Down Expand Up @@ -1920,6 +1954,7 @@ const validateUnsafeBinding: ValidatorFn = (diagnostics, field, value) => {
"data_blob",
"text_blob",
"browser",
"ai",
"kv_namespace",
"durable_object_namespace",
"d1_database",
Expand Down Expand Up @@ -2278,6 +2313,7 @@ const validateBindingsHaveUniqueNames = (
analytics_engine_datasets,
text_blobs,
browser,
ai,
unsafe,
vars,
define,
Expand All @@ -2294,6 +2330,7 @@ const validateBindingsHaveUniqueNames = (
"Analytics Engine Dataset": getBindingNames(analytics_engine_datasets),
"Text Blob": getBindingNames(text_blobs),
Browser: getBindingNames(browser),
AI: getBindingNames(ai),
Unsafe: getBindingNames(unsafe),
"Environment Variable": getBindingNames(vars),
Definition: getBindingNames(define),
Expand Down
1 change: 1 addition & 0 deletions packages/wrangler/src/deploy/deploy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ See https://developers.cloudflare.com/workers/platform/compatibility-dates for m
vars: { ...config.vars, ...props.vars },
wasm_modules: config.wasm_modules,
browser: config.browser,
ai: config.ai,
text_blobs: {
...config.text_blobs,
...(assets.manifest &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export type WorkerMetadataBinding =
| { type: "wasm_module"; name: string; part: string }
| { type: "text_blob"; name: string; part: string }
| { type: "browser"; name: string }
| { type: "ai"; name: string }
| { type: "data_blob"; name: string; part: string }
| { type: "kv_namespace"; name: string; namespace_id: string }
| {
Expand Down Expand Up @@ -268,6 +269,13 @@ export function createWorkerUploadForm(worker: CfWorkerInit): FormData {
});
}

if (bindings.ai !== undefined) {
metadataBindings.push({
name: bindings.ai.binding,
type: "ai",
});
}

for (const [name, filePath] of Object.entries(bindings.text_blobs || {})) {
metadataBindings.push({
name,
Expand Down
9 changes: 9 additions & 0 deletions packages/wrangler/src/deployment-bundle/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ export interface CfBrowserBinding {
binding: string;
}

/**
* A binding to the AI project
*/

export interface CfAIBinding {
binding: string;
}

/**
* A binding to a data blob (in service-worker format)
*/
Expand Down Expand Up @@ -256,6 +264,7 @@ export interface CfWorkerInit {
wasm_modules: CfWasmModuleBindings | undefined;
text_blobs: CfTextBlobBindings | undefined;
browser: CfBrowserBinding | undefined;
ai: CfAIBinding | undefined;
data_blobs: CfDataBlobBindings | undefined;
durable_objects: { bindings: CfDurableObject[] } | undefined;
queues: CfQueue[] | undefined;
Expand Down
1 change: 1 addition & 0 deletions packages/wrangler/src/dev.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,7 @@ function getBindings(
wasm_modules: configParam.wasm_modules,
text_blobs: configParam.text_blobs,
browser: configParam.browser,
ai: configParam.ai,
data_blobs: configParam.data_blobs,
durable_objects: {
bindings: [
Expand Down
7 changes: 7 additions & 0 deletions packages/wrangler/src/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,13 @@ export function mapBindings(bindings: WorkerMetadataBinding[]): RawConfig {
};
}
break;
case "ai":
{
configObj.ai = {
binding: binding.name,
};
}
break;
case "r2_bucket":
{
configObj.r2_buckets = [
Expand Down
1 change: 1 addition & 0 deletions packages/wrangler/src/secret/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ export const secret = (secretYargs: CommonYargsArgv) => {
analytics_engine_datasets: [],
wasm_modules: {},
browser: undefined,
ai: undefined,
text_blobs: {},
data_blobs: {},
dispatch_namespaces: [],
Expand Down

0 comments on commit 3556474

Please sign in to comment.