Skip to content

Commit 9f5fa20

Browse files
committed
fix: use coalesce instead of drop_duplicate_keys for join
closes #1305
1 parent c141dd3 commit 9f5fa20

File tree

3 files changed

+70
-23
lines changed

3 files changed

+70
-23
lines changed

python/datafusion/dataframe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def join(
778778
left_on: None = None,
779779
right_on: None = None,
780780
join_keys: None = None,
781-
drop_duplicate_keys: bool = True,
781+
coalesce_duplicate_keys: bool = True,
782782
) -> DataFrame: ...
783783

784784
@overload
@@ -791,7 +791,7 @@ def join(
791791
left_on: str | Sequence[str],
792792
right_on: str | Sequence[str],
793793
join_keys: tuple[list[str], list[str]] | None = None,
794-
drop_duplicate_keys: bool = True,
794+
coalesce_duplicate_keys: bool = True,
795795
) -> DataFrame: ...
796796

797797
@overload
@@ -804,7 +804,7 @@ def join(
804804
join_keys: tuple[list[str], list[str]],
805805
left_on: None = None,
806806
right_on: None = None,
807-
drop_duplicate_keys: bool = True,
807+
coalesce_duplicate_keys: bool = True,
808808
) -> DataFrame: ...
809809

810810
def join(
@@ -816,7 +816,7 @@ def join(
816816
left_on: str | Sequence[str] | None = None,
817817
right_on: str | Sequence[str] | None = None,
818818
join_keys: tuple[list[str], list[str]] | None = None,
819-
drop_duplicate_keys: bool = True,
819+
coalesce_duplicate_keys: bool = True,
820820
) -> DataFrame:
821821
"""Join this :py:class:`DataFrame` with another :py:class:`DataFrame`.
822822
@@ -829,9 +829,9 @@ def join(
829829
"right", "full", "semi", "anti".
830830
left_on: Join column of the left dataframe.
831831
right_on: Join column of the right dataframe.
832-
drop_duplicate_keys: When True, the columns from the right DataFrame
833-
that have identical names in the ``on`` fields to the left DataFrame
834-
will be dropped.
832+
coalesce_duplicate_keys: When True, coalesce the columns
833+
from the right DataFrame and left DataFrame
834+
that have identical names in the ``on`` fields.
835835
join_keys: Tuple of two lists of column names to join on. [Deprecated]
836836
837837
Returns:
@@ -879,7 +879,7 @@ def join(
879879
right_on = [right_on]
880880

881881
return DataFrame(
882-
self.df.join(right.df, how, left_on, right_on, drop_duplicate_keys)
882+
self.df.join(right.df, how, left_on, right_on, coalesce_duplicate_keys)
883883
)
884884

885885
def join_on(

python/tests/test_dataframe.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ def test_join():
663663
df1 = ctx.create_dataframe([[batch]], "r")
664664

665665
df2 = df.join(df1, on="a", how="inner")
666-
df2 = df2.sort(column("l.a"))
666+
df2 = df2.sort(column("a"))
667667
table = pa.Table.from_batches(df2.collect())
668668

669669
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
@@ -673,16 +673,18 @@ def test_join():
673673
# Since we may have a duplicate column name and pa.Table()
674674
# hides the fact, instead we need to explicitly check the
675675
# resultant arrays.
676-
df2 = df.join(df1, left_on="a", right_on="a", how="inner", drop_duplicate_keys=True)
677-
df2 = df2.sort(column("l.a"))
676+
df2 = df.join(
677+
df1, left_on="a", right_on="a", how="inner", coalesce_duplicate_keys=True
678+
)
679+
df2 = df2.sort(column("a"))
678680
result = df2.collect()[0]
679681
assert result.num_columns == 3
680682
assert result.column(0) == pa.array([1, 2], pa.int64())
681683
assert result.column(1) == pa.array([4, 5], pa.int64())
682684
assert result.column(2) == pa.array([8, 10], pa.int64())
683685

684686
df2 = df.join(
685-
df1, left_on="a", right_on="a", how="inner", drop_duplicate_keys=False
687+
df1, left_on="a", right_on="a", how="inner", coalesce_duplicate_keys=False
686688
)
687689
df2 = df2.sort(column("l.a"))
688690
result = df2.collect()[0]
@@ -695,7 +697,7 @@ def test_join():
695697
# Verify we don't make a breaking change to pre-43.0.0
696698
# where users would pass join_keys as a positional argument
697699
df2 = df.join(df1, (["a"], ["a"]), how="inner")
698-
df2 = df2.sort(column("l.a"))
700+
df2 = df2.sort(column("a"))
699701
table = pa.Table.from_batches(df2.collect())
700702

701703
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
@@ -720,7 +722,7 @@ def test_join_invalid_params():
720722
with pytest.deprecated_call():
721723
df2 = df.join(df1, join_keys=(["a"], ["a"]), how="inner")
722724
df2.show()
723-
df2 = df2.sort(column("l.a"))
725+
df2 = df2.sort(column("a"))
724726
table = pa.Table.from_batches(df2.collect())
725727

726728
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
@@ -778,6 +780,35 @@ def test_join_on():
778780
assert table.to_pydict() == expected
779781

780782

783+
def test_join_full_with_drop_duplicate_keys():
784+
ctx = SessionContext()
785+
786+
batch = pa.RecordBatch.from_arrays(
787+
[pa.array([1, 3, 5, 7, 9]), pa.array([True, True, True, True, True])],
788+
names=["log_time", "key_frame"],
789+
)
790+
key_frame = ctx.create_dataframe([[batch]])
791+
792+
batch = pa.RecordBatch.from_arrays(
793+
[pa.array([2, 4, 6, 8, 10])],
794+
names=["log_time"],
795+
)
796+
query_times = ctx.create_dataframe([[batch]])
797+
798+
merged = query_times.join(
799+
key_frame,
800+
left_on="log_time",
801+
right_on="log_time",
802+
how="full",
803+
coalesce_duplicate_keys=True,
804+
)
805+
merged = merged.sort(column("log_time"))
806+
result = merged.collect()[0]
807+
808+
assert result.num_columns == 2
809+
assert result.column(0).to_pylist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
810+
811+
781812
def test_join_on_invalid_expr():
782813
ctx = SessionContext()
783814

src/dataframe.rs

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ impl PyDataFrame {
649649
how: &str,
650650
left_on: Vec<PyBackedStr>,
651651
right_on: Vec<PyBackedStr>,
652-
drop_duplicate_keys: bool,
652+
coalesce_keys: bool,
653653
) -> PyDataFusionResult<Self> {
654654
let join_type = match how {
655655
"inner" => JoinType::Inner,
@@ -676,23 +676,24 @@ impl PyDataFrame {
676676
None,
677677
)?;
678678

679-
if drop_duplicate_keys {
679+
if coalesce_keys {
680680
let mutual_keys = left_keys
681681
.iter()
682682
.zip(right_keys.iter())
683683
.filter(|(l, r)| l == r)
684684
.map(|(key, _)| *key)
685685
.collect::<Vec<_>>();
686686

687-
let fields_to_drop = mutual_keys
687+
let fields_to_coalesce = mutual_keys
688688
.iter()
689689
.map(|name| {
690-
df.logical_plan()
690+
let qualified_fields = df
691+
.logical_plan()
691692
.schema()
692-
.qualified_fields_with_unqualified_name(name)
693+
.qualified_fields_with_unqualified_name(name);
694+
(*name, qualified_fields)
693695
})
694-
.filter(|r| r.len() == 2)
695-
.map(|r| r[1])
696+
.filter(|(_, fields)| fields.len() == 2)
696697
.collect::<Vec<_>>();
697698

698699
let expr: Vec<Expr> = df
@@ -702,8 +703,23 @@ impl PyDataFrame {
702703
.into_iter()
703704
.enumerate()
704705
.map(|(idx, _)| df.logical_plan().schema().qualified_field(idx))
705-
.filter(|(qualifier, f)| !fields_to_drop.contains(&(*qualifier, f)))
706-
.map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field))))
706+
.filter_map(|(qualifier, field)| {
707+
if let Some((key_name, qualified_fields)) = fields_to_coalesce
708+
.iter()
709+
.find(|(_, qf)| qf.contains(&(qualifier, field)))
710+
{
711+
// Only add the coalesce expression once (when we encounter the first field)
712+
// Skip the second field (it's already included in to coalesce)
713+
if (qualifier, field) == qualified_fields[0] {
714+
let left_col = Expr::Column(Column::from(qualified_fields[0]));
715+
let right_col = Expr::Column(Column::from(qualified_fields[1]));
716+
return Some(coalesce(vec![left_col, right_col]).alias(*key_name));
717+
}
718+
None
719+
} else {
720+
Some(Expr::Column(Column::from((qualifier, field))))
721+
}
722+
})
707723
.collect();
708724
df = df.select(expr)?;
709725
}

0 commit comments

Comments
 (0)