Skip to content

Commit 7dcab35

Browse files
authored
Merge pull request #10 from Mye-InfoBank/serafina/scvi_tools_dgea
Serafina/scvi tools dgea
2 parents cc3a849 + d272d36 commit 7dcab35

11 files changed

+473
-3
lines changed

app.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
from composition import composition_server, composition_ui
77
from export import export_ui, export_server
8-
from dgea.dgea import dgea_server, dgea_ui
8+
from dgea.dgea_scvi import dgea_server, dgea_ui
99
from tree import tree_server, tree_ui
1010

1111
with open("data/config.json") as f:
1212
config = json.load(f)
1313
adata = sc.read_h5ad("data/" + config["adata"])
1414
tree = pickle.load(open("data/" + config["tree"], "rb")) if "tree" in config else None
1515
name = config["name"]
16+
model_path = config["model_path"]
1617

1718
categorical_columns = adata.obs.select_dtypes(include="category").columns.to_list()
1819

@@ -34,6 +35,7 @@ def server(input, output, session):
3435
_dataframe = reactive.value(adata.obs)
3536
_adata = reactive.value(adata)
3637
_tree = reactive.value(tree)
38+
_model = reactive.value(model_path)
3739
composition_server("composition", _dataframe)
3840
export_server("export")
3941
dgea_server("dgea", _adata)

data/atlas.h5ad

9.81 MB
Binary file not shown.

data/config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"adata": "merged.h5ad",
33
"tree": "scarches.tree.pkl",
4-
"name": "possible_atlas"
4+
"name": "possible_atlas",
5+
"model_path": "./data"
56
}

data/model.pt

3.21 MB
Binary file not shown.

dgea/dgea_scvi.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from shiny import reactive, ui, render, module
2+
import anndata as ad
3+
4+
from dgea.run_dgea_scvi import run_dgea_ui, run_dgea_server
5+
from dgea.filter_dgea_scvi import filter_dgea_ui, filter_dgea_server
6+
from dgea.plot_dgea_scvi import plot_dgea_ui, plot_dgea_server
7+
8+
@module.ui
9+
def dgea_ui():
10+
return ui.layout_sidebar(
11+
ui.sidebar(
12+
run_dgea_ui("run_dgea"),
13+
filter_dgea_ui("filter_dgea"),
14+
title="Select covariates"
15+
),
16+
*plot_dgea_ui("plot_dgea")
17+
)
18+
19+
@module.server
20+
def dgea_server(input, output, session, _adata: reactive.Value[ad.AnnData]):
21+
_counts = reactive.value(None)
22+
_uniques = reactive.value([])
23+
24+
_contrast = reactive.value(None)
25+
_reference = reactive.value(None)
26+
_alternative = reactive.value(None)
27+
_log10_p = reactive.value(0.05)
28+
_lfc = reactive.value(1)
29+
30+
_result = reactive.value(None)
31+
_filtered_result = reactive.value(None)
32+
_filtered_genes = reactive.value(None)
33+
_filtered_counts = reactive.value(None)
34+
35+
run_dgea_server("run_dgea", _adata, _result, _counts, _reference, _alternative, _uniques, _contrast)
36+
filter_dgea_server("filter_dgea", _adata, _counts, _uniques,
37+
_result, _filtered_result, _filtered_genes,
38+
_filtered_counts, _reference, _alternative, _contrast, _log10_p, _lfc)
39+
plot_dgea_server("plot_dgea", _filtered_counts, _contrast, _reference, _alternative,
40+
_result, _log10_p, _lfc)

