1
0
mirror of https://github.com/craigerl/aprsd.git synced 2024-12-21 00:51:06 -05:00

Make all the Objectstore children use the same lock

This patch updates the ObjectStore and it's child classes
all use the same lock.
This commit is contained in:
Hemna 2024-04-24 13:49:55 -04:00
parent 2b2bf6c92d
commit c206f52a76
6 changed files with 144 additions and 168 deletions

View File

@ -1,9 +1,7 @@
from collections import OrderedDict
import logging
import threading
from oslo_config import cfg
import wrapt
from aprsd.packets import collector, core
from aprsd.utils import objectstore
@ -16,7 +14,6 @@ LOG = logging.getLogger("APRSD")
class PacketList(objectstore.ObjectStoreMixin):
"""Class to keep track of the packets we tx/rx."""
_instance = None
lock = threading.Lock()
_total_rx: int = 0
_total_tx: int = 0
maxlen: int = 100
@ -34,9 +31,9 @@ class PacketList(objectstore.ObjectStoreMixin):
"packets": OrderedDict(),
}
@wrapt.synchronized(lock)
def rx(self, packet: type[core.Packet]):
"""Add a packet that was received."""
with self.lock:
self._total_rx += 1
self._add(packet)
ptype = packet.__class__.__name__
@ -44,9 +41,9 @@ class PacketList(objectstore.ObjectStoreMixin):
self.data["types"][ptype] = {"tx": 0, "rx": 0}
self.data["types"][ptype]["rx"] += 1
@wrapt.synchronized(lock)
def tx(self, packet: type[core.Packet]):
"""Add a packet that was received."""
with self.lock:
self._total_tx += 1
self._add(packet)
ptype = packet.__class__.__name__
@ -54,8 +51,8 @@ class PacketList(objectstore.ObjectStoreMixin):
self.data["types"][ptype] = {"tx": 0, "rx": 0}
self.data["types"][ptype]["tx"] += 1
@wrapt.synchronized(lock)
def add(self, packet):
with self.lock:
self._add(packet)
def _add(self, packet):
@ -67,29 +64,25 @@ class PacketList(objectstore.ObjectStoreMixin):
self.data["packets"].popitem(last=False)
self.data["packets"][packet.key] = packet
@wrapt.synchronized(lock)
def copy(self):
return self.data.copy()
@wrapt.synchronized(lock)
def find(self, packet):
with self.lock:
return self.data["packets"][packet.key]
@wrapt.synchronized(lock)
def __len__(self):
with self.lock:
return len(self.data["packets"])
@wrapt.synchronized(lock)
def total_rx(self):
with self.lock:
return self._total_rx
@wrapt.synchronized(lock)
def total_tx(self):
with self.lock:
return self._total_tx
@wrapt.synchronized(lock)
def stats(self, serializable=False) -> dict:
# limit the number of packets to return to 50
with self.lock:
tmp = OrderedDict(
reversed(
list(

View File

@ -1,9 +1,7 @@
import datetime
import logging
import threading
from oslo_config import cfg
import wrapt
from aprsd.packets import collector, core
from aprsd.utils import objectstore
@ -17,7 +15,6 @@ class SeenList(objectstore.ObjectStoreMixin):
"""Global callsign seen list."""
_instance = None
lock = threading.Lock()
data: dict = {}
def __new__(cls, *args, **kwargs):
@ -26,19 +23,14 @@ class SeenList(objectstore.ObjectStoreMixin):
cls._instance.data = {}
return cls._instance
@wrapt.synchronized(lock)
def stats(self, serializable=False):
"""Return the stats for the PacketTrack class."""
with self.lock:
return self.data
@wrapt.synchronized(lock)
def copy(self):
"""Return a copy of the data."""
return self.data.copy()
@wrapt.synchronized(lock)
def rx(self, packet: type[core.Packet]):
"""When we get a packet from the network, update the seen list."""
with self.lock:
callsign = None
if packet.from_call:
callsign = packet.from_call

View File

@ -1,9 +1,7 @@
import datetime
import logging
import threading
from oslo_config import cfg
import wrapt
from aprsd.packets import collector, core
from aprsd.utils import objectstore
@ -28,7 +26,6 @@ class PacketTrack(objectstore.ObjectStoreMixin):
_instance = None
_start_time = None
lock = threading.Lock()
data: dict = {}
total_tracked: int = 0
@ -40,28 +37,28 @@ class PacketTrack(objectstore.ObjectStoreMixin):
cls._instance._init_store()
return cls._instance
@wrapt.synchronized(lock)
def __getitem__(self, name):
with self.lock:
return self.data[name]
@wrapt.synchronized(lock)
def __iter__(self):
with self.lock:
return iter(self.data)
@wrapt.synchronized(lock)
def keys(self):
with self.lock:
return self.data.keys()
@wrapt.synchronized(lock)
def items(self):
with self.lock:
return self.data.items()
@wrapt.synchronized(lock)
def values(self):
with self.lock:
return self.data.values()
@wrapt.synchronized(lock)
def stats(self, serializable=False):
with self.lock:
stats = {
"total_tracked": self.total_tracked,
}
@ -77,11 +74,6 @@ class PacketTrack(objectstore.ObjectStoreMixin):
stats["packets"] = pkts
return stats
@wrapt.synchronized(lock)
def __len__(self):
return len(self.data)
@wrapt.synchronized(lock)
def rx(self, packet: type[core.Packet]) -> None:
"""When we get a packet from the network, check if we should remove it."""
if isinstance(packet, core.AckPacket):
@ -92,23 +84,19 @@ class PacketTrack(objectstore.ObjectStoreMixin):
# Got a piggyback ack, so remove the original message
self._remove(packet.ackMsgNo)
@wrapt.synchronized(lock)
def tx(self, packet: type[core.Packet]) -> None:
"""Add a packet that was sent."""
with self.lock:
key = packet.msgNo
packet.send_count = 0
self.data[key] = packet
self.total_tracked += 1
@wrapt.synchronized(lock)
def get(self, key):
return self.data.get(key)
@wrapt.synchronized(lock)
def remove(self, key):
self._remove(key)
def _remove(self, key):
with self.lock:
try:
del self.data[key]
except KeyError:

View File

@ -1,9 +1,7 @@
import datetime
import logging
import threading
from oslo_config import cfg
import wrapt
from aprsd import utils
from aprsd.packets import collector, core
@ -18,7 +16,6 @@ class WatchList(objectstore.ObjectStoreMixin):
"""Global watch list and info for callsigns."""
_instance = None
lock = threading.Lock()
data = {}
def __new__(cls, *args, **kwargs):
@ -28,6 +25,7 @@ class WatchList(objectstore.ObjectStoreMixin):
return cls._instance
def _update_from_conf(self, config=None):
with self.lock:
if CONF.watch_list.enabled and CONF.watch_list.callsigns:
for callsign in CONF.watch_list.callsigns:
call = callsign.replace("*", "")
@ -41,9 +39,9 @@ class WatchList(objectstore.ObjectStoreMixin):
"packet": None,
}
@wrapt.synchronized(lock)
def stats(self, serializable=False) -> dict:
stats = {}
with self.lock:
for callsign in self.data:
stats[callsign] = {
"last": self.data[callsign]["last"],
@ -57,14 +55,15 @@ class WatchList(objectstore.ObjectStoreMixin):
return CONF.watch_list.enabled
def callsign_in_watchlist(self, callsign):
with self.lock:
return callsign in self.data
@wrapt.synchronized(lock)
def rx(self, packet: type[core.Packet]) -> None:
"""Track when we got a packet from the network."""
callsign = packet.from_call
if self.callsign_in_watchlist(callsign):
with self.lock:
self.data[callsign]["last"] = datetime.datetime.now()
self.data[callsign]["packet"] = packet
@ -72,6 +71,7 @@ class WatchList(objectstore.ObjectStoreMixin):
"""We don't care about TX packets."""
def last_seen(self, callsign):
with self.lock:
if self.callsign_in_watchlist(callsign):
return self.data[callsign]["last"]

View File

@ -27,7 +27,7 @@ class Collector:
cls = name()
if isinstance(cls, StatsProducer):
try:
stats[cls.__class__.__name__] = cls.stats(serializable=serializable)
stats[cls.__class__.__name__] = cls.stats(serializable=serializable).copy()
except Exception as e:
LOG.error(f"Error in producer {name} (stats): {e}")
else:

View File

@ -2,6 +2,7 @@ import logging
import os
import pathlib
import pickle
import threading
from oslo_config import cfg
@ -25,19 +26,28 @@ class ObjectStoreMixin:
aprsd server -f (flush) will wipe all saved objects.
"""
def __init__(self):
self.lock = threading.RLock()
def __len__(self):
with self.lock:
return len(self.data)
def __iter__(self):
with self.lock:
return iter(self.data)
def get_all(self):
with self.lock:
return self.data
def get(self, id):
def get(self, key):
with self.lock:
return self.data[id]
return self.data.get(key)
def copy(self):
with self.lock:
return self.data.copy()
def _init_store(self):
if not CONF.enable_save:
@ -58,14 +68,6 @@ class ObjectStoreMixin:
self.__class__.__name__.lower(),
)
def _dump(self):
dump = {}
with self.lock:
for key in self.data.keys():
dump[key] = self.data[key]
return dump
def save(self):
"""Save any queued to disk?"""
if not CONF.enable_save:
@ -78,8 +80,9 @@ class ObjectStoreMixin:
f" {len(self)} entries to disk at "
f"{save_filename}",
)
with self.lock:
with open(save_filename, "wb+") as fp:
pickle.dump(self._dump(), fp)
pickle.dump(self.data, fp)
else:
LOG.debug(
"{} Nothing to save, flushing old save file '{}'".format(