-
-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathmemory_data_store.py
More file actions
321 lines (270 loc) · 13.1 KB
/
memory_data_store.py
File metadata and controls
321 lines (270 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
#!/usr/bin/env python3
import asyncio
import copy
from datetime import datetime, timedelta
import logging
from zoneinfo import ZoneInfo
import aiohttp
from models.node import Node
from storage.db.postgres import PostgresStorage
import utils
logger = logging.getLogger(__name__)
class MemoryDataStore:
def __init__(self, config):
self.config = config
self.chat: dict = {}
self.chat['channels'] = {
'0': {
'name': 'General',
'messages': []
}
}
self.graph: dict|None = {}
self.messages: list = []
self.mqtt_messages: list = []
self.mqtt_connect_time: datetime = self.config['server']['start_time']
self.nodes: dict = {}
self.telemetry: list = []
self.telemetry_by_node: dict = {}
self.traceroutes: list = []
self.traceroutes_by_node: dict = {}
# Initialize Postgres storage
self.pg_storage = PostgresStorage(config)
def __deepcopy__(self, memo):
"""
Custom deepcopy to avoid copying non-copyable runtime objects (e.g., asyncpg buffers
held by PostgresStorage / pools / connections). Renderers only need the in-memory
data snapshot, not the live DB connection.
"""
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# Never deepcopy PostgresStorage (it can contain asyncpg internals)
if k == "pg_storage":
setattr(result, k, None)
continue
# Skip deepcopy for asyncpg internals if they somehow land on the store
mod = type(v).__module__
if isinstance(mod, str) and mod.startswith("asyncpg"):
setattr(result, k, None)
continue
# Skip common non-copyable runtime objects
try:
if isinstance(v, (asyncio.Lock, asyncio.Event, asyncio.Task, logging.Logger)):
setattr(result, k, v)
continue
except Exception as exc:
# If runtime/types differ, don't block deepcopy.
logger.debug(
"MemoryDataStore.__deepcopy__: isinstance() guard failed for key %r (type=%s): %s",
k, type(v),
exc,
exc_info=True,
)
setattr(result, k, copy.deepcopy(v, memo))
return result
def update(self, key, value):
self.__dict__[key] = value
def update_node(self, id: str, node):
n = node.copy()
if n['position'] is None:
n['position'] = {}
if self.config['integrations']['geocoding']['enabled']:
if 'geocoded' not in n['position']:
n['position']['geocoded'] = None
if 'latitude_i' in n['position'] and 'longitude_i' in n['position'] and n['position']['latitude_i'] is not None and n['position']['longitude_i'] is not None:
if n['position']['geocoded'] is None or n['position']['last_geocoding'] is None or n['position']['last_geocoding'] < datetime.now().astimezone(ZoneInfo(self.config['server']['timezone'])) - timedelta(minutes=60):
try:
geocoded = utils.geocode_position(self.config['integrations']['geocoding']['geocode.maps.co']['api_key'], n['position']['latitude_i'] / 10000000, n['position']['longitude_i'] / 10000000)
if geocoded is not None:
n['position']['geocoded'] = geocoded
n['position']['last_geocoding'] = datetime.now().astimezone(ZoneInfo(self.config['server']['timezone']))
except Exception as e:
logger.warning("Failed to geocode position: %s", e)
n['active'] = True
if 'last_seen' in n and n['last_seen'] is not None and isinstance(n['last_seen'], str):
n['last_seen'] = datetime.fromisoformat(n['last_seen']).astimezone(ZoneInfo(self.config['server']['timezone']))
n['since'] = datetime.now().astimezone(ZoneInfo(self.config['server']['timezone'])) - n['last_seen']
n['last_seen'] = datetime.now().astimezone(ZoneInfo(self.config['server']['timezone']))
self.nodes[id] = n
# Real-time write to Postgres (non-blocking)
if (
'postgres' in self.config.get('storage', {}).get('write_to', [])
and self.pg_storage is not None
and getattr(self.pg_storage, "enabled", False)
and getattr(self.pg_storage, "pool", None) is not None
):
try:
try:
loop = asyncio.get_running_loop()
loop.create_task(self.pg_storage.write_node(id, n))
except RuntimeError:
# No running loop in this thread/context
asyncio.run(self.pg_storage.write_node(id, n))
except Exception as e:
logger.error("Failed to write node %s to Postgres (non-blocking): %s", id, e)
async def load(self):
logger.info("Loading data from PostgreSQL")
await self._load_from_postgres()
async def _load_from_postgres(self):
"""Initialize PostgreSQL connection and prepare postgres-backed mode."""
try:
ok = await self.pg_storage.connect()
if not ok:
raise RuntimeError("PostgreSQL connect failed")
await self.pg_storage.ensure_schema()
# NOTE: your current code intentionally does NOT load rows into memory.
# If your API still reads from self.nodes/chat/telemetry, you'll get empty results.
self.nodes = {}
self.chat = {'channels': {'0': {'name': 'General', 'messages': []}}}
self.telemetry = []
self.telemetry_by_node = {}
self.traceroutes = []
self.traceroutes_by_node = {}
# Ensure default nodes exist in Postgres
default_id = self.config['server']['node_id']
default_node = Node.default_node(default_id)
self.nodes[default_id] = default_node
await self.pg_storage.write_node(default_id, default_node)
broadcast_node = Node.default_node('ffffffff')
self.nodes['ffffffff'] = broadcast_node
await self.pg_storage.write_node('ffffffff', broadcast_node)
logger.info("PostgreSQL mode: Data will be queried directly from database")
except Exception as e:
logger.exception("Failed to initialize PostgreSQL connection: %s", e)
raise
async def save(self):
save_start = datetime.now(ZoneInfo(self.config['server']['timezone']))
last_data = self.config['server']['last_data_save'] if 'last_data_save' in self.config['server'] else self.config['server']['start_time']
since_last_data = (save_start - last_data).total_seconds()
last_backfill = self.config['server']['last_backfill'] if 'last_backfill' in self.config['server'] else self.config['server']['start_time']
since_last_backfill = (save_start - last_backfill).total_seconds()
logger.debug(
"Save (since last): graph: %s (threshold: %s), enrich: %s (threshold: %s)",
since_last_data, self.config['server']['intervals']['data_save'],
since_last_backfill, self.config['server']['enrich']['interval'],
)
if 'enrich' in self.config['server'] and self.config['server']['enrich']['enabled']:
if since_last_backfill >= self.config['server']['enrich']['interval']:
await self.backfill_node_infos()
end = datetime.now(ZoneInfo(self.config['server']['timezone']))
logger.debug("Enriched in %.2f seconds", end.timestamp() - save_start.timestamp())
self.config['server']['last_backfill'] = end
if since_last_data >= self.config['server']['intervals']['data_save']:
end = datetime.now(ZoneInfo(self.config['server']['timezone']))
logger.debug("Rebuilt graph in %.2f seconds", end.timestamp() - save_start.timestamp())
self.config['server']['last_data_save'] = end
self.graph = self.graph_node(self.config['server']['node_id'])
### helpers
async def backfill_node_infos(self):
nodes_needing_enrichment = {}
for id, node in self.nodes.items():
if 'shortname' not in node or 'longname' not in node or node['shortname'] == 'UNK' or node['longname'] == 'Unknown':
nodes_needing_enrichment[id] = node
logger.info("Nodes needing enrichment: %d", len(nodes_needing_enrichment))
if len(nodes_needing_enrichment) > 0:
await self.enrich_nodes(nodes_needing_enrichment)
async def enrich_nodes(self, node_to_enrich):
async with aiohttp.ClientSession() as session:
node_ids = list(node_to_enrich.keys())
logger.info("Enriching nodes: %s", ','.join(node_ids))
if self.config['server']['enrich']['provider'] == 'bayme':
for node_id in node_ids:
logger.debug("Enriching %s", node_id)
url = f"https://data.bayme.sh/api/node/infos?ids={node_id}"
try:
async with session.get(url) as response:
if response.status == 200:
data = await response.json()
for node_id, node_info in data.items():
logger.debug("Got info for %s", node_id)
if node_id in self.nodes:
logger.debug("Enriched %s", node_id)
node = self.nodes[node_id]
node['shortname'] = node_info['shortName']
node['longname'] = node_info['longName']
self.nodes[node_id] = node
else:
logger.warning("Failed to get info for %s", node_id)
except Exception as e:
logger.warning("Failed to get info for %s: %s", node_id, e)
elif self.config['server']['enrich']['provider'] == 'world.meshinfo.network':
url = f"https://world.meshinfo.network/api/v1/nodes?ids={','.join(node_ids)}"
try:
async with session.get(url) as response:
if response.status == 200:
data = await response.json()
for node_id, node_info in data.items():
logger.debug("Got info for %s", node_id)
if node_id in self.nodes:
logger.debug("Enriched %s", node_id)
node = self.nodes[node_id]
node['shortname'] = node_info['shortName']
node['longname'] = node_info['longName']
self.nodes[node_id] = node
else:
logger.warning("Failed to get info for %s", node_ids)
except Exception as e:
logger.warning("Failed to get info for %s: %s", node_ids, e)
def find_node_by_int_id(self, id: int):
return self.nodes.get(utils.convert_node_id_from_int_to_hex(id), None)
def find_node_by_hex_id(self, id: str, include_neighbors: bool = False):
if not isinstance(id, str) or len(id) != 8 or not all(c in '0123456789abcdefABCDEF' for c in id):
return None
n = self.nodes.get(id, None)
if n is None:
return None
node = n.copy()
if include_neighbors:
neighbors_heard = []
if 'neighborinfo' in node and node['neighborinfo'] is not None and 'neighbors' in node['neighborinfo'] and len(node['neighborinfo']['neighbors']) > 0:
for neighbor in node['neighborinfo']['neighbors']:
nn = self.find_node_by_hex_id(utils.convert_node_id_from_int_to_hex(neighbor["node_id"]), include_neighbors=False)
if nn is not None:
neighbors_heard.append(nn.copy())
neighbors_heard_by = []
for nid, n in self.nodes.items():
if 'neighborinfo' in n and n['neighborinfo'] is not None and 'neighbors' in n['neighborinfo'] and len(n['neighborinfo']['neighbors']) > 0:
if id in n['neighborinfo']['neighbors']:
nn = self.find_node_by_hex_id(utils.convert_node_id_from_int_to_hex(nid), include_neighbors=False)
if nn is not None:
neighbors_heard_by.append(nn.copy())
node['neighbors_heard'] = neighbors_heard
node['neighbors_heard_by'] = neighbors_heard_by
return node
def find_node_by_short_name(self, sn: str):
for _id, node in self.nodes.items():
if node['shortname'] == sn:
return node
return None
def find_node_by_longname(self, ln: str):
for _id, node in self.nodes.items():
if node['longname'] == ln:
return node
return None
def graph_node(self, node_id: str) -> dict|None:
logger.debug("Graphing node: %s", node_id)
visited = set() # Set to keep track of visited nodes
def recursive_graph_node(node_id, start_id="", level=0) -> dict|None:
node = self.find_node_by_hex_id(node_id, include_neighbors=True)
if node is None:
return None
if level > 1 and node_id in visited:
return node # Return the node if it has already been visited
logger.debug("%s - %s", " " * level, node_id)
visited.add(node_id) # Mark the node as visited
neighbors_heard = []
neighbors_heard_by = []
if node['neighborinfo'] and node['neighborinfo']['neighbors']:
for neighbor in node['neighborinfo']['neighbors']:
nid = utils.convert_node_id_from_int_to_hex(neighbor["node_id"])
if start_id is not None and start_id == nid or (self.config['server']['graph']['max_depth'] is not None and level >= self.config['server']['graph']['max_depth']):
continue
nn = recursive_graph_node(nid, start_id=start_id, level=level+1)
if nn is not None:
neighbors_heard.append(nn.copy())
node['neighbors_heard'] = neighbors_heard
node['neighbors_heard_by'] = neighbors_heard_by
return node
return recursive_graph_node(node_id, start_id=node_id)