@@ -40,6 +40,7 @@ def __init__(self, compartments=None, seed=None, rng=None):
4040 self .population = None
4141 self .orig_comps = None
4242 self .demographics = False
43+ self .params = {}
4344
4445 if seed is None :
4546 seed = int (time .time ()) + os .getpid ()
@@ -48,11 +49,15 @@ def __init__(self, compartments=None, seed=None, rng=None):
4849 self .rng = np .random .default_rng (seed = seed )
4950 else :
5051 self .rng = rng
51-
52+
5253 if compartments is not None :
5354 self .transitions .add_nodes_from ([comp for comp in compartments ])
5455
55- def add_interaction (self , source : str , target : str , agent : str , rate : float ) -> None :
56+ def add_interaction (self ,
57+ source : str ,
58+ target : str ,
59+ agent : str ,
60+ ** rates ) -> None :
5661 """
5762 Add an interaction between two compartments
5863
@@ -63,15 +68,19 @@ def add_interaction(self, source: str, target: str, agent: str, rate: float) ->
6368 Name of the target compartment
6469 - agent: string
6570 Name of the agent
66- - rate: float
67- Rate of the interaction
71+ - params: string
72+ Named parameters for the interaction
6873
6974 Returns:
7075 None
71- """
72- self .transitions .add_edge (source , target , agent = agent , rate = rate )
76+ """
77+
78+ self .params .update (rates )
79+ rates = list (rates .keys ())
80+
81+ self .transitions .add_edge (source , target , agent = agent , rate = rates [0 ])
7382
74- def add_spontaneous (self , source : str , target : str , rate : float ) -> None :
83+ def add_spontaneous (self , source : str , target : str , ** rates ) -> None :
7584 """
7685 Add a spontaneous transition between two compartments
7786
@@ -86,7 +95,11 @@ def add_spontaneous(self, source: str, target: str, rate: float) -> None:
8695 Returns:
8796 None
8897 """
89- self .transitions .add_edge (source , target , rate = rate )
98+
99+ self .params .update (rates )
100+ rates = list (rates .keys ())
101+
102+ self .transitions .add_edge (source , target , rate = rates [0 ])
90103
91104 def add_birth_rate (self , rate : float , comps : Union [List , None ] = None ) -> None :
92105 """
@@ -224,7 +237,8 @@ def _new_cases(self, time: float, population: np.ndarray, pos: Dict) -> np.ndar
224237 target = edge [1 ]
225238 trans = edge [2 ]
226239
227- rate = trans ['rate' ]* population [pos [source ]]
240+ rate_val = self .params [trans ['rate' ]]
241+ rate = rate_val * population [pos [source ]]
228242
229243 if 'start' in trans and trans ['start' ] >= time :
230244 continue
@@ -370,7 +384,7 @@ def simulate(self, timesteps: int, t_min: int = 1, seasonality: Union[np.ndarray
370384 source = pos [comp ]
371385 target = pos [node_j ]
372386
373- rate = data ['rate' ]
387+ rate = self . params [ data ['rate' ] ]
374388
375389 if 'start' in data and data ['start' ] >= t :
376390 continue
@@ -505,21 +519,24 @@ def __repr__(self) -> str:
505519 (self .transitions .number_of_nodes (),
506520 self .transitions .number_of_edges ())
507521
522+ text += "Parameters:\n "
523+ for rate , value in self .params .items ():
524+ text += "%s= %f\n " % (rate , value )
525+ text += "\n \n Transitions:\n "
526+
508527 for edge in self .transitions .edges (data = True ):
509528 source = edge [0 ]
510529 target = edge [1 ]
511530 trans = edge [2 ]
512-
513- rate = trans ['rate' ]
514-
531+
515532 if 'agent' in trans :
516533 agent = trans ['agent' ]
517- text += "%s + %s = %s %f \n " % (source , agent , target , rate )
534+ text += "%s + %s = %s %s \n " % (source , agent , target , trans [ ' rate' ] )
518535 elif 'start' in trans :
519536 start = trans ['start' ]
520- text += "%s -> %s %f starting at %s days\n " % (source , target , rate , start )
537+ text += "%s -> %s %s starting at %s days\n " % (source , target , rate , start )
521538 else :
522- text += "%s -> %s %f \n " % (source , target , rate )
539+ text += "%s -> %s %s \n " % (source , target , rate )
523540
524541 R0 = self .R0 ()
525542
@@ -664,7 +681,7 @@ def R0(self) -> Union[float, None]:
664681
665682 try :
666683 for node_i , node_j , data in self .transitions .edges (data = True ):
667- rate = data ['rate' ]
684+ rate = self . params [ data ['rate' ] ]
668685
669686 if "agent" in data :
670687 target = pos [node_j ]
0 commit comments