Skip to content

Commit

Permalink
feat(query): add time_zone param
Browse files Browse the repository at this point in the history
  • Loading branch information
aniaan committed Sep 16, 2021
1 parent ea7a093 commit 4cf73f2
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 18 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ jobs:
run: |
export ES_URI="http://localhost:9200"
export ES_PORT=9200
export ES_SUPPORT_DATETIME_PARSE=False
nosetests -v --with-coverage --cover-package=es es.tests
- name: Run tests on Elasticsearch 7.10.X
run: |
Expand All @@ -97,6 +98,7 @@ jobs:
export ES_PORT=19200
export ES_SCHEME=https
export ES_USER=admin
export ES_SUPPORT_DATETIME_PARSE=False
nosetests -v --with-coverage --cover-package=es es.tests
- name: Run tests on Opendistro 13
run: |
Expand All @@ -107,6 +109,7 @@ jobs:
export ES_SCHEME=https
export ES_USER=admin
export ES_V2=True
export ES_SUPPORT_DATETIME_PARSE=False
nosetests -v --with-coverage --cover-package=es es.tests
- name: Upload code coverage
run: |
Expand Down
11 changes: 7 additions & 4 deletions es/baseapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_description_from_columns(

class BaseConnection(object):

"""Connection to an ES Cluster """
"""Connection to an ES Cluster"""

def __init__(
self,
Expand Down Expand Up @@ -192,6 +192,7 @@ def __init__(self, url: str, es: Elasticsearch, **kwargs):
self.es = es
self.sql_path = kwargs.get("sql_path", DEFAULT_SQL_PATH)
self.fetch_size = kwargs.get("fetch_size", DEFAULT_FETCH_SIZE)
self.time_zone: Optional[str] = kwargs.get("time_zone")
# This read/write attribute specifies the number of rows to fetch at a
# time with .fetchmany(). It defaults to 1 meaning to fetch a single
# row at a time.
Expand All @@ -218,7 +219,7 @@ def custom_sql_to_method_dispatcher(self, command: str) -> Optional["BaseCursor"
@check_result
@check_closed
def rowcount(self) -> int:
""" Counts the number of rows on a result """
"""Counts the number of rows on a result"""
if self._results:
return len(self._results)
return 0
Expand All @@ -230,7 +231,7 @@ def close(self) -> None:

@check_closed
def execute(self, operation, parameters=None) -> "BaseCursor":
""" Children must implement their own custom execute """
"""Children must implement their own custom execute"""
raise NotImplementedError # pragma: no cover

@check_closed
Expand Down Expand Up @@ -311,11 +312,13 @@ def elastic_query(self, query: str) -> Dict[str, Any]:
payload = {"query": query}
if self.fetch_size is not None:
payload["fetch_size"] = self.fetch_size
if self.time_zone is not None:
payload["time_zone"] = self.time_zone
path = f"/{self.sql_path}/"
try:
response = self.es.transport.perform_request("POST", path, body=payload)
except es_exceptions.ConnectionError:
raise exceptions.OperationalError(f"Error connecting to Elasticsearch")
raise exceptions.OperationalError("Error connecting to Elasticsearch")
except es_exceptions.RequestError as ex:
raise exceptions.ProgrammingError(
f"Error ({ex.error}): {ex.info['error']['reason']}"
Expand Down
2 changes: 1 addition & 1 deletion es/elastic/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def connect(

class Connection(BaseConnection):

"""Connection to an ES Cluster """
"""Connection to an ES Cluster"""

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion es/opendistro/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def connect(

class Connection(BaseConnection):

"""Connection to an ES Cluster """
"""Connection to an ES Cluster"""

def __init__(
self,
Expand Down
90 changes: 78 additions & 12 deletions es/tests/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,35 @@
from es.opendistro.api import connect as open_connect


def convert_bool(value: str) -> bool:
return True if value == "True" else False


class TestDBAPI(unittest.TestCase):
def setUp(self):
self.driver_name = os.environ.get("ES_DRIVER", "elasticsearch")
host = os.environ.get("ES_HOST", "localhost")
port = int(os.environ.get("ES_PORT", 9200))
scheme = os.environ.get("ES_SCHEME", "http")
verify_certs = os.environ.get("ES_VERIFY_CERTS", False)
user = os.environ.get("ES_USER", None)
password = os.environ.get("ES_PASSWORD", None)
self.host = os.environ.get("ES_HOST", "localhost")
self.port = int(os.environ.get("ES_PORT", 9200))
self.scheme = os.environ.get("ES_SCHEME", "http")
self.verify_certs = os.environ.get("ES_VERIFY_CERTS", False)
self.user = os.environ.get("ES_USER", None)
self.password = os.environ.get("ES_PASSWORD", None)
self.v2 = bool(os.environ.get("ES_V2", False))
self.support_datetime_parse = convert_bool(
os.environ.get("ES_SUPPORT_DATETIME_PARSE", "True")
)

if self.driver_name == "elasticsearch":
self.connect_func = elastic_connect
else:
self.connect_func = open_connect
self.conn = self.connect_func(
host=host,
port=port,
scheme=scheme,
verify_certs=verify_certs,
user=user,
password=password,
host=self.host,
port=self.port,
scheme=self.scheme,
verify_certs=self.verify_certs,
user=self.user,
password=self.password,
v2=self.v2,
)
self.cursor = self.conn.cursor()
Expand Down Expand Up @@ -213,3 +220,62 @@ def test_https(self, mock_elasticsearch):
mock_elasticsearch.assert_called_once_with(
"https://localhost:9200/", http_auth=("user", "password")
)

def test_simple_search_with_time_zone(self):
"""
DBAPI: Test simple search with time zone
UTC -> CST
2019-10-13T00:00:00.000Z => 2019-10-13T08:00:00.000+08:00
2019-10-13T00:00:01.000Z => 2019-10-13T08:01:00.000+08:00
2019-10-13T00:00:02.000Z => 2019-10-13T08:02:00.000+08:00
"""

if not self.support_datetime_parse:
return

conn = self.connect_func(
host=self.host,
port=self.port,
scheme=self.scheme,
verify_certs=self.verify_certs,
user=self.user,
password=self.password,
v2=self.v2,
time_zone="Asia/Shanghai",
)
cursor = conn.cursor()
pattern = "yyyy-MM-dd HH:mm:ss"
sql = f"""
SELECT timestamp FROM data1
WHERE timestamp >= DATETIME_PARSE('2019-10-13 00:08:00', '{pattern}')
"""

rows = cursor.execute(sql).fetchall()
self.assertEqual(len(rows), 3)

def test_simple_search_without_time_zone(self):
"""
DBAPI: Test simple search without time zone
"""

if not self.support_datetime_parse:
return

conn = self.connect_func(
host=self.host,
port=self.port,
scheme=self.scheme,
verify_certs=self.verify_certs,
user=self.user,
password=self.password,
v2=self.v2,
)
cursor = conn.cursor()
pattern = "yyyy-MM-dd HH:mm:ss"
sql = f"""
SELECT * FROM data1
WHERE timestamp >= DATETIME_PARSE('2019-10-13 08:00:00', '{pattern}')
"""

rows = cursor.execute(sql).fetchall()
self.assertEqual(len(rows), 0)

0 comments on commit 4cf73f2

Please sign in to comment.