Skip to content

Commit

Permalink
feat: Automatically add new columns to Redshift table during write op…
Browse files Browse the repository at this point in the history
…eration
  • Loading branch information
jack-dell committed Sep 10, 2024
1 parent 8cd0df8 commit 2cbab2a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 92 deletions.
2 changes: 1 addition & 1 deletion awswrangler/redshift/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _add_table_columns(
cursor: "redshift_connector.Cursor", schema: str, table: str, new_columns: dict[str, str]
) -> None:
for column_name, column_type in new_columns.items():
sql = f"ALTER TABLE {_identifier(schema)}.{_identifier(table)}\n ADD COLUMN {column_name} {column_type};"
sql = f"ALTER TABLE {_identifier(schema)}.{_identifier(table)}\n ADD COLUMN {_identifier(column_name)} {column_type};"
_logger.debug("Executing alter query:\n%s", sql)
cursor.execute(sql)

Expand Down
181 changes: 90 additions & 91 deletions tests/unit/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,52 +1447,46 @@ def test_copy_add_new_columns(
mode: str,
overwrite_method: str,
) -> None:
schema = "public"
df = pd.DataFrame({"foo": ["a", "b", "c"], "bar": ["c", "d", "e"]})
wr.redshift.copy(
df=df,
path=path,
con=redshift_con,
schema="public",
table=redshift_table,
iam_role=databases_parameters["redshift"]["role"],
add_new_columns=True,
mode=mode,
overwrite_method=overwrite_method,
)
copy_kwargs = {
"df": df,
"path": path,
"con": redshift_con,
"schema": schema,
"table": redshift_table,
"iam_role": databases_parameters["redshift"]["role"],
"primary_keys": ["foo"] if mode == "upsert" else None,
"overwrite_method": overwrite_method,
}

# Add new columns
df["abc"] = ["f", "g", "h"]
df["bce"] = ["j", "k", "l"]
wr.redshift.copy(
df=df,
path=path,
con=redshift_con,
schema="public",
table=redshift_table,
iam_role=databases_parameters["redshift"]["role"],
add_new_columns=True,
mode=mode,
overwrite_method=overwrite_method,
)
df2 = wr.redshift.read_sql_query(sql=f"SELECT * FROM public.{redshift_table}", con=redshift_con)
assert df2.columns.tolist() == df.columns.tolist()
assert df2["abc"].tolist() == ["f", "g", "h"]
assert df2["bce"].tolist() == ["j", "k", "l"]
# Create table
wr.redshift.copy(**copy_kwargs, add_new_columns=True, mode="overwrite")
copy_kwargs["mode"] = mode

# Assert error when trying to add a new column with 'add_new_columns' set to False
df["cde"] = ["m", "n", "o"]
with pytest.raises(redshift_connector.error.ProgrammingError) as exc_info:
wr.redshift.copy(
df=df,
path=path,
con=redshift_con,
schema="public",
table=redshift_table,
iam_role=databases_parameters["redshift"]["role"],
mode=mode,
overwrite_method=overwrite_method,
)
assert "unmatched number of columns" in str(exc_info.value).lower()
# Add new columns
df["xoo"] = ["f", "g", "h"]
df["baz"] = ["j", "k", "l"]
wr.redshift.copy(**copy_kwargs, add_new_columns=True)

sql = f"SELECT * FROM {schema}.{redshift_table}"
if mode == "append":
sql += "\nWHERE xoo IS NOT NULL AND baz IS NOT NULL"
df2 = wr.redshift.read_sql_query(sql=sql, con=redshift_con)
df2 = df2.sort_values(by=df2.columns.to_list())
assert df.values.tolist() == df2.values.tolist()
assert df.columns.tolist() == df2.columns.tolist()

# Assert error when trying to add a new column without 'add_new_columns' parameter (False as default) in "append"
# or "upsert". No error are expected in ('drop', 'cascade') overwrite_method
df["abc"] = ["m", "n", "o"]
if overwrite_method in ("drop", "cascade"):
wr.redshift.copy(**copy_kwargs)
else:
with pytest.raises(redshift_connector.error.ProgrammingError) as exc_info:
wr.redshift.copy(**copy_kwargs)
assert "ProgrammingError" == exc_info.typename
assert "unmatched number of columns" in str(exc_info.value).lower()


