Skip to content

Commit a159a2e

Browse files
committed
docstrings + edge_importance_data meth
1 parent 898e42a commit a159a2e

File tree

1 file changed

+104
-25
lines changed

1 file changed

+104
-25
lines changed

bioneuralnet/explainability/dpmon_explainer.py

Lines changed: 104 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch import nn
77
from torch.types import FileLike
88

9-
from typing import Literal
9+
from typing import Literal, Dict, List
1010

1111
from torch_geometric.data import Data
1212
from torch_geometric.explain import Explainer, ExplainerAlgorithm, ModelConfig
@@ -21,8 +21,6 @@
2121
setup_device,
2222
)
2323

24-
from typing import List, Optional
25-
2624

2725
class NeuralNetworkWrapper(nn.Module):
2826
"""A wrapper class for formatting DPMON Neural Network IO in a form pytorch_geometric requires"""
@@ -40,9 +38,9 @@ def __init__(self, nn: NeuralNetwork):
4038

4139
self.nn = nn
4240

43-
def forward(self, x, edge_index, train_features, edge_attr=None):
41+
def forward(self, x, edge_index, train_features, **kwargs):
4442

45-
_omics_network_tg = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
43+
_omics_network_tg = Data(x=x, edge_index=edge_index, **kwargs)
4644

