diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts index a4562d66349cc..330884868acf2 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts @@ -16,6 +16,7 @@ import { TrainedModelState } from '../../../../../../../common/types/pipelines'; import { GetDocumentsApiLogic } from '../../../../api/documents/get_document_logic'; import { MappingsApiLogic } from '../../../../api/mappings/mappings_logic'; import { MLModelsApiLogic } from '../../../../api/ml_models/ml_models_logic'; +import { StartTextExpansionModelApiLogic } from '../../../../api/ml_models/text_expansion/start_text_expansion_model_api_logic'; import { AttachMlInferencePipelineApiLogic } from '../../../../api/pipelines/attach_ml_inference_pipeline'; import { CreateMlInferencePipelineApiLogic } from '../../../../api/pipelines/create_ml_inference_pipeline'; import { FetchMlInferencePipelineProcessorsApiLogic } from '../../../../api/pipelines/fetch_ml_inference_pipeline_processors'; @@ -85,6 +86,7 @@ describe('MlInferenceLogic', () => { FetchMlInferencePipelinesApiLogic ); const { mount: mountGetDocumentsApiLogic } = new LogicMounter(GetDocumentsApiLogic); + const { mount: mountStartTextExpansionModel } = new LogicMounter(StartTextExpansionModelApiLogic); beforeEach(() => { jest.clearAllMocks(); @@ -97,6 +99,7 @@ describe('MlInferenceLogic', () => { mountCreateMlInferencePipelineApiLogic(); mountAttachMlInferencePipelineApiLogic(); mountGetDocumentsApiLogic(); + mountStartTextExpansionModel(); mount(); }); @@ -628,5 +631,16 @@ describe('MlInferenceLogic', () => { }); }); }); + describe('startTextExpansionModelSuccess', () => { + it('fetches ml models', () => { + jest.spyOn(MLInferenceLogic.actions, 'makeMLModelsRequest'); + StartTextExpansionModelApiLogic.actions.apiSuccess({ + deploymentState: 'started', + modelId: 'foo', + }); + + expect(MLInferenceLogic.actions.makeMLModelsRequest).toHaveBeenCalledWith(undefined); + }); + }); }); }); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts index 2cc9a7eabea35..18b0cb8ab8328 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts @@ -39,6 +39,10 @@ import { TrainedModelsApiLogicActions, TrainedModelsApiLogic, } from '../../../../api/ml_models/ml_trained_models_logic'; +import { + StartTextExpansionModelApiLogic, + StartTextExpansionModelApiLogicActions, +} from '../../../../api/ml_models/text_expansion/start_text_expansion_model_api_logic'; import { AttachMlInferencePipelineApiLogic, AttachMlInferencePipelineApiLogicArgs, @@ -156,6 +160,7 @@ interface MLInferenceProcessorsActions { setInferencePipelineConfiguration: (configuration: InferencePipelineConfiguration) => { configuration: InferencePipelineConfiguration; }; + startTextExpansionModelSuccess: StartTextExpansionModelApiLogicActions['apiSuccess']; } export interface AddInferencePipelineModal { @@ -230,6 +235,8 @@ export const MLInferenceLogic = kea< ], PipelinesLogic, ['closeAddMlInferencePipelineModal as closeAddMlInferencePipelineModal'], + StartTextExpansionModelApiLogic, + ['apiSuccess as startTextExpansionModelSuccess'], ], values: [ CachedFetchIndexApiLogic, @@ -316,6 +323,10 @@ export const MLInferenceLogic = kea< }); } }, + startTextExpansionModelSuccess: () => { + // Refresh ML models list when the text expansion model is started + actions.makeMLModelsRequest(undefined); + }, }), path: ['enterprise_search', 'content', 'pipelines_add_ml_inference_pipeline'], reducers: {