{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Ablations and embedding quality evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## UNAGI w.o GAN and GCN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "import torch\n", "import numpy as np\n", "import random\n", "import os\n", "warnings.filterwarnings('ignore')\n", "from UNAGI import UNAGI\n", "for i in range(15):\n", " torch.manual_seed(i)\n", " np.random.seed(i)\n", " random.seed(i)\n", " unagi = UNAGI()\n", "\n", " print('seed %d'%i)\n", " os.system('cp -r ../data/mes_raw ../data/plain_ziln_CPO_seed_%d'%i)\n", " unagi.setup_data('../data/plain_ziln_CPO_seed_%d'%i,total_stage=4,stage_key='stage')\n", " unagi.setup_training(task='plain_ziln_CPO_seed_%d'%i,dist='ziln',device='cuda:0',GPU=True,epoch_iter=0,epoch_initial=10,max_iter=1,BATCHSIZE=1024,GCN=False,adversarial=False)\n", " unagi.run_UNAGI(idrem_dir = 'PATH_TO_IDREM')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## UNAGI w.o GCN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "import torch\n", "import numpy as np\n", "import random\n", "import os\n", "warnings.filterwarnings('ignore')\n", "from UNAGI import UNAGI\n", "for i in range(15):\n", " torch.manual_seed(i)\n", " np.random.seed(i)\n", " random.seed(i)\n", " unagi = UNAGI()\n", "\n", " print('seed %d'%i)\n", " os.system('cp -r ../data/mes_raw ../data/gan_ziln_CPO_seed_%d'%i)\n", " unagi.setup_data('../data/gan_ziln_CPO_seed_%d'%i,total_stage=4,stage_key='stage')\n", " unagi.setup_training(task='gan_ziln_CPO_seed_%d'%i,dist='ziln',device='cuda:0',GPU=True,epoch_iter=0,epoch_initial=10,max_iter=1,BATCHSIZE=1024,GCN=False,adversarial=True)\n", " unagi.run_UNAGI(idrem_dir = 'PATH_TO_IDREM',CPO=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## UNAGI w.o. GAN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "import torch\n", "import numpy as np\n", "import random\n", "import os\n", "warnings.filterwarnings('ignore')\n", "from UNAGI import UNAGI\n", "for i in range(15):\n", " torch.manual_seed(i)\n", " np.random.seed(i)\n", " random.seed(i)\n", " unagi = UNAGI()\n", "\n", " print('seed %d'%i)\n", " os.system('cp -r ../data/mes_raw ../data/gcn_ziln_CPO_seed_%d'%i)\n", " unagi.setup_data('../data/gcn_ziln_CPO_seed_%d'%i,total_stage=4,stage_key='stage')\n", " unagi.setup_training(task='gcn_ziln_CPO_seed_%d'%i,dist='ziln',device='cuda:0',GPU=True,epoch_iter=0,epoch_initial=10,max_iter=1,BATCHSIZE=1024,GCN=True,adversarial=False)\n", " unagi.run_UNAGI(idrem_dir = '/mnt/md0/yumin/to_upload/idrem')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate embedding quality" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import scanpy as sc\n", "from UNAGI.utils import evaluate\n", "import pickle\n", "import concurrent.futures\n", "\n", "# Function to process each seed\n", "def process_seed(i):\n", " prefix = 'plain_ziln_CPO'\n", " target_dir = '../data/'\n", " file_name = prefix + '_seed_' + str(i)\n", " adata = sc.read(target_dir + file_name + '/0/stagedata/org_dataset.h5ad')\n", "\n", " aris, NMI,DBI,label_scores, silhouettes, isolated_asws, cell_type_asws, isolated_labels_f1s,clisi_graphs, overall_scibs = evaluate.run_metrics(adata, 'ident', 'stage')\n", " \n", " with open(target_dir + file_name+'.txt','w') as f:\n", " f.write('ari\\t ' + str(aris) + '\\n')\n", " f.write('nmi\\t ' + str(NMI) + '\\n')\n", " f.write('dbi\\t ' + str(DBI) + '\\n')\n", " f.write('label_score\\t ' + str(label_scores) + '\\n')\n", " f.write('sil\\t ' + str(silhouettes) + '\\n')\n", " f.write('isolated_asw\\t ' + str(isolated_asws) + '\\n')\n", " f.write('cell_type_asw\\t ' + str(cell_type_asws) + '\\n')\n", " f.write('isolated_labels_f1\\t ' + str(isolated_labels_f1s) + '\\n')\n", " f.write('clisi_graphs\\t ' + str(clisi_graphs) + '\\n')\n", " f.write('overall_scibs\\t ' + str(overall_scibs) + '\\n')\n", "\n", "\n", " return aris, NMI,DBI,label_scores, silhouettes, isolated_asws, cell_type_asws, isolated_labels_f1s,clisi_graphs, overall_scibs\n", "\n", "\n", "def main():\n", " prefix = 'plain_ziln_CPO'\n", " target_dir = '../data/'\n", " aris = []\n", " nmis = []\n", " dbis = []\n", " label_socres = []\n", " sils = []\n", " isolated_asws = []\n", " cell_type_asws = []\n", " isolated_labels_f1s = []\n", " clisi_graphs = []\n", " overall_scibs = []\n", " \n", " # Using ProcessPoolExecutor to parallelize\n", " with concurrent.futures.ProcessPoolExecutor(max_workers=3) as executor:\n", " results = list(executor.map(process_seed, range(15)))\n", "\n", " # Unpack results\n", " for result in results:\n", " ari, nmi,dbi, label_socre,sil, isolated_asw, cell_type_asw, isolated_labels_f1,clisi_graph,overall_scib = result\n", " aris.append(ari)\n", " nmis.append(nmi)\n", " dbis.append(dbi)\n", " label_socres.append(label_socre)\n", " sils.append(sil)\n", " isolated_asws.append(isolated_asw)\n", " cell_type_asws.append(cell_type_asw)\n", " isolated_labels_f1s.append(isolated_labels_f1)\n", " clisi_graphs.append(clisi_graph)\n", " overall_scibs.append(overall_scib)\n", "\n", " # Write results to file\n", " with open(target_dir + prefix + '_metrics.txt', 'w') as f:\n", " f.write('ari\\t ' + str(aris) + '\\n')\n", " f.write('nmi\\t ' + str(nmis) + '\\n')\n", " f.write('dbi\\t ' + str(dbis) + '\\n')\n", " f.write('label_score\\t ' + str(label_socres) + '\\n')\n", " f.write('sil\\t ' + str(sils) + '\\n')\n", " f.write('isolated_asw\\t ' + str(isolated_asws) + '\\n')\n", " f.write('cell_type_asw\\t ' + str(cell_type_asws) + '\\n')\n", " f.write('isolated_labels_f1\\t ' + str(isolated_labels_f1s) + '\\n')\n", " f.write('clisi_graphs\\t ' + str(clisi_graphs) + '\\n')\n", " f.write('overall_scibs\\t ' + str(overall_scibs) + '\\n')\n", "\n", "if __name__ == '__main__':\n", " main()\n" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }