seismometer.plot.mpl.cohorts_vertical¶
- seismometer.plot.mpl.cohorts_vertical(df, plot_func, gs=None, labels=None, func_kws=None)¶
Uses a passed plotting function to plot a line per given split.
- Parameters:
df (pd.DataFrame) – Data in format of get_cohort_data[0] Currently expects a pandas Dataframe that has three columns: split, true, prob.
plot_func (Callable (@model_plot decorator compatible)) – Plotting function that takes y_true and y_proba as first two inputs, and allows axis to be passed in by keyword axis.
gs (Optional[Axes], default=None) – Specific gridsearch subplot spec on which to plot, creates a new figure if None.
labels (Optional[list[str]], default=None) – List of labels to optionally pass to the plot_func callable, function must be able to handle a kwarg of ‘label’.
func_kws (Optional[dict], default=None) – A dictionary to pass to callable. Function must be able to handle all keywords.
- Return type:
Figure