From ceb3ee6adc863c59c68207d1986613077dcebaa5 Mon Sep 17 00:00:00 2001
From: Dogan AY <dogan.ay@forestadmin.com>
Date: Tue, 24 Dec 2024 11:48:54 +0100
Subject: [PATCH] feat(segment): support templating for segment live query

---
 packages/agent/src/routes/access/count.ts     |  6 +-
 packages/agent/src/routes/access/list.ts      |  6 +-
 packages/agent/src/services/index.ts          |  5 ++
 .../src/services/segment-query-handler.ts     | 47 +++++++++++++++
 .../forest-admin-http-driver-services.ts      |  2 +
 .../__factories__/segment-query-handler.ts    |  7 +++
 .../services/segment-query-handler.test.ts    | 58 +++++++++++++++++++
 .../src/decorators/segment/collection.ts      |  4 +-
 .../decorators/segment/collection.test.ts     |  6 +-
 .../interfaces/query/filter/unpaginated.ts    | 10 +++-
 10 files changed, 144 insertions(+), 7 deletions(-)
 create mode 100644 packages/agent/src/services/segment-query-handler.ts
 create mode 100644 packages/agent/test/__factories__/segment-query-handler.ts
 create mode 100644 packages/agent/test/services/segment-query-handler.test.ts

