Hi LChart,
Thanks for the detailed post—this is a great question, as SysVI is still relatively new and finicky compared to scVI or Harmony, especially for cross-protocol integrations like yours (10X vs. split-pool). I've run into similar "wonky" cycle losses and arbitrary embeddings myself when tweaking it for multi-study scRNA-seq data. You're on the right track with subsetting to shared HVGs across studies (nice use of highly_variable_nbatches > 1), but I suspect the core issue is in your data prep for SysVI specifically. Let me break it down.
Quick Diagnosis
Input data mismatch: Unlike standard scVI (which loves raw counts for its negative binomial likelihood), SysVI is designed for normalized + log-transformed data in adata.X (assuming Gaussian noise). From the official SysVI tutorial:
For scRNA-seq data the integration should be performed on normalized and log-transformed data, with normalization being set to a fixed number of counts per cell. [...] The here-used example data was already preprocessed accordingly and has normalized data in X.
In your pipeline, you normalize/log the full adata early on (good!), compute HVGs on that, but then override asub.X = asub.layers['counts'].copy() before setup. This feeds raw counts into SysVI, which can cause unstable training, erratic cycle losses (that jagged plot screams likelihood mismatch), and poor integration. Raw counts expect NB modeling, but SysVI defaults to Gaussian on normalized data.
Hyperparams amplifying the issue: Your z_distance_cycle_weight=30 is aggressive (tutorial suggests 2–10 for balance; up to 50 only if batches are extremely disparate). Combined with raw data, it over-pulls on cycle consistency, leading to arbitrary mixing. kl_weight=0.5 is fine for preservation, but start with defaults.
Other nits: embed_categorical_covariates=False skips embedding your target/group keys into the decoder—set to True if those have many levels (saves memory and improves conditioning). n_layers=3 might overparameterize (default is 1); try dropping it. n_prior_components=7 is reasonable for VampPrior, but defaults work well too.
Harmony doing "quite a bit" but leaving residuals makes sense here—split-pool vs. 10X often has deeper technical diffs (e.g., UMI saturation, empty droplets) that SysVI targets via its system-specific priors.
Recommended Fix: Adjusted Pipeline
Keep your early steps (normalize/log/HVG/PCA/UMAP for viz), but tweak the SysVI block. Here's a minimal revision—assumes your full adata has lognorm in .X and raw in layers['counts'] post your initial sc.pp calls. Subset after confirming .X is normalized.
import scvi
from scvi.external import SysVI
import scanpy as sc
import matplotlib.pyplot as plt
# Your early steps unchanged...
sc.pp.normalize_total(adata, target_sum=1e4) # Explicit target for reproducibility
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=4500, batch_key='study_id', flavor='seurat_v3')
sc.tl.pca(adata)
sc.pp.neighbors(adata, use_rep='X_pca', n_pcs=15)
sc.tl.umap(adata)
adata.obsm['X_umap_pca'] = adata.obsm['X_umap']
# Subset to shared HVGs *without* overriding X (keep lognorm!)
asub = adata[:, adata.var.highly_variable_nbatches > 1].copy() # ~2800 genes
# NO: asub.X = asub.layers['counts'].copy() # <-- This is the culprit—remove it!
# Optional: ensure raw counts are in layers for later if needed
if 'counts' not in asub.layers:
asub.layers['counts'] = adata.layers['counts'][:, asub.var_names] if 'counts' in adata.layers else asub.X # Fallback
# Setup (your study_id as 'system' batch)
scvi.settings.dl_num_workers = 16
scvi.settings.seed = 0 # For repro
SysVI.setup_anndata(
adata=asub,
batch_key='study_id',
categorical_covariate_keys=['target', 'group']
)
# Model init—embed covars, drop extra layers, use defaults
model = SysVI(
adata=asub,
embed_categorical_covariates=True, # Key for your covars
# n_prior_components=7, # Default ~10 is fine; tune if losses diverge
# n_layers=1, # Default; 3 may overfit small data
)
# Train with milder weights—monitor losses!
max_epochs = 200
model.train(
max_epochs=max_epochs,
check_val_every_n_epoch=10, # Less frequent to save time
plan_kwargs={
"kl_weight": 1.0, # Default for balance
"z_distance_cycle_weight": 5 # Start low; bump to 10 if under-corrected
}
)
# Quick loss inspection (adapt from tutorial)
fig, axs = plt.subplots(2, 3, figsize=(12, 6))
losses = ["reconstruction_loss", "kl_local", "cycle_loss"]
for i, l in enumerate(losses):
for j, split in enumerate(["_train", "_validation"]):
key = l + split
if key in model.history:
values = model.history[key]
axs[j, i].plot(values)
axs[j, i].set_title(f"{l}{split}")
plt.tight_layout()
plt.show() # Zoom on last 50 epochs—should stabilize, train/val close
- Post-training: Grab latents with
latent = model.get_latent_representation(), then sc.pp.neighbors(asub, use_rep='X_scvi'); sc.tl.umap(asub); plot colored by study_id/target/leiden. If cycle loss still jumps, drop to z_distance_cycle_weight=2 or disable (=0) to revert to scVI-like cVAE.
- Seeds matter: SysVI is stochastic—run 3x with different seeds, pick the best by kBET/silhouette on held-out batch mixing.
- Scale: With ~2800 genes and multi-study cells, 200 epochs should converge in <30min on GPU. If OOM, batch-size down via
model.train(early_stopping=True).
This should smooth the cycle loss and give cleaner separation (e.g., biology by target/group, not clumped by study_id). If it still looks off (e.g., over-mixing cell types), your batches might not be "system-level" disparate enough—fall back to scVI (just swap scvi.model.SCVI and feed raw counts). I've had success with scVI + totalVI for mixed protocols too.
Kevin