From f5f35c63a6118907db71fec03856876a4692f4c4 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 15 Nov 2023 08:46:00 -0500 Subject: [PATCH] fix(ir): ensure that join projection columns are all always nullable --- ibis/expr/operations/relations.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index d379ad8675a9..7d6fb0761ddb 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -27,6 +27,8 @@ import pandas as pd import pyarrow as pa + import ibis.expr.types as ir + _table_names = (f"unbound_table_{i:d}" for i in itertools.count()) @@ -658,7 +660,7 @@ def schema(self): return backend._get_schema_using_query(self.query) -def _dedup_join_columns(expr, lname: str, rname: str): +def _dedup_join_columns(expr: ir.Table, lname: str, rname: str): from ibis.expr.operations.generic import TableColumn from ibis.expr.operations.logical import Equals @@ -692,18 +694,22 @@ def _dedup_join_columns(expr, lname: str, rname: str): # Rename columns in the left table that overlap, unless they're known to be # equal to a column in the right left_projections = [ - left[column].name(lname.format(name=column) if lname else column) + left[column] + .cast(left[column].type().copy(nullable=True)) + .name(lname.format(name=column) if lname else column) if column in overlap and column not in equal - else left[column] + else left[column].cast(left[column].type().copy(nullable=True)).name(column) for column in left.columns ] # Rename columns in the right table that overlap, dropping any columns that # are known to be equal to those in the left table right_projections = [ - right[column].name(rname.format(name=column) if rname else column) + right[column] + .cast(right[column].type().copy(nullable=True)) + .name(rname.format(name=column) if rname else column) if column in overlap - else right[column] + else right[column].cast(right[column].type().copy(nullable=True)).name(column) for column in right.columns if column not in equal ]