diff --git a/packages/agent/src/routes/access/count.ts b/packages/agent/src/routes/access/count.ts
index 17736714ba..73a1976b0b 100644
--- a/packages/agent/src/routes/access/count.ts
+++ b/packages/agent/src/routes/access/count.ts
@@ -17,7 +17,11 @@ export default class CountRoute extends CollectionRoute {
     if (this.collection.schema.countable) {
       const scope = await this.services.authorization.getScope(this.collection, context);
       const caller = QueryStringParser.parseCaller(context);
-      const filter = ContextFilterFactory.build(this.collection, context, scope);
+      let filter = ContextFilterFactory.build(this.collection, context, scope);
+      filter = await this.services.segmentQueryHandler.handleLiveQuerySegmentFilter(
+        context,
+        filter,
+      );
 
       const aggregation = new Aggregation({ operation: 'Count' });
       const aggregationResult = await this.collection.aggregate(caller, filter, aggregation);
diff --git a/packages/agent/src/routes/access/list.ts b/packages/agent/src/routes/access/list.ts
index 5005c30c72..e2a3484a8c 100644
--- a/packages/agent/src/routes/access/list.ts
+++ b/packages/agent/src/routes/access/list.ts
@@ -14,7 +14,11 @@ export default class ListRoute extends CollectionRoute {
     await this.services.authorization.assertCanBrowse(context, this.collection.name);
 
     const scope = await this.services.authorization.getScope(this.collection, context);
-    const paginatedFilter = ContextFilterFactory.buildPaginated(this.collection, context, scope);
+    let paginatedFilter = ContextFilterFactory.buildPaginated(this.collection, context, scope);
+    paginatedFilter = await this.services.segmentQueryHandler.handleLiveQuerySegmentFilter(
+      context,
+      paginatedFilter,
+    );
 
     const records = await this.collection.list(
       QueryStringParser.parseCaller(context),
diff --git a/packages/agent/src/services/index.ts b/packages/agent/src/services/index.ts
index cf1afc335c..b0c7d5b3e8 100644
--- a/packages/agent/src/services/index.ts
+++ b/packages/agent/src/services/index.ts
@@ -2,6 +2,7 @@ import { ChartHandlerInterface } from '@forestadmin/forestadmin-client';
 
 import authorizationServiceFactory from './authorization';
 import AuthorizationService from './authorization/authorization';
+import SegmentQueryHandler from './segment-query-handler';
 import Serializer from './serializer';
 import { AgentOptionsWithDefaults } from '../types';
 
@@ -9,6 +10,7 @@ export type ForestAdminHttpDriverServices = {
   serializer: Serializer;
   authorization: AuthorizationService;
   chartHandler: ChartHandlerInterface;
+  segmentQueryHandler: SegmentQueryHandler;
 };
 
 export default (options: AgentOptionsWithDefaults): ForestAdminHttpDriverServices => {
@@ -16,5 +18,8 @@ export default (options: AgentOptionsWithDefaults): ForestAdminHttpDriverService
     authorization: authorizationServiceFactory(options),
     serializer: new Serializer(),
     chartHandler: options.forestAdminClient.chartHandler,
+    segmentQueryHandler: new SegmentQueryHandler(
+      options.forestAdminClient.contextVariablesInstantiator,
+    ),
   };
 };
diff --git a/packages/agent/src/services/segment-query-handler.ts b/packages/agent/src/services/segment-query-handler.ts
new file mode 100644
index 0000000000..e7b2aa42cd
--- /dev/null
+++ b/packages/agent/src/services/segment-query-handler.ts
@@ -0,0 +1,47 @@
+import type { Caller, PaginatedFilter } from '@forestadmin/datasource-toolkit';
+import type { ContextVariablesInstantiatorInterface } from '@forestadmin/forestadmin-client';
+
+import { ContextVariablesInjector } from '@forestadmin/forestadmin-client';
+import { Context } from 'koa';
+
+export default class SegmentQueryHandler {
+  private readonly contextVariablesInstantiator: ContextVariablesInstantiatorInterface;
+
+  constructor(contextVariablesInstantiator: ContextVariablesInstantiatorInterface) {
+    this.contextVariablesInstantiator = contextVariablesInstantiator;
+  }
+
+  public async handleLiveQuerySegmentFilter(context: Context, paginatedFilter: PaginatedFilter) {
+    if (paginatedFilter.liveQuerySegment) {
+      const { renderingId, id: userId } = <Caller>context.state.user;
+      const contextVariables = await this.contextVariablesInstantiator.buildContextVariables({
+        userId,
+        renderingId,
+      });
+      const contextVariablesUsed: Record<string, unknown> = {};
+
+      const replaceContextVariable = (contextVariableName: string) => {
+        const contextVariableRenamed = contextVariableName.replace(/\./g, '_');
+        contextVariablesUsed[contextVariableRenamed] =
+          contextVariables.getValue(contextVariableName);
+
+        return `$${contextVariableRenamed}`;
+      };
+
+      const replacedQuery = ContextVariablesInjector.injectContextInValueCustom(
+        paginatedFilter.liveQuerySegment.query,
+        replaceContextVariable,
+      );
+
+      return paginatedFilter.override({
+        liveQuerySegment: {
+          query: replacedQuery,
+          contextVariables: contextVariablesUsed,
+          connectionName: paginatedFilter.liveQuerySegment.connectionName,
+        },
+      });
+    }
+
+    return paginatedFilter;
+  }
+}
diff --git a/packages/agent/test/__factories__/forest-admin-http-driver-services.ts b/packages/agent/test/__factories__/forest-admin-http-driver-services.ts
index d0eb3d5ee9..f40c40af1a 100644
--- a/packages/agent/test/__factories__/forest-admin-http-driver-services.ts
+++ b/packages/agent/test/__factories__/forest-admin-http-driver-services.ts
@@ -1,12 +1,14 @@
 import { Factory } from 'fishery';
 
 import factoryAuthorization from './authorization/authorization';
+import factorySegmentQueryHandler from './segment-query-handler';
 import factorySerializer from './serializer';
 import { ForestAdminHttpDriverServices } from '../../src/services';
 
 export default Factory.define<ForestAdminHttpDriverServices>(() => ({
   serializer: factorySerializer.build(),
   authorization: factoryAuthorization.mockAllMethods().build(),
+  segmentQueryHandler: factorySegmentQueryHandler.build(),
   chartHandler: {
     getChartWithContextInjected: jest.fn(),
     getQueryForChart: jest.fn(),
diff --git a/packages/agent/test/__factories__/segment-query-handler.ts b/packages/agent/test/__factories__/segment-query-handler.ts
new file mode 100644
index 0000000000..0f6d9e5ed6
--- /dev/null
+++ b/packages/agent/test/__factories__/segment-query-handler.ts
@@ -0,0 +1,7 @@
+import { Factory } from 'fishery';
+
+import SegmentQueryHandler from '../../src/services/segment-query-handler';
+
+export default Factory.define<SegmentQueryHandler>(
+  () => new SegmentQueryHandler({ buildContextVariables: jest.fn().mockResolvedValue({}) }),
+);
diff --git a/packages/agent/test/services/segment-query-handler.test.ts b/packages/agent/test/services/segment-query-handler.test.ts
new file mode 100644
index 0000000000..3739d8a70d
--- /dev/null
+++ b/packages/agent/test/services/segment-query-handler.test.ts
@@ -0,0 +1,58 @@
+import { PaginatedFilter } from '@forestadmin/datasource-toolkit';
+import { Context } from 'koa';
+
+import SegmentQueryHandler from '../../src/services/segment-query-handler';
+
+describe('SegmentQueryHandler', () => {
+  const spyBuildContextVariables = jest.fn().mockResolvedValue({});
+
+  const setupService = (): SegmentQueryHandler => {
+    return new SegmentQueryHandler({ buildContextVariables: spyBuildContextVariables });
+  };
+
+  afterEach(() => jest.resetAllMocks());
+
+  describe('when no liveQuerySegment is defined', () => {
+    test('should do nothing', async () => {
+      const context = {
+        state: { user: {} },
+      } as Context;
+      const paginatedFilter = new PaginatedFilter({});
+
+      const service = setupService();
+
+      const result = await service.handleLiveQuerySegmentFilter(context, paginatedFilter);
+      expect(spyBuildContextVariables).not.toHaveBeenCalled();
+      expect(result).toStrictEqual(paginatedFilter);
+    });
+  });
+
+  describe('when liveQuerySegment is defined', () => {
+    test('should handle the live query', async () => {
+      const context = {
+        state: { user: { id: 42, renderingId: 101 } },
+      } as Context;
+      const paginatedFilter = new PaginatedFilter({
+        liveQuerySegment: {
+          query: 'SELECT id FROM "users" WHERE email like {{currentUser.email}};',
+          connectionName: 'main',
+        },
+      });
+
+      spyBuildContextVariables.mockResolvedValue({ getValue: () => 'johndoe@forestadmin.com' });
+      const service = setupService();
+
+      const result = await service.handleLiveQuerySegmentFilter(context, paginatedFilter);
+      expect(spyBuildContextVariables).toHaveBeenCalled();
+      expect(result).toStrictEqual(
+        paginatedFilter.override({
+          liveQuerySegment: {
+            query: 'SELECT id FROM "users" WHERE email like $currentUser_email;',
+            contextVariables: { currentUser_email: 'johndoe@forestadmin.com' },
+            connectionName: 'main',
+          },
+        }),
+      );
+    });
+  });
+});
diff --git a/packages/datasource-customizer/src/decorators/segment/collection.ts b/packages/datasource-customizer/src/decorators/segment/collection.ts
index 565cfbec7b..e4a8f25600 100644
--- a/packages/datasource-customizer/src/decorators/segment/collection.ts
+++ b/packages/datasource-customizer/src/decorators/segment/collection.ts
@@ -72,13 +72,13 @@ export default class SegmentCollectionDecorator extends CollectionDecorator {
     const { liveQuerySegment } = filter;
 
     if (liveQuerySegment) {
-      const { query, connectionName } = liveQuerySegment;
+      const { query, connectionName, contextVariables } = liveQuerySegment;
 
       try {
         const result = (await this.dataSource.executeNativeQuery(
           connectionName,
           query,
-          {},
+          contextVariables,
         )) as Record<string, unknown>[];
 
         const [primaryKey] = SchemaUtils.getPrimaryKeys(this.childCollection.schema);
diff --git a/packages/datasource-customizer/test/decorators/segment/collection.test.ts b/packages/datasource-customizer/test/decorators/segment/collection.test.ts
index 8128931203..acd1e0f096 100644
--- a/packages/datasource-customizer/test/decorators/segment/collection.test.ts
+++ b/packages/datasource-customizer/test/decorators/segment/collection.test.ts
@@ -128,7 +128,11 @@ describe('SegmentCollectionDecorator', () => {
         const filter = await segmentDecorator.refineFilter(
           factories.caller.build(),
           factories.filter.build({
-            liveQuerySegment: { query: 'SELECT id from toto', connectionName: 'main' },
+            liveQuerySegment: {
+              query: 'SELECT id from toto',
+              connectionName: 'main',
+              contextVariables: { toto: 1 },
+            },
           }),
         );
 
diff --git a/packages/datasource-toolkit/src/interfaces/query/filter/unpaginated.ts b/packages/datasource-toolkit/src/interfaces/query/filter/unpaginated.ts
index f511add2bc..8143659265 100644
--- a/packages/datasource-toolkit/src/interfaces/query/filter/unpaginated.ts
+++ b/packages/datasource-toolkit/src/interfaces/query/filter/unpaginated.ts
@@ -1,11 +1,17 @@
 import ConditionTree, { PlainConditionTree } from '../condition-tree/nodes/base';
 
+export type LiveQuerySegment = {
+  query: string;
+  connectionName: string;
+  contextVariables?: Record<string, unknown>;
+};
+
 export type FilterComponents = {
   conditionTree?: ConditionTree;
   search?: string | null;
   searchExtended?: boolean;
   segment?: string;
-  liveQuerySegment?: { query: string; connectionName: string };
+  liveQuerySegment?: LiveQuerySegment;
 };
 
 export type PlainFilter = {
@@ -20,7 +26,7 @@ export default class Filter {
   search?: string;
   searchExtended?: boolean;
   segment?: string;
-  liveQuerySegment?: { query: string; connectionName: string };
+  liveQuerySegment?: LiveQuerySegment;
 
   get isNestable(): boolean {
     return !this.search && !this.segment;