Skip to content

Commit

Permalink
Add source field to distinguish local and external model (#239)
Browse files Browse the repository at this point in the history
* Revert "feat: exclude remote model for admin UI (#225)"

This reverts commit ba01d34.

Signed-off-by: Lin Wang <wonglam@amazon.com>

* feat: add source field to distinguish local and external model

Signed-off-by: Lin Wang <wonglam@amazon.com>

* feat: add miss display words

Signed-off-by: Lin Wang <wonglam@amazon.com>

---------

Signed-off-by: Lin Wang <wonglam@amazon.com>
(cherry picked from commit 63c7a5a)
  • Loading branch information
wanglam authored and github-actions[bot] committed Aug 2, 2023
1 parent f542649 commit 9566494
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 49 deletions.
1 change: 0 additions & 1 deletion public/apis/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ export class Model {
size: number;
states?: MODEL_STATE[];
nameOrId?: string;
exclude?: 'REMOTE_MODEL';
}) {
return InnerHttpProvider.getHttp().get<ModelSearchResponse>(MODEL_API_ENDPOINT, {
query,
Expand Down
11 changes: 9 additions & 2 deletions public/components/monitoring/model_deployment_table.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ export const ModelDeploymentTable = ({
{
field: 'model_state',
name: 'Status',
width: '37.5%',
width: '34.5%',
sortable: true,
truncateText: true,
render: (
Expand Down Expand Up @@ -124,10 +124,17 @@ export const ModelDeploymentTable = ({
);
},
},
{
field: 'source',
name: 'Source',
width: '7%',
sortable: false,
truncateText: true,
},
{
field: 'id',
name: 'Model ID',
width: '25%',
width: '21%',
sortable: true,
render: (id: string) => (
<EuiCopy
Expand Down
28 changes: 23 additions & 5 deletions public/components/monitoring/tests/model_deployment_table.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const setup = (props?: Partial<ModelDeploymentTableProps>) => {
notRespondingNodesCount: 2,
planningNodesCount: 3,
planningWorkerNodes: [],
source: 'Local',
},
{
id: 'model-2-id',
Expand All @@ -27,6 +28,7 @@ const setup = (props?: Partial<ModelDeploymentTableProps>) => {
notRespondingNodesCount: 0,
planningNodesCount: 3,
planningWorkerNodes: [],
source: 'Local',
},
{
id: 'model-3-id',
Expand All @@ -35,6 +37,7 @@ const setup = (props?: Partial<ModelDeploymentTableProps>) => {
notRespondingNodesCount: 3,
planningNodesCount: 3,
planningWorkerNodes: [],
source: 'External',
},
],
pagination: { currentPage: 1, pageSize: 10, totalRecords: 100 },
Expand Down Expand Up @@ -122,11 +125,26 @@ describe('<DeployedModelTable />', () => {
expect(within(cells[2] as HTMLElement).getByText('on 3 of 3 nodes')).toBeInTheDocument();
});

it('should render Model ID at third column and copy to clipboard after text clicked', async () => {
it('should display source name at third column', () => {
const columnIndex = 2;
setup();
const header = screen.getAllByRole('columnheader')[columnIndex];
const columnContent = header
.closest('table')
?.querySelectorAll(`tbody tr td:nth-child(${columnIndex + 1})`);
expect(within(header).getByText('Source')).toBeInTheDocument();
expect(columnContent?.length).toBe(3);
const cells = columnContent!;
expect(within(cells[0] as HTMLElement).getByText('Local')).toBeInTheDocument();
expect(within(cells[1] as HTMLElement).getByText('Local')).toBeInTheDocument();
expect(within(cells[2] as HTMLElement).getByText('External')).toBeInTheDocument();
});

it('should render Model ID at forth column and copy to clipboard after text clicked', async () => {
const execCommandOrigin = document.execCommand;
document.execCommand = jest.fn(() => true);

const columnIndex = 2;
const columnIndex = 3;
setup();
const header = screen.getAllByRole('columnheader')[columnIndex];
const columnContent = header
Expand All @@ -146,7 +164,7 @@ describe('<DeployedModelTable />', () => {
});

it('should render Action column and call onViewDetail with the model item of the current table row', async () => {
const columnIndex = 3;
const columnIndex = 4;
const onViewDetailMock = jest.fn();
const { finalProps } = setup({
onViewDetail: onViewDetailMock,
Expand Down Expand Up @@ -258,7 +276,7 @@ describe('<DeployedModelTable />', () => {
},
});

await userEvent.click(within(screen.getAllByRole('columnheader')[2]).getByText('Model ID'));
await userEvent.click(within(screen.getAllByRole('columnheader')[3]).getByText('Model ID'));
expect(finalProps.onChange).toHaveBeenCalledWith(
expect.objectContaining({
sort: {
Expand All @@ -277,7 +295,7 @@ describe('<DeployedModelTable />', () => {
}}
/>
);
await userEvent.click(within(screen.getAllByRole('columnheader')[2]).getByText('Model ID'));
await userEvent.click(within(screen.getAllByRole('columnheader')[3]).getByText('Model ID'));
expect(finalProps.onChange).toHaveBeenCalledWith(
expect.objectContaining({
sort: {
Expand Down
65 changes: 48 additions & 17 deletions public/components/monitoring/tests/use_monitoring.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,54 @@ describe('useMonitoring', () => {
});
await waitFor(() => expect(Model.prototype.search).toHaveBeenCalledTimes(2));
});

it('should return consistent deployedModels', async () => {
jest.spyOn(Model.prototype, 'search').mockRestore();
const searchMock = jest.spyOn(Model.prototype, 'search').mockResolvedValue({
data: [
{
id: 'model-1-id',
name: 'model-1-name',
current_worker_node_count: 1,
planning_worker_node_count: 3,
algorithm: 'TEXT_EMBEDDING',
model_state: '',
model_version: '',
planning_worker_nodes: ['node1', 'node2', 'node3'],
},
{
id: 'model-2-id',
name: 'model-2-name',
current_worker_node_count: 1,
planning_worker_node_count: 3,
algorithm: 'REMOTE',
model_state: '',
model_version: '',
planning_worker_nodes: ['node1', 'node2', 'node3'],
},
],
total_models: 2,
});
const { result, waitFor } = renderHook(() => useMonitoring());

await waitFor(() => {
expect(result.current.deployedModels).toEqual(
expect.arrayContaining([
expect.objectContaining({
id: 'model-1-id',
name: 'model-1-name',
respondingNodesCount: 1,
notRespondingNodesCount: 2,
planningNodesCount: 3,
planningWorkerNodes: ['node1', 'node2', 'node3'],
}),
expect.objectContaining({ source: 'External' }),
])
);
});

searchMock.mockRestore();
});
});

describe('useMonitoring.pageStatus', () => {
Expand Down Expand Up @@ -195,21 +243,4 @@ describe('useMonitoring.pageStatus', () => {

await waitFor(() => expect(result.current.pageStatus).toBe('empty'));
});

it('should return consistent deployedModels', async () => {
const { result, waitFor } = renderHook(() => useMonitoring());

await waitFor(() =>
expect(result.current.deployedModels).toEqual([
{
id: 'model-1-id',
name: 'model-1-name',
respondingNodesCount: 1,
notRespondingNodesCount: 2,
planningNodesCount: 3,
planningWorkerNodes: ['node1', 'node2', 'node3'],
},
])
);
});
});
3 changes: 2 additions & 1 deletion public/components/monitoring/use_monitoring.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ const fetchDeployedModels = async (params: Params) => {
? [MODEL_STATE.loadFailed, MODEL_STATE.loaded, MODEL_STATE.partiallyLoaded]
: states,
sort: [`${params.sort.field}-${params.sort.direction}`],
exclude: 'REMOTE_MODEL',
});
const totalPages = Math.ceil(result.total_models / params.pageSize);
return {
Expand All @@ -64,6 +63,7 @@ const fetchDeployedModels = async (params: Params) => {
current_worker_node_count: workerCount,
planning_worker_node_count: planningCount,
planning_worker_nodes: planningWorkerNodes,
algorithm,
}) => {
return {
id,
Expand All @@ -75,6 +75,7 @@ const fetchDeployedModels = async (params: Params) => {
? planningCount - workerCount
: undefined,
planningWorkerNodes,
source: algorithm === 'REMOTE' ? 'External' : 'Local',
};
}
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const MODEL = {
id: 'id1',
name: 'test',
planningWorkerNodes: ['node-1', 'node-2', 'node-3'],
source: 'External',
};

function setup({ model = MODEL, onClose = jest.fn() }) {
Expand All @@ -26,10 +27,11 @@ describe('<PreviewPanel />', () => {
jest.clearAllMocks();
});

it('should render id and name in panel', () => {
it('should render id, name and source in panel', () => {
setup({});
expect(screen.getByText('test')).toBeInTheDocument();
expect(screen.getByText('id1')).toBeInTheDocument();
expect(screen.getByText('External')).toBeInTheDocument();
});

it('should call onClose when close panel', async () => {
Expand Down
5 changes: 4 additions & 1 deletion public/components/preview_panel/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export interface PreviewModel {
name: string;
id: string;
planningWorkerNodes: string[];
source: string;
}

interface Props {
Expand All @@ -39,7 +40,7 @@ interface Props {
}

export const PreviewPanel = ({ onClose, model }: Props) => {
const { id, name } = model;
const { id, name, source } = model;
const { data, loading } = useFetcher(APIProvider.getAPI('profile').getModel, id);
const nodes = useMemo(() => {
if (loading) {
Expand Down Expand Up @@ -101,6 +102,8 @@ export const PreviewPanel = ({ onClose, model }: Props) => {
<EuiDescriptionListDescription>
<CopyableText text={id} iconLeft={false} tooltipText="Copy model ID" />
</EuiDescriptionListDescription>
<EuiDescriptionListTitle style={{ fontSize: '14px' }}>Source</EuiDescriptionListTitle>
<EuiDescriptionListDescription>{source}</EuiDescriptionListDescription>
<EuiDescriptionListTitle style={{ fontSize: '14px' }}>
Model status by node
</EuiDescriptionListTitle>
Expand Down
4 changes: 1 addition & 3 deletions server/routes/model_router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,11 @@ export const modelRouter = (router: IRouter) => {
),
states: schema.maybe(schema.oneOf([schema.arrayOf(modelStateSchema), modelStateSchema])),
nameOrId: schema.maybe(schema.string()),
exclude: schema.maybe(schema.literal('REMOTE_MODEL')),
}),
},
},
async (context, request) => {
const { from, size, sort, states, nameOrId, exclude } = request.query;
const { from, size, sort, states, nameOrId } = request.query;
try {
const payload = await ModelService.search({
client: context.core.opensearch.client,
Expand All @@ -56,7 +55,6 @@ export const modelRouter = (router: IRouter) => {
sort: typeof sort === 'string' ? [sort] : sort,
states: typeof states === 'string' ? [states] : states,
nameOrId,
exclude,
});
return opensearchDashboardsResponseFactory.ok({ body: payload });
} catch (err) {
Expand Down
1 change: 0 additions & 1 deletion server/services/model_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ export class ModelService {
sort?: ModelSearchSort[];
states?: MODEL_STATE[];
nameOrId?: string;
exclude?: 'REMOTE_MODEL';
}) {
const {
body: { hits },
Expand Down
21 changes: 4 additions & 17 deletions server/services/utils/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ import { generateTermQuery } from './query';
export const generateModelSearchQuery = ({
states,
nameOrId,
exclude,
}: {
states?: MODEL_STATE[];
nameOrId?: string;
exclude?: 'REMOTE_MODEL';
}) => ({
bool: {
must: [
Expand All @@ -35,21 +33,10 @@ export const generateModelSearchQuery = ({
]
: []),
],
must_not: [
{
exists: {
field: 'chunk_number',
},
must_not: {
exists: {
field: 'chunk_number',
},
...(exclude === 'REMOTE_MODEL'
? [
{
term: {
algorithm: 'REMOTE',
},
},
]
: []),
],
},
},
});

0 comments on commit 9566494

Please sign in to comment.