Configure the model architecture and training parameters
This tutorial provides a step-by-step guide on configuring the model architectures, training hyperparameters, and analysis of time-series single dataset using UNAGI. We demonstrate the capabilities of UNAGI by applying it to scRNA-seq data sampled from a single-nuclei RNA sequencing data.
import warnings
warnings.filterwarnings('ignore')
from UNAGI import UNAGI
unagi = UNAGI()
Part 1: Setup and load the datasets
After loading UNAGI package, we need to setup the data for UNAGI training.
We need to specify the data path of your h5ad files after stage segmentation. e.g. ‘../data/small/0.h5ad’. Then UNAGI will load all h5ad files in the target directory.
UNAGI requires the total number of time-points the dataset has as the input. e.g. total_stage=4
UNAGI requires the key of time-points attribute in the annData.obs table.
If the dataset is not splited into individual stages, you can specify the splited_dataset as False to segment the dataset.
To build the K-Nearest Neighbors (KNN) connectivity matrix in Graph convolution training, the neighbors number of KNN should be defined. The default value is 25.
You can also specify how many threads you want to use when using UNAGI. The default number of threads is 20.
unagi.setup_data('../UNAGI/data/example',total_stage=4,stage_key='stage')
Part 2: Configure the model architecture of UNAGI and training hyper-parameters
First, it’s mandatory to specify the task your are executing. (e.g. we call the example dataset as task=’small_sample’) The task is the identifier of your experiments and you can reterive the trained model and the results of each iteration at ‘../data/task/’ directory.
Next, you will need to specify the distribution of you single cell data. UNAGI provides negative binomial (NB), zero-inflated negative binomial (ZINB), zero-inflated log normal, and normal distribution to model your single cell data.
You can use the device keyword to specify the device you want to use for training.
‘epoch_initial’: the number of training epochs for the first iteration.
‘epoch_iter’: the number of training epochs for the iterative training.
‘max_iter’: the total number of iterations UNAGI will run
‘BATCHSIZE’: the batch size of a mini-batch
‘lr’: the learning rate of Graph VAE
‘lr_dis’: the learning rate of the adversarial discriminator
‘latent_dim’: the dimension of Z space
‘hiddem_dim’: the neuron size of each fully connected layers
‘graph_dim’: the dimension of graph representation
After settingt the training hyper parameters and model architectures, you can use unagi.run_UNAGI() to start training.
unagi.setup_training('example',dist='ziln',device='cuda:0',GPU=True,epoch_iter=5,epoch_initial=1,max_iter=3,BATCHSIZE=560)
unagi.run_UNAGI(idrem_dir = '../../idrem')
...
0
loss 9241.198552284153
[epoch 000] average training loss: 9241.1986
(13550, 2484)
top gene
done
write stageadata
update done
(4195, 2484)
top gene
done
write stageadata
update done
(3152, 2484)
top gene
done
write stageadata
update done
(6750, 2484)
top gene
done
write stageadata
update done
edges updated
[[[3], [0], [2, 3, 15], [0, 6, 7, 8, 12]], [[5], [5], [4], [5]]]
b''
[[[3], [0], [2, 3, 15], [0, 6, 7, 8, 12]], [[5], [5], [4], [5]]]
['3', '0', '2n3n15', '0n6n7n8n12']
['5', '5', '4', '5']
['5-5-4-5.txt', '3-0-2n3n15-0n6n7n8n12.txt', '3-0-0n11-0n10n15n16n22.txt', '3-1-1n3-0n4n5n7.txt', '4-5-6-6.txt', '6-10-5-6.txt', '9-14-12-4.txt', '5-4-4-3n9.txt', '7-8n17-6n7-7.txt', '6-6-7-8.txt']
b"java.lang.IllegalArgumentException: WARNING: 'TF-Minimum_Absolute_Log_Ratio_Expression' is an unrecognized variable.\n\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.parseDefaults(DREM_IO_Batch.java:916)\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.<init>(DREM_IO_Batch.java:233)\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.<init>(DREM_IO_Batch.java:206)\n\tat edu.cmu.cs.sb.drem.DREM_IO.main(DREM_IO.java:5613)\n"
b"java.lang.IllegalArgumentException: WARNING: 'TF-Minimum_Absolute_Log_Ratio_Expression' is an unrecognized variable.\n\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.parseDefaults(DREM_IO_Batch.java:916)\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.<init>(DREM_IO_Batch.java:233)\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.<init>(DREM_IO_Batch.java:206)\n\tat edu.cmu.cs.sb.drem.DREM_IO.main(DREM_IO.java:5613)\n"
b'java.lang.IllegalArgumentException: All Genes Filtered\n\tat edu.cmu.cs.sb.core.DataSetCore.filtergenesgeneral(DataSetCore.java:902)\n\tat edu.cmu.cs.sb.core.DataSetCore.filtergenesthreshold2change(DataSetCore.java:1011)\n\tat edu.cmu.cs.sb.core.DataSetCore.filtergenesthreshold2(DataSetCore.java:979)\n\tat edu.cmu.cs.sb.drem.DREM_IO.buildset(DREM_IO.java:1920)\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.clusterscript(DREM_IO_Batch.java:1313)\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.<init>(DREM_IO_Batch.java:264)\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.<init>(DREM_IO_Batch.java:206)\n\tat edu.cmu.cs.sb.drem.DREM_IO.main(DREM_IO.java:5613)\n'
b'0\n1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\nwriting Json..\nTime: 30157ms\n'
b'java.lang.IllegalArgumentException: All Genes Filtered\n\tat edu.cmu.cs.sb.core.DataSetCore.filtergenesgeneral(DataSetCore.java:902)\n\tat edu.cmu.cs.sb.core.DataSetCore.filtergenesthreshold2change(DataSetCore.java:1011)\n\tat edu.cmu.cs.sb.core.DataSetCore.filtergenesthreshold2(DataSetCore.java:979)\n\tat edu.cmu.cs.sb.drem.DREM_IO.buildset(DREM_IO.java:1920)\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.clusterscript(DREM_IO_Batch.java:1313)\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.<init>(DREM_IO_Batch.java:264)\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.<init>(DREM_IO_Batch.java:206)\n\tat edu.cmu.cs.sb.drem.DREM_IO.main(DREM_IO.java:5613)\n'
b'0\n1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\nwriting Json..\nTime: 27588ms\n'
b'0\n1\n2\n3\n4\n5\n6\n7\n8\n9\n10\nwriting Json..\nTime: 30188ms\n'
b'0\n1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\n13\n14\nwriting Json..\nTime: 33155ms\n'
b'## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\n## NaN 0 0.0 0.0 0.0 0.0 -10.0 NaN\njava.lang.ArrayIndexOutOfBoundsException: Index -1 out of bounds for length 3\n\tat edu.cmu.cs.sb.drem.DREM_Timeiohmm.deleteMinPath(DREM_Timeiohmm.java:2382)\n\tat edu.cmu.cs.sb.drem.DREM_Timeiohmm.<init>(DREM_Timeiohmm.java:676)\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.clusterscript(DREM_IO_Batch.java:1359)\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.<init>(DREM_IO_Batch.java:264)\n\tat edu.cmu.cs.sb.drem.DREM_IO_Batch.<init>(DREM_IO_Batch.java:206)\n\tat edu.cmu.cs.sb.drem.DREM_IO.main(DREM_IO.java:5613)\n'
b'0\n1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\n13\n14\n15\nwriting Json..\nTime: 32050ms\n'
/mnt/md0/yumin/to_upload/UNAGI/tutorials
b''
b''
b''
idrem Done
getting TFs from 3-0-0n11-0n10n15n16n22.txt_viz
getting TFs from 5-4-4-3n9.txt_viz
getting TFs from 9-14-12-4.txt_viz
getting TFs from 6-10-5-6.txt_viz
getting TFs from 7-8n17-6n7-7.txt_viz
getting Target genes from 3-0-0n11-0n10n15n16n22.txt_viz
getting Target genes from 5-4-4-3n9.txt_viz
getting Target genes from 9-14-12-4.txt_viz
getting Target genes from 6-10-5-6.txt_viz
getting Target genes from 7-8n17-6n7-7.txt_viz
number of idrem file 0
stage 1
number of idrem file 1
stage 1
number of idrem file 2
stage 1
number of idrem file 3
stage 1
number of idrem file 4
stage 1
number of idrem file 4
stage 1
number of idrem file 0
stage 2
number of idrem file 0
stage 2
number of idrem file 1
stage 2
number of idrem file 2
stage 2
number of idrem file 3
stage 2
number of idrem file 4
stage 2
number of idrem file 4
stage 2
number of idrem file 0
stage 3
number of idrem file 0
stage 3
number of idrem file 0
stage 3
number of idrem file 0
stage 3
number of idrem file 0
stage 3
number of idrem file 1
stage 3
number of idrem file 1
stage 3
number of idrem file 2
stage 3
number of idrem file 3
stage 3
number of idrem file 4
stage 3
27646
...
load last iteration model.....
0
loss 3684.8589019604296
[epoch 000] average training loss: 3684.8589
1
loss 2963.5102881867833
[epoch 001] average training loss: 2963.5103
2
loss 2694.852858131081
[epoch 002] average training loss: 2694.8529
3
loss 2546.214904080913
[epoch 003] average training loss: 2546.2149
4
loss 2461.9431945599886
[epoch 004] average training loss: 2461.9432
(13550, 2484)
top gene
done
Part 3: Perform in-silico perturbations and downstream analysis
After training the UNAGI model, you can perfrom downstream tasks including hierarchical static marker discovries parameters: data_path: the directory of the dataset generated by UNAGI iteration: the iteration of the dataset belongs to progressionmarker_background_sampling_times: the number of sampling times to generate the dynamic marker backgrounds target_dir: the directory to store the downstream analysis results and h5ad files customized_drug: the directory to customized drug profile cmap_dir: the directory to the precomputed CMAP database which contains the drug/compounds and their regualted genes and regualated directions.
import warnings
warnings.filterwarnings('ignore')
from UNAGI import UNAGI
unagi = UNAGI()
unagi.analyse_UNAGI('../UNAGI/data/example/2/stagedata/org_dataset.h5ad',2,10,target_dir=None,customized_drug='../UNAGI/data/jasper_target_pair.npy',cmap_dir='../../CMAPDirectionDf.npy')