-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add cosine_distance for sparse vectors #24027
base: master
Are you sure you want to change the base?
Conversation
core/trino-main/src/test/java/io/trino/operator/scalar/TestMathFunctions.java
Show resolved
Hide resolved
4b126bf
to
cb8d293
Compare
assertThat(assertions.function("cosine_distance", "null", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])")) | ||
.isNull(); | ||
|
||
//assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b'], ARRAY[1.0E0, null])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails but have not looked into it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The failure was an NPE due to invalid input parameters in my opinion.. so I deleted that assertion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was the actual error an uncaught NPE? That would be a bug. A functions should throw a TrinoException with an INVALID_ARGUMENT error code if the inputs are invalid.
In that test case, there doesn't seem to be anything wrong with the input aside from the null value. The function should return NULL instead of failing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm ... I added this back in. Here is the stack trace. I wont be able to look at it for a bit but at a quick glance I suspect that the problem of the null return value we get here also applies to the original PR for sparse cosine similarity from @dain .. I could try to add tests for that as well and see if we get the same issue .. and then we probably have to change implementation to treat null properly
io.trino.testing.QueryFailedException: Cannot invoke "java.lang.Double.doubleValue()" because the return value of "io.trino.operator.scalar.MathFunctions.cosineSimilarity(io.trino.type.BlockTypeOperators$BlockPositionIsIdentical, io.trino.type.BlockTypeOperators$BlockPositionHashCode, io.trino.spi.block.SqlMap, io.trino.spi.block.SqlMap)" is null
at io.trino.testing.TestingDirectTrinoClient.toMaterializedRows(TestingDirectTrinoClient.java:80)
at io.trino.testing.TestingDirectTrinoClient.lambda$execute$0(TestingDirectTrinoClient.java:67)
at io.trino.testing.StandaloneQueryRunner.executeWithPlan(StandaloneQueryRunner.java:115)
at io.trino.sql.query.QueryAssertions$ExpressionAssertProvider.run(QueryAssertions.java:873)
at io.trino.sql.query.QueryAssertions$ExpressionAssertProvider.evaluate(QueryAssertions.java:838)
at io.trino.sql.query.QueryAssertions$ExpressionAssertProvider.assertThat(QueryAssertions.java:880)
at io.trino.sql.query.QueryAssertions$ExpressionAssertProvider.assertThat(QueryAssertions.java:793)
at org.assertj.core.api.AssertionsForInterfaceTypes.assertThat(AssertionsForInterfaceTypes.java:82)
at org.assertj.core.api.Assertions.assertThat(Assertions.java:3409)
at io.trino.operator.scalar.TestMathFunctions.testCosineDistance(TestMathFunctions.java:3475)
at java.base/java.lang.reflect.Method.invoke(Method.java:580)
at java.base/java.util.concurrent.ForkJoinTask.doExec$$$capture(ForkJoinTask.java:507)
at java.base/java.util.concurrent.ForkJoinTask.doExec(ForkJoinTask.java)
at java.base/java.util.concurrent.ForkJoinPool$WorkQueue.topLevelExec(ForkJoinPool.java:1458)
at java.base/java.util.concurrent.ForkJoinPool.runWorker(ForkJoinPool.java:2034)
at java.base/java.util.concurrent.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:189)
Caused by: java.lang.NullPointerException: Cannot invoke "java.lang.Double.doubleValue()" because the return value of "io.trino.operator.scalar.MathFunctions.cosineSimilarity(io.trino.type.BlockTypeOperators$BlockPositionIsIdentical, io.trino.type.BlockTypeOperators$BlockPositionHashCode, io.trino.spi.block.SqlMap, io.trino.spi.block.SqlMap)" is null
at io.trino.operator.scalar.MathFunctions.cosineDistance(MathFunctions.java:1422)
at io.trino.$gen.PageProjectionWork_20250129_195047_37.evaluate(Unknown Source)
at io.trino.$gen.PageProjectionWork_20250129_195047_37.process(Unknown Source)
at io.trino.operator.project.PageProcessor$ProjectSelectedPositions.processBatch(PageProcessor.java:315)
at io.trino.operator.project.PageProcessor$ProjectSelectedPositions.process(PageProcessor.java:199)
at io.trino.operator.WorkProcessorUtils$ProcessWorkProcessor.process(WorkProcessorUtils.java:423)
at io.trino.operator.WorkProcessorUtils.lambda$flatten$6(WorkProcessorUtils.java:317)
at io.trino.operator.WorkProcessorUtils$3.process(WorkProcessorUtils.java:359)
at io.trino.operator.WorkProcessorUtils$ProcessWorkProcessor.process(WorkProcessorUtils.java:423)
at io.trino.operator.WorkProcessorUtils$3.process(WorkProcessorUtils.java:346)
at io.trino.operator.WorkProcessorUtils$ProcessWorkProcessor.process(WorkProcessorUtils.java:423)
at io.trino.operator.WorkProcessorUtils.getNextState(WorkProcessorUtils.java:261)
at io.trino.operator.WorkProcessorUtils$BlockingProcess.process(WorkProcessorUtils.java:207)
at io.trino.operator.WorkProcessorUtils$ProcessWorkProcessor.process(WorkProcessorUtils.java:423)
at io.trino.operator.WorkProcessorOperatorAdapter.getOutput(WorkProcessorOperatorAdapter.java:150)
at io.trino.operator.Driver.processInternal(Driver.java:403)
at io.trino.operator.Driver.lambda$process$8(Driver.java:306)
at io.trino.operator.Driver.tryWithLock(Driver.java:709)
at io.trino.operator.Driver.process(Driver.java:298)
at io.trino.operator.Driver.processForDuration(Driver.java:269)
at io.trino.execution.SqlTaskExecution$DriverSplitRunner.processFor(SqlTaskExecution.java:890)
at io.trino.execution.executor.dedicated.SplitProcessor.run(SplitProcessor.java:77)
at io.trino.execution.executor.dedicated.TaskEntry$VersionEmbedderBridge.lambda$run$0(TaskEntry.java:201)
at io.trino.$gen.Trino_testversion____20250129_195040_1.run(Unknown Source)
at io.trino.execution.executor.dedicated.TaskEntry$VersionEmbedderBridge.run(TaskEntry.java:202)
at io.trino.execution.executor.scheduler.FairScheduler.runTask(FairScheduler.java:177)
at io.trino.execution.executor.scheduler.FairScheduler.lambda$submit$0(FairScheduler.java:164)
at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:572)
at com.google.common.util.concurrent.TrustedListenableFutureTask$TrustedFutureInterruptibleTask.runInterruptibly(TrustedListenableFutureTask.java:131)
at com.google.common.util.concurrent.InterruptibleTask.run(InterruptibleTask.java:75)
at com.google.common.util.concurrent.TrustedListenableFutureTask.run(TrustedListenableFutureTask.java:82)
at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
at java.base/java.lang.Thread.run(Thread.java:1575)
This pull request has gone a while without any activity. Tagging for triage help: @mosabua |
I need to look at the test that failed and chat some more with @martint afterwards. I will keep it open and update after Trino Summit. |
This pull request has gone a while without any activity. Tagging for triage help: @mosabua |
74e4109
to
c0211d2
Compare
c0211d2
to
4db1191
Compare
4db1191
to
1def2e3
Compare
Description
Still have to test this locally...
Additional context and related issues
follow up to #24005
Release notes
(*) Release notes are required, with the following suggested text: