Source code for epysurv.visualization.model_diagnostics

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotnine as gg
import seaborn as sns


[docs]def plot_confusion_matrix( confusion_matrix: np.ndarray, class_names: list, ax: matplotlib.axes.Axes = None ) -> matplotlib.axes.Axes: """Plots a confusion matrix, as returned by sklearn.metrics.confusion_matrix, as a heatmap. Based on https://gist.github.com/shaypal5/94c53d765083101efc0240d776a23823 Arguments --------- confusion_matrix The numpy.ndarray object returned from a call to sklearn.metrics.confusion_matrix. Similarly constructed ndarrays can also be used. class_names An ordered list of class names, in the order they index the given confusion matrix. figsize: A 2-long tuple, the first value determining the horizontal size of the ouputted figure, the second determining the vertical size. Defaults to (10,7). Returns ------- The resulting confusion matrix figure """ df_cm = pd.DataFrame(confusion_matrix, index=class_names, columns=class_names) if ax is None: fig, ax = plt.subplots() heatmap = sns.heatmap(df_cm, annot=True, cmap="Blues", ax=ax) heatmap.set(ylabel="True label", xlabel="Predicted label") return ax
[docs]def plot_prediction( train_data, test_data, prediction, ax: matplotlib.axes.Axes = None ) -> matplotlib.axes.Axes: """Plots case counts as step line, with outbreaks and alarms indicated by triangles.""" whole_data = pd.concat((train_data, test_data), sort=False) fontsize = 20 if ax is None: fig, ax = plt.subplots(figsize=(12, 8)) ax.step( x=whole_data.index, y=whole_data.n_cases, where="mid", color="blue", label="_nolegend_", ) alarms = prediction.query("alarm == 1") ax.plot(alarms.index, [0] * len(alarms), "g^", label="alarm", markersize=12) outbreaks = test_data.query("outbreak") ax.plot( outbreaks.index, outbreaks.n_outbreak_cases, "rv", label="outbreak", markersize=12, ) ax.set_xlabel("time", fontsize=fontsize) ax.set_ylabel("cases", fontsize=fontsize) ax.legend(fontsize="xx-large") return ax
[docs]def ghozzi_score_plot(prediction_result: pd.DataFrame, filename: str): """Plots case counts and detector predictions with ghozzi weighting. Parameters ---------- prediction_result DataFrame containing 'alarm', 'county', 'pathogen', 'n_cases', 'n_outbreak_cases', 'outbreak'. filename File name to write the plot to. """ # Outbreaks that were recognized. prediction_result["weighted_true_positives"] = ( prediction_result.alarm * prediction_result.outbreak * prediction_result.n_outbreak_cases ) # Outbreaks that were missed. prediction_result["weighted_false_negatives"] = ( (1 - prediction_result.alarm) * prediction_result.outbreak * prediction_result.n_outbreak_cases ) # Alarms that were falsely raised. prediction_result["weighted_false_positives"] = ( prediction_result.alarm * (prediction_result.outbreak != prediction_result.alarm) * np.mean(prediction_result.query("outbreak").n_outbreak_cases) ) melted_prediction_result = ( prediction_result.reset_index() .rename(columns={"index": "date"}) .melt( id_vars=[ "date", "county", "pathogen", "n_cases", "n_outbreak_cases", "outbreak", "alarm", ], var_name="prediction", value_name="weighting", ) ) case_color = "grey" n_cols = 4 n_filter_combinations = len( prediction_result[["county", "pathogen"]].drop_duplicates() ) chart = ( gg.ggplot(melted_prediction_result, gg.aes(x="date")) + gg.geom_bar( prediction_result, gg.aes(x="prediction_result.index", y="n_cases"), fill=case_color, stat="identity", ) + gg.geom_line(gg.aes(y=0), color=case_color) + gg.geom_bar(gg.aes(y="weighting", fill="prediction"), stat="identity") + gg.facet_wrap(["county", "pathogen"], ncol=n_cols) + gg.scale_x_date(date_breaks="4 month", date_labels="%Y-%m") + gg.ylab("# cases") + gg.scale_fill_manual(name="weighting", values=["red", "orange", "green"]) + gg.theme(panel_grid_minor=gg.element_blank()) + gg.theme_light() ) chart.save( filename, width=5 * n_cols, height=4 * n_filter_combinations / n_cols, unit="cm", limitsize=False, )