Skip to content

Commit

Permalink
Add transformations for text_embedding and text_classification (#143603)
Browse files Browse the repository at this point in the history
* Add transformations for text_embedding and text_classification

* Better expectation in jest test

* Change type import.

* Update types

* Fix tests

* Hack to workaround ml-plugin constant import issue
  • Loading branch information
brianmcgue authored Oct 19, 2022
1 parent 19b09ab commit 8b8c1fc
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,34 @@
* 2.0.
*/

import { MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { IngestSetProcessor, MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/types';
import { BUILT_IN_MODEL_TAG } from '@kbn/ml-plugin/common/constants/data_frame_analytics';
import { SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-plugin/common/constants/trained_models';

import { getMlModelTypesForModelConfig, BUILT_IN_MODEL_TAG as LOCAL_BUILT_IN_MODEL_TAG } from '.';
import { MlInferencePipeline } from '../types/pipelines';

import {
BUILT_IN_MODEL_TAG as LOCAL_BUILT_IN_MODEL_TAG,
generateMlInferencePipelineBody,
getMlModelTypesForModelConfig,
getSetProcessorForInferenceType,
SUPPORTED_PYTORCH_TASKS as LOCAL_SUPPORTED_PYTORCH_TASKS,
} from '.';

const mockModel: MlTrainedModelConfig = {
inference_config: {
ner: {},
},
input: {
field_names: [],
},
model_id: 'test_id',
model_type: 'pytorch',
tags: ['test_tag'],
version: '1',
};

describe('getMlModelTypesForModelConfig lib function', () => {
const mockModel: MlTrainedModelConfig = {
inference_config: {
ner: {},
},
input: {
field_names: [],
},
model_id: 'test_id',
model_type: 'pytorch',
tags: ['test_tag'],
};
const builtInMockModel: MlTrainedModelConfig = {
inference_config: {
text_classification: {},
Expand Down Expand Up @@ -50,3 +61,140 @@ describe('getMlModelTypesForModelConfig lib function', () => {
expect(LOCAL_BUILT_IN_MODEL_TAG).toEqual(BUILT_IN_MODEL_TAG);
});
});

describe('getSetProcessorForInferenceType lib function', () => {
const destinationField = 'dest';

it('local LOCAL_SUPPORTED_PYTORCH_TASKS matches ml plugin', () => {
expect(SUPPORTED_PYTORCH_TASKS).toEqual(LOCAL_SUPPORTED_PYTORCH_TASKS);
});

it('should return expected value for TEXT_CLASSIFICATION', () => {
const inferenceType = SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION;

const expected: IngestSetProcessor = {
copy_from: 'ml.inference.dest.predicted_value',
description:
"Copy the predicted_value to 'dest' if the prediction_probability is greater than 0.5",
field: destinationField,
if: 'ml.inference.dest.prediction_probability > 0.5',
value: undefined,
};

expect(getSetProcessorForInferenceType(destinationField, inferenceType)).toEqual(expected);
});

it('should return expected value for TEXT_EMBEDDING', () => {
const inferenceType = SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING;

const expected: IngestSetProcessor = {
copy_from: 'ml.inference.dest.predicted_value',
description: "Copy the predicted_value to 'dest'",
field: destinationField,
value: undefined,
};

expect(getSetProcessorForInferenceType(destinationField, inferenceType)).toEqual(expected);
});

it('should return undefined for unknown inferenceType', () => {
const inferenceType = 'wrongInferenceType';

expect(getSetProcessorForInferenceType(destinationField, inferenceType)).toBeUndefined();
});
});

describe('generateMlInferencePipelineBody lib function', () => {
const expected: MlInferencePipeline = {
description: 'my-description',
processors: [
{
remove: {
field: 'ml.inference.my-destination-field',
ignore_missing: true,
},
},
{
inference: {
field_map: {
'my-source-field': 'MODEL_INPUT_FIELD',
},
model_id: 'test_id',
on_failure: [
{
append: {
field: '_source._ingest.inference_errors',
value: [
{
message:
"Processor 'inference' in pipeline 'my-pipeline' failed with message '{{ _ingest.on_failure_message }}'",
pipeline: 'my-pipeline',
timestamp: '{{{ _ingest.timestamp }}}',
},
],
},
},
],
target_field: 'ml.inference.my-destination-field',
},
},
{
append: {
field: '_source._ingest.processors',
value: [
{
model_version: '1',
pipeline: 'my-pipeline',
processed_timestamp: '{{{ _ingest.timestamp }}}',
types: ['pytorch', 'ner'],
},
],
},
},
],
version: 1,
};

it('should return something expected', () => {
const actual: MlInferencePipeline = generateMlInferencePipelineBody({
description: 'my-description',
destinationField: 'my-destination-field',
model: mockModel,
pipelineName: 'my-pipeline',
sourceField: 'my-source-field',
});

expect(actual).toEqual(expected);
});

it('should return something expected 2', () => {
const mockTextClassificationModel: MlTrainedModelConfig = {
...mockModel,
...{ inference_config: { text_classification: {} } },
};
const actual: MlInferencePipeline = generateMlInferencePipelineBody({
description: 'my-description',
destinationField: 'my-destination-field',
model: mockTextClassificationModel,
pipelineName: 'my-pipeline',
sourceField: 'my-source-field',
});

expect(actual).toEqual(
expect.objectContaining({
description: expect.any(String),
processors: expect.arrayContaining([
expect.objectContaining({
set: {
copy_from: 'ml.inference.my-destination-field.predicted_value',
description:
"Copy the predicted_value to 'my-destination-field' if the prediction_probability is greater than 0.5",
field: 'my-destination-field',
if: 'ml.inference.my-destination-field.prediction_probability > 0.5',
},
}),
]),
})
);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,25 @@
* 2.0.
*/

import { MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { IngestSetProcessor, MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/types';

import { MlInferencePipeline } from '../types/pipelines';

// Getting an error importing this from @kbn/ml-plugin/common/constants/data_frame_analytics'
// So defining it locally for now with a test to make sure it matches.
export const BUILT_IN_MODEL_TAG = 'prepackaged';

// Getting an error importing this from @kbn/ml-plugin/common/constants/trained_models'
// So defining it locally for now with a test to make sure it matches.
export const SUPPORTED_PYTORCH_TASKS = {
FILL_MASK: 'fill_mask',
NER: 'ner',
QUESTION_ANSWERING: 'question_answering',
TEXT_CLASSIFICATION: 'text_classification',
TEXT_EMBEDDING: 'text_embedding',
ZERO_SHOT_CLASSIFICATION: 'zero_shot_classification',
} as const;

export interface MlInferencePipelineParams {
description?: string;
destinationField: string;
Expand All @@ -36,6 +47,10 @@ export const generateMlInferencePipelineBody = ({
// if model returned no input field, insert a placeholder
const modelInputField =
model.input?.field_names?.length > 0 ? model.input.field_names[0] : 'MODEL_INPUT_FIELD';

const inferenceType = Object.keys(model.inference_config)[0];
const set = getSetProcessorForInferenceType(destinationField, inferenceType);

return {
description: description ?? '',
processors: [
Expand All @@ -51,21 +66,21 @@ export const generateMlInferencePipelineBody = ({
[sourceField]: modelInputField,
},
model_id: model.model_id,
target_field: `ml.inference.${destinationField}`,
on_failure: [
{
append: {
field: '_source._ingest.inference_errors',
value: [
{
pipeline: pipelineName,
message: `Processor 'inference' in pipeline '${pipelineName}' failed with message '{{ _ingest.on_failure_message }}'`,
pipeline: pipelineName,
timestamp: '{{{ _ingest.timestamp }}}',
},
],
},
},
],
target_field: `ml.inference.${destinationField}`,
},
},
{
Expand All @@ -81,11 +96,39 @@ export const generateMlInferencePipelineBody = ({
],
},
},
...(set ? [{ set }] : []),
],
version: 1,
};
};

export const getSetProcessorForInferenceType = (
destinationField: string,
inferenceType: string
): IngestSetProcessor | undefined => {
let set: IngestSetProcessor | undefined;
const prefixedDestinationField = `ml.inference.${destinationField}`;

if (inferenceType === SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION) {
set = {
copy_from: `${prefixedDestinationField}.predicted_value`,
description: `Copy the predicted_value to '${destinationField}' if the prediction_probability is greater than 0.5`,
field: destinationField,
if: `${prefixedDestinationField}.prediction_probability > 0.5`,
value: undefined,
};
} else if (inferenceType === SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING) {
set = {
copy_from: `${prefixedDestinationField}.predicted_value`,
description: `Copy the predicted_value to '${destinationField}'`,
field: destinationField,
value: undefined,
};
}

return set;
};

/**
* Parses model types list from the given configuration of a trained machine learning model
* @param trainedModel configuration for a trained machine learning model
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import '@elastic/elasticsearch/lib/api/types';

// TODO: Remove once type fixed in elasticsearch-specification
// (add github issue)
declare module '@elastic/elasticsearch/lib/api/types' {
// This workaround adds copy_from and description to the original IngestSetProcess and makes value
// optional. It should be value xor copy_from, but that requires using type unions. This
// workaround requires interface merging (ie, not types), so we cannot get.
export interface IngestSetProcessor {
copy_from?: string;
description?: string;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

import { IngestGetPipelineResponse } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { IngestGetPipelineResponse } from '@elastic/elasticsearch/lib/api/types';
import { IScopedClusterClient } from '@kbn/core/server';

export const getCustomPipelines = async (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

import { IngestGetPipelineResponse } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { IngestGetPipelineResponse } from '@elastic/elasticsearch/lib/api/types';
import { IScopedClusterClient } from '@kbn/core/server';

export const getPipeline = async (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ describe('createMlInferencePipeline util function', () => {
Promise.resolve({
trained_model_configs: [
{
inference_config: {
ner: {},
},
input: {
field_names: ['target-field'],
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import {
* Details of a created pipeline.
*/
export interface CreatedPipeline {
id: string;
created?: boolean;
addedToParentPipeline?: boolean;
created?: boolean;
id: string;
}

/**
Expand Down Expand Up @@ -110,8 +110,8 @@ export const createMlInferencePipeline = async (
});

return Promise.resolve({
id: inferencePipelineGeneratedName,
created: true,
id: inferencePipelineGeneratedName,
});
};

Expand Down Expand Up @@ -143,8 +143,8 @@ export const addSubPipelineToIndexSpecificMlPipeline = async (
// Verify the parent pipeline exists with a processors array
if (!parentPipeline?.processors) {
return Promise.resolve({
id: pipelineName,
addedToParentPipeline: false,
id: pipelineName,
});
}

Expand All @@ -155,8 +155,8 @@ export const addSubPipelineToIndexSpecificMlPipeline = async (
);
if (existingSubPipeline) {
return Promise.resolve({
id: pipelineName,
addedToParentPipeline: false,
id: pipelineName,
});
}

Expand All @@ -173,7 +173,7 @@ export const addSubPipelineToIndexSpecificMlPipeline = async (
});

return Promise.resolve({
id: pipelineName,
addedToParentPipeline: true,
id: pipelineName,
});
};

0 comments on commit 8b8c1fc

Please sign in to comment.