diff --git a/.changeset/shiny-brooms-punch.md b/.changeset/shiny-brooms-punch.md new file mode 100644 index 000000000..e94629532 --- /dev/null +++ b/.changeset/shiny-brooms-punch.md @@ -0,0 +1,5 @@ +--- +'@sap-ai-sdk/orchestration': minor +--- + +[Improvement] Add `buildDocumentGroundingConfig()` convenience function to create document grounding configuration in the Orchestration client. diff --git a/packages/orchestration/README.md b/packages/orchestration/README.md index 295b24d4a..14fb69c89 100644 --- a/packages/orchestration/README.md +++ b/packages/orchestration/README.md @@ -330,6 +330,7 @@ return response.getContent(); ### Grounding Grounding enables integrating external, contextually relevant, domain-specific, or real-time data into AI processes. +The grounding configuration can be provided as a raw JSON object or by using the `buildDocumentGroundingConfig()` function, which requires only the minimal mandatory values. ```ts const orchestrationClient = new OrchestrationClient({ @@ -346,21 +347,16 @@ const orchestrationClient = new OrchestrationClient({ ], defaults: {} }, - grounding: { - type: 'document_grounding_service', - config: { - filters: [ + grounding: buildDocumentGroundingConfig( + input_params: ['groundingRequest'], + output_param: 'groundingOutput', + filters: [ { id: 'filter1', - data_repositories: ['*'], - search_config: {}, - data_repository_type: 'vector' + data_repositories: ['repository-id'] } ], - input_params: ['groundingRequest'], - output_param: 'groundingOutput' - } - } + ) }); const response = await orchestrationClient.chatCompletion({ diff --git a/packages/orchestration/src/index.ts b/packages/orchestration/src/index.ts index 215aec0fb..57eae616a 100644 --- a/packages/orchestration/src/index.ts +++ b/packages/orchestration/src/index.ts @@ -42,12 +42,17 @@ export type { OrchestrationModuleConfig, LlmModuleConfig, Prompt, + DocumentGroundingServiceConfig, + DocumentGroundingServiceFilter, LlmModelParams } from './orchestration-types.js'; export { OrchestrationClient } from './orchestration-client.js'; -export { buildAzureContentFilter } from './orchestration-utils.js'; +export { + buildAzureContentFilter, + buildDocumentGroundingConfig +} from './orchestration-utils.js'; export { OrchestrationResponse } from './orchestration-response.js'; diff --git a/packages/orchestration/src/orchestration-completion-post-request.test.ts b/packages/orchestration/src/orchestration-completion-post-request.test.ts index 6dd43d977..c7b19fcb0 100644 --- a/packages/orchestration/src/orchestration-completion-post-request.test.ts +++ b/packages/orchestration/src/orchestration-completion-post-request.test.ts @@ -1,6 +1,6 @@ import { constructCompletionPostRequest } from './orchestration-client.js'; import { buildAzureContentFilter } from './orchestration-utils.js'; -import type { CompletionPostRequest } from './client/api/schema'; +import type { CompletionPostRequest } from './client/api/schema/index.js'; import type { OrchestrationModuleConfig } from './orchestration-types.js'; describe('construct completion post request', () => { diff --git a/packages/orchestration/src/orchestration-types.ts b/packages/orchestration/src/orchestration-types.ts index 8634ef880..8edf616b9 100644 --- a/packages/orchestration/src/orchestration-types.ts +++ b/packages/orchestration/src/orchestration-types.ts @@ -1,6 +1,8 @@ import type { ChatModel } from './model-types.js'; import type { ChatMessages, + DataRepositoryType, + DocumentGroundingFilter, FilteringModuleConfig, GroundingModuleConfig, MaskingModuleConfig, @@ -69,3 +71,36 @@ export interface OrchestrationModuleConfig { */ grounding?: GroundingModuleConfig; } + +/** + * Represents a filter configuration for the Document Grounding Service. + */ +export type DocumentGroundingServiceFilter = Omit< + DocumentGroundingFilter, + 'data_repository_type' +> & { + /** + * Defines the type of data repository. + * If not set, the default value is 'vector'. + */ + data_repository_type?: DataRepositoryType; +}; + +/** + * Represents the configuration for the Document Grounding Service. + */ +export interface DocumentGroundingServiceConfig { + /** + * Defines the filters to apply during the grounding process. + */ + filters?: DocumentGroundingServiceFilter[]; + /** + * Contains the input parameters used for grounding input questions. + */ + input_params: string[]; + /** + * Parameter name used for grounding output. + * @example "groundingOutput" + */ + output_param: string; +} diff --git a/packages/orchestration/src/orchestration-utils.test.ts b/packages/orchestration/src/orchestration-utils.test.ts index f70d502d9..70d584b01 100644 --- a/packages/orchestration/src/orchestration-utils.test.ts +++ b/packages/orchestration/src/orchestration-utils.test.ts @@ -1,176 +1,237 @@ import { constructCompletionPostRequest } from './orchestration-client.js'; -import { buildAzureContentFilter } from './orchestration-utils.js'; +import { + buildAzureContentFilter, + buildDocumentGroundingConfig +} from './orchestration-utils.js'; import type { CompletionPostRequest, FilteringModuleConfig } from './client/api/schema/index.js'; -import type { OrchestrationModuleConfig } from './orchestration-types.js'; +import type { + OrchestrationModuleConfig, + DocumentGroundingServiceConfig +} from './orchestration-types.js'; -describe('filter utility', () => { - const config: OrchestrationModuleConfig = { - llm: { - model_name: 'gpt-35-turbo-16k', - model_params: { max_tokens: 50, temperature: 0.1 } - }, - templating: { - template: [ - { role: 'user', content: 'Create {number} paraphrases of {phrase}' } - ] - } - }; +describe('orchestration utils', () => { + describe('azure filter', () => { + const config: OrchestrationModuleConfig = { + llm: { + model_name: 'gpt-35-turbo-16k', + model_params: { max_tokens: 50, temperature: 0.1 } + }, + templating: { + template: [ + { role: 'user', content: 'Create {number} paraphrases of {phrase}' } + ] + } + }; + const prompt = { inputParams: { phrase: 'I hate you.', number: '3' } }; - const prompt = { inputParams: { phrase: 'I hate you.', number: '3' } }; + afterEach(() => { + config.filtering = undefined; + }); - afterEach(() => { - config.filtering = undefined; - }); + it('constructs filter configuration with only input', async () => { + const filtering: FilteringModuleConfig = { + input: buildAzureContentFilter({ Hate: 4, SelfHarm: 0 }) + }; + const expectedFilterConfig: FilteringModuleConfig = { + input: { + filters: [ + { + type: 'azure_content_safety', + config: { + Hate: 4, + SelfHarm: 0 + } + } + ] + } + }; + config.filtering = filtering; + const completionPostRequest: CompletionPostRequest = + constructCompletionPostRequest(config, prompt); + expect( + completionPostRequest.orchestration_config.module_configurations + .filtering_module_config + ).toEqual(expectedFilterConfig); + }); - it('constructs filter configuration with only input', async () => { - const filtering: FilteringModuleConfig = { - input: buildAzureContentFilter({ Hate: 4, SelfHarm: 0 }) - }; - const expectedFilterConfig: FilteringModuleConfig = { - input: { - filters: [ - { - type: 'azure_content_safety', - config: { - Hate: 4, - SelfHarm: 0 + it('constructs filter configuration with only output', async () => { + const filtering: FilteringModuleConfig = { + output: buildAzureContentFilter({ Sexual: 2, Violence: 6 }) + }; + const expectedFilterConfig: FilteringModuleConfig = { + output: { + filters: [ + { + type: 'azure_content_safety', + config: { + Sexual: 2, + Violence: 6 + } } - } - ] - } - }; - config.filtering = filtering; - const completionPostRequest: CompletionPostRequest = - constructCompletionPostRequest(config, prompt); - expect( - completionPostRequest.orchestration_config.module_configurations - .filtering_module_config - ).toEqual(expectedFilterConfig); - }); + ] + } + }; + config.filtering = filtering; + const completionPostRequest: CompletionPostRequest = + constructCompletionPostRequest(config, prompt); + expect( + completionPostRequest.orchestration_config.module_configurations + .filtering_module_config + ).toEqual(expectedFilterConfig); + }); - it('constructs filter configuration with only output', async () => { - const filtering: FilteringModuleConfig = { - output: buildAzureContentFilter({ Sexual: 2, Violence: 6 }) - }; - const expectedFilterConfig: FilteringModuleConfig = { - output: { - filters: [ - { - type: 'azure_content_safety', - config: { - Sexual: 2, - Violence: 6 + it('constructs filter configuration with both input and output', async () => { + const filtering: FilteringModuleConfig = { + input: buildAzureContentFilter({ + Hate: 4, + SelfHarm: 0, + Sexual: 2, + Violence: 6 + }), + output: buildAzureContentFilter({ Sexual: 2, Violence: 6 }) + }; + const expectedFilterConfig: FilteringModuleConfig = { + input: { + filters: [ + { + type: 'azure_content_safety', + config: { + Hate: 4, + SelfHarm: 0, + Sexual: 2, + Violence: 6 + } } - } - ] - } - }; - config.filtering = filtering; - const completionPostRequest: CompletionPostRequest = - constructCompletionPostRequest(config, prompt); - expect( - completionPostRequest.orchestration_config.module_configurations - .filtering_module_config - ).toEqual(expectedFilterConfig); - }); + ] + }, + output: { + filters: [ + { + type: 'azure_content_safety', + config: { + Sexual: 2, + Violence: 6 + } + } + ] + } + }; + config.filtering = filtering; + const completionPostRequest: CompletionPostRequest = + constructCompletionPostRequest(config, prompt); + expect( + completionPostRequest.orchestration_config.module_configurations + .filtering_module_config + ).toEqual(expectedFilterConfig); + }); - it('constructs filter configuration with both input and output', async () => { - const filtering: FilteringModuleConfig = { - input: buildAzureContentFilter({ - Hate: 4, - SelfHarm: 0, - Sexual: 2, - Violence: 6 - }), - output: buildAzureContentFilter({ Sexual: 2, Violence: 6 }) - }; - const expectedFilterConfig: FilteringModuleConfig = { - input: { - filters: [ - { - type: 'azure_content_safety', - config: { - Hate: 4, - SelfHarm: 0, - Sexual: 2, - Violence: 6 + it('omits filters if not set', async () => { + const filtering: FilteringModuleConfig = { + input: buildAzureContentFilter(), + output: buildAzureContentFilter() + }; + config.filtering = filtering; + const completionPostRequest: CompletionPostRequest = + constructCompletionPostRequest(config, prompt); + const expectedFilterConfig: FilteringModuleConfig = { + input: { + filters: [ + { + type: 'azure_content_safety' } - } - ] - }, - output: { - filters: [ - { - type: 'azure_content_safety', - config: { - Sexual: 2, - Violence: 6 + ] + }, + output: { + filters: [ + { + type: 'azure_content_safety' } - } - ] - } - }; - config.filtering = filtering; - const completionPostRequest: CompletionPostRequest = - constructCompletionPostRequest(config, prompt); - expect( - completionPostRequest.orchestration_config.module_configurations - .filtering_module_config - ).toEqual(expectedFilterConfig); - }); + ] + } + }; + expect( + completionPostRequest.orchestration_config.module_configurations + .filtering_module_config + ).toEqual(expectedFilterConfig); + }); - it('omits filters if not set', async () => { - const filtering: FilteringModuleConfig = { - input: buildAzureContentFilter(), - output: buildAzureContentFilter() - }; - config.filtering = filtering; - const completionPostRequest: CompletionPostRequest = - constructCompletionPostRequest(config, prompt); - const expectedFilterConfig: FilteringModuleConfig = { - input: { + it('omits filter configuration if not set', async () => { + const filtering: FilteringModuleConfig = {}; + config.filtering = filtering; + const completionPostRequest: CompletionPostRequest = + constructCompletionPostRequest(config, prompt); + expect( + completionPostRequest.orchestration_config.module_configurations + .filtering_module_config + ).toBeUndefined(); + }); + + it('throw error when configuring empty filter', async () => { + const createFilterConfig = () => { + { + buildAzureContentFilter({}); + } + }; + expect(createFilterConfig).toThrow( + 'Filter property cannot be an empty object' + ); + }); + }); + describe('document grounding', () => { + it('builds grounding configuration with minimal required properties', () => { + const groundingConfig: DocumentGroundingServiceConfig = { filters: [ { - type: 'azure_content_safety' + id: 'filter-id' } - ] - }, - output: { + ], + input_params: ['input'], + output_param: 'output' + }; + expect(buildDocumentGroundingConfig(groundingConfig)).toEqual({ + type: 'document_grounding_service', + config: { + filters: [ + { + id: 'filter-id', + data_repository_type: 'vector' + } + ], + input_params: ['input'], + output_param: 'output' + } + }); + }); + + it('overrides default data repository type', () => { + const groundingConfig: DocumentGroundingServiceConfig = { filters: [ { - type: 'azure_content_safety' + id: 'filter-id', + data_repositories: ['repo1', 'repo2'], + data_repository_type: 'custom-type' } - ] - } - }; - expect( - completionPostRequest.orchestration_config.module_configurations - .filtering_module_config - ).toEqual(expectedFilterConfig); - }); - - it('omits filter configuration if not set', async () => { - const filtering: FilteringModuleConfig = {}; - config.filtering = filtering; - const completionPostRequest: CompletionPostRequest = - constructCompletionPostRequest(config, prompt); - expect( - completionPostRequest.orchestration_config.module_configurations - .filtering_module_config - ).toBeUndefined(); - }); - - it('throw error when configuring empty filter', async () => { - const createFilterConfig = () => { - { - buildAzureContentFilter({}); - } - }; - expect(createFilterConfig).toThrow( - 'Filter property cannot be an empty object' - ); + ], + input_params: ['input'], + output_param: 'output' + }; + expect(buildDocumentGroundingConfig(groundingConfig)).toEqual({ + type: 'document_grounding_service', + config: { + filters: [ + { + id: 'filter-id', + data_repositories: ['repo1', 'repo2'], + data_repository_type: 'custom-type' + } + ], + input_params: ['input'], + output_param: 'output' + } + }); + }); }); }); diff --git a/packages/orchestration/src/orchestration-utils.ts b/packages/orchestration/src/orchestration-utils.ts index acc118ebf..7f553f281 100644 --- a/packages/orchestration/src/orchestration-utils.ts +++ b/packages/orchestration/src/orchestration-utils.ts @@ -1,5 +1,7 @@ +import type { DocumentGroundingServiceConfig } from './orchestration-types.js'; import type { AzureContentSafety, + GroundingModuleConfig, InputFilteringConfig, OutputFilteringConfig } from './client/api/schema/index.js'; @@ -24,3 +26,26 @@ export function buildAzureContentFilter( ] }; } + +/** + * Convenience function to create Document Grounding configuration. + * @param groundingConfig - Configuration for the document grounding service. + * @returns An object with the full grounding configuration. + */ +export function buildDocumentGroundingConfig( + groundingConfig: DocumentGroundingServiceConfig +): GroundingModuleConfig { + return { + type: 'document_grounding_service', + config: { + input_params: groundingConfig.input_params, + output_param: groundingConfig.output_param, + ...(groundingConfig.filters && { + filters: groundingConfig.filters?.map(filter => ({ + data_repository_type: 'vector', + ...filter + })) + }) + } + }; +} diff --git a/sample-code/src/orchestration.ts b/sample-code/src/orchestration.ts index 9eef3c85d..beb7ec779 100644 --- a/sample-code/src/orchestration.ts +++ b/sample-code/src/orchestration.ts @@ -1,7 +1,8 @@ import { readFile } from 'node:fs/promises'; import { OrchestrationClient, - buildAzureContentFilter + buildAzureContentFilter, + buildDocumentGroundingConfig } from '@sap-ai-sdk/orchestration'; import { createLogger } from '@sap-cloud-sdk/util'; import type { @@ -241,21 +242,11 @@ export async function orchestrationGrounding(): Promise { } ] }, - grounding: { - type: 'document_grounding_service', - config: { - filters: [ - { - id: 'filter1', - data_repositories: ['*'], - search_config: {}, - data_repository_type: 'vector' - } - ], - input_params: ['groundingRequest'], - output_param: 'groundingOutput' - } - } + grounding: buildDocumentGroundingConfig({ + input_params: ['groundingRequest'], + output_param: 'groundingOutput', + filters: [{ id: 'filter1' }] + }) }); return orchestrationClient.chatCompletion({ diff --git a/tests/type-tests/test/orchestration.test-d.ts b/tests/type-tests/test/orchestration.test-d.ts index 7b932dfdc..180afade6 100644 --- a/tests/type-tests/test/orchestration.test-d.ts +++ b/tests/type-tests/test/orchestration.test-d.ts @@ -5,6 +5,8 @@ import { OrchestrationResponse, TokenUsage, ChatModel, + GroundingModuleConfig, + buildDocumentGroundingConfig, LlmModelParams } from '@sap-ai-sdk/orchestration'; @@ -241,3 +243,31 @@ expectType>( expect('custom-model'); expect('gemini-1.0-pro'); + +/** + * Grounding util + */ +expectType( + buildDocumentGroundingConfig({ + input_params: ['test'], + output_param: 'test' + }) +); + +expectError( + buildDocumentGroundingConfig({ + input_params: ['test'] + }) +); + +expectType( + buildDocumentGroundingConfig({ + input_params: ['test'], + output_param: 'test', + filters: [ + { + id: 'test' + } + ] + }) +);