seismometer.plot.mpl.cohorts_vertical¶
- seismometer.plot.mpl.cohorts_vertical(df, plot_func, gs=None, labels=None, func_kws=None)¶
- Uses a passed plotting function to plot a line per given split. - Parameters:
- df (pd.DataFrame) – Data in format of get_cohort_data[0] Currently expects a pandas Dataframe that has three columns: split, true, prob. 
- plot_func (Callable (@model_plot decorator compatible)) – Plotting function that takes y_true and y_proba as first two inputs, and allows axis to be passed in by keyword axis. 
- gs (Optional[Axes], default=None) – Specific gridsearch subplot spec on which to plot, creates a new figure if None. 
- labels (Optional[list[str]], default=None) – List of labels to optionally pass to the plot_func callable, function must be able to handle a kwarg of ‘label’. 
- func_kws (Optional[dict], default=None) – A dictionary to pass to callable. Function must be able to handle all keywords. 
 
- Return type:
- Figure