# colodict = {}
# for each in zip(all_types,category_colors):
# colodict[each[0]] = rgb2hex(each[1])
import scanpy as sc
import gc
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score
from sklearn.neighbors import kneighbors_graph
from sklearn.metrics import adjusted_rand_score,normalized_mutual_info_score
from sklearn.neighbors import kneighbors_graph
[docs]
def plot_with_colormap(values,color_dict):
'''
The color scheme the cell types are plotted with.
Parameters
----------
values : list
List of cell types.
color_dict : dict
Dictionary of cell types and their colors.
Returns
-------
color_dict : dict
Dictionary of cell types and their colors.
'''
color_list = ["#4dbbd5", # Blue
"#f39b7f", # Orange
"#00a087", # Green
"#e64b35", # Red
"#3c5488", # Purple
"#8c564b", # Brown
"#e377c2", # Pink
"#7f7f7f", # Gray
"#bcbd22", # Yellow-Green
"#17becf", # Cyan
"#ff9896", # Light Red
"#c5b0d5", # Light Purple
"#c49c94", # Light Brown
"#f7b6d2", # Light Pink
"#c7c7c7", # Light Gray
"#dbdb8d" , # Light Yellow-Green,
'tab:pink','tab:olive','tab:cyan','gold', 'springgreen','coral','skyblue','tab:blue','tab:orange','tab:green','tab:red','tab:purple','tab:brown','yellow','aqua', 'turquoise','orangered', 'lightblue','darkorchid', 'fuchsia','royalblue','slategray', 'silver', 'teal', 'fuchsia','grey','indigo','khaki','magenta','tab:gray']
# random.shuffle(color_list)
values = list(set(values))
values = sorted(values)
for i, value in enumerate(values):
if value not in list(color_dict.keys()):
color_dict[value] = color_list[(len(list(color_dict.keys()))+1)]
return color_dict
[docs]
def plot_stages_latent_representation(adatas, cell_type_key, stage_key,color_scheme=None,ax=None,dpi=300,save=None):
'''
Plot the latent representation of the cells colored by cell type and leiden clusters.
Parameters
----------
adatas : AnnData object
Annotated data matrix.
cell_type_key : str
Key for cell type column in adata.obs.
stage_key : str
Key for stage column in adata.obs.
color_scheme : dict, optional
Dictionary of cell types and their colors. The default is None.
ax : matplotlib axis, optional
The default is None.
dpi : int, optional
The default is 300.
save : str, optional
Path to save the figure. The default is None.
Returns
--------------
'''
sc.set_figure_params(scanpy=True, dpi=dpi)
consistency = []
ariss= []
NMIs = []
silhouettes = []
# ITERATION= 5
stage_keys = adatas.obs[stage_key].unique().tolist()
stage_keys = sorted(stage_keys)
stage_keys = stage_keys[::-1]
if color_scheme is None:
color_dict_unagi = {}
else:
color_dict_unagi = color_scheme
color_dict_leiden = {}
color_dict_groundtruth = {}
total_adata = 0
count=0
NMI = 0
silhouettes =0
aris = 0
fig, ax = plt.subplots(4,2, figsize=(10,15))
for i,stage in enumerate(stage_keys):
temp_count = 0
#check the type of adatas.obs[stage_key]
if adatas.obs[stage_key].dtype == 'str':
stage = str(stage)
elif adatas.obs[stage_key].dtype == 'int':
stage = int(stage)
print(len(adatas.obs[adatas.obs[stage_key] == stage].index.tolist()))
adata = adatas[adatas.obs[adatas.obs[stage_key] == stage].index.tolist()]
# print(len(adata))
adata.obs['UNAGI'] = adata.obs[cell_type_key].astype('category')
# adata.obs['Ground Truth'] = adata.obs['name.simple'].astype('category')
adata.obs['leiden'] = adata.obs['leiden'].astype('category')
sorted_list = sorted(list(adata.obs['UNAGI'].unique()))
color_dict_unagi = plot_with_colormap(sorted_list,color_dict_unagi)
adata.obs['leiden'] = adata.obs['leiden'].astype('string')
sc.pl.umap(adata,color='UNAGI',ax=ax[i,0], show=False,palette=color_dict_unagi,title=str(stage_keys[i]))
sc.pl.umap(adata,color='leiden',ax=ax[i,1], show=False,title = str(stage_keys[i]))
total_adata+=len(adata)
count+=temp_count
temp_ari = adjusted_rand_score(adata.obs['name.simple'],adata.obs['UNAGI'] )
temp_nmi = normalized_mutual_info_score(adata.obs['name.simple'],adata.obs['UNAGI'])
temp_silhouette_score = silhouette_score(adata.obsm['z'], adata.obs['leiden'])
print('ARI: ', temp_ari)
print('NMIs: ', temp_nmi)
print('silhouette score: ', temp_silhouette_score)
NMI += temp_nmi
silhouettes += temp_silhouette_score
aris += temp_ari
consistency.append(count/total_adata)
ariss.append(aris/4)
NMIs.append(NMI/4)
plt.tight_layout()
if save is not None:
plt.savefig(save,dpi=dpi)
else:
plt.show()
print('ARIs: ', ariss)
print('NMI: ', NMIs)
print('silhouette score: ', silhouettes/4)
# if __name__ == '__main__':
# adata = sc.read_h5ad('../dataset.h5ad')
# plot_stages_latent_representation(adata,'ident','stage')