4745
pred, _, _ = self.nn(train_features, _omics_network_tg)
4846
return pred
@@ -55,18 +53,23 @@ def __init__(
5553
self,
5654
f: FileLike,
5755
dpmon: DPMON,
58-
algorithm: ExplainerAlgorithm,
59-
mode: Literal["regression", "binary_classification", "multiclass_classification"],
60-
explanation_type: Literal["model", "phenomenon"] = "model",
61-
node_mask_type: Literal["object", "common_attributes", "attributes"] | None = "attributes",
62-
edge_mask_type: Literal["object", "common_attributes", "attributes"] | None = "object",
63-
task_level: Literal["edge", "node", "graph"] = "graph",
64-
return_type: Literal["raw", "log_probs", "probs"] = "raw",
6556
weights_only: bool = True,
6657
):
6758
"""Initialize DPMON explainer object.
68-
This implementation is a first version.
69-
There has to be a better way to do this
59+
This implementation is a first version. By default, it uses `torch_geometric.explain.GNNExplainer()`
60+
to produce feature importance explanations on `clinical` data. The raw node explanations are stored in
61+
`self.expl.node_mask`
62+
63+
It is important to note that these explanations likely do not capture the full picture of the predictions
64+
in a multi-omics network, rather provide insight into the clinical (patient) features which account
65+
for the explanations.
66+
67+
This explainer object also produces edge importances at an `object` level. It is currently
68+
unknown how useful these explanations are, but are stored in `self.expl.edge_mask`. A method is provided
69+
to retrieve the top n important edges for the user to observe if they want.
70+
71+
Otherwise, all of the default `pytorch_geometric.explain.Explainer` methods are available to use, whether
72+
through a provided wrapper or through the `expl` member variable
7073
7174
Args:
7275
f (FileLike): The file object or path to a saved model trained with DPMON
@@ -90,6 +93,8 @@ def __init__(
9093
phenotype_col="phenotype",
9194
)[0]
9295

96+
self.clinical_data = dpmon.clinical_data
97+
9398
model = NeuralNetwork(
9499
model_type=dpmon.model,
95100
gnn_input_dim=self.omics_network_tg.x.shape[1], # type: ignore
@@ -112,29 +117,103 @@ def __init__(
112117
).to(device)
113118

114119
self.model = NeuralNetworkWrapper(model)
120+
121+
def explain(
122+
self,
123+
algorithm: ExplainerAlgorithm,
124+
mode: Literal[
125+
"regression", "binary_classification", "multiclass_classification"
126+
],
127+
explanation_type: Literal["model", "phenomenon"] = "model",
128+
node_mask_type: (
129+
Literal["object", "common_attributes", "attributes"] | None
130+
) = "attributes",
131+
edge_mask_type: (
132+
Literal["object", "common_attributes", "attributes"] | None
133+
) = "object",
134+
task_level: Literal["edge", "node", "graph"] = "graph",
135+
return_type: Literal["raw", "log_probs", "probs"] = "raw",
136+
):
137+
"""Generate explanations for the DPMON instance and the model loaded at the specified path
138+
139+
Args:
140+
algorithm (ExplainerAlgorithm): The `pytorch_geometric.explain` explainer algorithm to use. Currently only tested with `GNNExplainer()`
141+
mode (Literal[ "regression", "binary_classification", "multiclass_classification" ]): The type of prediction the GNN is making
142+
explanation_type (Literal["model", "phenomenon"], optional): Whether to generate explanations on the `model` predictions or Explains the `phenomenon` that the model is trying to predict. Defaults to "model".
143+
node_mask_type (Literal["object", "common_attributes", "attributes"] | None, optional): The node explanation type to generate. Defaults to "attributes".
144+
edge_mask_type (Literal["object", "common_attributes", "attributes"] | None, optional): The edge explanation type to generate. Defaults to "object".
145+
task_level (Literal["edge", "node", "graph"], optional): The prediction scope of the model. Defaults to "graph".
146+
return_type (Literal["raw", "log_probs", "probs"], optional): The output of the model. Defaults to "raw".
147+
"""
115148
self.explainer = Explainer(
116149
self.model,
117150
algorithm,
118151
explanation_type=explanation_type,
119152
node_mask_type=node_mask_type,
120153
edge_mask_type=edge_mask_type,
121154
model_config=ModelConfig(
122-
mode=mode,
123-
task_level=task_level,
124-
return_type=return_type
125-
)
155+
mode=mode, task_level=task_level, return_type=return_type
156+
),
126157
)
127158

128159
if self.omics_network_tg.x != None and self.omics_network_tg.edge_index != None:
129-
self.expl = self.explainer(**self.omics_network_tg.to_dict(), train_features=self.train_features)
130-
print(self.expl.edge_mask)
131-
print(self.expl.node_mask)
160+
self.expl = self.explainer(
161+
**self.omics_network_tg.to_dict(), train_features=self.train_features
162+
)
163+
164+
def edge_importance_data(self, top_n: int = 5) -> List[Dict]:
165+
"""Method for providing a summary on object level edge importance
132166
167+
Args:
168+
top_n (int, optional): the number of important edges to retrieve. Defaults to 5.
169+
Returns:
170+
List[Dict]: The edges with the top n importances and `edge_attr` if it exists
171+
"""
133172

173+
edges = []
174+
175+
if self.expl.edge_mask != None:
176+
for idx, importance in enumerate(self.expl.edge_mask):
177+
importance = importance.item()
178+
new_edge = {"importance": importance, "edge": self.omics_network_tg.edge_index[:, idx]} # type: ignore
179+
if self.omics_network_tg.edge_attr != None:
180+
if self.omics_network_tg.edge_attr.ndim == 1:
181+
new_edge.update(
182+
{"edge_attr": self.omics_network_tg.edge_attr[idx].item()}
183+
)
184+
else:
185+
new_edge.update(
186+
{"edge_attr": self.omics_network_tg.edge_attr[idx, :]}
187+
)
188+
189+
if len(edges) < 1:
190+
edges.append(new_edge) # type: ignore
191+
continue
192+
193+
for i, edge in enumerate(edges):
194+
if edge["importance"] < importance:
195+
edges.insert(i, new_edge) # type: ignore
196+
break
197+
edges = edges[:top_n]
198+
return edges[:top_n]
199+
else:
200+
raise AttributeError(
201+
"edge_mask is not defined. Generate explanations on edges first"
202+
)
134203

204+
def visualize_feature_importance(
205+
self, path: os.PathLike | None = None, top_k: int | None = None
206+
):
207+
"""Wrapper of the `pytorch_geometric.explain.Explainer.visualize_feature_importance` method
135208
136-
def visualize_feature_importance(self, path: os.PathLike):
137-
self.expl.visualize_feature_importance(str(path))
209+
Args:
210+
path (os.PathLike | None, optional): Path to save the feature importance graph. Defaults to None.
211+
top_k (int | None, optional): The number of features to include in the graph. Defaults to None.
212+
"""
213+
feat_labels = None
214+
if isinstance(self.clinical_data, pd.DataFrame):
215+
feat_labels = self.clinical_data.columns.to_list()
138216

139-
def visualize_graph(self, path: os.PathLike):
140-
self.expl.visualize_graph(str(path))
217+
self.expl.visualize_feature_importance(
218+
str(path) if path != None else path, top_k=top_k, feat_labels=feat_labels # type: ignore
219+
)

0 commit comments

Comments
 (0)