Skip to content

Commit 680ba56

Browse files
committed
mv to use np.random.default_rng
1 parent 92229ef commit 680ba56

File tree

3 files changed

+26
-11
lines changed

3 files changed

+26
-11
lines changed

src/epidemik/EpiModel.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from typing import Dict, List, Set, Union
77
import warnings
88
import string
9+
import time
10+
import os
911

1012
import networkx as nx
1113
import numpy as np
@@ -22,7 +24,7 @@ class EpiModel(object):
2224
2325
Provides a way to implement and numerically integrate
2426
"""
25-
def __init__(self, compartments=None):
27+
def __init__(self, compartments=None, seed=None, rng=None):
2628
"""
2729
Initialize the EpiModel object
2830
@@ -38,7 +40,15 @@ def __init__(self, compartments=None):
3840
self.population = None
3941
self.orig_comps = None
4042
self.demographics = False
41-
43+
44+
if seed is None:
45+
seed = int(time.time()) + os.getpid()
46+
47+
if rng is None:
48+
self.rng = np.random.default_rng(seed=seed)
49+
else:
50+
self.rng = rng
51+
4252
if compartments is not None:
4353
self.transitions.add_nodes_from([comp for comp in compartments])
4454

@@ -397,11 +407,11 @@ def simulate(self, timesteps: int, t_min: int = 1, seasonality: Union[np.ndarray
397407
comp_id = pos[comp]
398408

399409
if "birth" in data:
400-
births = np.random.binomial(pop[comp_id], data["birth"])
410+
births = self.rng.binomial(pop[comp_id], data["birth"])
401411
new_pop[comp_id] += births
402412

403413
if "death" in data:
404-
deaths = np.random.binomial(pop[comp_id], data["death"])
414+
deaths = self.rng.binomial(pop[comp_id], data["death"])
405415
new_pop[comp_id] -= deaths
406416

407417
values.append(new_pop)
@@ -438,7 +448,7 @@ def integrate(self, timesteps: int , t_min: int = 1, seasonality: Union[np.ndarr
438448
else:
439449
total_pop = self.population.sum()
440450
p = np.copy(self.population)/total_pop
441-
n = np.random.multinomial(kwargs[comp], p, 1)[0]
451+
n = self.rng.multinomial(kwargs[comp], p, 1)[0]
442452

443453
for i, age in enumerate(string.ascii_lowercase[:len(p)]):
444454
comp_age = comp + '_' + age

src/epidemik/MetaEpiModel.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class MetaEpiModel:
2222
2323
Provides a way to implement and numerically integrate
2424
"""
25-
def __init__(self, travel_graph, populations, population='Population'):
25+
def __init__(self, travel_graph, populations, population='Population', seed=None):
2626
"""
2727
Initialize the EpiModel object
2828
@@ -43,8 +43,13 @@ def __init__(self, travel_graph, populations, population='Population'):
4343
self.prototype = None
4444
self.seasonality = None
4545

46+
if seed is None:
47+
seed = int(time.time()) + os.getpid()
48+
49+
self.rng = np.random.default_rng(seed=seed)
50+
4651
for state in travel_graph.index:
47-
models[state] = EpiModel()
52+
models[state] = EpiModel(rng=self.rng)
4853
if self.transitions is None:
4954
self.transitions = models[state].transitions
5055
self.prototype = models[state]
@@ -153,7 +158,7 @@ def _run_travel(self, compartments_, travel):
153158
def travel_step(x, populations):
154159
n = populations.loc[x.name]
155160
p = travel.loc[x.name].values.tolist()
156-
output = np.random.multinomial(n, p)
161+
output = self.rng.multinomial(n, p)
157162

158163
return pd.Series(output, index=travel.columns)
159164

src/epidemik/NetworkEpiModel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,21 @@ def simulate(self, timesteps, seeds, **kwargs):
8181
continue
8282

8383
current_active = list(active_nodes)
84-
np.random.shuffle(current_active)
84+
self.rng.shuffle(current_active)
8585

8686
for node_i in current_active:
8787
state_i = population[t-1, node_i]
8888

8989
if state_i in infections:
9090
# contact each neighbour to see if we infect them
9191
NN = list(self.network.neighbors(node_i))
92-
np.random.shuffle(NN)
92+
self.rng.shuffle(NN)
9393

9494
for node_j in NN:
9595
state_j = population[t-1, node_j]
9696

9797
if state_j in infections[state_i]:
98-
prob = np.random.random()
98+
prob = self.rng.random()
9999

100100
if prob < infections[state_i][state_j]['rate']:
101101
new_state = infections[state_i][state_j]['target']

0 commit comments

Comments
 (0)