Skip to content

Commit 878d4db

Browse files
authored
refactor checking for df and bump version
2 parents 4d62e73 + 1429e6f commit 878d4db

17 files changed

+101
-63
lines changed

packages/python/plotly/optional-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ ipython
3939

4040
## pandas deps for some matplotlib functionality ##
4141
pandas
42-
narwhals>=1.11.0
42+
narwhals>=1.12.0
4343

4444
## scipy deps for some FigureFactory functions ##
4545
scipy

packages/python/plotly/plotly/express/_core.py

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,48 +1421,41 @@ def build_dataframe(args, constructor):
14211421
# Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.)
14221422
df_provided = args["data_frame"] is not None
14231423
is_pd_like = False
1424+
needs_interchanging = False
14241425
if df_provided:
14251426

1426-
if nw.dependencies.is_polars_dataframe(
1427-
args["data_frame"]
1428-
) or nw.dependencies.is_pyarrow_table(args["data_frame"]):
1429-
args["data_frame"] = nw.from_native(args["data_frame"], eager_only=True)
1430-
columns = args["data_frame"].columns
1431-
1432-
elif nw.dependencies.is_polars_series(
1433-
args["data_frame"]
1434-
) or nw.dependencies.is_pyarrow_chunked_array(args["data_frame"]):
1435-
args["data_frame"] = nw.from_native(
1436-
args["data_frame"],
1437-
series_only=True,
1438-
).to_frame()
1439-
columns = args["data_frame"].columns
1440-
1441-
elif nw.dependencies.is_pandas_like_dataframe(args["data_frame"]):
1427+
if nw.dependencies.is_pandas_like_dataframe(args["data_frame"]):
14421428

14431429
columns = args["data_frame"].columns # This can be multi index
1444-
args["data_frame"] = nw.from_native(args["data_frame"])
1430+
args["data_frame"] = nw.from_native(args["data_frame"], eager_only=True)
14451431
is_pd_like = True
14461432

14471433
elif nw.dependencies.is_pandas_like_series(args["data_frame"]):
14481434

14491435
args["data_frame"] = nw.from_native(
1450-
args["data_frame"],
1451-
series_only=True,
1436+
args["data_frame"], series_only=True
14521437
).to_frame()
14531438
columns = args["data_frame"].columns
14541439
is_pd_like = True
14551440

1456-
elif hasattr(args["data_frame"], "__dataframe__"):
1457-
# data_frame supports interchange protocol
1458-
args["data_frame"] = nw.from_native(
1459-
nw.from_native(
1460-
args["data_frame"], eager_or_interchange_only=True
1461-
).to_pandas(), # Converts to pandas
1462-
eager_only=True,
1463-
)
1441+
elif isinstance(
1442+
data_frame := nw.from_native(
1443+
args["data_frame"], eager_or_interchange_only=True, strict=False
1444+
),
1445+
nw.DataFrame,
1446+
):
1447+
args["data_frame"] = data_frame
1448+
needs_interchanging = nw.get_level(data_frame) == "interchange"
1449+
columns = args["data_frame"].columns
1450+
1451+
elif isinstance(
1452+
series := nw.from_native(
1453+
args["data_frame"], series_only=True, strict=False
1454+
),
1455+
nw.Series,
1456+
):
1457+
args["data_frame"] = series.to_frame()
14641458
columns = args["data_frame"].columns
1465-
is_pd_like = True
14661459

14671460
elif hasattr(args["data_frame"], "toPandas"):
14681461
# data_frame is PySpark: it does not support interchange and it is not
@@ -1498,11 +1491,16 @@ def build_dataframe(args, constructor):
14981491
columns = None # no data_frame
14991492

15001493
df_input: nw.DataFrame | None = args["data_frame"]
1501-
index = nw.maybe_get_index(df_input) if df_provided else None
1502-
1503-
# This is safe since at this point `_compliant_frame` is one of the "full" level
1504-
# support dataframe(s)
1505-
native_namespace = nw.get_native_namespace(df_input) if df_provided else None
1494+
index = (
1495+
nw.maybe_get_index(df_input)
1496+
if df_provided and not needs_interchanging
1497+
else None
1498+
)
1499+
native_namespace = (
1500+
nw.get_native_namespace(df_input)
1501+
if df_provided and not needs_interchanging
1502+
else None
1503+
)
15061504

15071505
# now we handle special cases like wide-mode or x-xor-y specification
15081506
# by rearranging args to tee things up for process_args_into_dataframe to work
@@ -1575,6 +1573,32 @@ def build_dataframe(args, constructor):
15751573
value_name = _escape_col_name(columns, "value", [])
15761574
var_name = _escape_col_name(columns, var_name, [])
15771575

