Source code for besca.tl._annot_compare

import csv
import os
import sys

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc

from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    f1_score,
    make_scorer,
    adjusted_mutual_info_score,
    adjusted_rand_score,
    silhouette_score
)

from sklearn.metrics.cluster import pair_confusion_matrix

from besca.pl._riverplot import riverplot_2categories


[docs]def report( adata_pred, celltype, method, analysis_name, train_datasets=[], test_dataset="", merge="", name_prediction="auto_annot", name_report="auto_annot", use_raw=False, genes_to_use="", remove_nonshared=False, clustering="leiden", asymmetric_matrix=True, results_folder="./", delimiter=",", verbose=False, ): """reports basic metrics, produces confusion matrices and plots umap of prediction Writes out a csv file containing all accuracy and f1 scores. Writes normalized and absolute confusion matrices, as well as umap prediction comparisons to ./figures. parameters ---------- adata_pred: AnnData original adata object with name_prediction column celltype: `str` celltype column on which the prediction was performed method: `str` method that was used for prediction. analysis_name: `str` name of the analyis, used for writing files train_datasets: `list` list of used training datasets test_dataset: `str` name of test dataset merge: `str` what merging was performed name_prediction : "auto_annot"| default = "auto_annot" observation name containing the prediction to compare with. name_report : "auto_annot"| default = "auto_annot" prefix of the report use_raw: `bool` | default = False if anndata.raw was used genes_to_use: `list` or `string` | default = 'all' what geneset wsa used remove_nonshared: `bool`|default = False clustering: `str` | default = leiden clustering that was used in original analysis of testing set, needed for umap plotting asymmetric_matrix: `bool` | default = True if False returns square confusion matrix, if True it only shows possible combinations results_folder: `str` | default = './' output directory. A figures folder will be generated within it. delimiter: `str` | default = ',' separator between fields in the csv/txt report file verbose: `bool` | default = False print verbose messages to standard out returns ------- Figure A matplotlib figure element containing the riveplot generated for interactive display. """ # calculate umaps for plot if "X_umap" not in adata_pred.obsm: sc.tl.umap(adata_pred) if name_prediction not in adata_pred.obs.keys(): sys.exit( name_prediction + " label not found in the predicted dataset (should be in obs)" ) # get acc acc = accuracy_score(adata_pred.obs[celltype], adata_pred.obs[name_prediction]) # get f1 f1 = f1_score( adata_pred.obs[celltype], adata_pred.obs[name_prediction], labels=adata_pred.obs[celltype], average="macro", ) if verbose: print("acc: " + str(round(acc, 2))) print("f1: " + str(round(f1, 2))) # get report class_report = classification_report( adata_pred.obs[celltype], adata_pred.obs[name_prediction], output_dict=True ) sklearn_report = round(pd.DataFrame(class_report).transpose(), 2) # get clustering scores ami = adjusted_mutual_info_score( adata_pred.obs[celltype], adata_pred.obs[name_prediction] ) ari = adjusted_rand_score(adata_pred.obs[celltype], adata_pred.obs[name_prediction]) silhouette_celltype = silhouette_score( adata_pred.obsm["X_umap"], adata_pred.obs.get(celltype) ) silhouette_pred = silhouette_score( adata_pred.obsm["X_umap"], adata_pred.obs.get(name_prediction) ) pair_conf_m = pair_confusion_matrix( adata_pred.obs[celltype], adata_pred.obs[name_prediction] ) if verbose: print("ami: " + str(round(ami, 2))) print("ari: " + str(round(ari, 2))) print("silhouette " + celltype + ": " + str(round(silhouette_celltype, 2))) print( "silhouette " + name_prediction + ": " + str(str(round(silhouette_pred, 2))) ) print("pair confusion matrix:\n" + str(pd.DataFrame(pair_conf_m))) # csv file with important metrics file_ending = ".txt" if delimiter == ",": file_ending = ".csv" with open( os.path.join( results_folder, name_report + "_report_" + analysis_name + file_ending ), mode="w", ) as report_file: report_writer = csv.writer( report_file, delimiter=delimiter, quotechar='"', quoting=csv.QUOTE_MINIMAL ) report_writer.writerow( ["train_dataset = ", train_datasets, "test_datset = ", test_dataset] ) report_writer.writerow(["celltype = ", celltype, "method = ", method]) report_writer.writerow( [ "remove_nonshared = ", remove_nonshared, "merge = ", merge, "use_raw = ", use_raw, "genes_to_use = ", genes_to_use, ] ) report_writer.writerow(["accuracy=", round(acc, 2), "f1=", round(f1, 2)]) report_writer.writerow(["clustering report"]) report_writer.writerow(["ari=", round(ari, 2), "ami=", round(ami, 2)]) report_writer.writerow( [ "silhouette_celltype=", round(silhouette_celltype, 2), "silhouette_pred=", round(silhouette_pred, 2), ] ) report_writer.writerow(["pair confusion matrix"]) pd.DataFrame(pair_conf_m).to_csv(report_file, header=True, sep=delimiter) report_writer.writerow(["classification report"]) sklearn_report.to_csv(report_file, header=True, sep=delimiter) # make umap sc.settings.set_figure_params(dpi=120) sc.pl.umap( adata_pred, color=[celltype, name_prediction, clustering], legend_loc="on data", legend_fontsize=7, frameon=False, save=".ondata_" + analysis_name + ".png", ) for col in [celltype, name_prediction, clustering]: sc.pl.umap( adata_pred, color=col, wspace=1.5, frameon=False, save="." + analysis_name + "_" + col + ".png", ) sc.settings.set_figure_params(dpi=60) os.makedirs(os.path.join(results_folder, "figures"), exist_ok=True) # plot basic riverplot riverplot = riverplot_2categories( adata=adata_pred, categories=[celltype, name_prediction] ) riverplot.show() riverplot.write_image( os.path.join( results_folder, "figures", method + "_riverplot_" + analysis_name + "_" + celltype + "_" + name_prediction + ".svg", ) ) # make conf matrices (4) class_names = np.unique( np.concatenate((adata_pred.obs[celltype], adata_pred.obs[name_prediction])) ) np.set_printoptions(precision=2) # Plot non-normalized confusion matrix fig = plot_confusion_matrix( adata_pred.obs[celltype], adata_pred.obs[name_prediction], classes=class_names, celltype=celltype, name_prediction=name_prediction, title="Confusion matrix, without normalization", numbers=False, adata_predicted=adata_pred, asymmetric_matrix=asymmetric_matrix, ) fig.show() fig.savefig( os.path.join( results_folder, "figures", method + "_confusion_matrix_" + analysis_name + "_" + celltype + ".svg", ) ) # Plot normalized confusion matrix with numbers fig = plot_confusion_matrix( adata_pred.obs[celltype], adata_pred.obs[name_prediction], classes=class_names, celltype=celltype, name_prediction=name_prediction, normalize=True, title="Normalized confusion matrix", numbers=False, adata_predicted=adata_pred, asymmetric_matrix=asymmetric_matrix, ) fig.show() fig.savefig( os.path.join( results_folder, "figures", method + "_confusion_matrix_norm_" + analysis_name + "_" + celltype + ".svg", ) ) return riverplot
[docs]def plot_confusion_matrix( y_true, y_pred, classes, celltype, name_prediction="auto_annot", normalize=False, title=None, numbers=False, cmap=plt.cm.Blues, adata_predicted=None, asymmetric_matrix=True, ): """plots confusion matrices returns a matplotlib confusion matrix parameters ---------- y_true: pandas.core.series.Series ordered series of all true labels y_pred: pandas.core.series.Series ordered series of all predicted celltypes classes: numpy.ndarray union of true and predictable celltypes celltype: `str` celltype column on which the prediction was performed name_prediction : "auto_annot"| default = "auto_annot" observation name containing the prediction to compare with. normalize: `bool` | default = False whether to return absolute values or to value all celltypes equally title: `str` | default = None title to be given to confusion matrix figure in file. numbers: `bool`| default = False should the numbers be displayed in the plot. Note: is illegible in larger plots cmap: matplotlib.cm | default = plt.cm.Blues colour to be used for plotting asymmetric_matrix: `bool` | default = True if False returns square confusion matrix, if True it only shows possible combinations returns ------- matplotlib.pyplot.plot plot of confusion matrix """ if not title: if normalize: title = "Normalized confusion matrix" else: title = "Confusion matrix, without normalization" # Compute confusion matrix cm = confusion_matrix(y_true, y_pred) # Only use the labels that appear in the data # classes = classes[unique_labels(y_true, y_pred)] if asymmetric_matrix == True: class_names = np.unique( np.concatenate( (adata_predicted.obs[celltype], adata_predicted.obs[name_prediction]) ) ) class_names_orig = np.unique(adata_predicted.obs[celltype]) class_names_pred = np.unique(adata_predicted.obs[name_prediction]) test_celltypes_ind = np.searchsorted(class_names, class_names_orig) train_celltypes_ind = np.searchsorted(class_names, class_names_pred) cm = cm[test_celltypes_ind, :][:, train_celltypes_ind] if normalize: cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] # print("Normalized confusion matrix") # else: # print("Confusion matrix, without normalization") fig, ax = plt.subplots(figsize=(15, 15)) im = ax.imshow(cm, interpolation="nearest", cmap=cmap) ax.figure.colorbar(im, ax=ax, shrink=0.8) # We want to show all ticks... if asymmetric_matrix == True: ax.set( xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]), # ... and label them with the respective list entries xticklabels=class_names_pred, yticklabels=class_names_orig, title=title, ylabel="True label", xlabel="Predicted label", ) else: ax.set( xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]), # ... and label them with the respective list entries xticklabels=classes, yticklabels=classes, title=title, ylabel="True label", xlabel="Predicted label", ) ax.grid(False) # ax.tick_params(axis='both', which='major', labelsize=10) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. if numbers == True: fmt = ".2f" if normalize else "d" thresh = cm.max() / 2.0 for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text( j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black", ) # fig.tight_layout() return fig