Skip to content

Commit

Permalink
Changing security so that it uses the credentials of the user who req…
Browse files Browse the repository at this point in the history
…uested the reindex
  • Loading branch information
masseyke committed Nov 25, 2024
1 parent da43acb commit 645a69d
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@
import java.util.function.Predicate;
import java.util.function.Supplier;

import static org.elasticsearch.action.datastreams.ReindexDataStreamAction.REINDEX_DATA_STREAM_ORIGIN;
import static org.elasticsearch.cluster.metadata.DataStreamLifecycle.DATA_STREAM_LIFECYCLE_ORIGIN;

public class DataStreamsPlugin extends Plugin implements ActionPlugin, HealthPlugin, PersistentTaskPlugin {
Expand Down Expand Up @@ -382,13 +381,6 @@ public List<PersistentTasksExecutor<?>> getPersistentTasksExecutor(
SettingsModule settingsModule,
IndexNameExpressionResolver expressionResolver
) {
return List.of(
new ReindexDataStreamPersistentTaskExecutor(
new OriginSettingClient(client, REINDEX_DATA_STREAM_ORIGIN),
clusterService,
ReindexDataStreamTask.TASK_NAME,
threadPool
)
);
return List.of(new ReindexDataStreamPersistentTaskExecutor(client, clusterService, ReindexDataStreamTask.TASK_NAME, threadPool));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.action.datastreams.ReindexDataStreamAction.ReindexDataStreamResponse;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.internal.ClientHelperService;
import org.elasticsearch.cluster.metadata.DataStream;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.service.ClusterService;
Expand All @@ -36,13 +37,15 @@ public class ReindexDataStreamTransportAction extends HandledTransportAction<Rei
private final PersistentTasksService persistentTasksService;
private final TransportService transportService;
private final ClusterService clusterService;
private final ClientHelperService clientHelperService;

@Inject
public ReindexDataStreamTransportAction(
TransportService transportService,
ActionFilters actionFilters,
PersistentTasksService persistentTasksService,
ClusterService clusterService
ClusterService clusterService,
ClientHelperService clientHelperService
) {
super(
ReindexDataStreamAction.NAME,
Expand All @@ -55,6 +58,7 @@ public ReindexDataStreamTransportAction(
this.transportService = transportService;
this.persistentTasksService = persistentTasksService;
this.clusterService = clusterService;
this.clientHelperService = clientHelperService;
}

@Override
Expand All @@ -75,7 +79,11 @@ protected void doExecute(Task task, ReindexDataStreamRequest request, ActionList
sourceDataStreamName,
transportService.getThreadPool().absoluteTimeInMillis(),
totalIndices,
totalIndicesToBeUpgraded
totalIndicesToBeUpgraded,
clientHelperService.getPersistableSafeSecurityHeaders(
transportService.getThreadPool().getThreadContext(),
clusterService.state()
)
);
String persistentTaskId = getPersistentTaskId(sourceDataStreamName);
persistentTasksService.sendStartRequest(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.datastreams.task;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.support.AbstractClient;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Assertions;

import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

public class ReindexDataStreamClient extends AbstractClient {
private static final String RUN_AS_USER_HEADER = "es-security-runas-user";
private static final String AUTHENTICATION_KEY = "_xpack_security_authentication";
private static final String THREAD_CTX_KEY = "_xpack_security_secondary_authc";

private static final Set<String> SECURITY_HEADER_FILTERS = Set.of(RUN_AS_USER_HEADER, AUTHENTICATION_KEY, THREAD_CTX_KEY);

private final Client client;
private final Map<String, String> headers;

public ReindexDataStreamClient(Client client, Map<String, String> headers) {
super(client.settings(), client.threadPool());
this.client = client;
this.headers = headers;
}

@Override
protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
ActionType<Response> action,
Request request,
ActionListener<Response> listener
) {
executeWithHeadersAsync(headers, client, action, request, listener);
}

private static <Request extends ActionRequest, Response extends ActionResponse> void executeWithHeadersAsync(
Map<String, String> headers,
Client client,
ActionType<Response> action,
Request request,
ActionListener<Response> listener
) {
executeWithHeadersAsync(client.threadPool().getThreadContext(), headers, request, listener, (r, l) -> client.execute(action, r, l));
}

private static <Request, Response> void executeWithHeadersAsync(
ThreadContext threadContext,
Map<String, String> headers,
Request request,
ActionListener<Response> listener,
BiConsumer<Request, ActionListener<Response>> consumer
) {
// No need to rewrite authentication header because it will be handled by Security Interceptor
final Map<String, String> filteredHeaders = filterSecurityHeaders(headers);
filteredHeaders.forEach((k, v) -> System.out.printf("%-15s : %s%n", k, v));
// No headers (e.g. security not installed/in use) so execute as origin
if (filteredHeaders.isEmpty()) {
consumer.accept(request, listener);
} else {
// Otherwise stash the context and copy in the saved headers before executing
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(false);
try (ThreadContext.StoredContext ignore = stashWithHeaders(threadContext, filteredHeaders)) {
consumer.accept(request, new ContextPreservingActionListener<>(supplier, listener));
}
}
}

private static Map<String, String> filterSecurityHeaders(Map<String, String> headers) {
if (SECURITY_HEADER_FILTERS.containsAll(headers.keySet())) {
// fast-track to skip the artifice below
return headers;
} else {
return Objects.requireNonNull(headers)
.entrySet()
.stream()
.filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
}

private static ThreadContext.StoredContext stashWithHeaders(ThreadContext threadContext, Map<String, String> headers) {
final ThreadContext.StoredContext storedContext = threadContext.stashContext();
assertNoAuthorizationHeader(headers);
threadContext.copyHeaders(headers.entrySet());
return storedContext;
}

private static final Pattern authorizationHeaderPattern = Pattern.compile(
"\\s*" + Pattern.quote("Authorization") + "\\s*",
Pattern.CASE_INSENSITIVE
);

private static void assertNoAuthorizationHeader(Map<String, String> headers) {
if (Assertions.ENABLED) {
for (String header : headers.keySet()) {
if (authorizationHeaderPattern.matcher(header).find()) {
assert false : "headers contain \"Authorization\"";
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ protected void nodeOperation(AllocatedPersistentTask task, ReindexDataStreamTask
GetDataStreamAction.Request request = new GetDataStreamAction.Request(TimeValue.MAX_VALUE, new String[] { sourceDataStream });
assert task instanceof ReindexDataStreamTask;
final ReindexDataStreamTask reindexDataStreamTask = (ReindexDataStreamTask) task;
client.execute(GetDataStreamAction.INSTANCE, request, ActionListener.wrap(response -> {
ReindexDataStreamClient reindexClient = new ReindexDataStreamClient(client, params.headers());
reindexClient.execute(GetDataStreamAction.INSTANCE, request, ActionListener.wrap(response -> {
List<GetDataStreamAction.Response.DataStreamInfo> dataStreamInfos = response.getDataStreams();
if (dataStreamInfos.size() == 1) {
List<Index> indices = dataStreamInfos.getFirst().getDataStream().getIndices();
Expand All @@ -89,7 +90,7 @@ protected void nodeOperation(AllocatedPersistentTask task, ReindexDataStreamTask
// TODO: put all these on a queue, only process N from queue at a time
for (Index index : indicesToBeReindexed) {
reindexDataStreamTask.incrementInProgressIndicesCount();
client.execute(
reindexClient.execute(
ReindexDataStreamIndexAction.INSTANCE,
new ReindexDataStreamIndexAction.Request(index.getName()),
ActionListener.wrap(response1 -> {
Expand All @@ -99,7 +100,7 @@ protected void nodeOperation(AllocatedPersistentTask task, ReindexDataStreamTask
}, exception -> {
reindexDataStreamTask.reindexFailed(index.getName(), exception);
listener.onResponse(null);
}));
}), reindexClient);

}, exception -> {
reindexDataStreamTask.reindexFailed(index.getName(), exception);
Expand All @@ -114,8 +115,14 @@ protected void nodeOperation(AllocatedPersistentTask task, ReindexDataStreamTask
}, exception -> completeFailedPersistentTask(reindexDataStreamTask, exception)));
}

private void updateDataStream(String dataStream, String oldIndex, String newIndex, ActionListener<Void> listener) {
client.execute(
private void updateDataStream(
String dataStream,
String oldIndex,
String newIndex,
ActionListener<Void> listener,
ReindexDataStreamClient reindexClient
) {
reindexClient.execute(
SwapDataStreamIndexAction.INSTANCE,
new SwapDataStreamIndexAction.Request(dataStream, oldIndex, newIndex),
new ActionListener<>() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,58 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.persistent.PersistentTaskParams;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Map;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

public record ReindexDataStreamTaskParams(String sourceDataStream, long startTime, int totalIndices, int totalIndicesToBeUpgraded)
implements
PersistentTaskParams {
public record ReindexDataStreamTaskParams(
String sourceDataStream,
long startTime,
int totalIndices,
int totalIndicesToBeUpgraded,
Map<String, String> headers
) implements PersistentTaskParams {

public static final String NAME = ReindexDataStreamTask.TASK_NAME;
private static final String SOURCE_DATA_STREAM_FIELD = "source_data_stream";
private static final String START_TIME_FIELD = "start_time";
private static final String TOTAL_INDICES_FIELD = "total_indices";
private static final String TOTAL_INDICES_TO_BE_UPGRADED_FIELD = "total_indices_to_be_upgraded";
private static final String HEADERS_FIELD = "headers";
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ReindexDataStreamTaskParams, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
args -> new ReindexDataStreamTaskParams((String) args[0], (long) args[1], (int) args[2], (int) args[3])
args -> new ReindexDataStreamTaskParams(
(String) args[0],
(long) args[1],
(int) args[2],
(int) args[3],
(Map<String, String>) args[4]
)
);
static {
PARSER.declareString(constructorArg(), new ParseField(SOURCE_DATA_STREAM_FIELD));
PARSER.declareLong(constructorArg(), new ParseField(START_TIME_FIELD));
PARSER.declareInt(constructorArg(), new ParseField(TOTAL_INDICES_FIELD));
PARSER.declareInt(constructorArg(), new ParseField(TOTAL_INDICES_TO_BE_UPGRADED_FIELD));
PARSER.declareField(
ConstructingObjectParser.constructorArg(),
XContentParser::mapStrings,
new ParseField(HEADERS_FIELD),
ObjectParser.ValueType.OBJECT
);
}

@SuppressWarnings("unchecked")
public ReindexDataStreamTaskParams(StreamInput in) throws IOException {
this(in.readString(), in.readLong(), in.readInt(), in.readInt());
this(in.readString(), in.readLong(), in.readInt(), in.readInt(), (Map<String, String>) in.readGenericValue());
}

@Override
Expand All @@ -64,6 +85,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeLong(startTime);
out.writeInt(totalIndices);
out.writeInt(totalIndicesToBeUpgraded);
out.writeGenericValue(headers);
}

@Override
Expand All @@ -73,6 +95,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
.field(START_TIME_FIELD, startTime)
.field(TOTAL_INDICES_FIELD, totalIndices)
.field(TOTAL_INDICES_TO_BE_UPGRADED_FIELD, totalIndicesToBeUpgraded)
.stringStringMap(HEADERS_FIELD, headers)
.endObject();
}

Expand All @@ -83,4 +106,8 @@ public String getSourceDataStream() {
public static ReindexDataStreamTaskParams fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

public Map<String, String> getHeaders() {
return headers;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ protected Writeable.Reader<ReindexDataStreamTaskParams> instanceReader() {

@Override
protected ReindexDataStreamTaskParams createTestInstance() {
return new ReindexDataStreamTaskParams(randomAlphaOfLength(50), randomLong(), randomNonNegativeInt(), randomNonNegativeInt());
return new ReindexDataStreamTaskParams(
randomAlphaOfLength(50),
randomLong(),
randomNonNegativeInt(),
randomNonNegativeInt(),
Map.of()
);
}

@Override
Expand All @@ -47,7 +53,7 @@ protected ReindexDataStreamTaskParams mutateInstance(ReindexDataStreamTaskParams
case 3 -> totalIndices = totalIndicesToBeUpgraded + 1;
default -> throw new UnsupportedOperationException();
}
return new ReindexDataStreamTaskParams(sourceDataStream, startTime, totalIndices, totalIndicesToBeUpgraded);
return new ReindexDataStreamTaskParams(sourceDataStream, startTime, totalIndices, totalIndicesToBeUpgraded, Map.of());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.client.internal;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.common.util.concurrent.ThreadContext;

import java.util.Map;

public interface ClientHelperService {
Map<String, String> getPersistableSafeSecurityHeaders(ThreadContext threadContext, ClusterState clusterState);

<Request extends ActionRequest, Response extends ActionResponse> void executeWithHeadersAsync(
Map<String, String> headers,
String origin,
Client client,
ActionType<Response> action,
Request request,
ActionListener<Response> listener
);
}
Loading

0 comments on commit 645a69d

Please sign in to comment.