From 5dc8edbb42ee3a17b34ffbcaa53abfc4da2c6e3b Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Tue, 2 May 2023 08:58:53 -0700 Subject: [PATCH] [Feature/Extensions] Profile detector (#882) --- build.gradle | 8 +- .../ad/AnomalyDetectorExtension.java | 10 +- .../ad/AnomalyDetectorProfileRunner.java | 23 ++- .../ad/rest/RestGetAnomalyDetectorAction.java | 74 ++++--- .../org/opensearch/ad/task/ADTaskManager.java | 182 +++++++++--------- .../ADTaskProfileTransportAction.java | 70 ++++--- .../ad/transport/ProfileTransportAction.java | 16 +- .../transport/RCFPollingTransportAction.java | 45 +++-- .../ad/transport/RCFPollingTests.java | 52 +---- 9 files changed, 234 insertions(+), 246 deletions(-) diff --git a/build.gradle b/build.gradle index 55a2e35db..514f67aab 100644 --- a/build.gradle +++ b/build.gradle @@ -776,7 +776,13 @@ List jacocoExclusions = [ 'org.opensearch.ad.ratelimit.ResultWriteRequest', 'org.opensearch.ad.AnomalyDetectorJobRunner.*', 'org.opensearch.ad.util.RestHandlerUtils', - 'org.opensearch.ad.transport.SearchAnomalyDetectorInfoTransportAction.*' + 'org.opensearch.ad.transport.SearchAnomalyDetectorInfoTransportAction.*', + 'org.opensearch.ad.transport.RCFPollingAction', + 'org.opensearch.ad.transport.RCFPollingRequest', + 'org.opensearch.ad.transport.RCFPollingTransportAction', + 'org.opensearch.ad.transport.RCFPollingTransportAction.*', + 'org.opensearch.ad.transport.RCFPollingResponse', + ] diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorExtension.java b/src/main/java/org/opensearch/ad/AnomalyDetectorExtension.java index 53cdcc5d2..e69827bcc 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorExtension.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorExtension.java @@ -92,6 +92,8 @@ import org.opensearch.ad.transport.ADResultBulkTransportAction; import org.opensearch.ad.transport.ADStatsNodesAction; import org.opensearch.ad.transport.ADStatsNodesTransportAction; +import org.opensearch.ad.transport.ADTaskProfileAction; +import org.opensearch.ad.transport.ADTaskProfileTransportAction; import org.opensearch.ad.transport.AnomalyDetectorJobAction; import org.opensearch.ad.transport.AnomalyDetectorJobTransportAction; import org.opensearch.ad.transport.AnomalyResultAction; @@ -112,6 +114,8 @@ import org.opensearch.ad.transport.PreviewAnomalyDetectorTransportAction; import org.opensearch.ad.transport.ProfileAction; import org.opensearch.ad.transport.ProfileTransportAction; +import org.opensearch.ad.transport.RCFPollingAction; +import org.opensearch.ad.transport.RCFPollingTransportAction; import org.opensearch.ad.transport.RCFResultAction; import org.opensearch.ad.transport.RCFResultTransportAction; import org.opensearch.ad.transport.SearchADTasksAction; @@ -578,7 +582,7 @@ public PooledObject wrap(LinkedBuffer obj) { xContentRegistry, anomalyDetectionIndices, nodeFilter, - null, // hashRing + // null, //hashring MultiNode support https://github.com/opensearch-project/opensearch-sdk-java/issues/200 adTaskCacheManager, threadPool ); @@ -801,7 +805,9 @@ public List> getExecutorBuilders(Settings settings) { new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class), new ActionHandler<>(ForwardADTaskAction.INSTANCE, ForwardADTaskTransportAction.class), new ActionHandler<>(ADBatchAnomalyResultAction.INSTANCE, ADBatchAnomalyResultTransportAction.class), - new ActionHandler<>(ADCancelTaskAction.INSTANCE, ADCancelTaskTransportAction.class) + new ActionHandler<>(ADCancelTaskAction.INSTANCE, ADCancelTaskTransportAction.class), + new ActionHandler<>(RCFPollingAction.INSTANCE, RCFPollingTransportAction.class), + new ActionHandler<>(ADTaskProfileAction.INSTANCE, ADTaskProfileTransportAction.class) ); } diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java index 210168868..1ce755d48 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java @@ -49,8 +49,10 @@ import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.settings.NumericSetting; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.ProfileAction; import org.opensearch.ad.transport.ProfileRequest; import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.transport.RCFPollingAction; import org.opensearch.ad.transport.RCFPollingRequest; import org.opensearch.ad.transport.RCFPollingResponse; import org.opensearch.ad.util.DiscoveryNodeFilterer; @@ -78,7 +80,7 @@ public class AnomalyDetectorProfileRunner extends AbstractProfileRunner { private final Logger logger = LogManager.getLogger(AnomalyDetectorProfileRunner.class); - private SDKRestClient client; + private SDKRestClient sdkRestClient; private SDKNamedXContentRegistry xContentRegistry; private DiscoveryNodeFilterer nodeFilter; private final TransportService transportService; @@ -86,7 +88,7 @@ public class AnomalyDetectorProfileRunner extends AbstractProfileRunner { private final int maxTotalEntitiesToTrack; public AnomalyDetectorProfileRunner( - SDKRestClient client, + SDKRestClient sdkRestClient, SDKNamedXContentRegistry xContentRegistry, DiscoveryNodeFilterer nodeFilter, long requiredSamples, @@ -94,7 +96,7 @@ public AnomalyDetectorProfileRunner( ADTaskManager adTaskManager ) { super(requiredSamples); - this.client = client; + this.sdkRestClient = sdkRestClient; this.xContentRegistry = xContentRegistry; this.nodeFilter = nodeFilter; if (requiredSamples <= 0) { @@ -119,7 +121,7 @@ private void calculateTotalResponsesToWait( ActionListener listener ) { GetRequest getDetectorRequest = new GetRequest(ANOMALY_DETECTORS_INDEX, detectorId); - client.get(getDetectorRequest, ActionListener.wrap(getDetectorResponse -> { + sdkRestClient.get(getDetectorRequest, ActionListener.wrap(getDetectorResponse -> { if (getDetectorResponse != null && getDetectorResponse.isExists()) { try ( XContentParser xContentParser = XContentType.JSON @@ -153,7 +155,7 @@ private void prepareProfile( ) { String detectorId = detector.getDetectorId(); GetRequest getRequest = new GetRequest(ANOMALY_DETECTOR_JOB_INDEX, detectorId); - client.get(getRequest, ActionListener.wrap(getResponse -> { + sdkRestClient.get(getRequest, ActionListener.wrap(getResponse -> { if (getResponse != null && getResponse.isExists()) { try ( XContentParser parser = XContentType.JSON @@ -298,7 +300,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener { + sdkRestClient.search(request, ActionListener.wrap(searchResponse -> { Map aggMap = searchResponse.getAggregations().asMap(); InternalCardinality totalEntities = (InternalCardinality) aggMap.get(CommonName.TOTAL_ENTITIES); long value = totalEntities.getValue(); @@ -321,7 +323,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener { + sdkRestClient.search(searchRequest, ActionListener.wrap(searchResponse -> { DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); Aggregations aggs = searchResponse.getAggregations(); if (aggs == null) { @@ -383,7 +385,7 @@ private void profileStateRelated( ) { if (enabled) { RCFPollingRequest request = new RCFPollingRequest(detector.getDetectorId()); - // client.execute(RCFPollingAction.INSTANCE, request, onPollRCFUpdates(detector, profilesToCollect, listener)); + sdkRestClient.execute(RCFPollingAction.INSTANCE, request, onPollRCFUpdates(detector, profilesToCollect, listener)); } else { DetectorProfile.Builder builder = new DetectorProfile.Builder(); if (profilesToCollect.contains(DetectorProfileName.STATE)) { @@ -402,7 +404,8 @@ private void profileModels( ) { DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); ProfileRequest profileRequest = new ProfileRequest(detector.getDetectorId(), profiles, forMultiEntityDetector, dataNodes); - // client.execute(ProfileAction.INSTANCE, profileRequest, onModelResponse(detector, profiles, job, listener));// get init progress + sdkRestClient.execute(ProfileAction.INSTANCE, profileRequest, onModelResponse(detector, profiles, job, listener));// get init + // progress } private ActionListener onModelResponse( @@ -482,7 +485,7 @@ private void confirmMultiEntityDetectorInitStatus( MultiResponsesDelegateActionListener listener ) { SearchRequest searchLatestResult = createInittedEverRequest(detector.getDetectorId(), enabledTime, detector.getResultIndex()); - client.search(searchLatestResult, onInittedEver(enabledTime, profile, profilesToCollect, detector, totalUpdates, listener)); + sdkRestClient.search(searchLatestResult, onInittedEver(enabledTime, profile, profilesToCollect, detector, totalUpdates, listener)); } private ActionListener onInittedEver( diff --git a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java index cebe36fee..489440c58 100644 --- a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java @@ -12,6 +12,7 @@ package org.opensearch.ad.rest; import static org.opensearch.ad.util.RestHandlerUtils.DETECTOR_ID; +import static org.opensearch.ad.util.RestHandlerUtils.PROFILE; import static org.opensearch.ad.util.RestHandlerUtils.TYPE; import java.io.IOException; @@ -108,13 +109,6 @@ protected ExtensionRestResponse prepareRequest(RestRequest request) throws IOExc return getAnomalyDetectorResponse(request, response); } - @Override - public List replacedRouteHandlers() { - String path = String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorExtension.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID); - String newPath = String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorExtension.AD_BASE_DETECTORS_URI, DETECTOR_ID); - return ImmutableList.of(new ReplacedRouteHandler(RestRequest.Method.GET, newPath, RestRequest.Method.GET, path, handleRequest)); - } - private Function handleRequest = (request) -> { try { return prepareRequest(request); @@ -124,54 +118,58 @@ public List replacedRouteHandlers() { } }; - /*@Override - public List routes() { + @Override + public List routeHandlers() { return ImmutableList - .of( - // Opensearch-only API. Considering users may provide entity in the search body, support POST as well. - new Route( - RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE) - ), - new Route( - RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/{%s}/%s/{%s}", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE, TYPE) - ) - ); - }*/ - - /* @Override - public List replacedRoutes() { - String path = String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID); - String newPath = String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID); + .of( + // Opensearch-only API. Considering users may provide entity in the search body, support POST as well. + new RouteHandler( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorExtension.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE), + handleRequest + ), + new RouteHandler( + RestRequest.Method.POST, + String + .format(Locale.ROOT, "%s/{%s}/%s/{%s}", AnomalyDetectorExtension.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE, TYPE), + handleRequest + ) + ); + } + + @Override + public List replacedRouteHandlers() { + String path = String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorExtension.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID); + String newPath = String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorExtension.AD_BASE_DETECTORS_URI, DETECTOR_ID); return ImmutableList .of( - new ReplacedRoute(RestRequest.Method.GET, newPath, RestRequest.Method.GET, path), - new ReplacedRoute(RestRequest.Method.HEAD, newPath, RestRequest.Method.HEAD, path), - new ReplacedRoute( + new ReplacedRouteHandler(RestRequest.Method.GET, newPath, RestRequest.Method.GET, path, handleRequest), + new ReplacedRouteHandler(RestRequest.Method.HEAD, newPath, RestRequest.Method.HEAD, path, handleRequest), + new ReplacedRouteHandler( RestRequest.Method.GET, - String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE), + String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorExtension.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE), RestRequest.Method.GET, - String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, PROFILE) + String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorExtension.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, PROFILE), + handleRequest ), - // types is a profile names. See a complete list of supported profiles names in - // org.opensearch.ad.model.ProfileName. - new ReplacedRoute( + new ReplacedRouteHandler( RestRequest.Method.GET, - String.format(Locale.ROOT, "%s/{%s}/%s/{%s}", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE, TYPE), + String + .format(Locale.ROOT, "%s/{%s}/%s/{%s}", AnomalyDetectorExtension.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE, TYPE), RestRequest.Method.GET, String .format( Locale.ROOT, "%s/{%s}/%s/{%s}", - AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, + AnomalyDetectorExtension.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, PROFILE, TYPE - ) + ), + handleRequest ) ); - }*/ + } private Entity buildEntity(RestRequest request, String detectorId) throws IOException { if (Strings.isEmpty(detectorId)) { diff --git a/src/main/java/org/opensearch/ad/task/ADTaskManager.java b/src/main/java/org/opensearch/ad/task/ADTaskManager.java index ae908babc..5281f6bb1 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskManager.java @@ -101,7 +101,6 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.ad.auth.UserIdentity; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.common.exception.ADTaskCancelledException; import org.opensearch.ad.common.exception.AnomalyDetectionException; import org.opensearch.ad.common.exception.DuplicateTaskException; @@ -196,14 +195,15 @@ public class ADTaskManager { private final Logger logger = LogManager.getLogger(this.getClass()); static final String STATE_INDEX_NOT_EXIST_MSG = "State index does not exist."; private final Set retryableErrors = ImmutableSet.of(EXCEED_HISTORICAL_ANALYSIS_LIMIT, NO_ELIGIBLE_NODE_TO_RUN_DETECTOR); - private final SDKRestClient client; + private final SDKRestClient sdkRestClient; private final OpenSearchAsyncClient sdkJavaAsyncClient; - private final SDKClusterService clusterService; + private final SDKClusterService sdkClusterService; private final SDKNamedXContentRegistry xContentRegistry; private final AnomalyDetectionIndices detectionIndices; private final DiscoveryNodeFilterer nodeFilter; private final ADTaskCacheManager adTaskCacheManager; - private final HashRing hashRing; + /* MultiNode support https://github.com/opensearch-project/opensearch-sdk-java/issues/200 */ + // private final HashRing hashRing; private volatile Integer maxOldAdTaskDocsPerDetector; private volatile Integer pieceIntervalSeconds; private volatile boolean deleteADResultWhenDeleteDetector; @@ -220,43 +220,47 @@ public class ADTaskManager { public ADTaskManager( Settings settings, - SDKClusterService clusterService, - SDKRestClient client, + SDKClusterService sdkClusterService, + SDKRestClient sdkRestClient, OpenSearchAsyncClient sdkJavaAsyncClient, SDKNamedXContentRegistry xContentRegistry, AnomalyDetectionIndices detectionIndices, DiscoveryNodeFilterer nodeFilter, - HashRing hashRing, + /* MultiNode support https://github.com/opensearch-project/opensearch-sdk-java/issues/200 */ + // HashRing hashRing, ADTaskCacheManager adTaskCacheManager, ThreadPool threadPool ) { - this.client = client; + this.sdkRestClient = sdkRestClient; this.sdkJavaAsyncClient = sdkJavaAsyncClient; this.xContentRegistry = xContentRegistry; this.detectionIndices = detectionIndices; this.nodeFilter = nodeFilter; - this.clusterService = clusterService; + this.sdkClusterService = sdkClusterService; this.adTaskCacheManager = adTaskCacheManager; - this.hashRing = hashRing; + /* MultiNode support https://github.com/opensearch-project/opensearch-sdk-java/issues/200 */ + // this.hashRing = hashRing; this.maxOldAdTaskDocsPerDetector = MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.get(settings); - clusterService + sdkClusterService .getClusterSettings() .addSettingsUpdateConsumer(MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, it -> maxOldAdTaskDocsPerDetector = it); this.pieceIntervalSeconds = BATCH_TASK_PIECE_INTERVAL_SECONDS.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(BATCH_TASK_PIECE_INTERVAL_SECONDS, it -> pieceIntervalSeconds = it); + sdkClusterService + .getClusterSettings() + .addSettingsUpdateConsumer(BATCH_TASK_PIECE_INTERVAL_SECONDS, it -> pieceIntervalSeconds = it); this.deleteADResultWhenDeleteDetector = DELETE_AD_RESULT_WHEN_DELETE_DETECTOR.get(settings); - clusterService + sdkClusterService .getClusterSettings() .addSettingsUpdateConsumer(DELETE_AD_RESULT_WHEN_DELETE_DETECTOR, it -> deleteADResultWhenDeleteDetector = it); this.maxAdBatchTaskPerNode = MAX_BATCH_TASK_PER_NODE.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BATCH_TASK_PER_NODE, it -> maxAdBatchTaskPerNode = it); + sdkClusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BATCH_TASK_PER_NODE, it -> maxAdBatchTaskPerNode = it); this.maxRunningEntitiesPerDetector = MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS.get(settings); - clusterService + sdkClusterService .getClusterSettings() .addSettingsUpdateConsumer(MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS, it -> maxRunningEntitiesPerDetector = it); @@ -265,7 +269,7 @@ public ADTaskManager( .withType(TransportRequestOptions.Type.REG) .withTimeout(REQUEST_TIMEOUT.get(settings)) .build(); - clusterService + sdkClusterService .getClusterSettings() .addSettingsUpdateConsumer( REQUEST_TIMEOUT, @@ -411,7 +415,7 @@ public void forwardRequestToLeadNode( ); }, listener); */ - client + sdkRestClient .execute( ForwardADTaskAction.INSTANCE, forwardADTaskRequest, @@ -458,7 +462,7 @@ public void startHistoricalAnalysis( ); }, listener); */ - DiscoveryNode owningNode = clusterService.localNode(); + DiscoveryNode owningNode = sdkClusterService.localNode(); logger.debug("coordinating node is : {} for detector: {}", owningNode.getId(), detectorId); forwardDetectRequestToCoordinatingNode( detector, @@ -518,7 +522,7 @@ protected void forwardDetectRequestToCoordinatingNode( ); */ Version adVersion = Version.CURRENT; - client + sdkRestClient .execute( ForwardADTaskAction.INSTANCE, // We need to check AD version of remote node as we may send clean detector cache request to old @@ -553,7 +557,7 @@ protected void forwardADTaskToCoordinatingNode( new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) ); */ - client + sdkRestClient .execute( ForwardADTaskAction.INSTANCE, new ForwardADTaskRequest(adTask, adTaskAction), @@ -587,7 +591,7 @@ protected void forwardStaleRunningEntitiesToCoordinatingNode( new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) ); */ - client + sdkRestClient .execute( ForwardADTaskAction.INSTANCE, new ForwardADTaskRequest(adTask, adTaskAction, staleRunningEntity), @@ -638,13 +642,13 @@ public void checkTaskSlots( /* @anomaly-detection commented until we have support for the hashring: https://github.com/opensearch-project/opensearch-sdk-java/issues/200 hashRing.getNodesWithSameLocalAdVersion(nodes -> { */ - DiscoveryNode[] extensionNode = { clusterService.localNode() }; + DiscoveryNode[] extensionNode = { sdkClusterService.localNode() }; int maxAdTaskSlots = extensionNode.length * maxAdBatchTaskPerNode; ADStatsRequest adStatsRequest = new ADStatsRequest(extensionNode); adStatsRequest .addAll(ImmutableSet.of(AD_USED_BATCH_TASK_SLOT_COUNT.getName(), AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT.getName())); - client.execute(ADStatsNodesAction.INSTANCE, adStatsRequest, ActionListener.wrap(adStatsResponse -> { + sdkRestClient.execute(ADStatsNodesAction.INSTANCE, adStatsRequest, ActionListener.wrap(adStatsResponse -> { int totalUsedTaskSlots = 0; // Total entity tasks running on worker nodes int totalAssignedTaskSlots = 0; // Total assigned task slots on coordinating nodes for (ADStatsNodeResponse response : adStatsResponse.getNodes()) { @@ -762,7 +766,7 @@ protected void scaleTaskLaneOnCoordinatingNode( new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) ); */ - client + sdkRestClient .execute( ForwardADTaskAction.INSTANCE, new ForwardADTaskRequest(adTask, approvedTaskSlot, ADTaskAction.SCALE_ENTITY_TASK_SLOTS), @@ -934,7 +938,7 @@ public void stopDetector( */ public void getDetector(String detectorId, Consumer> function, ActionListener listener) { GetRequest getRequest = new GetRequest(ANOMALY_DETECTORS_INDEX, detectorId); - client.get(getRequest, ActionListener.wrap(response -> { + sdkRestClient.get(getRequest, ActionListener.wrap(response -> { if (!response.isExists()) { function.accept(Optional.empty()); return; @@ -1071,7 +1075,7 @@ public void getAndExecuteOnLatestADTasks( searchRequest.source(sourceBuilder); searchRequest.indices(DETECTION_STATE_INDEX); - client.search(searchRequest, ActionListener.wrap(r -> { + sdkRestClient.search(searchRequest, ActionListener.wrap(r -> { // https://github.com/opendistro-for-elasticsearch/anomaly-detection/pull/359#discussion_r558653132 // getTotalHits will be null when we track_total_hits is false in the query request. // Add more checking here to cover some unknown cases. @@ -1162,7 +1166,7 @@ private void resetRealtimeDetectorTaskState( ADTask adTask = runningRealtimeTasks.get(0); String detectorId = adTask.getDetectorId(); GetRequest getJobRequest = new GetRequest(ANOMALY_DETECTOR_JOB_INDEX).id(detectorId); - client.get(getJobRequest, ActionListener.wrap(r -> { + sdkRestClient.get(getJobRequest, ActionListener.wrap(r -> { if (r.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry.getRegistry(), r.getSourceAsBytesRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); @@ -1314,11 +1318,11 @@ private void stopHistoricalAnalysis( /* @anomaly-detection commented until we have support for the hashring: https://github.com/opensearch-project/opensearch-sdk-java/issues/200 DiscoveryNode[] dataNodes = hashRing.getNodesWithSameLocalAdVersion(); */ - DiscoveryNode[] dataNodes = { clusterService.localNode() }; + DiscoveryNode[] dataNodes = { sdkClusterService.localNode() }; String userName = user == null ? null : user.getName(); ADCancelTaskRequest cancelTaskRequest = new ADCancelTaskRequest(detectorId, taskId, userName, dataNodes); - client + sdkRestClient .execute( ADCancelTaskAction.INSTANCE, cancelTaskRequest, @@ -1504,50 +1508,52 @@ public void getLatestHistoricalTaskProfile( private void getADTaskProfile(ADTask adDetectorLevelTask, ActionListener listener) { String detectorId = adDetectorLevelTask.getDetectorId(); - hashRing.getAllEligibleDataNodesWithKnownAdVersion(dataNodes -> { - ADTaskProfileRequest adTaskProfileRequest = new ADTaskProfileRequest(detectorId, dataNodes); - client.execute(ADTaskProfileAction.INSTANCE, adTaskProfileRequest, ActionListener.wrap(response -> { - if (response.hasFailures()) { - listener.onFailure(response.failures().get(0)); - return; - } + /* MultiNode support https://github.com/opensearch-project/opensearch-sdk-java/issues/200 */ + // hashRing.getAllEligibleDataNodesWithKnownAdVersion(dataNodes -> { - List adEntityTaskProfiles = new ArrayList<>(); - ADTaskProfile detectorTaskProfile = new ADTaskProfile(adDetectorLevelTask); - for (ADTaskProfileNodeResponse node : response.getNodes()) { - ADTaskProfile taskProfile = node.getAdTaskProfile(); - if (taskProfile != null) { - if (taskProfile.getNodeId() != null) { - // HC detector: task profile from coordinating node - // Single entity detector: task profile from worker node - detectorTaskProfile.setTaskId(taskProfile.getTaskId()); - detectorTaskProfile.setShingleSize(taskProfile.getShingleSize()); - detectorTaskProfile.setRcfTotalUpdates(taskProfile.getRcfTotalUpdates()); - detectorTaskProfile.setThresholdModelTrained(taskProfile.getThresholdModelTrained()); - detectorTaskProfile.setThresholdModelTrainingDataSize(taskProfile.getThresholdModelTrainingDataSize()); - detectorTaskProfile.setModelSizeInBytes(taskProfile.getModelSizeInBytes()); - detectorTaskProfile.setNodeId(taskProfile.getNodeId()); - detectorTaskProfile.setTotalEntitiesCount(taskProfile.getTotalEntitiesCount()); - detectorTaskProfile.setDetectorTaskSlots(taskProfile.getDetectorTaskSlots()); - detectorTaskProfile.setPendingEntitiesCount(taskProfile.getPendingEntitiesCount()); - detectorTaskProfile.setRunningEntitiesCount(taskProfile.getRunningEntitiesCount()); - detectorTaskProfile.setRunningEntities(taskProfile.getRunningEntities()); - detectorTaskProfile.setAdTaskType(taskProfile.getAdTaskType()); - } - if (taskProfile.getEntityTaskProfiles() != null) { - adEntityTaskProfiles.addAll(taskProfile.getEntityTaskProfiles()); - } + ADTaskProfileRequest adTaskProfileRequest = new ADTaskProfileRequest(detectorId, sdkClusterService.localNode()); + sdkRestClient.execute(ADTaskProfileAction.INSTANCE, adTaskProfileRequest, ActionListener.wrap(response -> { + if (response.hasFailures()) { + listener.onFailure(response.failures().get(0)); + return; + } + + List adEntityTaskProfiles = new ArrayList<>(); + ADTaskProfile detectorTaskProfile = new ADTaskProfile(adDetectorLevelTask); + for (ADTaskProfileNodeResponse node : response.getNodes()) { + ADTaskProfile taskProfile = node.getAdTaskProfile(); + if (taskProfile != null) { + if (taskProfile.getNodeId() != null) { + // HC detector: task profile from coordinating node + // Single entity detector: task profile from worker node + detectorTaskProfile.setTaskId(taskProfile.getTaskId()); + detectorTaskProfile.setShingleSize(taskProfile.getShingleSize()); + detectorTaskProfile.setRcfTotalUpdates(taskProfile.getRcfTotalUpdates()); + detectorTaskProfile.setThresholdModelTrained(taskProfile.getThresholdModelTrained()); + detectorTaskProfile.setThresholdModelTrainingDataSize(taskProfile.getThresholdModelTrainingDataSize()); + detectorTaskProfile.setModelSizeInBytes(taskProfile.getModelSizeInBytes()); + detectorTaskProfile.setNodeId(taskProfile.getNodeId()); + detectorTaskProfile.setTotalEntitiesCount(taskProfile.getTotalEntitiesCount()); + detectorTaskProfile.setDetectorTaskSlots(taskProfile.getDetectorTaskSlots()); + detectorTaskProfile.setPendingEntitiesCount(taskProfile.getPendingEntitiesCount()); + detectorTaskProfile.setRunningEntitiesCount(taskProfile.getRunningEntitiesCount()); + detectorTaskProfile.setRunningEntities(taskProfile.getRunningEntities()); + detectorTaskProfile.setAdTaskType(taskProfile.getAdTaskType()); + } + if (taskProfile.getEntityTaskProfiles() != null) { + adEntityTaskProfiles.addAll(taskProfile.getEntityTaskProfiles()); } } - if (adEntityTaskProfiles != null && adEntityTaskProfiles.size() > 0) { - detectorTaskProfile.setEntityTaskProfiles(adEntityTaskProfiles); - } - listener.onResponse(detectorTaskProfile); - }, e -> { - logger.error("Failed to get task profile for task " + adDetectorLevelTask.getTaskId(), e); - listener.onFailure(e); - })); - }, listener); + } + if (adEntityTaskProfiles != null && adEntityTaskProfiles.size() > 0) { + detectorTaskProfile.setEntityTaskProfiles(adEntityTaskProfiles); + } + listener.onResponse(detectorTaskProfile); + }, e -> { + logger.error("Failed to get task profile for task " + adDetectorLevelTask.getTaskId(), e); + listener.onFailure(e); + })); + // }, listener); } @@ -1596,7 +1602,7 @@ private void updateLatestFlagOfOldTasksAndCreateNewTask( // coordinating node once realtime job starts. // For historical analysis, this method will be called on coordinating node, so we can set coordinating // node as local node. - String coordinatingNode = detectionDateRange == null ? null : clusterService.localNode().getId(); + String coordinatingNode = detectionDateRange == null ? null : sdkClusterService.localNode().getId(); createNewADTask(detector, detectionDateRange, user, coordinatingNode, listener); } else { logger @@ -1666,7 +1672,7 @@ public void createADTaskDirectly(ADTask adTask, Consumer func request .source(adTask.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(request, ActionListener.wrap(r -> function.accept(r), e -> { + sdkRestClient.index(request, ActionListener.wrap(r -> function.accept(r), e -> { logger.error("Failed to create AD task for detector " + adTask.getDetectorId(), e); listener.onFailure(e); })); @@ -1823,7 +1829,7 @@ protected void deleteTaskDocs( } }); - client.search(searchRequest, searchListener); + sdkRestClient.search(searchRequest, searchListener); } /** @@ -1874,7 +1880,7 @@ public void cleanChildTasksAndADResultsOfDeletedTask() { } private void runBatchResultAction(IndexResponse response, ADTask adTask, ActionListener listener) { - client.execute(ADBatchAnomalyResultAction.INSTANCE, new ADBatchAnomalyResultRequest(adTask), ActionListener.wrap(r -> { + sdkRestClient.execute(ADBatchAnomalyResultAction.INSTANCE, new ADBatchAnomalyResultRequest(adTask), ActionListener.wrap(r -> { String remoteOrLocal = r.isRunTaskRemotely() ? "remote" : "local"; logger .info( @@ -1965,7 +1971,7 @@ public void updateADTask(String taskId, Map updatedFields, Actio updatedContent.put(LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli()); updateRequest.doc(updatedContent); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.update(updateRequest, listener); + sdkRestClient.update(updateRequest, listener); } /** @@ -1992,7 +1998,7 @@ public void deleteADTask(String taskId) { */ public void deleteADTask(String taskId, ActionListener listener) { DeleteRequest deleteRequest = new DeleteRequest(DETECTION_STATE_INDEX, taskId); - client.delete(deleteRequest, listener); + sdkRestClient.delete(deleteRequest, listener); } /** @@ -2190,7 +2196,7 @@ public void updateLatestRealtimeTaskOnCoordinatingNode( return; } Map updatedFields = new HashMap<>(); - updatedFields.put(COORDINATING_NODE_FIELD, clusterService.localNode().getId()); + updatedFields.put(COORDINATING_NODE_FIELD, sdkClusterService.localNode().getId()); if (initProgress != null) { updatedFields.put(INIT_PROGRESS_FIELD, initProgress); updatedFields.put(ESTIMATED_MINUTES_LEFT_FIELD, Math.max(0, NUM_MIN_SAMPLES - rcfTotalUpdates) * detectorIntervalInMinutes); @@ -2251,7 +2257,7 @@ public void initRealtimeTaskCacheAndCleanupStaleCache( detector, null, detector.getUser(), - clusterService.localNode().getId(), + sdkClusterService.localNode().getId(), ActionListener.wrap(r -> { logger.info("Recreate realtime task successfully for detector {}", detectorId); adTaskCacheManager.initRealtimeTaskCache(detectorId, detector.getDetectorIntervalInMilliseconds()); @@ -2266,7 +2272,7 @@ public void initRealtimeTaskCacheAndCleanupStaleCache( } ADTask adTask = adTaskOptional.get(); - String localNodeId = clusterService.localNode().getId(); + String localNodeId = sdkClusterService.localNode().getId(); String oldCoordinatingNode = adTask.getCoordinatingNode(); if (oldCoordinatingNode != null && !localNodeId.equals(oldCoordinatingNode)) { logger @@ -2535,7 +2541,7 @@ public void countEntityTasksByState(String detectorTaskId, List tas SearchRequest request = new SearchRequest(); request.source(sourceBuilder); request.indices(DETECTION_STATE_INDEX); - client.search(request, ActionListener.wrap(r -> { + sdkRestClient.search(request, ActionListener.wrap(r -> { TotalHits totalHits = r.getHits().getTotalHits(); listener.onResponse(totalHits.value); }, e -> listener.onFailure(e))); @@ -2652,7 +2658,7 @@ public void runNextEntityForHCADHistorical( listener.onResponse(new AnomalyDetectorJobResponse(detectorId, 0, 0, 0, RestStatus.ACCEPTED)); return; } - client.execute(ADBatchAnomalyResultAction.INSTANCE, new ADBatchAnomalyResultRequest(adTask), ActionListener.wrap(r -> { + sdkRestClient.execute(ADBatchAnomalyResultAction.INSTANCE, new ADBatchAnomalyResultRequest(adTask), ActionListener.wrap(r -> { String remoteOrLocal = r.isRunTaskRemotely() ? "remote" : "local"; logger .info( @@ -2764,7 +2770,7 @@ public int detectorTaskSlotScaleDelta(String detectorId) { /* @anomaly-detection commented until we have support for the hashring: https://github.com/opensearch-project/opensearch-sdk-java/issues/200 DiscoveryNode[] eligibleDataNodes = hashRing.getNodesWithSameLocalAdVersion(); */ - DiscoveryNode[] eligibleDataNodes = { clusterService.localNode() }; + DiscoveryNode[] eligibleDataNodes = { sdkClusterService.localNode() }; int unfinishedEntities = adTaskCacheManager.getUnfinishedEntityCount(detectorId); int totalTaskSlots = eligibleDataNodes.length * maxAdBatchTaskPerNode; int taskLaneLimit = Math.min(unfinishedEntities, Math.min(totalTaskSlots, maxRunningEntitiesPerDetector)); @@ -2808,7 +2814,7 @@ public ADTaskProfile getLocalADTaskProfilesByDetectorId(String detectorId) { List tasksOfDetector = adTaskCacheManager.getTasksOfDetector(detectorId); ADTaskProfile detectorTaskProfile = null; - String localNodeId = clusterService.localNode().getId(); + String localNodeId = sdkClusterService.localNode().getId(); if (adTaskCacheManager.isHCTaskRunning(detectorId)) { detectorTaskProfile = new ADTaskProfile(); if (adTaskCacheManager.isHCTaskCoordinatingNode(detectorId)) { @@ -3015,7 +3021,7 @@ public Entity parseEntityFromString(String entityValue, ADTask adTask) { */ public void getADTask(String taskId, ActionListener> listener) { GetRequest request = new GetRequest(DETECTION_STATE_INDEX, taskId); - client.get(request, ActionListener.wrap(r -> { + sdkRestClient.get(request, ActionListener.wrap(r -> { if (r != null && r.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry.getRegistry(), r.getSourceAsBytesRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); @@ -3133,10 +3139,12 @@ public void maintainRunningHistoricalTasks(TransportService transportService, in // Clean expired HC batch task run state cache. adTaskCacheManager.cleanExpiredHCBatchTaskRunStates(); + /* MultiNode support https://github.com/opensearch-project/opensearch-sdk-java/issues/200 */ // Find owning node with highest AD version to make sure we only have 1 node maintain running historical tasks // and we use the latest logic. - Optional owningNode = hashRing.getOwningNodeWithHighestAdVersion(AD_TASK_MAINTAINENCE_NODE_MODEL_ID); - if (!owningNode.isPresent() || !clusterService.localNode().getId().equals(owningNode.get().getId())) { + // Optional owningNode = hashRing.getOwningNodeWithHighestAdVersion(AD_TASK_MAINTAINENCE_NODE_MODEL_ID); + Optional owningNode = Optional.ofNullable(sdkClusterService.localNode()); + if (!owningNode.isPresent() || !sdkClusterService.localNode().getId().equals(owningNode.get().getId())) { return; } logger.info("Start to maintain running historical tasks"); @@ -3152,7 +3160,7 @@ public void maintainRunningHistoricalTasks(TransportService transportService, in searchRequest.source(sourceBuilder); searchRequest.indices(DETECTION_STATE_INDEX); - client.search(searchRequest, ActionListener.wrap(r -> { + sdkRestClient.search(searchRequest, ActionListener.wrap(r -> { if (r == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { return; } diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java index f4124c460..c828dd760 100644 --- a/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java @@ -12,75 +12,83 @@ package org.opensearch.ad.transport; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import org.opensearch.Version; +import org.opensearch.action.ActionListener; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.ad.cluster.HashRing; +import org.opensearch.action.support.TransportAction; import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.inject.Inject; import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; +import org.opensearch.sdk.SDKClusterService; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskManager; -public class ADTaskProfileTransportAction extends - TransportNodesAction { +import com.google.inject.Inject; + +// TODO: https://github.com/opensearch-project/opensearch-sdk-java/issues/683 (multi node support needed for extensions. +// Previously, the class used to extend TransportNodesAction by which request is sent to multiple nodes. +// For extensions as of now we only have one node support. In order to test multinode feature we need to add multinode support equivalent for SDK ) +public class ADTaskProfileTransportAction extends TransportAction { + // TransportNodesAction { private ADTaskManager adTaskManager; - private HashRing hashRing; + + /* MultiNode support https://github.com/opensearch-project/opensearch-sdk-java/issues/200 */ + // private HashRing hashRing; + + private final SDKClusterService sdkClusterService; @Inject public ADTaskProfileTransportAction( - ThreadPool threadPool, - ClusterService clusterService, - TransportService transportService, + SDKClusterService sdkClusterService, ActionFilters actionFilters, ADTaskManager adTaskManager, - HashRing hashRing + /* MultiNode support https://github.com/opensearch-project/opensearch-sdk-java/issues/200 */ + // HashRing hashRing, + TaskManager taskManager ) { - super( - ADTaskProfileAction.NAME, - threadPool, - clusterService, - transportService, - actionFilters, - ADTaskProfileRequest::new, - ADTaskProfileNodeRequest::new, - ThreadPool.Names.MANAGEMENT, - ADTaskProfileNodeResponse.class - ); + super(ADTaskProfileAction.NAME, actionFilters, taskManager); this.adTaskManager = adTaskManager; - this.hashRing = hashRing; + /* MultiNode support https://github.com/opensearch-project/opensearch-sdk-java/issues/200 */ + // this.hashRing = hashRing; + this.sdkClusterService = sdkClusterService; } - @Override protected ADTaskProfileResponse newResponse( ADTaskProfileRequest request, List responses, List failures ) { - return new ADTaskProfileResponse(clusterService.getClusterName(), responses, failures); + return new ADTaskProfileResponse(sdkClusterService.state().getClusterName(), responses, failures); } - @Override protected ADTaskProfileNodeRequest newNodeRequest(ADTaskProfileRequest request) { return new ADTaskProfileNodeRequest(request); } - @Override protected ADTaskProfileNodeResponse newNodeResponse(StreamInput in) throws IOException { return new ADTaskProfileNodeResponse(in); } @Override - protected ADTaskProfileNodeResponse nodeOperation(ADTaskProfileNodeRequest request) { + protected void doExecute(Task task, ADTaskProfileRequest request, ActionListener actionListener) { + /* @anomaly.detection Commented until we have extension support for hashring : https://github.com/opensearch-project/opensearch-sdk-java/issues/200 String remoteNodeId = request.getParentTask().getNodeId(); Version remoteAdVersion = hashRing.getAdVersion(remoteNodeId); + */ + Version remoteAdVersion = Version.CURRENT; ADTaskProfile adTaskProfile = adTaskManager.getLocalADTaskProfilesByDetectorId(request.getDetectorId()); - return new ADTaskProfileNodeResponse(clusterService.localNode(), adTaskProfile, remoteAdVersion); + actionListener + .onResponse( + newResponse( + request, + new ArrayList<>(List.of(new ADTaskProfileNodeResponse(sdkClusterService.localNode(), adTaskProfile, remoteAdVersion))), + new ArrayList<>() + ) + ); } } diff --git a/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java index a003a8059..68d5475da 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java @@ -45,29 +45,25 @@ public class ProfileTransportAction extends TransportAction this.numModelsToReturn = it); + this.sdkClusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_MODEL_SIZE_PER_NODE, it -> this.numModelsToReturn = it); } private ProfileResponse newResponse(ProfileRequest request, List responses, List failures) { - return new ProfileResponse(clusterService.state().getClusterName(), responses, failures); + return new ProfileResponse(sdkClusterService.state().getClusterName(), responses, failures); } @Override @@ -133,7 +129,7 @@ protected void doExecute(Task task, ProfileRequest request, ActionListener { +public class RCFPollingTransportAction extends TransportAction { private static final Logger LOG = LogManager.getLogger(RCFPollingTransportAction.class); static final String NO_NODE_FOUND_MSG = "Cannot find model hosting node"; @@ -49,29 +50,30 @@ public class RCFPollingTransportAction extends HandledTransportAction rcfNode = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(rcfModelID); + /* Commenting the below piece of code as we do not have support for multinode + https://github.com/opensearch-project/opensearch-sdk-java/issues/200 + * */ + // Optional rcfNode = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(rcfModelID); + + Optional rcfNode = Optional.ofNullable(sdkClusterService.localNode()); if (!rcfNode.isPresent()) { listener.onFailure(new AnomalyDetectionException(adID, NO_NODE_FOUND_MSG)); return; @@ -89,7 +96,7 @@ protected void doExecute(Task task, RCFPollingRequest request, ActionListener TransportResponseHandler rcfRollingHandler(TransportResponseHandler handler) { return new TransportResponseHandler() { @Override @@ -260,6 +214,7 @@ public String executor() { * @param handler callback handler * @return handlder that would return a connection failure */ +/* private TransportResponseHandler rcfFailureRollingHandler(TransportResponseHandler handler) { return new TransportResponseHandler() { @Override @@ -363,3 +318,4 @@ public void testNullDetectorId() { assertTrue(emptyRequest.validate() != null); } } +*/