From c206f52a763a665c4b22aaaa4bf0234d0d998da7 Mon Sep 17 00:00:00 2001 From: Hemna Date: Wed, 24 Apr 2024 13:49:55 -0400 Subject: [PATCH] Make all the Objectstore children use the same lock This patch updates the ObjectStore and it's child classes all use the same lock. --- aprsd/packets/packet_list.py | 103 ++++++++++++++++------------------- aprsd/packets/seen_list.py | 40 ++++++-------- aprsd/packets/tracker.py | 78 +++++++++++--------------- aprsd/packets/watch_list.py | 58 ++++++++++---------- aprsd/stats/collector.py | 2 +- aprsd/utils/objectstore.py | 31 ++++++----- 6 files changed, 144 insertions(+), 168 deletions(-) diff --git a/aprsd/packets/packet_list.py b/aprsd/packets/packet_list.py index 48969df..c97b429 100644 --- a/aprsd/packets/packet_list.py +++ b/aprsd/packets/packet_list.py @@ -1,9 +1,7 @@ from collections import OrderedDict import logging -import threading from oslo_config import cfg -import wrapt from aprsd.packets import collector, core from aprsd.utils import objectstore @@ -16,7 +14,6 @@ LOG = logging.getLogger("APRSD") class PacketList(objectstore.ObjectStoreMixin): """Class to keep track of the packets we tx/rx.""" _instance = None - lock = threading.Lock() _total_rx: int = 0 _total_tx: int = 0 maxlen: int = 100 @@ -34,29 +31,29 @@ class PacketList(objectstore.ObjectStoreMixin): "packets": OrderedDict(), } - @wrapt.synchronized(lock) def rx(self, packet: type[core.Packet]): """Add a packet that was received.""" - self._total_rx += 1 - self._add(packet) - ptype = packet.__class__.__name__ - if not ptype in self.data["types"]: - self.data["types"][ptype] = {"tx": 0, "rx": 0} - self.data["types"][ptype]["rx"] += 1 + with self.lock: + self._total_rx += 1 + self._add(packet) + ptype = packet.__class__.__name__ + if not ptype in self.data["types"]: + self.data["types"][ptype] = {"tx": 0, "rx": 0} + self.data["types"][ptype]["rx"] += 1 - @wrapt.synchronized(lock) def tx(self, packet: type[core.Packet]): """Add a packet that was received.""" - self._total_tx += 1 - self._add(packet) - ptype = packet.__class__.__name__ - if not ptype in self.data["types"]: - self.data["types"][ptype] = {"tx": 0, "rx": 0} - self.data["types"][ptype]["tx"] += 1 + with self.lock: + self._total_tx += 1 + self._add(packet) + ptype = packet.__class__.__name__ + if not ptype in self.data["types"]: + self.data["types"][ptype] = {"tx": 0, "rx": 0} + self.data["types"][ptype]["tx"] += 1 - @wrapt.synchronized(lock) def add(self, packet): - self._add(packet) + with self.lock: + self._add(packet) def _add(self, packet): if not self.data.get("packets"): @@ -67,54 +64,50 @@ class PacketList(objectstore.ObjectStoreMixin): self.data["packets"].popitem(last=False) self.data["packets"][packet.key] = packet - @wrapt.synchronized(lock) - def copy(self): - return self.data.copy() - - @wrapt.synchronized(lock) def find(self, packet): - return self.data["packets"][packet.key] + with self.lock: + return self.data["packets"][packet.key] - @wrapt.synchronized(lock) def __len__(self): - return len(self.data["packets"]) + with self.lock: + return len(self.data["packets"]) - @wrapt.synchronized(lock) def total_rx(self): - return self._total_rx + with self.lock: + return self._total_rx - @wrapt.synchronized(lock) def total_tx(self): - return self._total_tx + with self.lock: + return self._total_tx - @wrapt.synchronized(lock) def stats(self, serializable=False) -> dict: # limit the number of packets to return to 50 - tmp = OrderedDict( - reversed( - list( - self.data.get("packets", OrderedDict()).items(), + with self.lock: + tmp = OrderedDict( + reversed( + list( + self.data.get("packets", OrderedDict()).items(), + ), ), - ), - ) - pkts = [] - count = 1 - for packet in tmp: - pkts.append(tmp[packet]) - count += 1 - if count > CONF.packet_list_stats_maxlen: - break + ) + pkts = [] + count = 1 + for packet in tmp: + pkts.append(tmp[packet]) + count += 1 + if count > CONF.packet_list_stats_maxlen: + break - stats = { - "total_tracked": self._total_rx + self._total_rx, - "rx": self._total_rx, - "tx": self._total_tx, - "types": self.data.get("types", []), - "packet_count": len(self.data.get("packets", [])), - "maxlen": self.maxlen, - "packets": pkts, - } - return stats + stats = { + "total_tracked": self._total_rx + self._total_rx, + "rx": self._total_rx, + "tx": self._total_tx, + "types": self.data.get("types", []), + "packet_count": len(self.data.get("packets", [])), + "maxlen": self.maxlen, + "packets": pkts, + } + return stats # Now register the PacketList with the collector diff --git a/aprsd/packets/seen_list.py b/aprsd/packets/seen_list.py index e72324c..7150b64 100644 --- a/aprsd/packets/seen_list.py +++ b/aprsd/packets/seen_list.py @@ -1,9 +1,7 @@ import datetime import logging -import threading from oslo_config import cfg -import wrapt from aprsd.packets import collector, core from aprsd.utils import objectstore @@ -17,7 +15,6 @@ class SeenList(objectstore.ObjectStoreMixin): """Global callsign seen list.""" _instance = None - lock = threading.Lock() data: dict = {} def __new__(cls, *args, **kwargs): @@ -26,32 +23,27 @@ class SeenList(objectstore.ObjectStoreMixin): cls._instance.data = {} return cls._instance - @wrapt.synchronized(lock) def stats(self, serializable=False): """Return the stats for the PacketTrack class.""" - return self.data + with self.lock: + return self.data - @wrapt.synchronized(lock) - def copy(self): - """Return a copy of the data.""" - return self.data.copy() - - @wrapt.synchronized(lock) def rx(self, packet: type[core.Packet]): """When we get a packet from the network, update the seen list.""" - callsign = None - if packet.from_call: - callsign = packet.from_call - else: - LOG.warning(f"Can't find FROM in packet {packet}") - return - if callsign not in self.data: - self.data[callsign] = { - "last": None, - "count": 0, - } - self.data[callsign]["last"] = datetime.datetime.now() - self.data[callsign]["count"] += 1 + with self.lock: + callsign = None + if packet.from_call: + callsign = packet.from_call + else: + LOG.warning(f"Can't find FROM in packet {packet}") + return + if callsign not in self.data: + self.data[callsign] = { + "last": None, + "count": 0, + } + self.data[callsign]["last"] = datetime.datetime.now() + self.data[callsign]["count"] += 1 def tx(self, packet: type[core.Packet]): """We don't care about TX packets.""" diff --git a/aprsd/packets/tracker.py b/aprsd/packets/tracker.py index d848712..83fa1c3 100644 --- a/aprsd/packets/tracker.py +++ b/aprsd/packets/tracker.py @@ -1,9 +1,7 @@ import datetime import logging -import threading from oslo_config import cfg -import wrapt from aprsd.packets import collector, core from aprsd.utils import objectstore @@ -28,7 +26,6 @@ class PacketTrack(objectstore.ObjectStoreMixin): _instance = None _start_time = None - lock = threading.Lock() data: dict = {} total_tracked: int = 0 @@ -40,48 +37,43 @@ class PacketTrack(objectstore.ObjectStoreMixin): cls._instance._init_store() return cls._instance - @wrapt.synchronized(lock) def __getitem__(self, name): - return self.data[name] + with self.lock: + return self.data[name] - @wrapt.synchronized(lock) def __iter__(self): - return iter(self.data) + with self.lock: + return iter(self.data) - @wrapt.synchronized(lock) def keys(self): - return self.data.keys() + with self.lock: + return self.data.keys() - @wrapt.synchronized(lock) def items(self): - return self.data.items() + with self.lock: + return self.data.items() - @wrapt.synchronized(lock) def values(self): - return self.data.values() + with self.lock: + return self.data.values() - @wrapt.synchronized(lock) def stats(self, serializable=False): - stats = { - "total_tracked": self.total_tracked, - } - pkts = {} - for key in self.data: - last_send_time = self.data[key].last_send_time - pkts[key] = { - "last_send_time": last_send_time, - "send_count": self.data[key].send_count, - "retry_count": self.data[key].retry_count, - "message": self.data[key].raw, + with self.lock: + stats = { + "total_tracked": self.total_tracked, } - stats["packets"] = pkts + pkts = {} + for key in self.data: + last_send_time = self.data[key].last_send_time + pkts[key] = { + "last_send_time": last_send_time, + "send_count": self.data[key].send_count, + "retry_count": self.data[key].retry_count, + "message": self.data[key].raw, + } + stats["packets"] = pkts return stats - @wrapt.synchronized(lock) - def __len__(self): - return len(self.data) - - @wrapt.synchronized(lock) def rx(self, packet: type[core.Packet]) -> None: """When we get a packet from the network, check if we should remove it.""" if isinstance(packet, core.AckPacket): @@ -92,27 +84,23 @@ class PacketTrack(objectstore.ObjectStoreMixin): # Got a piggyback ack, so remove the original message self._remove(packet.ackMsgNo) - @wrapt.synchronized(lock) def tx(self, packet: type[core.Packet]) -> None: """Add a packet that was sent.""" - key = packet.msgNo - packet.send_count = 0 - self.data[key] = packet - self.total_tracked += 1 + with self.lock: + key = packet.msgNo + packet.send_count = 0 + self.data[key] = packet + self.total_tracked += 1 - @wrapt.synchronized(lock) - def get(self, key): - return self.data.get(key) - - @wrapt.synchronized(lock) def remove(self, key): self._remove(key) def _remove(self, key): - try: - del self.data[key] - except KeyError: - pass + with self.lock: + try: + del self.data[key] + except KeyError: + pass # Now register the PacketList with the collector diff --git a/aprsd/packets/watch_list.py b/aprsd/packets/watch_list.py index 0925613..dec5e21 100644 --- a/aprsd/packets/watch_list.py +++ b/aprsd/packets/watch_list.py @@ -1,9 +1,7 @@ import datetime import logging -import threading from oslo_config import cfg -import wrapt from aprsd import utils from aprsd.packets import collector, core @@ -18,7 +16,6 @@ class WatchList(objectstore.ObjectStoreMixin): """Global watch list and info for callsigns.""" _instance = None - lock = threading.Lock() data = {} def __new__(cls, *args, **kwargs): @@ -28,52 +25,55 @@ class WatchList(objectstore.ObjectStoreMixin): return cls._instance def _update_from_conf(self, config=None): - if CONF.watch_list.enabled and CONF.watch_list.callsigns: - for callsign in CONF.watch_list.callsigns: - call = callsign.replace("*", "") - # FIXME(waboring) - we should fetch the last time we saw - # a beacon from a callsign or some other mechanism to find - # last time a message was seen by aprs-is. For now this - # is all we can do. - if call not in self.data: - self.data[call] = { - "last": None, - "packet": None, - } + with self.lock: + if CONF.watch_list.enabled and CONF.watch_list.callsigns: + for callsign in CONF.watch_list.callsigns: + call = callsign.replace("*", "") + # FIXME(waboring) - we should fetch the last time we saw + # a beacon from a callsign or some other mechanism to find + # last time a message was seen by aprs-is. For now this + # is all we can do. + if call not in self.data: + self.data[call] = { + "last": None, + "packet": None, + } - @wrapt.synchronized(lock) def stats(self, serializable=False) -> dict: stats = {} - for callsign in self.data: - stats[callsign] = { - "last": self.data[callsign]["last"], - "packet": self.data[callsign]["packet"], - "age": self.age(callsign), - "old": self.is_old(callsign), - } + with self.lock: + for callsign in self.data: + stats[callsign] = { + "last": self.data[callsign]["last"], + "packet": self.data[callsign]["packet"], + "age": self.age(callsign), + "old": self.is_old(callsign), + } return stats def is_enabled(self): return CONF.watch_list.enabled def callsign_in_watchlist(self, callsign): - return callsign in self.data + with self.lock: + return callsign in self.data - @wrapt.synchronized(lock) def rx(self, packet: type[core.Packet]) -> None: """Track when we got a packet from the network.""" callsign = packet.from_call if self.callsign_in_watchlist(callsign): - self.data[callsign]["last"] = datetime.datetime.now() - self.data[callsign]["packet"] = packet + with self.lock: + self.data[callsign]["last"] = datetime.datetime.now() + self.data[callsign]["packet"] = packet def tx(self, packet: type[core.Packet]) -> None: """We don't care about TX packets.""" def last_seen(self, callsign): - if self.callsign_in_watchlist(callsign): - return self.data[callsign]["last"] + with self.lock: + if self.callsign_in_watchlist(callsign): + return self.data[callsign]["last"] def age(self, callsign): now = datetime.datetime.now() diff --git a/aprsd/stats/collector.py b/aprsd/stats/collector.py index 91e1833..9928bb3 100644 --- a/aprsd/stats/collector.py +++ b/aprsd/stats/collector.py @@ -27,7 +27,7 @@ class Collector: cls = name() if isinstance(cls, StatsProducer): try: - stats[cls.__class__.__name__] = cls.stats(serializable=serializable) + stats[cls.__class__.__name__] = cls.stats(serializable=serializable).copy() except Exception as e: LOG.error(f"Error in producer {name} (stats): {e}") else: diff --git a/aprsd/utils/objectstore.py b/aprsd/utils/objectstore.py index 7637fc2..b04f6e6 100644 --- a/aprsd/utils/objectstore.py +++ b/aprsd/utils/objectstore.py @@ -2,6 +2,7 @@ import logging import os import pathlib import pickle +import threading from oslo_config import cfg @@ -25,19 +26,28 @@ class ObjectStoreMixin: aprsd server -f (flush) will wipe all saved objects. """ + def __init__(self): + self.lock = threading.RLock() + def __len__(self): - return len(self.data) + with self.lock: + return len(self.data) def __iter__(self): - return iter(self.data) + with self.lock: + return iter(self.data) def get_all(self): with self.lock: return self.data - def get(self, id): + def get(self, key): with self.lock: - return self.data[id] + return self.data.get(key) + + def copy(self): + with self.lock: + return self.data.copy() def _init_store(self): if not CONF.enable_save: @@ -58,14 +68,6 @@ class ObjectStoreMixin: self.__class__.__name__.lower(), ) - def _dump(self): - dump = {} - with self.lock: - for key in self.data.keys(): - dump[key] = self.data[key] - - return dump - def save(self): """Save any queued to disk?""" if not CONF.enable_save: @@ -78,8 +80,9 @@ class ObjectStoreMixin: f" {len(self)} entries to disk at " f"{save_filename}", ) - with open(save_filename, "wb+") as fp: - pickle.dump(self._dump(), fp) + with self.lock: + with open(save_filename, "wb+") as fp: + pickle.dump(self.data, fp) else: LOG.debug( "{} Nothing to save, flushing old save file '{}'".format(