diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java index 0b1aa9a4a759a..b9c7177da828b 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java @@ -480,28 +480,37 @@ public T getTransient(String key) { * @param value the header value */ public void addResponseHeader(final String key, final String value) { - addResponseHeader(key, value, v -> v); + updateResponseHeader(key, value, v -> v, false); } /** - * Remove the {@code value} for the specified {@code key}. + * Update the {@code value} for the specified {@code key} * * @param key the header name + * @param value the header value */ - public void removeResponseHeader(final String key) { - threadLocal.get().responseHeaders.remove(key); + public void updateResponseHeader(final String key, final String value) { + updateResponseHeader(key, value, v -> v, true); } /** - * Add the {@code value} for the specified {@code key} with the specified {@code uniqueValue} used for de-duplication. Any duplicate + * Update the {@code value} for the specified {@code key} with the specified {@code uniqueValue} used for de-duplication. Any duplicate * {@code value} after applying {@code uniqueValue} is ignored. * * @param key the header name * @param value the header value * @param uniqueValue the function that produces de-duplication values - */ - public void addResponseHeader(final String key, final String value, final Function uniqueValue) { - threadLocal.set(threadLocal.get().putResponse(key, value, uniqueValue, maxWarningHeaderCount, maxWarningHeaderSize)); + * @param replaceExistingKey whether to replace the existing header if it already exists + */ + public void updateResponseHeader( + final String key, + final String value, + final Function uniqueValue, + final boolean replaceExistingKey + ) { + threadLocal.set( + threadLocal.get().putResponse(key, value, uniqueValue, maxWarningHeaderCount, maxWarningHeaderSize, replaceExistingKey) + ); } /** @@ -726,7 +735,8 @@ private ThreadContextStruct putResponse( final String value, final Function uniqueValue, final int maxWarningHeaderCount, - final long maxWarningHeaderSize + final long maxWarningHeaderSize, + final boolean replaceExistingKey ) { assert value != null; long newWarningHeaderSize = warningHeadersSize; @@ -768,8 +778,13 @@ private ThreadContextStruct putResponse( if (existingValues.contains(uniqueValue.apply(value))) { return this; } - // preserve insertion order - final Set newValues = Stream.concat(existingValues.stream(), Stream.of(value)).collect(LINKED_HASH_SET_COLLECTOR); + Set newValues; + if (replaceExistingKey) { + newValues = Stream.of(value).collect(LINKED_HASH_SET_COLLECTOR); + } else { + // preserve insertion order + newValues = Stream.concat(existingValues.stream(), Stream.of(value)).collect(LINKED_HASH_SET_COLLECTOR); + } newResponseHeaders = new HashMap<>(responseHeaders); newResponseHeaders.put(key, Collections.unmodifiableSet(newValues)); } else { diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index 564eff6c10df6..ca1957cdb1633 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -322,10 +322,7 @@ public void writeTaskResourceUsage(SearchShardTask task, String nodeId) { ) .build(); // Remove the existing TASK_RESOURCE_USAGE header since it would have come from an earlier phase in the same request. - synchronized (this) { - threadPool.getThreadContext().removeResponseHeader(TASK_RESOURCE_USAGE); - threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString()); - } + threadPool.getThreadContext().updateResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString()); } catch (Exception e) { logger.debug("Error during writing task resource usage: ", e); } diff --git a/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java b/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java index 10669ca1a805b..e6d07c5630541 100644 --- a/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java +++ b/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java @@ -344,11 +344,16 @@ public void testResponseHeaders() { } final String value = HeaderWarning.formatWarning("qux"); - threadContext.addResponseHeader("baz", value, s -> HeaderWarning.extractWarningValueFromWarningHeader(s, false)); + threadContext.updateResponseHeader("baz", value, s -> HeaderWarning.extractWarningValueFromWarningHeader(s, false), false); // pretend that another thread created the same response at a different time if (randomBoolean()) { final String duplicateValue = HeaderWarning.formatWarning("qux"); - threadContext.addResponseHeader("baz", duplicateValue, s -> HeaderWarning.extractWarningValueFromWarningHeader(s, false)); + threadContext.updateResponseHeader( + "baz", + duplicateValue, + s -> HeaderWarning.extractWarningValueFromWarningHeader(s, false), + false + ); } threadContext.addResponseHeader("Warning", "One is the loneliest number");