seismometer.plot.mpl.cohorts_overlay

seismometer.plot.mpl.cohorts_overlay(data, plot_func, axis=None, labels=None, func_kws=None, censor_threshold=None)

Uses a passed plotting function to plot a line per given split.

Parameters:
  • data (pd.DataFrame) – Data in format of either get_cohort_data[0] OR get_cohort_performance_data.

  • plot_func (Callable (@model_plot decorator compatible)) – Should accept data in the first parameter and all other parameters as keyword arguments. Axis must be passed in by keyword axis, and has special handling.

  • axis (Optional[Axes], default=None) – Matplotlib axis on which to plot, creates a new figure if None.

  • labels (Optional[list[str]], default=None) – List of labels to optionally pass to the plot_func callable, function must be able to handle a kwarg of ‘label’.

  • func_kws (Optional[dict], default=None) – A dictionary to pass to callable. Function must be able to handle all keywords.

  • censor_threshold (int, default=None) – Minimum number of samples to plot a line, otherwise it will be censored.

Return type:

Figure