diff --git a/.eslintrc.yml b/.eslintrc.yml index 213e622..c5f1202 100644 --- a/.eslintrc.yml +++ b/.eslintrc.yml @@ -14,3 +14,10 @@ rules: "import/order": - error - newlines-between: always + "import/extensions": + - error + - ignorePackages + - js: always + "quote-props": + - error + - consistent diff --git a/.prettierrc.yml b/.prettierrc.yml new file mode 100644 index 0000000..bf836d7 --- /dev/null +++ b/.prettierrc.yml @@ -0,0 +1 @@ +quoteProps: "consistent" diff --git a/src/commands/config/default.js b/src/commands/config/default.js new file mode 100644 index 0000000..2b9cf76 --- /dev/null +++ b/src/commands/config/default.js @@ -0,0 +1,55 @@ +import { stdin, stdout } from "node:process"; +import { createInterface } from "node:readline"; +import { promisify } from "node:util"; + +import { DEFAULT_ENDPOINT } from "../../utils/constants.js"; +import { mergeConfig } from "../../utils/config.js"; + +export const defaultCommandDefinition = [ + "$0", + "Modify configuration", + {}, + async (args) => { + const rl = createInterface({ + input: stdin, + output: stdout, + }); + const question = promisify(rl.question).bind(rl); + rl.on("close", () => { + console.log(); + }); + + const config = { + configuration: {}, + credentials: {}, + }; + if (args.profile) { + config.configuration.profiles = { [args.profile]: {} }; + config.credentials.profiles = { [args.profile]: {} }; + } + + const endpoint = await question( + `Endpoint (${ + args.endpoint ?? process.env.GENAI_ENDPOINT ?? DEFAULT_ENDPOINT + }): ` + ); + if (endpoint) { + if (args.profile) { + config.configuration.profiles[args.profile].endpoint = endpoint; + } else { + config.configuration.endpoint = endpoint; + } + } + const apiKey = await question(`API Key (${args.apiKey ?? "none"}): `); + if (apiKey) { + if (args.profile) { + config.credentials.profiles[args.profile].apiKey = apiKey; + } else { + config.credentials.apiKey = apiKey; + } + } + + rl.close(); + mergeConfig(config); + }, +]; diff --git a/src/commands/config/index.js b/src/commands/config/index.js new file mode 100644 index 0000000..f7338eb --- /dev/null +++ b/src/commands/config/index.js @@ -0,0 +1,15 @@ +import { defaultCommandDefinition } from "./default.js"; +import { profilesCommandDefinition } from "./profiles.js"; +import { removeCommandDefinition } from "./remove.js"; +import { showCommandDefinition } from "./show.js"; + +export const configCommandDefinition = [ + "config", + "Manage CLI configuration", + (yargs) => + yargs + .command(...defaultCommandDefinition) + .command(...showCommandDefinition) + .command(...profilesCommandDefinition) + .command(...removeCommandDefinition), +]; diff --git a/src/commands/config/profiles.js b/src/commands/config/profiles.js new file mode 100644 index 0000000..83acbca --- /dev/null +++ b/src/commands/config/profiles.js @@ -0,0 +1,13 @@ +import { allProfiles } from "../../utils/config.js"; + +export const profilesCommandDefinition = [ + "profiles", + "List configuration profiles", + {}, + () => { + const profiles = allProfiles(); + profiles.forEach((profile) => { + console.log(profile); + }); + }, +]; diff --git a/src/commands/config/remove.js b/src/commands/config/remove.js new file mode 100644 index 0000000..cd99098 --- /dev/null +++ b/src/commands/config/remove.js @@ -0,0 +1,10 @@ +import { deleteProfileConfig } from "../../utils/config.js"; + +export const removeCommandDefinition = [ + "remove", + "Remove configuration", + {}, + (args) => { + deleteProfileConfig(args.profile); + }, +]; diff --git a/src/commands/config/show.js b/src/commands/config/show.js new file mode 100644 index 0000000..60f80d7 --- /dev/null +++ b/src/commands/config/show.js @@ -0,0 +1,12 @@ +import { loadProfileConfig } from "../../utils/config.js"; +import { prettyPrint } from "../../utils/print.js"; + +export const showCommandDefinition = [ + "show", + "Show configuration", + {}, + (args) => { + const config = loadProfileConfig(args.profile); + prettyPrint(config); + }, +]; diff --git a/src/commands/files/delete.js b/src/commands/files/delete.js new file mode 100644 index 0000000..c217cc6 --- /dev/null +++ b/src/commands/files/delete.js @@ -0,0 +1,12 @@ +export const deleteCommandDefinition = [ + "delete ", + "Delete a file", + (yargs) => + yargs.positional("id", { + type: "string", + description: "Identifier of the file to be deleted", + }), + async (args) => { + await args.client.file({ id: args.id }, { delete: true }); + }, +]; diff --git a/src/commands/files/download.js b/src/commands/files/download.js new file mode 100644 index 0000000..fc3e41d --- /dev/null +++ b/src/commands/files/download.js @@ -0,0 +1,35 @@ +import { stdout } from "node:process"; +import { createWriteStream } from "node:fs"; +import { pipeline } from "node:stream/promises"; + +import { groupOptions } from "../../utils/yargs.js"; + +export const downloadCommandDefinition = [ + "download ", + "Download a file", + (yargs) => + yargs + .positional("id", { + type: "string", + description: "Identifier of the file to download", + }) + .options( + groupOptions({ + output: { + alias: "o", + describe: "Filepath to write the file to", + type: "string", + normalize: true, + requiresArg: true, + coerce: (output) => createWriteStream(output), + }, + }) + ), + async (args) => { + const { download } = await args.client.file({ + id: args.id, + }); + const readable = await download(); + await pipeline(readable, args.output ?? stdout); + }, +]; diff --git a/src/commands/files/index.js b/src/commands/files/index.js new file mode 100644 index 0000000..4e695ed --- /dev/null +++ b/src/commands/files/index.js @@ -0,0 +1,21 @@ +import { clientMiddleware } from "../../middleware/client.js"; + +import { deleteCommandDefinition } from "./delete.js"; +import { downloadCommandDefinition } from "./download.js"; +import { infoCommandDefinition } from "./info.js"; +import { listCommandDefinition } from "./list.js"; +import { uploadCommandDefinition } from "./upload.js"; + +export const filesCommandDefinition = [ + "files", + "Upload and manage files", + (yargs) => + yargs + .middleware(clientMiddleware) + .command(...listCommandDefinition) + .command(...infoCommandDefinition) + .command(...uploadCommandDefinition) + .command(...downloadCommandDefinition) + .command(...deleteCommandDefinition) + .demandCommand(1, 1, "Please choose a command"), +]; diff --git a/src/commands/files/info.js b/src/commands/files/info.js new file mode 100644 index 0000000..4429f0a --- /dev/null +++ b/src/commands/files/info.js @@ -0,0 +1,17 @@ +import { prettyPrint } from "../../utils/print.js"; + +export const infoCommandDefinition = [ + "info ", + "Show detailed information about a file", + (yargs) => + yargs.positional("id", { + type: "string", + description: "Identifier of the file", + }), + async (args) => { + const { id, file_name, purpose, created_at } = await args.client.file({ + id: args.id, + }); + prettyPrint({ id, name: file_name, purpose, created_at }); + }, +]; diff --git a/src/commands/files/list.js b/src/commands/files/list.js new file mode 100644 index 0000000..dd5ec24 --- /dev/null +++ b/src/commands/files/list.js @@ -0,0 +1,24 @@ +import { FilePurposeSchema } from "@ibm-generative-ai/node-sdk"; + +import { groupOptions } from "../../utils/yargs.js"; + +export const listCommandDefinition = [ + "list", + "List all files", + groupOptions({ + purpose: { + alias: "p", + type: "array", + description: "Filter listed by purpose", + requiresArg: true, + choices: FilePurposeSchema.options, + }, + }), + async (args) => { + for await (const file of args.client.files()) { + if (!args.purpose || args.purpose.includes(file.purpose)) { + console.log(`${file.id} (${file.file_name})`); + } + } + }, +]; diff --git a/src/commands/files/upload.js b/src/commands/files/upload.js new file mode 100644 index 0000000..7984027 --- /dev/null +++ b/src/commands/files/upload.js @@ -0,0 +1,44 @@ +import { createReadStream } from "node:fs"; +import path from "node:path"; + +import { FilePurposeSchema } from "@ibm-generative-ai/node-sdk"; + +import { groupOptions } from "../../utils/yargs.js"; +import { prettyPrint } from "../../utils/print.js"; + +export const uploadCommandDefinition = [ + "upload ", + "Upload a file", + (yargs) => + yargs + .positional("file", { + type: "string", + describe: "Filepath to the file to be uploaded", + normalize: true, + }) + .options( + groupOptions({ + purpose: { + alias: "p", + description: "Purpose of the file", + requiresArg: true, + demandOption: true, + choices: FilePurposeSchema.options, + }, + name: { + alias: "n", + type: "string", + description: "Name of the file", + requiresArg: true, + }, + }) + ), + async (args) => { + const { id, file_name, purpose, created_at } = await args.client.file({ + purpose: args.purpose, + filename: args.name ?? path.parse(args.file).base, + file: createReadStream(args.file), + }); + prettyPrint({ id, name: file_name, purpose, created_at }); + }, +]; diff --git a/src/commands/generate/config.js b/src/commands/generate/config.js new file mode 100644 index 0000000..4a685ae --- /dev/null +++ b/src/commands/generate/config.js @@ -0,0 +1,94 @@ +import isEmpty from "lodash/isEmpty.js"; +import merge from "lodash/merge.js"; + +import { pickDefined } from "../../utils/common.js"; +import { prettyPrint } from "../../utils/print.js"; +import { groupOptions } from "../../utils/yargs.js"; + +export const configCommandDefinition = [ + "config", + "Read and modify generate configuration", + (yargs) => + yargs + .options( + groupOptions( + { + "input-text": { + type: "boolean", + description: "Include input text", + }, + "generated-tokens": { + type: "boolean", + description: "Include the list of individual generated tokens", + }, + "input-tokens": { + type: "boolean", + description: "Include the list of input tokens", + }, + "token-logprobs": { + type: "boolean", + description: "Include logprob for each token", + }, + "token-ranks": { + type: "boolean", + description: "Include rank of each returned token", + }, + "top-n-tokens": { + type: "number", + requiresArg: true, + nargs: 1, + description: + "Include top n candidate tokens at the position of each returned token", + }, + }, + "Configuration:" + ) + ) + .middleware((args) => { + const return_options = pickDefined({ + input_text: args.inputText, + generated_tokens: args.generatedTokens, + input_tokens: args.inputTokens, + token_logprobs: args.tokenLogprobs, + input_ranks: args.tokenRanks, + top_n_tokens: args.topNTokens, + }); + args.parameters = !isEmpty(return_options) + ? merge({}, args.parameters, { return_options }) + : args.parameters; + }) + .options( + groupOptions({ + reset: { + describe: "Reset the config", + type: "boolean", + }, + replace: { + describe: "Replace the entire config", + type: "boolean", + conflicts: "reset", + }, + }) + ), + async (args) => { + const hasInput = args.model || args.parameters; + + const output = hasInput + ? await args.client.generateConfig( + { + model_id: args.model, + parameters: args.parameters, + }, + { + strategy: args.replace ? "replace" : "merge", + timeout: args.timeout, + } + ) + : await args.client.generateConfig({ + reset: args.reset, + timeout: args.timeout, + }); + + prettyPrint(output); + }, +]; diff --git a/src/commands/generate/default.js b/src/commands/generate/default.js new file mode 100644 index 0000000..1d7c744 --- /dev/null +++ b/src/commands/generate/default.js @@ -0,0 +1,103 @@ +import { createReadStream, createWriteStream } from "node:fs"; +import { stdin, stdout } from "node:process"; + +import { BaseError as BaseSDKError } from "@ibm-generative-ai/node-sdk"; + +import { readJSONStream } from "../../utils/streams.js"; +import { groupOptions } from "../../utils/yargs.js"; +import { parseInput } from "../../utils/parsers.js"; + +export const defaultCommandDefinition = [ + ["$0 [inputs..]"], // Default subcommand for generate command + "Generate a text based on an input. Outputs will follow JSONL format. Inputs coming from stdin MUST follow the JSONL format.", + (yargs) => + yargs + .positional("inputs", { + describe: "Text serving as an input for the generation", + array: true, + conflicts: "input", + }) + .options( + groupOptions({ + file: { + alias: "f", + describe: + "File to read the inputs from. File MUST follow JSONL format", + array: true, + normalize: true, + requiresArg: true, + conflicts: "inputs", + coerce: async (files) => { + const inputs = await Promise.all( + files.map((file) => readJSONStream(createReadStream(file))) + ); + return inputs.flat().map(parseInput); + }, + }, + output: { + alias: "o", + describe: "File to write the outputs to", + type: "string", + normalize: true, + requiresArg: true, + coerce: (output) => { + if (typeof output !== "string") + throw new Error("Only a single output file must be specified"); + return createWriteStream(output); + }, + }, + }) + ), + async (args) => { + const fileInputs = args.file; + const inlineInputs = args.inputs; + const inputs = + inlineInputs ?? + fileInputs ?? + (await readJSONStream(stdin)).map(parseInput); + + const outputStream = args.output ?? stdout; + const mappedInputs = inputs.map((input) => ({ + model_id: args.model ?? "default", + parameters: args.parameters, + input, + })); + + let hasError = false; + if (args.parameters?.stream) { + try { + for await (const chunk of args.client.generate(mappedInputs, { + timeout: args.timeout, + stream: true, + })) { + outputStream.write(chunk.generated_text); + } + } catch (err) { + hasError = true; + outputStream.write(JSON.stringify({ error: err.message })); + } finally { + outputStream.write("\n"); + } + } else { + for (const promise of args.client.generate(mappedInputs, { + timeout: args.timeout, + })) { + try { + const output = await promise; + outputStream.write(JSON.stringify(output.generated_text)); + } catch (err) { + hasError = true; + outputStream.write(JSON.stringify({ error: err.message })); + } finally { + outputStream.write("\n"); + } + } + } + + if (hasError) { + throw new BaseSDKError( + "Errors have been encountered during generation, see output" + ); + } + }, +]; diff --git a/src/commands/generate/index.js b/src/commands/generate/index.js new file mode 100644 index 0000000..43e9b18 --- /dev/null +++ b/src/commands/generate/index.js @@ -0,0 +1,160 @@ +import isEmpty from "lodash/isEmpty.js"; + +import { clientMiddleware } from "../../middleware/client.js"; +import { pickDefined } from "../../utils/common.js"; +import { groupOptions } from "../../utils/yargs.js"; + +import { interactiveCommandDefinition } from "./interactive.js"; +import { configCommandDefinition } from "./config.js"; +import { defaultCommandDefinition } from "./default.js"; + +export const generateCommandDefinition = [ + "generate", + "Generate a text from an input text", + (yargs) => + yargs + .middleware(clientMiddleware) + .options( + groupOptions( + { + "model": { + alias: "m", + describe: "Select a model to be used for generation", + requiresArg: true, + type: "string", + coerce: (parameters) => { + if (typeof parameters !== "string") + throw new Error("Only a single model must be specified"); + return parameters; + }, + }, + "stream": { + type: "boolean", + description: + "Enables to stream partial progress as server-sent events.", + }, + "decoding-method": { + type: "string", + requiresArg: true, + nargs: 1, + choices: ["greedy", "sample"], + description: + "Represents the strategy used for picking the tokens during generation of the output text", + }, + "decay-factor": { + type: "number", + requiresArg: true, + nargs: 1, + description: "Represents the factor of exponential decay", + }, + "decay-start-index": { + type: "number", + requiresArg: true, + nargs: 1, + description: + "A number of generated tokens after which the decay factor should take effect", + }, + "max-new-tokens": { + type: "number", + requiresArg: true, + nargs: 1, + description: "The maximum number of new tokens to be generated", + }, + "min-new-tokens": { + type: "number", + requiresArg: true, + nargs: 1, + description: "The minimum number of new tokens to be generated", + }, + "random-seed": { + type: "number", + requiresArg: true, + nargs: 1, + description: + "Random number generator seed to use in sampling mode for experimental repeatability", + }, + "stop-sequences": { + array: true, + type: "string", + requiresArg: true, + description: + "One or more strings which will cause the text generation to stop if detected in the output", + }, + "temperature": { + type: "number", + requiresArg: true, + nargs: 1, + description: + "A value used to modify the next-token probabilities in sampling mode", + }, + "time-limit": { + type: "number", + requiresArg: true, + nargs: 1, + description: "Time limit in milliseconds", + }, + "top-k": { + type: "number", + requiresArg: true, + nargs: 1, + description: + "The number of highest probability vocabulary tokens to keep for top-k-filtering", + }, + "top-p": { + type: "number", + requiresArg: true, + nargs: 1, + description: + "The number of highest probability vocabulary tokens to keep for top-p-filtering", + }, + "repetition-penalty": { + type: "number", + requiresArg: true, + nargs: 1, + description: + "Represents the penalty for repeated tokens in the output", + }, + "truncate-input-tokens": { + type: "number", + requiresArg: true, + nargs: 1, + description: + "Represents the number to which input tokens would be truncated", + }, + }, + "Configuration:" + ) + ) + .middleware((args) => { + const length_penalty = pickDefined({ + decay_factor: args.decayFactor, + start_index: args.decayStartIndex, + }); + const parameters = pickDefined({ + decoding_method: args.decodingMethod, + length_penalty: !isEmpty(length_penalty) ? length_penalty : undefined, + max_new_tokens: args.maxNewTokens, + min_new_tokens: args.minNewTokens, + random_seed: args.randomSeed, + stop_sequences: args.stopSequences, + temperature: args.temperature, + time_limit: args.timeLimit, + top_k: args.topK, + top_p: args.topP, + repetition_penalty: args.repetitionPenalty, + truncate_input_tokens: args.truncateInputTokens, + stream: args.stream, + }); + args.parameters = !isEmpty(parameters) ? parameters : undefined; + }) + .command(...defaultCommandDefinition) + .command(...interactiveCommandDefinition) + .command(...configCommandDefinition) + .example('$0 generate "Hello World"', "Supply single input") + .example("$0 generate -f inputs.jsonl", "Supply JSONL file with inputs") + .example( + "$0 generate config -m google/flan-t5-xxl --random-seed 2", + "Modify generate configuration with a given model and parameters" + ) + .demandCommand(1, 1, "Please choose a command"), +]; diff --git a/src/commands/generate/interactive.js b/src/commands/generate/interactive.js new file mode 100644 index 0000000..97b05e5 --- /dev/null +++ b/src/commands/generate/interactive.js @@ -0,0 +1,105 @@ +import { createInterface } from "node:readline"; +import { stdin, stdout } from "node:process"; +import { promisify } from "node:util"; + +import { isAbortError, isCancelOperationKey } from "../../utils/common.js"; + +export const interactiveCommandDefinition = [ + "interactive", + "Interactive context-free generate session", + {}, + async (args) => { + if (args.parameters?.stream === false) { + throw new Error("Stream is automatically enabled for interactive mode."); + } + + const ctx = { + isTerminated: false, + isProcessing: false, + abortOperation: () => {}, + }; + + const onUserActionCancel = () => { + ctx.abortOperation(); + if (!ctx.isProcessing) { + ctx.isTerminated = true; + } + }; + + const rl = createInterface({ + input: new Proxy(stdin, { + get(target, prop) { + if (prop === "on" || prop === "addEventListener") { + return (event, handler) => { + if (event === "data") { + const originalHandler = handler; + handler = (chunk) => { + if (!ctx.isProcessing || isCancelOperationKey(chunk)) { + originalHandler(chunk); + } + }; + } + return target.on(event, handler); + }; + } + return target[prop]; + }, + }), + output: stdout, + prompt: "", + }) + .on("SIGINT", onUserActionCancel) + .on("SIGTSTP", onUserActionCancel) + .on("SIGCONT", onUserActionCancel) + .on("history", (history) => { + if (ctx.isProcessing) { + history.shift(); + } + }); + + while (!ctx.isTerminated) { + ctx.isProcessing = false; + + try { + const controller = new AbortController(); + ctx.abortOperation = () => controller.abort(); + + const question = promisify(rl.question).bind(rl); + const input = await question("GenAI> ", controller); + + ctx.isProcessing = true; + if (input.length > 0) { + const stream = args.client.generate( + { + input: input, + model_id: args.model ?? "default", + parameters: args.parameters, + }, + { + timeout: args.timeout, + signal: controller.signal, + stream: true, + } + ); + + for await (const chunk of stream) { + rl.write(chunk.generated_text); + } + rl.write("\n"); + } + } catch (err) { + if (isAbortError(err)) { + // Clear line due to broken cursor + rl.write("\n"); + rl.write(null, { ctrl: true, name: "u" }); + } else { + console.error(err.message); + } + } + } + + rl.write("Goodbye"); + rl.write("\n"); + rl.close(); + }, +]; diff --git a/src/commands/history/default.js b/src/commands/history/default.js new file mode 100644 index 0000000..1814090 --- /dev/null +++ b/src/commands/history/default.js @@ -0,0 +1,51 @@ +import { + HistoryOriginSchema, + HistoryStatusSchema, +} from "@ibm-generative-ai/node-sdk"; +import dayjs from "dayjs"; + +import { parseDateTime } from "../../utils/parsers.js"; +import { groupOptions } from "../../utils/yargs.js"; +import { prettyPrint } from "../../utils/print.js"; + +export const defaultCommandDefinition = [ + "$0", + "Show the history of inference (past 30 days)", + (yargs) => + yargs.options( + groupOptions({ + from: { + type: "string", + requiresArg: true, + coerce: parseDateTime, + description: "Lower bound of the history timeframe [e.g. YYYY-MM-DD]", + }, + to: { + type: "string", + requiresArg: true, + coerce: parseDateTime, + description: "Upper bound of the history timeframe [e.g. YYYY-MM-DD]", + }, + status: { + choices: HistoryStatusSchema.options, + description: "Filter history by status", + }, + origin: { + choices: HistoryOriginSchema.options, + description: "Filter history by origin", + }, + }) + ), + async (args) => { + const { status, origin, from, to } = args; + for await (const output of args.client.history({ status, origin })) { + const createdAt = dayjs(output.created_at); + if ( + (!from || createdAt.isAfter(from)) && + (!to || createdAt.isBefore(to)) + ) { + prettyPrint(output); + } + } + }, +]; diff --git a/src/commands/history/index.js b/src/commands/history/index.js new file mode 100644 index 0000000..18c6725 --- /dev/null +++ b/src/commands/history/index.js @@ -0,0 +1,10 @@ +import { clientMiddleware } from "../../middleware/client.js"; + +import { defaultCommandDefinition } from "./default.js"; + +export const historyCommandDefinition = [ + "history", + "Show the history of inference (past 30 days)", + (yargs) => + yargs.middleware(clientMiddleware).command(...defaultCommandDefinition), +]; diff --git a/src/commands/models/index.js b/src/commands/models/index.js new file mode 100644 index 0000000..5af77ec --- /dev/null +++ b/src/commands/models/index.js @@ -0,0 +1,17 @@ +import { clientMiddleware } from "../../middleware/client.js"; + +import { infoCommandDefinition } from "./info.js"; +import { listCommandDefinition } from "./list.js"; +import { schemaCommandDefinition } from "./schema.js"; + +export const modelsCommandDefinition = [ + "models", + "Show information about available models", + (yargs) => + yargs + .middleware(clientMiddleware) + .command(...listCommandDefinition) + .command(...infoCommandDefinition) + .command(...schemaCommandDefinition) + .demandCommand(1, 1, "Please choose a command"), +]; diff --git a/src/commands/models/info.js b/src/commands/models/info.js new file mode 100644 index 0000000..c6bb971 --- /dev/null +++ b/src/commands/models/info.js @@ -0,0 +1,36 @@ +import { prettyPrint } from "../../utils/print.js"; + +export const infoCommandDefinition = [ + "info ", + "Show detailed information about a model", + { + model: { + type: "string", + }, + }, + async (args) => { + const { + id, + name, + size, + description, + token_limit, + source_model, + family, + tasks, + tags, + } = await args.client.model({ id: args.model }); + + prettyPrint({ + id, + name, + size, + description, + source_model: source_model?.id, + family, + token_limit, + tasks, + tags, + }); + }, +]; diff --git a/src/commands/models/list.js b/src/commands/models/list.js new file mode 100644 index 0000000..6e41dc8 --- /dev/null +++ b/src/commands/models/list.js @@ -0,0 +1,11 @@ +export const listCommandDefinition = [ + "list", + "List all available models", + {}, + async (args) => { + const models = await args.client.models(); + models.forEach((model) => { + console.log(model.id); + }); + }, +]; diff --git a/src/commands/models/schema.js b/src/commands/models/schema.js new file mode 100644 index 0000000..81ebb43 --- /dev/null +++ b/src/commands/models/schema.js @@ -0,0 +1,27 @@ +import { prettyPrint } from "../../utils/print.js"; + +export const schemaCommandDefinition = [ + "schema ", + "Show validation schema for a model", + { + type: { + alias: "t", + describe: "Type of the schema to show", + demandOption: true, + choices: ["generate", "tokenize"], + requiresArg: true, + type: "string", + }, + model: { + type: "string", + }, + }, + async (args) => { + const { schema_generate, schema_tokenize } = await args.client.model({ + id: args.model, + }); + + if (args.type === "generate") prettyPrint(schema_generate.value); + else if (args.type === "tokenize") prettyPrint(schema_tokenize.value); + }, +]; diff --git a/src/commands/tokenize/default.js b/src/commands/tokenize/default.js new file mode 100644 index 0000000..386a41d --- /dev/null +++ b/src/commands/tokenize/default.js @@ -0,0 +1,120 @@ +import { stdin, stdout } from "node:process"; +import { createReadStream, createWriteStream } from "node:fs"; + +import { BaseError as BaseSDKError } from "@ibm-generative-ai/node-sdk"; + +import { clientMiddleware } from "../../middleware/client.js"; +import { parseInput } from "../../utils/parsers.js"; +import { readJSONStream } from "../../utils/streams.js"; +import { groupOptions } from "../../utils/yargs.js"; + +export const defaultCommandDefinition = [ + "$0 [inputs..]", + "Convert provided inputs to tokens. Tokenization is model specific.", + (yargs) => + yargs + .middleware(clientMiddleware) + .options( + groupOptions( + { + model: { + alias: "m", + describe: "Select a model to be used for generation", + requiresArg: true, + type: "string", + coerce: (parameters) => { + if (typeof parameters !== "string") + throw new Error("Only a single model must be specified"); + return parameters; + }, + }, + returnTokens: { + type: "boolean", + default: true, + description: "Return tokens with the response. Defaults to true.", + }, + }, + "Configuration:" + ) + ) + .positional("inputs", { + describe: "Text serving as an input for the generation", + array: true, + conflicts: "input", + }) + .options( + groupOptions({ + file: { + alias: "f", + describe: + "File to read the inputs from. File MUST follow JSONL format", + array: true, + normalize: true, + requiresArg: true, + conflicts: "inputs", + coerce: async (files) => { + const inputs = await Promise.all( + files.map((file) => readJSONStream(createReadStream(file))) + ); + return inputs.flat().map(parseInput); + }, + }, + output: { + alias: "o", + describe: "File to write the outputs to", + type: "string", + normalize: true, + requiresArg: true, + coerce: (output) => { + if (typeof output !== "string") + throw new Error("Only a single output file must be specified"); + return createWriteStream(output); + }, + }, + }) + ), + async (args) => { + const fileInputs = args.file; + const inlineInputs = args.inputs; + const inputs = + inlineInputs ?? + fileInputs ?? + (await readJSONStream(stdin)).map(parseInput); + + const outputStream = args.output ?? stdout; + const results = await Promise.allSettled( + inputs.map(async (input) => { + const { token_count, tokens } = await args.client.tokenize( + { + model_id: args.model ?? "default", + parameters: { + return_tokens: args.returnTokens, + }, + input, + }, + { + timeout: args.timeout, + } + ); + return { token_count, tokens }; + }) + ); + + let hasError = false; + for (const result of results) { + if (result.status === "rejected") { + hasError = true; + outputStream.write(JSON.stringify({ error: result.reason?.message })); + } else { + outputStream.write(JSON.stringify(result.value)); + } + outputStream.write("\n"); + } + + if (hasError) { + throw new BaseSDKError( + "Errors have been encountered during tokenization, see output" + ); + } + }, +]; diff --git a/src/commands/tokenize/index.js b/src/commands/tokenize/index.js new file mode 100644 index 0000000..e91ea1c --- /dev/null +++ b/src/commands/tokenize/index.js @@ -0,0 +1,10 @@ +import { defaultCommandDefinition } from "./default.js"; + +export const tokenizeCommandDefinition = [ + "tokenize", + "Convert provided inputs to tokens", + (yargs) => + yargs + .command(...defaultCommandDefinition) + .demandCommand(1, 1, "Please choose a command"), +]; diff --git a/src/commands/tunes/create.js b/src/commands/tunes/create.js new file mode 100644 index 0000000..e38cd03 --- /dev/null +++ b/src/commands/tunes/create.js @@ -0,0 +1,145 @@ +import isEmpty from "lodash/isEmpty.js"; + +import { pickDefined } from "../../utils/common.js"; +import { groupOptions } from "../../utils/yargs.js"; +import { prettyPrint } from "../../utils/print.js"; + +export const createCommandDefinition = [ + "create", + "Create a tuned model", + (yargs) => + yargs + .options( + groupOptions({ + name: { + type: "string", + description: "Name of the tuned model", + requiresArg: true, + demandOption: true, + }, + model: { + type: "string", + description: "Model to be tuned", + requiresArg: true, + demandOption: true, + }, + task: { + description: "Tuning task", + requiresArg: true, + demandOption: true, + choices: ["generation", "classification", "summarization"], + }, + method: { + type: "string", + description: "Tuning method", + requiresArg: true, + demandOption: true, + }, + training: { + type: "array", + description: "Uploaded files to be used for training", + requiresArg: true, + demandOption: true, + }, + validation: { + type: "array", + description: "Uploaded files to be used for validation", + requiresArg: true, + }, + evaluation: { + type: "array", + description: "Uploaded files to be used for evaluation", + requiresArg: true, + }, + }) + ) + .options( + groupOptions( + { + "accumulate-steps": { + type: "number", + requiresArg: true, + default: 16, + description: + "Number of training steps to use to combine gradients", + }, + "batch-size": { + type: "number", + requiresArg: true, + default: 16, + description: + "Number of samples to work through before updating the internal model parameters", + }, + "learning-rate": { + type: "number", + requiresArg: true, + default: 0.3, + description: + "Learning rate to be used while tuning prompt vectors", + }, + "max-input-tokens": { + type: "number", + requiresArg: true, + default: 256, + description: + "Maximum number of tokens that are accepted in the input field for each example", + }, + "max-output-tokens": { + type: "number", + requiresArg: true, + default: 128, + description: + "Maximum number of tokens that are accepted in the output field for each example", + }, + "num-epochs": { + type: "number", + requiresArg: true, + default: 20, + description: + "The number of times to cycle through the training data set", + }, + "num-virtual-tokens": { + type: "number", + requiresArg: true, + default: 100, + description: "Number of virtual tokens to be used for training", + }, + "verbalizer": { + type: "string", + requiresArg: true, + description: + "Verbalizer template to be used for formatting data at train and inference time", + }, + }, + "Parameters:" + ) + ) + .middleware((args) => { + const parameters = pickDefined({ + accumulate_steps: args.accumulateSteps, + batch_size: args.batchSize, + init_method: args.initMethod, + init_text: args.initText, + learning_rate: args.learningRate, + max_input_tokens: args.maxInputTokens, + max_output_tokens: args.maxOutputTokens, + num_epochs: args.numEpochs, + num_virtual_tokens: args.numVirtualTokens, + verbalizer: args.verbalizer, + }); + args.parameters = !isEmpty(parameters) ? parameters : undefined; + }), + async (args) => { + const { id, name } = await args.client.tune({ + name: args.name, + model_id: args.model, + task_id: args.task, + method_id: args.method, + parameters: args.parameters, + training_file_ids: args.training, + validation_file_ids: args.validation, + evaluation_file_ids: args.evaluation, + }); + prettyPrint({ id, name }); + }, +]; diff --git a/src/commands/tunes/delete.js b/src/commands/tunes/delete.js new file mode 100644 index 0000000..8b568ab --- /dev/null +++ b/src/commands/tunes/delete.js @@ -0,0 +1,12 @@ +export const deleteCommandDefinition = [ + "delete ", + "Delete a tuned model", + (yargs) => + yargs.positional("id", { + type: "string", + description: "Identifier of the tuned model", + }), + async (args) => { + await args.client.tune({ id: args.id }, { delete: true }); + }, +]; diff --git a/src/commands/tunes/download.js b/src/commands/tunes/download.js new file mode 100644 index 0000000..1a0a263 --- /dev/null +++ b/src/commands/tunes/download.js @@ -0,0 +1,31 @@ +import { stdout } from "node:process"; +import { pipeline } from "node:stream/promises"; + +import { + BaseError as BaseSDKError, + TuneAssetTypeSchema, +} from "@ibm-generative-ai/node-sdk"; + +export const downloadCommandDefiniton = [ + "download ", + "Download assets of a completed tuned model", + (yargs) => + yargs + .positional("id", { + type: "string", + description: "Identifier of the tuned model", + }) + .positional("asset", { + describe: "Type of the asset", + choices: TuneAssetTypeSchema.options, + }), + async (args) => { + const { downloadAsset, status } = await args.client.tune({ + id: args.id, + }); + if (status !== "COMPLETED") + throw new BaseSDKError("Only completed tunes have assets available"); + const readable = await downloadAsset(args.type); + await pipeline(readable, args.output ?? stdout); + }, +]; diff --git a/src/commands/tunes/index.js b/src/commands/tunes/index.js new file mode 100644 index 0000000..7bb3315 --- /dev/null +++ b/src/commands/tunes/index.js @@ -0,0 +1,23 @@ +import { clientMiddleware } from "../../middleware/client.js"; + +import { createCommandDefinition } from "./create.js"; +import { deleteCommandDefinition } from "./delete.js"; +import { downloadCommandDefiniton } from "./download.js"; +import { infoCommandDefinition } from "./info.js"; +import { listCommandDefinition } from "./list.js"; +import { methodsCommandDefinition } from "./methods.js"; + +export const tunesCommandDefinition = [ + "tunes", + "Create and manage tuned models", + (yargs) => + yargs + .middleware(clientMiddleware) + .command(...methodsCommandDefinition) + .command(...listCommandDefinition) + .command(...infoCommandDefinition) + .command(...createCommandDefinition) + .command(...downloadCommandDefiniton) + .command(...deleteCommandDefinition) + .demandCommand(1, 1, "Please choose a command"), +]; diff --git a/src/commands/tunes/info.js b/src/commands/tunes/info.js new file mode 100644 index 0000000..f045994 --- /dev/null +++ b/src/commands/tunes/info.js @@ -0,0 +1,17 @@ +import { prettyPrint } from "../../utils/print.js"; + +export const infoCommandDefinition = [ + "info ", + "Show detailed information about a tuned model", + (yargs) => + yargs.positional("id", { + type: "string", + description: "Identifier of the tuned model", + }), + async (args) => { + const { id, name } = await args.client.tune({ + id: args.id, + }); + prettyPrint({ id, name }); + }, +]; diff --git a/src/commands/tunes/list.js b/src/commands/tunes/list.js new file mode 100644 index 0000000..70a424e --- /dev/null +++ b/src/commands/tunes/list.js @@ -0,0 +1,18 @@ +import { groupOptions } from "../../utils/yargs.js"; + +export const listCommandDefinition = [ + "list", + "List all tuned models", + groupOptions({ + name: { + type: "string", + description: "Filter tuned models by name", + requiresArg: true, + }, + }), + async (args) => { + for await (const tune of args.client.tunes({ search: args.name })) { + console.log(`${tune.id} (${tune.name})`); + } + }, +]; diff --git a/src/commands/tunes/methods.js b/src/commands/tunes/methods.js new file mode 100644 index 0000000..0b76881 --- /dev/null +++ b/src/commands/tunes/methods.js @@ -0,0 +1,11 @@ +export const methodsCommandDefinition = [ + "methods", + "List all tune methods", + {}, + async (args) => { + const methods = await args.client.tuneMethods(); + methods.forEach((method) => { + console.log(`${method.id} (${method.name})`); + }); + }, +]; diff --git a/src/parser.js b/src/parser.js index 25cf922..1d749c2 100644 --- a/src/parser.js +++ b/src/parser.js @@ -1,40 +1,14 @@ -import { stdin, stdout } from "node:process"; -import { createReadStream, createWriteStream } from "node:fs"; -import { createInterface } from "node:readline"; -import { promisify } from "node:util"; -import path from "node:path"; -import { pipeline } from "node:stream/promises"; - import yargs from "yargs"; import { hideBin } from "yargs/helpers"; -import _ from "lodash"; -import dayjs from "dayjs"; -import { - BaseError as BaseSDKError, - FilePurposeSchema, - HistoryStatusSchema, - HistoryOriginSchema, - TuneAssetTypeSchema, -} from "@ibm-generative-ai/node-sdk"; -import { parseDateTime, parseInput } from "./utils/parsers.js"; -import { readJSONStream } from "./utils/streams.js"; -import { prettyPrint } from "./utils/print.js"; -import { - isAbortError, - pickDefined, - isCancelOperationKey, -} from "./utils/common.js"; -import { groupOptions } from "./utils/yargs.js"; -import { - loadProfileConfig, - mergeConfig, - deleteProfileConfig, - allProfiles, -} from "./utils/config.js"; -import { clientMiddleware } from "./middleware/client.js"; import { profileMiddleware } from "./middleware/profile.js"; -import { DEFAULT_ENDPOINT } from "./utils/constants.js"; +import { generateCommandDefinition } from "./commands/generate/index.js"; +import { modelsCommandDefinition } from "./commands/models/index.js"; +import { tokenizeCommandDefinition } from "./commands/tokenize/index.js"; +import { filesCommandDefinition } from "./commands/files/index.js"; +import { historyCommandDefinition } from "./commands/history/index.js"; +import { tunesCommandDefinition } from "./commands/tunes/index.js"; +import { configCommandDefinition } from "./commands/config/index.js"; export const parser = yargs(hideBin(process.argv)) .options({ @@ -67,1076 +41,13 @@ export const parser = yargs(hideBin(process.argv)) .help() .alias("h", "help") .updateStrings({ "Options:": "Global Options:" }) - .command("config", "Manage CLI configuration", (yargs) => - yargs - .command("$0", "Modify configuration", {}, async (args) => { - const rl = createInterface({ - input: stdin, - output: stdout, - }); - const question = promisify(rl.question).bind(rl); - rl.on("close", () => { - console.log(); - }); - - const config = { - configuration: {}, - credentials: {}, - }; - if (args.profile) { - config.configuration.profiles = { [args.profile]: {} }; - config.credentials.profiles = { [args.profile]: {} }; - } - - const endpoint = await question( - `Endpoint (${ - args.endpoint ?? process.env.GENAI_ENDPOINT ?? DEFAULT_ENDPOINT - }): ` - ); - if (endpoint) { - if (args.profile) { - config.configuration.profiles[args.profile].endpoint = endpoint; - } else { - config.configuration.endpoint = endpoint; - } - } - const apiKey = await question(`API Key (${args.apiKey ?? "none"}): `); - if (apiKey) { - if (args.profile) { - config.credentials.profiles[args.profile].apiKey = apiKey; - } else { - config.credentials.apiKey = apiKey; - } - } - - rl.close(); - mergeConfig(config); - }) - .command("show", "Show configuration", {}, (args) => { - const config = loadProfileConfig(args.profile); - prettyPrint(config); - }) - .command("profiles", "List configuration profiles", {}, (args) => { - const profiles = allProfiles(); - profiles.forEach((profile) => { - console.log(profile); - }); - }) - .command("remove", "Remove configuration", {}, (args) => { - deleteProfileConfig(args.profile); - }) - ) - .command("generate", "Generate a text from an input text", (yargs) => - yargs - .middleware(clientMiddleware) - .options( - groupOptions( - { - model: { - alias: "m", - describe: "Select a model to be used for generation", - requiresArg: true, - type: "string", - coerce: (parameters) => { - if (typeof parameters !== "string") - throw new Error("Only a single model must be specified"); - return parameters; - }, - }, - stream: { - type: "boolean", - description: - "Enables to stream partial progress as server-sent events.", - }, - "decoding-method": { - type: "string", - requiresArg: true, - nargs: 1, - choices: ["greedy", "sample"], - description: - "Represents the strategy used for picking the tokens during generation of the output text", - }, - "decay-factor": { - type: "number", - requiresArg: true, - nargs: 1, - description: "Represents the factor of exponential decay", - }, - "decay-start-index": { - type: "number", - requiresArg: true, - nargs: 1, - description: - "A number of generated tokens after which the decay factor should take effect", - }, - "max-new-tokens": { - type: "number", - requiresArg: true, - nargs: 1, - description: "The maximum number of new tokens to be generated", - }, - "min-new-tokens": { - type: "number", - requiresArg: true, - nargs: 1, - description: "The minimum number of new tokens to be generated", - }, - "random-seed": { - type: "number", - requiresArg: true, - nargs: 1, - description: - "Random number generator seed to use in sampling mode for experimental repeatability", - }, - "stop-sequences": { - array: true, - type: "string", - requiresArg: true, - description: - "One or more strings which will cause the text generation to stop if detected in the output", - }, - temperature: { - type: "number", - requiresArg: true, - nargs: 1, - description: - "A value used to modify the next-token probabilities in sampling mode", - }, - "time-limit": { - type: "number", - requiresArg: true, - nargs: 1, - description: "Time limit in milliseconds", - }, - "top-k": { - type: "number", - requiresArg: true, - nargs: 1, - description: - "The number of highest probability vocabulary tokens to keep for top-k-filtering", - }, - "top-p": { - type: "number", - requiresArg: true, - nargs: 1, - description: - "The number of highest probability vocabulary tokens to keep for top-p-filtering", - }, - "repetition-penalty": { - type: "number", - requiresArg: true, - nargs: 1, - description: - "Represents the penalty for repeated tokens in the output", - }, - "truncate-input-tokens": { - type: "number", - requiresArg: true, - nargs: 1, - description: - "Represents the number to which input tokens would be truncated", - }, - }, - "Configuration:" - ) - ) - .middleware((args) => { - const length_penalty = pickDefined({ - decay_factor: args.decayFactor, - start_index: args.decayStartIndex, - }); - const parameters = pickDefined({ - decoding_method: args.decodingMethod, - length_penalty: !_.isEmpty(length_penalty) - ? length_penalty - : undefined, - max_new_tokens: args.maxNewTokens, - min_new_tokens: args.minNewTokens, - random_seed: args.randomSeed, - stop_sequences: args.stopSequences, - temperature: args.temperature, - time_limit: args.timeLimit, - top_k: args.topK, - top_p: args.topP, - repetition_penalty: args.repetitionPenalty, - truncate_input_tokens: args.truncateInputTokens, - stream: args.stream, - }); - args.parameters = !_.isEmpty(parameters) ? parameters : undefined; - }) - .command( - ["$0 [inputs..]"], // Default subcommand for generate command - "Generate a text based on an input. Outputs will follow JSONL format. Inputs coming from stdin MUST follow the JSONL format.", - (yargs) => - yargs - .positional("inputs", { - describe: "Text serving as an input for the generation", - array: true, - conflicts: "input", - }) - .options( - groupOptions({ - file: { - alias: "f", - describe: - "File to read the inputs from. File MUST follow JSONL format", - array: true, - normalize: true, - requiresArg: true, - conflicts: "inputs", - coerce: async (files) => { - const inputs = await Promise.all( - files.map((file) => - readJSONStream(createReadStream(file)) - ) - ); - return inputs.flat().map(parseInput); - }, - }, - output: { - alias: "o", - describe: "File to write the outputs to", - type: "string", - normalize: true, - requiresArg: true, - coerce: (output) => { - if (typeof output !== "string") - throw new Error( - "Only a single output file must be specified" - ); - return createWriteStream(output); - }, - }, - }) - ), - async (args) => { - const fileInputs = args.file; - const inlineInputs = args.inputs; - const inputs = - inlineInputs ?? - fileInputs ?? - (await readJSONStream(stdin)).map(parseInput); - - const outputStream = args.output ?? stdout; - const mappedInputs = inputs.map((input) => ({ - model_id: args.model ?? "default", - parameters: args.parameters, - input, - })); - - let hasError = false; - if (args.parameters?.stream) { - try { - for await (const chunk of args.client.generate(mappedInputs, { - timeout: args.timeout, - stream: true, - })) { - outputStream.write(chunk.generated_text); - } - } catch (err) { - hasError = true; - outputStream.write(JSON.stringify({ error: err.message })); - } finally { - outputStream.write("\n"); - } - } else { - for (const promise of args.client.generate(mappedInputs, { - timeout: args.timeout, - })) { - try { - const output = await promise; - outputStream.write(JSON.stringify(output.generated_text)); - } catch (err) { - hasError = true; - outputStream.write(JSON.stringify({ error: err.message })); - } finally { - outputStream.write("\n"); - } - } - } - - if (hasError) { - throw new BaseSDKError( - "Errors have been encountered during generation, see output" - ); - } - } - ) - .command( - "interactive", - "Interactive context-free generate session", - {}, - async (args) => { - if (args.parameters?.stream === false) { - throw new Error( - "Stream is automatically enabled for interactive mode." - ); - } - - const ctx = { - isTerminated: false, - isProcessing: false, - abortOperation: () => {}, - }; - - const onUserActionCancel = () => { - ctx.abortOperation(); - if (!ctx.isProcessing) { - ctx.isTerminated = true; - } - }; - - const rl = createInterface({ - input: new Proxy(stdin, { - get(target, prop) { - if (prop === "on" || prop === "addEventListener") { - return (event, handler) => { - if (event === "data") { - const originalHandler = handler; - handler = (chunk) => { - if (!ctx.isProcessing || isCancelOperationKey(chunk)) { - originalHandler(chunk); - } - }; - } - return target.on(event, handler); - }; - } - return target[prop]; - }, - }), - output: stdout, - prompt: "", - }) - .on("SIGINT", onUserActionCancel) - .on("SIGTSTP", onUserActionCancel) - .on("SIGCONT", onUserActionCancel) - .on("history", (history) => { - if (ctx.isProcessing) { - history.shift(); - } - }); - - while (!ctx.isTerminated) { - ctx.isProcessing = false; - - try { - const controller = new AbortController(); - ctx.abortOperation = () => controller.abort(); - - const question = promisify(rl.question).bind(rl); - const input = await question("GenAI> ", controller); - - ctx.isProcessing = true; - if (input.length > 0) { - const stream = args.client.generate( - { - input: input, - model_id: args.model ?? "default", - parameters: args.parameters, - }, - { - timeout: args.timeout, - signal: controller.signal, - stream: true, - } - ); - - for await (const chunk of stream) { - rl.write(chunk.generated_text); - } - rl.write("\n"); - } - } catch (err) { - if (isAbortError(err)) { - // Clear line due to broken cursor - rl.write("\n"); - rl.write(null, { ctrl: true, name: "u" }); - } else { - console.error(err.message); - } - } - } - - rl.write("Goodbye"); - rl.write("\n"); - rl.close(); - } - ) - .command( - "config", - "Read and modify generate configuration", - (yargs) => - yargs - .options( - groupOptions( - { - "input-text": { - type: "boolean", - description: "Include input text", - }, - "generated-tokens": { - type: "boolean", - description: - "Include the list of individual generated tokens", - }, - "input-tokens": { - type: "boolean", - description: "Include the list of input tokens", - }, - "token-logprobs": { - type: "boolean", - description: "Include logprob for each token", - }, - "token-ranks": { - type: "boolean", - description: "Include rank of each returned token", - }, - "top-n-tokens": { - type: "number", - requiresArg: true, - nargs: 1, - description: - "Include top n candidate tokens at the position of each returned token", - }, - }, - "Configuration:" - ) - ) - .middleware((args) => { - const return_options = pickDefined({ - input_text: args.inputText, - generated_tokens: args.generatedTokens, - input_tokens: args.inputTokens, - token_logprobs: args.tokenLogprobs, - input_ranks: args.tokenRanks, - top_n_tokens: args.topNTokens, - }); - args.parameters = !_.isEmpty(return_options) - ? _.merge({}, args.parameters, { return_options }) - : args.parameters; - }) - .options( - groupOptions({ - reset: { - describe: "Reset the config", - type: "boolean", - }, - replace: { - describe: "Replace the entire config", - type: "boolean", - conflicts: "reset", - }, - }) - ), - async (args) => { - const hasInput = args.model || args.parameters; - - const output = hasInput - ? await args.client.generateConfig( - { - model_id: args.model, - parameters: args.parameters, - }, - { - strategy: args.replace ? "replace" : "merge", - timeout: args.timeout, - } - ) - : await args.client.generateConfig({ - reset: args.reset, - timeout: args.timeout, - }); - - prettyPrint(output); - } - ) - .example('$0 generate "Hello World"', "Supply single input") - .example("$0 generate -f inputs.jsonl", "Supply JSONL file with inputs") - .example( - "$0 generate config -m google/flan-t5-xxl --random-seed 2", - "Modify generate configuration with a given model and parameters" - ) - .demandCommand(1, 1, "Please choose a command") - ) - .command("models", "Show information about available models", (yargs) => - yargs - .middleware(clientMiddleware) - .command("list", "List all available models", {}, async (args) => { - const models = await args.client.models(); - models.forEach((model) => { - console.log(model.id); - }); - }) - .command( - "info ", - "Show detailed information about a model", - { - model: { - type: "string", - }, - }, - async (args) => { - const { - id, - name, - size, - description, - token_limit, - source_model, - family, - tasks, - tags, - } = await args.client.model({ id: args.model }); - - prettyPrint({ - id, - name, - size, - description, - source_model: source_model?.id, - family, - token_limit, - tasks, - tags, - }); - } - ) - .command( - "schema ", - "Show validation schema for a model", - { - type: { - alias: "t", - describe: "Type of the schema to show", - demandOption: true, - choices: ["generate", "tokenize"], - requiresArg: true, - type: "string", - }, - model: { - type: "string", - }, - }, - async (args) => { - const { schema_generate, schema_tokenize } = await args.client.model({ - id: args.model, - }); - - if (args.type === "generate") prettyPrint(schema_generate.value); - else if (args.type === "tokenize") prettyPrint(schema_tokenize.value); - } - ) - .demandCommand(1, 1, "Please choose a command") - ) - .command("tokenize", "Convert provided inputs to tokens", (yargs) => - yargs - .command( - "$0 [inputs..]", - "Convert provided inputs to tokens. Tokenization is model specific.", - (yargs) => - yargs - .middleware(clientMiddleware) - .options( - groupOptions( - { - model: { - alias: "m", - describe: "Select a model to be used for generation", - requiresArg: true, - type: "string", - coerce: (parameters) => { - if (typeof parameters !== "string") - throw new Error( - "Only a single model must be specified" - ); - return parameters; - }, - }, - returnTokens: { - type: "boolean", - default: true, - description: - "Return tokens with the response. Defaults to true.", - }, - }, - "Configuration:" - ) - ) - .positional("inputs", { - describe: "Text serving as an input for the generation", - array: true, - conflicts: "input", - }) - .options( - groupOptions({ - file: { - alias: "f", - describe: - "File to read the inputs from. File MUST follow JSONL format", - array: true, - normalize: true, - requiresArg: true, - conflicts: "inputs", - coerce: async (files) => { - const inputs = await Promise.all( - files.map((file) => - readJSONStream(createReadStream(file)) - ) - ); - return inputs.flat().map(parseInput); - }, - }, - output: { - alias: "o", - describe: "File to write the outputs to", - type: "string", - normalize: true, - requiresArg: true, - coerce: (output) => { - if (typeof output !== "string") - throw new Error( - "Only a single output file must be specified" - ); - return createWriteStream(output); - }, - }, - }) - ), - async (args) => { - const fileInputs = args.file; - const inlineInputs = args.inputs; - const inputs = - inlineInputs ?? - fileInputs ?? - (await readJSONStream(stdin)).map(parseInput); - - const outputStream = args.output ?? stdout; - const results = await Promise.allSettled( - inputs.map(async (input) => { - const { token_count, tokens } = await args.client.tokenize( - { - model_id: args.model ?? "default", - parameters: { - return_tokens: args.returnTokens, - }, - input, - }, - { - timeout: args.timeout, - } - ); - return { token_count, tokens }; - }) - ); - - let hasError = false; - for (const result of results) { - if (result.status === "rejected") { - hasError = true; - outputStream.write( - JSON.stringify({ error: result.reason?.message }) - ); - } else { - outputStream.write(JSON.stringify(result.value)); - } - outputStream.write("\n"); - } - - if (hasError) { - throw new BaseSDKError( - "Errors have been encountered during tokenization, see output" - ); - } - } - ) - .demandCommand(1, 1, "Please choose a command") - ) - .command("files", "Upload and manage files", (yargs) => - yargs - .middleware(clientMiddleware) - .command( - "list", - "List all files", - groupOptions({ - purpose: { - alias: "p", - type: "array", - description: "Filter listed by purpose", - requiresArg: true, - choices: FilePurposeSchema.options, - }, - }), - async (args) => { - for await (const file of args.client.files()) { - if (!args.purpose || args.purpose.includes(file.purpose)) { - console.log(`${file.id} (${file.file_name})`); - } - } - } - ) - .command( - "info ", - "Show detailed information about a file", - (yargs) => - yargs.positional("id", { - type: "string", - description: "Identifier of the file", - }), - async (args) => { - const { id, file_name, purpose, created_at } = await args.client.file( - { - id: args.id, - } - ); - prettyPrint({ id, name: file_name, purpose, created_at }); - } - ) - .command( - "upload ", - "Upload a file", - (yargs) => - yargs - .positional("file", { - type: "string", - describe: "Filepath to the file to be uploaded", - normalize: true, - }) - .options( - groupOptions({ - purpose: { - alias: "p", - description: "Purpose of the file", - requiresArg: true, - demandOption: true, - choices: FilePurposeSchema.options, - }, - name: { - alias: "n", - type: "string", - description: "Name of the file", - requiresArg: true, - }, - }) - ), - async (args) => { - const { id, file_name, purpose, created_at } = await args.client.file( - { - purpose: args.purpose, - filename: args.name ?? path.parse(args.file).base, - file: createReadStream(args.file), - } - ); - prettyPrint({ id, name: file_name, purpose, created_at }); - } - ) - .command( - "download ", - "Download a file", - (yargs) => - yargs - .positional("id", { - type: "string", - description: "Identifier of the file to download", - }) - .options( - groupOptions({ - output: { - alias: "o", - describe: "Filepath to write the file to", - type: "string", - normalize: true, - requiresArg: true, - coerce: (output) => createWriteStream(output), - }, - }) - ), - async (args) => { - const { download } = await args.client.file({ - id: args.id, - }); - const readable = await download(); - await pipeline(readable, args.output ?? stdout); - } - ) - .command( - "delete ", - "Delete a file", - (yargs) => - yargs.positional("id", { - type: "string", - description: "Identifier of the file to be deleted", - }), - async (args) => { - await args.client.file({ id: args.id }, { delete: true }); - } - ) - .demandCommand(1, 1, "Please choose a command") - ) - .command("history", "Show the history of inference (past 30 days)", (yargs) => - yargs.middleware(clientMiddleware).command( - "$0", - "Show the history of inference (past 30 days)", - (yargs) => - yargs.options( - groupOptions({ - from: { - type: "string", - requiresArg: true, - coerce: parseDateTime, - description: - "Lower bound of the history timeframe [e.g. YYYY-MM-DD]", - }, - to: { - type: "string", - requiresArg: true, - coerce: parseDateTime, - description: - "Upper bound of the history timeframe [e.g. YYYY-MM-DD]", - }, - status: { - choices: HistoryStatusSchema.options, - description: "Filter history by status", - }, - origin: { - choices: HistoryOriginSchema.options, - description: "Filter history by origin", - }, - }) - ), - async (args) => { - const { status, origin, from, to } = args; - for await (const output of args.client.history({ status, origin })) { - const createdAt = dayjs(output.created_at); - if ( - (!from || createdAt.isAfter(from)) && - (!to || createdAt.isBefore(to)) - ) { - prettyPrint(output); - } - } - } - ) - ) - .command("tunes", "Create and manage tuned models", (yargs) => - yargs - .middleware(clientMiddleware) - .command("methods", "List all tune methods", {}, async (args) => { - const methods = await args.client.tuneMethods(); - methods.forEach((method) => { - console.log(`${method.id} (${method.name})`); - }); - }) - .command( - "list", - "List all tuned models", - groupOptions({ - name: { - type: "string", - description: "Filter tuned models by name", - requiresArg: true, - }, - }), - async (args) => { - for await (const tune of args.client.tunes({ search: args.name })) { - console.log(`${tune.id} (${tune.name})`); - } - } - ) - .command( - "info ", - "Show detailed information about a tuned model", - (yargs) => - yargs.positional("id", { - type: "string", - description: "Identifier of the tuned model", - }), - async (args) => { - const { id, name } = await args.client.tune({ - id: args.id, - }); - prettyPrint({ id, name }); - } - ) - .command( - "create", - "Create a tuned model", - (yargs) => - yargs - .options( - groupOptions({ - name: { - type: "string", - description: "Name of the tuned model", - requiresArg: true, - demandOption: true, - }, - model: { - type: "string", - description: "Model to be tuned", - requiresArg: true, - demandOption: true, - }, - task: { - description: "Tuning task", - requiresArg: true, - demandOption: true, - choices: ["generation", "classification", "summarization"], - }, - method: { - type: "string", - description: "Tuning method", - requiresArg: true, - demandOption: true, - }, - training: { - type: "array", - description: "Uploaded files to be used for training", - requiresArg: true, - demandOption: true, - }, - validation: { - type: "array", - description: "Uploaded files to be used for validation", - requiresArg: true, - }, - evaluation: { - type: "array", - description: "Uploaded files to be used for evaluation", - requiresArg: true, - }, - }) - ) - .options( - groupOptions( - { - "accumulate-steps": { - type: "number", - requiresArg: true, - default: 16, - description: - "Number of training steps to use to combine gradients", - }, - "batch-size": { - type: "number", - requiresArg: true, - default: 16, - description: - "Number of samples to work through before updating the internal model parameters", - }, - "learning-rate": { - type: "number", - requiresArg: true, - default: 0.3, - description: - "Learning rate to be used while tuning prompt vectors", - }, - "max-input-tokens": { - type: "number", - requiresArg: true, - default: 256, - description: - "Maximum number of tokens that are accepted in the input field for each example", - }, - "max-output-tokens": { - type: "number", - requiresArg: true, - default: 128, - description: - "Maximum number of tokens that are accepted in the output field for each example", - }, - "num-epochs": { - type: "number", - requiresArg: true, - default: 20, - description: - "The number of times to cycle through the training data set", - }, - "num-virtual-tokens": { - type: "number", - requiresArg: true, - default: 100, - description: - "Number of virtual tokens to be used for training", - }, - verbalizer: { - type: "string", - requiresArg: true, - description: - "Verbalizer template to be used for formatting data at train and inference time", - }, - }, - "Parameters:" - ) - ) - .middleware((args) => { - const parameters = pickDefined({ - accumulate_steps: args.accumulateSteps, - batch_size: args.batchSize, - init_method: args.initMethod, - init_text: args.initText, - learning_rate: args.learningRate, - max_input_tokens: args.maxInputTokens, - max_output_tokens: args.maxOutputTokens, - num_epochs: args.numEpochs, - num_virtual_tokens: args.numVirtualTokens, - verbalizer: args.verbalizer, - }); - args.parameters = !_.isEmpty(parameters) ? parameters : undefined; - }), - async (args) => { - const { id, name } = await args.client.tune({ - name: args.name, - model_id: args.model, - task_id: args.task, - method_id: args.method, - parameters: args.parameters, - training_file_ids: args.training, - validation_file_ids: args.validation, - evaluation_file_ids: args.evaluation, - }); - prettyPrint({ id, name }); - } - ) - .command( - "download ", - "Download assets of a completed tuned model", - (yargs) => - yargs - .positional("id", { - type: "string", - description: "Identifier of the tuned model", - }) - .positional("asset", { - describe: "Type of the asset", - choices: TuneAssetTypeSchema.options, - }), - async (args) => { - const { downloadAsset, status } = await args.client.tune({ - id: args.id, - }); - if (status !== "COMPLETED") - throw new BaseSDKError( - "Only completed tunes have assets available" - ); - const readable = await downloadAsset(args.type); - await pipeline(readable, args.output ?? stdout); - } - ) - .command( - "delete ", - "Delete a tuned model", - (yargs) => - yargs.positional("id", { - type: "string", - description: "Identifier of the tuned model", - }), - async (args) => { - await args.client.tune({ id: args.id }, { delete: true }); - } - ) - .demandCommand(1, 1, "Please choose a command") - ) + .command(...configCommandDefinition) + .command(...generateCommandDefinition) + .command(...modelsCommandDefinition) + .command(...tokenizeCommandDefinition) + .command(...filesCommandDefinition) + .command(...historyCommandDefinition) + .command(...tunesCommandDefinition) .demandCommand(1, 1, "Please choose a command") .strict() .fail(false) diff --git a/src/utils/common.js b/src/utils/common.js index f4f67d8..da7a11e 100644 --- a/src/utils/common.js +++ b/src/utils/common.js @@ -1,6 +1,6 @@ -import _ from "lodash"; +import pickBy from "lodash/pickBy.js"; -export const pickDefined = (obj) => _.pickBy(obj, (v) => v !== undefined); +export const pickDefined = (obj) => pickBy(obj, (v) => v !== undefined); export const isAbortError = (err) => Boolean(err && err.name === "AbortError"); diff --git a/src/utils/config.js b/src/utils/config.js index 864fa12..a674f7c 100644 --- a/src/utils/config.js +++ b/src/utils/config.js @@ -3,7 +3,9 @@ import path from "path"; import fs, { writeFileSync } from "fs"; import YAML from "yaml"; -import _ from "lodash"; +import pick from "lodash/pick.js"; +import merge from "lodash/merge.js"; +import cloneDeep from "lodash/cloneDeep.js"; import { z } from "zod"; const CONFIG_DIR_PATH = path.join(os.homedir(), ".genai"); @@ -82,7 +84,7 @@ export const storeConfig = (config) => { export const mergeConfig = (config) => { const currentConfig = loadConfig(); - storeConfig(_.merge({}, currentConfig, config)); + storeConfig(merge({}, currentConfig, config)); }; export function loadProfileConfig(profile) { @@ -98,7 +100,7 @@ export function loadProfileConfig(profile) { } export function deleteProfileConfig(profile) { - const config = _.cloneDeep(loadConfig()); + const config = cloneDeep(loadConfig()); if (profile) { [config.configuration.profiles, config.credentials.profiles].forEach( (profiles) => { @@ -110,8 +112,8 @@ export function deleteProfileConfig(profile) { storeConfig(config); } else { storeConfig({ - configuration: _.pick(config.configuration, "profiles"), - credentials: _.pick(config.credentials, "profiles"), + configuration: pick(config.configuration, "profiles"), + credentials: pick(config.credentials, "profiles"), }); } }