Skip to content

Commit

Permalink
Added Setting to Toggle Data Source Management Code Paths (#2723)
Browse files Browse the repository at this point in the history
Signed-off-by: Frank Dattalo <fddattal@amazon.com>
  • Loading branch information
fddattal committed Jul 9, 2024
1 parent aec4825 commit d639796
Show file tree
Hide file tree
Showing 17 changed files with 809 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import lombok.RequiredArgsConstructor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
Expand All @@ -26,11 +27,14 @@
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.datasources.exceptions.DataSourceClientException;
import org.opensearch.sql.datasources.exceptions.ErrorMessage;
import org.opensearch.sql.datasources.utils.Scheduler;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.utils.MetricUtils;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
import org.opensearch.sql.opensearch.util.RestRequestUtil;
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
import org.opensearch.sql.spark.leasemanager.ConcurrencyLimitExceededException;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest;
Expand All @@ -45,13 +49,16 @@
import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionRequest;
import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse;

@RequiredArgsConstructor
public class RestAsyncQueryManagementAction extends BaseRestHandler {

public static final String ASYNC_QUERY_ACTIONS = "async_query_actions";
public static final String BASE_ASYNC_QUERY_ACTION_URL = "/_plugins/_async_query";

private static final Logger LOG = LogManager.getLogger(RestAsyncQueryManagementAction.class);

private final OpenSearchSettings settings;

@Override
public String getName() {
return ASYNC_QUERY_ACTIONS;
Expand Down Expand Up @@ -100,6 +107,9 @@ public List<Route> routes() {
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient nodeClient)
throws IOException {
if (!dataSourcesEnabled()) {
return dataSourcesDisabledError(restRequest);
}
switch (restRequest.method()) {
case POST:
return executePostRequest(restRequest, nodeClient);
Expand Down Expand Up @@ -272,4 +282,21 @@ private void addCustomerErrorMetric(RestRequest.Method requestMethod) {
break;
}
}

private boolean dataSourcesEnabled() {
return settings.getSettingValue(Settings.Key.DATASOURCES_ENABLED);
}

private RestChannelConsumer dataSourcesDisabledError(RestRequest request) {

RestRequestUtil.consumeAllRequestParameters(request);

return channel -> {
reportError(
channel,
new IllegalAccessException(
String.format("%s setting is false", Settings.Key.DATASOURCES_ENABLED.getKeyValue())),
BAD_REQUEST);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,10 @@ private DataSourceServiceImpl createDataSourceService() {
String masterKey = "a57d991d9b573f75b9bba1df";
DataSourceMetadataStorage dataSourceMetadataStorage =
new OpenSearchDataSourceMetadataStorage(
client, clusterService, new EncryptorImpl(masterKey));
client,
clusterService,
new EncryptorImpl(masterKey),
(OpenSearchSettings) pluginSettings);
return new DataSourceServiceImpl(
new ImmutableSet.Builder<DataSourceFactory>()
.add(new GlueDataSourceFactory(pluginSettings))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package org.opensearch.sql.spark.rest;

import com.google.gson.Gson;
import com.google.gson.JsonObject;
import lombok.SneakyThrows;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.opensearch.client.node.NodeClient;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
import org.opensearch.threadpool.ThreadPool;

public class RestAsyncQueryManagementActionTest {

private OpenSearchSettings settings;
private RestRequest request;
private RestChannel channel;
private NodeClient nodeClient;
private ThreadPool threadPool;
private RestAsyncQueryManagementAction unit;

@BeforeEach
public void setup() {
settings = Mockito.mock(OpenSearchSettings.class);
request = Mockito.mock(RestRequest.class);
channel = Mockito.mock(RestChannel.class);
nodeClient = Mockito.mock(NodeClient.class);
threadPool = Mockito.mock(ThreadPool.class);

Mockito.when(nodeClient.threadPool()).thenReturn(threadPool);

unit = new RestAsyncQueryManagementAction(settings);
}

@Test
@SneakyThrows
public void testWhenDataSourcesAreDisabled() {
setDataSourcesEnabled(false);
unit.handleRequest(request, channel, nodeClient);
Mockito.verifyNoInteractions(nodeClient);
ArgumentCaptor<RestResponse> response = ArgumentCaptor.forClass(RestResponse.class);
Mockito.verify(channel, Mockito.times(1)).sendResponse(response.capture());
Assertions.assertEquals(400, response.getValue().status().getStatus());
JsonObject actualResponseJson =
new Gson().fromJson(response.getValue().content().utf8ToString(), JsonObject.class);
JsonObject expectedResponseJson = new JsonObject();
expectedResponseJson.addProperty("status", 400);
expectedResponseJson.add("error", new JsonObject());
expectedResponseJson.getAsJsonObject("error").addProperty("type", "IllegalAccessException");
expectedResponseJson.getAsJsonObject("error").addProperty("reason", "Invalid Request");
expectedResponseJson
.getAsJsonObject("error")
.addProperty("details", "plugins.query.datasources.enabled setting is false");
Assertions.assertEquals(expectedResponseJson, actualResponseJson);
}

@Test
@SneakyThrows
public void testWhenDataSourcesAreEnabled() {
setDataSourcesEnabled(true);
Mockito.when(request.method()).thenReturn(RestRequest.Method.GET);
unit.handleRequest(request, channel, nodeClient);
Mockito.verify(threadPool, Mockito.times(1))
.schedule(ArgumentMatchers.any(), ArgumentMatchers.any(), ArgumentMatchers.any());
Mockito.verifyNoInteractions(channel);
}

@Test
public void testGetName() {
Assertions.assertEquals("async_query_actions", unit.getName());
}

private void setDataSourcesEnabled(boolean value) {
Mockito.when(settings.getSettingValue(Settings.Key.DATASOURCES_ENABLED)).thenReturn(value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public enum Key {
ENCYRPTION_MASTER_KEY("plugins.query.datasources.encryption.masterkey"),
DATASOURCES_URI_HOSTS_DENY_LIST("plugins.query.datasources.uri.hosts.denylist"),
DATASOURCES_LIMIT("plugins.query.datasources.limit"),
DATASOURCES_ENABLED("plugins.query.datasources.enabled"),

METRICS_ROLLING_WINDOW("plugins.query.metrics.rolling_window"),
METRICS_ROLLING_INTERVAL("plugins.query.metrics.rolling_interval"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import lombok.RequiredArgsConstructor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException;
import org.opensearch.sql.datasources.exceptions.ErrorMessage;
Expand All @@ -37,14 +40,19 @@
import org.opensearch.sql.datasources.utils.XContentParserUtils;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.utils.MetricUtils;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
import org.opensearch.sql.opensearch.util.RestRequestUtil;

@RequiredArgsConstructor
public class RestDataSourceQueryAction extends BaseRestHandler {

public static final String DATASOURCE_ACTIONS = "datasource_actions";
public static final String BASE_DATASOURCE_ACTION_URL = "/_plugins/_query/_datasources";

private static final Logger LOG = LogManager.getLogger(RestDataSourceQueryAction.class);

private final OpenSearchSettings settings;

@Override
public String getName() {
return DATASOURCE_ACTIONS;
Expand Down Expand Up @@ -115,6 +123,9 @@ public List<Route> routes() {
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient nodeClient)
throws IOException {
if (!enabled()) {
return disabledError(restRequest);
}
switch (restRequest.method()) {
case POST:
return executePostRequest(restRequest, nodeClient);
Expand Down Expand Up @@ -314,4 +325,22 @@ private static boolean isClientError(Exception e) {
|| e instanceof IllegalArgumentException
|| e instanceof IllegalStateException;
}

private boolean enabled() {
return settings.getSettingValue(Settings.Key.DATASOURCES_ENABLED);
}

private RestChannelConsumer disabledError(RestRequest request) {

RestRequestUtil.consumeAllRequestParameters(request);

return channel -> {
reportError(
channel,
new OpenSearchStatusException(
String.format("%s setting is false", Settings.Key.DATASOURCES_ENABLED.getKeyValue()),
BAD_REQUEST),
BAD_REQUEST);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasources.encryptor.Encryptor;
import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException;
import org.opensearch.sql.datasources.service.DataSourceMetadataStorage;
import org.opensearch.sql.datasources.utils.XContentParserUtils;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;

public class OpenSearchDataSourceMetadataStorage implements DataSourceMetadataStorage {

Expand All @@ -61,6 +63,7 @@ public class OpenSearchDataSourceMetadataStorage implements DataSourceMetadataSt
private final ClusterService clusterService;

private final Encryptor encryptor;
private final OpenSearchSettings settings;

/**
* This class implements DataSourceMetadataStorage interface using OpenSearch as underlying
Expand All @@ -71,14 +74,21 @@ public class OpenSearchDataSourceMetadataStorage implements DataSourceMetadataSt
* @param encryptor Encryptor.
*/
public OpenSearchDataSourceMetadataStorage(
Client client, ClusterService clusterService, Encryptor encryptor) {
Client client,
ClusterService clusterService,
Encryptor encryptor,
OpenSearchSettings settings) {
this.client = client;
this.clusterService = clusterService;
this.encryptor = encryptor;
this.settings = settings;
}

@Override
public List<DataSourceMetadata> getDataSourceMetadata() {
if (!isEnabled()) {
return Collections.emptyList();
}
if (!this.clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME)) {
createDataSourcesIndex();
return Collections.emptyList();
Expand All @@ -88,6 +98,9 @@ public List<DataSourceMetadata> getDataSourceMetadata() {

@Override
public Optional<DataSourceMetadata> getDataSourceMetadata(String datasourceName) {
if (!isEnabled()) {
return Optional.empty();
}
if (!this.clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME)) {
createDataSourcesIndex();
return Optional.empty();
Expand All @@ -101,6 +114,9 @@ public Optional<DataSourceMetadata> getDataSourceMetadata(String datasourceName)

@Override
public void createDataSourceMetadata(DataSourceMetadata dataSourceMetadata) {
if (!isEnabled()) {
throw new IllegalStateException("Data source management is disabled");
}
encryptDecryptAuthenticationData(dataSourceMetadata, true);
if (!this.clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME)) {
createDataSourcesIndex();
Expand Down Expand Up @@ -134,6 +150,9 @@ public void createDataSourceMetadata(DataSourceMetadata dataSourceMetadata) {

@Override
public void updateDataSourceMetadata(DataSourceMetadata dataSourceMetadata) {
if (!isEnabled()) {
throw new IllegalStateException("Data source management is disabled");
}
encryptDecryptAuthenticationData(dataSourceMetadata, true);
UpdateRequest updateRequest =
new UpdateRequest(DATASOURCE_INDEX_NAME, dataSourceMetadata.getName());
Expand Down Expand Up @@ -163,6 +182,9 @@ public void updateDataSourceMetadata(DataSourceMetadata dataSourceMetadata) {

@Override
public void deleteDataSourceMetadata(String datasourceName) {
if (!isEnabled()) {
throw new IllegalStateException("Data source management is disabled");
}
DeleteRequest deleteRequest = new DeleteRequest(DATASOURCE_INDEX_NAME);
deleteRequest.id(datasourceName);
deleteRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
Expand Down Expand Up @@ -302,4 +324,8 @@ private void handleSigV4PropertiesEncryptionDecryption(
.ifPresent(list::add);
encryptOrDecrypt(propertiesMap, isEncryption, list);
}

private boolean isEnabled() {
return settings.getSettingValue(Settings.Key.DATASOURCES_ENABLED);
}
}
Loading

0 comments on commit d639796

Please sign in to comment.