@@ -239,6 +239,56 @@ def make_mapping(args, variable):
239239 )
240240
241241
242+ def lowess (options , x , y , x_label , y_label , non_missing ):
243+ import statsmodels .api as sm
244+
245+ frac = options .get ("frac" , 0.6666666 )
246+ # missing ='drop' is the default value for lowess but not for OLS (None)
247+ # we force it here in case statsmodels change their defaults
248+ y_out = sm .nonparametric .lowess (y , x , missing = "drop" , frac = frac )[:, 1 ]
249+ hover_header = "<b>LOWESS trendline</b><br><br>"
250+ return y_out , hover_header , None
251+
252+
253+ def ma (options , x , y , x_label , y_label , non_missing ):
254+ y_out = pd .Series (y , index = x ).rolling (** options ).mean ()[non_missing ]
255+ hover_header = "<b>Moving Average trendline</b><br><br>"
256+ return y_out , hover_header , None
257+
258+
259+ def ewm (options , x , y , x_label , y_label , non_missing ):
260+ y_out = pd .Series (y , index = x ).ewm (** options ).mean ()[non_missing ]
261+ hover_header = "<b>EWM trendline</b><br><br>"
262+ return y_out , hover_header , None
263+
264+
265+ def ols (options , x , y , x_label , y_label , non_missing ):
266+ import statsmodels .api as sm
267+
268+ add_constant = options .get ("add_constant" , True )
269+ fit_results = sm .OLS (
270+ y , sm .add_constant (x ) if add_constant else x , missing = "drop"
271+ ).fit ()
272+ y_out = fit_results .predict ()
273+ hover_header = "<b>OLS trendline</b><br>"
274+ if len (fit_results .params ) == 2 :
275+ hover_header += "%s = %g * %s + %g<br>" % (
276+ y_label ,
277+ fit_results .params [1 ],
278+ x_label ,
279+ fit_results .params [0 ],
280+ )
281+ elif not add_constant :
282+ hover_header += "%s = %g* %s<br>" % (y_label , fit_results .params [0 ], x_label ,)
283+ else :
284+ hover_header += "%s = %g<br>" % (y_label , fit_results .params [0 ],)
285+ hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results .rsquared
286+ return y_out , hover_header , fit_results
287+
288+
289+ trendline_functions = dict (lowess = lowess , ma = ma , ewm = ewm , ols = ols )
290+
291+
242292def make_trace_kwargs (args , trace_spec , trace_data , mapping_labels , sizeref ):
243293 """Populates a dict with arguments to update trace
244294
@@ -313,12 +363,11 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
313363 mapping_labels ["count" ] = "%{x}"
314364 elif attr_name == "trendline" :
315365 if (
316- attr_value [ 0 ] in [ "ols" , "lowess" , "ma" , "ewm" ]
366+ attr_value in trendline_functions
317367 and args ["x" ]
318368 and args ["y" ]
319369 and len (trace_data [[args ["x" ], args ["y" ]]].dropna ()) > 1
320370 ):
321- import statsmodels .api as sm
322371
323372 # sorting is bad but trace_specs with "trendline" have no other attrs
324373 sorted_trace_data = trace_data .sort_values (by = args ["x" ])
@@ -349,56 +398,19 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
349398 np .logical_or (np .isnan (y ), np .isnan (x ))
350399 )
351400 trace_patch ["x" ] = sorted_trace_data [args ["x" ]][non_missing ]
352-
353- if attr_value [0 ] == "lowess" :
354- alpha = attr_value [1 ] or 0.6666666
355- # missing ='drop' is the default value for lowess but not for OLS (None)
356- # we force it here in case statsmodels change their defaults
357- trendline = sm .nonparametric .lowess (
358- y , x , missing = "drop" , frac = alpha
359- )
360- trace_patch ["y" ] = trendline [:, 1 ]
361- hover_header = "<b>LOWESS trendline</b><br><br>"
362- elif attr_value [0 ] == "ma" :
363- trace_patch ["y" ] = (
364- pd .Series (y [non_missing ])
365- .rolling (window = attr_value [1 ] or 3 )
366- .mean ()
367- )
368- elif attr_value [0 ] == "ewm" :
369- trace_patch ["y" ] = (
370- pd .Series (y [non_missing ])
371- .ewm (alpha = attr_value [1 ] or 0.5 )
372- .mean ()
373- )
374- elif attr_value [0 ] == "ols" :
375- add_constant = attr_value [1 ] is not False
376- fit_results = sm .OLS (
377- y , sm .add_constant (x ) if add_constant else x , missing = "drop"
378- ).fit ()
379- trace_patch ["y" ] = fit_results .predict ()
380- hover_header = "<b>OLS trendline</b><br>"
381- if len (fit_results .params ) == 2 :
382- hover_header += "%s = %g * %s + %g<br>" % (
383- args ["y" ],
384- fit_results .params [1 ],
385- args ["x" ],
386- fit_results .params [0 ],
387- )
388- elif not add_constant :
389- hover_header += "%s = %g* %s<br>" % (
390- args ["y" ],
391- fit_results .params [0 ],
392- args ["x" ],
393- )
394- else :
395- hover_header += "%s = %g<br>" % (
396- args ["y" ],
397- fit_results .params [0 ],
398- )
399- hover_header += (
400- "R<sup>2</sup>=%f<br><br>" % fit_results .rsquared
401- )
401+ trendline_function = trendline_functions [attr_value ]
402+ y_out , hover_header , fit_results = trendline_function (
403+ args ["trendline_options" ],
404+ x ,
405+ y ,
406+ args ["x" ],
407+ args ["y" ],
408+ non_missing ,
409+ )
410+ assert len (y_out ) == len (
411+ trace_patch ["x" ]
412+ ), "missing-data-handling failure in trendline code"
413+ trace_patch ["y" ] = y_out
402414 mapping_labels [get_label (args , args ["x" ])] = "%{x}"
403415 mapping_labels [get_label (args , args ["y" ])] = "%{y} <b>(trend)</b>"
404416 elif attr_name .startswith ("error" ):
@@ -1850,9 +1862,8 @@ def infer_config(args, constructor, trace_patch, layout_patch):
18501862 ):
18511863 args ["facet_col_wrap" ] = 0
18521864
1853- if args .get ("trendline" , None ) is not None :
1854- if isinstance (args ["trendline" ], str ):
1855- args ["trendline" ] = (args ["trendline" ], None )
1865+ if "trendline_options" in args and args ["trendline_options" ] is None :
1866+ args ["trendline_options" ] = dict ()
18561867
18571868 # Compute applicable grouping attributes
18581869 for k in group_attrables :
0 commit comments