From bf2576430ca890950c4a9905efd27887d61c268b Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Fri, 9 Jun 2023 13:33:35 -0700 Subject: [PATCH 1/2] Fix the vector tests to match the final PR --- tests/test_vector.py | 44 +++++++++++++++++--------------------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/tests/test_vector.py b/tests/test_vector.py index a96119d3..559992c2 100644 --- a/tests/test_vector.py +++ b/tests/test_vector.py @@ -36,90 +36,80 @@ def setUp(self): if not self.client.query_required_single(''' select exists ( - select sys::ExtensionPackage filter .name = 'vector' + select sys::ExtensionPackage filter .name = 'pgvector' ) '''): self.skipTest("feature not implemented") self.client.execute(''' - create extension vector version '1.0' + create extension pgvector version '0.4.0' ''') def tearDown(self): try: self.client.execute(''' - drop extension vector version '1.0' + drop extension pgvector version '0.4.0' ''') finally: super().tearDown() async def test_vector_01(self): - # if not self.client.query_required_single(''' - # select exists ( - # select sys::ExtensionPackage filter .name = 'vector' - # ) - # '''): - # self.skipTest("feature not implemented") - - # self.client.execute(''' - # create extension vector version '1.0' - # ''') - val = self.client.query_single(''' - select '[1.5,2.0,3.8]' + select [1.5,2.0,3.8] ''') self.assertTrue(isinstance(val, array.array)) self.assertEqual(val, array.array('f', [1.5, 2.0, 3.8])) val = self.client.query_single( ''' - select $0 + select $0 ''', [3.0, 9.0, -42.5], ) - self.assertEqual(val, '[3,9,-42.5]') + self.assertEqual(val, '[3, 9, -42.5]') val = self.client.query_single( ''' - select $0 + select $0 ''', array.array('f', [3.0, 9.0, -42.5]) ) - self.assertEqual(val, '[3,9,-42.5]') + self.assertEqual(val, '[3, 9, -42.5]') val = self.client.query_single( ''' - select $0 + select $0 ''', array.array('i', [1, 2, 3]), ) - self.assertEqual(val, '[1,2,3]') + self.assertEqual(val, '[1, 2, 3]') # Test that the fast-path works: if the encoder tries to # call __getitem__ on this brokenarray, it will fail. val = self.client.query_single( ''' - select $0 + select $0 ''', brokenarray('f', [3.0, 9.0, -42.5]) ) - self.assertEqual(val, '[3,9,-42.5]') + self.assertEqual(val, '[3, 9, -42.5]') # I don't think it's worth adding a dependency to test this, # but this works too: # import numpy as np # val = self.client.query_single( # ''' - # select $0 + # select $0 # ''', # np.asarray([3.0, 9.0, -42.5], dtype=np.float32), # ) + # self.assertEqual(val, '[3,9,-42.5]') # Some sad path tests with self.assertRaises(edgedb.InvalidArgumentError): self.client.query_single( ''' - select $0 + select $0 ''', [3.0, None, -42.5], ) @@ -127,7 +117,7 @@ async def test_vector_01(self): with self.assertRaises(edgedb.InvalidArgumentError): self.client.query_single( ''' - select $0 + select $0 ''', [3.0, 'x', -42.5], ) @@ -135,7 +125,7 @@ async def test_vector_01(self): with self.assertRaises(edgedb.InvalidArgumentError): self.client.query_single( ''' - select $0 + select $0 ''', 'foo', ) From c75a259bb77b4901245baa217e0d646f7e52609d Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Fri, 9 Jun 2023 14:56:34 -0700 Subject: [PATCH 2/2] don't be explicit about the version --- tests/test_vector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_vector.py b/tests/test_vector.py index 559992c2..ede4a3d0 100644 --- a/tests/test_vector.py +++ b/tests/test_vector.py @@ -42,13 +42,13 @@ def setUp(self): self.skipTest("feature not implemented") self.client.execute(''' - create extension pgvector version '0.4.0' + create extension pgvector; ''') def tearDown(self): try: self.client.execute(''' - drop extension pgvector version '0.4.0' + drop extension pgvector; ''') finally: super().tearDown()