diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 39cdc7b..6727919 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -83,6 +83,7 @@ jobs: run: | export ES_URI="http://localhost:9200" export ES_PORT=9200 + export ES_SUPPORT_DATETIME_PARSE=False nosetests -v --with-coverage --cover-package=es es.tests - name: Run tests on Elasticsearch 7.10.X run: | @@ -97,6 +98,7 @@ jobs: export ES_PORT=19200 export ES_SCHEME=https export ES_USER=admin + export ES_SUPPORT_DATETIME_PARSE=False nosetests -v --with-coverage --cover-package=es es.tests - name: Run tests on Opendistro 13 run: | @@ -107,6 +109,7 @@ jobs: export ES_SCHEME=https export ES_USER=admin export ES_V2=True + export ES_SUPPORT_DATETIME_PARSE=False nosetests -v --with-coverage --cover-package=es es.tests - name: Upload code coverage run: | diff --git a/es/baseapi.py b/es/baseapi.py index 6cb4daa..a092c5a 100644 --- a/es/baseapi.py +++ b/es/baseapi.py @@ -109,7 +109,7 @@ def get_description_from_columns( class BaseConnection(object): - """Connection to an ES Cluster """ + """Connection to an ES Cluster""" def __init__( self, @@ -192,6 +192,7 @@ def __init__(self, url: str, es: Elasticsearch, **kwargs): self.es = es self.sql_path = kwargs.get("sql_path", DEFAULT_SQL_PATH) self.fetch_size = kwargs.get("fetch_size", DEFAULT_FETCH_SIZE) + self.time_zone: Optional[str] = kwargs.get("time_zone") # This read/write attribute specifies the number of rows to fetch at a # time with .fetchmany(). It defaults to 1 meaning to fetch a single # row at a time. @@ -218,7 +219,7 @@ def custom_sql_to_method_dispatcher(self, command: str) -> Optional["BaseCursor" @check_result @check_closed def rowcount(self) -> int: - """ Counts the number of rows on a result """ + """Counts the number of rows on a result""" if self._results: return len(self._results) return 0 @@ -230,7 +231,7 @@ def close(self) -> None: @check_closed def execute(self, operation, parameters=None) -> "BaseCursor": - """ Children must implement their own custom execute """ + """Children must implement their own custom execute""" raise NotImplementedError # pragma: no cover @check_closed @@ -311,11 +312,13 @@ def elastic_query(self, query: str) -> Dict[str, Any]: payload = {"query": query} if self.fetch_size is not None: payload["fetch_size"] = self.fetch_size + if self.time_zone is not None: + payload["time_zone"] = self.time_zone path = f"/{self.sql_path}/" try: response = self.es.transport.perform_request("POST", path, body=payload) except es_exceptions.ConnectionError: - raise exceptions.OperationalError(f"Error connecting to Elasticsearch") + raise exceptions.OperationalError("Error connecting to Elasticsearch") except es_exceptions.RequestError as ex: raise exceptions.ProgrammingError( f"Error ({ex.error}): {ex.info['error']['reason']}" diff --git a/es/elastic/api.py b/es/elastic/api.py index ed1a4c4..13e41f8 100644 --- a/es/elastic/api.py +++ b/es/elastic/api.py @@ -38,7 +38,7 @@ def connect( class Connection(BaseConnection): - """Connection to an ES Cluster """ + """Connection to an ES Cluster""" def __init__( self, diff --git a/es/opendistro/api.py b/es/opendistro/api.py index b505c82..a27c14b 100644 --- a/es/opendistro/api.py +++ b/es/opendistro/api.py @@ -42,7 +42,7 @@ def connect( class Connection(BaseConnection): - """Connection to an ES Cluster """ + """Connection to an ES Cluster""" def __init__( self, diff --git a/es/tests/test_dbapi.py b/es/tests/test_dbapi.py index e42a306..5968b70 100644 --- a/es/tests/test_dbapi.py +++ b/es/tests/test_dbapi.py @@ -7,28 +7,35 @@ from es.opendistro.api import connect as open_connect +def convert_bool(value: str) -> bool: + return True if value == "True" else False + + class TestDBAPI(unittest.TestCase): def setUp(self): self.driver_name = os.environ.get("ES_DRIVER", "elasticsearch") - host = os.environ.get("ES_HOST", "localhost") - port = int(os.environ.get("ES_PORT", 9200)) - scheme = os.environ.get("ES_SCHEME", "http") - verify_certs = os.environ.get("ES_VERIFY_CERTS", False) - user = os.environ.get("ES_USER", None) - password = os.environ.get("ES_PASSWORD", None) + self.host = os.environ.get("ES_HOST", "localhost") + self.port = int(os.environ.get("ES_PORT", 9200)) + self.scheme = os.environ.get("ES_SCHEME", "http") + self.verify_certs = os.environ.get("ES_VERIFY_CERTS", False) + self.user = os.environ.get("ES_USER", None) + self.password = os.environ.get("ES_PASSWORD", None) self.v2 = bool(os.environ.get("ES_V2", False)) + self.support_datetime_parse = convert_bool( + os.environ.get("ES_SUPPORT_DATETIME_PARSE", "True") + ) if self.driver_name == "elasticsearch": self.connect_func = elastic_connect else: self.connect_func = open_connect self.conn = self.connect_func( - host=host, - port=port, - scheme=scheme, - verify_certs=verify_certs, - user=user, - password=password, + host=self.host, + port=self.port, + scheme=self.scheme, + verify_certs=self.verify_certs, + user=self.user, + password=self.password, v2=self.v2, ) self.cursor = self.conn.cursor() @@ -213,3 +220,62 @@ def test_https(self, mock_elasticsearch): mock_elasticsearch.assert_called_once_with( "https://localhost:9200/", http_auth=("user", "password") ) + + def test_simple_search_with_time_zone(self): + """ + DBAPI: Test simple search with time zone + UTC -> CST + 2019-10-13T00:00:00.000Z => 2019-10-13T08:00:00.000+08:00 + 2019-10-13T00:00:01.000Z => 2019-10-13T08:01:00.000+08:00 + 2019-10-13T00:00:02.000Z => 2019-10-13T08:02:00.000+08:00 + """ + + if not self.support_datetime_parse: + return + + conn = self.connect_func( + host=self.host, + port=self.port, + scheme=self.scheme, + verify_certs=self.verify_certs, + user=self.user, + password=self.password, + v2=self.v2, + time_zone="Asia/Shanghai", + ) + cursor = conn.cursor() + pattern = "yyyy-MM-dd HH:mm:ss" + sql = f""" + SELECT timestamp FROM data1 + WHERE timestamp >= DATETIME_PARSE('2019-10-13 00:08:00', '{pattern}') + """ + + rows = cursor.execute(sql).fetchall() + self.assertEqual(len(rows), 3) + + def test_simple_search_without_time_zone(self): + """ + DBAPI: Test simple search without time zone + """ + + if not self.support_datetime_parse: + return + + conn = self.connect_func( + host=self.host, + port=self.port, + scheme=self.scheme, + verify_certs=self.verify_certs, + user=self.user, + password=self.password, + v2=self.v2, + ) + cursor = conn.cursor() + pattern = "yyyy-MM-dd HH:mm:ss" + sql = f""" + SELECT * FROM data1 + WHERE timestamp >= DATETIME_PARSE('2019-10-13 08:00:00', '{pattern}') + """ + + rows = cursor.execute(sql).fetchall() + self.assertEqual(len(rows), 0)