Skip to content

Commit

Permalink
Improve auto-determination of ResponseEntity type; Default Accept hea…
Browse files Browse the repository at this point in the history
…der to text for hardened fallback logic
  • Loading branch information
D-Pow committed Aug 18, 2024
1 parent f7ceb4e commit 3928c25
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ public ResponseEntity<Resource> getProxiedVideoStream(
proxyHeaders.set(HttpHeaders.ACCEPT_RANGES, "bytes");
proxyHeaders.set(HttpHeaders.RANGE, String.format("bytes=%d-%d", ranges.get(0), ranges.get(1)));

return (ResponseEntity<Resource>) CorsProxy.doCorsRequest(
return CorsProxy.doCorsRequest(
HttpMethod.GET,
URI.create(videoUrl),
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ public VideoSearchResult getVideosForEpisode(String url) {
getVideoUrlFromHostHeaders
).getHeaders().getLocation().toString();

ResponseEntity<String> videoResponse = (ResponseEntity<String>) CorsProxy.doCorsRequest(
ResponseEntity<String> videoResponse = CorsProxy.doCorsRequest(
HttpMethod.GET,
URI.create(videoUrl),
URI.create(UriParser.getOrigin(videoUrl)),
Expand All @@ -247,7 +247,7 @@ public VideoSearchResult getVideosForEpisode(String url) {
if (videoResponse.getHeaders().getLocation() != null) {
videoUrl = videoResponse.getHeaders().getLocation().toString();

videoResponse = (ResponseEntity<String>) CorsProxy.doCorsRequest(
videoResponse = CorsProxy.doCorsRequest(
HttpMethod.GET,
URI.create(videoUrl),
URI.create(UriParser.getOrigin(videoUrl)),
Expand Down
30 changes: 23 additions & 7 deletions server/src/main/java/org/animeatsume/api/utils/http/CorsProxy.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,23 @@
public class CorsProxy {
private static final boolean defaultFollowRedirects = true;

public static ResponseEntity<?> doCorsRequest(
public static <T> ResponseEntity<T> doCorsRequest(
HttpMethod method,
String url,
@Nullable String origin,
@Nullable Object body
) {
return doCorsRequest(
method,
url,
origin,
body,
null,
defaultFollowRedirects
);
}

public static <T> ResponseEntity<T> doCorsRequest(
HttpMethod method,
String url,
@Nullable String origin,
Expand All @@ -33,7 +49,7 @@ public static ResponseEntity<?> doCorsRequest(
);
}

public static ResponseEntity<?> doCorsRequest(
public static <T> ResponseEntity<T> doCorsRequest(
HttpMethod method,
String url,
@Nullable String origin,
Expand Down Expand Up @@ -66,7 +82,7 @@ public static ResponseEntity<?> doCorsRequest(
);
}

public static ResponseEntity<?> doCorsRequest(
public static <T> ResponseEntity<T> doCorsRequest(
HttpMethod method,
URI url,
@Nullable URI origin,
Expand All @@ -76,7 +92,7 @@ public static ResponseEntity<?> doCorsRequest(
return doCorsRequest(method, url, origin, body, headers, defaultFollowRedirects);
}

public static ResponseEntity<?> doCorsRequest(
public static <T> ResponseEntity<T> doCorsRequest(
HttpMethod method,
URI url,
@Nullable URI origin,
Expand All @@ -89,7 +105,7 @@ public static ResponseEntity<?> doCorsRequest(
List<MediaType> requestAcceptHeaders = corsEntity.getHeaders().getAccept();

if (requestAcceptHeaders.size() == 0) {
return ResponseEntity
return (ResponseEntity<T>) ResponseEntity
.status(HttpStatus.NOT_ACCEPTABLE)
.body("You must add a value for the 'Accept' header");
}
Expand All @@ -101,7 +117,7 @@ public static ResponseEntity<?> doCorsRequest(
Requests.addAcceptableMediaTypes(restTemplate, MediaType.APPLICATION_FORM_URLENCODED);
restTemplate.getMessageConverters().add(new FormHttpMessageConverter());

ResponseEntity<?> response = Requests.doRequestWithFallback(restTemplate, url, method, corsEntity, responseClass);
ResponseEntity<T> response = Requests.<T>doRequestWithFallback(restTemplate, url, method, corsEntity, responseClass);
Object responseBody = response.getBody();
HttpHeaders responseHeaders = Requests.copyHttpHeaders(response.getHeaders());

Expand All @@ -111,7 +127,7 @@ public static ResponseEntity<?> doCorsRequest(
responseHeaders.remove(HttpHeaders.CONTENT_ENCODING);
responseHeaders.remove(HttpHeaders.TRANSFER_ENCODING);

return new ResponseEntity<>(responseBody, responseHeaders, HttpStatus.OK);
return new ResponseEntity<T>((T) responseBody, responseHeaders, HttpStatus.OK);
}

public static <T> HttpEntity<T> getCorsEntity(T body, String origin, String referer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public static Class<?> getClassFromContentTypeHeader(String contentTypeHeader) {
return Resource.class;
}

public static ResponseEntity<?> doRequestWithFallback(
public static <T> ResponseEntity<T> doRequestWithFallback(
RestTemplate restTemplate,
URI url,
HttpMethod method,
Expand Down Expand Up @@ -190,9 +190,11 @@ public static ResponseEntity<?> doRequestWithFallback(

HttpHeaders responseHeaders = headForHeadersWithAcceptAllFallback(url, restTemplate, requestEntity);
List<MediaType> headersAccept = responseHeaders.getAccept();
String contentTypeHeader = headersAccept
.get(headersAccept.size() - 1)
.toString();
String contentTypeHeader = headersAccept.isEmpty()
? MediaType.TEXT_PLAIN_VALUE
: headersAccept
.get(headersAccept.size() - 1)
.toString();
Class<?> actualResponseTypeClass = getClassFromContentTypeHeader(contentTypeHeader);

response = restTemplate.exchange(
Expand Down Expand Up @@ -222,10 +224,10 @@ public static ResponseEntity<?> doRequestWithFallback(
e.getMessage()
);

return new ResponseEntity<>(e.getResponseBodyAsString(), e.getResponseHeaders(), e.getStatusCode());
return new ResponseEntity<>((T) e.getResponseBodyAsString(), e.getResponseHeaders(), e.getStatusCode());
}

return new ResponseEntity<>(body, response.getHeaders(), response.getStatusCode());
return new ResponseEntity<>((T) body, response.getHeaders(), response.getStatusCode());
}

/**
Expand Down

0 comments on commit 3928c25

Please sign in to comment.