Skip to content

Commit

Permalink
Refactor search index to improve connection handling (#192)
Browse files Browse the repository at this point in the history
We were leaning on a hack to run some async code in a sync setting. This
was both dangerous and an anti-pattern. We also needed to refactor some
of the shared and non-shared content between the BaseSearchIndex and
derivatives.
  • Loading branch information
tylerhutcherson authored Jul 30, 2024
1 parent cb61457 commit 67eee3d
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 184 deletions.
2 changes: 1 addition & 1 deletion docs/examples/openai_qna.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@
"client = redis.Redis.from_url(\"redis://localhost:6379\")\n",
"schema = IndexSchema.from_yaml(\"wiki_schema.yaml\")\n",
"\n",
"index = AsyncSearchIndex(schema, client)\n",
"index = await AsyncSearchIndex(schema).set_client(client)\n",
"\n",
"await index.create()"
]
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guide/getting_started_01.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@
"client = Redis.from_url(\"redis://localhost:6379\")\n",
"\n",
"index = AsyncSearchIndex.from_dict(schema)\n",
"index.set_client(client)"
"await index.set_client(client)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion redisvl/extensions/session_manager/standard_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
self._client = RedisConnectionFactory.get_redis_connection(
redis_url, **connection_kwargs
)
RedisConnectionFactory.validate_redis(self._client)
RedisConnectionFactory.validate_sync_redis(self._client)

self.set_scope(session_tag, user_tag)

Expand Down
168 changes: 105 additions & 63 deletions redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,24 @@ def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
RedisConnectionFactory.validate_redis(self._redis_client, self._lib_name)
RedisConnectionFactory.validate_sync_redis(
self._redis_client, self._lib_name
)
return result

return wrapper

return decorator


def setup_async_redis():
def decorator(func):
@wraps(func)
async def wrapper(self, *args, **kwargs):
result = await func(self, *args, **kwargs)
await RedisConnectionFactory.validate_async_redis(
self._redis_client, self._lib_name
)
return result

return wrapper
Expand Down Expand Up @@ -140,41 +157,10 @@ class BaseSearchIndex:
StorageType.JSON: JsonStorage,
}

def __init__(
self,
schema: IndexSchema,
redis_client: Optional[Union[redis.Redis, aredis.Redis]] = None,
redis_url: Optional[str] = None,
connection_args: Dict[str, Any] = {},
**kwargs,
):
"""Initialize the RedisVL search index with a schema, Redis client
(or URL string with other connection args), connection_args, and other
kwargs.
Args:
schema (IndexSchema): Index schema object.
redis_client(Union[redis.Redis, aredis.Redis], optional): An
instantiated redis client.
redis_url (str, optional): The URL of the Redis server to
connect to.
connection_args (Dict[str, Any], optional): Redis client connection
args.
"""
# final validation on schema object
if not isinstance(schema, IndexSchema):
raise ValueError("Must provide a valid IndexSchema object")

self.schema = schema

self._lib_name: Optional[str] = kwargs.pop("lib_name", None)
schema: IndexSchema

# set up redis connection
self._redis_client: Optional[Union[redis.Redis, aredis.Redis]] = None
if redis_client is not None:
self.set_client(redis_client)
elif redis_url is not None:
self.connect(redis_url, **connection_args)
def __init__(*args, **kwargs):
pass

@property
def _storage(self) -> BaseStorage:
Expand Down Expand Up @@ -237,8 +223,6 @@ def from_dict(cls, schema_dict: Dict[str, Any], **kwargs):
Args:
schema_dict (Dict[str, Any]): A dictionary containing the schema.
connection_args (Dict[str, Any], optional): Redis client connection
args.
Returns:
SearchIndex: A RedisVL SearchIndex object.
Expand All @@ -262,14 +246,6 @@ def from_dict(cls, schema_dict: Dict[str, Any], **kwargs):
schema = IndexSchema.from_dict(schema_dict)
return cls(schema=schema, **kwargs)

def connect(self, redis_url: Optional[str] = None, **kwargs):
"""Connect to Redis at a given URL."""
raise NotImplementedError

def set_client(self, client: Union[redis.Redis, aredis.Redis]):
"""Manually set the Redis client to use with the search index."""
raise NotImplementedError

