From e009791b754d3ddde52609f836e67330d2f55e20 Mon Sep 17 00:00:00 2001
From: Hemna <waboring@hemna.com>
Date: Fri, 22 Oct 2021 16:07:20 -0400
Subject: [PATCH] Converted MsgTrack to ObjectStoreMixin

---
 aprsd/main.py           |  7 ++--
 aprsd/messaging.py      | 86 +++++++++++------------------------------
 aprsd/objectstore.py    |  6 +--
 tests/test_messaging.py |  7 +++-
 tests/test_plugin.py    |  3 +-
 5 files changed, 35 insertions(+), 74 deletions(-)

diff --git a/aprsd/main.py b/aprsd/main.py
index ef18cf6..f027d7b 100644
--- a/aprsd/main.py
+++ b/aprsd/main.py
@@ -146,8 +146,7 @@ def signal_handler(sig, frame):
             ),
         )
         time.sleep(1.5)
-        tracker = messaging.MsgTrack()
-        tracker.save()
+        messaging.MsgTrack().save()
         packets.WatchList().save()
         packets.SeenList().save()
         LOG.info(stats.APRSDStats())
@@ -480,13 +479,13 @@ def server(
     packets.PacketList(config=config)
     if flush:
         LOG.debug("Deleting saved MsgTrack.")
-        messaging.MsgTrack().flush()
+        messaging.MsgTrack(config=config).flush()
         packets.WatchList(config=config)
         packets.SeenList(config=config)
     else:
         # Try and load saved MsgTrack list
         LOG.debug("Loading saved MsgTrack object.")
-        messaging.MsgTrack().load()
+        messaging.MsgTrack(config=config).load()
         packets.WatchList(config=config).load()
         packets.SeenList(config=config).load()
 
diff --git a/aprsd/messaging.py b/aprsd/messaging.py
index f80574c..23a6c8f 100644
--- a/aprsd/messaging.py
+++ b/aprsd/messaging.py
@@ -2,16 +2,11 @@ import abc
 import datetime
 import logging
 from multiprocessing import RawValue
-import os
-import pathlib
-import pickle
 import re
 import threading
 import time
 
-from aprsd import client
-from aprsd import config as aprsd_config
-from aprsd import packets, stats, threads
+from aprsd import client, objectstore, packets, stats, threads
 
 
 LOG = logging.getLogger("APRSD")
@@ -21,7 +16,7 @@ LOG = logging.getLogger("APRSD")
 NULL_MESSAGE = -1
 
 
-class MsgTrack:
+class MsgTrack(objectstore.ObjectStoreMixin):
     """Class to keep track of outstanding text messages.
 
     This is a thread safe class that keeps track of active
@@ -38,7 +33,7 @@ class MsgTrack:
     _start_time = None
     lock = None
 
-    track = {}
+    data = {}
     total_messages_tracked = 0
 
     def __new__(cls, *args, **kwargs):
@@ -47,93 +42,65 @@ class MsgTrack:
             cls._instance.track = {}
             cls._instance._start_time = datetime.datetime.now()
             cls._instance.lock = threading.Lock()
+            cls._instance.config = kwargs["config"]
+            cls._instance._init_store()
         return cls._instance
 
     def __getitem__(self, name):
         with self.lock:
-            return self.track[name]
+            return self.data[name]
 
     def __iter__(self):
         with self.lock:
-            return iter(self.track)
+            return iter(self.data)
 
     def keys(self):
         with self.lock:
-            return self.track.keys()
+            return self.data.keys()
 
     def items(self):
         with self.lock:
-            return self.track.items()
+            return self.data.items()
 
     def values(self):
         with self.lock:
-            return self.track.values()
+            return self.data.values()
 
     def __len__(self):
         with self.lock:
-            return len(self.track)
+            return len(self.data)
 
     def __str__(self):
         with self.lock:
             result = "{"
-            for key in self.track.keys():
-                result += f"{key}: {str(self.track[key])}, "
+            for key in self.data.keys():
+                result += f"{key}: {str(self.data[key])}, "
             result += "}"
             return result
 
     def add(self, msg):
         with self.lock:
             key = int(msg.id)
-            self.track[key] = msg
+            self.data[key] = msg
             stats.APRSDStats().msgs_tracked_inc()
             self.total_messages_tracked += 1
 
     def get(self, id):
         with self.lock:
-            if id in self.track:
-                return self.track[id]
+            if id in self.data:
+                return self.data[id]
 
     def remove(self, id):
         with self.lock:
             key = int(id)
-            if key in self.track.keys():
-                del self.track[key]
-
-    def save(self):
-        """Save any queued to disk?"""
-        LOG.debug(f"Save tracker to disk? {len(self)}")
-        if len(self) > 0:
-            LOG.info(f"Saving {len(self)} tracking messages to disk")
-            pickle.dump(self.dump(), open(aprsd_config.DEFAULT_SAVE_FILE, "wb+"))
-        else:
-            LOG.debug(
-                "Nothing to save, flushing old save file '{}'".format(
-                    aprsd_config.DEFAULT_SAVE_FILE,
-                ),
-            )
-            self.flush()
-
-    def dump(self):
-        dump = {}
-        with self.lock:
-            for key in self.track.keys():
-                dump[key] = self.track[key]
-
-        return dump
-
-    def load(self):
-        if os.path.exists(aprsd_config.DEFAULT_SAVE_FILE):
-            raw = pickle.load(open(aprsd_config.DEFAULT_SAVE_FILE, "rb"))
-            if raw:
-                self.track = raw
-                LOG.debug("Loaded MsgTrack dict from disk.")
-                LOG.debug(self)
+            if key in self.data.keys():
+                del self.data[key]
 
     def restart(self):
         """Walk the list of messages and restart them if any."""
 
-        for key in self.track.keys():
-            msg = self.track[key]
+        for key in self.data.keys():
+            msg = self.data[key]
             if msg.last_send_attempt < msg.retry_count:
                 msg.send()
 
@@ -145,14 +112,14 @@ class MsgTrack:
         """Walk the list of delayed messages and restart them if any."""
         if not count:
             # Send all the delayed messages
-            for key in self.track.keys():
-                msg = self.track[key]
+            for key in self.data.keys():
+                msg = self.data[key]
                 if msg.last_send_attempt == msg.retry_count:
                     self._resend(msg)
         else:
             # They want to resend <count> delayed messages
             tmp = sorted(
-                self.track.items(),
+                self.data.items(),
                 reverse=most_recent,
                 key=lambda x: x[1].last_send_time,
             )
@@ -160,13 +127,6 @@ class MsgTrack:
             for (_key, msg) in msg_list:
                 self._resend(msg)
 
-    def flush(self):
-        """Nuke the old pickle file that stored the old results from last aprsd run."""
-        if os.path.exists(aprsd_config.DEFAULT_SAVE_FILE):
-            pathlib.Path(aprsd_config.DEFAULT_SAVE_FILE).unlink()
-        with self.lock:
-            self.track = {}
-
 
 class MessageCounter:
     """
diff --git a/aprsd/objectstore.py b/aprsd/objectstore.py
index 33e0934..90c7773 100644
--- a/aprsd/objectstore.py
+++ b/aprsd/objectstore.py
@@ -53,7 +53,6 @@ class ObjectStoreMixin:
     def _save_filename(self):
         save_location = self._save_location()
 
-        LOG.debug(f"{self.__class__.__name__}::Using save location {save_location}")
         return "{}/{}.p".format(
             save_location,
             self.__class__.__name__.lower(),
@@ -65,15 +64,12 @@ class ObjectStoreMixin:
             for key in self.data.keys():
                 dump[key] = self.data[key]
 
-        LOG.debug(f"{self.__class__.__name__}:: DUMP")
-        LOG.debug(dump)
-
         return dump
 
     def save(self):
         """Save any queued to disk?"""
         if len(self) > 0:
-            LOG.info(f"{self.__class__.__name__}::Saving {len(self)} entries to disk")
+            LOG.info(f"{self.__class__.__name__}::Saving {len(self)} entries to disk at {self._save_location()}")
             pickle.dump(self._dump(), open(self._save_filename(), "wb+"))
         else:
             LOG.debug(
diff --git a/tests/test_messaging.py b/tests/test_messaging.py
index d7a9110..7fd6136 100644
--- a/tests/test_messaging.py
+++ b/tests/test_messaging.py
@@ -6,9 +6,14 @@ from aprsd import messaging
 
 
 class TestMessageTrack(unittest.TestCase):
+
+    def setUp(self) -> None:
+        config = {}
+        messaging.MsgTrack(config=config)
+
     def _clean_track(self):
         track = messaging.MsgTrack()
-        track.track = {}
+        track.data = {}
         track.total_messages_tracked = 0
         return track
 
diff --git a/tests/test_plugin.py b/tests/test_plugin.py
index 4082b61..0d15cc4 100644
--- a/tests/test_plugin.py
+++ b/tests/test_plugin.py
@@ -26,6 +26,7 @@ class TestPlugin(unittest.TestCase):
         stats.APRSDStats(self.config)
         packets.WatchList(config=self.config)
         packets.SeenList(config=self.config)
+        messaging.MsgTrack(config=self.config)
 
     @mock.patch.object(fake.FakeBaseNoThreadsPlugin, "process")
     def test_base_plugin_no_threads(self, mock_process):
@@ -161,7 +162,7 @@ class TestQueryPlugin(TestPlugin):
     @mock.patch("aprsd.messaging.MsgTrack.restart_delayed")
     def test_query_restart_delayed(self, mock_restart):
         track = messaging.MsgTrack()
-        track.track = {}
+        track.data = {}
         packet = fake.fake_packet(message="!4")
         query = query_plugin.QueryPlugin(self.config)