1
0
mirror of https://github.com/craigerl/aprsd.git synced 2025-06-29 15:35:22 -04:00

Converted MsgTrack to ObjectStoreMixin

This commit is contained in:
Hemna 2021-10-22 16:07:20 -04:00
parent b0d25a76f7
commit e009791b75
5 changed files with 35 additions and 74 deletions

View File

@ -146,8 +146,7 @@ def signal_handler(sig, frame):
), ),
) )
time.sleep(1.5) time.sleep(1.5)
tracker = messaging.MsgTrack() messaging.MsgTrack().save()
tracker.save()
packets.WatchList().save() packets.WatchList().save()
packets.SeenList().save() packets.SeenList().save()
LOG.info(stats.APRSDStats()) LOG.info(stats.APRSDStats())
@ -480,13 +479,13 @@ def server(
packets.PacketList(config=config) packets.PacketList(config=config)
if flush: if flush:
LOG.debug("Deleting saved MsgTrack.") LOG.debug("Deleting saved MsgTrack.")
messaging.MsgTrack().flush() messaging.MsgTrack(config=config).flush()
packets.WatchList(config=config) packets.WatchList(config=config)
packets.SeenList(config=config) packets.SeenList(config=config)
else: else:
# Try and load saved MsgTrack list # Try and load saved MsgTrack list
LOG.debug("Loading saved MsgTrack object.") LOG.debug("Loading saved MsgTrack object.")
messaging.MsgTrack().load() messaging.MsgTrack(config=config).load()
packets.WatchList(config=config).load() packets.WatchList(config=config).load()
packets.SeenList(config=config).load() packets.SeenList(config=config).load()

View File

@ -2,16 +2,11 @@ import abc
import datetime import datetime
import logging import logging
from multiprocessing import RawValue from multiprocessing import RawValue
import os
import pathlib
import pickle
import re import re
import threading import threading
import time import time
from aprsd import client from aprsd import client, objectstore, packets, stats, threads
from aprsd import config as aprsd_config
from aprsd import packets, stats, threads
LOG = logging.getLogger("APRSD") LOG = logging.getLogger("APRSD")
@ -21,7 +16,7 @@ LOG = logging.getLogger("APRSD")
NULL_MESSAGE = -1 NULL_MESSAGE = -1
class MsgTrack: class MsgTrack(objectstore.ObjectStoreMixin):
"""Class to keep track of outstanding text messages. """Class to keep track of outstanding text messages.
This is a thread safe class that keeps track of active This is a thread safe class that keeps track of active
@ -38,7 +33,7 @@ class MsgTrack:
_start_time = None _start_time = None
lock = None lock = None
track = {} data = {}
total_messages_tracked = 0 total_messages_tracked = 0
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
@ -47,93 +42,65 @@ class MsgTrack:
cls._instance.track = {} cls._instance.track = {}
cls._instance._start_time = datetime.datetime.now() cls._instance._start_time = datetime.datetime.now()
cls._instance.lock = threading.Lock() cls._instance.lock = threading.Lock()
cls._instance.config = kwargs["config"]
cls._instance._init_store()
return cls._instance return cls._instance
def __getitem__(self, name): def __getitem__(self, name):
with self.lock: with self.lock:
return self.track[name] return self.data[name]
def __iter__(self): def __iter__(self):
with self.lock: with self.lock:
return iter(self.track) return iter(self.data)
def keys(self): def keys(self):
with self.lock: with self.lock:
return self.track.keys() return self.data.keys()
def items(self): def items(self):
with self.lock: with self.lock:
return self.track.items() return self.data.items()
def values(self): def values(self):
with self.lock: with self.lock:
return self.track.values() return self.data.values()
def __len__(self): def __len__(self):
with self.lock: with self.lock:
return len(self.track) return len(self.data)
def __str__(self): def __str__(self):
with self.lock: with self.lock:
result = "{" result = "{"
for key in self.track.keys(): for key in self.data.keys():
result += f"{key}: {str(self.track[key])}, " result += f"{key}: {str(self.data[key])}, "
result += "}" result += "}"
return result return result
def add(self, msg): def add(self, msg):
with self.lock: with self.lock:
key = int(msg.id) key = int(msg.id)
self.track[key] = msg self.data[key] = msg
stats.APRSDStats().msgs_tracked_inc() stats.APRSDStats().msgs_tracked_inc()
self.total_messages_tracked += 1 self.total_messages_tracked += 1
def get(self, id): def get(self, id):
with self.lock: with self.lock:
if id in self.track: if id in self.data:
return self.track[id] return self.data[id]
def remove(self, id): def remove(self, id):
with self.lock: with self.lock:
key = int(id) key = int(id)
if key in self.track.keys(): if key in self.data.keys():
del self.track[key] del self.data[key]
def save(self):
"""Save any queued to disk?"""
LOG.debug(f"Save tracker to disk? {len(self)}")
if len(self) > 0:
LOG.info(f"Saving {len(self)} tracking messages to disk")
pickle.dump(self.dump(), open(aprsd_config.DEFAULT_SAVE_FILE, "wb+"))
else:
LOG.debug(
"Nothing to save, flushing old save file '{}'".format(
aprsd_config.DEFAULT_SAVE_FILE,
),
)
self.flush()
def dump(self):
dump = {}
with self.lock:
for key in self.track.keys():
dump[key] = self.track[key]
return dump
def load(self):
if os.path.exists(aprsd_config.DEFAULT_SAVE_FILE):
raw = pickle.load(open(aprsd_config.DEFAULT_SAVE_FILE, "rb"))
if raw:
self.track = raw
LOG.debug("Loaded MsgTrack dict from disk.")
LOG.debug(self)
def restart(self): def restart(self):
"""Walk the list of messages and restart them if any.""" """Walk the list of messages and restart them if any."""
for key in self.track.keys(): for key in self.data.keys():
msg = self.track[key] msg = self.data[key]
if msg.last_send_attempt < msg.retry_count: if msg.last_send_attempt < msg.retry_count:
msg.send() msg.send()
@ -145,14 +112,14 @@ class MsgTrack:
"""Walk the list of delayed messages and restart them if any.""" """Walk the list of delayed messages and restart them if any."""
if not count: if not count:
# Send all the delayed messages # Send all the delayed messages
for key in self.track.keys(): for key in self.data.keys():
msg = self.track[key] msg = self.data[key]
if msg.last_send_attempt == msg.retry_count: if msg.last_send_attempt == msg.retry_count:
self._resend(msg) self._resend(msg)
else: else:
# They want to resend <count> delayed messages # They want to resend <count> delayed messages
tmp = sorted( tmp = sorted(
self.track.items(), self.data.items(),
reverse=most_recent, reverse=most_recent,
key=lambda x: x[1].last_send_time, key=lambda x: x[1].last_send_time,
) )
@ -160,13 +127,6 @@ class MsgTrack:
for (_key, msg) in msg_list: for (_key, msg) in msg_list:
self._resend(msg) self._resend(msg)
def flush(self):
"""Nuke the old pickle file that stored the old results from last aprsd run."""
if os.path.exists(aprsd_config.DEFAULT_SAVE_FILE):
pathlib.Path(aprsd_config.DEFAULT_SAVE_FILE).unlink()
with self.lock:
self.track = {}
class MessageCounter: class MessageCounter:
""" """

View File

@ -53,7 +53,6 @@ class ObjectStoreMixin:
def _save_filename(self): def _save_filename(self):
save_location = self._save_location() save_location = self._save_location()
LOG.debug(f"{self.__class__.__name__}::Using save location {save_location}")
return "{}/{}.p".format( return "{}/{}.p".format(
save_location, save_location,
self.__class__.__name__.lower(), self.__class__.__name__.lower(),
@ -65,15 +64,12 @@ class ObjectStoreMixin:
for key in self.data.keys(): for key in self.data.keys():
dump[key] = self.data[key] dump[key] = self.data[key]
LOG.debug(f"{self.__class__.__name__}:: DUMP")
LOG.debug(dump)
return dump return dump
def save(self): def save(self):
"""Save any queued to disk?""" """Save any queued to disk?"""
if len(self) > 0: if len(self) > 0:
LOG.info(f"{self.__class__.__name__}::Saving {len(self)} entries to disk") LOG.info(f"{self.__class__.__name__}::Saving {len(self)} entries to disk at {self._save_location()}")
pickle.dump(self._dump(), open(self._save_filename(), "wb+")) pickle.dump(self._dump(), open(self._save_filename(), "wb+"))
else: else:
LOG.debug( LOG.debug(

View File

@ -6,9 +6,14 @@ from aprsd import messaging
class TestMessageTrack(unittest.TestCase): class TestMessageTrack(unittest.TestCase):
def setUp(self) -> None:
config = {}
messaging.MsgTrack(config=config)
def _clean_track(self): def _clean_track(self):
track = messaging.MsgTrack() track = messaging.MsgTrack()
track.track = {} track.data = {}
track.total_messages_tracked = 0 track.total_messages_tracked = 0
return track return track

View File

@ -26,6 +26,7 @@ class TestPlugin(unittest.TestCase):
stats.APRSDStats(self.config) stats.APRSDStats(self.config)
packets.WatchList(config=self.config) packets.WatchList(config=self.config)
packets.SeenList(config=self.config) packets.SeenList(config=self.config)
messaging.MsgTrack(config=self.config)
@mock.patch.object(fake.FakeBaseNoThreadsPlugin, "process") @mock.patch.object(fake.FakeBaseNoThreadsPlugin, "process")
def test_base_plugin_no_threads(self, mock_process): def test_base_plugin_no_threads(self, mock_process):
@ -161,7 +162,7 @@ class TestQueryPlugin(TestPlugin):
@mock.patch("aprsd.messaging.MsgTrack.restart_delayed") @mock.patch("aprsd.messaging.MsgTrack.restart_delayed")
def test_query_restart_delayed(self, mock_restart): def test_query_restart_delayed(self, mock_restart):
track = messaging.MsgTrack() track = messaging.MsgTrack()
track.track = {} track.data = {}
packet = fake.fake_packet(message="!4") packet = fake.fake_packet(message="!4")
query = query_plugin.QueryPlugin(self.config) query = query_plugin.QueryPlugin(self.config)