diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index f0dca3e4..668a7bc3 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -15,6 +15,22 @@ jobs: runs-on: ubuntu-latest + # Service containers to run with `container-job` + services: + # Label used to access the service container + postgres: + # Docker Hub image + image: postgres + # Provide the password for postgres + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: lux + POSTGRES_DB: postgres + # Set health checks to wait until postgres has started + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 + ports: + - 5432:5432 + steps: - uses: actions/checkout@v2 - name: Set up Python 3.7 @@ -28,6 +44,11 @@ jobs: pip install wheel pip install -r requirements.txt pip install -r requirements-dev.txt + pip install sqlalchemy + - name: Upload data to Postgres + run: | + python lux/data/upload_car_data.py + python lux/data/upload_aug_test_data.py - name: Lint check with black run: | black --target-version py37 --line-length 105 --check . diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ad2ce1a9..0fb1bdf7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -37,7 +37,7 @@ lux/ ``` # Code Formatting -In order to keep our codebase clean and readible, we are using PEP8 guidelines. To help us maintain and check code style, we are using [black](https://github.com/psf/black). Simply run `black .` before commiting. Failure to do so may fail the tests run on Travis. This package should have been installed for you as part of [requirements-dev](https://github.com/lux-org/lux/blob/master/requirements-dev.txt). +In order to keep our codebase clean and readible, we are using PEP8 guidelines. To help us maintain and check code style, we are using [black](https://github.com/psf/black). Simply run `black .` before commiting. Without running black, the checks on the continuous integration tests can fail. `black` should be installed for you as part of [requirements-dev](https://github.com/lux-org/lux/blob/master/requirements-dev.txt). # Running the Test Suite @@ -55,8 +55,6 @@ To run a single test file, run: python -m pytest tests/.py ``` - - # Submitting a Pull Request You can commit your code and push to your forked repo. Once all of your local changes have been tested and formatted, you are ready to submit a PR. For Lux, we use the "Squash and Merge" strategy to merge in PR, which means that even if you make a lot of small commits in your PR, they will all get squashed into a single commit associated with the PR. Please make sure that comments and unnecessary file changes are not committed as part of the PR by looking at the "File Changes" diff view on the pull request page. diff --git a/doc/source/advanced/executor.rst b/doc/source/advanced/executor.rst index 93dadd84..e5347e98 100644 --- a/doc/source/advanced/executor.rst +++ b/doc/source/advanced/executor.rst @@ -11,12 +11,12 @@ Please refer to :mod:`lux.executor.Executor`, if you are interested in extending SQL Executor ============= -Lux extends its visualization exploration operations to data within SQL databases. By using the SQL Executor, users can specify a SQL database to connect a Lux Dataframe for generating all the visualizations recommended in Lux. +Lux extends its visualization exploration operations to data within SQL databases. By using the SQL Executor, users can specify a SQL database to connect a LuxSQLTable for generating all the visualizations recommended in Lux. Connecting Lux to a Database ---------------------------- -Before Lux can operate on data within a Postgresql database, users have to connect their Lux Dataframe to their database. +Before Lux can operate on data within a Postgresql database, users have to connect their LuxSQLTable to their database. To do this, users first need to specify a connection to their SQL database. This can be done using the psycopg2 package's functionality. .. code-block:: python @@ -24,28 +24,44 @@ To do this, users first need to specify a connection to their SQL database. This import psycopg2 connection = psycopg2.connect("dbname=example_database user=example_user, password=example_password") -Once this connection is created, users can connect their Lux Dataframe to the database using the Lux Dataframe's set_SQL_connection command. +Once this connection is created, users can connect the lux config to the database using the set_SQL_connection command. .. code-block:: python - lux_df.set_SQL_connection(connection, "my_table") + lux.config.set_SQL_connection(connection) -When the set_SQL_connection function is called, Lux will then populate the Dataframe with all the metadata it needs to run its intent from the database table. +When the set_SQL_connection function is called, Lux will then populate the LuxSQLTable with all the metadata it needs to run its intent from the database table. + +Connecting a LuxSQLTable to a Table/View +-------------------------- + +LuxSQLTables can be connected to individual tables or views created within your Postgresql database. This can be done by either specifying the table/view name in the constructor. + +.. code-block:: python + + sql_tbl = LuxSQLTable(table_name = "my_table") + +You can also connect a LuxSQLTable to a table/view by using the set_SQL_table function. + +.. code-block:: python + + sql_tbl = LuxSQLTable() + sql_tbl.set_SQL_table("my_table") Choosing an Executor -------------------------- Once a user has created a connection to their Postgresql database, they need to change Lux's execution engine so that the system can collect and process the data properly. -By default Lux uses the Pandas executor to process local data in the Lux Dataframe, but users need to use the SQL executor when their Lux Dataframe is connected to a database. -Users can specify the executor that a Lux Dataframe will use via the set_executor_type function as follows: +By default Lux uses the Pandas executor to process local data in the LuxDataframe, but users will use the SQL executor when their LuxSQLTable is connected to a database. +Users can specify the executor that Lux will use via the set_executor_type function as follows: .. code-block:: python lux_df.set_executor_type("SQL") -Once a Lux Dataframe has been connected to a Postgresql table and set to use the SQL Executor, users can take full advantage of Lux's visual exploration capabilities as-is. Users can set their intent to specify which variables they are most interested in and discover insightful visualizations from their database. +Once a LuxSQLTable has been connected to a Postgresql table and set to use the SQL Executor, users can take full advantage of Lux's visual exploration capabilities as-is. Users can set their intent to specify which variables they are most interested in and discover insightful visualizations from their database. SQL Executor Limitations -------------------------- -While users can make full use of Lux's functionalities on data within a database table, they will not be able to use any of Pandas' Dataframe functions to manipulate the data. Since the Lux SQL Executor delegates most data processing to the Postgresql database, it does not pull in the entire dataset into the Lux Dataframe. As such there is no actual data within the Lux Dataframe to manipulate, only the relevant metadata required to for Lux to manage its intent. Thus, if users are interested in manipulating or querying their data, this needs to be done through SQL or an alternative RDBMS interface. \ No newline at end of file +While users can make full use of Lux's functionalities on data within a database table, they will not be able to use any of Pandas' Dataframe functions to manipulate the data in the LuxSQLTable object. Since the Lux SQL Executor delegates most data processing to the Postgresql database, it does not pull in the entire dataset into the Lux Dataframe. As such there is no actual data within the LuxSQLTable to manipulate, only the relevant metadata required to for Lux to manage its intent. Thus, if users are interested in manipulating or querying their data, this needs to be done through SQL or an alternative RDBMS interface. \ No newline at end of file diff --git a/lux/__init__.py b/lux/__init__.py index 4b135ace..7d865410 100644 --- a/lux/__init__.py +++ b/lux/__init__.py @@ -15,6 +15,7 @@ # Register the commonly used modules (similar to how pandas does it: https://github.com/pandas-dev/pandas/blob/master/pandas/__init__.py) from lux.vis.Clause import Clause from lux.core.frame import LuxDataFrame +from lux.core.sqltable import LuxSQLTable from ._version import __version__, version_info from lux._config import config from lux._config.config import warning_format diff --git a/lux/_config/config.py b/lux/_config/config.py index 519f1f5d..59dd0ae5 100644 --- a/lux/_config/config.py +++ b/lux/_config/config.py @@ -343,25 +343,21 @@ def set_SQL_connection(self, connection): connection : SQLAlchemy connectable, str, or sqlite3 connection For more information, `see here `__ """ + self.set_executor_type("SQL") self.SQLconnection = connection def set_executor_type(self, exe): if exe == "SQL": - import pkgutil - - if pkgutil.find_loader("psycopg2") is None: - raise ImportError( - "psycopg2 is not installed. Run `pip install psycopg2' to install psycopg2 to enable the Postgres connection." - ) - else: - import psycopg2 from lux.executor.SQLExecutor import SQLExecutor self.executor = SQLExecutor() - else: + elif exe == "Pandas": from lux.executor.PandasExecutor import PandasExecutor + self.SQLconnection = "" self.executor = PandasExecutor() + else: + raise ValueError("Executor type must be either 'Pandas' or 'SQL'") def warning_format(message, category, filename, lineno, file=None, line=None): diff --git a/lux/action/correlation.py b/lux/action/correlation.py index 83af9f01..05a300cf 100644 --- a/lux/action/correlation.py +++ b/lux/action/correlation.py @@ -73,7 +73,6 @@ def correlation(ldf: LuxDataFrame, ignore_transpose: bool = True): ) msr1 = measures[0].attribute msr2 = measures[1].attribute - if ignore_transpose: check_transpose = check_transpose_not_computed(vlist, msr1, msr2) else: diff --git a/lux/action/custom.py b/lux/action/custom.py index 4fa7b450..1fe01efc 100644 --- a/lux/action/custom.py +++ b/lux/action/custom.py @@ -64,7 +64,7 @@ def custom_actions(ldf): recommendations : Dict[str,obj] object with a collection of visualizations that were previously registered. """ - if len(lux.config.actions) > 0 and len(ldf) > 0: + if len(lux.config.actions) > 0 and (len(ldf) > 0 or lux.config.executor.name != "PandasExecutor"): recommendations = [] for action_name in lux.config.actions.keys(): display_condition = lux.config.actions[action_name].display_condition diff --git a/lux/action/univariate.py b/lux/action/univariate.py index e44e8321..e94af9bd 100644 --- a/lux/action/univariate.py +++ b/lux/action/univariate.py @@ -84,13 +84,10 @@ def univariate(ldf, *args): examples = f" (e.g., {possible_attributes[0]})" intent = [lux.Clause("?", data_type="geographical"), lux.Clause("?", data_model="measure")] intent.extend(filter_specs) - long_description = f"Geographical displays choropleths for geographic attribute{examples}, with colors indicating the average measure values. " - if lux.config.plotting_backend == "matplotlib": - long_description += "The map visualizations from the 'Geographical' tab are rendered using Altair. Lux does not currently support geographical maps with Matplotlib. If you would like this feature, please leave us a comment at issue #310 to let us know!" recommendation = { "action": "Geographical", "description": "Show choropleth maps of

geographic

