2424import pandas as pd
2525import pymc as pm
2626import xarray as xr
27- from pymc .backends import NDArray
28- from pymc .backends .base import MultiTrace
2927from pymc .util import RandomState
3028
3129# If scikit-learn is available, use its data validator
@@ -427,7 +425,6 @@ def fit(
427425 self ,
428426 X : pd .DataFrame ,
429427 y : Optional [pd .Series ] = None ,
430- fit_method = "mcmc" ,
431428 progressbar : bool = True ,
432429 predictor_names : List [str ] = None ,
433430 random_seed : RandomState = None ,
@@ -444,8 +441,6 @@ def fit(
444441 The training input samples.
445442 y : array-like if sklearn is available, otherwise array, shape (n_obs,)
446443 The target values (real numbers).
447- fit_method : str
448- Which method to use to infer model parameters. One of ["mcmc", "MAP"].
449444 progressbar : bool
450445 Specifies whether the fit progressbar should be displayed
451446 predictor_names: List[str] = None,
@@ -454,14 +449,19 @@ def fit(
454449 random_seed : RandomState
455450 Provides sampler with initial random seed for obtaining reproducible samples
456451 **kwargs : Any
457- Parameters to pass to the inference method. See `_fit_mcmc` or `_fit_MAP` for
458- method-specific parameters.
452+ Custom sampler settings can be provided in form of keyword arguments.
453+
454+ Returns
455+ -------
456+ self : az.InferenceData
457+ returns inference data of the fitted model.
458+ Examples
459+ --------
460+ >>> model = MyModel()
461+ >>> idata = model.fit(data)
462+ Auto-assigning NUTS sampler...
463+ Initializing NUTS using jitter+adapt_diag...
459464 """
460- available_methods = ["mcmc" , "MAP" ]
461- if fit_method not in available_methods :
462- raise ValueError (
463- f"Inference method { fit_method } not found. Choose one of { available_methods } ."
464- )
465465 if predictor_names is None :
466466 predictor_names = []
467467 if y is None :
@@ -474,74 +474,14 @@ def fit(
474474 sampler_config ["progressbar" ] = progressbar
475475 sampler_config ["random_seed" ] = random_seed
476476 sampler_config .update (** kwargs )
477-
478- if fit_method == "mcmc" :
479- self .idata = self .sample_model (** sampler_config )
480- elif fit_method == "MAP" :
481- self .idata = self ._fit_MAP (** sampler_config )
477+ self .idata = self .sample_model (** sampler_config )
482478
483479 X_df = pd .DataFrame (X , columns = X .columns )
484480 combined_data = pd .concat ([X_df , y ], axis = 1 )
485481 assert all (combined_data .columns ), "All columns must have non-empty names"
486482 self .idata .add_groups (fit_data = combined_data .to_xarray ()) # type: ignore
487483 return self .idata # type: ignore
488484
489- def _fit_MAP (
490- self ,
491- ** kwargs ,
492- ):
493- """Find model maximum a posteriori using scipy optimizer"""
494-
495- model = self .model
496- find_MAP_args = {** self .sampler_config , ** kwargs }
497- if "random_seed" in find_MAP_args :
498- # find_MAP takes a different argument name for seed than sample_* do.
499- find_MAP_args ["seed" ] = find_MAP_args ["random_seed" ]
500- # Extra unknown arguments cause problems for SciPy minimize
501- allowed_args = [ # find_MAP args
502- "start" ,
503- "vars" ,
504- "method" ,
505- # "return_raw", # probably causes a problem if set spuriously
506- # "include_transformed", # probably causes a problem if set spuriously
507- "progressbar" ,
508- "maxeval" ,
509- "seed" ,
510- ]
511- allowed_args += [ # scipy.optimize.minimize args
512- # "fun", # used by find_MAP
513- # "x0", # used by find_MAP
514- "args" ,
515- "method" ,
516- # "jac", # used by find_MAP
517- # "hess", # probably causes a problem if set spuriously
518- # "hessp", # probably causes a problem if set spuriously
519- "bounds" ,
520- "constraints" ,
521- "tol" ,
522- "callback" ,
523- "options" ,
524- ]
525- for arg in list (find_MAP_args ):
526- if arg not in allowed_args :
527- del find_MAP_args [arg ]
528-
529- map_res = pm .find_MAP (model = model , ** find_MAP_args )
530- # Filter non-value variables
531- value_vars_names = {v .name for v in model .value_vars }
532- map_res = {k : v for k , v in map_res .items () if k in value_vars_names }
533-
534- # Convert map result to InferenceData
535- map_strace = NDArray (model = model )
536- map_strace .setup (draws = 1 , chain = 0 )
537- map_strace .record (map_res )
538- map_strace .close ()
539- trace = MultiTrace ([map_strace ])
540- idata = pm .to_inference_data (trace , model = model )
541- self .set_idata_attrs (idata )
542-
543- return idata
544-
545485 def predict (
546486 self ,
547487 X_pred : Union [np .ndarray , pd .DataFrame , pd .Series ],
0 commit comments