scanvi_explainer.utils.train_test_group_split#
- scanvi_explainer.utils.train_test_group_split(adata: AnnData, groupby: str, train_size: float = 0.8, batch_size: int = 128, layer: str = 'counts') tuple[dict[str, Tensor], dict[str, Tensor]]#
Function to split anndata object 80/20 per group in format required for SCANVIDeep explainer.
Bigger datasets might not fit to the GPU memory. To overcome this issue we recommend setting
batch_sizeto 128, meaning each group will only use randomly sampled 128 cells. This speeds up the explainer at the cost of correctness. We suggest bootstrapping this process multiple times.- Parameters:
- Returns:
Train and test splits
- Return type:
tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]