dgea/dgea_scvi_helpers.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import numpy as np
2+
import pandas as pd
3+
import scanpy as sc
4+
import anndata as ad
5+
6+
from scvi.model import SCANVI, SCVI
7+
8+
def scanvi_dgea(adata:ad.AnnData, groupby:str, reference:str, alternative:str, directory_model:str):
9+
10+
if 'cell_type' in adata.obs.columns:
11+
model_type = SCANVI
12+
print('is scavi')
13+
14+
else:
15+
model_type = SCVI
16+
print('is scanvi')
17+
18+
model_type.prepare_query_anndata(adata = adata, reference_model=directory_model)
19+
20+
model = model_type.load_query_data(adata, directory_model)
21+
22+
groups = np.array(adata.obs[groupby].unique())
23+
24+
idx1 = adata.obs[groupby] == reference
25+
idx2 = adata.obs[groupby] == alternative
26+
27+
dge_change = model.differential_expression(adata=adata, groupby=groupby, idx1=idx1, idx2=idx2, mode="change")
28+
29+
epsilon = 1e-10
30+
dge_change['proba_not_de'] = np.maximum(dge_change["proba_not_de"], epsilon)
31+
dge_change["log10_pscore"] = np.log10(dge_change["proba_not_de"])
32+
dge_change["-log10_pscore"] = -np.log10(dge_change["proba_not_de"])
33+
34+
return dge_change
35+
36+
def get_normalized_counts(adata):
37+
print(adata.shape)
38+
sc.pp.normalize_total(adata, target_sum=1e4)
39+
sc.pp.log1p(adata)
40+
adata.layers["counts"] = adata.X.copy().tocsr()
41+
counts = adata.layers["counts"]
42+
dense_matrix = counts.toarray()
43+
df_counts = pd.DataFrame(dense_matrix, index=adata.obs_names, columns=adata.var_names)
44+
return df_counts
45+
46+
if __name__ == '__main__':
47+
print('Running DGEA test')
48+
adata = sc.read_h5ad('/workspaces/SIMBA-Downstream_1/data/atlas.h5ad')
49+
dge_test = scanvi_dgea(adata, "cell_type", "Endothelial", "Epithelial", './data')
50+
print(dge_test.head())

dgea/filter_dgea_scvi.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from shiny import reactive, ui, render, module
2+
import anndata as ad
3+
import pandas as pd
4+
from dgea.dgea_scvi_helpers import scanvi_dgea, get_normalized_counts
5+
6+
@module.ui
7+
def filter_dgea_ui():
8+
return ui.div(
9+
ui.output_ui("select_reference"),
10+
ui.output_ui("select_alternative"),
11+
ui.input_slider("log10_pscore", "Ropability in Reference (significance threshold)", min=0, max=20, step=0.01, value=3),
12+
ui.input_slider("lfc", "Log2 fold change", min=0, max=10, step=0.1, value=1),
13+
ui.output_ui("open_gprofiler")
14+
)
15+
16+
@module.server
17+
def filter_dgea_server(input, output, session,
18+
_adata, _counts, _uniques,
19+
_result, _filtered_result, _filtered_genes, _filtered_counts,
20+
_reference, _alternative, _contrast,
21+
_log10_p, _lfc
22+
):
23+
24+
@output
25+
@render.ui
26+
def select_reference():
27+
uniques = _uniques.get()
28+
29+
if not uniques or len(uniques) < 2:
30+
return ui.p("Run analysis to see options")
31+
32+
return ui.input_select("reference", "Reference", choices=uniques, selected=uniques[0])
33+
34+
@output
35+
@render.ui
36+
def select_alternative():
37+
uniques = _uniques.get()
38+
39+
if not uniques or len(uniques) < 2:
40+
return ui.p("Run analysis to see options")
41+
42+
print(uniques)
43+
44+
return ui.input_select("alternative", "Alternative", choices=uniques, selected=uniques[1])
45+
46+
47+
@reactive.effect
48+
def update_filters():
49+
_reference.set(input["reference"].get())
50+
_alternative.set(input["alternative"].get())
51+
_log10_p.set(input["log10_pscore"].get())
52+
_lfc.set(input["lfc"].get())
53+
54+
@reactive.effect
55+
def update_result():
56+
adata = _adata.get()
57+
reference = _reference.get()
58+
alternative = _alternative.get()
59+
contrast = _contrast.get()
60+
61+
if None in (reference, alternative, contrast):
62+
return
63+
64+
res_df = scanvi_dgea(adata, contrast, reference, alternative)
65+
res_counts = get_normalized_counts(adata)
66+
_result.set(res_df)
67+
_counts.set(res_counts)
68+
69+
@reactive.effect
70+
def filter_result():
71+
result = _result.get()
72+
log10_p = input["log10_pscore"].get()
73+
lfc = input["lfc"].get()
74+
counts = _counts.get()
75+
76+
if result is None:
77+
return None
78+
79+
result = result[(result["-log10_pscore"] < log10_p) & (result["lfc_mean"].abs() > lfc)]
80+
_filtered_result.set(result)
81+
genes = result.index.tolist()
82+
genes_not_found = [gene for gene in genes if gene not in counts.columns]
83+
if genes_not_found:
84+
print(f"Genes not found in the DataFrame: {genes_not_found}")
85+
_filtered_genes.set(genes)
86+
_filtered_counts.set(counts.loc[:, genes])
87+
88+
@render.ui
89+
def open_gprofiler():
90+
genes = _filtered_genes.get()
91+
if not genes:
92+
return None
93+
94+
return ui.input_action_button("gprofiler", label="Open g:Profiler",
95+
onclick=f"window.open('https://biit.cs.ut.ee/gprofiler/gost?organism=hsapiens&query={'%0A'.join(genes)}', '_blank')")

