66from torch import nn
77from torch .types import FileLike
88
9- from typing import Literal
9+ from typing import Literal , Dict , List
1010
1111from torch_geometric .data import Data
1212from torch_geometric .explain import Explainer , ExplainerAlgorithm , ModelConfig
2121 setup_device ,
2222)
2323
24- from typing import List , Optional
25-
2624
2725class NeuralNetworkWrapper (nn .Module ):
2826 """A wrapper class for formatting DPMON Neural Network IO in a form pytorch_geometric requires"""
@@ -40,9 +38,9 @@ def __init__(self, nn: NeuralNetwork):
4038
4139 self .nn = nn
4240
43- def forward (self , x , edge_index , train_features , edge_attr = None ):
41+ def forward (self , x , edge_index , train_features , ** kwargs ):
4442
45- _omics_network_tg = Data (x = x , edge_index = edge_index , edge_attr = edge_attr )
43+ _omics_network_tg = Data (x = x , edge_index = edge_index , ** kwargs )
4644
4745 pred , _ , _ = self .nn (train_features , _omics_network_tg )
4846 return pred
@@ -55,18 +53,23 @@ def __init__(
5553 self ,
5654 f : FileLike ,
5755 dpmon : DPMON ,
58- algorithm : ExplainerAlgorithm ,
59- mode : Literal ["regression" , "binary_classification" , "multiclass_classification" ],
60- explanation_type : Literal ["model" , "phenomenon" ] = "model" ,
61- node_mask_type : Literal ["object" , "common_attributes" , "attributes" ] | None = "attributes" ,
62- edge_mask_type : Literal ["object" , "common_attributes" , "attributes" ] | None = "object" ,
63- task_level : Literal ["edge" , "node" , "graph" ] = "graph" ,
64- return_type : Literal ["raw" , "log_probs" , "probs" ] = "raw" ,
6556 weights_only : bool = True ,
6657 ):
6758 """Initialize DPMON explainer object.
68- This implementation is a first version.
69- There has to be a better way to do this
59+ This implementation is a first version. By default, it uses `torch_geometric.explain.GNNExplainer()`
60+ to produce feature importance explanations on `clinical` data. The raw node explanations are stored in
61+ `self.expl.node_mask`
62+
63+ It is important to note that these explanations likely do not capture the full picture of the predictions
64+ in a multi-omics network, rather provide insight into the clinical (patient) features which account
65+ for the explanations.
66+
67+ This explainer object also produces edge importances at an `object` level. It is currently
68+ unknown how useful these explanations are, but are stored in `self.expl.edge_mask`. A method is provided
69+ to retrieve the top n important edges for the user to observe if they want.
70+
71+ Otherwise, all of the default `pytorch_geometric.explain.Explainer` methods are available to use, whether
72+ through a provided wrapper or through the `expl` member variable
7073
7174 Args:
7275 f (FileLike): The file object or path to a saved model trained with DPMON
@@ -90,6 +93,8 @@ def __init__(
9093 phenotype_col = "phenotype" ,
9194 )[0 ]
9295
96+ self .clinical_data = dpmon .clinical_data
97+
9398 model = NeuralNetwork (
9499 model_type = dpmon .model ,
95100 gnn_input_dim = self .omics_network_tg .x .shape [1 ], # type: ignore
@@ -112,29 +117,103 @@ def __init__(
112117 ).to (device )
113118
114119 self .model = NeuralNetworkWrapper (model )
120+
121+ def explain (
122+ self ,
123+ algorithm : ExplainerAlgorithm ,
124+ mode : Literal [
125+ "regression" , "binary_classification" , "multiclass_classification"
126+ ],
127+ explanation_type : Literal ["model" , "phenomenon" ] = "model" ,
128+ node_mask_type : (
129+ Literal ["object" , "common_attributes" , "attributes" ] | None
130+ ) = "attributes" ,
131+ edge_mask_type : (
132+ Literal ["object" , "common_attributes" , "attributes" ] | None
133+ ) = "object" ,
134+ task_level : Literal ["edge" , "node" , "graph" ] = "graph" ,
135+ return_type : Literal ["raw" , "log_probs" , "probs" ] = "raw" ,
136+ ):
137+ """Generate explanations for the DPMON instance and the model loaded at the specified path
138+
139+ Args:
140+ algorithm (ExplainerAlgorithm): The `pytorch_geometric.explain` explainer algorithm to use. Currently only tested with `GNNExplainer()`
141+ mode (Literal[ "regression", "binary_classification", "multiclass_classification" ]): The type of prediction the GNN is making
142+ explanation_type (Literal["model", "phenomenon"], optional): Whether to generate explanations on the `model` predictions or Explains the `phenomenon` that the model is trying to predict. Defaults to "model".
143+ node_mask_type (Literal["object", "common_attributes", "attributes"] | None, optional): The node explanation type to generate. Defaults to "attributes".
144+ edge_mask_type (Literal["object", "common_attributes", "attributes"] | None, optional): The edge explanation type to generate. Defaults to "object".
145+ task_level (Literal["edge", "node", "graph"], optional): The prediction scope of the model. Defaults to "graph".
146+ return_type (Literal["raw", "log_probs", "probs"], optional): The output of the model. Defaults to "raw".
147+ """
115148 self .explainer = Explainer (
116149 self .model ,
117150 algorithm ,
118151 explanation_type = explanation_type ,
119152 node_mask_type = node_mask_type ,
120153 edge_mask_type = edge_mask_type ,
121154 model_config = ModelConfig (
122- mode = mode ,
123- task_level = task_level ,
124- return_type = return_type
125- )
155+ mode = mode , task_level = task_level , return_type = return_type
156+ ),
126157 )
127158
128159 if self .omics_network_tg .x != None and self .omics_network_tg .edge_index != None :
129- self .expl = self .explainer (** self .omics_network_tg .to_dict (), train_features = self .train_features )
130- print (self .expl .edge_mask )
131- print (self .expl .node_mask )
160+ self .expl = self .explainer (
161+ ** self .omics_network_tg .to_dict (), train_features = self .train_features
162+ )
163+
164+ def edge_importance_data (self , top_n : int = 5 ) -> List [Dict ]:
165+ """Method for providing a summary on object level edge importance
132166
167+ Args:
168+ top_n (int, optional): the number of important edges to retrieve. Defaults to 5.
169+ Returns:
170+ List[Dict]: The edges with the top n importances and `edge_attr` if it exists
171+ """
133172
173+ edges = []
174+
175+ if self .expl .edge_mask != None :
176+ for idx , importance in enumerate (self .expl .edge_mask ):
177+ importance = importance .item ()
178+ new_edge = {"importance" : importance , "edge" : self .omics_network_tg .edge_index [:, idx ]} # type: ignore
179+ if self .omics_network_tg .edge_attr != None :
180+ if self .omics_network_tg .edge_attr .ndim == 1 :
181+ new_edge .update (
182+ {"edge_attr" : self .omics_network_tg .edge_attr [idx ].item ()}
183+ )
184+ else :
185+ new_edge .update (
186+ {"edge_attr" : self .omics_network_tg .edge_attr [idx , :]}
187+ )
188+
189+ if len (edges ) < 1 :
190+ edges .append (new_edge ) # type: ignore
191+ continue
192+
193+ for i , edge in enumerate (edges ):
194+ if edge ["importance" ] < importance :
195+ edges .insert (i , new_edge ) # type: ignore
196+ break
197+ edges = edges [:top_n ]
198+ return edges [:top_n ]
199+ else :
200+ raise AttributeError (
201+ "edge_mask is not defined. Generate explanations on edges first"
202+ )
134203
204+ def visualize_feature_importance (
205+ self , path : os .PathLike | None = None , top_k : int | None = None
206+ ):
207+ """Wrapper of the `pytorch_geometric.explain.Explainer.visualize_feature_importance` method
135208
136- def visualize_feature_importance (self , path : os .PathLike ):
137- self .expl .visualize_feature_importance (str (path ))
209+ Args:
210+ path (os.PathLike | None, optional): Path to save the feature importance graph. Defaults to None.
211+ top_k (int | None, optional): The number of features to include in the graph. Defaults to None.
212+ """
213+ feat_labels = None
214+ if isinstance (self .clinical_data , pd .DataFrame ):
215+ feat_labels = self .clinical_data .columns .to_list ()
138216
139- def visualize_graph (self , path : os .PathLike ):
140- self .expl .visualize_graph (str (path ))
217+ self .expl .visualize_feature_importance (
218+ str (path ) if path != None else path , top_k = top_k , feat_labels = feat_labels # type: ignore
219+ )
0 commit comments