1576+
if isinstance(args["data_frame"], nw.DataFrame) and needs_interchanging:
1577+
# Interchange to PyArrow
1578+
if wide_mode:
1579+
args["data_frame"] = nw.from_native(
1580+
args["data_frame"].to_arrow(), eager_only=True
1581+
)
1582+
else:
1583+
# Save precious resources by only interchanging columns that are
1584+
# actually going to be plotted. This is tricky to do in the general case,
1585+
# because Plotly allows calls like `px.line(df, x='x', y=['y1', df['y1']])`,
1586+
# but interchange-only objects (e.g. DuckDB) don't typically have a concept
1587+
# of self-standing Series. It's more important to perform project pushdown
1588+
# here seeing as we're materialising to an (eager) PyArrow table.
1589+
necessary_columns = {
1590+
i for i in args.values() if isinstance(i, str) and i in columns
1591+
}
1592+
for field in args:
1593+
if args[field] is not None and field in array_attrables:
1594+
necessary_columns.update(i for i in args[field] if i in columns)
1595+
columns = list(necessary_columns)
1596+
args["data_frame"] = nw.from_native(
1597+
args["data_frame"].select(columns).to_arrow(), eager_only=True
1598+
)
1599+
import pyarrow as pa
1600+
1601+
native_namespace = pa
15781602
missing_bar_dim = None
15791603
if (
15801604
constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types

packages/python/plotly/plotly/tests/test_optional/test_px/test_px_input.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import plotly.express as px
2+
import pyarrow as pa
23
import plotly.graph_objects as go
34
import narwhals.stable.v1 as nw
45
import numpy as np
@@ -290,35 +291,48 @@ def test_build_df_with_index():
290291

291292
def test_build_df_using_interchange_protocol_mock():
292293
class InterchangeDataFrame:
293-
def __init__(self, columns):
294-
self._columns = columns
294+
def __init__(self, df):
295+
self._df = df
295296

296-
def column_names(self):
297-
return self._columns
297+
def __dataframe__(self):
298+
return self
298299

299-
interchange_dataframe = InterchangeDataFrame(
300-
["petal_width", "sepal_length", "sepal_width"]
301-
)
300+
def column_names(self):
301+
return list(self._df._data.keys())
302+
303+
def select_columns_by_name(self, columns):
304+
return InterchangeDataFrame(
305+
CustomDataFrame(
306+
{
307+
key: value
308+
for key, value in self._df._data.items()
309+
if key in columns
310+
}
311+
)
312+
)
302313

303314
class CustomDataFrame:
304-
def __dataframe__(self):
305-
return interchange_dataframe
315+
def __init__(self, data):
316+
self._data = data
317+
318+
def __dataframe__(self, allow_copy: bool = True):
319+
return InterchangeDataFrame(self)
306320

307-
input_dataframe = CustomDataFrame()
321+
input_dataframe = CustomDataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
308322

309-
iris_pandas = px.data.iris()
323+
input_dataframe_pa = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]})
310324

311-
args = dict(data_frame=input_dataframe, x="petal_width", y="sepal_length")
325+
args = dict(data_frame=input_dataframe, x="a", y="b")
312326
with mock.patch(
313-
"narwhals._interchange.dataframe.InterchangeFrame.to_pandas",
314-
return_value=iris_pandas,
327+
"narwhals._interchange.dataframe.InterchangeFrame.to_arrow",
328+
return_value=input_dataframe_pa,
315329
) as mock_from_dataframe:
316330
out = build_dataframe(args, go.Scatter)
317331

318332
mock_from_dataframe.assert_called_once()
319333

320334
assert_frame_equal(
321-
iris_pandas.reset_index()[out["data_frame"].columns],
335+
input_dataframe_pa.select(out["data_frame"].columns).to_pandas(),
322336
out["data_frame"].to_pandas(),
323337
)
324338

packages/python/plotly/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
###################################################
77

88
## dataframe agnostic layer ##
9-
narwhals>=1.11.0
9+
narwhals>=1.12.0

packages/python/plotly/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def run(self):
603603
data_files=[
604604
("etc/jupyter/nbconfig/notebook.d", ["jupyterlab-plotly.json"]),
605605
],
606-
install_requires=["narwhals>=1.11.0", "packaging"],
606+
install_requires=["narwhals>=1.12.0", "packaging"],
607607
zip_safe=False,
608608
cmdclass=dict(
609609
build_py=js_prerelease(versioneer_cmds["build_py"]),
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
requests==2.25.1
22
pytest==7.4.4
3-
narwhals>=1.11.0
3+
narwhals>=1.12.0

packages/python/plotly/test_requirements/requirements_310_optional.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ kaleido
2121
orjson==3.8.12
2222
polars[timezone]
2323
pyarrow
24-
narwhals>=1.11.0
24+
narwhals>=1.12.0
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
requests==2.25.1
22
pytest==7.4.4
3-
narwhals>=1.11.0
3+
narwhals>=1.12.0

packages/python/plotly/test_requirements/requirements_311_optional.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ kaleido
2121
orjson==3.8.12
2222
polars[timezone]
2323
pyarrow
24-
narwhals>=1.11.0
24+
narwhals>=1.12.0
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
requests==2.25.1
22
pytest==7.4.4
3-
narwhals>=1.11.0
3+
narwhals>=1.12.0

0 commit comments

Comments
 (0)