Skip to content

Commit

Permalink
fix review6: validate OnnxTensor ThrMsgs
Browse files Browse the repository at this point in the history
  • Loading branch information
EvgeniiMunin committed Sep 24, 2024
1 parent aef819c commit 452e928
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private boolean isPC() {
|| PC_OS_FAMILIES.contains(osFamily())
|| ("Windows".equals(osFamily()) && "ME".equals(osMajor()))
|| ("Mac OS X".equals(osFamily()) && !userAgent.contains("Silk"))
|| userAgent.contains("Linux") && userAgent.contains("X11"))
|| (userAgent.contains("Linux") && userAgent.contains("X11")))
.orElse(false);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import org.prebid.server.exception.PreBidException;
Expand Down Expand Up @@ -49,14 +50,42 @@ private Map<String, Map<String, Boolean>> processModelResults(
List<ThrottlingMessage> throttlingMessages,
Double threshold) {

validateThrottlingMessages(throttlingMessages);

return StreamSupport.stream(results.spliterator(), false)
.filter(onnxItem -> Objects.equals(onnxItem.getKey(), "probabilities"))
.filter(onnxItem -> {
validateOnnxTensor(onnxItem);
return Objects.equals(onnxItem.getKey(), "probabilities");
})
.map(onnxItem -> (OnnxTensor) onnxItem.getValue())
.map(tensor -> extractAndProcessProbabilities(tensor, throttlingMessages, threshold))
.map(tensor -> {
validateTensorSize(tensor, throttlingMessages.size());
return extractAndProcessProbabilities(tensor, throttlingMessages, threshold);
})
.flatMap(map -> map.entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}

private void validateThrottlingMessages(List<ThrottlingMessage> throttlingMessages) {
if (throttlingMessages == null || throttlingMessages.isEmpty()) {
throw new PreBidException("throttlingMessages cannot be null or empty");
}
}

private void validateOnnxTensor(Map.Entry<String, OnnxValue> onnxItem) {
if (!(onnxItem.getValue() instanceof OnnxTensor)) {
throw new PreBidException("Expected OnnxTensor for 'probabilities', but found: "
+ onnxItem.getValue().getClass().getName());
}
}

private void validateTensorSize(OnnxTensor tensor, int expectedSize) {
final long[] tensorShape = tensor.getInfo().getShape();
if (tensorShape.length == 0 || tensorShape[0] != expectedSize) {
throw new PreBidException("Mismatch between tensor size and throttlingMessages size");
}
}

private Map<String, Map<String, Boolean>> extractAndProcessProbabilities(
OnnxTensor tensor,
List<ThrottlingMessage> throttlingMessages,
Expand Down

0 comments on commit 452e928

Please sign in to comment.