seismometer.plot.mpl.cohorts_vertical

seismometer.plot.mpl.cohorts_vertical(df, plot_func, gs=None, labels=None, func_kws=None)

Uses a passed plotting function to plot a line per given split.

Parameters:
  • df (pd.DataFrame) – Data in format of get_cohort_data[0] Currently expects a pandas Dataframe that has three columns: split, true, prob.

  • plot_func (Callable (@model_plot decorator compatible)) – Plotting function that takes y_true and y_proba as first two inputs, and allows axis to be passed in by keyword axis.

  • gs (Optional[Axes], default=None) – Specific gridsearch subplot spec on which to plot, creates a new figure if None.

  • labels (Optional[list[str]], default=None) – List of labels to optionally pass to the plot_func callable, function must be able to handle a kwarg of ‘label’.

  • func_kws (Optional[dict], default=None) – A dictionary to pass to callable. Function must be able to handle all keywords.

Return type:

Figure