Demo example

Contents

Demo example#

Date: 27-09-2024

Author: Martin Proks

!export CUDA_VISIBLE_DEVICES=1
from scvi.hub import HubModel
from numba.core.errors import NumbaDeprecationWarning

import warnings
warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
from scanvi_explainer import SCANVIDeep, SCANVIBoostrapper
from scanvi_explainer.plots import feature_plot
hmo = HubModel.pull_from_huggingface_hub(
    repo_name="brickmanlab/mouse-scanvi",
    cache_dir="/tmp/mouse_scanvi",
    revision="v1.0",
)
lvae = hmo.model
lvae
INFO     Loading model...
INFO     File                                                                                                      
         /tmp/mouse_scanvi/models--brickmanlab--mouse-scanvi/snapshots/122feddff5447c62e8a0b320650dbb6c7a1d764a/mod
         el.pt already downloaded
ScanVI Model with the following params: 
unlabeled_category: Unknown, n_hidden: 128, n_latent: 10, n_layers: 2, dropout_rate: 0.1, dispersion: gene, 
gene_likelihood: nb
Training status: Trained
Model's adata is minified?: False

e = SCANVIDeep(lvae, train_size=0.8, batch_size=128)
e
SCANVIDeep with the following parameters:
train_size=0.8, test_size=0.2, batch_size=128, labels_key=ct, layers_key=counts
training_on=cuda:0

shap_values = e.shap_values()
import shap


shap.summary_plot(
    shap_values,
    e.test['X'],
    feature_names=lvae.adata.var_names, 
    class_names=lvae.adata.obs.ct.cat.categories
)
../_images/dc88830865d12534b09b65fee4ac246ba671dec9f8a40450ee09d80822f5b708.png
lvae.adata.var
gene_ids gene_symbol mt n_cells_by_counts mean_counts pct_dropout_by_counts total_counts n_cells highly_variable means dispersions dispersions_norm highly_variable_nbatches highly_variable_intersection
sox17 ENSMUSG00000025902 Sox17 False 641 20.303165 68.298714 41053.0 641 True 0.383795 0.916023 3.415429 7 False
ppp1r42 ENSMUSG00000025916 Ppp1r42 False 50 0.649357 97.527201 1313.0 50 True 0.010606 0.317620 1.000011 4 False
arfgef1 ENSMUSG00000067851 Arfgef1 False 1595 200.979723 21.117705 406381.0 1595 True 1.193781 0.573535 1.177822 4 False
prdm14 ENSMUSG00000042414 Prdm14 False 639 18.401088 68.397626 37207.0 639 True 0.333239 0.770153 1.246952 4 False
xkr9 ENSMUSG00000067813 Xkr9 False 327 12.995054 83.827893 26276.0 327 True 0.222491 0.630573 1.196979 3 False
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
habp2 ENSMUSG00000025075 Habp2 False 158 2.715628 92.185955 5491.0 158 True 0.059447 0.638976 1.476901 3 False
ccdc186 ENSMUSG00000035173 Ccdc186 False 1231 331.236400 39.119683 669760.0 1231 True 1.094922 0.594682 1.020226 4 False
afap1l2 ENSMUSG00000025083 Afap1l2 False 251 56.485658 87.586548 114214.0 251 True 0.275932 0.754868 1.354747 5 False
pnlip ENSMUSG00000046008 Pnlip False 140 2.630564 93.076162 5319.0 140 True 0.041573 0.695832 0.864209 4 False
pnliprp2 ENSMUSG00000025091 Pnliprp2 False 799 27.727498 60.484669 56065.0 799 True 0.457133 0.893608 2.415332 6 False

3000 rows × 14 columns

feature_plot(e, shap_values, subset=True, top_n=10)
../_images/1e3f8c336790669ec8654bd574dad748599ef0f91bf17a3b5efea5c2c77b30d9.png

Bootstrapper#

In order to strenghten the predicted features (genes), we have also implemented bootstrapping approach. The plots below calculates \(\mu\) value of each bootstrap. To adjust parameters, please refer to the documentation.

bootstrapper = SCANVIBoostrapper(lvae, n_bootstraps=10)
shap_values = bootstrapper.run(train_size=0.8, batch_size=64)
bootstrapper.save(shap_values, './bootstrapped_shaps.feather')
bootstrapper.feature_plot(shap_values)
../_images/665cb135e4e3075c882cd45d1e89a1f9659007e319c928ea513fcf4d96982b4e.png
bootstrapper.feature_plot(shap_values, kind="barplot", gene_symbols='gene_ids')
../_images/aea43b7c7208fdddc91ca8315ed636b7f33424a9df34dfa3e3bd4c602b822d2b.png