Skip to content

Commit

Permalink
fix concurrent modification issue in thread context (opensearch-proje…
Browse files Browse the repository at this point in the history
…ct#14084)

Signed-off-by: Chenyang Ji <cyji@amazon.com>
  • Loading branch information
ansjcy authored Jun 10, 2024
1 parent 42d6af6 commit c8f0b6d
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -480,28 +480,37 @@ public <T> T getTransient(String key) {
* @param value the header value
*/
public void addResponseHeader(final String key, final String value) {
addResponseHeader(key, value, v -> v);
updateResponseHeader(key, value, v -> v, false);
}

/**
* Remove the {@code value} for the specified {@code key}.
* Update the {@code value} for the specified {@code key}
*
* @param key the header name
* @param value the header value
*/
public void removeResponseHeader(final String key) {
threadLocal.get().responseHeaders.remove(key);
public void updateResponseHeader(final String key, final String value) {
updateResponseHeader(key, value, v -> v, true);
}

/**
* Add the {@code value} for the specified {@code key} with the specified {@code uniqueValue} used for de-duplication. Any duplicate
* Update the {@code value} for the specified {@code key} with the specified {@code uniqueValue} used for de-duplication. Any duplicate
* {@code value} after applying {@code uniqueValue} is ignored.
*
* @param key the header name
* @param value the header value
* @param uniqueValue the function that produces de-duplication values
*/
public void addResponseHeader(final String key, final String value, final Function<String, String> uniqueValue) {
threadLocal.set(threadLocal.get().putResponse(key, value, uniqueValue, maxWarningHeaderCount, maxWarningHeaderSize));
* @param replaceExistingKey whether to replace the existing header if it already exists
*/
public void updateResponseHeader(
final String key,
final String value,
final Function<String, String> uniqueValue,
final boolean replaceExistingKey
) {
threadLocal.set(
threadLocal.get().putResponse(key, value, uniqueValue, maxWarningHeaderCount, maxWarningHeaderSize, replaceExistingKey)
);
}

/**
Expand Down Expand Up @@ -726,7 +735,8 @@ private ThreadContextStruct putResponse(
final String value,
final Function<String, String> uniqueValue,
final int maxWarningHeaderCount,
final long maxWarningHeaderSize
final long maxWarningHeaderSize,
final boolean replaceExistingKey
) {
assert value != null;
long newWarningHeaderSize = warningHeadersSize;
Expand Down Expand Up @@ -768,8 +778,13 @@ private ThreadContextStruct putResponse(
if (existingValues.contains(uniqueValue.apply(value))) {
return this;
}
// preserve insertion order
final Set<String> newValues = Stream.concat(existingValues.stream(), Stream.of(value)).collect(LINKED_HASH_SET_COLLECTOR);
Set<String> newValues;
if (replaceExistingKey) {
newValues = Stream.of(value).collect(LINKED_HASH_SET_COLLECTOR);
} else {
// preserve insertion order
newValues = Stream.concat(existingValues.stream(), Stream.of(value)).collect(LINKED_HASH_SET_COLLECTOR);
}
newResponseHeaders = new HashMap<>(responseHeaders);
newResponseHeaders.put(key, Collections.unmodifiableSet(newValues));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,7 @@ public void writeTaskResourceUsage(SearchShardTask task, String nodeId) {
)
.build();
// Remove the existing TASK_RESOURCE_USAGE header since it would have come from an earlier phase in the same request.
synchronized (this) {
threadPool.getThreadContext().removeResponseHeader(TASK_RESOURCE_USAGE);
threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString());
}
threadPool.getThreadContext().updateResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString());
} catch (Exception e) {
logger.debug("Error during writing task resource usage: ", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,16 @@ public void testResponseHeaders() {
}

final String value = HeaderWarning.formatWarning("qux");
threadContext.addResponseHeader("baz", value, s -> HeaderWarning.extractWarningValueFromWarningHeader(s, false));
threadContext.updateResponseHeader("baz", value, s -> HeaderWarning.extractWarningValueFromWarningHeader(s, false), false);
// pretend that another thread created the same response at a different time
if (randomBoolean()) {
final String duplicateValue = HeaderWarning.formatWarning("qux");
threadContext.addResponseHeader("baz", duplicateValue, s -> HeaderWarning.extractWarningValueFromWarningHeader(s, false));
threadContext.updateResponseHeader(
"baz",
duplicateValue,
s -> HeaderWarning.extractWarningValueFromWarningHeader(s, false),
false
);
}

threadContext.addResponseHeader("Warning", "One is the loneliest number");
Expand Down

0 comments on commit c8f0b6d

Please sign in to comment.