55from datetime import datetime
66
77
8- @pytest .mark .parametrize ("mode" , ["ols" , "lowess" ])
9- def test_trendline_results_passthrough (mode ):
8+ @pytest .mark .parametrize (
9+ "mode,options" ,
10+ [
11+ ("ols" , None ),
12+ ("ols" , dict (log_x = True , log_y = True )),
13+ ("lowess" , None ),
14+ ("lowess" , dict (frac = 0.3 )),
15+ ("ma" , dict (window = 2 )),
16+ ("ewma" , dict (alpha = 0.5 )),
17+ ],
18+ )
19+ def test_trendline_results_passthrough (mode , options ):
1020 df = px .data .gapminder ().query ("continent == 'Oceania'" )
11- fig = px .scatter (df , x = "year" , y = "pop" , color = "country" , trendline = mode )
21+ fig = px .scatter (
22+ df ,
23+ x = "year" ,
24+ y = "pop" ,
25+ color = "country" ,
26+ trendline = mode ,
27+ trendline_options = options ,
28+ )
1229 assert len (fig .data ) == 4
1330 for trace in fig ["data" ][0 ::2 ]:
1431 assert "trendline" not in trace .hovertemplate
@@ -20,90 +37,161 @@ def test_trendline_results_passthrough(mode):
2037 if mode == "ols" :
2138 assert len (results ) == 2
2239 assert results ["country" ].values [0 ] == "Australia"
23- assert results ["country" ].values [0 ] == "Australia"
2440 au_result = results ["px_fit_results" ].values [0 ]
2541 assert len (au_result .params ) == 2
2642 else :
2743 assert len (results ) == 0
2844
2945
30- @pytest .mark .parametrize ("mode" , ["ols" , "lowess" ])
31- def test_trendline_enough_values (mode ):
32- fig = px .scatter (x = [0 , 1 ], y = [0 , 1 ], trendline = mode )
46+ @pytest .mark .parametrize (
47+ "mode,options" ,
48+ [
49+ ("ols" , None ),
50+ ("ols" , dict (add_constant = False , log_x = True , log_y = True )),
51+ ("lowess" , None ),
52+ ("lowess" , dict (frac = 0.3 )),
53+ ("ma" , dict (window = 2 )),
54+ ("ewma" , dict (alpha = 0.5 )),
55+ ],
56+ )
57+ def test_trendline_enough_values (mode , options ):
58+ fig = px .scatter (x = [0 , 1 ], y = [0 , 1 ], trendline = mode , trendline_options = options )
3359 assert len (fig .data ) == 2
3460 assert len (fig .data [1 ].x ) == 2
35- fig = px .scatter (x = [0 ], y = [0 ], trendline = mode )
61+ fig = px .scatter (x = [0 ], y = [0 ], trendline = mode , trendline_options = options )
3662 assert len (fig .data ) == 2
3763 assert fig .data [1 ].x is None
38- fig = px .scatter (x = [0 , 1 ], y = [0 , None ], trendline = mode )
64+ fig = px .scatter (x = [0 , 1 ], y = [0 , None ], trendline = mode , trendline_options = options )
3965 assert len (fig .data ) == 2
4066 assert fig .data [1 ].x is None
41- fig = px .scatter (x = [0 , 1 ], y = np .array ([0 , np .nan ]), trendline = mode )
67+ fig = px .scatter (
68+ x = [0 , 1 ], y = np .array ([0 , np .nan ]), trendline = mode , trendline_options = options
69+ )
4270 assert len (fig .data ) == 2
4371 assert fig .data [1 ].x is None
44- fig = px .scatter (x = [0 , 1 , None ], y = [0 , None , 1 ], trendline = mode )
72+ fig = px .scatter (
73+ x = [0 , 1 , None ], y = [0 , None , 1 ], trendline = mode , trendline_options = options
74+ )
4575 assert len (fig .data ) == 2
4676 assert fig .data [1 ].x is None
4777 fig = px .scatter (
48- x = np .array ([0 , 1 , np .nan ]), y = np .array ([0 , np .nan , 1 ]), trendline = mode
78+ x = np .array ([0 , 1 , np .nan ]),
79+ y = np .array ([0 , np .nan , 1 ]),
80+ trendline = mode ,
81+ trendline_options = options ,
4982 )
5083 assert len (fig .data ) == 2
5184 assert fig .data [1 ].x is None
52- fig = px .scatter (x = [0 , 1 , None , 2 ], y = [1 , None , 1 , 2 ], trendline = mode )
85+ fig = px .scatter (
86+ x = [0 , 1 , None , 2 ], y = [1 , None , 1 , 2 ], trendline = mode , trendline_options = options
87+ )
5388 assert len (fig .data ) == 2
5489 assert len (fig .data [1 ].x ) == 2
5590 fig = px .scatter (
56- x = np .array ([0 , 1 , np .nan , 2 ]), y = np .array ([1 , np .nan , 1 , 2 ]), trendline = mode
91+ x = np .array ([0 , 1 , np .nan , 2 ]),
92+ y = np .array ([1 , np .nan , 1 , 2 ]),
93+ trendline = mode ,
94+ trendline_options = options ,
5795 )
5896 assert len (fig .data ) == 2
5997 assert len (fig .data [1 ].x ) == 2
6098
6199
62- @pytest .mark .parametrize ("mode" , ["ols" , "lowess" ])
63- def test_trendline_nan_values (mode ):
100+ @pytest .mark .parametrize (
101+ "mode,options" ,
102+ [
103+ ("ols" , None ),
104+ ("ols" , dict (add_constant = False , log_x = True , log_y = True )),
105+ ("lowess" , None ),
106+ ("lowess" , dict (frac = 0.3 )),
107+ ("ma" , dict (window = 2 )),
108+ ("ewma" , dict (alpha = 0.5 )),
109+ ],
110+ )
111+ def test_trendline_nan_values (mode , options ):
64112 df = px .data .gapminder ().query ("continent == 'Oceania'" )
65113 start_date = 1970
66114 df ["pop" ][df ["year" ] < start_date ] = np .nan
67- fig = px .scatter (df , x = "year" , y = "pop" , color = "country" , trendline = mode )
115+ fig = px .scatter (
116+ df ,
117+ x = "year" ,
118+ y = "pop" ,
119+ color = "country" ,
120+ trendline = mode ,
121+ trendline_options = options ,
122+ )
68123 for trendline in fig ["data" ][1 ::2 ]:
69124 assert trendline .x [0 ] >= start_date
70125 assert len (trendline .x ) == len (trendline .y )
71126
72127
73- def test_no_slope_ols_trendline ():
128+ def test_ols_trendline_slopes ():
74129 fig = px .scatter (x = [0 , 1 ], y = [0 , 1 ], trendline = "ols" )
75- assert "y = 1" in fig .data [1 ].hovertemplate # then + x*(some small number)
130+ assert "y = 1 * x + 0<br> " in fig .data [1 ].hovertemplate
76131 results = px .get_trendline_results (fig )
77132 params = results ["px_fit_results" ].iloc [0 ].params
78133 assert np .all (np .isclose (params , [0 , 1 ]))
79134
135+ fig = px .scatter (x = [0 , 1 ], y = [1 , 2 ], trendline = "ols" )
136+ assert "y = 1 * x + 1<br>" in fig .data [1 ].hovertemplate
137+ results = px .get_trendline_results (fig )
138+ params = results ["px_fit_results" ].iloc [0 ].params
139+ assert np .all (np .isclose (params , [1 , 1 ]))
140+
141+ fig = px .scatter (
142+ x = [0 , 1 ], y = [1 , 2 ], trendline = "ols" , trendline_options = dict (add_constant = False )
143+ )
144+ assert "y = 2 * x<br>" in fig .data [1 ].hovertemplate
145+ results = px .get_trendline_results (fig )
146+ params = results ["px_fit_results" ].iloc [0 ].params
147+ assert np .all (np .isclose (params , [2 ]))
148+
149+ fig = px .scatter (
150+ x = [1 , 1 ], y = [0 , 0 ], trendline = "ols" , trendline_options = dict (add_constant = False )
151+ )
152+ assert "y = 0 * x<br>" in fig .data [1 ].hovertemplate
153+ results = px .get_trendline_results (fig )
154+ params = results ["px_fit_results" ].iloc [0 ].params
155+ assert np .all (np .isclose (params , [0 ]))
156+
80157 fig = px .scatter (x = [1 , 1 ], y = [0 , 0 ], trendline = "ols" )
81- assert "y = 0" in fig .data [1 ].hovertemplate
158+ assert "y = 0<br> " in fig .data [1 ].hovertemplate
82159 results = px .get_trendline_results (fig )
83160 params = results ["px_fit_results" ].iloc [0 ].params
84161 assert np .all (np .isclose (params , [0 ]))
85162
86163 fig = px .scatter (x = [1 , 2 ], y = [0 , 0 ], trendline = "ols" )
87- assert "y = 0" in fig .data [1 ].hovertemplate
164+ assert "y = 0 * x + 0<br> " in fig .data [1 ].hovertemplate
88165 fig = px .scatter (x = [0 , 0 ], y = [1 , 1 ], trendline = "ols" )
89- assert "y = 0 * x + 1" in fig .data [1 ].hovertemplate
166+ assert "y = 0 * x + 1<br> " in fig .data [1 ].hovertemplate
90167 fig = px .scatter (x = [0 , 0 ], y = [1 , 2 ], trendline = "ols" )
91- assert "y = 0 * x + 1.5" in fig .data [1 ].hovertemplate
168+ assert "y = 0 * x + 1.5<br> " in fig .data [1 ].hovertemplate
92169
93170
94- @pytest .mark .parametrize ("mode" , ["ols" , "lowess" ])
95- def test_trendline_on_timeseries (mode ):
171+ @pytest .mark .parametrize (
172+ "mode,options" ,
173+ [
174+ ("ols" , None ),
175+ ("ols" , dict (add_constant = False , log_x = True , log_y = True )),
176+ ("lowess" , None ),
177+ ("lowess" , dict (frac = 0.3 )),
178+ ("ma" , dict (window = 2 )),
179+ ("ma" , dict (window = "10d" )),
180+ ("ewma" , dict (alpha = 0.5 )),
181+ ],
182+ )
183+ def test_trendline_on_timeseries (mode , options ):
96184 df = px .data .stocks ()
97185
98186 with pytest .raises (ValueError ) as err_msg :
99- px .scatter (df , x = "date" , y = "GOOG" , trendline = mode )
187+ px .scatter (df , x = "date" , y = "GOOG" , trendline = mode , trendline_options = options )
100188 assert "Could not convert value of 'x' ('date') into a numeric type." in str (
101189 err_msg .value
102190 )
103191
104192 df ["date" ] = pd .to_datetime (df ["date" ])
105193 df ["date" ] = df ["date" ].dt .tz_localize ("CET" ) # force a timezone
106- fig = px .scatter (df , x = "date" , y = "GOOG" , trendline = mode )
194+ fig = px .scatter (df , x = "date" , y = "GOOG" , trendline = mode , trendline_options = options )
107195 assert len (fig .data ) == 2
108196 assert len (fig .data [0 ].x ) == len (fig .data [1 ].x )
109197 assert type (fig .data [0 ].x [0 ]) == datetime
0 commit comments