attributes", - "long_description": long_description, + "long_description": f"Occurence displays choropleths of averages for some geographic attribute{examples}. Visualizations are ranked by diversity of the geographic attribute.", } elif data_type_constraint == "temporal": intent = [lux.Clause("?", data_type="temporal")] diff --git a/lux/core/frame.py b/lux/core/frame.py index a57b3552..c1a12e0b 100644 --- a/lux/core/frame.py +++ b/lux/core/frame.py @@ -18,6 +18,7 @@ from lux.vis.Vis import Vis from lux.vis.VisList import VisList from lux.history.history import History +from lux.utils.date_utils import is_datetime_series from lux.utils.message import Message from lux.utils.utils import check_import_lux_widget from typing import Dict, Union, List, Callable @@ -57,8 +58,6 @@ class LuxDataFrame(pd.DataFrame): ] def __init__(self, *args, **kw): - from lux.executor.PandasExecutor import PandasExecutor - self._history = History() self._intent = [] self._inferred_intent = [] @@ -70,7 +69,14 @@ def __init__(self, *args, **kw): super(LuxDataFrame, self).__init__(*args, **kw) self.table_name = "" - lux.config.executor = PandasExecutor() + if lux.config.SQLconnection == "": + from lux.executor.PandasExecutor import PandasExecutor + + lux.config.executor = PandasExecutor() + else: + from lux.executor.SQLExecutor import SQLExecutor + + lux.config.executor = SQLExecutor() self._sampled = None self._toggle_pandas_display = True @@ -110,14 +116,25 @@ def data_type(self): return self._data_type def maintain_metadata(self): + is_sql_tbl = lux.config.executor.name == "SQLExecutor" + if lux.config.SQLconnection != "" and is_sql_tbl: + from lux.executor.SQLExecutor import SQLExecutor + + lux.config.executor = SQLExecutor() + # Check that metadata has not yet been computed if not hasattr(self, "_metadata_fresh") or not self._metadata_fresh: # only compute metadata information if the dataframe is non-empty - if len(self) > 0: - lux.config.executor.compute_stats(self) + if is_sql_tbl: lux.config.executor.compute_dataset_metadata(self) self._infer_structure() self._metadata_fresh = True + else: + if len(self) > 0: + lux.config.executor.compute_stats(self) + lux.config.executor.compute_dataset_metadata(self) + self._infer_structure() + self._metadata_fresh = True def expire_recs(self): """ @@ -168,12 +185,14 @@ def _infer_structure(self): # If the dataframe is very small and the index column is not a range index, then it is likely that this is an aggregated data is_multi_index_flag = self.index.nlevels != 1 not_int_index_flag = not pd.api.types.is_integer_dtype(self.index) - small_df_flag = len(self) < 100 + is_sql_tbl = lux.config.executor.name == "SQLExecutor" + + small_df_flag = len(self) < 100 and is_sql_tbl if self.pre_aggregated == None: self.pre_aggregated = (is_multi_index_flag or not_int_index_flag) and small_df_flag if "Number of Records" in self.columns: self.pre_aggregated = True - self.pre_aggregated = "groupby" in [event.name for event in self.history] + self.pre_aggregated = "groupby" in [event.name for event in self.history] and not is_sql_tbl @property def intent(self): @@ -317,110 +336,6 @@ def current_vis(self): def current_vis(self, current_vis: Dict): self._current_vis = current_vis - ####################################################### - ########## SQL Metadata, type, model schema ########### - ####################################################### - - def set_SQL_table(self, t_name): - self.table_name = t_name - self.compute_SQL_dataset_metadata() - - def compute_SQL_dataset_metadata(self): - self.get_SQL_attributes() - for attr in list(self.columns): - self[attr] = None - self._data_type = {} - #####NOTE: since we aren't expecting users to do much data processing with the SQL database, should we just keep this - ##### in the initialization and do it just once - self.compute_SQL_data_type() - self.compute_SQL_stats() - - def compute_SQL_stats(self): - # precompute statistics - self.unique_values = {} - self._min_max = {} - - self.get_SQL_unique_values() - # self.get_SQL_cardinality() - for attribute in self.columns: - if self._data_type[attribute] == "quantitative": - self._min_max[attribute] = ( - self[attribute].min(), - self[attribute].max(), - ) - - def get_SQL_attributes(self): - if "." in self.table_name: - table_name = self.table_name[self.table_name.index(".") + 1 :] - else: - table_name = self.table_name - query = f"SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS where TABLE_NAME = '{table_name}'" - attributes = list(pd.read_sql(query, lux.config.SQLconnection)["column_name"]) - for attr in attributes: - self[attr] = None - - def get_SQL_cardinality(self): - cardinality = {} - for attr in list(self.columns): - card_query = pd.read_sql( - f"SELECT Count(Distinct({attr})) FROM {self.table_name}", - lux.config.SQLconnection, - ) - cardinality[attr] = list(card_query["count"])[0] - self.cardinality = cardinality - - def get_SQL_unique_values(self): - unique_vals = {} - for attr in list(self.columns): - unique_query = pd.read_sql( - f"SELECT Distinct({attr}) FROM {self.table_name}", - lux.config.SQLconnection, - ) - unique_vals[attr] = list(unique_query[attr]) - self.unique_values = unique_vals - - def compute_SQL_data_type(self): - data_type = {} - sql_dtypes = {} - self.get_SQL_cardinality() - if "." in self.table_name: - table_name = self.table_name[self.table_name.index(".") + 1 :] - else: - table_name = self.table_name - # get the data types of the attributes in the SQL table - for attr in list(self.columns): - query = f"SELECT DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '{table_name}' AND COLUMN_NAME = '{attr}'" - datatype = list(pd.read_sql(query, lux.config.SQLconnection)["data_type"])[0] - sql_dtypes[attr] = datatype - - for attr in list(self.columns): - if attr in self._type_override: - data_type[attr] = self._type_override[attr] - elif str(attr).lower() in ["month", "year"]: - data_type[attr] = "temporal" - elif sql_dtypes[attr] in [ - "character", - "character varying", - "boolean", - "uuid", - "text", - ]: - data_type[attr] = "nominal" - elif sql_dtypes[attr] in [ - "integer", - "real", - "smallint", - "smallserial", - "serial", - ]: - if self.cardinality[attr] < 13: - data_type[attr] = "nominal" - else: - data_type[attr] = "quantitative" - elif "time" in sql_dtypes[attr] or "date" in sql_dtypes[attr]: - data_type[attr] = "temporal" - self._data_type = data_type - def _append_rec(self, rec_infolist, recommendations: Dict): if recommendations["collection"] is not None and len(recommendations["collection"]) > 0: rec_infolist.append(recommendations) @@ -470,6 +385,7 @@ def maintain_recs(self, is_series="DataFrame"): # Check that recs has not yet been computed if not hasattr(rec_df, "_recs_fresh") or not rec_df._recs_fresh: + is_sql_tbl = lux.config.executor.name == "SQLExecutor" rec_infolist = [] from lux.action.row_group import row_group from lux.action.column_group import column_group @@ -479,7 +395,7 @@ def maintain_recs(self, is_series="DataFrame"): if rec_df.columns.name is not None: rec_df._append_rec(rec_infolist, row_group(rec_df)) rec_df._append_rec(rec_infolist, column_group(rec_df)) - elif not (len(rec_df) < 5 and not rec_df.pre_aggregated) and not ( + elif not (len(rec_df) < 5 and not rec_df.pre_aggregated and not is_sql_tbl) and not ( self.index.nlevels >= 2 or self.columns.nlevels >= 2 ): from lux.action.custom import custom_actions diff --git a/lux/core/series.py b/lux/core/series.py index a5e2defd..2b902730 100644 --- a/lux/core/series.py +++ b/lux/core/series.py @@ -43,13 +43,13 @@ class LuxSeries(pd.Series): "_prev", "_history", "_saved_export", + "name", "_sampled", "_toggle_pandas_display", "_message", "_pandas_only", "pre_aggregated", "_type_override", - "name", ] _default_metadata = { diff --git a/lux/core/sqltable.py b/lux/core/sqltable.py new file mode 100644 index 00000000..de426f70 --- /dev/null +++ b/lux/core/sqltable.py @@ -0,0 +1,186 @@ +# Copyright 2019-2020 The Lux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +from lux.core.series import LuxSeries +from lux.vis.Clause import Clause +from lux.vis.Vis import Vis +from lux.vis.VisList import VisList +from lux.history.history import History +from lux.utils.date_utils import is_datetime_series +from lux.utils.message import Message +from lux.utils.utils import check_import_lux_widget +from typing import Dict, Union, List, Callable + +# from lux.executor.Executor import * +import warnings +import traceback +import lux + + +class LuxSQLTable(lux.LuxDataFrame): + """ + A subclass of Lux.LuxDataFrame that houses other variables and functions for generating visual recommendations. Does not support normal pandas functionality. + """ + + # MUST register here for new properties!! + _metadata = [ + "_intent", + "_inferred_intent", + "_data_type", + "unique_values", + "cardinality", + "_rec_info", + "_min_max", + "_current_vis", + "_widget", + "_recommendation", + "_prev", + "_history", + "_saved_export", + "_sampled", + "_toggle_pandas_display", + "_message", + "_pandas_only", + "pre_aggregated", + "_type_override", + "_length", + "_setup_done", + ] + + def __init__(self, *args, table_name="", **kw): + super(LuxSQLTable, self).__init__(*args, **kw) + from lux.executor.SQLExecutor import SQLExecutor + + lux.config.executor = SQLExecutor() + + self._length = 0 + self._setup_done = False + if table_name != "": + self.set_SQL_table(table_name) + warnings.formatwarning = lux.warning_format + + def __len__(self): + if self._setup_done: + return self._length + else: + return super(LuxSQLTable, self).__len__() + + def set_SQL_table(self, t_name): + # function that ties the Lux Dataframe to a SQL database table + if self.table_name != "": + warnings.warn( + f"\nThis LuxSQLTable is already tied to a database table. Please create a new Lux dataframe and connect it to your table '{t_name}'.", + stacklevel=2, + ) + else: + self.table_name = t_name + import psycopg2 + + try: + lux.config.executor.compute_dataset_metadata(self) + except Exception as error: + error_str = str(error) + if f'relation "{t_name}" does not exist' in error_str: + warnings.warn( + f"\nThe table '{t_name}' does not exist in your database./", + stacklevel=2, + ) + + def _ipython_display_(self): + from IPython.display import HTML, Markdown, display + from IPython.display import clear_output + import ipywidgets as widgets + + try: + if self._pandas_only: + display(self.display_pandas()) + self._pandas_only = False + if not self.index.nlevels >= 2 or self.columns.nlevels >= 2: + self.maintain_metadata() + + if self._intent != [] and (not hasattr(self, "_compiled") or not self._compiled): + from lux.processor.Compiler import Compiler + + self.current_vis = Compiler.compile_intent(self, self._intent) + + if lux.config.default_display == "lux": + self._toggle_pandas_display = False + else: + self._toggle_pandas_display = True + + # df_to_display.maintain_recs() # compute the recommendations (TODO: This can be rendered in another thread in the background to populate self._widget) + self.maintain_recs() + + # Observers(callback_function, listen_to_this_variable) + self._widget.observe(self.remove_deleted_recs, names="deletedIndices") + self._widget.observe(self.set_intent_on_click, names="selectedIntentIndex") + + button = widgets.Button( + description="Toggle Table/Lux", + layout=widgets.Layout(width="200px", top="6px", bottom="6px"), + ) + self.output = widgets.Output() + self._sampled = lux.config.executor.execute_preview(self) + display(button, self.output) + + def on_button_clicked(b): + with self.output: + if b: + self._toggle_pandas_display = not self._toggle_pandas_display + clear_output() + + # create connection string to display + connect_str = self.table_name + connection_type = str(type(lux.config.SQLconnection)) + if "psycopg2.extensions.connection" in connection_type: + connection_dsn = lux.config.SQLconnection.get_dsn_parameters() + host_name = connection_dsn["host"] + host_port = connection_dsn["port"] + dbname = connection_dsn["dbname"] + connect_str = host_name + ":" + host_port + "/" + dbname + + elif "sqlalchemy.engine.base.Engine" in connection_type: + db_connection = str(lux.config.SQLconnection) + db_start = db_connection.index("@") + 1 + db_end = len(db_connection) - 1 + connect_str = db_connection[db_start:db_end] + + if self._toggle_pandas_display: + notification = "Here is a preview of the **{}** database table: **{}**".format( + self.table_name, connect_str + ) + display(Markdown(notification), self._sampled.display_pandas()) + else: + # b.layout.display = "none" + display(self._widget) + # b.layout.display = "inline-block" + + button.on_click(on_button_clicked) + on_button_clicked(None) + + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + if lux.config.pandas_fallback: + warnings.warn( + "\nUnexpected error in rendering Lux widget and recommendations. " + "Falling back to Pandas display.\n" + "Please report the following issue on Github: https://github.com/lux-org/lux/issues \n", + stacklevel=2, + ) + warnings.warn(traceback.format_exc()) + display(self.display_pandas()) + else: + raise diff --git a/lux/data/upload_aug_test_data.py b/lux/data/upload_aug_test_data.py new file mode 100644 index 00000000..a6982192 --- /dev/null +++ b/lux/data/upload_aug_test_data.py @@ -0,0 +1,44 @@ +import pandas as pd +import psycopg2 +import csv + +conn = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") +cur = conn.cursor() +cur.execute( + """ + DROP TABLE IF EXISTS aug_test_table + """ +) +# create car table in postgres database +cur.execute( + """ + CREATE TABLE aug_test_table( + enrollee_id integer, + city text, + city_development_index numeric, + gender text, + relevent_experience text, + enrolled_university text, + education_level text, + major_discipline text, + experience text, + company_size text, + company_type text, + last_new_job text, + training_hours integer +) +""" +) + +# open car.csv and read data into database +import urllib.request + +target_url = "https://raw.githubusercontent.com/lux-org/lux-datasets/master/data/aug_test.csv" +for line in urllib.request.urlopen(target_url): + decoded = line.decode("utf-8") + if "enrollee_id,city,city_development_index" not in decoded: + cur.execute( + "INSERT INTO aug_test_table VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)", + decoded.split(","), + ) +conn.commit() diff --git a/lux/data/upload_car_data.py b/lux/data/upload_car_data.py new file mode 100644 index 00000000..92894a4f --- /dev/null +++ b/lux/data/upload_car_data.py @@ -0,0 +1,46 @@ +import pandas as pd +import psycopg2 +import csv + +from sqlalchemy import create_engine + +data = pd.read_csv("https://raw.githubusercontent.com/lux-org/lux-datasets/master/data/car.csv") +engine = create_engine("postgresql://postgres:lux@localhost:5432") +data.to_sql(name="car", con=engine, if_exists="replace", index=False) + +conn = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") +cur = conn.cursor() +cur.execute( + """ + DROP TABLE IF EXISTS cars + """ +) +# create car table in postgres database +cur.execute( + """ + CREATE TABLE cars( + name text, + milespergal numeric, + cylinders integer, + displacement numeric, + horsepower integer, + weight integer, + acceleration numeric, + year integer, + origin text, + brand text +) +""" +) + +# open car.csv and read data into database +import urllib.request + +target_url = "https://raw.githubusercontent.com/lux-org/lux-datasets/master/data/car.csv" +for line in urllib.request.urlopen(target_url): + decoded = line.decode("utf-8") + if "Name,MilesPerGal,Cylinders" not in decoded: + cur.execute( + "INSERT INTO cars VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)", decoded.split(",") + ) +conn.commit() diff --git a/lux/data/upload_flights_data.py b/lux/data/upload_flights_data.py new file mode 100644 index 00000000..6a63e301 --- /dev/null +++ b/lux/data/upload_flights_data.py @@ -0,0 +1,45 @@ +import pandas as pd +from sqlalchemy import create_engine +import psycopg2 +import csv + +import psycopg2 + +conn = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") +cur = conn.cursor() +cur.execute( + """ + DROP TABLE IF EXISTS flights + """ +) + +# create flights table in postgres database +cur.execute( + """ + CREATE TABLE flights( + year integer, + month text, + day integer, + weekday integer, + carrier text, + origin text, + destination text, + arrivaldelay integer, + depaturedelay integer, + weatherdelay integer, + distance integer +) +""" +) + +# open car.csv and read data into database +import urllib.request + +target_url = "https://raw.githubusercontent.com/lux-org/lux-datasets/master/data/flights.csv" +for line in urllib.request.urlopen(target_url): + decoded = line.decode("utf-8") + if "day,weekday,carrier,origin" not in decoded: + cur.execute( + "INSERT INTO flights VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)", decoded.split(",") + ) +conn.commit() diff --git a/lux/executor/PandasExecutor.py b/lux/executor/PandasExecutor.py index 33e41b70..4ecfb675 100644 --- a/lux/executor/PandasExecutor.py +++ b/lux/executor/PandasExecutor.py @@ -379,7 +379,7 @@ def execute_2D_binning(vis: Vis): (color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]), ] ).reset_index() - elif color_attr.data_type == "quantitative": + elif color_attr.data_type == "quantitative" or color_attr.data_type == "temporal": # Compute the average of all values in the bin result = groups.agg( [("count", "count"), (color_attr.attribute, "mean")] @@ -513,7 +513,8 @@ def _is_geographical_attribute(series): @staticmethod def _is_datetime_number(series): - if series.dtype == int: + is_int_dtype = pd.api.types.is_integer_dtype(series.dtype) + if is_int_dtype: try: temp = series.astype(str) pd.to_datetime(temp) @@ -527,6 +528,7 @@ def compute_stats(self, ldf: LuxDataFrame): ldf.unique_values = {} ldf._min_max = {} ldf.cardinality = {} + ldf._length = len(ldf) for attribute in ldf.columns: diff --git a/lux/executor/SQLExecutor.py b/lux/executor/SQLExecutor.py index 6fdab393..a2061d11 100644 --- a/lux/executor/SQLExecutor.py +++ b/lux/executor/SQLExecutor.py @@ -1,23 +1,12 @@ -# Copyright 2019-2020 The Lux Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import pandas from lux.vis.VisList import VisList from lux.vis.Vis import Vis -from lux.core.frame import LuxDataFrame +from lux.core.sqltable import LuxSQLTable from lux.executor.Executor import Executor from lux.utils import utils +from lux.utils.utils import check_import_lux_widget, check_if_id_like +import lux + import math @@ -36,53 +25,145 @@ def __repr__(self): return f"" @staticmethod - def execute(vislist: VisList, ldf: LuxDataFrame): - import pandas as pd - - """ - Given a VisList, fetch the data required to render the vis - 1) Apply filters - 2) Retreive relevant attribute - 3) return a DataFrame with relevant results - """ - for vis in vislist: - # Select relevant data based on attribute information - attributes = set([]) - for clause in vis._inferred_intent: - if clause.attribute: - if clause.attribute == "Record": - attributes.add(clause.attribute) - # else: - attributes.add(clause.attribute) - if vis.mark not in ["bar", "line", "histogram"]: - where_clause, filterVars = SQLExecutor.execute_filter(vis) - required_variables = attributes | set(filterVars) - required_variables = ",".join(required_variables) - row_count = list( - pd.read_sql( - f"SELECT COUNT(*) FROM {lux.config.table_name} {where_clause}", - ldf.SQLconnection, - )["count"] - )[0] - if row_count > 10000: - query = f"SELECT {required_variables} FROM {lux.config.table_name} {where_clause} ORDER BY random() LIMIT 10000" + def execute_preview(tbl: LuxSQLTable, preview_size=5): + output = pandas.read_sql( + "SELECT * from {} LIMIT {}".format(tbl.table_name, preview_size), lux.config.SQLconnection + ) + return output + + @staticmethod + def execute_sampling(tbl: LuxSQLTable): + SAMPLE_FLAG = lux.config.sampling + SAMPLE_START = lux.config.sampling_start + SAMPLE_CAP = lux.config.sampling_cap + SAMPLE_FRAC = 0.2 + + length_query = pandas.read_sql( + "SELECT COUNT(*) as length FROM {}".format(tbl.table_name), + lux.config.SQLconnection, + ) + limit = int(list(length_query["length"])[0]) * SAMPLE_FRAC + tbl._sampled = pandas.read_sql( + "SELECT * from {} LIMIT {}".format(tbl.table_name, str(limit)), lux.config.SQLconnection + ) + + @staticmethod + def execute(view_collection: VisList, tbl: LuxSQLTable): + """ + Given a VisList, fetch the data required to render the view + 1) Generate Necessary WHERE clauses + 2) Query necessary data, applying appropriate aggregation for the chart type + 3) populates vis' data with a DataFrame with relevant results + """ + + for view in view_collection: + # choose execution method depending on vis mark type + + # when mark is empty, deal with lazy execution by filling the data with a small sample of the dataframe + if view.mark == "": + SQLExecutor.execute_sampling(tbl) + view._vis_data = tbl._sampled + if view.mark == "scatter": + where_clause, filterVars = SQLExecutor.execute_filter(view) + length_query = pandas.read_sql( + "SELECT COUNT(1) as length FROM {} {}".format(tbl.table_name, where_clause), + lux.config.SQLconnection, + ) + view_data_length = list(length_query["length"])[0] + if len(view.get_attr_by_channel("color")) == 1 or view_data_length < 5000: + # NOTE: might want to have a check somewhere to not use categorical variables with greater than some number of categories as a Color variable---------------- + has_color = True + SQLExecutor.execute_scatter(view, tbl) else: - query = f"SELECT {required_variables} FROM {lux.config.table_name} {where_clause}" - data = pd.read_sql(query, ldf.SQLconnection) - vis._vis_data = utils.pandas_to_lux(data) - if vis.mark == "bar" or vis.mark == "line": - SQLExecutor.execute_aggregate(vis, ldf) - elif vis.mark == "histogram": - SQLExecutor.execute_binning(vis, ldf) + view._mark = "heatmap" + SQLExecutor.execute_2D_binning(view, tbl) + elif view.mark == "bar" or view.mark == "line": + SQLExecutor.execute_aggregate(view, tbl) + elif view.mark == "histogram": + SQLExecutor.execute_binning(view, tbl) @staticmethod - def execute_aggregate(vis: Vis, ldf: LuxDataFrame): - import pandas as pd + def execute_scatter(view: Vis, tbl: LuxSQLTable): + """ + Given a scatterplot vis and a Lux Dataframe, fetch the data required to render the vis. + 1) Generate WHERE clause for the SQL query + 2) Check number of datapoints to be included in the query + 3) If the number of datapoints exceeds 10000, perform a random sample from the original data + 4) Query datapoints needed for the scatterplot visualization + 5) return a DataFrame with relevant results + + Parameters + ---------- + vislist: list[lux.Vis] + vis list that contains lux.Vis objects for visualization. + tbl : lux.core.frame + LuxSQLTable with specified intent. + + Returns + ------- + None + """ + + attributes = set([]) + for clause in view._inferred_intent: + if clause.attribute: + if clause.attribute != "Record": + attributes.add(clause.attribute) + where_clause, filterVars = SQLExecutor.execute_filter(view) + + length_query = pandas.read_sql( + "SELECT COUNT(1) as length FROM {} {}".format(tbl.table_name, where_clause), + lux.config.SQLconnection, + ) + + def add_quotes(var_name): + return '"' + var_name + '"' + + required_variables = attributes | set(filterVars) + required_variables = map(add_quotes, required_variables) + required_variables = ",".join(required_variables) + row_count = list( + pandas.read_sql( + f"SELECT COUNT(*) FROM {tbl.table_name} {where_clause}", + lux.config.SQLconnection, + )["count"] + )[0] + if row_count > lux.config.sampling_cap: + query = f"SELECT {required_variables} FROM {tbl.table_name} {where_clause} ORDER BY random() LIMIT 10000" + else: + query = "SELECT {} FROM {} {}".format(required_variables, tbl.table_name, where_clause) + data = pandas.read_sql(query, lux.config.SQLconnection) + view._vis_data = utils.pandas_to_lux(data) + # view._vis_data.length = list(length_query["length"])[0] + + tbl._message.add_unique( + f"Large scatterplots detected: Lux is automatically binning scatterplots to heatmaps.", + priority=98, + ) - x_attr = vis.get_attr_by_channel("x")[0] - y_attr = vis.get_attr_by_channel("y")[0] + @staticmethod + def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True): + """ + Aggregate data points on an axis for bar or line charts + Parameters + ---------- + vis: lux.Vis + lux.Vis object that represents a visualization + tbl : lux.core.frame + LuxSQLTable with specified intent. + isFiltered: boolean + boolean that represents whether a vis has had a filter applied to its data + Returns + ------- + None + """ + x_attr = view.get_attr_by_channel("x")[0] + y_attr = view.get_attr_by_channel("y")[0] + has_color = False groupby_attr = "" measure_attr = "" + if x_attr.aggregation is None or y_attr.aggregation is None: + return if y_attr.aggregation != "": groupby_attr = x_attr measure_attr = y_attr @@ -91,69 +172,248 @@ def execute_aggregate(vis: Vis, ldf: LuxDataFrame): groupby_attr = y_attr measure_attr = x_attr agg_func = x_attr.aggregation - + if groupby_attr.attribute in tbl.unique_values.keys(): + attr_unique_vals = tbl.unique_values[groupby_attr.attribute] + # checks if color is specified in the Vis + if len(view.get_attr_by_channel("color")) == 1: + color_attr = view.get_attr_by_channel("color")[0] + color_attr_vals = tbl.unique_values[color_attr.attribute] + color_cardinality = len(color_attr_vals) + # NOTE: might want to have a check somewhere to not use categorical variables with greater than some number of categories as a Color variable---------------- + has_color = True + else: + color_cardinality = 1 if measure_attr != "": # barchart case, need count data for each group if measure_attr.attribute == "Record": - where_clause, filterVars = SQLExecutor.execute_filter(vis) - count_query = f"SELECT {groupby_attr.attribute}, COUNT({groupby_attr.attribute}) FROM {lux.config.table_name} {where_clause} GROUP BY {groupby_attr.attribute}" - vis._vis_data = pd.read_sql(count_query, ldf.SQLconnection) - vis._vis_data = vis.data.rename(columns={"count": "Record"}) - vis._vis_data = utils.pandas_to_lux(vis.data) + where_clause, filterVars = SQLExecutor.execute_filter(view) + length_query = pandas.read_sql( + "SELECT COUNT(*) as length FROM {} {}".format(tbl.table_name, where_clause), + lux.config.SQLconnection, + ) + # generates query for colored barchart case + if has_color: + count_query = 'SELECT "{}", "{}", COUNT("{}") FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + groupby_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, + ) + view._vis_data = pandas.read_sql(count_query, lux.config.SQLconnection) + view._vis_data = view._vis_data.rename(columns={"count": "Record"}) + view._vis_data = utils.pandas_to_lux(view._vis_data) + # generates query for normal barchart case + else: + count_query = 'SELECT "{}", COUNT("{}") FROM {} {} GROUP BY "{}"'.format( + groupby_attr.attribute, + groupby_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + ) + view._vis_data = pandas.read_sql(count_query, lux.config.SQLconnection) + view._vis_data = view._vis_data.rename(columns={"count": "Record"}) + view._vis_data = utils.pandas_to_lux(view._vis_data) + # view._vis_data.length = list(length_query["length"])[0] + # aggregate barchart case, need aggregate data (mean, sum, max) for each group else: - where_clause, filterVars = SQLExecutor.execute_filter(vis) - if agg_func == "mean": - mean_query = f"SELECT {groupby_attr.attribute}, AVG({measure_attr.attribute}) as {measure_attr.attribute} FROM {lux.config.table_name} {where_clause} GROUP BY {groupby_attr.attribute}" - vis._vis_data = pd.read_sql(mean_query, ldf.SQLconnection) - vis._vis_data = utils.pandas_to_lux(vis.data) - if agg_func == "sum": - mean_query = f"SELECT {groupby_attr.attribute}, SUM({measure_attr.attribute}) as {measure_attr.attribute} FROM {lux.config.table_name} {where_clause} GROUP BY {groupby_attr.attribute}" - vis._vis_data = pd.read_sql(mean_query, ldf.SQLconnection) - vis._vis_data = utils.pandas_to_lux(vis.data) - if agg_func == "max": - mean_query = f"SELECT {groupby_attr.attribute}, MAX({measure_attr.attribute}) as {measure_attr.attribute} FROM {lux.config.table_name} {where_clause} GROUP BY {groupby_attr.attribute}" - vis._vis_data = pd.read_sql(mean_query, ldf.SQLconnection) - vis._vis_data = utils.pandas_to_lux(vis.data) - - # pad empty categories with 0 counts after filter is applied - all_attr_vals = ldf.unique_values[groupby_attr.attribute] - result_vals = list(vis.data[groupby_attr.attribute]) - if len(result_vals) != len(all_attr_vals): - # For filtered aggregation that have missing groupby-attribute values, set these aggregated value as 0, since no datapoints - for vals in all_attr_vals: - if vals not in result_vals: - vis.data.loc[len(vis.data)] = [vals] + [0] * (len(vis.data.columns) - 1) + where_clause, filterVars = SQLExecutor.execute_filter(view) + + length_query = pandas.read_sql( + "SELECT COUNT(*) as length FROM {} {}".format(tbl.table_name, where_clause), + lux.config.SQLconnection, + ) + # generates query for colored barchart case + if has_color: + if agg_func == "mean": + agg_query = ( + 'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, + ) + ) + view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) + + view._vis_data = utils.pandas_to_lux(view._vis_data) + if agg_func == "sum": + agg_query = ( + 'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, + ) + ) + view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) + view._vis_data = utils.pandas_to_lux(view._vis_data) + if agg_func == "max": + agg_query = ( + 'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, + ) + ) + view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) + view._vis_data = utils.pandas_to_lux(view._vis_data) + # generates query for normal barchart case + else: + if agg_func == "mean": + agg_query = 'SELECT "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}"'.format( + groupby_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + ) + view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) + view._vis_data = utils.pandas_to_lux(view._vis_data) + if agg_func == "sum": + agg_query = 'SELECT "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}"'.format( + groupby_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + ) + view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) + view._vis_data = utils.pandas_to_lux(view._vis_data) + if agg_func == "max": + agg_query = 'SELECT "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}"'.format( + groupby_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + ) + view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) + view._vis_data = utils.pandas_to_lux(view._vis_data) + result_vals = list(view._vis_data[groupby_attr.attribute]) + # create existing group by attribute combinations if color is specified + # this is needed to check what combinations of group_by_attr and color_attr values have a non-zero number of elements in them + if has_color: + res_color_combi_vals = [] + result_color_vals = list(view._vis_data[color_attr.attribute]) + for i in range(0, len(result_vals)): + res_color_combi_vals.append([result_vals[i], result_color_vals[i]]) + # For filtered aggregation that have missing groupby-attribute values, set these aggregated value as 0, since no datapoints + if isFiltered or has_color and attr_unique_vals: + N_unique_vals = len(attr_unique_vals) + if len(result_vals) != N_unique_vals * color_cardinality: + columns = view._vis_data.columns + if has_color: + df = pandas.DataFrame( + { + columns[0]: attr_unique_vals * color_cardinality, + columns[1]: pandas.Series(color_attr_vals).repeat(N_unique_vals), + } + ) + view._vis_data = view._vis_data.merge( + df, + on=[columns[0], columns[1]], + how="right", + suffixes=["", "_right"], + ) + for col in columns[2:]: + view._vis_data[col] = view._vis_data[col].fillna(0) # Triggers __setitem__ + assert len(list(view._vis_data[groupby_attr.attribute])) == N_unique_vals * len( + color_attr_vals + ), f"Aggregated data missing values compared to original range of values of `{groupby_attr.attribute, color_attr.attribute}`." + view._vis_data = view._vis_data.iloc[ + :, :3 + ] # Keep only the three relevant columns not the *_right columns resulting from merge + else: + df = pandas.DataFrame({columns[0]: attr_unique_vals}) + + view._vis_data = view._vis_data.merge( + df, on=columns[0], how="right", suffixes=["", "_right"] + ) + + for col in columns[1:]: + view._vis_data[col] = view._vis_data[col].fillna(0) + assert ( + len(list(view._vis_data[groupby_attr.attribute])) == N_unique_vals + ), f"Aggregated data missing values compared to original range of values of `{groupby_attr.attribute}`." + view._vis_data = view._vis_data.sort_values(by=groupby_attr.attribute, ascending=True) + view._vis_data = view._vis_data.reset_index() + view._vis_data = view._vis_data.drop(columns="index") + # view._vis_data.length = list(length_query["length"])[0] @staticmethod - def execute_binning(vis: Vis, ldf: LuxDataFrame): + def execute_binning(view: Vis, tbl: LuxSQLTable): + """ + Binning of data points for generating histograms + Parameters + ---------- + vis: lux.Vis + lux.Vis object that represents a visualization + tbl : lux.core.frame + LuxSQLTable with specified intent. + Returns + ------- + None + """ import numpy as np - import pandas as pd - - bin_attribute = list(filter(lambda x: x.bin_size != 0, vis._inferred_intent))[0] - if not math.isnan(vis.data.min_max[bin_attribute.attribute][0]) and math.isnan( - vis.data.min_max[bin_attribute.attribute][1] - ): - num_bins = bin_attribute.bin_size - attr_min = min(ldf.unique_values[bin_attribute.attribute]) - attr_max = max(ldf.unique_values[bin_attribute.attribute]) - attr_type = type(ldf.unique_values[bin_attribute.attribute][0]) - - # need to calculate the bin edges before querying for the relevant data - bin_width = (attr_max - attr_min) / num_bins - upper_edges = [] - for e in range(1, num_bins): - curr_edge = attr_min + e * bin_width - if attr_type == int: - upper_edges.append(str(math.ceil(curr_edge))) - else: - upper_edges.append(str(curr_edge)) - upper_edges = ",".join(upper_edges) - vis_filter, filter_vars = SQLExecutor.execute_filter(vis) - bin_count_query = f"SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket({bin_attribute.attribute}, '{{{upper_edges}}}') FROM {lux.config.table_name}) as Buckets GROUP BY width_bucket ORDER BY width_bucket" - bin_count_data = pd.read_sql(bin_count_query, ldf.SQLconnection) - # counts,binEdges = np.histogram(ldf[bin_attribute.attribute],bins=bin_attribute.bin_size) + bin_attribute = list(filter(lambda x: x.bin_size != 0, view._inferred_intent))[0] + + num_bins = bin_attribute.bin_size + attr_min = tbl._min_max[bin_attribute.attribute][0] + attr_max = tbl._min_max[bin_attribute.attribute][1] + attr_type = type(tbl.unique_values[bin_attribute.attribute][0]) + + # get filters if available + where_clause, filterVars = SQLExecutor.execute_filter(view) + + length_query = pandas.read_sql( + "SELECT COUNT(1) as length FROM {} {}".format(tbl.table_name, where_clause), + lux.config.SQLconnection, + ) + # need to calculate the bin edges before querying for the relevant data + bin_width = (attr_max - attr_min) / num_bins + upper_edges = [] + for e in range(1, num_bins): + curr_edge = attr_min + e * bin_width + if attr_type == int: + upper_edges.append(str(math.ceil(curr_edge))) + else: + upper_edges.append(str(curr_edge)) + upper_edges = ",".join(upper_edges) + view_filter, filter_vars = SQLExecutor.execute_filter(view) + bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(CAST (\"{}\" AS FLOAT), '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format( + bin_attribute.attribute, + "{" + upper_edges + "}", + tbl.table_name, + where_clause, + ) + + bin_count_data = pandas.read_sql(bin_count_query, lux.config.SQLconnection) + if not bin_count_data["width_bucket"].isnull().values.any(): + # np.histogram breaks if data contain NaN + + # counts,binEdges = np.histogram(tbl[bin_attribute.attribute],bins=bin_attribute.bin_size) # binEdges of size N+1, so need to compute binCenter as the bin location upper_edges = [float(i) for i in upper_edges.split(",")] if attr_type == int: @@ -177,37 +437,356 @@ def execute_binning(vis: Vis, ldf: LuxDataFrame): for i in range(0, len(bin_centers)): if i not in bucket_lables: bin_count_data = bin_count_data.append( - pd.DataFrame([[i, 0]], columns=bin_count_data.columns) + pandas.DataFrame([[i, 0]], columns=bin_count_data.columns) ) - vis._vis_data = pd.DataFrame( + view._vis_data = pandas.DataFrame( np.array([bin_centers, list(bin_count_data["count"])]).T, columns=[bin_attribute.attribute, "Number of Records"], ) - vis._vis_data = utils.pandas_to_lux(vis.data) + view._vis_data = utils.pandas_to_lux(view.data) + # view._vis_data.length = list(length_query["length"])[0] + + @staticmethod + def execute_2D_binning(view: Vis, tbl: LuxSQLTable): + import numpy as np + + x_attribute = list(filter(lambda x: x.channel == "x", view._inferred_intent))[0] + + y_attribute = list(filter(lambda x: x.channel == "y", view._inferred_intent))[0] + + num_bins = lux.config.heatmap_bin_size + x_attr_min = tbl._min_max[x_attribute.attribute][0] + x_attr_max = tbl._min_max[x_attribute.attribute][1] + x_attr_type = type(tbl.unique_values[x_attribute.attribute][0]) + + y_attr_min = tbl._min_max[y_attribute.attribute][0] + y_attr_max = tbl._min_max[y_attribute.attribute][1] + y_attr_type = type(tbl.unique_values[y_attribute.attribute][0]) + + # get filters if available + where_clause, filterVars = SQLExecutor.execute_filter(view) + + # need to calculate the bin edges before querying for the relevant data + x_bin_width = (x_attr_max - x_attr_min) / num_bins + y_bin_width = (y_attr_max - y_attr_min) / num_bins + + x_upper_edges = [] + y_upper_edges = [] + for e in range(0, num_bins): + x_curr_edge = x_attr_min + e * x_bin_width + y_curr_edge = y_attr_min + e * y_bin_width + # get upper edges for x attribute bins + if x_attr_type == int: + x_upper_edges.append(math.ceil(x_curr_edge)) + else: + x_upper_edges.append(x_curr_edge) + # get upper edges for y attribute bins + if y_attr_type == int: + y_upper_edges.append(str(math.ceil(y_curr_edge))) + else: + y_upper_edges.append(str(y_curr_edge)) + x_upper_edges_string = [str(int) for int in x_upper_edges] + x_upper_edges_string = ",".join(x_upper_edges_string) + y_upper_edges_string = ",".join(y_upper_edges) + + bin_count_query = "SELECT width_bucket1, width_bucket2, count(*) FROM (SELECT width_bucket(CAST (\"{}\" AS FLOAT), '{}') as width_bucket1, width_bucket(CAST (\"{}\" AS FLOAT), '{}') as width_bucket2 FROM {} {}) as foo GROUP BY width_bucket1, width_bucket2".format( + x_attribute.attribute, + "{" + x_upper_edges_string + "}", + y_attribute.attribute, + "{" + y_upper_edges_string + "}", + tbl.table_name, + where_clause, + ) + + # data = pandas.read_sql(bin_count_query, lux.config.SQLconnection) + + data = pandas.read_sql(bin_count_query, lux.config.SQLconnection) + # data = data[data["width_bucket1"] != num_bins - 1] + # data = data[data["width_bucket2"] != num_bins - 1] + if len(data) > 0: + data["xBinStart"] = data.apply( + lambda row: float(x_upper_edges[int(row["width_bucket1"]) - 1]) - x_bin_width, axis=1 + ) + data["xBinEnd"] = data.apply( + lambda row: float(x_upper_edges[int(row["width_bucket1"]) - 1]), axis=1 + ) + data["yBinStart"] = data.apply( + lambda row: float(y_upper_edges[int(row["width_bucket2"]) - 1]) - y_bin_width, axis=1 + ) + data["yBinEnd"] = data.apply( + lambda row: float(y_upper_edges[int(row["width_bucket2"]) - 1]), axis=1 + ) + view._vis_data = utils.pandas_to_lux(data) @staticmethod - # takes in a vis and returns an appropriate SQL WHERE clause that based on the filters specified in the vis's _inferred_intent - def execute_filter(vis: Vis): + def execute_filter(view: Vis): + """ + Helper function to convert a Vis' filter specification to a SQL where clause. + Takes in a Vis object and returns an appropriate SQL WHERE clause based on the filters specified in the vis' _inferred_intent. + + Parameters + ---------- + vis: lux.Vis + lux.Vis object that represents a visualization + + Returns + ------- + where_clause: string + String representation of a SQL WHERE clause + filter_vars: list of strings + list of variables that have been used as filters + """ + filters = utils.get_filter_specs(view._inferred_intent) + return SQLExecutor.create_where_clause(filters, view=view) + + def create_where_clause(filter_specs, view=""): where_clause = [] - filters = utils.get_filter_specs(vis._inferred_intent) filter_vars = [] + filters = filter_specs if filters: for f in range(0, len(filters)): if f == 0: where_clause.append("WHERE") else: where_clause.append("AND") + curr_value = str(filters[f].value) + curr_value = curr_value.replace("'", "''") where_clause.extend( [ - str(filters[f].attribute), + '"' + str(filters[f].attribute) + '"', str(filters[f].filter_op), - "'" + str(filters[f].value) + "'", + "'" + curr_value + "'", ] ) if filters[f].attribute not in filter_vars: filter_vars.append(filters[f].attribute) + if view != "": + attributes = utils.get_attrs_specs(view._inferred_intent) + + # need to ensure that no null values are included in the data + # null values breaks binning queries + for a in attributes: + if a.attribute != "Record": + if where_clause == []: + where_clause.append("WHERE") + else: + where_clause.append("AND") + where_clause.extend( + [ + '"' + str(a.attribute) + '"', + "IS NOT NULL", + ] + ) + if where_clause == []: return ("", []) else: where_clause = " ".join(where_clause) return (where_clause, filter_vars) + + def get_filtered_size(filter_specs, tbl): + clause_info = SQLExecutor.create_where_clause(filter_specs=filter_specs, view="") + where_clause = clause_info[0] + filter_intents = filter_specs[0] + filtered_length = pandas.read_sql( + "SELECT COUNT(1) as length FROM {} {}".format(tbl.table_name, where_clause), + lux.config.SQLconnection, + ) + return list(filtered_length["length"])[0] + + ####################################################### + ########## Metadata, type, model schema ############### + ####################################################### + + def compute_dataset_metadata(self, tbl: LuxSQLTable): + """ + Function which computes the metadata required for the Lux recommendation system. + Populates the metadata parameters of the specified Lux DataFrame. + + Parameters + ---------- + tbl: lux.LuxSQLTable + lux.LuxSQLTable object whose metadata will be calculated + + Returns + ------- + None + """ + if not tbl._setup_done: + self.get_SQL_attributes(tbl) + tbl._data_type = {} + #####NOTE: since we aren't expecting users to do much data processing with the SQL database, should we just keep this + ##### in the initialization and do it just once + self.compute_data_type(tbl) + self.compute_stats(tbl) + + def get_SQL_attributes(self, tbl: LuxSQLTable): + """ + Retrieves the names of variables within a specified Lux DataFrame's Postgres SQL table. + Uses these variables to populate the Lux DataFrame's columns list. + + Parameters + ---------- + tbl: lux.LuxSQLTable + lux.LuxSQLTable object whose columns will be populated + + Returns + ------- + None + """ + if "." in tbl.table_name: + table_name = tbl.table_name[self.table_name.index(".") + 1 :] + else: + table_name = tbl.table_name + attr_query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS where TABLE_NAME = '{}'".format( + table_name + ) + attributes = list(pandas.read_sql(attr_query, lux.config.SQLconnection)["column_name"]) + for attr in attributes: + tbl[attr] = None + tbl._setup_done = True + + def compute_stats(self, tbl: LuxSQLTable): + """ + Function which computes the min and max values for each variable within the specified Lux DataFrame's SQL table. + Populates the metadata parameters of the specified Lux DataFrame. + + Parameters + ---------- + tbl: lux.LuxSQLTable + lux.LuxSQLTable object whose metadata will be calculated + + Returns + ------- + None + """ + # precompute statistics + tbl.unique_values = {} + tbl._min_max = {} + length_query = pandas.read_sql( + "SELECT COUNT(1) as length FROM {}".format(tbl.table_name), + lux.config.SQLconnection, + ) + tbl._length = list(length_query["length"])[0] + + self.get_unique_values(tbl) + for attribute in tbl.columns: + if tbl._data_type[attribute] == "quantitative": + min_max_query = pandas.read_sql( + 'SELECT MIN("{}") as min, MAX("{}") as max FROM {}'.format( + attribute, attribute, tbl.table_name + ), + lux.config.SQLconnection, + ) + tbl._min_max[attribute] = ( + list(min_max_query["min"])[0], + list(min_max_query["max"])[0], + ) + + def get_cardinality(self, tbl: LuxSQLTable): + """ + Function which computes the cardinality for each variable within the specified Lux DataFrame's SQL table. + Populates the metadata parameters of the specified Lux DataFrame. + + Parameters + ---------- + tbl: lux.LuxSQLTable + lux.LuxSQLTable object whose metadata will be calculated + + Returns + ------- + None + """ + cardinality = {} + for attr in list(tbl.columns): + card_query = 'SELECT Count(Distinct("{}")) FROM {} WHERE "{}" IS NOT NULL'.format( + attr, tbl.table_name, attr + ) + card_data = pandas.read_sql( + card_query, + lux.config.SQLconnection, + ) + cardinality[attr] = list(card_data["count"])[0] + tbl.cardinality = cardinality + + def get_unique_values(self, tbl: LuxSQLTable): + """ + Function which collects the unique values for each variable within the specified Lux DataFrame's SQL table. + Populates the metadata parameters of the specified Lux DataFrame. + + Parameters + ---------- + tbl: lux.LuxSQLTable + lux.LuxSQLTable object whose metadata will be calculated + + Returns + ------- + None + """ + unique_vals = {} + for attr in list(tbl.columns): + unique_query = 'SELECT Distinct("{}") FROM {} WHERE "{}" IS NOT NULL'.format( + attr, tbl.table_name, attr + ) + unique_data = pandas.read_sql( + unique_query, + lux.config.SQLconnection, + ) + unique_vals[attr] = list(unique_data[attr]) + tbl.unique_values = unique_vals + + def compute_data_type(self, tbl: LuxSQLTable): + """ + Function which the equivalent Pandas data type of each variable within the specified Lux DataFrame's SQL table. + Populates the metadata parameters of the specified Lux DataFrame. + + Parameters + ---------- + tbl: lux.LuxSQLTable + lux.LuxSQLTable object whose metadata will be calculated + + Returns + ------- + None + """ + data_type = {} + self.get_cardinality(tbl) + if "." in tbl.table_name: + table_name = tbl.table_name[tbl.table_name.index(".") + 1 :] + else: + table_name = tbl.table_name + # get the data types of the attributes in the SQL table + for attr in list(tbl.columns): + datatype_query = "SELECT DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '{}' AND COLUMN_NAME = '{}'".format( + table_name, attr + ) + datatype = list(pandas.read_sql(datatype_query, lux.config.SQLconnection)["data_type"])[0] + if str(attr).lower() in {"month", "year"} or "time" in datatype or "date" in datatype: + data_type[attr] = "temporal" + elif datatype in { + "character", + "character varying", + "boolean", + "uuid", + "text", + }: + data_type[attr] = "nominal" + elif datatype in { + "integer", + "numeric", + "decimal", + "bigint", + "real", + "smallint", + "smallserial", + "serial", + "double precision", + }: + if tbl.cardinality[attr] < 13: + data_type[attr] = "nominal" + elif check_if_id_like(tbl, attr): + data_type[attr] = "id" + else: + data_type[attr] = "quantitative" + + tbl._data_type = data_type diff --git a/lux/interestingness/interestingness.py b/lux/interestingness/interestingness.py index 4eb61fe7..53725bc6 100644 --- a/lux/interestingness/interestingness.py +++ b/lux/interestingness/interestingness.py @@ -43,7 +43,6 @@ def interestingness(vis: Vis, ldf: LuxDataFrame) -> int: int Interestingness Score """ - if vis.data is None or len(vis.data) == 0: return -1 # raise Exception("Vis.data needs to be populated before interestingness can be computed. Run Executor.execute(vis,ldf).") @@ -225,13 +224,19 @@ def deviation_from_overall( int Score describing how different the vis is from the overall vis """ - v_filter_size = get_filtered_size(filter_specs, ldf) + if lux.config.executor.name == "PandasExecutor": + if exclude_nan: + vdata = vis.data.dropna() + else: + vdata = vis.data + v_filter_size = get_filtered_size(filter_specs, ldf) + v_size = len(vis.data) + elif lux.config.executor.name == "SQLExecutor": + from lux.executor.SQLExecutor import SQLExecutor - if exclude_nan: - vdata = vis.data.dropna() - else: + v_filter_size = SQLExecutor.get_filtered_size(filter_specs, ldf) + v_size = len(ldf) vdata = vis.data - v_size = len(vdata) v_filter = vdata[msr_attribute] total = v_filter.sum() v_filter = v_filter / total # normalize by total to get ratio diff --git a/lux/processor/Validator.py b/lux/processor/Validator.py index 2550ac31..9d2d63b1 100644 --- a/lux/processor/Validator.py +++ b/lux/processor/Validator.py @@ -18,6 +18,7 @@ from typing import List from lux.utils.date_utils import is_datetime_series, is_datetime_string import warnings +import pandas as pd import lux import lux.utils.utils @@ -90,7 +91,10 @@ def validate_clause(clause): else: vals = [clause.value] for val in vals: - if val not in series.values: + if ( + lux.config.executor.name == "PandasExecutor" + and val not in series.values + ): warn_msg = f"\n- The input value '{val}' does not exist for the attribute '{clause.attribute}' for the DataFrame." return warn_msg diff --git a/lux/utils/utils.py b/lux/utils/utils.py index 4ad187d3..0b38da79 100644 --- a/lux/utils/utils.py +++ b/lux/utils/utils.py @@ -13,6 +13,7 @@ # limitations under the License. import pandas as pd import matplotlib.pyplot as plt +import lux def convert_to_list(x): @@ -83,7 +84,12 @@ def check_if_id_like(df, attribute): if is_string: # For string IDs, usually serial numbers or codes with alphanumerics have a consistent length (eg., CG-39405) with little deviation. For a high cardinality string field but not ID field (like Name or Brand), there is less uniformity across the string lengths. if len(df) > 50: - sampled = df[attribute].sample(50, random_state=99) + if lux.config.executor.name == "PandasExecutor": + sampled = df[attribute].sample(50, random_state=99) + else: + from lux.executor.SQLExecutor import SQLExecutor + + sampled = SQLExecutor.execute_preview(df, preview_size=50) else: sampled = df[attribute] str_length_uniformity = sampled.apply(lambda x: type(x) == str and len(x)).std() < 3 diff --git a/lux/vis/Vis.py b/lux/vis/Vis.py index a6a67ac3..77b26c38 100644 --- a/lux/vis/Vis.py +++ b/lux/vis/Vis.py @@ -351,6 +351,7 @@ def refresh_source(self, ldf): # -> Vis: self._source = ldf self._inferred_intent = Parser.parse(self._intent) Validator.validate_intent(self._inferred_intent, ldf) + Compiler.compile_vis(ldf, self) lux.config.executor.execute([self], ldf) diff --git a/lux/vislib/altair/AltairRenderer.py b/lux/vislib/altair/AltairRenderer.py index 6c8ddae2..2c1ab206 100644 --- a/lux/vislib/altair/AltairRenderer.py +++ b/lux/vislib/altair/AltairRenderer.py @@ -51,10 +51,9 @@ def create_vis(self, vis, standalone=True): """ # Lazy Evaluation for 2D Binning if vis.mark == "scatter" and vis._postbin: - vis._mark = "heatmap" - from lux.executor.PandasExecutor import PandasExecutor - - PandasExecutor.execute_2D_binning(vis) + if lux.config.executor.name == "PandasExecutor": + vis._mark = "heatmap" + lux.config.executor.execute_2D_binning(vis) # If a column has a Period dtype, or contains Period objects, convert it back to Datetime if vis.data is not None: for attr in list(vis.data.columns): diff --git a/requirements-dev.txt b/requirements-dev.txt index abb3f4aa..e3803b98 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,4 +4,7 @@ Sphinx>=3.0.2 sphinx-rtd-theme>=0.4.3 xlrd black +# Install to use SQLExecutor +psycopg2>=2.8.5 +psycopg2-binary>=2.8.5 lxml \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d49ced00..0beb7050 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,9 +4,6 @@ numpy>=1.16.5 pandas>=1.2.0 scikit-learn>=0.22 matplotlib>=3.0.0 -# Install only to use SQLExecutor -# psycopg2>=2.8.5 -# psycopg2-binary>=2.8.5 lux-widget>=0.1.4 us iso3166 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 1ad16581..8ee3ddbb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,11 +23,11 @@ def global_var(): "_prev", "_history", "_saved_export", + "name", "_sampled", "_toggle_pandas_display", "_message", "_pandas_only", "pre_aggregated", "_type_override", - "name", ] diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 76e7034d..d425e41d 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -17,6 +17,7 @@ import pandas as pd from lux.vis.Vis import Vis from lux.vis.VisList import VisList +import psycopg2 def test_underspecified_no_vis(global_var, test_recs): @@ -31,8 +32,22 @@ def test_underspecified_no_vis(global_var, test_recs): assert len(df.current_vis) == 0 df.clear_intent() + # test for sql executor + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + + test_recs(sql_df, no_vis_actions) + assert len(sql_df.current_vis) == 0 + + # test only one filter context case. + sql_df.set_intent([lux.Clause(attribute="origin", filter_op="=", value="USA")]) + test_recs(sql_df, no_vis_actions) + assert len(sql_df.current_vis) == 0 + def test_underspecified_single_vis(global_var, test_recs): + lux.config.set_executor_type("Pandas") one_vis_actions = ["Enhance", "Filter", "Generalize"] df = pytest.car_df df.set_intent([lux.Clause(attribute="MilesPerGal"), lux.Clause(attribute="Weight")]) @@ -45,6 +60,18 @@ def test_underspecified_single_vis(global_var, test_recs): assert attr.data_type == "quantitative" df.clear_intent() + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + sql_df.set_intent([lux.Clause(attribute="milespergal"), lux.Clause(attribute="weight")]) + test_recs(sql_df, one_vis_actions) + assert len(sql_df.current_vis) == 1 + assert sql_df.current_vis[0].mark == "scatter" + for attr in sql_df.current_vis[0]._inferred_intent: + assert attr.data_model == "measure" + for attr in sql_df.current_vis[0]._inferred_intent: + assert attr.data_type == "quantitative" + # def test_underspecified_vis_collection(test_recs): # multiple_vis_actions = ["Current viss"] @@ -80,6 +107,7 @@ def test_underspecified_single_vis(global_var, test_recs): def test_set_intent_as_vis(global_var, test_recs): + lux.config.set_executor_type("Pandas") df = pytest.car_df df._ipython_display_() vis = df.recommendation["Correlation"][0] @@ -87,6 +115,15 @@ def test_set_intent_as_vis(global_var, test_recs): df._ipython_display_() test_recs(df, ["Enhance", "Filter", "Generalize"]) + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + sql_df._repr_html_() + vis = sql_df.recommendation["Correlation"][0] + sql_df.intent = vis + sql_df._repr_html_() + test_recs(sql_df, ["Enhance", "Filter", "Generalize"]) + @pytest.fixture def test_recs(): @@ -100,6 +137,7 @@ def test_recs_function(df, actions): def test_parse(global_var): + lux.config.set_executor_type("Pandas") df = pytest.car_df vlst = VisList([lux.Clause("Origin=?"), lux.Clause(attribute="MilesPerGal")], df) assert len(vlst) == 3 @@ -108,8 +146,25 @@ def test_parse(global_var): vlst = VisList([lux.Clause("Origin=?"), lux.Clause("MilesPerGal")], df) assert len(vlst) == 3 + df = pd.read_csv("lux/data/car.csv") + vlst = VisList([lux.Clause("Origin=?"), lux.Clause("MilesPerGal")], df) + assert len(vlst) == 3 + + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + vlst = VisList([lux.Clause("origin=?"), lux.Clause(attribute="milespergal")], sql_df) + assert len(vlst) == 3 + + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + vlst = VisList([lux.Clause("origin=?"), lux.Clause("milespergal")], sql_df) + assert len(vlst) == 3 + def test_underspecified_vis_collection_zval(global_var): + lux.config.set_executor_type("Pandas") # check if the number of charts is correct df = pytest.car_df vlst = VisList( @@ -126,11 +181,24 @@ def test_underspecified_vis_collection_zval(global_var): # vlst = VisList([lux.Clause(attribute = ["Origin","Cylinders"], filter_op="=",value="?"),lux.Clause(attribute = ["Horsepower"]),lux.Clause(attribute = "Weight")],df) # assert len(vlst) == 8 + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + vlst = VisList( + [ + lux.Clause(attribute="origin", filter_op="=", value="?"), + lux.Clause(attribute="milespergal"), + ], + sql_df, + ) + assert len(vlst) == 3 + def test_sort_bar(global_var): from lux.processor.Compiler import Compiler from lux.vis.Vis import Vis + lux.config.set_executor_type("Pandas") df = pytest.car_df vis = Vis( [ @@ -153,8 +221,35 @@ def test_sort_bar(global_var): assert vis.mark == "bar" assert vis._inferred_intent[1].sort == "ascending" + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + vis = Vis( + [ + lux.Clause(attribute="acceleration", data_model="measure", data_type="quantitative"), + lux.Clause(attribute="origin", data_model="dimension", data_type="nominal"), + ], + sql_df, + ) + assert vis.mark == "bar" + assert vis._inferred_intent[1].sort == "" + + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + vis = Vis( + [ + lux.Clause(attribute="acceleration", data_model="measure", data_type="quantitative"), + lux.Clause(attribute="name", data_model="dimension", data_type="nominal"), + ], + sql_df, + ) + assert vis.mark == "bar" + assert vis._inferred_intent[1].sort == "ascending" + def test_specified_vis_collection(global_var): + lux.config.set_executor_type("Pandas") df = pytest.car_df # change pandas dtype for the column "Year" to datetype df["Year"] = pd.to_datetime(df["Year"], format="%Y") @@ -186,6 +281,7 @@ def test_specified_vis_collection(global_var): def test_specified_channel_enforced_vis_collection(global_var): + lux.config.set_executor_type("Pandas") df = pytest.car_df # change pandas dtype for the column "Year" to datetype df["Year"] = pd.to_datetime(df["Year"], format="%Y") @@ -198,6 +294,7 @@ def test_specified_channel_enforced_vis_collection(global_var): def test_autoencoding_scatter(global_var): + lux.config.set_executor_type("Pandas") # No channel specified df = pytest.car_df # change pandas dtype for the column "Year" to datetype @@ -238,8 +335,101 @@ def test_autoencoding_scatter(global_var): ) df.clear_intent() + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + visList = VisList( + [lux.Clause(attribute="?"), lux.Clause(attribute="milespergal", channel="x")], + sql_df, + ) + for vis in visList: + check_attribute_on_channel(vis, "milespergal", "x") + + +def test_autoencoding_scatter(): + lux.config.set_executor_type("Pandas") + # No channel specified + df = pd.read_csv("lux/data/car.csv") + df["Year"] = pd.to_datetime( + df["Year"], format="%Y" + ) # change pandas dtype for the column "Year" to datetype + vis = Vis([lux.Clause(attribute="MilesPerGal"), lux.Clause(attribute="Weight")], df) + check_attribute_on_channel(vis, "MilesPerGal", "x") + check_attribute_on_channel(vis, "Weight", "y") + + # Partial channel specified + vis = Vis( + [ + lux.Clause(attribute="MilesPerGal", channel="y"), + lux.Clause(attribute="Weight"), + ], + df, + ) + check_attribute_on_channel(vis, "MilesPerGal", "y") + check_attribute_on_channel(vis, "Weight", "x") + + # Full channel specified + vis = Vis( + [ + lux.Clause(attribute="MilesPerGal", channel="y"), + lux.Clause(attribute="Weight", channel="x"), + ], + df, + ) + check_attribute_on_channel(vis, "MilesPerGal", "y") + check_attribute_on_channel(vis, "Weight", "x") + # Duplicate channel specified + with pytest.raises(ValueError): + # Should throw error because there should not be columns with the same channel specified + df.set_intent( + [ + lux.Clause(attribute="MilesPerGal", channel="x"), + lux.Clause(attribute="Weight", channel="x"), + ] + ) + + # test for sql executor + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + vis = Vis([lux.Clause(attribute="milespergal"), lux.Clause(attribute="weight")], sql_df) + check_attribute_on_channel(vis, "milespergal", "x") + check_attribute_on_channel(vis, "weight", "y") + + # Partial channel specified + vis = Vis( + [ + lux.Clause(attribute="milespergal", channel="y"), + lux.Clause(attribute="weight"), + ], + sql_df, + ) + check_attribute_on_channel(vis, "milespergal", "y") + check_attribute_on_channel(vis, "weight", "x") + + # Full channel specified + vis = Vis( + [ + lux.Clause(attribute="milespergal", channel="y"), + lux.Clause(attribute="weight", channel="x"), + ], + sql_df, + ) + check_attribute_on_channel(vis, "milespergal", "y") + check_attribute_on_channel(vis, "weight", "x") + # Duplicate channel specified + with pytest.raises(ValueError): + # Should throw error because there should not be columns with the same channel specified + sql_df.set_intent( + [ + lux.Clause(attribute="milespergal", channel="x"), + lux.Clause(attribute="weight", channel="x"), + ] + ) + def test_autoencoding_histogram(global_var): + lux.config.set_executor_type("Pandas") # No channel specified df = pytest.car_df # change pandas dtype for the column "Year" to datetype @@ -251,8 +441,21 @@ def test_autoencoding_histogram(global_var): assert vis.get_attr_by_channel("x")[0].attribute == "MilesPerGal" assert vis.get_attr_by_channel("y")[0].attribute == "Record" + # No channel specified + # test for sql executor + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + vis = Vis([lux.Clause(attribute="milespergal", channel="y")], sql_df) + check_attribute_on_channel(vis, "milespergal", "y") + + vis = Vis([lux.Clause(attribute="milespergal", channel="x")], sql_df) + assert vis.get_attr_by_channel("x")[0].attribute == "milespergal" + assert vis.get_attr_by_channel("y")[0].attribute == "Record" + def test_autoencoding_line_chart(global_var): + lux.config.set_executor_type("Pandas") df = pytest.car_df # change pandas dtype for the column "Year" to datetype df["Year"] = pd.to_datetime(df["Year"], format="%Y") @@ -292,8 +495,48 @@ def test_autoencoding_line_chart(global_var): ) df.clear_intent() + # test for sql executor + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + vis = Vis([lux.Clause(attribute="year"), lux.Clause(attribute="acceleration")], sql_df) + check_attribute_on_channel(vis, "year", "x") + check_attribute_on_channel(vis, "acceleration", "y") + + # Partial channel specified + vis = Vis( + [ + lux.Clause(attribute="year", channel="y"), + lux.Clause(attribute="acceleration"), + ], + sql_df, + ) + check_attribute_on_channel(vis, "year", "y") + check_attribute_on_channel(vis, "acceleration", "x") + + # Full channel specified + vis = Vis( + [ + lux.Clause(attribute="year", channel="y"), + lux.Clause(attribute="acceleration", channel="x"), + ], + sql_df, + ) + check_attribute_on_channel(vis, "year", "y") + check_attribute_on_channel(vis, "acceleration", "x") + + with pytest.raises(ValueError): + # Should throw error because there should not be columns with the same channel specified + sql_df.set_intent( + [ + lux.Clause(attribute="year", channel="x"), + lux.Clause(attribute="acceleration", channel="x"), + ] + ) + def test_autoencoding_color_line_chart(global_var): + lux.config.set_executor_type("Pandas") df = pytest.car_df # change pandas dtype for the column "Year" to datetype df["Year"] = pd.to_datetime(df["Year"], format="%Y") @@ -307,8 +550,23 @@ def test_autoencoding_color_line_chart(global_var): check_attribute_on_channel(vis, "Acceleration", "y") check_attribute_on_channel(vis, "Origin", "color") + # test for sql executor + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + intent = [ + lux.Clause(attribute="year"), + lux.Clause(attribute="acceleration"), + lux.Clause(attribute="origin"), + ] + vis = Vis(intent, sql_df) + check_attribute_on_channel(vis, "year", "x") + check_attribute_on_channel(vis, "acceleration", "y") + check_attribute_on_channel(vis, "origin", "color") + def test_autoencoding_color_scatter_chart(global_var): + lux.config.set_executor_type("Pandas") df = pytest.car_df # change pandas dtype for the column "Year" to datetype df["Year"] = pd.to_datetime(df["Year"], format="%Y") @@ -332,8 +590,33 @@ def test_autoencoding_color_scatter_chart(global_var): ) check_attribute_on_channel(vis, "Acceleration", "color") + # test for sql executor + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + vis = Vis( + [ + lux.Clause(attribute="horsepower"), + lux.Clause(attribute="acceleration"), + lux.Clause(attribute="origin"), + ], + sql_df, + ) + check_attribute_on_channel(vis, "origin", "color") + + vis = Vis( + [ + lux.Clause(attribute="horsepower"), + lux.Clause(attribute="acceleration", channel="color"), + lux.Clause(attribute="origin"), + ], + sql_df, + ) + check_attribute_on_channel(vis, "acceleration", "color") + def test_populate_options(global_var): + lux.config.set_executor_type("Pandas") from lux.processor.Compiler import Compiler df = pytest.car_df @@ -361,8 +644,36 @@ def test_populate_options(global_var): ) df.clear_intent() + # test for sql executor + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + sql_df.set_intent([lux.Clause(attribute="?"), lux.Clause(attribute="milespergal")]) + col_set = set() + for specOptions in Compiler.populate_wildcard_options(sql_df._intent, sql_df)["attributes"]: + for clause in specOptions: + col_set.add(clause.attribute) + assert list_equal(list(col_set), list(sql_df.columns)) + + sql_df.set_intent( + [ + lux.Clause(attribute="?", data_model="measure"), + lux.Clause(attribute="milespergal"), + ] + ) + sql_df._repr_html_() + col_set = set() + for specOptions in Compiler.populate_wildcard_options(sql_df._intent, sql_df)["attributes"]: + for clause in specOptions: + col_set.add(clause.attribute) + assert list_equal( + list(col_set), + ["acceleration", "weight", "horsepower", "milespergal", "displacement"], + ) + def test_remove_all_invalid(global_var): + lux.config.set_executor_type("Pandas") df = pytest.car_df df["Year"] = pd.to_datetime(df["Year"], format="%Y") # with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"): @@ -376,6 +687,20 @@ def test_remove_all_invalid(global_var): assert len(df.current_vis) == 0 df.clear_intent() + # test for sql executor + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + lux.config.set_SQL_connection(connection) + sql_df = lux.LuxSQLTable(table_name="cars") + # with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"): + sql_df.set_intent( + [ + lux.Clause(attribute="origin", filter_op="=", value="USA"), + lux.Clause(attribute="origin"), + ] + ) + sql_df._repr_html_() + assert len(sql_df.current_vis) == 0 + def list_equal(l1, l2): l1.sort() diff --git a/tests/test_config.py b/tests/test_config.py index 18c2da76..bd61366d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -50,6 +50,7 @@ def contain_horsepower(df): def test_default_actions_registered(global_var): + lux.config.set_executor_type("Pandas") df = pytest.car_df df._ipython_display_() assert "Distribution" in df.recommendation diff --git a/tests/test_interestingness.py b/tests/test_interestingness.py index 703fa44d..aac79256 100644 --- a/tests/test_interestingness.py +++ b/tests/test_interestingness.py @@ -16,6 +16,7 @@ import pytest import pandas as pd import numpy as np +import psycopg2 from lux.interestingness.interestingness import interestingness @@ -73,8 +74,26 @@ def test_interestingness_1_0_1(global_var): assert df.current_vis[0].score == 0 df.clear_intent() + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("car") + + tbl.set_intent( + [ + lux.Clause(attribute="Origin", filter_op="=", value="USA"), + lux.Clause(attribute="Cylinders"), + ] + ) + tbl._repr_html_() + filter_score = tbl.recommendation["Filter"][0].score + assert tbl.current_vis[0].score == 0 + assert filter_score > 0 + tbl.clear_intent() + def test_interestingness_0_1_0(global_var): + lux.config.set_executor_type("Pandas") df = pytest.car_df df["Year"] = pd.to_datetime(df["Year"], format="%Y") @@ -134,8 +153,25 @@ def test_interestingness_0_1_1(global_var): assert str(df.recommendation["Current Vis"][0]._inferred_intent[2].value) == "USA" df.clear_intent() + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("car") + + tbl.set_intent( + [ + lux.Clause(attribute="Origin", filter_op="=", value="?"), + lux.Clause(attribute="MilesPerGal"), + ] + ) + tbl._repr_html_() + assert interestingness(tbl.recommendation["Current Vis"][0], tbl) != None + assert str(tbl.recommendation["Current Vis"][0]._inferred_intent[2].value) == "USA" + tbl.clear_intent() + def test_interestingness_1_1_0(global_var): + lux.config.set_executor_type("Pandas") df = pytest.car_df df["Year"] = pd.to_datetime(df["Year"], format="%Y") @@ -204,12 +240,31 @@ def test_interestingness_1_1_1(global_var): assert interestingness(df.recommendation["Filter"][0], df) != None df.clear_intent() + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("car") + + tbl.set_intent( + [ + lux.Clause(attribute="Horsepower"), + lux.Clause(attribute="Origin", filter_op="=", value="USA", bin_size=20), + ] + ) + tbl._repr_html_() + assert interestingness(tbl.recommendation["Enhance"][0], tbl) != None + + # check for top recommended Filter graph score is not none + assert interestingness(tbl.recommendation["Filter"][0], tbl) != None + tbl.clear_intent() + def test_interestingness_1_2_0(global_var): from lux.vis.Vis import Vis from lux.vis.Vis import Clause from lux.interestingness.interestingness import interestingness + lux.config.set_executor_type("Pandas") df = pytest.car_df y_clause = Clause(attribute="Name", channel="y") color_clause = Clause(attribute="Cylinders", channel="color") @@ -301,6 +356,6 @@ def test_interestingness_deviation_nan(): smaller_diff_score = interestingness(vis, test) bigger_diff_score = interestingness(vis2, test) - assert np.isclose(smaller_diff_score, 0.29, rtol=0.1) - assert np.isclose(bigger_diff_score, 0.94, rtol=0.1) + assert np.isclose(smaller_diff_score, 0.19, rtol=0.1) + assert np.isclose(bigger_diff_score, 0.62, rtol=0.1) assert smaller_diff_score < bigger_diff_score diff --git a/tests/test_series.py b/tests/test_series.py index 0c0ff668..83b6b818 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -40,13 +40,13 @@ def test_df_to_series(): "_prev", "_history", "_saved_export", + "name", "_sampled", "_toggle_pandas_display", "_message", "_pandas_only", "pre_aggregated", "_type_override", - "name", ], "Metadata is lost when going from Dataframe to Series." assert df.cardinality is not None, "Metadata is lost when going from Dataframe to Series." assert series.name == "Weight", "Pandas Series original `name` property not retained." diff --git a/tests/test_sql_executor.py b/tests/test_sql_executor.py new file mode 100644 index 00000000..f8d8d907 --- /dev/null +++ b/tests/test_sql_executor.py @@ -0,0 +1,247 @@ +# Copyright 2019-2020 The Lux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .context import lux +import pytest +import pandas as pd +from lux.executor.SQLExecutor import SQLExecutor +from lux.vis.Vis import Vis +from lux.vis.VisList import VisList +import psycopg2 + + +def test_lazy_execution(): + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("car") + + intent = [ + lux.Clause(attribute="Horsepower", aggregation="mean"), + lux.Clause(attribute="Origin"), + ] + vis = Vis(intent) + # Check data field in vis is empty before calling executor + assert vis.data is None + SQLExecutor.execute([vis], tbl) + assert type(vis.data) == lux.core.frame.LuxDataFrame + + +def test_selection(): + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("car") + + intent = [ + lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]), + lux.Clause(attribute="Year"), + ] + vislist = VisList(intent, tbl) + assert all([type(vis.data) == lux.core.frame.LuxDataFrame for vis in vislist]) + assert all(vislist[2].data.columns == ["Year", "Acceleration"]) + + +def test_aggregation(): + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("car") + + intent = [ + lux.Clause(attribute="Horsepower", aggregation="mean"), + lux.Clause(attribute="Origin"), + ] + vis = Vis(intent, tbl) + result_df = vis.data + assert int(result_df[result_df["Origin"] == "USA"]["Horsepower"]) == 119 + + intent = [ + lux.Clause(attribute="Horsepower", aggregation="sum"), + lux.Clause(attribute="Origin"), + ] + vis = Vis(intent, tbl) + result_df = vis.data + assert int(result_df[result_df["Origin"] == "Japan"]["Horsepower"]) == 6307 + + intent = [ + lux.Clause(attribute="Horsepower", aggregation="max"), + lux.Clause(attribute="Origin"), + ] + vis = Vis(intent, tbl) + result_df = vis.data + assert int(result_df[result_df["Origin"] == "Europe"]["Horsepower"]) == 133 + + +def test_colored_bar_chart(): + from lux.vis.Vis import Vis + from lux.vis.Vis import Clause + + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("car") + + x_clause = Clause(attribute="MilesPerGal", channel="x") + y_clause = Clause(attribute="Origin", channel="y") + color_clause = Clause(attribute="Cylinders", channel="color") + + new_vis = Vis([x_clause, y_clause, color_clause], tbl) + # make sure dimention of the data is correct + color_cardinality = len(tbl.unique_values["Cylinders"]) + group_by_cardinality = len(tbl.unique_values["Origin"]) + assert len(new_vis.data.columns) == 3 + assert ( + len(new_vis.data) == 15 > group_by_cardinality < color_cardinality * group_by_cardinality + ) # Not color_cardinality*group_by_cardinality since some combinations have 0 values + + +def test_colored_line_chart(): + from lux.vis.Vis import Vis + from lux.vis.Vis import Clause + + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("car") + + x_clause = Clause(attribute="Year", channel="x") + y_clause = Clause(attribute="MilesPerGal", channel="y") + color_clause = Clause(attribute="Cylinders", channel="color") + + new_vis = Vis([x_clause, y_clause, color_clause], tbl) + + # make sure dimention of the data is correct + color_cardinality = len(tbl.unique_values["Cylinders"]) + group_by_cardinality = len(tbl.unique_values["Year"]) + assert len(new_vis.data.columns) == 3 + assert ( + len(new_vis.data) == 60 > group_by_cardinality < color_cardinality * group_by_cardinality + ) # Not color_cardinality*group_by_cardinality since some combinations have 0 values + + +def test_filter(): + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("car") + + intent = [ + lux.Clause(attribute="Horsepower"), + lux.Clause(attribute="Year"), + lux.Clause(attribute="Origin", filter_op="=", value="USA"), + ] + vis = Vis(intent, tbl) + vis._vis_data = tbl + filter_output = SQLExecutor.execute_filter(vis) + where_clause = filter_output[0] + where_clause_list = where_clause.split(" AND ") + assert ( + "WHERE \"Origin\" = 'USA'" in where_clause_list + and '"Horsepower" IS NOT NULL' in where_clause_list + and '"Year" IS NOT NULL' in where_clause_list + ) + assert filter_output[1] == ["Origin"] + + +def test_inequalityfilter(): + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("car") + + vis = Vis( + [ + lux.Clause(attribute="Horsepower", filter_op=">", value=50), + lux.Clause(attribute="MilesPerGal"), + ] + ) + vis._vis_data = tbl + filter_output = SQLExecutor.execute_filter(vis) + assert filter_output[0] == 'WHERE "Horsepower" > \'50\' AND "MilesPerGal" IS NOT NULL' + assert filter_output[1] == ["Horsepower"] + + intent = [ + lux.Clause(attribute="Horsepower", filter_op="<=", value=100), + lux.Clause(attribute="MilesPerGal"), + ] + vis = Vis(intent, tbl) + vis._vis_data = tbl + filter_output = SQLExecutor.execute_filter(vis) + assert filter_output[0] == 'WHERE "Horsepower" <= \'100\' AND "MilesPerGal" IS NOT NULL' + assert filter_output[1] == ["Horsepower"] + + +def test_binning(): + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("car") + + vis = Vis([lux.Clause(attribute="Horsepower")], tbl) + nbins = list(filter(lambda x: x.bin_size != 0, vis._inferred_intent))[0].bin_size + assert len(vis.data) == nbins + + +def test_record(): + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("car") + + vis = Vis([lux.Clause(attribute="Cylinders")], tbl) + assert len(vis.data) == len(tbl.unique_values["Cylinders"]) + + +def test_filter_aggregation_fillzero_aligned(): + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("car") + + intent = [ + lux.Clause(attribute="Cylinders"), + lux.Clause(attribute="MilesPerGal"), + lux.Clause("Origin=Japan"), + ] + vis = Vis(intent, tbl) + result = vis.data + assert result[result["Cylinders"] == 5]["MilesPerGal"].values[0] == 0 + assert result[result["Cylinders"] == 8]["MilesPerGal"].values[0] == 0 + + +def test_exclude_attribute(): + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("car") + + intent = [lux.Clause("?", exclude=["Name", "Year"]), lux.Clause("Horsepower")] + vislist = VisList(intent, tbl) + for vis in vislist: + assert vis.get_attr_by_channel("x")[0].attribute != "Year" + assert vis.get_attr_by_channel("x")[0].attribute != "name" + assert vis.get_attr_by_channel("y")[0].attribute != "Year" + assert vis.get_attr_by_channel("y")[0].attribute != "Year" + + +def test_null_values(): + # checks that the SQLExecutor has filtered out any None or Null values from its metadata + connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux") + tbl = lux.LuxSQLTable() + lux.config.set_SQL_connection(connection) + tbl.set_SQL_table("aug_test_table") + + assert None not in tbl.unique_values["enrolled_university"] diff --git a/tests/test_type.py b/tests/test_type.py index 8b57d003..5395c661 100644 --- a/tests/test_type.py +++ b/tests/test_type.py @@ -21,6 +21,7 @@ # Suite of test that checks if data_type inferred correctly by Lux def test_check_cars(): + lux.config.set_SQL_connection("") df = pd.read_csv("lux/data/car.csv") df.maintain_metadata() assert df.data_type["Name"] == "nominal"