def disconnect(self):
"""Disconnect from the Redis database."""
self._redis_client = None
Expand Down Expand Up @@ -323,6 +299,43 @@ class SearchIndex(BaseSearchIndex):
"""

def __init__(
self,
schema: IndexSchema,
redis_client: Optional[redis.Redis] = None,
redis_url: Optional[str] = None,
connection_args: Dict[str, Any] = {},
**kwargs,
):
"""Initialize the RedisVL search index with a schema, Redis client
(or URL string with other connection args), connection_args, and other
kwargs.
Args:
schema (IndexSchema): Index schema object.
redis_client(Optional[redis.Redis]): An
instantiated redis client.
redis_url (Optional[str]): The URL of the Redis server to
connect to.
connection_args (Dict[str, Any], optional): Redis client connection
args.
"""
# final validation on schema object
if not isinstance(schema, IndexSchema):
raise ValueError("Must provide a valid IndexSchema object")

self.schema = schema

self._lib_name: Optional[str] = kwargs.pop("lib_name", None)

# set up redis connection
self._redis_client: Optional[redis.Redis] = None

if redis_client is not None:
self.set_client(redis_client)
elif redis_url is not None:
self.connect(redis_url, **connection_args)

@classmethod
def from_existing(
cls,
Expand All @@ -342,7 +355,7 @@ def from_existing(
)

# Validate modules
installed_modules = RedisConnectionFactory._get_modules(redis_client)
installed_modules = RedisConnectionFactory.get_modules(redis_client)
validate_modules(installed_modules, [{"name": "search", "ver": 20810}])

# Fetch index info and convert to schema
Expand Down Expand Up @@ -380,15 +393,15 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
return self.set_client(client)

@setup_redis()
def set_client(self, client: redis.Redis, **kwargs):
def set_client(self, redis_client: redis.Redis, **kwargs):
"""Manually set the Redis client to use with the search index.
This method configures the search index to use a specific Redis or
Async Redis client. It is useful for cases where an external,
custom-configured client is preferred instead of creating a new one.
Args:
client (redis.Redis): A Redis or Async Redis
redis_client (redis.Redis): A Redis or Async Redis
client instance to be used for the connection.
Raises:
Expand All @@ -404,10 +417,10 @@ def set_client(self, client: redis.Redis, **kwargs):
index.set_client(client)
"""
if not isinstance(client, redis.Redis):
if not isinstance(redis_client, redis.Redis):
raise TypeError("Invalid Redis client instance")

self._redis_client = client
self._redis_client = redis_client

return self

Expand Down Expand Up @@ -759,7 +772,7 @@ class AsyncSearchIndex(BaseSearchIndex):
# initialize the index object with schema from file
index = AsyncSearchIndex.from_yaml("schemas/schema.yaml")
index.connect(redis_url="redis://localhost:6379")
await index.connect(redis_url="redis://localhost:6379")
# create the index
await index.create(overwrite=True)
Expand All @@ -772,6 +785,34 @@ class AsyncSearchIndex(BaseSearchIndex):
"""

def __init__(
self,
schema: IndexSchema,
**kwargs,
):
"""Initialize the RedisVL async search index with a schema.
Args:
schema (IndexSchema): Index schema object.
connection_args (Dict[str, Any], optional): Redis client connection
args.
"""
# final validation on schema object
if not isinstance(schema, IndexSchema):
raise ValueError("Must provide a valid IndexSchema object")

self.schema = schema

self._lib_name: Optional[str] = kwargs.pop("lib_name", None)

# set up empty redis connection
self._redis_client: Optional[aredis.Redis] = None

if "redis_client" in kwargs or "redis_url" in kwargs:
logger.warning(
"Must use set_client() or connect() methods to provide a Redis connection to AsyncSearchIndex"
)

@classmethod
async def from_existing(
cls,
Expand All @@ -791,18 +832,18 @@ async def from_existing(
)

# Validate modules
installed_modules = await RedisConnectionFactory._get_modules_async(
redis_client
)
installed_modules = await RedisConnectionFactory.get_modules_async(redis_client)
validate_modules(installed_modules, [{"name": "search", "ver": 20810}])

# Fetch index info and convert to schema
index_info = await cls._info(name, redis_client)
schema_dict = convert_index_info_to_schema(index_info)
schema = IndexSchema.from_dict(schema_dict)
return cls(schema, redis_client, **kwargs)
index = cls(schema, **kwargs)
await index.set_client(redis_client)
return index

def connect(self, redis_url: Optional[str] = None, **kwargs):
async def connect(self, redis_url: Optional[str] = None, **kwargs):
"""Connect to a Redis instance using the provided `redis_url`, falling
back to the `REDIS_URL` environment variable (if available).
Expand All @@ -828,18 +869,18 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
client = RedisConnectionFactory.connect(
redis_url=redis_url, use_async=True, **kwargs
)
return self.set_client(client)
return await self.set_client(client)

@setup_redis()
def set_client(self, client: aredis.Redis):
@setup_async_redis()
async def set_client(self, redis_client: aredis.Redis):
"""Manually set the Redis client to use with the search index.
This method configures the search index to use a specific
Async Redis client. It is useful for cases where an external,
custom-configured client is preferred instead of creating a new one.
Args:
client (aredis.Redis): An Async Redis
redis_client (aredis.Redis): An Async Redis
client instance to be used for the connection.
Raises:
Expand All @@ -853,13 +894,13 @@ def set_client(self, client: aredis.Redis):
# async Redis client and index
client = aredis.Redis.from_url("redis://localhost:6379")
index = AsyncSearchIndex.from_yaml("schemas/schema.yaml")
index.set_client(client)
await index.set_client(client)
"""
if not isinstance(client, aredis.Redis):
if not isinstance(redis_client, aredis.Redis):
raise TypeError("Invalid Redis client instance")

self._redis_client = client
self._redis_client = redis_client

return self

Expand Down Expand Up @@ -889,6 +930,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
await index.create(overwrite=True, drop=True)
"""
redis_fields = self.schema.redis_fields

if not redis_fields:
raise ValueError("No fields defined for index")
if not isinstance(overwrite, bool):
Expand Down
Loading

0 comments on commit 67eee3d

Please sign in to comment.