Skip to content

Commit

Permalink
feat(postgres): support loading tables with pgvector column types (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth authored Apr 25, 2024
1 parent 2c1a58e commit 8846514
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 15 deletions.
30 changes: 15 additions & 15 deletions ci/schema/postgres.sql
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,6 @@ CREATE TABLE awards_players (

COPY awards_players FROM '/data/awards_players.csv' WITH (FORMAT CSV, HEADER TRUE, DELIMITER ',');

DROP TYPE IF EXISTS vector CASCADE;
CREATE TYPE vector AS (
x FLOAT8,
y FLOAT8,
z FLOAT8
);

DROP VIEW IF EXISTS awards_players_special_types CASCADE;
CREATE VIEW awards_players_special_types AS
SELECT
*,
setweight(to_tsvector('simple', notes), 'A')::TSVECTOR AS search,
NULL::vector AS simvec
FROM awards_players;

DROP TABLE IF EXISTS functional_alltypes CASCADE;

CREATE TABLE functional_alltypes (
Expand Down Expand Up @@ -302,3 +287,18 @@ DROP TABLE IF EXISTS topk;

CREATE TABLE topk (x BIGINT);
INSERT INTO topk VALUES (1), (1), (NULL);

CREATE EXTENSION IF NOT EXISTS vector;

DROP VIEW IF EXISTS awards_players_special_types CASCADE;
CREATE VIEW awards_players_special_types AS
SELECT
*,
setweight(to_tsvector('simple', notes), 'A')::TSVECTOR AS search,
NULL::vector AS simvec
FROM awards_players;


DROP TABLE IF EXISTS items CASCADE;
CREATE TABLE items (id bigserial PRIMARY KEY, embedding vector(3));
INSERT INTO items (embedding) VALUES ('[1,2,3]'), ('[4,5,6]');
14 changes: 14 additions & 0 deletions docker/postgres/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,2 +1,16 @@
FROM postgis/postgis:15-3.3-alpine AS pgvector-builder
RUN apk add --no-cache git
RUN apk add --no-cache build-base
RUN apk add --no-cache clang15
RUN apk add --no-cache llvm15-dev llvm15
WORKDIR /tmp
RUN git clone --branch v0.6.2 https://github.com/pgvector/pgvector.git
WORKDIR /tmp/pgvector
RUN make
RUN make install

FROM postgis/postgis:15-3.3-alpine
RUN apk add --no-cache postgresql15-plpython3
COPY --from=pgvector-builder /usr/local/lib/postgresql/bitcode/vector.index.bc /usr/local/lib/postgresql/bitcode/vector.index.bc
COPY --from=pgvector-builder /usr/local/lib/postgresql/vector.so /usr/local/lib/postgresql/vector.so
COPY --from=pgvector-builder /usr/local/share/postgresql/extension /usr/local/share/postgresql/extension
38 changes: 38 additions & 0 deletions ibis/backends/postgres/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations

import os
import random

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -257,3 +258,40 @@ def test_port():
# check that we parse and use the port (and then of course fail cuz it's bogus)
with pytest.raises(PsycoPg2OperationalError):
ibis.connect("postgresql://postgres:postgres@localhost:1337/ibis_testing")


def test_pgvector_type_load(con):
"""
CREATE TABLE items (id bigserial PRIMARY KEY, embedding vector(3));
INSERT INTO items (embedding) VALUES ('[1,2,3]'), ('[4,5,6]');
"""
t = con.table("items")

assert t.schema() == ibis.schema(
{
"id": dt.int64(nullable=False),
"embedding": dt.unknown,
}
)

result = ["[1,2,3]", "[4,5,6]"]
assert t.to_pyarrow().column("embedding").to_pylist() == result

query = f"""
DROP TABLE IF EXISTS itemsvrandom;
CREATE TABLE itemsvrandom (id bigserial PRIMARY KEY, embedding vector({random.randint(4, 1000)}));
"""

with con.raw_sql(query):
pass

t = con.table("itemsvrandom")

assert t.schema() == ibis.schema(
{
"id": dt.int64(nullable=False),
"embedding": dt.unknown,
}
)

con.drop_table("itemsvrandom")
10 changes: 10 additions & 0 deletions ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,16 @@ def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
raise com.IbisTypeError("Postgres only supports string values in maps")
return sge.DataType(this=typecode.HSTORE)

@classmethod
def from_string(cls, text: str, nullable: bool | None = None) -> dt.DataType:
if text.lower().startswith("vector"):
text = "vector"
if dtype := cls.unknown_type_strings.get(text.lower()):
return dtype

sgtype = sg.parse_one(text, into=sge.DataType, read=cls.dialect)
return cls.to_ibis(sgtype, nullable=nullable)


class RisingWaveType(PostgresType):
dialect = "risingwave"
Expand Down

0 comments on commit 8846514

Please sign in to comment.