'''
This is the main module of UNAGI. It contains the UNAGI class, which is the main class of UNAGI. It also contains the functions to prepare the data, start the model training and start analysing the perturbation results. Initially, `setup_data` function should be used to prepare the data. Then, `setup_training`` function should be used to setup the training parameters. Finally, `run_UNAGI` function should be used to start the model training. After the model training is done, `analyse_UNAGI` function should be used to start the perturbation analysis.
'''
import subprocess
from tracemalloc import start
import numpy as np
from .utils.attribute_utils import split_dataset_into_stage, get_all_adj_adata
import os
import scanpy as sc
import gc
from .utils.gcn_utils import get_gcn_exp
from .train.runner import UNAGI_runner
import torch
from .model.models import VAE,Discriminator,Plain_VAE
from .UNAGI_analyst import analyst
from .train.trainer import UNAGI_trainer
[docs]
class UNAGI:
'''
The UNAGI class is the main class of UNAGI. It contains the function to prepare the data, start the model training and start analysing the perturbation results.
'''
[docs]
def __init__(self,):
self.CPO_parameters = None
self.iDREM_parameters = None
self.species = 'Human'
self.input_dim = None
[docs]
def setup_data(self,data_path,stage_key,total_stage,gcn_connectivities=False,neighbors=25,threads = 20):
'''
The function to specify the data directory, the attribute name of the stage information and the total number of time stages of the time-series single-cell data. If the input data is a single h5ad file, then the data will be split into multiple h5ad files based on the stage information. The function can take either the h5ad file or the directory as the input. The function will check weather the data is already splited into stages or not. If the data is already splited into stages, the data will be directly used for training. Otherwise, the data will be split into multiple h5ad files based on the stage information. The function will also calculate the cell graphs for each stage. The cell graphs will be used for the graph convolutional network (GCN) based cell graph construction.
parameters
--------------
data_path: str
the directory of the h5ad file or the folder contains data.
stage_key: str
the attribute name of the stage information.
total_stage: int
the total number of time stages of the time-series single-cell data.
gcn_connectivities: bool
whether the cell graphs are already calculated. Default is False.
neighbors: int
the number of neighbors for each cell used to construct the cell neighbors graph, default is 25.
threads: int
the number of threads for the cell graph construction, default is 20.
'''
if total_stage < 2:
raise ValueError('The total number of stages should be larger than 1')
if os.path.isfile(data_path):
self.data_folder = os.path.dirname(data_path)
else:
self.data_folder = data_path
#os.path.dirname(data_path)
self.stage_key = stage_key
if os.path.exists(os.path.join(self.data_folder ,'0.h5ad')):
temp = sc.read(os.path.join(self.data_folder , '0.h5ad'))
self.input_dim = temp.shape[1]
if 'gcn_connectivities' not in list(temp.obsp.keys()):
gcn_connectivities = False
else:
gcn_connectivities = True
else:
print('The dataset is not splited into stages, please use setup_data function to split the dataset into stages first')
self.data_path = data_path
self.input_dim = split_dataset_into_stage(self.data_path, self.data_folder, self.stage_key)
gcn_connectivities = False
self.data_path = os.path.join(data_path,'0.h5ad')
self.ns = total_stage
#data folder is the folder that contains all the h5ad files
self.data_folder = os.path.dirname(self.data_path)
if os.path.exists(os.path.join(self.data_folder , '0')):
raise ValueError('The iteration 0 folder is already existed, please remove the folder and rerun the code')
if os.path.exists(os.path.join(self.data_folder , '0/stagedata')):
raise ValueError('The iteration 0/stagedata folder is already existed, please remove the folder and rerun the code')
if os.path.exists(os.path.join(self.data_folder , 'model_save')):
raise ValueError('The iteration model_save folder is already existed, please remove the folder and rerun the code')
dir1 = os.path.join(self.data_folder , '0')
dir2 = os.path.join(self.data_folder , '0/stagedata')
dir3 = os.path.join(self.data_folder , 'model_save')
initalcommand = 'mkdir '+ dir1 +' && mkdir '+dir2 +' && mkdir '+dir3
p = subprocess.Popen(initalcommand, stdout=subprocess.PIPE, shell=True)
if not gcn_connectivities:
print('Cell graphs not found, calculating cell graphs for individual stages! Using K=%d and threads=%d for cell graph construction'%(neighbors,threads))
self.calculate_neighbor_graph(neighbors,threads)
else:
print('Cell graphs found, skipping cell graph construction!')
[docs]
def calculate_neighbor_graph(self, neighbors=25,threads = 20):
'''
The function to calculate the cell graphs for each stage. The cell graphs will be used for the graph convolutional network (GCN) based cell graph construction.
parameters
--------------
neighbors: int
the number of neighbors for each cell, default is 25.
threads: int
the number of threads for the cell graph construction, default is 20.
'''
get_gcn_exp(self.data_folder, self.ns ,neighbors,threads= threads)
[docs]
def setup_training(self,
task,
dist,
device=None,
epoch_iter=10,
epoch_initial=20,
lr=1e-4,
lr_dis = 5e-4,
beta=1,
hidden_dim=256,
latent_dim=64,
graph_dim=1024,
BATCHSIZE=512,
max_iter=10,
GPU=False,
adversarial=True,
GCN=True):
'''
Set up the training parameters and the model parameters.
parameters
--------------
task: str
the name of this task. It is used to name the output folder.
dist: str
the distribution of the single-cell data. Chosen from 'ziln' (zero-inflated log normal), 'zinb' (zero-inflated negative binomial), 'zig' (zero-inflated gamma), and 'nb' (negative binomial).
device: str
the device to run the model. If GPU is enabled, the device should be specified. Default is None.
epoch_iter: int
the number of epochs for the iterative training process. Default is 10.
epoch_initial: int
the number of epochs for the inital iteration. Default is 20.
lr: float
the learning rate of the VAE model. Default is 1e-4.
lr_dis: float
the learning rate of the discriminator. Default is 5e-4.
beta: float
the beta parameter of the beta-VAE. Default is 1.
hiddem_dim: int
the hidden dimension of the VAE model. Default is 256.
latent_dim: int
the latent dimension of the VAE model. Default is 64.
graph_dim: int
the dimension of the GCN layer. Default is 1024.
BATCHSIZE: int
the batch size for the model training. Default is 512.
max_iter: int
the maximum number of iterations for the model training. Default is 10.
GPU: bool
whether to use GPU for the model training. Default is False.
'''
self.dist = dist
self.device = device
self.epoch_iter = epoch_iter
self.epoch_initial = epoch_initial
self.lr = lr
self.beta = beta
self.lr_dis = lr_dis
self.task = task
self.latent_dim = latent_dim
self.graph_dim = graph_dim
self.hidden_dim = hidden_dim
self.BATCHSIZE = BATCHSIZE
self.max_iter = max_iter
self.GPU = GPU
#if self.input is not existed then raised error
if self.input_dim is None:
raise ValueError('Please use setup_data function to prepare the data first')
if GCN:
self.model = VAE(self.input_dim, self.hidden_dim,self.graph_dim, self.latent_dim,beta=self.beta,distribution=self.dist)
else:
self.model = Plain_VAE(self.input_dim, self.hidden_dim,self.graph_dim, self.latent_dim,beta=self.beta,distribution=self.dist)
self.GCN = GCN
self.adversarial = adversarial
if self.adversarial:
self.dis_model = Discriminator(self.input_dim)
else:
self.dis_model = None
self.training_parameters = {
'dist': self.dist,
'device': self.device,
'epoch_iter': self.epoch_iter,
'epoch_initial': self.epoch_initial,
'lr': self.lr,
'beta': self.beta,
'lr_dis': self.lr_dis,
'task': self.task,
'latent_dim': self.latent_dim,
'graph_dim': self.graph_dim,
'hidden_dim': self.hidden_dim,
'BATCHSIZE': self.BATCHSIZE,
'max_iter': self.max_iter,
'GPU': self.GPU,
'input_dim': self.input_dim, # assuming self.input_dim is defined elsewhere
'GCN': self.GCN,
'adversarial': self.adversarial,
'total_stage':self.ns
}
if self.GPU:
assert self.device is not None, "GPU is enabled but device is not specified"
self.device = torch.device(self.device)
else:
self.device = torch.device('cpu')
self.unagi_trainer = UNAGI_trainer(self.model,self.dis_model,self.task,self.BATCHSIZE,self.epoch_initial,self.epoch_iter,self.device,self.lr, self.lr_dis,GCN=self.GCN,cuda=self.GPU)
[docs]
def register_CPO_parameters(self,anchor_neighbors=15, max_neighbors=35, min_neighbors=10, resolution_min=0.8, resolution_max=1.5):
'''
The function to register the parameters for the CPO analysis. The parameters will be used to perform the CPO analysis.
parameters
--------------
anchor_neighbors: int
the number of neighbors for each anchor cell.
max_neighbors: int
the maximum number of neighbors for each cell.
min_neighbors: int
the minimum number of neighbors for each cell.
resolution_min: float
the minimum resolution for the Leiden community detection.
resolution_max: float
the maximum resolution for the Leiden community detection.
'''
self.CPO_parameters = {}
self.CPO_parameters['anchor_neighbors'] = anchor_neighbors
self.CPO_parameters['max_neighbors'] = max_neighbors
self.CPO_parameters['min_neighbors'] = min_neighbors
self.CPO_parameters['resolution_min'] = resolution_min
self.CPO_parameters['resolution_max'] = resolution_max
[docs]
def register_species(self,species):
'''
The function to register the species of the single-cell data.
parameters
--------------
species: str
the species of the single-cell data.
'''
if species not in ['human','mouse', 'Human', 'Mouse']:
raise ValueError('species should be either human or mouse')
if species == 'human':
species = 'Human'
if species == 'mouse':
species = 'Mouse'
self.species = species
[docs]
def register_iDREM_parameters(self,Normalize_data = 'Log_normalize_data', Minimum_Absolute_Log_Ratio_Expression = 0.5, Convergence_Likelihood = 0.001, Minimum_Standard_Deviation = 0.5):
'''
The function to register the parameters for the iDREM analysis. The parameters will be used to perform the iDREM analysis.
parameters
--------------
Normalize_data: str
the method to normalize the data. Chosen from 'Log_normalize_data' (log normalize the data), 'Normalize_data' (normalize the data), and 'No_normalize_data' (do not normalize the data).
Minimum_Absolute_Log_Ratio_Expression: float
the minimum absolute log ratio expression for the iDREM analysis.
Convergence_Likelihood: float
the convergence likelihood for the iDREM analysis.
Minimum_Standard_Deviation: float
the minimum standard deviation for the iDREM analysis.
'''
self.iDREM_parameters = {}
if Normalize_data not in ['Log_normalize_data','Normalize_data','No_normalize_data']:
raise ValueError('Normalize_data should be chosen from Log_normalize_data, Normalize_data and No_normalize_data')
self.iDREM_parameters['Normalize_data'] = Normalize_data
self.iDREM_parameters['Minimum_Absolute_Log_Ratio_Expression'] = Minimum_Absolute_Log_Ratio_Expression
self.iDREM_parameters['Convergence_Likelihood'] = Convergence_Likelihood
self.iDREM_parameters['Minimum_Standard_Deviation'] = Minimum_Standard_Deviation
[docs]
def run_UNAGI(self,idrem_dir,CPO=True,resume=False,resume_iteration=None):
'''
The function to launch the model training. The model will be trained iteratively. The number of iterations is specified by the `max_iter` parameter in the `setup_training` function.
parameters
--------------
idrem_dir: str
the directory to the iDREM tool which is used to reconstruct the temporal dynamics.
transcription_factor_file: str
the directory to the transcription factor file. The transcription factor file is used to perform the CPO analysis.
'''
start_iteration = 0
import json
with open(os.path.join(self.data_folder , 'model_save')+'/training_parameters.json', 'w') as json_file:
json.dump(self.training_parameters, json_file, indent=4)
if resume:
start_iteration = resume_iteration
for iteration in range(start_iteration,self.max_iter):
if iteration != 0:
dir1 = os.path.join(self.data_folder , str(iteration))
dir2 = os.path.join(self.data_folder , str(iteration)+'/stagedata')
dir3 = os.path.join(self.data_folder , 'model_save')
initalcommand = 'mkdir '+ dir1 +' && mkdir '+dir2
p = subprocess.Popen(initalcommand, stdout=subprocess.PIPE, shell=True)
unagi_runner = UNAGI_runner(self.data_folder,self.ns,iteration,self.unagi_trainer,idrem_dir,adversarial=self.adversarial,GCN = self.GCN)
unagi_runner.set_up_species(self.species)
if self.CPO_parameters is not None:
if type (self.CPO_parameters) != dict:
raise ValueError('CPO_parameters should be a dictionary')
else:
unagi_runner.set_up_CPO(anchor_neighbors=self.CPO_parameters['anchor_neighbors'], max_neighbors=self.CPO_parameters['max_neighbors'], min_neighbors=self.CPO_parameters['min_neighbors'], resolution_min=self.CPO_parameters['resolution_min'], resolution_max=self.CPO_parameters['resolution_max'])
if self.iDREM_parameters is not None:
if type (self.iDREM_parameters) != dict:
raise ValueError('iDREM_parameters should be a dictionary')
else:
unagi_runner.set_up_iDREM(Minimum_Absolute_Log_Ratio_Expression = self.iDREM_parameters['Minimum_Absolute_Log_Ratio_Expression'], Convergence_Likelihood = self.iDREM_parameters['Convergence_Likelihood'], Minimum_Standard_Deviation = self.iDREM_parameters['Minimum_Standard_Deviation'])
unagi_runner.run(CPO)
def test_geneweihts(self,iteration,idrem_dir):
iteration = int(iteration)
unagi_runner = UNAGI_runner(self.data_folder,self.ns,iteration,self.unagi_trainer,idrem_dir)
unagi_runner.set_up_species(self.species)
unagi_runner.load_stage_data()
unagi_runner.update_gene_weights_table()
def analyse_UNAGI(self,data_path,iteration,progressionmarker_background_sampling_times,run_pertubration,target_dir=None,customized_drug=None,cmap_dir=None):
'''
Perform downstream tasks including dynamic markers discoveries, hierarchical markers discoveries, pathway perturbations and compound perturbations.
parameters
---------------
data_path: str
the directory of the data (h5ad format, e.g. dataset.h5ad).
iteration: int
the iteration used for analysis.
progressionmarker_background_sampling_times: int
the number of times to sample the background cells for dynamic markers discoveries.
target_dir: str
the directory to save the results. Default is None.
customized_drug: str
the customized drug perturbation list. Default is None.
cmap_dir: str
the directory to the cmap database. Default is None.
'''
analysts = analyst(data_path,iteration,target_dir=target_dir,customized_drug=customized_drug,cmap_dir=cmap_dir)
analysts.start_analyse(progressionmarker_background_sampling_times,run_pertubration=run_pertubration)
print('The analysis has been done, please check the outputs!')
def customize_pathway_perturbation(self,data_path,iteration,customized_pathway,bound,CUDA=True,save_csv = None,save_adata = None,target_dir=None,device='cuda:0',show=False,top_n=None,cut_off=None):
if bound == 1:
raise ValueError('If change level is one, the perturbed gene expression will not change')
analysts = analyst(data_path,iteration,target_dir=target_dir,customized_mode=True)
analysts.perturbation_analyse_customized_pathway(customized_pathway,bound=bound,save_csv = save_csv,save_adata = save_adata,CUDA=CUDA,device=device)
return analysts.adata
def customize_drug_perturbation(self,data_path,iteration,customized_drug,bound,CUDA=True,save_csv = None,save_adata = None,target_dir=None,device='cuda:0',show=False,top_n=None,cut_off=None):
if bound == 1:
raise ValueError('If change level is one, the perturbed gene expression will not change')
analysts = analyst(data_path,iteration,target_dir=target_dir,customized_drug=customized_drug,customized_mode=True)
analysts.perturbation_analyse_customized_drug(customized_drug,bound=bound,save_csv = save_csv,save_adata = save_adata,CUDA=CUDA,device=device)
return analysts.adata