Skip to content

Commit

Permalink
feat: adds tune LLM sample
Browse files Browse the repository at this point in the history
  • Loading branch information
telpirion committed Jun 26, 2023
2 parents 5116f77 + a8e2ce0 commit 2e78e88
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 0 deletions.
82 changes: 82 additions & 0 deletions ai-platform/snippets/test/tuning.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/* eslint-disable */

'use strict';

const {assert} = require('chai');
const {describe, it} = require('mocha');
const uuid = require('uuid')
const sinon = require('sinon');

const aiplatform = require('@google-cloud/aiplatform');
const clientOptions = {
apiEndpoint: 'europe-west4-aiplatform.googleapis.com',
};
const pipelineClient = new aiplatform.v1.PipelineServiceClient(clientOptions);

const {tuneModel} = require('../tuning');

const projectId = process.env.CAIP_PROJECT_ID;
const location = 'europe-west4';
const timestampId = `${new Date().toISOString().replace(/(:|\.)/g, '-').toLowerCase()}`
const pipelineJobName = `my-tuning-pipeline-${timestampId}`
const modelDisplayName = `my-tuned-model-${timestampId}`
const bucketName = `ucaip-samples-europe-west4/training_pipeline_output`;
const bucketUri = `gs://${bucketName}/tune-model-nodejs`

describe('Tune a model', () => {
const stubConsole = function () {
sinon.stub(console, 'error');
sinon.stub(console, 'log');
};

const restoreConsole = function () {
console.log.restore();
console.error.restore();
};

after(async () => {
// Cancel and delete the pipeline job
const name = pipelineClient.pipelineJobPath(
projectId,
location,
pipelineJobName
);

const cancelRequest = {
name,
};

pipelineClient.cancelPipelineJob(cancelRequest).then(() => {
const deleteRequest = {
name,
};

return pipelineClient.deletePipeline(deleteRequest);
});
});

beforeEach(stubConsole);
afterEach(restoreConsole);

it('should prompt-tune an existing model', async () => {
// Act
await tuneModel(projectId, pipelineJobName, modelDisplayName, bucketUri);

// Assert
assert.include(console.log.firstCall.args, 'Tuning pipeline job:');
});
});
101 changes: 101 additions & 0 deletions ai-platform/snippets/tuning.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

'use strict';

async function main(
project,
pipelineJobId,
modelDisplayName,
gcsOutputDirectory,
location = 'europe-west4',
datasetUri = 'gs://cloud-samples-data/ai-platform/generative_ai/headline_classification.jsonl',
trainSteps = 10
) {
// [START aiplatform_model_tuning]
/**
* TODO(developer): Uncomment these variables before running the sample.\
* (Not necessary if passing values as arguments)
*/
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {PipelineServiceClient} = aiplatform.v1;

// Import the helper module for converting arbitrary protobuf.Value objects.
const {helpers} = aiplatform;

// Specifies the location of the api endpoint
const clientOptions = {
apiEndpoint: 'europe-west4-aiplatform.googleapis.com',
};
const model = 'text-bison@001';

const pipelineClient = new PipelineServiceClient(clientOptions);

async function tuneLLM() {
// Configure the parent resource
const parent = `projects/${project}/locations/${location}`;

const parameters = {
train_steps: helpers.toValue(trainSteps),
project: helpers.toValue(project),
location: helpers.toValue('us-central1'),
dataset_uri: helpers.toValue(datasetUri),
large_model_reference: helpers.toValue(model),
model_display_name: helpers.toValue(modelDisplayName),
};

const runtimeConfig = {
gcsOutputDirectory,
parameterValues: parameters,
};

const pipelineJob = {
templateUri:
'https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v1.0.0',
displayName: 'my-tuning-job',
runtimeConfig,
};

const createPipelineRequest = {
parent,
pipelineJob,
pipelineJobId,
};
await new Promise((resolve, reject) => {
pipelineClient.createPipelineJob(createPipelineRequest).then(
response => resolve(response),
e => reject(e)
);
}).then(response => {
const [result] = response;
console.log('Tuning pipeline job:');
console.log(`\tName: ${result.name}`);
console.log(
`\tCreate time: ${new Date(1970, 0, 1)
.setSeconds(result.createTime.seconds)
.toLocaleString()}`
);
console.log(`\tStatus: ${result.status}`);
});
}

await tuneLLM();
// [END aiplatform_model_tuning]
}

exports.tuneModel = main;

0 comments on commit 2e78e88

Please sign in to comment.