diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 813613a3..c3b72118 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -34,6 +34,7 @@ import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.index.query.QueryBuilder; @@ -114,6 +115,10 @@ public CreateWorkflowTransportAction( @Override protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { + String tenantId = request.getTemplate() == null ? null : request.getTemplate().getTenantId(); + if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) { + return; + } User user = getUserContext(client); String workflowId = request.getWorkflowId(); try { diff --git a/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java index 2974f522..fbaa449e 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java @@ -24,8 +24,10 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -43,6 +45,7 @@ public class DeleteWorkflowTransportAction extends HandledTransportAction listener) { if (flowFrameworkIndicesHandler.doesIndexExist(GLOBAL_CONTEXT_INDEX)) { + String tenantId = request.getTemplate() == null ? null : request.getTemplate().getTenantId(); + if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) { + return; + } String workflowId = request.getWorkflowId(); User user = getUserContext(client); diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java index 9699de5a..44e99af7 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java @@ -32,6 +32,7 @@ import org.opensearch.flowframework.model.ProvisioningProgress; import org.opensearch.flowframework.model.ResourceCreated; import org.opensearch.flowframework.model.State; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowStep; @@ -148,6 +149,10 @@ private void executeDeprovisionRequest( ActionListener listener, ThreadContext.StoredContext context ) { + String tenantId = request.getTemplate() == null ? null : request.getTemplate().getTenantId(); + if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) { + return; + } String workflowId = request.getWorkflowId(); String allowDelete = request.getParams().get(ALLOW_DELETE); GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java index 58d37edf..8e1e711b 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java @@ -24,11 +24,13 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -45,6 +47,7 @@ public class GetWorkflowTransportAction extends HandledTransportAction listener) { if (flowFrameworkIndicesHandler.doesIndexExist(GLOBAL_CONTEXT_INDEX)) { - + String tenantId = request.getTemplate() == null ? null : request.getTemplate().getTenantId(); + if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) { + return; + } String workflowId = request.getWorkflowId(); - User user = getUserContext(client); // Retrieve workflow by ID diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 45f37416..c4fd6b51 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -25,6 +25,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.ProvisioningProgress; @@ -32,6 +33,7 @@ import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.plugins.PluginsService; @@ -71,6 +73,7 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction listener) { // Retrieve use case template from global context + String tenantId = request.getTemplate() == null ? null : request.getTemplate().getTenantId(); + if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) { + return; + } String workflowId = request.getWorkflowId(); - User user = getUserContext(client); // Stash thread context to interact with system index diff --git a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java index 8e501228..5ae101e3 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java @@ -34,6 +34,7 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.util.TenantAwareHelper; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; @@ -130,7 +131,10 @@ public ReprovisionWorkflowTransportAction( @Override protected void doExecute(Task task, ReprovisionWorkflowRequest request, ActionListener listener) { - + String tenantId = request.getUpdatedTemplate() == null ? null : request.getUpdatedTemplate().getTenantId(); + if (!TenantAwareHelper.validateTenantId(flowFrameworkSettings.isMultiTenancyEnabled(), tenantId, listener)) { + return; + } String workflowId = request.getWorkflowId(); User user = getUserContext(client); diff --git a/src/main/java/org/opensearch/flowframework/util/TenantAwareHelper.java b/src/main/java/org/opensearch/flowframework/util/TenantAwareHelper.java new file mode 100644 index 00000000..e2eaceaf --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/util/TenantAwareHelper.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; + +import java.util.Objects; + +public class TenantAwareHelper { + + /** + * Validates the tenant ID based on the multi-tenancy feature setting. + * + * @param isMultiTenancyEnabled whether the multi-tenancy feature is enabled. + * @param tenantId The tenant ID to validate. + * @param listener The action listener to handle failure cases. + * @return true if the tenant ID is valid or if multi-tenancy is not enabled; false if the tenant ID is invalid and multi-tenancy is enabled. + */ + public static boolean validateTenantId(boolean isMultiTenancyEnabled, String tenantId, ActionListener listener) { + if (isMultiTenancyEnabled && tenantId == null) { + listener.onFailure(new FlowFrameworkException("You don't have permission to access this resource", RestStatus.FORBIDDEN)); + return false; + } else { + return true; + } + } + + /** + * Validates the tenant resource by comparing the tenant ID from the request with the tenant ID from the resource. + * + * @param isMultiTenancyEnabled whether the multi-tenancy feature is enabled. + * @param tenantIdFromRequest The tenant ID obtained from the request. + * @param tenantIdFromResource The tenant ID obtained from the resource. + * @param listener The action listener to handle failure cases. + * @return true if the tenant IDs match or if multi-tenancy is not enabled; false if the tenant IDs do not match and multi-tenancy is enabled. + */ + public static boolean validateTenantResource( + boolean isMultiTenancyEnabled, + String tenantIdFromRequest, + String tenantIdFromResource, + ActionListener listener + ) { + if (isMultiTenancyEnabled && !Objects.equals(tenantIdFromRequest, tenantIdFromResource)) { + listener.onFailure(new FlowFrameworkException("You don't have permission to access this resource", RestStatus.FORBIDDEN)); + return false; + } else return true; + } +} diff --git a/src/test/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportActionTests.java index ef4dfe09..6012e5b8 100644 --- a/src/test/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportActionTests.java @@ -46,12 +46,14 @@ public class DeleteWorkflowTransportActionTests extends OpenSearchTestCase { private Client client; private DeleteWorkflowTransportAction deleteWorkflowTransportAction; private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private FlowFrameworkSettings flowFrameworkSettings; @Override public void setUp() throws Exception { super.setUp(); this.client = mock(Client.class); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.flowFrameworkSettings = mock(FlowFrameworkSettings.class); ClusterService clusterService = mock(ClusterService.class); ClusterSettings clusterSettings = new ClusterSettings( @@ -64,6 +66,7 @@ public void setUp() throws Exception { mock(TransportService.class), mock(ActionFilters.class), flowFrameworkIndicesHandler, + flowFrameworkSettings, client, clusterService, xContentRegistry(), diff --git a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java index 07d53300..4c59206c 100644 --- a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java @@ -62,6 +62,7 @@ public class GetWorkflowTransportActionTests extends OpenSearchTestCase { private NamedXContentRegistry xContentRegistry; private GetWorkflowTransportAction getTemplateTransportAction; private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private FlowFrameworkSettings flowFrameworkSettings; private Template template; private EncryptorUtils encryptorUtils; @@ -71,6 +72,7 @@ public void setUp() throws Exception { this.client = mock(Client.class); this.xContentRegistry = mock(NamedXContentRegistry.class); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.flowFrameworkSettings = mock(FlowFrameworkSettings.class); this.sdkClient = SdkClientFactory.createSdkClient(client, xContentRegistry, Collections.emptyMap()); this.encryptorUtils = new EncryptorUtils(mock(ClusterService.class), client, sdkClient, xContentRegistry); ClusterService clusterService = mock(ClusterService.class); @@ -84,6 +86,7 @@ public void setUp() throws Exception { mock(TransportService.class), mock(ActionFilters.class), flowFrameworkIndicesHandler, + flowFrameworkSettings, client, encryptorUtils, clusterService, diff --git a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java index 623270a2..77050e83 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java @@ -69,6 +69,7 @@ public class ProvisionWorkflowTransportActionTests extends OpenSearchTestCase { private ProvisionWorkflowTransportAction provisionWorkflowTransportAction; private Template template; private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private FlowFrameworkSettings flowFrameworkSettings; private EncryptorUtils encryptorUtils; private PluginsService pluginsService; @@ -79,6 +80,7 @@ public void setUp() throws Exception { this.client = mock(Client.class); this.workflowProcessSorter = mock(WorkflowProcessSorter.class); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.flowFrameworkSettings = mock(FlowFrameworkSettings.class); this.encryptorUtils = mock(EncryptorUtils.class); this.pluginsService = mock(PluginsService.class); ClusterService clusterService = mock(ClusterService.class); @@ -95,6 +97,7 @@ public void setUp() throws Exception { client, workflowProcessSorter, flowFrameworkIndicesHandler, + flowFrameworkSettings, encryptorUtils, pluginsService, clusterService, diff --git a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java index 6e1e65d3..edba29b1 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java @@ -78,6 +78,7 @@ public void setUp() throws Exception { this.workflowStepFactory = mock(WorkflowStepFactory.class); this.workflowProcessSorter = mock(WorkflowProcessSorter.class); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.flowFrameworkSettings = mock(FlowFrameworkSettings.class); this.encryptorUtils = mock(EncryptorUtils.class); this.pluginsService = mock(PluginsService.class); diff --git a/src/test/java/org/opensearch/flowframework/util/TenantAwareHelperTests.java b/src/test/java/org/opensearch/flowframework/util/TenantAwareHelperTests.java new file mode 100644 index 00000000..750094a6 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/util/TenantAwareHelperTests.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.junit.Before; +import org.junit.Test; + +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.verify; + +public class TenantAwareHelperTests { + + @Mock + private ActionListener actionListener; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + } + + @Test + public void testValidateTenantId_MultiTenancyEnabled_TenantIdNull() { + boolean result = TenantAwareHelper.validateTenantId(true, null, actionListener); + assertFalse(result); + ArgumentCaptor captor = ArgumentCaptor.forClass(FlowFrameworkException.class); + verify(actionListener).onFailure(captor.capture()); + FlowFrameworkException exception = captor.getValue(); + assert exception.status() == RestStatus.FORBIDDEN; + assert exception.getMessage().equals("You don't have permission to access this resource"); + } + + @Test + public void testValidateTenantId_MultiTenancyEnabled_TenantIdPresent() { + assertTrue(TenantAwareHelper.validateTenantId(true, "_tenant_id", actionListener)); + } + + @Test + public void testValidateTenantId_MultiTenancyDisabled() { + assertTrue(TenantAwareHelper.validateTenantId(false, null, actionListener)); + } + + @Test + public void testValidateTenantResource_MultiTenancyEnabled_TenantIdMismatch() { + boolean result = TenantAwareHelper.validateTenantResource(true, null, "different_tenant_id", actionListener); + assertFalse(result); + ArgumentCaptor captor = ArgumentCaptor.forClass(FlowFrameworkException.class); + verify(actionListener).onFailure(captor.capture()); + FlowFrameworkException exception = captor.getValue(); + assert exception.status() == RestStatus.FORBIDDEN; + assert exception.getMessage().equals("You don't have permission to access this resource"); + } + + @Test + public void testValidateTenantResource_MultiTenancyEnabled_TenantIdMatch() { + assertTrue(TenantAwareHelper.validateTenantResource(true, "_tenant_id", "_tenant_id", actionListener)); + } + + @Test + public void testValidateTenantResource_MultiTenancyDisabled() { + assertTrue(TenantAwareHelper.validateTenantResource(false, "_tenant_id", "different_tenant_id", actionListener)); + } +}