|
2 | 2 | import plotly.io as pio |
3 | 3 | from collections import namedtuple, OrderedDict |
4 | 4 | from ._special_inputs import IdentityMap, Constant, Range |
| 5 | +from .trendline_functions import ols, lowess, ma, ewm |
5 | 6 |
|
6 | 7 | from _plotly_utils.basevalidators import ColorscaleValidator |
7 | 8 | from plotly.colors import qualitative, sequential |
@@ -239,65 +240,6 @@ def make_mapping(args, variable): |
239 | 240 | ) |
240 | 241 |
|
241 | 242 |
|
242 | | -def lowess(options, x, y, x_label, y_label, non_missing): |
243 | | - import statsmodels.api as sm |
244 | | - |
245 | | - frac = options.get("frac", 0.6666666) |
246 | | - # missing ='drop' is the default value for lowess but not for OLS (None) |
247 | | - # we force it here in case statsmodels change their defaults |
248 | | - y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1] |
249 | | - hover_header = "<b>LOWESS trendline</b><br><br>" |
250 | | - return y_out, hover_header, None |
251 | | - |
252 | | - |
253 | | -def ma(options, x, y, x_label, y_label, non_missing): |
254 | | - y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing] |
255 | | - hover_header = "<b>Moving Average trendline</b><br><br>" |
256 | | - return y_out, hover_header, None |
257 | | - |
258 | | - |
259 | | -def ewm(options, x, y, x_label, y_label, non_missing): |
260 | | - y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing] |
261 | | - hover_header = "<b>EWM trendline</b><br><br>" |
262 | | - return y_out, hover_header, None |
263 | | - |
264 | | - |
265 | | -def ols(options, x, y, x_label, y_label, non_missing): |
266 | | - import statsmodels.api as sm |
267 | | - |
268 | | - add_constant = options.get("add_constant", True) |
269 | | - log_x = options.get("log_x", False) |
270 | | - log_y = options.get("log_y", False) |
271 | | - |
272 | | - if log_y: |
273 | | - y = np.log(y) |
274 | | - if log_x: |
275 | | - x = np.log(x) |
276 | | - if add_constant: |
277 | | - x = sm.add_constant(x) |
278 | | - fit_results = sm.OLS(y, x, missing="drop").fit() |
279 | | - y_out = fit_results.predict() |
280 | | - if log_y: |
281 | | - y_out = np.exp(y_out) |
282 | | - hover_header = "<b>OLS trendline</b><br>" |
283 | | - if len(fit_results.params) == 2: |
284 | | - hover_header += "%s = %g * %s + %g<br>" % ( |
285 | | - y_label, |
286 | | - fit_results.params[1], |
287 | | - x_label, |
288 | | - fit_results.params[0], |
289 | | - ) |
290 | | - elif not add_constant: |
291 | | - hover_header += "%s = %g* %s<br>" % (y_label, fit_results.params[0], x_label,) |
292 | | - else: |
293 | | - hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],) |
294 | | - hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared |
295 | | - return y_out, hover_header, fit_results |
296 | | - |
297 | | - |
298 | | -trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols) |
299 | | - |
300 | | - |
301 | 243 | def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): |
302 | 244 | """Populates a dict with arguments to update trace |
303 | 245 |
|
@@ -371,6 +313,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): |
371 | 313 | if trace_spec.constructor == go.Histogram: |
372 | 314 | mapping_labels["count"] = "%{x}" |
373 | 315 | elif attr_name == "trendline": |
| 316 | + trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols) |
374 | 317 | if ( |
375 | 318 | attr_value in trendline_functions |
376 | 319 | and args["x"] |
|
0 commit comments