Skip to content

Commit

Permalink
Revise to look for case in jupysql and close_match for SqlMagic
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejb committed Jan 25, 2024
1 parent c5aeba7 commit 858108e
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 25 deletions.
68 changes: 43 additions & 25 deletions src/sql/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,33 @@ def get_user_configs(primary_path, alternate_path):
section_found = False
if file_path and file_path.exists():
data = load_toml(file_path)
section_names = ["tool", "jupysql", "SqlMagic"]

# Look for SqlMagic section in toml file
data = get_nested(data, section_names)
data = data.get("tool")

# Look for jupysql section under tool
if data:
keys = data.keys()
data = data.get("jupysql")
if data is None:
similar_key = case_insensitive_match("jupysql", keys)
if similar_key:
display.message(
f"Hint: We found 'tool.{similar_key}' in {file_path}. "
f"Did you mean 'tool.jupysql'?"
)

# Look for SqlMagic section under jupysql
if data:
keys = data.keys()
data = data.get("SqlMagic")
if data is None:
similar_key_list = find_close_match("SqlMagic", keys)
if similar_key_list:
raise exceptions.ConfigurationError(
f"[tool.jupysql.{similar_key_list[0]}] is an invalid section "
f"name in {file_path}. "
f"Did you mean [tool.jupysql.SqlMagic]?"
)

if data is None:
if display_tip:
Expand Down Expand Up @@ -559,35 +582,30 @@ def enclose_table_with_double_quotations(table, conn):
return _table


def get_nested(data, keys):
def case_insensitive_match(target, string_list):
"""
Retrieve nested data by following a sequence of keys.
Perform a case-insensitive match of a target string against a list of strings.
Parameters
----------
data : dict
The dictionary to retrieve data from.
keys : iterable
An iterable of keys.
Each key is used to step one level down in the data dictionary.
target : str
The target string to match.
string_list : list of str
The list of strings to search through.
Returns
-------
value : any
The value at the end of the sequence of keys.
If any key in the sequence does
not exist in the data dictionary, the function returns `None`.
str or None
The first matching string from the list, preserving its original case,
or None if there is no match.
Examples
--------
>>> data = {'a': {'b': {'c': 1}}}
>>> keys = ['a', 'b', 'c']
>>> get_nested(data, keys)
1
"""
for key in keys:
if isinstance(data, dict) and key in data:
data = data[key]
else:
return None
return data
>>> case_insensitive_match('foo', ['bar', 'FOO'])
'FOO'
"""
target_lower = target.lower()
for string in string_list:
if string.lower() == target_lower:
return string
return None
14 changes: 14 additions & 0 deletions src/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,13 @@ def test_load_toml_user_configurations_not_specified(
),
(
"""
[tool.jupysql.SQLMagics]
autocommit = true
""",
"[tool.jupysql.SQLMagics] is an invalid section name in {path}. Did you mean [tool.jupysql.SqlMagic]?",
),
(
"""
[tool.jupysql.SqlMagic]
autocommit = True
""",
Expand Down Expand Up @@ -577,6 +584,13 @@ def test_toml_optional_message(tmp_empty, monkeypatch, ip, capsys):
"[tool.jupysql.SqlMagic] present in {pyproject_path} but empty.",
],
),
(
"[tool.JupySQL.SqlMagic]",
"",
[
"Hint: We found 'tool.JupySQL' in {pyproject_path}. Did you mean 'tool.jupysql'?",
],
),
],
)
def test_user_config_load_sequence_and_messages(
Expand Down

0 comments on commit 858108e

Please sign in to comment.