forked from apache/superset
-
Notifications
You must be signed in to change notification settings - Fork 11
/
base.py
2511 lines (2183 loc) · 87.9 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
from __future__ import annotations
import logging
import re
import warnings
from datetime import datetime
from re import Match, Pattern
from typing import (
Any,
Callable,
cast,
ContextManager,
NamedTuple,
TYPE_CHECKING,
TypedDict,
Union,
)
from urllib.parse import urlencode, urljoin
from uuid import uuid4
import pandas as pd
import requests
import sqlparse
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from deprecation import deprecated
from flask import current_app, g, url_for
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as __, lazy_gettext as _
from marshmallow import fields, Schema
from marshmallow.validate import Range
from sqlalchemy import column, select, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.interfaces import Compiled, Dialect
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import literal_column, quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
from sqlalchemy.types import TypeEngine
from sqlparse.tokens import CTE
from superset import db, sql_parse
from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
from superset.sql.parse import BaseSQLStatement, SQLScript, Table
from superset.sql_parse import ParsedQuery
from superset.superset_typing import (
OAuth2ClientConfig,
OAuth2State,
OAuth2TokenResponse,
ResultSetColumnType,
SQLAColumnType,
)
from superset.utils import core as utils, json
from superset.utils.core import ColumnSpec, GenericDataType
from superset.utils.hashing import md5_sha_from_str
from superset.utils.json import redact_sensitive, reveal_sensitive
from superset.utils.network import is_hostname_valid, is_port_open
from superset.utils.oauth2 import encode_oauth2_state
if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
from superset.databases.schemas import TableMetadataResponse
from superset.models.core import Database
from superset.models.sql_lab import Query
ColumnTypeMapping = tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
]
logger = logging.getLogger()
# When connecting to a database it's hard to catch specific exceptions, since we support
# more than 50 different database drivers. Usually the try/except block will catch the
# generic `Exception` class, which requires a pylint disablee comment. To make it clear
# that we know this is a necessary evil we create an alias, and catch it instead.
GenericDBException = Exception
def convert_inspector_columns(cols: list[SQLAColumnType]) -> list[ResultSetColumnType]:
result_set_columns: list[ResultSetColumnType] = []
for col in cols:
result_set_columns.append({"column_name": col.get("name"), **col}) # type: ignore
return result_set_columns
class TimeGrain(NamedTuple):
name: str # TODO: redundant field, remove
label: str
function: str
duration: str | None
builtin_time_grains: dict[str | None, str] = {
TimeGrainConstants.SECOND: _("Second"),
TimeGrainConstants.FIVE_SECONDS: _("5 second"),
TimeGrainConstants.THIRTY_SECONDS: _("30 second"),
TimeGrainConstants.MINUTE: _("Minute"),
TimeGrainConstants.FIVE_MINUTES: _("5 minute"),
TimeGrainConstants.TEN_MINUTES: _("10 minute"),
TimeGrainConstants.FIFTEEN_MINUTES: _("15 minute"),
TimeGrainConstants.THIRTY_MINUTES: _("30 minute"),
TimeGrainConstants.HOUR: _("Hour"),
TimeGrainConstants.SIX_HOURS: _("6 hour"),
TimeGrainConstants.DAY: _("Day"),
TimeGrainConstants.WEEK: _("Week"),
TimeGrainConstants.MONTH: _("Month"),
TimeGrainConstants.QUARTER: _("Quarter"),
TimeGrainConstants.YEAR: _("Year"),
TimeGrainConstants.WEEK_STARTING_SUNDAY: _("Week starting Sunday"),
TimeGrainConstants.WEEK_STARTING_MONDAY: _("Week starting Monday"),
TimeGrainConstants.WEEK_ENDING_SATURDAY: _("Week ending Saturday"),
TimeGrainConstants.WEEK_ENDING_SUNDAY: _("Week ending Sunday"),
}
class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
"""Sqlalchemy class that can be used to render native column elements respecting
engine-specific quoting rules as part of a string-based expression.
:param expr: Sql expression with '{col}' denoting the locations where the col
object will be rendered.
:param col: the target column
"""
super().__init__(expr, **kwargs)
self.col = col
@property
def _constructor(self) -> ColumnClause:
# Needed to ensure that the column label is rendered correctly when
# proxied to the outer query.
# See https://github.com/sqlalchemy/sqlalchemy/issues/4730
return ColumnClause
@compiles(TimestampExpression)
def compile_timegrain_expression(
element: TimestampExpression, compiler: Compiled, **kwargs: Any
) -> str:
return element.name.replace("{col}", compiler.process(element.col, **kwargs))
class LimitMethod: # pylint: disable=too-few-public-methods
"""Enum the ways that limits can be applied"""
FETCH_MANY = "fetch_many"
WRAP_SQL = "wrap_sql"
FORCE_LIMIT = "force_limit"
class MetricType(TypedDict, total=False):
"""
Type for metrics return by `get_metrics`.
"""
metric_name: str
expression: str
verbose_name: str | None
metric_type: str | None
description: str | None
d3format: str | None
currency: str | None
warning_text: str | None
extra: str | None
class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""Abstract class for database engine specific configurations
Attributes:
allows_alias_to_source_column: Whether the engine is able to pick the
source column for aggregation clauses
used in ORDER BY when a column in SELECT
has an alias that is the same as a source
column.
allows_hidden_orderby_agg: Whether the engine allows ORDER BY to
directly use aggregation clauses, without
having to add the same aggregation in SELECT.
"""
engine_name: str | None = None # for user messages, overridden in child classes
# These attributes map the DB engine spec to one or more SQLAlchemy dialects/drivers;
# see the ``supports_url`` and ``supports_backend`` methods below.
engine = "base" # str as defined in sqlalchemy.engine.engine
engine_aliases: set[str] = set()
drivers: dict[str, str] = {}
default_driver: str | None = None
# placeholder with the SQLAlchemy URI template
sqlalchemy_uri_placeholder = (
"engine+driver://user:password@host:port/dbname[?key=value&key=value...]"
)
disable_ssh_tunneling = False
_date_trunc_functions: dict[str, str] = {}
_time_grain_expressions: dict[str | None, str] = {}
_default_column_type_mappings: tuple[ColumnTypeMapping, ...] = (
(
re.compile(r"^string", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r"^n((var)?char|text)", re.IGNORECASE),
types.UnicodeText(),
GenericDataType.STRING,
),
(
re.compile(r"^(var)?char", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r"^(tiny|medium|long)?text", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r"^smallint", re.IGNORECASE),
types.SmallInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^int(eger)?", re.IGNORECASE),
types.Integer(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^bigint", re.IGNORECASE),
types.BigInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^long", re.IGNORECASE),
types.Float(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^decimal", re.IGNORECASE),
types.Numeric(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^numeric", re.IGNORECASE),
types.Numeric(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^float", re.IGNORECASE),
types.Float(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^double", re.IGNORECASE),
types.Float(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^real", re.IGNORECASE),
types.REAL,
GenericDataType.NUMERIC,
),
(
re.compile(r"^smallserial", re.IGNORECASE),
types.SmallInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^serial", re.IGNORECASE),
types.Integer(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^bigserial", re.IGNORECASE),
types.BigInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^money", re.IGNORECASE),
types.Numeric(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^timestamp", re.IGNORECASE),
types.TIMESTAMP(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^datetime", re.IGNORECASE),
types.DateTime(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^date", re.IGNORECASE),
types.Date(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^time", re.IGNORECASE),
types.Time(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^interval", re.IGNORECASE),
types.Interval(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^bool(ean)?", re.IGNORECASE),
types.Boolean(),
GenericDataType.BOOLEAN,
),
)
# engine-specific type mappings to check prior to the defaults
column_type_mappings: tuple[ColumnTypeMapping, ...] = ()
# type-specific functions to mutate values received from the database.
# Needed on certain databases that return values in an unexpected format
column_type_mutators: dict[TypeEngine, Callable[[Any], Any]] = {}
# Does database support join-free timeslot grouping
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
supports_multivalues_insert = False
allows_joins = True
allows_subqueries = True
allows_alias_in_select = True
allows_alias_in_orderby = True
allows_sql_comments = True
allows_escaped_colons = True
# Whether ORDER BY clause can use aliases created in SELECT
# that are the same as a source column
allows_alias_to_source_column = True
# Whether ORDER BY clause must appear in SELECT
# if True, then it doesn't have to.
allows_hidden_orderby_agg = True
# Whether ORDER BY clause can use sql calculated expression
# if True, use alias of select column for `order by`
# the True is safely for most database
# But for backward compatibility, False by default
allows_hidden_cc_in_orderby = False
# Whether allow CTE as subquery or regular CTE
# If True, then it will allow in subquery ,
# if False it will allow as regular CTE
allows_cte_in_subquery = True
# Define alias for CTE
cte_alias = "__cte"
# Whether allow LIMIT clause in the SQL
# If True, then the database engine is allowed for LIMIT clause
# If False, then the database engine is allowed for TOP clause
allow_limit_clause = True
# This set will give keywords for select statements
# to consider for the engines with TOP SQL parsing
select_keywords: set[str] = {"SELECT"}
# This set will give the keywords for data limit statements
# to consider for the engines with TOP SQL parsing
top_keywords: set[str] = {"TOP"}
# A set of disallowed connection query parameters by driver name
disallow_uri_query_params: dict[str, set[str]] = {}
# A Dict of query parameters that will always be used on every connection
# by driver name
enforce_uri_query_params: dict[str, dict[str, Any]] = {}
force_column_alias_quotes = False
arraysize = 0
max_column_name_length: int | None = None
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
run_multiple_statements_as_one = False
custom_errors: dict[
Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]
] = {}
# List of JSON path to fields in `encrypted_extra` that should be masked when the
# database is edited. By default everything is masked.
# pylint: disable=invalid-name
encrypted_extra_sensitive_fields: set[str] = {"$.*"}
# Whether the engine supports file uploads
# if True, database will be listed as option in the upload file form
supports_file_upload = True
# Is the DB engine spec able to change the default schema? This requires implementing
# a custom `adjust_engine_params` method.
supports_dynamic_schema = False
# Does the DB support catalogs? A catalog here is a group of schemas, and has
# different names depending on the DB: BigQuery calles it a "project", Postgres calls
# it a "database", Trino calls it a "catalog", etc.
#
# When this is changed to true in a DB engine spec it MUST support the
# `get_default_catalog` and `get_catalog_names` methods. In addition, you MUST write
# a database migration updating any existing schema permissions using the helper
# `upgrade_catalog_perms`.
supports_catalog = False
# Can the catalog be changed on a per-query basis?
supports_dynamic_catalog = False
# Does the engine supports OAuth 2.0? This requires logic to be added to one of the
# the user impersonation methods to handle personal tokens.
supports_oauth2 = False
oauth2_scope = ""
oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name
oauth2_token_request_uri: str | None = None
oauth2_token_request_type = "data"
# Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError
# Does the query id related to the connection?
# The default value is True, which means that the query id is determined when
# the connection is created.
# When this is changed to false in a DB engine spec it means the query id
# is determined only after the specific query is executed and it will update
# the `cancel_query` value in the `extra` field of the `query` object
has_query_id_before_execute = True
@classmethod
def is_oauth2_enabled(cls) -> bool:
return (
cls.supports_oauth2
and cls.engine_name in current_app.config["DATABASE_OAUTH2_CLIENTS"]
)
@classmethod
def start_oauth2_dance(cls, database: Database) -> None:
"""
Start the OAuth2 dance.
This method will raise a custom exception that is captured by the frontend to
start the OAuth2 authentication. The frontend will open a new tab where the user
can authorize Superset to access the database. Once the user has authorized, the
tab sends a message to the original tab informing that authorization was
successful (or not), and then closes. The original tab will automatically
re-run the query after authorization.
"""
tab_id = str(uuid4())
default_redirect_uri = url_for("DatabaseRestApi.oauth2", _external=True)
# The state is passed to the OAuth2 provider, and sent back to Superset after
# the user authorizes the access. The redirect endpoint in Superset can then
# inspect the state to figure out to which user/database the access token
# belongs to.
state: OAuth2State = {
# Database ID and user ID are the primary key associated with the token.
"database_id": database.id,
"user_id": g.user.id,
# In multi-instance deployments there might be a single proxy handling
# redirects, with a custom `DATABASE_OAUTH2_REDIRECT_URI`. Since the OAuth2
# application requires every redirect URL to be registered a priori, this
# allows OAuth2 to be used where new instances are being constantly
# deployed. The proxy can extract `default_redirect_uri` from the state and
# then forward the token to the instance that initiated the authentication.
"default_redirect_uri": default_redirect_uri,
# When OAuth2 is complete the browser tab where OAuth2 happened will send a
# message to the original browser tab informing that the process was
# successful. To allow cross-tab commmunication in a safe way we assign a
# UUID to the original tab, and the second tab will use it when sending the
# message.
"tab_id": tab_id,
}
oauth2_config = database.get_oauth2_config()
if oauth2_config is None:
raise OAuth2Error("No configuration found for OAuth2")
oauth_url = cls.get_oauth2_authorization_uri(oauth2_config, state)
raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri)
@classmethod
def get_oauth2_config(cls) -> OAuth2ClientConfig | None:
"""
Build the DB engine spec level OAuth2 client config.
"""
oauth2_config = current_app.config["DATABASE_OAUTH2_CLIENTS"]
if cls.engine_name not in oauth2_config:
return None
db_engine_spec_config = oauth2_config[cls.engine_name]
redirect_uri = current_app.config.get(
"DATABASE_OAUTH2_REDIRECT_URI",
url_for("DatabaseRestApi.oauth2", _external=True),
)
config: OAuth2ClientConfig = {
"id": db_engine_spec_config["id"],
"secret": db_engine_spec_config["secret"],
"scope": db_engine_spec_config.get("scope") or cls.oauth2_scope,
"redirect_uri": redirect_uri,
"authorization_request_uri": db_engine_spec_config.get(
"authorization_request_uri",
cls.oauth2_authorization_request_uri,
),
"token_request_uri": db_engine_spec_config.get(
"token_request_uri",
cls.oauth2_token_request_uri,
),
"request_content_type": db_engine_spec_config.get(
"request_content_type", cls.oauth2_token_request_type
),
}
return config
@classmethod
def get_oauth2_authorization_uri(
cls,
config: OAuth2ClientConfig,
state: OAuth2State,
) -> str:
"""
Return URI for initial OAuth2 request.
"""
uri = config["authorization_request_uri"]
params = {
"scope": config["scope"],
"access_type": "offline",
"include_granted_scopes": "false",
"response_type": "code",
"state": encode_oauth2_state(state),
"redirect_uri": config["redirect_uri"],
"client_id": config["id"],
"prompt": "consent",
}
return urljoin(uri, "?" + urlencode(params))
@classmethod
def get_oauth2_token(
cls,
config: OAuth2ClientConfig,
code: str,
) -> OAuth2TokenResponse:
"""
Exchange authorization code for refresh/access tokens.
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
req_body = {
"code": code,
"client_id": config["id"],
"client_secret": config["secret"],
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
}
if config["request_content_type"] == "data":
return requests.post(uri, data=req_body, timeout=timeout).json()
return requests.post(uri, json=req_body, timeout=timeout).json()
@classmethod
def get_oauth2_fresh_token(
cls,
config: OAuth2ClientConfig,
refresh_token: str,
) -> OAuth2TokenResponse:
"""
Refresh an access token that has expired.
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
req_body = {
"client_id": config["id"],
"client_secret": config["secret"],
"refresh_token": refresh_token,
"grant_type": "refresh_token",
}
if config["request_content_type"] == "data":
return requests.post(uri, data=req_body, timeout=timeout).json()
return requests.post(uri, json=req_body, timeout=timeout).json()
@classmethod
def get_allows_alias_in_select(
cls,
database: Database, # pylint: disable=unused-argument
) -> bool:
"""
Method for dynamic `allows_alias_in_select`.
In Dremio this atribute is version-dependent, so Superset needs to inspect the
database configuration in order to determine it. This method allows engine-specs
to define dynamic values for the attribute.
"""
return cls.allows_alias_in_select
@classmethod
def supports_url(cls, url: URL) -> bool:
"""
Returns true if the DB engine spec supports a given SQLAlchemy URL.
As an example, if a given DB engine spec has:
class PostgresDBEngineSpec:
engine = "postgresql"
engine_aliases = "postgres"
drivers = {
"psycopg2": "The default Postgres driver",
"asyncpg": "An asynchronous Postgres driver",
}
It would be used for all the following SQLAlchemy URIs:
- postgres://user:password@host/db
- postgresql://user:password@host/db
- postgres+asyncpg://user:password@host/db
- postgres+psycopg2://user:password@host/db
- postgresql+asyncpg://user:password@host/db
- postgresql+psycopg2://user:password@host/db
Note that SQLAlchemy has a default driver even if one is not specified:
>>> from sqlalchemy.engine.url import make_url
>>> make_url('postgres://').get_driver_name()
'psycopg2'
"""
backend = url.get_backend_name()
driver = url.get_driver_name()
return cls.supports_backend(backend, driver)
@classmethod
def supports_backend(cls, backend: str, driver: str | None = None) -> bool:
"""
Returns true if the DB engine spec supports a given SQLAlchemy backend/driver.
"""
# check the backend first
if backend != cls.engine and backend not in cls.engine_aliases:
return False
# originally DB engine specs didn't declare any drivers and the check was made
# only on the engine; if that's the case, ignore the driver for backwards
# compatibility
if not cls.drivers or driver is None:
return True
return driver in cls.drivers
@classmethod
def get_default_catalog(
cls,
database: Database, # pylint: disable=unused-argument
) -> str | None:
"""
Return the default catalog for a given database.
"""
return None
@classmethod
def get_default_schema(cls, database: Database, catalog: str | None) -> str | None:
"""
Return the default schema for a catalog in a given database.
"""
with database.get_inspector(catalog=catalog) as inspector:
return inspector.default_schema_name
@classmethod
def get_schema_from_engine_params( # pylint: disable=unused-argument
cls,
sqlalchemy_uri: URL,
connect_args: dict[str, Any],
) -> str | None:
"""
Return the schema configured in a SQLALchemy URI and connection arguments, if any.
"""
return None
@classmethod
def get_default_schema_for_query(
cls,
database: Database,
query: Query,
) -> str | None:
"""
Return the default schema for a given query.
This is used to determine the schema of tables that aren't fully qualified, eg:
SELECT * FROM foo;
In the example above, the schema where the `foo` table lives depends on a few
factors:
1. For DB engine specs that allow dynamically changing the schema based on the
query we should use the query schema.
2. For DB engine specs that don't support dynamically changing the schema and
have the schema hardcoded in the SQLAlchemy URI we should use the schema
from the URI.
3. For DB engine specs that don't connect to a specific schema and can't
change it dynamically we need to probe the database for the default schema.
Determining the correct schema is crucial for managing access to data, so please
make sure you understand this logic when working on a new DB engine spec.
"""
# dynamic schema varies on a per-query basis
if cls.supports_dynamic_schema:
return query.schema
# check if the schema is stored in the SQLAlchemy URI or connection arguments
try:
connect_args = database.get_extra()["engine_params"]["connect_args"]
except KeyError:
connect_args = {}
sqlalchemy_uri = make_url_safe(database.sqlalchemy_uri)
if schema := cls.get_schema_from_engine_params(sqlalchemy_uri, connect_args):
return schema
# return the default schema of the database
return cls.get_default_schema(database, query.catalog)
@classmethod
def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
"""
Each engine can implement and converge its own specific exceptions into
Superset DBAPI exceptions
Note: On python 3.9 this method can be changed to a classmethod property
without the need of implementing a metaclass type
:return: A map of driver specific exception to superset custom exceptions
"""
return {}
@classmethod
def parse_error_exception(cls, exception: Exception) -> Exception:
"""
Each engine can implement and converge its own specific parser method
:return: An Exception with a parsed string off the original exception
"""
return exception
@classmethod
def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception:
"""
Get a superset custom DBAPI exception from the driver specific exception.
Override if the engine needs to perform extra changes to the exception, for
example change the exception message or implement custom more complex logic
:param exception: The driver specific exception
:return: Superset custom DBAPI exception
"""
new_exception = cls.get_dbapi_exception_mapping().get(type(exception))
if not new_exception:
return cls.parse_error_exception(exception)
return new_exception(str(exception))
@classmethod
def get_allow_cost_estimate( # pylint: disable=unused-argument
cls,
extra: dict[str, Any],
) -> bool:
return False
@classmethod
def get_text_clause(cls, clause: str) -> TextClause:
"""
SQLAlchemy wrapper to ensure text clauses are escaped properly
:param clause: string clause with potentially unescaped characters
:return: text clause with escaped characters
"""
if cls.allows_escaped_colons:
clause = clause.replace(":", "\\:")
return text(clause)
@classmethod
def get_engine(
cls,
database: Database,
catalog: str | None = None,
schema: str | None = None,
source: utils.QuerySource | None = None,
) -> ContextManager[Engine]:
"""
Return an engine context manager.
>>> with DBEngineSpec.get_engine(database, catalog, schema, source) as engine:
... connection = engine.connect()
... connection.execute(sql)
"""
return database.get_sqla_engine(catalog=catalog, schema=schema, source=source)
@classmethod
def get_timestamp_expr(
cls,
col: ColumnClause,
pdf: str | None,
time_grain: str | None,
) -> TimestampExpression:
"""
Construct a TimestampExpression to be used in a SQLAlchemy query.
:param col: Target column for the TimestampExpression
:param pdf: date format (seconds or milliseconds)
:param time_grain: time grain, e.g. P1Y for 1 year
:return: TimestampExpression object
"""
if time_grain:
type_ = str(getattr(col, "type", ""))
time_expr = cls.get_time_grain_expressions().get(time_grain)
if not time_expr:
raise NotImplementedError(
f"No grain spec for {time_grain} for database {cls.engine}"
)
if type_ and "{func}" in time_expr:
date_trunc_function = cls._date_trunc_functions.get(type_)
if date_trunc_function:
time_expr = time_expr.replace("{func}", date_trunc_function)
if type_ and "{type}" in time_expr:
date_trunc_function = cls._date_trunc_functions.get(type_)
if date_trunc_function:
time_expr = time_expr.replace("{type}", type_)
else:
time_expr = "{col}"
# if epoch, translate to DATE using db specific conf
if pdf == "epoch_s":
time_expr = time_expr.replace("{col}", cls.epoch_to_dttm())
elif pdf == "epoch_ms":
time_expr = time_expr.replace("{col}", cls.epoch_ms_to_dttm())
return TimestampExpression(time_expr, col, type_=col.type)
@classmethod
def get_time_grains(cls) -> tuple[TimeGrain, ...]:
"""
Generate a tuple of supported time grains.
:return: All time grains supported by the engine
"""
ret_list = []
time_grains = builtin_time_grains.copy()
time_grains.update(current_app.config["TIME_GRAIN_ADDONS"])
for duration, func in cls.get_time_grain_expressions().items():
if duration in time_grains:
name = time_grains[duration]
ret_list.append(TimeGrain(name, _(name), func, duration))
return tuple(ret_list)
@classmethod
def _sort_time_grains(
cls, val: tuple[str | None, str], index: int
) -> float | int | str:
"""
Return an ordered time-based value of a portion of a time grain
for sorting
Values are expected to be either None or start with P or PT
Have a numerical value in the middle and end with
a value for the time interval
It can also start or end with epoch start time denoting a range
i.e, week beginning or ending with a day
"""
pos = {
"FIRST": 0,
"SECOND": 1,
"THIRD": 2,
"LAST": 3,
}
if val[0] is None:
return pos["FIRST"]
prog = re.compile(r"(.*\/)?(P|PT)([0-9\.]+)(S|M|H|D|W|M|Y)(\/.*)?")
result = prog.match(val[0])
# for any time grains that don't match the format, put them at the end
if result is None:
return pos["LAST"]
second_minute_hour = ["S", "M", "H"]
day_week_month_year = ["D", "W", "M", "Y"]
is_less_than_day = result.group(2) == "PT"
interval = result.group(4)
epoch_time_start_string = result.group(1) or result.group(5)
has_starting_or_ending = bool(len(epoch_time_start_string or ""))
def sort_day_week() -> int:
if has_starting_or_ending:
return pos["LAST"]
if is_less_than_day:
return pos["SECOND"]
return pos["THIRD"]
def sort_interval() -> float:
if is_less_than_day:
return second_minute_hour.index(interval)
return day_week_month_year.index(interval)
# 0: all "PT" values should come before "P" values (i.e, PT10M)
# 1: order values within the above arrays ("D" before "W")
# 2: sort by numeric value (PT10M before PT15M)
# 3: sort by any week starting/ending values
plist = {
0: sort_day_week(),
1: pos["SECOND"] if is_less_than_day else pos["THIRD"],
2: sort_interval(),
3: float(result.group(3)),
}
return plist.get(index, 0)
@classmethod
def get_time_grain_expressions(cls) -> dict[str | None, str]:
"""
Return a dict of all supported time grains including any potential added grains
but excluding any potentially disabled grains in the config file.
:return: All time grain expressions supported by the engine
"""
# TODO: use @memoize decorator or similar to avoid recomputation on every call
time_grain_expressions = cls._time_grain_expressions.copy()
grain_addon_expressions = current_app.config["TIME_GRAIN_ADDON_EXPRESSIONS"]
time_grain_expressions.update(grain_addon_expressions.get(cls.engine, {}))
denylist: list[str] = current_app.config["TIME_GRAIN_DENYLIST"]
for key in denylist:
time_grain_expressions.pop(key, None)
return dict(
sorted(
time_grain_expressions.items(),
key=lambda x: (
cls._sort_time_grains(x, 0),
cls._sort_time_grains(x, 1),
cls._sort_time_grains(x, 2),
cls._sort_time_grains(x, 3),
),
)
)
@classmethod
def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]:
"""
:param cursor: Cursor instance
:param limit: Maximum number of rows to be returned by the cursor
:return: Result of query
"""
if cls.arraysize:
cursor.arraysize = cls.arraysize
try:
if cls.limit_method == LimitMethod.FETCH_MANY and limit:
return cursor.fetchmany(limit)
data = cursor.fetchall()
description = cursor.description or []
# Create a mapping between column name and a mutator function to normalize
# values with. The first two items in the description row are
# the column name and type.
column_mutators = {
row[0]: func
for row in description
if (
func := cls.column_type_mutators.get(
type(cls.get_sqla_column_type(cls.get_datatype(row[1])))
)
)
}
if column_mutators:
indexes = {row[0]: idx for idx, row in enumerate(description)}
for row_idx, row in enumerate(data):
new_row = list(row)
for col, func in column_mutators.items():
col_idx = indexes[col]
new_row[col_idx] = func(row[col_idx])
data[row_idx] = tuple(new_row)
return data
except Exception as ex:
raise cls.get_dbapi_mapped_exception(ex) from ex