diff --git a/airflow/providers/pinecone/CHANGELOG.rst b/airflow/providers/pinecone/CHANGELOG.rst index 7b2a20deb023c..a1482f953467a 100644 --- a/airflow/providers/pinecone/CHANGELOG.rst +++ b/airflow/providers/pinecone/CHANGELOG.rst @@ -20,6 +20,36 @@ Changelog --------- +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +.. warning:: + This release of provider has breaking changes from previous versions. Changes are based on + the migration guide from pinecone - + +* ``log_level`` field is removed from the Connections as it is not used by the provider anymore. +* ``PineconeHook.get_conn`` is removed in favor of ``conn`` property which returns the Connection object. Use ``pinecone_client`` property to access the Pinecone client. +* Following ``PineconeHook`` methods are converted from static methods to instance methods. Hence, Initialization is required to use these now: + + + ``PineconeHook.list_indexes`` + + ``PineconeHook.upsert`` + + ``PineconeHook.create_index`` + + ``PineconeHook.describe_index`` + + ``PineconeHook.delete_index`` + + ``PineconeHook.configure_index`` + + ``PineconeHook.create_collection`` + + ``PineconeHook.delete_collection`` + + ``PineconeHook.describe_collection`` + + ``PineconeHook.list_collections`` + + ``PineconeHook.query_vector`` + + ``PineconeHook.describe_index_stats`` + +* ``PineconeHook.create_index`` is updated to accept a ``ServerlessSpec`` or ``PodSpec`` instead of directly accepting index related configurations +* To initialize ``PineconeHook`` object, API key needs to be passed via argument or the connection. + 1.1.2 ..... diff --git a/airflow/providers/pinecone/hooks/pinecone.py b/airflow/providers/pinecone/hooks/pinecone.py index 3d11c74b647a2..a04ae60ce8391 100644 --- a/airflow/providers/pinecone/hooks/pinecone.py +++ b/airflow/providers/pinecone/hooks/pinecone.py @@ -20,9 +20,11 @@ from __future__ import annotations import itertools +import os +from functools import cached_property from typing import TYPE_CHECKING, Any -import pinecone +from pinecone import Pinecone, PodSpec, ServerlessSpec from airflow.hooks.base import BaseHook @@ -30,6 +32,8 @@ from pinecone.core.client.model.sparse_values import SparseValues from pinecone.core.client.models import DescribeIndexStatsResponse, QueryResponse, UpsertResponse + from airflow.models.connection import Connection + class PineconeHook(BaseHook): """ @@ -49,10 +53,11 @@ def get_connection_form_widgets(cls) -> dict[str, Any]: """Return connection widgets to add to connection form.""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext - from wtforms import StringField + from wtforms import BooleanField, StringField return { - "log_level": StringField(lazy_gettext("Log Level"), widget=BS3TextFieldWidget(), default=None), + "region": StringField(lazy_gettext("Pinecone Region"), widget=BS3TextFieldWidget(), default=None), + "debug_curl": BooleanField(lazy_gettext("PINECONE_DEBUG_CURL"), default=False), "project_id": StringField( lazy_gettext("Project ID"), widget=BS3TextFieldWidget(), @@ -64,43 +69,73 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: """Return custom field behaviour.""" return { "hidden_fields": ["port", "schema"], - "relabeling": {"login": "Pinecone Environment", "password": "Pinecone API key"}, + "relabeling": { + "login": "Pinecone Environment", + "host": "Pinecone Host", + "password": "Pinecone API key", + }, } - def __init__(self, conn_id: str = default_conn_name) -> None: + def __init__( + self, conn_id: str = default_conn_name, environment: str | None = None, region: str | None = None + ) -> None: self.conn_id = conn_id - self.get_conn() - - def get_conn(self) -> None: - pinecone_connection = self.get_connection(self.conn_id) - api_key = pinecone_connection.password - pinecone_environment = pinecone_connection.login - pinecone_host = pinecone_connection.host - extras = pinecone_connection.extra_dejson + self._environment = environment + self._region = region + + @property + def api_key(self) -> str: + key = self.conn.password + if not key: + raise LookupError("Pinecone API Key not found in connection") + return key + + @cached_property + def environment(self) -> str: + if self._environment: + return self._environment + env = self.conn.login + if not env: + raise LookupError("Pinecone environment not found in connection") + return env + + @cached_property + def region(self) -> str: + if self._region: + return self._region + region = self.conn.extra_dejson.get("region") + if not region: + raise LookupError("Pinecone region not found in connection") + return region + + @cached_property + def pinecone_client(self) -> Pinecone: + """Pinecone object to interact with Pinecone.""" + pinecone_host = self.conn.host + extras = self.conn.extra_dejson pinecone_project_id = extras.get("project_id") - log_level = extras.get("log_level", None) - pinecone.init( - api_key=api_key, - environment=pinecone_environment, - host=pinecone_host, - project_name=pinecone_project_id, - log_level=log_level, - ) + enable_curl_debug = extras.get("debug_curl") + if enable_curl_debug: + os.environ["PINECONE_DEBUG_CURL"] = "true" + return Pinecone(api_key=self.api_key, host=pinecone_host, project_id=pinecone_project_id) + + @cached_property + def conn(self) -> Connection: + return self.get_connection(self.conn_id) def test_connection(self) -> tuple[bool, str]: try: - self.list_indexes() + self.pinecone_client.list_indexes() return True, "Connection established" except Exception as e: return False, str(e) - @staticmethod - def list_indexes() -> Any: + def list_indexes(self) -> Any: """Retrieve a list of all indexes in your project.""" - return pinecone.list_indexes() + return self.pinecone_client.list_indexes() - @staticmethod def upsert( + self, index_name: str, vectors: list[Any], namespace: str = "", @@ -126,7 +161,7 @@ def upsert( :param show_progress: Whether to show a progress bar using tqdm. Applied only if batch_size is provided. """ - index = pinecone.Index(index_name) + index = self.pinecone_client.Index(index_name) return index.upsert( vectors=vectors, namespace=namespace, @@ -135,75 +170,93 @@ def upsert( **kwargs, ) - @staticmethod + def get_pod_spec_obj( + self, + *, + replicas: int | None = None, + shards: int | None = None, + pods: int | None = None, + pod_type: str | None = "p1.x1", + metadata_config: dict | None = None, + source_collection: str | None = None, + environment: str | None = None, + ) -> PodSpec: + """ + Get a PodSpec object. + + :param replicas: The number of replicas. + :param shards: The number of shards. + :param pods: The number of pods. + :param pod_type: The type of pod. + :param metadata_config: The metadata configuration. + :param source_collection: The source collection. + :param environment: The environment to use when creating the index. + """ + return PodSpec( + environment=environment or self.environment, + replicas=replicas, + shards=shards, + pods=pods, + pod_type=pod_type, + metadata_config=metadata_config, + source_collection=source_collection, + ) + + def get_serverless_spec_obj(self, *, cloud: str, region: str | None = None) -> ServerlessSpec: + """ + Get a ServerlessSpec object. + + :param cloud: The cloud provider. + :param region: The region to use when creating the index. + """ + return ServerlessSpec(cloud=cloud, region=region or self.region) + def create_index( + self, index_name: str, dimension: int, - index_type: str | None = "approximated", + spec: ServerlessSpec | PodSpec, metric: str | None = "cosine", - replicas: int | None = 1, - shards: int | None = 1, - pods: int | None = 1, - pod_type: str | None = "p1", - index_config: dict[str, str] | None = None, - metadata_config: dict[str, str] | None = None, - source_collection: str | None = "", timeout: int | None = None, ) -> None: """ Create a new index. - .. seealso:: https://docs.pinecone.io/reference/create_index/ - - :param index_name: The name of the index to create. - :param dimension: the dimension of vectors that would be inserted in the index - :param index_type: type of index, one of {"approximated", "exact"}, defaults to "approximated". - :param metric: type of metric used in the vector index, one of {"cosine", "dotproduct", "euclidean"} - :param replicas: the number of replicas, defaults to 1. - :param shards: the number of shards per index, defaults to 1. - :param pods: Total number of pods to be used by the index. pods = shard*replicas - :param pod_type: the pod type to be used for the index. can be one of p1 or s1. - :param index_config: Advanced configuration options for the index - :param metadata_config: Configuration related to the metadata index - :param source_collection: Collection name to create the index from - :param timeout: Timeout for wait until index gets ready. + :param index_name: The name of the index. + :param dimension: The dimension of the vectors to be indexed. + :param spec: Pass a `ServerlessSpec` object to create a serverless index or a `PodSpec` object to create a pod index. + ``get_serverless_spec_obj`` and ``get_pod_spec_obj`` can be used to create the Spec objects. + :param metric: The metric to use. + :param timeout: The timeout to use. """ - pinecone.create_index( + self.pinecone_client.create_index( name=index_name, - timeout=timeout, - index_type=index_type, dimension=dimension, + spec=spec, metric=metric, - pods=pods, - replicas=replicas, - shards=shards, - pod_type=pod_type, - metadata_config=metadata_config, - source_collection=source_collection, - index_config=index_config, + timeout=timeout, ) - @staticmethod - def describe_index(index_name: str) -> Any: + def describe_index(self, index_name: str) -> Any: """ Retrieve information about a specific index. :param index_name: The name of the index to describe. """ - return pinecone.describe_index(name=index_name) + return self.pinecone_client.describe_index(name=index_name) - @staticmethod - def delete_index(index_name: str, timeout: int | None = None) -> None: + def delete_index(self, index_name: str, timeout: int | None = None) -> None: """ Delete a specific index. :param index_name: the name of the index. :param timeout: Timeout for wait until index gets ready. """ - pinecone.delete_index(name=index_name, timeout=timeout) + self.pinecone_client.delete_index(name=index_name, timeout=timeout) - @staticmethod - def configure_index(index_name: str, replicas: int | None = None, pod_type: str | None = "") -> None: + def configure_index( + self, index_name: str, replicas: int | None = None, pod_type: str | None = "" + ) -> None: """ Change the current configuration of the index. @@ -211,43 +264,39 @@ def configure_index(index_name: str, replicas: int | None = None, pod_type: str :param replicas: The new number of replicas. :param pod_type: the new pod_type for the index. """ - pinecone.configure_index(name=index_name, replicas=replicas, pod_type=pod_type) + self.pinecone_client.configure_index(name=index_name, replicas=replicas, pod_type=pod_type) - @staticmethod - def create_collection(collection_name: str, index_name: str) -> None: + def create_collection(self, collection_name: str, index_name: str) -> None: """ Create a new collection from a specified index. :param collection_name: The name of the collection to create. :param index_name: The name of the source index. """ - pinecone.create_collection(name=collection_name, source=index_name) + self.pinecone_client.create_collection(name=collection_name, source=index_name) - @staticmethod - def delete_collection(collection_name: str) -> None: + def delete_collection(self, collection_name: str) -> None: """ Delete a specific collection. :param collection_name: The name of the collection to delete. """ - pinecone.delete_collection(collection_name) + self.pinecone_client.delete_collection(collection_name) - @staticmethod - def describe_collection(collection_name: str) -> Any: + def describe_collection(self, collection_name: str) -> Any: """ Retrieve information about a specific collection. :param collection_name: The name of the collection to describe. """ - return pinecone.describe_collection(collection_name) + return self.pinecone_client.describe_collection(collection_name) - @staticmethod - def list_collections() -> Any: + def list_collections(self) -> Any: """Retrieve a list of all collections in the current project.""" - return pinecone.list_collections() + return self.pinecone_client.list_collections() - @staticmethod def query_vector( + self, index_name: str, vector: list[Any], query_id: str | None = None, @@ -275,7 +324,7 @@ def query_vector( :param sparse_vector: sparse values of the query vector. Expected to be either a SparseValues object or a dict of the form: {'indices': List[int], 'values': List[float]}, where the lists each have the same length. """ - index = pinecone.Index(index_name) + index = self.pinecone_client.Index(index_name) return index.query( vector=vector, id=query_id, @@ -313,7 +362,7 @@ def upsert_data_async( :param pool_threads: Number of threads for parallel upserting. If async_req is True, this must be provided. """ responses = [] - with pinecone.Index(index_name, pool_threads=pool_threads) as index: + with self.pinecone_client.Index(index_name, pool_threads=pool_threads) as index: if async_req and pool_threads: async_results = [index.upsert(vectors=chunk, async_req=True) for chunk in self._chunks(data)] responses = [async_result.get() for async_result in async_results] @@ -323,8 +372,8 @@ def upsert_data_async( responses.append(response) return responses - @staticmethod def describe_index_stats( + self, index_name: str, stats_filter: dict[str, str | float | int | bool | list[Any] | dict[Any, Any]] | None = None, **kwargs: Any, @@ -340,5 +389,5 @@ def describe_index_stats( :param stats_filter: If this parameter is present, the operation only returns statistics for vectors that satisfy the filter. See https://www.pinecone.io/docs/metadata-filtering/ """ - index = pinecone.Index(index_name) + index = self.pinecone_client.Index(index_name) return index.describe_index_stats(filter=stats_filter, **kwargs) diff --git a/airflow/providers/pinecone/operators/pinecone.py b/airflow/providers/pinecone/operators/pinecone.py index 1c757d8fa541c..8431276206e07 100644 --- a/airflow/providers/pinecone/operators/pinecone.py +++ b/airflow/providers/pinecone/operators/pinecone.py @@ -22,6 +22,7 @@ from airflow.models import BaseOperator from airflow.providers.pinecone.hooks.pinecone import PineconeHook +from airflow.utils.context import Context if TYPE_CHECKING: from airflow.utils.context import Context @@ -81,3 +82,132 @@ def execute(self, context: Context) -> None: ) self.log.info("Successfully ingested data into Pinecone index %s.", self.index_name) + + +class CreatePodIndexOperator(BaseOperator): + """ + Create a pod based index in Pinecone. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CreatePodIndexOperator` + + :param conn_id: The connection id to use when connecting to Pinecone. + :param index_name: Name of the Pinecone index. + :param dimension: The dimension of the vectors to be indexed. + :param environment: The environment to use when creating the index. + :param replicas: The number of replicas to use. + :param shards: The number of shards to use. + :param pods: The number of pods to use. + :param pod_type: The type of pod to use. + :param metadata_config: The metadata configuration to use. + :param source_collection: The source collection to use. + :param metric: The metric to use. + :param timeout: The timeout to use. + """ + + def __init__( + self, + *, + conn_id: str = PineconeHook.default_conn_name, + index_name: str, + dimension: int, + environment: str | None = None, + replicas: int | None = None, + shards: int | None = None, + pods: int | None = None, + pod_type: str | None = None, + metadata_config: dict | None = None, + source_collection: str | None = None, + metric: str | None = None, + timeout: int | None = None, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.conn_id = conn_id + self.index_name = index_name + self.dimension = dimension + self.environment = environment + self.replicas = replicas + self.shards = shards + self.pods = pods + self.pod_type = pod_type + self.metadata_config = metadata_config + self.source_collection = source_collection + self.metric = metric + self.timeout = timeout + + @cached_property + def hook(self) -> PineconeHook: + return PineconeHook(conn_id=self.conn_id, environment=self.environment) + + def execute(self, context: Context) -> None: + pod_spec_obj = self.hook.get_pod_spec_obj( + replicas=self.replicas, + shards=self.shards, + pods=self.pods, + pod_type=self.pod_type, + metadata_config=self.metadata_config, + source_collection=self.source_collection, + environment=self.environment, + ) + self.hook.create_index( + index_name=self.index_name, + dimension=self.dimension, + spec=pod_spec_obj, + metric=self.metric, + timeout=self.timeout, + ) + + +class CreateServerlessIndexOperator(BaseOperator): + """ + Create a serverless index in Pinecone. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CreateServerlessIndexOperator` + + :param conn_id: The connection id to use when connecting to Pinecone. + :param index_name: Name of the Pinecone index. + :param dimension: The dimension of the vectors to be indexed. + :param cloud: The cloud to use when creating the index. + :param region: The region to use when creating the index. + :param metric: The metric to use. + :param timeout: The timeout to use. + """ + + def __init__( + self, + *, + conn_id: str = PineconeHook.default_conn_name, + index_name: str, + dimension: int, + cloud: str, + region: str | None = None, + metric: str | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.conn_id = conn_id + self.index_name = index_name + self.dimension = dimension + self.cloud = cloud + self.region = region + self.metric = metric + self.timeout = timeout + + @cached_property + def hook(self) -> PineconeHook: + return PineconeHook(conn_id=self.conn_id, region=self.region) + + def execute(self, context: Context) -> None: + serverless_spec_obj = self.hook.get_serverless_spec_obj(cloud=self.cloud, region=self.region) + self.hook.create_index( + index_name=self.index_name, + dimension=self.dimension, + spec=serverless_spec_obj, + metric=self.metric, + timeout=self.timeout, + ) diff --git a/airflow/providers/pinecone/provider.yaml b/airflow/providers/pinecone/provider.yaml index a48f041fa896b..0c6e3d9b4c586 100644 --- a/airflow/providers/pinecone/provider.yaml +++ b/airflow/providers/pinecone/provider.yaml @@ -42,10 +42,7 @@ integrations: dependencies: - apache-airflow>=2.7.0 - # Pinecone Python SDK v3.0.0 was released at 2024-01-16 and introduce some breaking changes. - # It's crucial to adhere to the v3.0.0 Migration Guide before the upper-bound limitation can be removed. - # https://canyon-quilt-082.notion.site/Pinecone-Python-SDK-v3-0-0-Migration-Guide-056d3897d7634bf7be399676a4757c7b - - pinecone-client>=2.2.4,<3.0 + - pinecone-client>=3.0.0 hooks: - integration-name: Pinecone diff --git a/docs/apache-airflow-providers-pinecone/connections.rst b/docs/apache-airflow-providers-pinecone/connections.rst index 07054a9388b79..50a72b133a950 100644 --- a/docs/apache-airflow-providers-pinecone/connections.rst +++ b/docs/apache-airflow-providers-pinecone/connections.rst @@ -33,11 +33,17 @@ Configuring the Connection Host (optional) Host URL to connect to a specific Pinecone index. -Pinecone Environment (required) - Specify your Pinecone environment to connect to. +Pinecone Environment (optional) + Specify your Pinecone environment for pod based indexes. Pinecone API key (required) Specify your Pinecone API Key to connect. -Project ID (required) +Project ID (optional) Project ID corresponding to your API Key. + +Pinecone Region (optional) + Specify the region for Serverless Indexes in Pinecone. + +PINECONE_DEBUG_CURL (optional) + Set to ``true`` to enable curl debug output. diff --git a/docs/apache-airflow-providers-pinecone/index.rst b/docs/apache-airflow-providers-pinecone/index.rst index d82935e3181c8..91913b916844c 100644 --- a/docs/apache-airflow-providers-pinecone/index.rst +++ b/docs/apache-airflow-providers-pinecone/index.rst @@ -69,7 +69,7 @@ Package apache-airflow-providers-pinecone `Pinecone `__ -Release: 1.1.2 +Release: 2.0.0 Provider package ---------------- @@ -93,5 +93,5 @@ The minimum Apache Airflow version supported by this provider package is ``2.6.0 PIP package Version required =================== ================== ``apache-airflow`` ``>=2.6.0`` -``pinecone-client`` ``>=2.2.4,<3.0`` +``pinecone-client`` ``>=3.0.0`` =================== ================== diff --git a/docs/apache-airflow-providers-pinecone/operators/pinecone.rst b/docs/apache-airflow-providers-pinecone/operators/pinecone.rst index 71f847919fa80..b50e5300f09a9 100644 --- a/docs/apache-airflow-providers-pinecone/operators/pinecone.rst +++ b/docs/apache-airflow-providers-pinecone/operators/pinecone.rst @@ -15,10 +15,13 @@ specific language governing permissions and limitations under the License. +Operators +--------- + .. _howto/operator:PineconeIngestOperator: -PineconeIngestOperator -====================== +Ingest data into a pinecone index +================================= Use the :class:`~airflow.providers.pinecone.operators.pinecone.PineconeIngestOperator` to interact with Pinecone APIs to ingest vectors. @@ -38,3 +41,48 @@ An example using the operator in this way: :dedent: 4 :start-after: [START howto_operator_pinecone_ingest] :end-before: [END howto_operator_pinecone_ingest] + +.. _howto/operator:CreatePodIndexOperator: + +Create a Pod based Index +======================== + +Use the :class:`~airflow.providers.pinecone.operators.pinecone.CreatePodIndexOperator` to +interact with Pinecone APIs to create a Pod based Index. + +Using the Operator +^^^^^^^^^^^^^^^^^^ + +The ``CreatePodIndexOperator`` requires the index details as well as the pod configuration details. ``api_key``, ``environment`` can be +passed via arguments to the operator or via the connection. + +An example using the operator in this way: + +.. exampleinclude:: /../../tests/system/providers/pinecone/example_create_pod_index.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_create_pod_index] + :end-before: [END howto_operator_create_pod_index] + + +.. _howto/operator:CreateServerlessIndexOperator: + +Create a Serverless Index +========================= + +Use the :class:`~airflow.providers.pinecone.operators.pinecone.CreateServerlessIndexOperator` to +interact with Pinecone APIs to create a Pod based Index. + +Using the Operator +^^^^^^^^^^^^^^^^^^ + +The ``CreateServerlessIndexOperator`` requires the index details as well as the Serverless configuration details. ``api_key``, ``environment`` can be +passed via arguments to the operator or via the connection. + +An example using the operator in this way: + +.. exampleinclude:: /../../tests/system/providers/pinecone/example_create_serverless_index.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_create_serverless_index] + :end-before: [END howto_operator_create_serverless_index] diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index d91b48858029a..6b3d2e7d8158f 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -910,7 +910,7 @@ "pinecone": { "deps": [ "apache-airflow>=2.7.0", - "pinecone-client>=2.2.4,<3.0" + "pinecone-client>=3.0.0" ], "devel-deps": [], "cross-providers-deps": [], diff --git a/tests/providers/pinecone/hooks/test_pinecone.py b/tests/providers/pinecone/hooks/test_pinecone.py index fb076cc0a38de..82a01e53198d9 100644 --- a/tests/providers/pinecone/hooks/test_pinecone.py +++ b/tests/providers/pinecone/hooks/test_pinecone.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import os from unittest.mock import Mock, patch from airflow.providers.pinecone.hooks.pinecone import PineconeHook @@ -28,13 +29,15 @@ def setup_method(self): with patch("airflow.models.Connection.get_connection_from_secrets") as mock_get_connection: mock_conn = Mock() mock_conn.host = "pinecone.io" - mock_conn.login = "test_user" - mock_conn.password = "test_password" + mock_conn.login = "us-west1-gcp" # Pinecone Environment + mock_conn.password = "test_password" # Pinecone API Key + mock_conn.extra_dejson = {"region": "us-east-1", "debug_curl": True} mock_get_connection.return_value = mock_conn self.pinecone_hook = PineconeHook() + self.pinecone_hook.conn self.index_name = "test_index" - @patch("airflow.providers.pinecone.hooks.pinecone.pinecone.Index") + @patch("airflow.providers.pinecone.hooks.pinecone.Pinecone.Index") def test_upsert(self, mock_index): """Test the upsert_data_async method of PineconeHook for correct data insertion asynchronously.""" data = [("id1", [1.0, 2.0, 3.0], {"meta": "data"})] @@ -49,11 +52,38 @@ def test_list_indexes(self, mock_list_indexes): self.pinecone_hook.list_indexes() mock_list_indexes.assert_called_once() + @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.list_indexes") + def test_debug_curl_setting(self, mock_list_indexes): + """Test that the PINECONE_DEBUG_CURL environment variable is set when initializing Pinecone Object.""" + self.pinecone_hook.list_indexes() + mock_list_indexes.assert_called_once() + assert os.environ.get("PINECONE_DEBUG_CURL") == "true" + + @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.create_index") + def test_create_index_for_pod_based(self, mock_create_index): + """Test that the create_index method of PineconeHook is called with correct arguments for pod based index.""" + pod_spec = self.pinecone_hook.get_pod_spec_obj() + self.pinecone_hook.create_index(index_name=self.index_name, dimension=128, spec=pod_spec) + mock_create_index.assert_called_once_with(index_name="test_index", dimension=128, spec=pod_spec) + @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.create_index") - def test_create_index(self, mock_create_index): - """Test that the create_index method of PineconeHook is called with correct arguments.""" - self.pinecone_hook.create_index(index_name=self.index_name, dimension=128) - mock_create_index.assert_called_once_with(index_name="test_index", dimension=128) + def test_create_index_for_serverless_based(self, mock_create_index): + """Test that the create_index method of PineconeHook is called with correct arguments for serverless index.""" + serverless_spec = self.pinecone_hook.get_serverless_spec_obj(cloud="aws") + self.pinecone_hook.create_index(index_name=self.index_name, dimension=128, spec=serverless_spec) + mock_create_index.assert_called_once_with( + index_name="test_index", dimension=128, spec=serverless_spec + ) + + def test_get_pod_spec_obj(self): + """Test that the get_pod_spec_obj method of PineconeHook returns the correct pod spec object.""" + pod_spec = self.pinecone_hook.get_pod_spec_obj() + assert pod_spec.environment == "us-west1-gcp" + + def test_get_serverless_spec_obj(self): + """Test that the get_serverless_spec_obj method of PineconeHook returns the correct serverless spec object.""" + serverless_spec = self.pinecone_hook.get_serverless_spec_obj(cloud="gcp") + assert serverless_spec.region == "us-east-1" @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.describe_index") def test_describe_index(self, mock_describe_index): diff --git a/tests/system/providers/pinecone/example_create_pod_index.py b/tests/system/providers/pinecone/example_create_pod_index.py new file mode 100644 index 0000000000000..9b6f7d7d8882d --- /dev/null +++ b/tests/system/providers/pinecone/example_create_pod_index.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +from airflow import DAG +from airflow.providers.pinecone.operators.pinecone import CreatePodIndexOperator + +index_name = os.getenv("INDEX_NAME", "test") + + +with DAG( + "example_pinecone_create_pod_index", + schedule="@once", + start_date=datetime(2024, 1, 1), + catchup=False, +) as dag: + # [START howto_operator_create_pod_index] + # reference: https://docs.pinecone.io/reference/api/control-plane/create_index + CreatePodIndexOperator( + task_id="pinecone_create_pod_index", + index_name=index_name, + dimension=3, + replicas=1, + shards=1, + pods=1, + pod_type="p1.x1", + ) + # [END howto_operator_create_pod_index] + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/pinecone/example_create_serverless_index.py b/tests/system/providers/pinecone/example_create_serverless_index.py new file mode 100644 index 0000000000000..a7924e63ef338 --- /dev/null +++ b/tests/system/providers/pinecone/example_create_serverless_index.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +from airflow import DAG +from airflow.providers.pinecone.operators.pinecone import CreateServerlessIndexOperator + +index_name = os.getenv("INDEX_NAME", "test") + + +with DAG( + "example_pinecone_create_serverless_index", + schedule="@once", + start_date=datetime(2024, 1, 1), + catchup=False, +) as dag: + # [START howto_operator_create_serverless_index] + # reference: https://docs.pinecone.io/reference/api/control-plane/create_index + CreateServerlessIndexOperator( + task_id="pinecone_create_serverless_index", + index_name=index_name, + dimension=128, + cloud="aws", + region="us-west-2", + metric="cosine", + ) + # [END howto_operator_create_serverless_index] + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)