Skip to content

Commit

Permalink
feat: add impersonate user by id feature
Browse files Browse the repository at this point in the history
  • Loading branch information
egor-romanov committed Nov 14, 2023
1 parent ec961c4 commit 5579edf
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 21 deletions.
17 changes: 13 additions & 4 deletions src/vecs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Client:
vx.disconnect()
"""

def __init__(self, connection_string: str):
def __init__(self, connection_string: str, skip_auth: bool = True, user_id: Optional[str] = None):
"""
Initialize a Client instance.
Expand All @@ -58,8 +58,13 @@ def __init__(self, connection_string: str):
None
"""
self.engine = create_engine(connection_string)
self.skip_auth = skip_auth
self.user_id = user_id
self.meta = MetaData(schema="vecs")
self.Session = sessionmaker(self.engine)

if not self.skip_auth:
return

with self.Session() as sess:
with sess.begin():
Expand Down Expand Up @@ -113,6 +118,8 @@ def get_or_create_collection(
dimension=dimension or adapter_dimension, # type: ignore
client=self,
adapter=adapter,
skip_auth=self.skip_auth,
user_id=self.user_id,
)

return collection._create_if_not_exists()
Expand All @@ -134,7 +141,7 @@ def create_collection(self, name: str, dimension: int) -> Collection:
"""
from vecs.collection import Collection

return Collection(name, dimension, self)._create()
return Collection(name, dimension, self, skip_auth=self.skip_auth, user_id=self.user_id)._create()

@deprecated("use Client.get_or_create_collection")
def get_collection(self, name: str) -> Collection:
Expand Down Expand Up @@ -180,6 +187,8 @@ def get_collection(self, name: str) -> Collection:
name,
dimension,
self,
skip_auth=self.skip_auth,
user_id=self.user_id,
)

def list_collections(self) -> List["Collection"]:
Expand All @@ -191,7 +200,7 @@ def list_collections(self) -> List["Collection"]:
"""
from vecs.collection import Collection

return Collection._list_collections(self)
return Collection._list_collections(self, skip_auth=self.skip_auth, user_id=self.user_id)

def delete_collection(self, name: str) -> None:
"""
Expand All @@ -207,7 +216,7 @@ def delete_collection(self, name: str) -> None:
"""
from vecs.collection import Collection

Collection(name, -1, self)._drop()
Collection(name, -1, self, skip_auth=self.skip_auth, user_id=self.user_id)._drop()
return

def disconnect(self) -> None:
Expand Down
67 changes: 50 additions & 17 deletions src/vecs/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def __init__(
dimension: int,
client: Client,
adapter: Optional[Adapter] = None,
*,
skip_auth: bool = True,
user_id: Optional[str] = None,
):
"""
Initializes a new instance of the `Collection` class.
Expand All @@ -177,6 +180,8 @@ def __init__(
self.table = build_table(name, client.meta, dimension)
self._index: Optional[str] = None
self.adapter = adapter or Adapter(steps=[NoOp(dimension=dimension)])
self.skip_auth = skip_auth
self.user_id = user_id

reported_dimensions = set(
[
Expand Down Expand Up @@ -213,6 +218,7 @@ def __len__(self) -> int:
"""
with self.client.Session() as sess:
with sess.begin():
self._add_auth(sess)
stmt = select(func.count()).select_from(self.table)
return sess.execute(stmt).scalar() or 0

Expand Down Expand Up @@ -243,12 +249,14 @@ def _create_if_not_exists(self):
"""
).bindparams(name=self.name)
with self.client.Session() as sess:
query_result = sess.execute(query).fetchone()
with sess.begin():
self._add_auth(sess)
query_result = sess.execute(query).fetchone()

if query_result:
_, collection_dimension = query_result
else:
collection_dimension = None
if query_result:
_, collection_dimension = query_result
else:
collection_dimension = None

reported_dimensions = set(
[x for x in [self.dimension, collection_dimension] if x is not None]
Expand Down Expand Up @@ -285,15 +293,17 @@ def _create(self):

unique_string = str(uuid.uuid4()).replace("-", "_")[0:7]
with self.client.Session() as sess:
sess.execute(
text(
f"""
create index ix_meta_{unique_string}
on vecs."{self.table.name}"
using gin ( metadata jsonb_path_ops )
"""
with sess.begin():
self._add_auth(sess)
sess.execute(
text(
f"""
create index ix_meta_{unique_string}
on vecs."{self.table.name}"
using gin ( metadata jsonb_path_ops )
"""
)
)
)
return self

def _drop(self):
Expand All @@ -309,8 +319,10 @@ def _drop(self):
from sqlalchemy.schema import DropTable

with self.client.Session() as sess:
sess.execute(DropTable(self.table, if_exists=True))
sess.commit()
with sess.begin():
self._add_auth(sess)
sess.execute(DropTable(self.table, if_exists=True))
sess.commit()

return self

Expand Down Expand Up @@ -341,6 +353,7 @@ def upsert(

with self.client.Session() as sess:
with sess.begin():
self._add_auth(sess)
for chunk in pipeline:
stmt = postgresql.insert(self.table).values(chunk)
stmt = stmt.on_conflict_do_update(
Expand Down Expand Up @@ -369,6 +382,7 @@ def fetch(self, ids: Iterable[str]) -> List[Record]:
records = []
with self.client.Session() as sess:
with sess.begin():
self._add_auth(sess)
for id_chunk in flu(ids).chunk(chunk_size):
stmt = select(self.table).where(self.table.c.id.in_(id_chunk))
chunk_records = sess.execute(stmt)
Expand All @@ -394,6 +408,7 @@ def delete(self, ids: Iterable[str]) -> List[str]:
ids = []
with self.client.Session() as sess:
with sess.begin():
self._add_auth(sess)
for id_chunk in flu(del_ids).chunk(chunk_size):
stmt = (
delete(self.table)
Expand Down Expand Up @@ -533,12 +548,29 @@ def query(
ef_search=ef_search
)
)
self._add_auth(sess)
if len(cols) == 1:
return [str(x) for x in sess.scalars(stmt).fetchall()]
return sess.execute(stmt).fetchall() or []

def _add_auth(self, sess):
if not self.skip_auth:
if self.user_id:
sess.execute(
text("set local request.jwt.claim.sub = :user_id").bindparams(
user_id=self.user_id
)
)
sess.execute(
text("set local role authenticated;")
)
else:
sess.execute(
text("set local role anon;")
)

@classmethod
def _list_collections(cls, client: "Client") -> List["Collection"]:
def _list_collections(cls, client: "Client", skip_auth: bool = True, user_id: Optional[str] = None) -> List["Collection"]:
"""
PRIVATE
Expand Down Expand Up @@ -570,7 +602,7 @@ def _list_collections(cls, client: "Client") -> List["Collection"]:
xc = []
with client.Session() as sess:
for name, dimension in sess.execute(query):
existing_collection = cls(name, dimension, client)
existing_collection = cls(name, dimension, client, skip_auth=skip_auth, user_id=user_id)
xc.append(existing_collection)
return xc

Expand Down Expand Up @@ -734,6 +766,7 @@ def create_index(

with self.client.Session() as sess:
with sess.begin():
self._add_auth(sess)
if self.index is not None:
if replace:
sess.execute(text(f'drop index vecs."{self.index}";'))
Expand Down

0 comments on commit 5579edf

Please sign in to comment.