Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cosine_distance for sparse vectors #24027

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

mosabua
Copy link
Member

@mosabua mosabua commented Nov 4, 2024

Description

Still have to test this locally...

Additional context and related issues

follow up to #24005

Release notes

(*) Release notes are required, with the following suggested text:

## General 

* Add the {func}`cosine_distance` function for sparse vectors. ({issue}`tbd`)

@mosabua mosabua requested a review from dain November 4, 2024 18:50
@cla-bot cla-bot bot added the cla-signed label Nov 4, 2024
@mosabua mosabua marked this pull request as draft November 4, 2024 18:50
@github-actions github-actions bot added the docs label Nov 4, 2024
@mosabua mosabua force-pushed the sparse-distance branch 4 times, most recently from 4b126bf to cb8d293 Compare November 4, 2024 19:54
@mosabua mosabua marked this pull request as ready for review November 8, 2024 17:10
@mosabua mosabua requested a review from martint November 8, 2024 17:10
assertThat(assertions.function("cosine_distance", "null", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
.isNull();

//assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b'], ARRAY[1.0E0, null])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails but have not looked into it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failure was an NPE due to invalid input parameters in my opinion.. so I deleted that assertion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the actual error an uncaught NPE? That would be a bug. A functions should throw a TrinoException with an INVALID_ARGUMENT error code if the inputs are invalid.

In that test case, there doesn't seem to be anything wrong with the input aside from the null value. The function should return NULL instead of failing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm ... I added this back in. Here is the stack trace. I wont be able to look at it for a bit but at a quick glance I suspect that the problem of the null return value we get here also applies to the original PR for sparse cosine similarity from @dain .. I could try to add tests for that as well and see if we get the same issue .. and then we probably have to change implementation to treat null properly


io.trino.testing.QueryFailedException: Cannot invoke "java.lang.Double.doubleValue()" because the return value of "io.trino.operator.scalar.MathFunctions.cosineSimilarity(io.trino.type.BlockTypeOperators$BlockPositionIsIdentical, io.trino.type.BlockTypeOperators$BlockPositionHashCode, io.trino.spi.block.SqlMap, io.trino.spi.block.SqlMap)" is null

	at io.trino.testing.TestingDirectTrinoClient.toMaterializedRows(TestingDirectTrinoClient.java:80)
	at io.trino.testing.TestingDirectTrinoClient.lambda$execute$0(TestingDirectTrinoClient.java:67)
	at io.trino.testing.StandaloneQueryRunner.executeWithPlan(StandaloneQueryRunner.java:115)
	at io.trino.sql.query.QueryAssertions$ExpressionAssertProvider.run(QueryAssertions.java:873)
	at io.trino.sql.query.QueryAssertions$ExpressionAssertProvider.evaluate(QueryAssertions.java:838)
	at io.trino.sql.query.QueryAssertions$ExpressionAssertProvider.assertThat(QueryAssertions.java:880)
	at io.trino.sql.query.QueryAssertions$ExpressionAssertProvider.assertThat(QueryAssertions.java:793)
	at org.assertj.core.api.AssertionsForInterfaceTypes.assertThat(AssertionsForInterfaceTypes.java:82)
	at org.assertj.core.api.Assertions.assertThat(Assertions.java:3409)
	at io.trino.operator.scalar.TestMathFunctions.testCosineDistance(TestMathFunctions.java:3475)
	at java.base/java.lang.reflect.Method.invoke(Method.java:580)
	at java.base/java.util.concurrent.ForkJoinTask.doExec$$$capture(ForkJoinTask.java:507)
	at java.base/java.util.concurrent.ForkJoinTask.doExec(ForkJoinTask.java)
	at java.base/java.util.concurrent.ForkJoinPool$WorkQueue.topLevelExec(ForkJoinPool.java:1458)
	at java.base/java.util.concurrent.ForkJoinPool.runWorker(ForkJoinPool.java:2034)
	at java.base/java.util.concurrent.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:189)
Caused by: java.lang.NullPointerException: Cannot invoke "java.lang.Double.doubleValue()" because the return value of "io.trino.operator.scalar.MathFunctions.cosineSimilarity(io.trino.type.BlockTypeOperators$BlockPositionIsIdentical, io.trino.type.BlockTypeOperators$BlockPositionHashCode, io.trino.spi.block.SqlMap, io.trino.spi.block.SqlMap)" is null
	at io.trino.operator.scalar.MathFunctions.cosineDistance(MathFunctions.java:1422)
	at io.trino.$gen.PageProjectionWork_20250129_195047_37.evaluate(Unknown Source)
	at io.trino.$gen.PageProjectionWork_20250129_195047_37.process(Unknown Source)
	at io.trino.operator.project.PageProcessor$ProjectSelectedPositions.processBatch(PageProcessor.java:315)
	at io.trino.operator.project.PageProcessor$ProjectSelectedPositions.process(PageProcessor.java:199)
	at io.trino.operator.WorkProcessorUtils$ProcessWorkProcessor.process(WorkProcessorUtils.java:423)
	at io.trino.operator.WorkProcessorUtils.lambda$flatten$6(WorkProcessorUtils.java:317)
	at io.trino.operator.WorkProcessorUtils$3.process(WorkProcessorUtils.java:359)
	at io.trino.operator.WorkProcessorUtils$ProcessWorkProcessor.process(WorkProcessorUtils.java:423)
	at io.trino.operator.WorkProcessorUtils$3.process(WorkProcessorUtils.java:346)
	at io.trino.operator.WorkProcessorUtils$ProcessWorkProcessor.process(WorkProcessorUtils.java:423)
	at io.trino.operator.WorkProcessorUtils.getNextState(WorkProcessorUtils.java:261)
	at io.trino.operator.WorkProcessorUtils$BlockingProcess.process(WorkProcessorUtils.java:207)
	at io.trino.operator.WorkProcessorUtils$ProcessWorkProcessor.process(WorkProcessorUtils.java:423)
	at io.trino.operator.WorkProcessorOperatorAdapter.getOutput(WorkProcessorOperatorAdapter.java:150)
	at io.trino.operator.Driver.processInternal(Driver.java:403)
	at io.trino.operator.Driver.lambda$process$8(Driver.java:306)
	at io.trino.operator.Driver.tryWithLock(Driver.java:709)
	at io.trino.operator.Driver.process(Driver.java:298)
	at io.trino.operator.Driver.processForDuration(Driver.java:269)
	at io.trino.execution.SqlTaskExecution$DriverSplitRunner.processFor(SqlTaskExecution.java:890)
	at io.trino.execution.executor.dedicated.SplitProcessor.run(SplitProcessor.java:77)
	at io.trino.execution.executor.dedicated.TaskEntry$VersionEmbedderBridge.lambda$run$0(TaskEntry.java:201)
	at io.trino.$gen.Trino_testversion____20250129_195040_1.run(Unknown Source)
	at io.trino.execution.executor.dedicated.TaskEntry$VersionEmbedderBridge.run(TaskEntry.java:202)
	at io.trino.execution.executor.scheduler.FairScheduler.runTask(FairScheduler.java:177)
	at io.trino.execution.executor.scheduler.FairScheduler.lambda$submit$0(FairScheduler.java:164)
	at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:572)
	at com.google.common.util.concurrent.TrustedListenableFutureTask$TrustedFutureInterruptibleTask.runInterruptibly(TrustedListenableFutureTask.java:131)
	at com.google.common.util.concurrent.InterruptibleTask.run(InterruptibleTask.java:75)
	at com.google.common.util.concurrent.TrustedListenableFutureTask.run(TrustedListenableFutureTask.java:82)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	at java.base/java.lang.Thread.run(Thread.java:1575)

Copy link

github-actions bot commented Dec 2, 2024

This pull request has gone a while without any activity. Tagging for triage help: @mosabua

@github-actions github-actions bot added the stale label Dec 2, 2024
@mosabua
Copy link
Member Author

mosabua commented Dec 2, 2024

I need to look at the test that failed and chat some more with @martint afterwards. I will keep it open and update after Trino Summit.

@github-actions github-actions bot removed the stale label Dec 3, 2024
Copy link

This pull request has gone a while without any activity. Tagging for triage help: @mosabua

@github-actions github-actions bot added the stale label Dec 25, 2024
@mosabua mosabua force-pushed the sparse-distance branch 2 times, most recently from 74e4109 to c0211d2 Compare January 10, 2025 21:39
@mosabua
Copy link
Member Author

mosabua commented Jan 10, 2025

I need to look at the test that failed and chat some more with @martint afterwards. I will keep it open and update after Trino Summit.

@martint and @dain .. I think this is ready now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Development

Successfully merging this pull request may close these issues.

3 participants