@pytest.mark.parametrize(
Expand All @@ -1514,86 +1508,91 @@ def test_to_sql_add_new_columns(
mode: str,
overwrite_method: str,
) -> None:
schema = "public"
df = pd.DataFrame({"foo": ["a", "b", "c"], "bar": ["c", "d", "e"]})
wr.redshift.to_sql(
df=df,
con=redshift_con,
table=redshift_table,
schema="public",
add_new_columns=True,
mode=mode,
overwrite_method=overwrite_method,
)
to_sql_kwargs = {
"df": df,
"con": redshift_con,
"schema": schema,
"table": redshift_table,
"primary_keys": ["foo"] if mode == "upsert" else None,
"overwrite_method": overwrite_method,
}

# Add new columns
df["cba"] = ["f", "g", "h"]
df["ebc"] = ["j", "k", "l"]
wr.redshift.to_sql(
df=df,
con=redshift_con,
table=redshift_table,
schema="public",
add_new_columns=True,
mode=mode,
overwrite_method=overwrite_method,
)
# Create table
wr.redshift.to_sql(**to_sql_kwargs, add_new_columns=True, mode="overwrite")
to_sql_kwargs["mode"] = mode

df2 = wr.redshift.read_sql_query(sql=f"SELECT * FROM public.{redshift_table}", con=redshift_con)
assert df2.columns.tolist() == df.columns.tolist()
assert df2["cba"].tolist() == ["f", "g", "h"]
assert df2["ebc"].tolist() == ["j", "k", "l"]
# Add new columns
df["xoo"] = ["f", "g", "h"]
df["baz"] = ["j", "k", "l"]
wr.redshift.to_sql(**to_sql_kwargs, add_new_columns=True)

sql = f"SELECT * FROM {schema}.{redshift_table}"
if mode == "append":
sql += "\nWHERE xoo IS NOT NULL AND baz IS NOT NULL"
df2 = wr.redshift.read_sql_query(sql=sql, con=redshift_con)
df2 = df2.sort_values(by=df2.columns.to_list())
assert df.values.tolist() == df2.values.tolist()
assert df.columns.tolist() == df2.columns.tolist()

# Assert error when trying to add a new column without 'add_new_columns' parameter (False as default) in "append"
# or "upsert". No error are expected in ('drop', 'cascade') overwrite_method
df["abc"] = ["m", "n", "o"]
if overwrite_method in ("drop", "cascade"):
wr.redshift.to_sql(**to_sql_kwargs)
else:
with pytest.raises(redshift_connector.error.ProgrammingError) as exc_info:
wr.redshift.to_sql(**to_sql_kwargs)
assert "ProgrammingError" == exc_info.typename
assert "insert has more expressions than target columns" in str(exc_info.value).lower()

# Assert error when trying to add a new column with 'add_new_columns' set to False
df["cde"] = ["m", "n", "o"]
with pytest.raises(redshift_connector.error.ProgrammingError) as exc_info:
wr.redshift.to_sql(
df=df,
con=redshift_con,
table=redshift_table,
schema="public",
mode=mode,
overwrite_method=overwrite_method,
)
assert "unmatched number of columns" in str(exc_info.value).lower()
with pytest.raises(redshift_connector.error.ProgrammingError) as exc_info:
wr.redshift.to_sql(**to_sql_kwargs, use_column_names=True)
assert "ProgrammingError" == exc_info.typename
assert 'column "abc" of relation' in str(exc_info.value).lower()


def test_add_new_columns_case_sensitive(
path: str, redshift_table: str, redshift_con: redshift_connector.Connection, databases_parameters: dict[str, Any]
) -> None:
schema = "public"

df = pd.DataFrame({"foo": ["a", "b", "c"]})
wr.redshift.to_sql(df=df, con=redshift_con, table=redshift_table, schema="public", add_new_columns=True)
wr.redshift.to_sql(df=df, con=redshift_con, table=redshift_table, schema=schema, add_new_columns=True)

# Set enable_case_sensitive_identifier to False (default value)
with redshift_con.cursor() as cursor:
cursor.execute("SET enable_case_sensitive_identifier TO off;")
redshift_con.commit()

df["Boo"] = ["f", "g", "h"]
wr.redshift.to_sql(df=df, con=redshift_con, table=redshift_table, schema="public", add_new_columns=True)
df_check = wr.redshift.read_sql_query(sql=f"SELECT * FROM public.{redshift_table}", con=redshift_con)
wr.redshift.to_sql(df=df, con=redshift_con, table=redshift_table, schema=schema, add_new_columns=True)
df2 = wr.redshift.read_sql_query(sql=f"SELECT * FROM {schema}.{redshift_table}", con=redshift_con)

# Since 'enable_case_sensitive_identifier' is set to False, the column 'Boo' is automatically written as 'boo' by
# Redshift
assert df_check.columns.tolist() == [x.lower() for x in df.columns]
assert "boo" in df_check.columns
assert df2.columns.tolist() == [x.lower() for x in df.columns]
assert "boo" in df2.columns

# Trying to add a new column 'BOO' causes an exception because Redshift attempts to lowercase it, resulting in a
# column mismatch between the DataFrame and the table schema
df["BOO"] = ["j", "k", "l"]
with pytest.raises(redshift_connector.error.ProgrammingError) as exc_info:
wr.redshift.to_sql(df=df, con=redshift_con, table=redshift_table, schema="public", add_new_columns=True)
wr.redshift.to_sql(df=df, con=redshift_con, table=redshift_table, schema=schema, add_new_columns=True)
assert "insert has more expressions than target columns" in str(exc_info.value).lower()

# Enable enable_case_sensitive_identifier
with redshift_con.cursor() as cursor:
cursor.execute("SET enable_case_sensitive_identifier TO on;")
redshift_con.commit()
wr.redshift.to_sql(df=df, con=redshift_con, table=redshift_table, schema="public", add_new_columns=True)
wr.redshift.to_sql(df=df, con=redshift_con, table=redshift_table, schema=schema, add_new_columns=True)
cursor.execute("RESET enable_case_sensitive_identifier;")
redshift_con.commit()

# Ensure that the new uppercase column has been added correctly
df_check = wr.redshift.read_sql_query(sql=f"SELECT * FROM public.{redshift_table}", con=redshift_con)
assert df_check.columns.tolist() == df.columns.tolist()
assert "boo" in df_check.columns
assert "BOO" in df_check.columns
df2 = wr.redshift.read_sql_query(sql=f"SELECT * FROM {schema}.{redshift_table}", con=redshift_con)
assert df2.columns.tolist() == df.columns.tolist()
assert "foo" in df2.columns
assert "boo" in df2.columns
assert "BOO" in df2.columns

0 comments on commit 2cbab2a

Please sign in to comment.