Skip to content

Commit

Permalink
injecting ClientHelperService into ReindexDataStreamPersistentTaskExe…
Browse files Browse the repository at this point in the history
…cutor
  • Loading branch information
masseyke committed Nov 25, 2024
1 parent 645a69d commit 783b955
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 84 deletions.
1 change: 1 addition & 0 deletions modules/data-streams/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
requires org.apache.lucene.core;

exports org.elasticsearch.datastreams.action to org.elasticsearch.server;
exports org.elasticsearch.datastreams.task to org.elasticsearch.server;
exports org.elasticsearch.datastreams.lifecycle.action to org.elasticsearch.server;
exports org.elasticsearch.datastreams.lifecycle;
exports org.elasticsearch.datastreams.options.action to org.elasticsearch.server;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.health.HealthIndicatorService;
import org.elasticsearch.index.IndexSettingProvider;
import org.elasticsearch.node.PluginComponentBinding;
import org.elasticsearch.persistent.PersistentTaskParams;
import org.elasticsearch.persistent.PersistentTaskState;
import org.elasticsearch.persistent.PersistentTasksExecutor;
Expand Down Expand Up @@ -173,6 +174,7 @@ public static TimeValue getLookAheadTime(Settings settings) {
private final SetOnce<DataStreamLifecycleHealthInfoPublisher> dataStreamLifecycleErrorsPublisher = new SetOnce<>();
private final SetOnce<DataStreamLifecycleHealthIndicatorService> dataStreamLifecycleHealthIndicatorService = new SetOnce<>();
private final Settings settings;
private PersistentTasksExecutor<?> persistentTaskExecutor;

public DataStreamsPlugin(Settings settings) {
this.settings = settings;
Expand Down Expand Up @@ -250,6 +252,13 @@ public Collection<?> createComponents(PluginServices services) {
components.add(errorStoreInitialisationService.get());
components.add(dataLifecycleInitialisationService.get());
components.add(dataStreamLifecycleErrorsPublisher.get());
persistentTaskExecutor = new ReindexDataStreamPersistentTaskExecutor(
services.client(),
services.clusterService(),
ReindexDataStreamTask.TASK_NAME,
services.threadPool()
);
components.add(new PluginComponentBinding<>(ReindexDataStreamPersistentTaskExecutor.class, persistentTaskExecutor));
return components;
}

Expand Down Expand Up @@ -381,6 +390,6 @@ public List<PersistentTasksExecutor<?>> getPersistentTasksExecutor(
SettingsModule settingsModule,
IndexNameExpressionResolver expressionResolver
) {
return List.of(new ReindexDataStreamPersistentTaskExecutor(client, clusterService, ReindexDataStreamTask.TASK_NAME, threadPool));
return List.of(persistentTaskExecutor);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,21 @@
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.ClientHelperService;
import org.elasticsearch.client.internal.support.AbstractClient;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Assertions;

import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

public class ReindexDataStreamClient extends AbstractClient {
private static final String RUN_AS_USER_HEADER = "es-security-runas-user";
private static final String AUTHENTICATION_KEY = "_xpack_security_authentication";
private static final String THREAD_CTX_KEY = "_xpack_security_secondary_authc";

private static final Set<String> SECURITY_HEADER_FILTERS = Set.of(RUN_AS_USER_HEADER, AUTHENTICATION_KEY, THREAD_CTX_KEY);

private final ClientHelperService clientHelperService;
private final Client client;
private final Map<String, String> headers;

public ReindexDataStreamClient(Client client, Map<String, String> headers) {
public ReindexDataStreamClient(ClientHelperService clientHelperService, Client client, Map<String, String> headers) {
super(client.settings(), client.threadPool());
this.clientHelperService = clientHelperService;
this.client = client;
this.headers = headers;
}
Expand All @@ -49,73 +38,7 @@ protected <Request extends ActionRequest, Response extends ActionResponse> void
Request request,
ActionListener<Response> listener
) {
executeWithHeadersAsync(headers, client, action, request, listener);
clientHelperService.executeWithHeadersAsync(headers, "", client, action, request, listener);
}

private static <Request extends ActionRequest, Response extends ActionResponse> void executeWithHeadersAsync(
Map<String, String> headers,
Client client,
ActionType<Response> action,
Request request,
ActionListener<Response> listener
) {
executeWithHeadersAsync(client.threadPool().getThreadContext(), headers, request, listener, (r, l) -> client.execute(action, r, l));
}

private static <Request, Response> void executeWithHeadersAsync(
ThreadContext threadContext,
Map<String, String> headers,
Request request,
ActionListener<Response> listener,
BiConsumer<Request, ActionListener<Response>> consumer
) {
// No need to rewrite authentication header because it will be handled by Security Interceptor
final Map<String, String> filteredHeaders = filterSecurityHeaders(headers);
filteredHeaders.forEach((k, v) -> System.out.printf("%-15s : %s%n", k, v));
// No headers (e.g. security not installed/in use) so execute as origin
if (filteredHeaders.isEmpty()) {
consumer.accept(request, listener);
} else {
// Otherwise stash the context and copy in the saved headers before executing
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(false);
try (ThreadContext.StoredContext ignore = stashWithHeaders(threadContext, filteredHeaders)) {
consumer.accept(request, new ContextPreservingActionListener<>(supplier, listener));
}
}
}

private static Map<String, String> filterSecurityHeaders(Map<String, String> headers) {
if (SECURITY_HEADER_FILTERS.containsAll(headers.keySet())) {
// fast-track to skip the artifice below
return headers;
} else {
return Objects.requireNonNull(headers)
.entrySet()
.stream()
.filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
}

private static ThreadContext.StoredContext stashWithHeaders(ThreadContext threadContext, Map<String, String> headers) {
final ThreadContext.StoredContext storedContext = threadContext.stashContext();
assertNoAuthorizationHeader(headers);
threadContext.copyHeaders(headers.entrySet());
return storedContext;
}

private static final Pattern authorizationHeaderPattern = Pattern.compile(
"\\s*" + Pattern.quote("Authorization") + "\\s*",
Pattern.CASE_INSENSITIVE
);

private static void assertNoAuthorizationHeader(Map<String, String> headers) {
if (Assertions.ENABLED) {
for (String header : headers.keySet()) {
if (authorizationHeaderPattern.matcher(header).find()) {
assert false : "headers contain \"Authorization\"";
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import org.elasticsearch.action.datastreams.SwapDataStreamIndexAction;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.ClientHelperService;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.Index;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.persistent.AllocatedPersistentTask;
import org.elasticsearch.persistent.PersistentTaskState;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
Expand All @@ -34,6 +36,7 @@ public class ReindexDataStreamPersistentTaskExecutor extends PersistentTasksExec
private final Client client;
private final ClusterService clusterService;
private final ThreadPool threadPool;
private ClientHelperService clientHelperService;

public ReindexDataStreamPersistentTaskExecutor(Client client, ClusterService clusterService, String taskName, ThreadPool threadPool) {
super(taskName, threadPool.generic());
Expand All @@ -42,6 +45,11 @@ public ReindexDataStreamPersistentTaskExecutor(Client client, ClusterService clu
this.threadPool = threadPool;
}

@Inject
public void initialize(ClientHelperService clientHelperService) {
this.clientHelperService = clientHelperService;
}

@Override
protected ReindexDataStreamTask createTask(
long id,
Expand Down Expand Up @@ -72,7 +80,7 @@ protected void nodeOperation(AllocatedPersistentTask task, ReindexDataStreamTask
GetDataStreamAction.Request request = new GetDataStreamAction.Request(TimeValue.MAX_VALUE, new String[] { sourceDataStream });
assert task instanceof ReindexDataStreamTask;
final ReindexDataStreamTask reindexDataStreamTask = (ReindexDataStreamTask) task;
ReindexDataStreamClient reindexClient = new ReindexDataStreamClient(client, params.headers());
ReindexDataStreamClient reindexClient = new ReindexDataStreamClient(clientHelperService, client, params.headers());
reindexClient.execute(GetDataStreamAction.INSTANCE, request, ActionListener.wrap(response -> {
List<GetDataStreamAction.Response.DataStreamInfo> dataStreamInfos = response.getDataStreams();
if (dataStreamInfos.size() == 1) {
Expand Down

0 comments on commit 783b955

Please sign in to comment.