dgea/plot_dgea_scvi.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from shiny import reactive, ui, render, module
2+
import numpy as np
3+
import seaborn as sns
4+
import shinywidgets as sw
5+
import tempfile
6+
import plotly.express as px
7+
8+
@module.ui
9+
def plot_dgea_ui():
10+
return [
11+
ui.card(
12+
ui.card_header("Volcano plot"),
13+
sw.output_widget("plot_volcano")
14+
),
15+
ui.card(
16+
ui.card_header("Heatmap"),
17+
ui.output_plot("plot_heatmap"),
18+
ui.card_footer(
19+
ui.download_button("download_dgea", "Download DGEA matrix"),
20+
ui.download_button("download_plot", "Download plot")
21+
)
22+
)
23+
]
24+
25+
@module.server
26+
def plot_dgea_server(input, output, session,
27+
_filtered_counts,
28+
_contrast,
29+
_reference,
30+
_alternative,
31+
_result,
32+
_log10_p,
33+
_lfc
34+
):
35+
_heatmap = reactive.value(None)
36+
37+
@output
38+
@render.plot
39+
def plot_heatmap():
40+
counts_df = _filtered_counts.get()
41+
contrast = _contrast.get()
42+
reference = _reference.get()
43+
alternative = _alternative.get()
44+
45+
if counts_df is None:
46+
return None
47+
48+
if counts_df.empty:
49+
return None
50+
51+
plot = sns.clustermap(counts_df.T, cmap="viridis", figsize=(10, 10))
52+
_heatmap.set(plot)
53+
54+
return plot
55+
56+
@render.download(
57+
filename=lambda: f"dgea_matrix_{_contrast.get()}-{_reference.get()}:{_alternative.get()}.csv"
58+
)
59+
def download_dgea():
60+
scanvi_results = _result.get()
61+
if scanvi_results is None:
62+
return None
63+
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as temp:
64+
scanvi_results.to_csv(temp.name)
65+
return temp.name
66+
67+
@render.download(
68+
filename=lambda: f"heatmap_{_contrast.get()}-{_reference.get()}:{_alternative.get()}.png"
69+
)
70+
def download_plot():
71+
plot = _heatmap.get()
72+
if plot is None:
73+
return None
74+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp:
75+
plot.savefig(temp.name, bbox_inches="tight")
76+
return temp.name
77+
78+
@output
79+
@sw.render_plotly
80+
def plot_volcano():
81+
scanvi_results = _result.get()
82+
log10_p = _log10_p.get()
83+
lfc = _lfc.get()
84+
85+
if scanvi_results is None:
86+
return None
87+
88+
df_plot = scanvi_results.copy()
89+
90+
df_plot["category"] = "Not significant"
91+
df_plot["gene"] = df_plot.index
92+
df_plot.loc[(df_plot["lfc_mean"] < -lfc) & (df_plot["-log10_pscore"] < log10_p), "category"] = f"High in {_alternative.get()}"
93+
df_plot.loc[(df_plot["lfc_mean"] > lfc) & (df_plot["-log10_pscore"] < log10_p), "category"] = f"High in {_reference.get()}"
94+
95+
colormap = {"Not significant": "grey", f"High in {_alternative.get()}": "blue", f"High in {_reference.get()}": "red"}
96+
97+
hover_data = {
98+
"lfc_mean": True,
99+
"-log10_pscore": True,
100+
"category": True,
101+
}
102+
103+
fig = px.scatter(df_plot, x="lfc_mean", y="-log10_pscore", color="category",
104+
color_discrete_map=colormap,
105+
hover_name="gene",
106+
hover_data=hover_data,
107+
labels={"lfc_mean": "Log2 fold change mean",
108+
"-log10_pscore": "Negative log 10 P-value",
109+
"category": "Category"})
110+
111+
num_points = len(df_plot)
112+
fig.add_annotation(text=f"number of counts = {num_points}", x=0.5, y=-0.15, showarrow=False, font=dict(size=12), xanchor="center")
113+
fig.update_layout(margin=dict(l=40, r=40, b=80, t=40), height=600)
114+
115+
return fig
116+

0 commit comments

Comments
 (0)