Compare commits

..

1 Commits

45 changed files with 3719 additions and 1844 deletions

View File

@ -3,7 +3,6 @@ project(MILSTD110C)
# Set C++17 standard
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# Include directories
include_directories(include)
@ -12,49 +11,20 @@ include_directories(include/modulation)
include_directories(include/utils)
# Add subdirectories for organization
enable_testing()
add_subdirectory(tests)
# Set source files
set(SOURCES main.cpp)
# Find required packages
# Link with libsndfile
list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake")
find_package(SndFile REQUIRED)
find_package(FFTW3 REQUIRED)
find_package(fmt REQUIRED)
find_package(Gnuradio REQUIRED COMPONENTS
analog
blocks
channels
filter
fft
runtime
)
if(NOT Gnuradio_FOUND)
message(FATAL_ERROR "GNU Radio not found!")
endif()
# Include GNU Radio directories
include_directories(${Gnuradio_INCLUDE_DIRS})
link_directories(${Gnuradio_LIBRARY_DIRS})
# Add executable
add_executable(MILSTD110C ${SOURCES})
# Link executable with required libraries
target_link_libraries(MILSTD110C
SndFile::sndfile
FFTW3::fftw3
gnuradio::gnuradio-runtime
gnuradio::gnuradio-analog
gnuradio::gnuradio-blocks
gnuradio::gnuradio-filter
gnuradio::gnuradio-fft
gnuradio::gnuradio-channels
fmt::fmt
)
# Link executable with libsndfile library
target_link_libraries(MILSTD110C SndFile::sndfile)
# Debug and Release Build Types
set(CMAKE_CONFIGURATION_TYPES "Debug;Release" CACHE STRING "" FORCE)

View File

@ -1,60 +0,0 @@
# FindFFTW3.cmake
# This file is used by CMake to locate the FFTW3 library on the system.
# It sets the FFTW3_INCLUDE_DIRS and FFTW3_LIBRARIES variables.
# Find the include directory for FFTW3
find_path(FFTW3_INCLUDE_DIR fftw3.h
HINTS
${FFTW3_DIR}/include
/usr/include
/usr/local/include
/opt/local/include
)
# Find the library for FFTW3
find_library(FFTW3_LIBRARY fftw3
HINTS
${FFTW3_DIR}/lib
/usr/lib
/usr/local/lib
/opt/local/lib
)
# Find the multi-threaded FFTW3 library, if available
find_library(FFTW3_THREADS_LIBRARY fftw3_threads
HINTS
${FFTW3_DIR}/lib
/usr/lib
/usr/local/lib
/opt/local/lib
)
# Check if the FFTW3 library was found
if(FFTW3_INCLUDE_DIR AND FFTW3_LIBRARY)
set(FFTW3_FOUND TRUE)
# Create the FFTW3 imported target
add_library(FFTW3::fftw3 UNKNOWN IMPORTED)
set_target_properties(FFTW3::fftw3 PROPERTIES
IMPORTED_LOCATION ${FFTW3_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${FFTW3_INCLUDE_DIR}
)
# Create the FFTW3 Threads imported target, if found
if(FFTW3_THREADS_LIBRARY)
add_library(FFTW3::fftw3_threads UNKNOWN IMPORTED)
set_target_properties(FFTW3::fftw3_threads PROPERTIES
IMPORTED_LOCATION ${FFTW3_THREADS_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${FFTW3_INCLUDE_DIR}
)
endif()
message(STATUS "Found FFTW3: ${FFTW3_LIBRARY}")
else()
set(FFTW3_FOUND FALSE)
message(STATUS "FFTW3 not found.")
endif()
# Mark variables as advanced to hide from the cache
mark_as_advanced(FFTW3_INCLUDE_DIR FFTW3_LIBRARY FFTW3_THREADS_LIBRARY)

28
include/bitstream/LICENSE Normal file
View File

@ -0,0 +1,28 @@
BSD 3-Clause License
Copyright (c) 2023, Krede
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2018 Stanislav Denisov (nxrighthere@gmail.com)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1 @@
0.1.4

View File

@ -0,0 +1,22 @@
#pragma once
// Quantization
#include "quantization/bounded_range.h"
#include "quantization/half_precision.h"
#include "quantization/smallest_three.h"
// Stream
#include "stream/bit_measure.h"
#include "stream/bit_reader.h"
#include "stream/bit_writer.h"
#include "stream/byte_buffer.h"
#include "stream/serialize_traits.h"
// Traits
#include "traits/array_traits.h"
#include "traits/bool_trait.h"
#include "traits/enum_trait.h"
#include "traits/float_trait.h"
#include "traits/integral_traits.h"
#include "traits/quantization_traits.h"
#include "traits/string_traits.h"

View File

@ -0,0 +1,104 @@
#pragma once
/*
* Copyright (c) 2018 Stanislav Denisov
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#include <cstdint>
namespace bitstream
{
/**
* @brief Class for quantizing single-precision floats into a range and precision
*/
class bounded_range
{
public:
constexpr bounded_range() noexcept :
m_Min(0),
m_Max(0),
m_Precision(0),
m_BitsRequired(0),
m_Mask(0) {}
constexpr bounded_range(float min, float max, float precision) noexcept :
m_Min(min),
m_Max(max),
m_Precision(precision),
m_BitsRequired(log2(static_cast<uint32_t>((m_Max - m_Min) * (1.0f / precision) + 0.5f)) + 1),
m_Mask((1U << m_BitsRequired) - 1U) {}
constexpr inline float get_min() const noexcept { return m_Min; }
constexpr inline float get_max() const noexcept { return m_Max; }
constexpr inline float get_precision() const noexcept { return m_Precision; }
constexpr inline uint32_t get_bits_required() const noexcept { return m_BitsRequired; }
constexpr inline uint32_t quantize(float value) const noexcept
{
if (value < m_Min)
value = m_Min;
else if (value > m_Max)
value = m_Max;
return static_cast<uint32_t>(static_cast<float>((value - m_Min) * (1.0f / m_Precision)) + 0.5f) & m_Mask;
}
constexpr inline float dequantize(uint32_t data) const noexcept
{
float adjusted = (static_cast<float>(data) * m_Precision) + m_Min;
if (adjusted < m_Min)
adjusted = m_Min;
else if (adjusted > m_Max)
adjusted = m_Max;
return adjusted;
}
private:
constexpr inline static uint32_t log2(uint32_t value) noexcept
{
value |= value >> 1;
value |= value >> 2;
value |= value >> 4;
value |= value >> 8;
value |= value >> 16;
return DE_BRUIJN[(value * 0x07C4ACDDU) >> 27];
}
private:
float m_Min;
float m_Max;
float m_Precision;
uint32_t m_BitsRequired;
uint32_t m_Mask;
constexpr inline static uint32_t DE_BRUIJN[32]
{
0, 9, 1, 10, 13, 21, 2, 29,
11, 14, 16, 18, 22, 25, 3, 30,
8, 12, 20, 28, 15, 17, 24, 7,
19, 27, 23, 6, 26, 5, 4, 31
};
};
}

View File

@ -0,0 +1,114 @@
#pragma once
/*
* Copyright (c) 2018 Stanislav Denisov
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#include <cstdint>
#include <cstring>
namespace bitstream
{
/**
* @brief Class for quantizing single-precision floats into half-precision
*/
class half_precision
{
public:
inline static uint16_t quantize(float value) noexcept
{
int32_t tmp;
std::memcpy(&tmp, &value, sizeof(float));
int32_t s = (tmp >> 16) & 0x00008000;
int32_t e = ((tmp >> 23) & 0X000000FF) - (127 - 15);
int32_t m = tmp & 0X007FFFFF;
if (e <= 0) {
if (e < -10)
return static_cast<uint16_t>(s);
m |= 0x00800000;
int32_t t = 14 - e;
int32_t a = (1 << (t - 1)) - 1;
int32_t b = (m >> t) & 1;
m = (m + a + b) >> t;
return static_cast<uint16_t>(s | m);
}
if (e == 0XFF - (127 - 15)) {
if (m == 0)
return static_cast<uint16_t>(s | 0X7C00);
m >>= 13;
return static_cast<uint16_t>(s | 0X7C00 | m | ((m == 0) ? 1 : 0));
}
m = m + 0X00000FFF + ((m >> 13) & 1);
if ((m & 0x00800000) != 0) {
m = 0;
e++;
}
if (e > 30)
return static_cast<uint16_t>(s | 0X7C00);
return static_cast<uint16_t>(s | (e << 10) | (m >> 13));
}
inline static float dequantize(uint16_t value) noexcept
{
uint32_t tmp;
uint32_t mantissa = static_cast<uint32_t>(value & 1023);
uint32_t exponent = 0XFFFFFFF2;
if ((value & -33792) == 0) {
if (mantissa != 0) {
while ((mantissa & 1024) == 0) {
exponent--;
mantissa <<= 1;
}
mantissa &= 0XFFFFFBFF;
tmp = ((static_cast<uint32_t>(value) & 0x8000) << 16) | ((exponent + 127) << 23) | (mantissa << 13);
}
else
{
tmp = static_cast<uint32_t>((value & 0x8000) << 16);
}
}
else
{
tmp = ((static_cast<uint32_t>(value) & 0x8000) << 16) | (((((static_cast<uint32_t>(value) >> 10) & 0X1F) - 15) + 127) << 23) | (mantissa << 13);
}
float result;
std::memcpy(&result, &tmp, sizeof(float));
return result;
}
};
}

View File

@ -0,0 +1,156 @@
#pragma once
/*
* Copyright (c) 2020 Stanislav Denisov, Maxim Munning, Davin Carten
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#include <cstdint>
#include <cmath>
namespace bitstream
{
/**
* @brief A quantized representation of a quaternion
*/
struct quantized_quaternion
{
uint32_t m;
uint32_t a;
uint32_t b;
uint32_t c;
constexpr quantized_quaternion() noexcept :
m(0),
a(0),
b(0),
c(0) {}
constexpr quantized_quaternion(uint32_t w, uint32_t x, uint32_t y, uint32_t z) noexcept :
m(w), a(x), b(y), c(z) {}
};
/**
* @brief Class for quantizing a user-specified quaternion into fewer bits using the smallest-three algorithm
* @tparam T The quaternion-type to quantize
*/
template<typename T, size_t BitsPerElement = 12>
class smallest_three
{
private:
static constexpr float SMALLEST_THREE_UNPACK = 0.70710678118654752440084436210485f + 0.0000001f;
static constexpr float SMALLEST_THREE_PACK = 1.0f / SMALLEST_THREE_UNPACK;
public:
inline static quantized_quaternion quantize(const T& quaternion) noexcept
{
constexpr float half_range = static_cast<float>(1 << (BitsPerElement - 1));
constexpr float packer = SMALLEST_THREE_PACK * half_range;
float max_value = -1.0f;
bool sign_minus = false;
uint32_t m = 0;
uint32_t a = 0;
uint32_t b = 0;
uint32_t c = 0;
for (uint32_t i = 0; i < 4; i++)
{
float element = quaternion[i];
float abs = element > 0.0f ? element : -element;
if (abs > max_value)
{
sign_minus = element < 0.0f;
m = i;
max_value = abs;
}
}
float af = 0.0f;
float bf = 0.0f;
float cf = 0.0f;
switch (m)
{
case 0:
af = quaternion[1];
bf = quaternion[2];
cf = quaternion[3];
break;
case 1:
af = quaternion[0];
bf = quaternion[2];
cf = quaternion[3];
break;
case 2:
af = quaternion[0];
bf = quaternion[1];
cf = quaternion[3];
break;
default: // case 3
af = quaternion[0];
bf = quaternion[1];
cf = quaternion[2];
break;
}
if (sign_minus)
{
a = static_cast<uint32_t>((-af * packer) + half_range);
b = static_cast<uint32_t>((-bf * packer) + half_range);
c = static_cast<uint32_t>((-cf * packer) + half_range);
}
else
{
a = static_cast<uint32_t>((af * packer) + half_range);
b = static_cast<uint32_t>((bf * packer) + half_range);
c = static_cast<uint32_t>((cf * packer) + half_range);
}
return { m, a, b, c };
}
inline static T dequantize(const quantized_quaternion& data) noexcept
{
constexpr uint32_t half_range = (1 << (BitsPerElement - 1));
constexpr float unpacker = SMALLEST_THREE_UNPACK * (1.0f / half_range);
float a = static_cast<float>(data.a * unpacker - half_range * unpacker);
float b = static_cast<float>(data.b * unpacker - half_range * unpacker);
float c = static_cast<float>(data.c * unpacker - half_range * unpacker);
float d = std::sqrt(1.0f - ((a * a) + (b * b) + (c * c)));
switch (data.m)
{
case 0:
return T{ d, a, b, c };
case 1:
return T{ a, d, b, c };
case 2:
return T{ a, b, d, c };
default: // case 3
return T{ a, b, c, d };
}
}
};
}

View File

@ -0,0 +1,235 @@
#pragma once
#include "../utility/assert.h"
#include "../utility/crc.h"
#include "../utility/endian.h"
#include "../utility/meta.h"
#include "byte_buffer.h"
#include "serialize_traits.h"
#include <cstdint>
#include <cstring>
#include <limits>
#include <memory>
#include <type_traits>
namespace bitstream
{
/**
* @brief A stream for writing objects tightly into a buffer
* @note Does not take ownership of the buffer
*/
class bit_measure
{
public:
static constexpr bool writing = true;
static constexpr bool reading = false;
/**
* @brief Default construct a writer pointing to a null buffer
*/
bit_measure() noexcept :
m_NumBitsWritten(0),
m_TotalBits((std::numeric_limits<uint32_t>::max)()) {}
/**
* @brief Construct a writer pointing to the given byte array with @p num_bytes size
* @param num_bytes The number of bytes in the array
*/
bit_measure(uint32_t num_bytes) noexcept :
m_NumBitsWritten(0),
m_TotalBits(num_bytes * 8) {}
bit_measure(const bit_measure&) = delete;
bit_measure(bit_measure&& other) noexcept :
m_NumBitsWritten(other.m_NumBitsWritten),
m_TotalBits(other.m_TotalBits)
{
other.m_NumBitsWritten = 0;
other.m_TotalBits = 0;
}
bit_measure& operator=(const bit_measure&) = delete;
bit_measure& operator=(bit_measure&& rhs) noexcept
{
m_NumBitsWritten = rhs.m_NumBitsWritten;
m_TotalBits = rhs.m_TotalBits;
rhs.m_NumBitsWritten = 0;
rhs.m_TotalBits = 0;
return *this;
}
/**
* @brief Returns the buffer that this writer is currently serializing into
* @return The buffer
*/
[[nodiscard]] uint8_t* get_buffer() const noexcept { return nullptr; }
/**
* @brief Returns the number of bits which have been written to the buffer
* @return The number of bits which have been written
*/
[[nodiscard]] uint32_t get_num_bits_serialized() const noexcept { return m_NumBitsWritten; }
/**
* @brief Returns the number of bytes which have been written to the buffer
* @return The number of bytes which have been written
*/
[[nodiscard]] uint32_t get_num_bytes_serialized() const noexcept { return m_NumBitsWritten > 0U ? ((m_NumBitsWritten - 1U) / 8U + 1U) : 0U; }
/**
* @brief Returns whether the @p num_bits can fit in the buffer
* @param num_bits The number of bits to test
* @return Whether the number of bits can fit in the buffer
*/
[[nodiscard]] bool can_serialize_bits(uint32_t num_bits) const noexcept { return m_NumBitsWritten + num_bits <= m_TotalBits; }
/**
* @brief Returns the number of bits which have not been written yet
* @note The same as get_total_bits() - get_num_bits_serialized()
* @return The remaining space in the buffer
*/
[[nodiscard]] uint32_t get_remaining_bits() const noexcept { return m_TotalBits - m_NumBitsWritten; }
/**
* @brief Returns the size of the buffer, in bits
* @return The size of the buffer, in bits
*/
[[nodiscard]] uint32_t get_total_bits() const noexcept { return m_TotalBits; }
/**
* @brief Instructs the writer that you intend to use `serialize_checksum()` later on, and to reserve the first 32 bits.
* @return Returns false if anything has already been written to the buffer or if there's no space to write the checksum
*/
[[nodiscard]] bool prepend_checksum() noexcept
{
BS_ASSERT(m_NumBitsWritten == 0);
BS_ASSERT(can_serialize_bits(32U));
m_NumBitsWritten += 32U;
return true;
}
/**
* @brief Writes a checksum of the @p protocol_version and the rest of the buffer as the first 32 bits
* @param protocol_version A unique version number
* @return The number of bytes written to the buffer
*/
uint32_t serialize_checksum(uint32_t protocol_version) noexcept
{
return m_NumBitsWritten;
}
/**
* @brief Pads the buffer up to the given number of bytes with zeros
* @param num_bytes The byte number to pad to
* @return Returns false if the current size of the buffer is bigger than @p num_bytes
*/
[[nodiscard]] bool pad_to_size(uint32_t num_bytes) noexcept
{
BS_ASSERT(num_bytes * 8U <= m_TotalBits);
BS_ASSERT(num_bytes * 8U >= m_NumBitsWritten);
m_NumBitsWritten = num_bytes * 8U;
return true;
}
/**
* @brief Pads the buffer up with the given number of bytes
* @param num_bytes The amount of bytes to pad
* @return Returns false if the current size of the buffer is bigger than @p num_bytes or if the padded bits are not zeros.
*/
[[nodiscard]] bool pad(uint32_t num_bytes) noexcept
{
return pad_to_size(get_num_bytes_serialized() + num_bytes);
}
/**
* @brief Pads the buffer with up to 8 zeros, so that the next write is byte-aligned
* @return Success
*/
[[nodiscard]] bool align() noexcept
{
uint32_t remainder = m_NumBitsWritten % 8U;
if (remainder != 0U)
m_NumBitsWritten += 8U - remainder;
return true;
}
/**
* @brief Writes the first @p num_bits bits of @p value into the buffer
* @param value The value to serialize
* @param num_bits The number of bits of the @p value to serialize
* @return Returns false if @p num_bits is less than 1 or greater than 32 or if writing the given number of bits would overflow the buffer
*/
[[nodiscard]] bool serialize_bits(uint32_t value, uint32_t num_bits) noexcept
{
BS_ASSERT(num_bits > 0U && num_bits <= 32U);
BS_ASSERT(can_serialize_bits(num_bits));
m_NumBitsWritten += num_bits;
return true;
}
/**
* @brief Writes the first @p num_bits bits of the given byte array, 32 bits at a time
* @param bytes The bytes to serialize
* @param num_bits The number of bits of the @p bytes to serialize
* @return Returns false if @p num_bits is less than 1 or if writing the given number of bits would overflow the buffer
*/
[[nodiscard]] bool serialize_bytes(const uint8_t* bytes, uint32_t num_bits) noexcept
{
BS_ASSERT(num_bits > 0U);
BS_ASSERT(can_serialize_bits(num_bits));
m_NumBitsWritten += num_bits;
return true;
}
/**
* @brief Writes to the buffer, using the given @p Trait.
* @note The Trait type in this function must always be explicitly declared
* @tparam Trait A template specialization of serialize_trait<>
* @tparam ...Args The types of the arguments to pass to the serialize function
* @param ...args The arguments to pass to the serialize function
* @return Whether successful or not
*/
template<typename Trait, typename... Args, typename = utility::has_serialize_t<Trait, bit_measure, Args...>>
[[nodiscard]] bool serialize(Args&&... args) noexcept(utility::is_serialize_noexcept_v<Trait, bit_measure, Args...>)
{
return serialize_traits<Trait>::serialize(*this, std::forward<Args>(args)...);
}
/**
* @brief Writes to the buffer, by trying to deduce the trait.
* @note The Trait type in this function is always implicit and will be deduced from the first argument if possible.
* If the trait cannot be deduced it will not compile.
* @tparam Trait The type of the first argument, which will be used to deduce the trait specialization
* @tparam ...Args The types of the arguments to pass to the serialize function
* @param arg The first argument to pass to the serialize function
* @param ...args The rest of the arguments to pass to the serialize function
* @return Whether successful or not
*/
template<typename... Args, typename Trait, typename = utility::has_deduce_serialize_t<Trait, bit_measure, Args...>>
[[nodiscard]] bool serialize(Trait&& arg, Args&&... args) noexcept(utility::is_deduce_serialize_noexcept_v<Trait, bit_measure, Args...>)
{
return serialize_traits<utility::deduce_trait_t<Trait, bit_measure, Args...>>::serialize(*this, std::forward<Trait>(arg), std::forward<Args>(args)...);
}
private:
uint32_t m_NumBitsWritten;
uint32_t m_TotalBits;
};
}

View File

@ -0,0 +1,343 @@
#pragma once
#include "../utility/assert.h"
#include "../utility/crc.h"
#include "../utility/endian.h"
#include "../utility/meta.h"
#include "byte_buffer.h"
#include "serialize_traits.h"
#include "stream_traits.h"
#include <cstdint>
#include <cstring>
#include <string>
#include <type_traits>
namespace bitstream
{
/**
* @brief A stream for reading objects from a tightly packed buffer
* @tparam Policy The underlying representation of the buffer
*/
template<typename Policy>
class bit_reader
{
public:
static constexpr bool writing = false;
static constexpr bool reading = true;
/**
* @brief Construct a reader with the parameters passed to the underlying policy
* @param ...args The arguments to pass to the policy
*/
template<typename... Ts,
typename = std::enable_if_t<std::is_constructible_v<Policy, Ts...>>>
bit_reader(Ts&&... args)
noexcept(std::is_nothrow_constructible_v<Policy, Ts...>) :
m_Policy(std::forward<Ts>(args) ...),
m_Scratch(0),
m_ScratchBits(0),
m_WordIndex(0) {}
bit_reader(const bit_reader&) = delete;
bit_reader(bit_reader&& other) noexcept :
m_Policy(std::move(other.m_Policy)),
m_Scratch(other.m_Scratch),
m_ScratchBits(other.m_ScratchBits),
m_WordIndex(other.m_WordIndex)
{
other.m_Scratch = 0;
other.m_ScratchBits = 0;
other.m_WordIndex = 0;
}
bit_reader& operator=(const bit_reader&) = delete;
bit_reader& operator=(bit_reader&& rhs) noexcept
{
m_Policy = std::move(rhs.m_Policy);
m_Scratch = rhs.m_Scratch;
m_ScratchBits = rhs.m_ScratchBits;
m_WordIndex = rhs.m_WordIndex;
rhs.m_Scratch = 0;
rhs.m_ScratchBits = 0;
rhs.m_WordIndex = 0;
return *this;
}
/**
* @brief Returns the buffer that this reader is currently serializing from
* @return The buffer
*/
[[nodiscard]] const uint8_t* get_buffer() const noexcept { return reinterpret_cast<const uint8_t*>(m_Policy.get_buffer()); }
/**
* @brief Returns the number of bits which have been read from the buffer
* @return The number of bits which have been read
*/
[[nodiscard]] uint32_t get_num_bits_serialized() const noexcept { return m_Policy.get_num_bits_serialized(); }
/**
* @brief Returns the number of bytes which have been read from the buffer
* @return The number of bytes which have been read
*/
[[nodiscard]] uint32_t get_num_bytes_serialized() const noexcept { return get_num_bits_serialized() > 0U ? ((get_num_bits_serialized() - 1U) / 8U + 1U) : 0U; }
/**
* @brief Returns whether the @p num_bits be read from the buffer
* @param num_bits The number of bits to test
* @return Whether the number of bits can be read from the buffer
*/
[[nodiscard]] bool can_serialize_bits(uint32_t num_bits) const noexcept { return m_Policy.can_serialize_bits(num_bits); }
/**
* @brief Returns the number of bits which have not been read yet
* @note The same as get_total_bits() - get_num_bits_serialized()
* @return The remaining space in the buffer
*/
[[nodiscard]] uint32_t get_remaining_bits() const noexcept { return get_total_bits() - get_num_bits_serialized(); }
/**
* @brief Returns the size of the buffer, in bits
* @return The size of the buffer, in bits
*/
[[nodiscard]] uint32_t get_total_bits() const noexcept { return m_Policy.get_total_bits(); }
/**
* @brief Reads the first 32 bits of the buffer and compares it to a checksum of the @p protocol_version and the rest of the buffer
* @param protocol_version A unique version number
* @return Whether the checksum matches what was written
*/
[[nodiscard]] bool serialize_checksum(uint32_t protocol_version) noexcept
{
BS_ASSERT(get_num_bits_serialized() == 0);
BS_ASSERT(can_serialize_bits(32U));
uint32_t num_bytes = (get_total_bits() - 1U) / 8U + 1U;
const uint32_t* buffer = m_Policy.get_buffer();
// Generate checksum to compare against
uint32_t generated_checksum = utility::crc_uint32(reinterpret_cast<const uint8_t*>(&protocol_version), reinterpret_cast<const uint8_t*>(buffer + 1), num_bytes - 4);
// Advance the reader by the size of the checksum (32 bits / 1 word)
m_WordIndex++;
BS_ASSERT(m_Policy.extend(32U));
// Read the checksum
uint32_t checksum = *buffer;
// Compare the checksum
return generated_checksum == checksum;
}
/**
* @brief Pads the buffer up to the given number of bytes
* @param num_bytes The byte number to pad to
* @return Returns false if the current size of the buffer is bigger than @p num_bytes or if the padded bits are not zeros.
*/
[[nodiscard]] bool pad_to_size(uint32_t num_bytes) noexcept
{
uint32_t num_bits_read = get_num_bits_serialized();
BS_ASSERT(num_bytes * 8U >= num_bits_read);
BS_ASSERT(can_serialize_bits(num_bytes * 8U - num_bits_read));
uint32_t remainder = (num_bytes * 8U - num_bits_read) % 32U;
uint32_t zero;
// Test the last word more carefully, as it may have data
if (remainder != 0U)
{
bool status = serialize_bits(zero, remainder);
BS_ASSERT(status && zero == 0);
}
uint32_t offset = get_num_bits_serialized() / 32;
uint32_t max = num_bytes / 4;
// Test for zeros in padding
for (uint32_t i = offset; i < max; i++)
{
bool status = serialize_bits(zero, 32);
BS_ASSERT(status && zero == 0);
}
return true;
}
/**
* @brief Pads the buffer up with the given number of bytes
* @param num_bytes The amount of bytes to pad
* @return Returns false if the current size of the buffer is bigger than @p num_bytes or if the padded bits are not zeros.
*/
[[nodiscard]] bool pad(uint32_t num_bytes) noexcept
{
return pad_to_size(get_num_bytes_serialized() + num_bytes);
}
/**
* @brief Pads the buffer with up to 8 zeros, so that the next read is byte-aligned
* @notes Return false if the padded bits are not zeros
* @return Returns false if the padded bits are not zeros
*/
[[nodiscard]] bool align() noexcept
{
uint32_t remainder = get_num_bits_serialized() % 8U;
if (remainder != 0U)
{
uint32_t zero;
bool status = serialize_bits(zero, 8U - remainder);
BS_ASSERT(status && zero == 0U && get_num_bits_serialized() % 8U == 0U);
}
return true;
}
/**
* @brief Reads the first @p num_bits bits of @p value from the buffer
* @param value The value to serialize
* @param num_bits The number of bits of the @p value to serialize
* @return Returns false if @p num_bits is less than 1 or greater than 32 or if reading the given number of bits would overflow the buffer
*/
[[nodiscard]] bool serialize_bits(uint32_t& value, uint32_t num_bits) noexcept
{
BS_ASSERT(num_bits > 0U && num_bits <= 32U);
BS_ASSERT(m_Policy.extend(num_bits));
// This is actually slower
// Possibly due to unlikely branching
/*if (num_bits == 32U && m_ScratchBits == 0U)
{
const uint32_t* ptr = m_Policy.get_buffer() + m_WordIndex;
value = utility::to_big_endian32(*ptr);
m_WordIndex++;
return true;
}*/
if (m_ScratchBits < num_bits)
{
const uint32_t* ptr = m_Policy.get_buffer() + m_WordIndex;
uint64_t ptr_value = static_cast<uint64_t>(utility::to_big_endian32(*ptr)) << (32U - m_ScratchBits);
m_Scratch |= ptr_value;
m_ScratchBits += 32U;
m_WordIndex++;
}
uint32_t offset = 64U - num_bits;
value = static_cast<uint32_t>(m_Scratch >> offset);
m_Scratch <<= num_bits;
m_ScratchBits -= num_bits;
return true;
}
/**
* @brief Reads the first @p num_bits bits of the given byte array, 32 bits at a time
* @param bytes The bytes to serialize
* @param num_bits The number of bits of the @p bytes to serialize
* @return Returns false if @p num_bits is less than 1 or if reading the given number of bits would overflow the buffer
*/
[[nodiscard]] bool serialize_bytes(uint8_t* bytes, uint32_t num_bits) noexcept
{
BS_ASSERT(num_bits > 0U);
BS_ASSERT(can_serialize_bits(num_bits));
// Read the byte array as words
uint32_t* word_buffer = reinterpret_cast<uint32_t*>(bytes);
uint32_t num_words = num_bits / 32U;
if (m_ScratchBits % 32U == 0U && num_words > 0U)
{
BS_ASSERT(m_Policy.extend(num_words * 32U));
// If the read buffer is word-aligned, just memcpy it
std::memcpy(word_buffer, m_Policy.get_buffer() + m_WordIndex, num_words * 4U);
m_WordIndex += num_words;
}
else
{
// If the buffer is not word-aligned, serialize a word at a time
for (uint32_t i = 0U; i < num_words; i++)
{
uint32_t value;
BS_ASSERT(serialize_bits(value, 32U));
// Casting a byte-array to an int is wrong on little-endian systems
// We have to swap the bytes around
word_buffer[i] = utility::to_big_endian32(value);
}
}
// Early exit if the word-count matches
if (num_bits % 32 == 0)
return true;
uint32_t remaining_bits = num_bits - num_words * 32U;
uint32_t num_bytes = (remaining_bits - 1U) / 8U + 1U;
for (uint32_t i = 0; i < num_bytes; i++)
{
uint32_t value;
BS_ASSERT(serialize_bits(value, (std::min)(remaining_bits - i * 8U, 8U)));
bytes[num_words * 4 + i] = static_cast<uint8_t>(value);
}
return true;
}
/**
* @brief Reads from the buffer, using the given @p Trait.
* @note The Trait type in this function must always be explicitly declared
* @tparam Trait A template specialization of serialize_trait<>
* @tparam ...Args The types of the arguments to pass to the serialize function
* @param ...args The arguments to pass to the serialize function
* @return Whether successful or not
*/
template<typename Trait, typename... Args, typename = utility::has_serialize_t<Trait, bit_reader, Args...>>
[[nodiscard]] bool serialize(Args&&... args) noexcept(utility::is_serialize_noexcept_v<Trait, bit_reader, Args...>)
{
return serialize_traits<Trait>::serialize(*this, std::forward<Args>(args)...);
}
/**
* @brief Reads from the buffer, by trying to deduce the trait.
* @note The Trait type in this function is always implicit and will be deduced from the first argument if possible.
* If the trait cannot be deduced it will not compile.
* @tparam Trait The type of the first argument, which will be used to deduce the trait specialization
* @tparam ...Args The types of the arguments to pass to the serialize function
* @param arg The first argument to pass to the serialize function
* @param ...args The rest of the arguments to pass to the serialize function
* @return Whether successful or not
*/
template<typename... Args, typename Trait, typename = utility::has_deduce_serialize_t<Trait, bit_reader, Args...>>
[[nodiscard]] bool serialize(Trait&& arg, Args&&... args) noexcept(utility::is_deduce_serialize_noexcept_v<Trait, bit_reader, Args...>)
{
return serialize_traits<utility::deduce_trait_t<Trait, bit_reader, Args...>>::serialize(*this, std::forward<Trait>(arg), std::forward<Args>(args)...);
}
private:
Policy m_Policy;
uint64_t m_Scratch;
uint32_t m_ScratchBits;
uint32_t m_WordIndex;
};
using fixed_bit_reader = bit_reader<fixed_policy>;
}

View File

@ -0,0 +1,400 @@
#pragma once
#include "../utility/assert.h"
#include "../utility/crc.h"
#include "../utility/endian.h"
#include "../utility/meta.h"
#include "byte_buffer.h"
#include "serialize_traits.h"
#include "stream_traits.h"
#include <cstdint>
#include <cstring>
#include <memory>
#include <type_traits>
namespace bitstream
{
/**
* @brief A stream for writing objects tightly into a buffer
* @tparam Policy The underlying representation of the buffer
*/
template<typename Policy>
class bit_writer
{
public:
static constexpr bool writing = true;
static constexpr bool reading = false;
/**
* @brief Construct a writer with the parameters passed to the underlying policy
* @param ...args The arguments to pass to the policy
*/
template<typename... Ts,
typename = std::enable_if_t<std::is_constructible_v<Policy, Ts...>>>
bit_writer(Ts&&... args)
noexcept(std::is_nothrow_constructible_v<Policy, Ts...>) :
m_Policy(std::forward<Ts>(args) ...),
m_Scratch(0),
m_ScratchBits(0),
m_WordIndex(0) {}
bit_writer(const bit_writer&) = delete;
bit_writer(bit_writer&& other) noexcept :
m_Policy(std::move(other.m_Policy)),
m_Scratch(other.m_Scratch),
m_ScratchBits(other.m_ScratchBits),
m_WordIndex(other.m_WordIndex)
{
other.m_Scratch = 0;
other.m_ScratchBits = 0;
other.m_WordIndex = 0;
}
bit_writer& operator=(const bit_writer&) = delete;
bit_writer& operator=(bit_writer&& rhs) noexcept
{
m_Policy = std::move(rhs.m_Policy);
m_Scratch = rhs.m_Scratch;
m_ScratchBits = rhs.m_ScratchBits;
m_WordIndex = rhs.m_WordIndex;
rhs.m_Scratch = 0;
rhs.m_ScratchBits = 0;
rhs.m_WordIndex = 0;
return *this;
}
/**
* @brief Returns the buffer that this writer is currently serializing into
* @return The buffer
*/
[[nodiscard]] uint8_t* get_buffer() const noexcept { return reinterpret_cast<uint8_t*>(m_Policy.get_buffer()); }
/**
* @brief Returns the number of bits which have been written to the buffer
* @return The number of bits which have been written
*/
[[nodiscard]] uint32_t get_num_bits_serialized() const noexcept { return m_Policy.get_num_bits_serialized(); }
/**
* @brief Returns the number of bytes which have been written to the buffer
* @return The number of bytes which have been written
*/
[[nodiscard]] uint32_t get_num_bytes_serialized() const noexcept { return get_num_bits_serialized() > 0U ? ((get_num_bits_serialized() - 1U) / 8U + 1U) : 0U; }
/**
* @brief Returns whether the @p num_bits can fit in the buffer
* @param num_bits The number of bits to test
* @return Whether the number of bits can fit in the buffer
*/
[[nodiscard]] bool can_serialize_bits(uint32_t num_bits) const noexcept { return m_Policy.can_serialize_bits(num_bits); }
/**
* @brief Returns the number of bits which have not been written yet
* @note The same as get_total_bits() - get_num_bits_serialized()
* @return The remaining space in the buffer
*/
[[nodiscard]] uint32_t get_remaining_bits() const noexcept { return get_total_bits() - get_num_bits_serialized(); }
/**
* @brief Returns the size of the buffer, in bits
* @return The size of the buffer, in bits
*/
[[nodiscard]] uint32_t get_total_bits() const noexcept { return m_Policy.get_total_bits(); }
/**
* @brief Flushes any remaining bits into the buffer. Use this when you no longer intend to write anything to the buffer.
* @return The number of bytes written to the buffer
*/
uint32_t flush() noexcept
{
if (m_ScratchBits > 0U)
{
uint32_t* ptr = m_Policy.get_buffer() + m_WordIndex;
uint32_t ptr_value = static_cast<uint32_t>(m_Scratch >> 32U);
*ptr = utility::to_big_endian32(ptr_value);
m_Scratch = 0U;
m_ScratchBits = 0U;
m_WordIndex++;
}
return get_num_bits_serialized();
}
/**
* @brief Instructs the writer that you intend to use `serialize_checksum()` later on, and to reserve the first 32 bits.
* @return Returns false if anything has already been written to the buffer or if there's no space to write the checksum
*/
[[nodiscard]] bool prepend_checksum() noexcept
{
BS_ASSERT(get_num_bits_serialized() == 0);
BS_ASSERT(m_Policy.extend(32U));
// Advance the reader by the size of the checksum (32 bits / 1 word)
m_WordIndex++;
return true;
}
/**
* @brief Writes a checksum of the @p protocol_version and the rest of the buffer as the first 32 bits
* @param protocol_version A unique version number
* @return The number of bytes written to the buffer
*/
uint32_t serialize_checksum(uint32_t protocol_version) noexcept
{
uint32_t num_bits = flush();
BS_ASSERT(num_bits > 32U);
// Copy protocol version to buffer
uint32_t* buffer = m_Policy.get_buffer();
*buffer = protocol_version;
// Generate checksum of version + data
uint32_t checksum = utility::crc_uint32(reinterpret_cast<uint8_t*>(buffer), get_num_bytes_serialized());
// Put checksum at beginning
*buffer = checksum;
return num_bits;
}
/**
* @brief Pads the buffer up to the given number of bytes with zeros
* @param num_bytes The byte number to pad to
* @return Returns false if the current size of the buffer is bigger than @p num_bytes
*/
[[nodiscard]] bool pad_to_size(uint32_t num_bytes) noexcept
{
uint32_t num_bits_written = get_num_bits_serialized();
BS_ASSERT(num_bytes * 8U >= num_bits_written);
BS_ASSERT(can_serialize_bits(num_bytes * 8U - num_bits_written));
if (num_bits_written == 0)
{
BS_ASSERT(m_Policy.extend(num_bytes * 8U - num_bits_written));
std::memset(m_Policy.get_buffer(), 0, num_bytes);
m_Scratch = 0;
m_ScratchBits = 0;
m_WordIndex = num_bytes / 4;
return true;
}
uint32_t remainder = (num_bytes * 8U - num_bits_written) % 32U;
uint32_t zero = 0;
// Align to byte
if (remainder != 0U)
BS_ASSERT(serialize_bits(zero, remainder));
uint32_t offset = get_num_bits_serialized() / 32;
uint32_t max = num_bytes / 4;
// Serialize words
for (uint32_t i = offset; i < max; i++)
BS_ASSERT(serialize_bits(zero, 32));
return true;
}
/**
* @brief Pads the buffer up with the given number of bytes
* @param num_bytes The amount of bytes to pad
* @return Returns false if the current size of the buffer is bigger than @p num_bytes or if the padded bits are not zeros.
*/
[[nodiscard]] bool pad(uint32_t num_bytes) noexcept
{
return pad_to_size(get_num_bytes_serialized() + num_bytes);
}
/**
* @brief Pads the buffer with up to 8 zeros, so that the next write is byte-aligned
* @return Success
*/
[[nodiscard]] bool align() noexcept
{
uint32_t remainder = m_ScratchBits % 8U;
if (remainder != 0U)
{
uint32_t zero = 0U;
bool status = serialize_bits(zero, 8U - remainder);
BS_ASSERT(status && get_num_bits_serialized() % 8U == 0U);
}
return true;
}
/**
* @brief Writes the first @p num_bits bits of @p value into the buffer
* @param value The value to serialize
* @param num_bits The number of bits of the @p value to serialize
* @return Returns false if @p num_bits is less than 1 or greater than 32 or if writing the given number of bits would overflow the buffer
*/
[[nodiscard]] bool serialize_bits(uint32_t value, uint32_t num_bits) noexcept
{
BS_ASSERT(num_bits > 0U && num_bits <= 32U);
BS_ASSERT(m_Policy.extend(num_bits));
// This is actually slower
// Possibly due to unlikely branching
/*if (num_bits == 32U && m_ScratchBits == 0U)
{
uint32_t* ptr = m_Policy.get_buffer() + m_WordIndex;
*ptr = utility::to_big_endian32(value);
m_WordIndex++;
return true;
}*/
uint32_t offset = 64U - num_bits - m_ScratchBits;
uint64_t ls_value = static_cast<uint64_t>(value) << offset;
m_Scratch |= ls_value;
m_ScratchBits += num_bits;
if (m_ScratchBits >= 32U)
{
uint32_t* ptr = m_Policy.get_buffer() + m_WordIndex;
uint32_t ptr_value = static_cast<uint32_t>(m_Scratch >> 32U);
*ptr = utility::to_big_endian32(ptr_value);
m_Scratch <<= 32ULL;
m_ScratchBits -= 32U;
m_WordIndex++;
}
return true;
}
/**
* @brief Writes the first @p num_bits bits of the given byte array, 32 bits at a time
* @param bytes The bytes to serialize
* @param num_bits The number of bits of the @p bytes to serialize
* @return Returns false if @p num_bits is less than 1 or if writing the given number of bits would overflow the buffer
*/
[[nodiscard]] bool serialize_bytes(const uint8_t* bytes, uint32_t num_bits) noexcept
{
BS_ASSERT(num_bits > 0U);
BS_ASSERT(can_serialize_bits(num_bits));
// Write the byte array as words
const uint32_t* word_buffer = reinterpret_cast<const uint32_t*>(bytes);
uint32_t num_words = num_bits / 32U;
if (m_ScratchBits % 32U == 0U && num_words > 0U)
{
BS_ASSERT(m_Policy.extend(num_words * 32U));
// If the written buffer is word-aligned, just memcpy it
std::memcpy(m_Policy.get_buffer() + m_WordIndex, word_buffer, num_words * 4U);
m_WordIndex += num_words;
}
else
{
// If the buffer is not word-aligned, serialize a word at a time
for (uint32_t i = 0U; i < num_words; i++)
{
// Casting a byte-array to an int is wrong on little-endian systems
// We have to swap the bytes around
uint32_t value = utility::to_big_endian32(word_buffer[i]);
BS_ASSERT(serialize_bits(value, 32U));
}
}
// Early exit if the word-count matches
if (num_bits % 32U == 0U)
return true;
uint32_t remaining_bits = num_bits - num_words * 32U;
uint32_t num_bytes = (remaining_bits - 1U) / 8U + 1U;
for (uint32_t i = 0U; i < num_bytes; i++)
{
uint32_t value = static_cast<uint32_t>(bytes[num_words * 4U + i]);
BS_ASSERT(serialize_bits(value, (std::min)(remaining_bits - i * 8U, 8U)));
}
return true;
}
/**
* @brief Writes the contents of the buffer into the given @p writer. Essentially copies the entire buffer without modifying it.
* @param writer The writer to copy into
* @return Returns false if writing would overflow the buffer
*/
[[nodiscard]] bool serialize_into(bit_writer& writer) const noexcept
{
uint8_t* buffer = reinterpret_cast<uint8_t*>(m_Policy.get_buffer());
uint32_t num_bits = get_num_bits_serialized();
uint32_t remainder_bits = num_bits % 8U;
BS_ASSERT(writer.serialize_bytes(buffer, num_bits - remainder_bits));
if (remainder_bits > 0U)
{
uint32_t byte_value = buffer[num_bits / 8U] >> (8U - remainder_bits);
BS_ASSERT(writer.serialize_bits(byte_value, remainder_bits));
}
return true;
}
/**
* @brief Writes to the buffer, using the given @p Trait.
* @note The Trait type in this function must always be explicitly declared
* @tparam Trait A template specialization of serialize_trait<>
* @tparam ...Args The types of the arguments to pass to the serialize function
* @param ...args The arguments to pass to the serialize function
* @return Whether successful or not
*/
template<typename Trait, typename... Args, typename = utility::has_serialize_t<Trait, bit_writer, Args...>>
[[nodiscard]] bool serialize(Args&&... args) noexcept(utility::is_serialize_noexcept_v<Trait, bit_writer, Args...>)
{
return serialize_traits<Trait>::serialize(*this, std::forward<Args>(args)...);
}
/**
* @brief Writes to the buffer, by trying to deduce the trait.
* @note The Trait type in this function is always implicit and will be deduced from the first argument if possible.
* If the trait cannot be deduced it will not compile.
* @tparam Trait The type of the first argument, which will be used to deduce the trait specialization
* @tparam ...Args The types of the arguments to pass to the serialize function
* @param arg The first argument to pass to the serialize function
* @param ...args The rest of the arguments to pass to the serialize function
* @return Whether successful or not
*/
template<typename... Args, typename Trait, typename = utility::has_deduce_serialize_t<Trait, bit_writer, Args...>>
[[nodiscard]] bool serialize(Trait&& arg, Args&&... args) noexcept(utility::is_deduce_serialize_noexcept_v<Trait, bit_writer, Args...>)
{
return serialize_traits<utility::deduce_trait_t<Trait, bit_writer, Args...>>::serialize(*this, std::forward<Trait>(arg), std::forward<Args>(args)...);
}
private:
Policy m_Policy;
uint64_t m_Scratch;
int m_ScratchBits;
size_t m_WordIndex;
};
using fixed_bit_writer = bit_writer<fixed_policy>;
template<typename T>
using growing_bit_writer = bit_writer<growing_policy<T>>;
}

View File

@ -0,0 +1,22 @@
#pragma once
#include <cstddef>
#include <cstdint>
namespace bitstream
{
/**
* @brief A byte buffer aligned to 4 bytes.
* Can be used with bit_reader and bit_writer.
* @note Size must be a multiple of 4
*/
template<size_t Size>
struct byte_buffer
{
static_assert(Size % 4 == 0, "Buffer size must be a multiple of 4");
alignas(uint32_t) uint8_t Bytes[Size];
uint8_t& operator[](size_t i) noexcept { return Bytes[i]; }
};
}

View File

@ -0,0 +1,12 @@
#pragma once
namespace bitstream
{
/**
* @brief A class for specializing trait serialization functions
* @tparam Trait Make a specialization on this type
* @tparam Void Use std::enable_if here if you need to, otherwise leave empty
*/
template<typename Trait, typename Void = void>
struct serialize_traits;
}

View File

@ -0,0 +1,96 @@
#pragma once
#include "byte_buffer.h"
#include <cstddef>
#include <cstdint>
#include <limits>
#include <type_traits>
namespace bitstream
{
struct fixed_policy
{
/**
* @brief Construct a stream pointing to the given byte array with @p num_bytes size
* @param bytes The byte array to serialize to/from. Must be 4-byte aligned and the size must be a multiple of 4
* @param num_bytes The number of bytes in the array
*/
fixed_policy(void* buffer, uint32_t num_bits) noexcept :
m_Buffer(static_cast<uint32_t*>(buffer)),
m_NumBitsSerialized(0),
m_TotalBits(num_bits) {}
/**
* @brief Construct a stream pointing to the given @p buffer
* @param buffer The buffer to serialize to/from
* @param num_bits The maximum number of bits that we can read
*/
template<size_t Size>
fixed_policy(byte_buffer<Size>& buffer, uint32_t num_bits) noexcept :
m_Buffer(reinterpret_cast<uint32_t*>(buffer.Bytes)),
m_NumBitsSerialized(0),
m_TotalBits(num_bits) {}
/**
* @brief Construct a stream pointing to the given @p buffer
* @param buffer The buffer to serialize to/from
*/
template<size_t Size>
fixed_policy(byte_buffer<Size>& buffer) noexcept :
m_Buffer(reinterpret_cast<uint32_t*>(buffer.Bytes)),
m_NumBitsSerialized(0),
m_TotalBits(Size * 8) {}
uint32_t* get_buffer() const noexcept { return m_Buffer; }
// TODO: Transition sizes to size_t
uint32_t get_num_bits_serialized() const noexcept { return m_NumBitsSerialized; }
bool can_serialize_bits(uint32_t num_bits) const noexcept { return m_NumBitsSerialized + num_bits <= m_TotalBits; }
uint32_t get_total_bits() const noexcept { return m_TotalBits; }
bool extend(uint32_t num_bits) noexcept
{
if (!can_serialize_bits(num_bits))
return false;
m_NumBitsSerialized += num_bits;
return true;
}
uint32_t* m_Buffer;
// TODO: Transition sizes to size_t
uint32_t m_NumBitsSerialized;
uint32_t m_TotalBits;
};
template<typename T>
struct growing_policy
{
growing_policy(T& container) noexcept :
m_Buffer(container),
m_NumBitsSerialized(0) {}
uint32_t* get_buffer() const noexcept { return m_Buffer.data(); }
uint32_t get_num_bits_serialized() const noexcept { return m_NumBitsSerialized; }
bool can_serialize_bits(uint32_t num_bits) const noexcept { return true; }
uint32_t get_total_bits() const noexcept { return (std::numeric_limits<uint32_t>::max)(); }
bool extend(uint32_t num_bits)
{
m_NumBitsSerialized += num_bits;
uint32_t num_bytes = (m_NumBitsSerialized - 1) / 8U + 1;
m_Buffer.resize(num_bytes);
return true;
}
T& m_Buffer;
uint32_t m_NumBitsSerialized;
};
}

View File

@ -0,0 +1,165 @@
#pragma once
#include "../utility/assert.h"
#include "../utility/meta.h"
#include "../utility/parameter.h"
#include "../stream/serialize_traits.h"
#include "../traits/bool_trait.h"
#include "../traits/integral_traits.h"
#include <cstdint>
namespace bitstream
{
/**
* @brief Wrapper type for subsets of arrays
* @tparam T The type of the array
*/
template<typename T, typename = T>
struct array_subset;
/**
* @brief A trait used for serializing a subset of an array of objects
* @tparam T The type of the object in the array
* @tparam Trait
*/
template<typename T, typename Trait>
struct serialize_traits<array_subset<T, Trait>>
{
private:
template<uint32_t Min, uint32_t Max, typename Stream>
static bool serialize_difference(Stream& stream, int& previous, int& current, uint32_t& difference)
{
bool use_bits;
if constexpr (Stream::writing)
use_bits = difference <= Max;
BS_ASSERT(stream.template serialize<bool>(use_bits));
if (use_bits)
{
using bounded_trait = bounded_int<uint32_t, Min, Max>;
BS_ASSERT(stream.template serialize<bounded_trait>(difference));
if constexpr (Stream::reading)
current = previous + difference;
previous = current;
return true;
}
return false;
}
template<typename Stream>
static bool serialize_index(Stream& stream, int& previous, int& current, int max_size)
{
uint32_t difference;
if constexpr (Stream::writing)
{
BS_ASSERT(previous < current);
difference = current - previous;
BS_ASSERT(difference > 0);
}
// +1 (1 bit)
bool plus_one;
if constexpr (Stream::writing)
plus_one = difference == 1;
BS_ASSERT(stream.template serialize<bool>(plus_one));
if (plus_one)
{
if constexpr (Stream::reading)
current = previous + 1;
previous = current;
return true;
}
// [+2,5] -> [0,3] (2 bits)
if (serialize_difference<2, 5>(stream, previous, current, difference))
return true;
// [6,13] -> [0,7] (3 bits)
if (serialize_difference<6, 13>(stream, previous, current, difference))
return true;
// [14,29] -> [0,15] (4 bits)
if (serialize_difference<14, 29>(stream, previous, current, difference))
return true;
// [30,61] -> [0,31] (5 bits)
if (serialize_difference<30, 61>(stream, previous, current, difference))
return true;
// [62,125] -> [0,63] (6 bits)
if (serialize_difference<62, 125>(stream, previous, current, difference))
return true;
// [126,MaxObjects+1]
BS_ASSERT(stream.template serialize<uint32_t>(difference, 126, max_size));
if constexpr (Stream::reading)
current = previous + difference;
previous = current;
return true;
}
public:
/**
* @brief Writes a subset of the array @p values into the writer
* @tparam Compare A function type which returns a bool
* @tparam ...Args The types of any additional arguments
* @param writer The stream to write to
* @param values The array of objects to serialize
* @param max_size The size of the array
* @param compare A function which returns true if the object should be written, false otherwise
* @param ...args Any additional arguments to use when serializing each individual object
* @return Success
*/
template<typename Stream, typename Compare, typename... Args>
typename utility::is_writing_t<Stream>
static serialize(Stream& writer, T* values, int max_size, Compare compare, Args&&... args) noexcept
{
int prev_index = -1;
for (int index = 0; index < max_size; index++)
{
if (!compare(values[index]))
continue;
BS_ASSERT(serialize_index(writer, prev_index, index, max_size));
BS_ASSERT(writer.template serialize<Trait>(values[index], std::forward<Args>(args)...));
}
BS_ASSERT(serialize_index(writer, prev_index, max_size, max_size));
return true;
}
/**
* @brief Writes a subset of a serialized array into @p values
* @tparam ...Args The types of any additional arguments
* @param reader The stream to read from
* @param values The array of objects to read into
* @param max_size The size of the array
* @param compare Not used, but kept for compatibility with the serialize write function
* @param ...args Any additional arguments to use when serializing each individual object
* @return Success
*/
template<typename Stream, typename... Args>
typename utility::is_reading_t<Stream>
static serialize(Stream& reader, T* values, int max_size, Args&&... args) noexcept
{
int prev_index = -1;
int index = 0;
while (true)
{
BS_ASSERT(serialize_index(reader, prev_index, index, max_size));
if (index == max_size)
break;
BS_ASSERT(reader.template serialize<Trait>(values[index], std::forward<Args>(args)...));
}
return true;
}
};
}

View File

@ -0,0 +1,74 @@
#pragma once
#include "../utility/assert.h"
#include "../utility/meta.h"
#include "../utility/parameter.h"
#include "../stream/serialize_traits.h"
namespace bitstream
{
/**
* @brief A trait used to serialize a boolean as a single bit
*/
template<>
struct serialize_traits<bool>
{
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& writer, in<bool> value) noexcept
{
uint32_t unsigned_value = value;
return writer.serialize_bits(unsigned_value, 1U);
}
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& reader, out<bool> value) noexcept
{
uint32_t unsigned_value;
BS_ASSERT(reader.serialize_bits(unsigned_value, 1U));
value = unsigned_value;
return true;
}
};
/**
* @brief A trait used to serialize multiple boolean values
*/
template<size_t Size>
struct serialize_traits<bool[Size]>
{
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& writer, const bool* values) noexcept
{
uint32_t unsigned_value;
for (size_t i = 0; i < Size; i++)
{
unsigned_value = values[i];
BS_ASSERT(writer.serialize_bits(unsigned_value, 1U));
}
return writer.serialize_bits(unsigned_value, 1U);
}
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& reader, bool* values) noexcept
{
uint32_t unsigned_value;
for (size_t i = 0; i < Size; i++)
{
BS_ASSERT(reader.serialize_bits(unsigned_value, 1U));
values[i] = unsigned_value;
}
return true;
}
};
}

View File

@ -0,0 +1,81 @@
#pragma once
#include "../utility/assert.h"
#include "../utility/meta.h"
#include "../utility/parameter.h"
#include "../stream/serialize_traits.h"
#include "../traits/integral_traits.h"
namespace bitstream
{
/**
* @brief Wrapper type for compiletime known integer bounds
* @tparam T
*/
template<typename T, std::underlying_type_t<T> = (std::numeric_limits<T>::min)(), std::underlying_type_t<T> = (std::numeric_limits<T>::max)()>
struct bounded_enum;
/**
* @brief A trait used to serialize an enum type with runtime bounds
*/
template<typename T>
struct serialize_traits<T, typename std::enable_if_t<std::is_enum_v<T>>>
{
using value_type = std::underlying_type_t<T>;
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& writer, T value, value_type min = 0, value_type max = (std::numeric_limits<value_type>::max)()) noexcept
{
value_type unsigned_value = static_cast<value_type>(value);
return writer.template serialize<value_type>(unsigned_value, min, max);
}
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& reader, T& value, value_type min = 0, value_type max = (std::numeric_limits<value_type>::max)()) noexcept
{
value_type unsigned_value;
BS_ASSERT(reader.template serialize<value_type>(unsigned_value, min, max));
value = static_cast<T>(unsigned_value);
return true;
}
};
/**
* @brief A trait used to serialize an enum type with compiletime bounds
*/
template<typename T, std::underlying_type_t<T> Min, std::underlying_type_t<T> Max>
struct serialize_traits<bounded_enum<T, Min, Max>, typename std::enable_if_t<std::is_enum_v<T>>>
{
using value_type = std::underlying_type_t<T>;
using bound_type = bounded_int<value_type, Min, Max>;
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& writer, T value) noexcept
{
value_type unsigned_value = static_cast<value_type>(value);
return writer.template serialize<bound_type>(unsigned_value);
}
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& reader, T& value) noexcept
{
value_type unsigned_value;
BS_ASSERT(reader.template serialize<bound_type>(unsigned_value));
value = static_cast<T>(unsigned_value);
return true;
}
};
}

View File

@ -0,0 +1,102 @@
#pragma once
#include "../utility/assert.h"
#include "../utility/meta.h"
#include "../utility/parameter.h"
#include "../stream/serialize_traits.h"
#include <cstdint>
#include <cstring>
namespace bitstream
{
/**
* @brief A trait used to serialize a float as-is, without any bound checking or quantization
*/
template<>
struct serialize_traits<float>
{
/**
* @brief Serializes a whole float into the writer
* @param writer The stream to write to
* @param value The float to serialize
* @return Success
*/
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& writer, in<float> value) noexcept
{
uint32_t tmp;
std::memcpy(&tmp, &value, sizeof(float));
BS_ASSERT(writer.serialize_bits(tmp, 32));
return true;
}
/**
* @brief Serializes a whole float from the reader
* @param reader The stream to read from
* @param value The float to serialize to
* @return Success
*/
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& reader, float& value) noexcept
{
uint32_t tmp;
BS_ASSERT(reader.serialize_bits(tmp, 32));
std::memcpy(&value, &tmp, sizeof(float));
return true;
}
};
/**
* @brief A trait used to serialize a double as-is, without any bound checking or quantization
*/
template<>
struct serialize_traits<double>
{
/**
* @brief Serializes a whole double into the writer
* @param writer The stream to write to
* @param value The double to serialize
* @return Success
*/
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& writer, in<double> value) noexcept
{
uint32_t tmp[2];
std::memcpy(tmp, &value, sizeof(double));
BS_ASSERT(writer.serialize_bits(tmp[0], 32));
BS_ASSERT(writer.serialize_bits(tmp[1], 32));
return true;
}
/**
* @brief Serializes a whole double from the reader
* @param reader The stream to read from
* @param value The double to serialize to
* @return Success
*/
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& reader, double& value) noexcept
{
uint32_t tmp[2];
BS_ASSERT(reader.serialize_bits(tmp[0], 32));
BS_ASSERT(reader.serialize_bits(tmp[1], 32));
std::memcpy(&value, tmp, sizeof(double));
return true;
}
};
}

View File

@ -0,0 +1,233 @@
#pragma once
#include "../utility/assert.h"
#include "../utility/bits.h"
#include "../utility/meta.h"
#include "../utility/parameter.h"
#include "../stream/serialize_traits.h"
#include <cstdint>
#include <limits>
#include <type_traits>
namespace bitstream
{
/**
* @brief Wrapper type for compiletime known integer bounds
* @tparam T
*/
template<typename T, T = (std::numeric_limits<T>::min)(), T = (std::numeric_limits<T>::max)()>
struct bounded_int;
#pragma region const integral types
/**
* @brief A trait used to serialize integer values with compiletime bounds
* @tparam T A type matching an integer value
* @tparam Min The lower bound. Inclusive
* @tparam Max The upper bound. Inclusive
*/
template<typename T, T Min, T Max>
struct serialize_traits<bounded_int<T, Min, Max>, typename std::enable_if_t<std::is_integral_v<T> && !std::is_const_v<T>>>
{
static_assert(sizeof(T) <= 8, "Integers larger than 8 bytes are currently not supported. You will have to write this functionality yourself");
/**
* @brief Writes an integer into the @p writer
* @param writer The stream to write to
* @param value The value to serialize
* @return Success
*/
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& writer, in<T> value) noexcept
{
static_assert(Min < Max);
BS_ASSERT(value >= Min && value <= Max);
constexpr uint32_t num_bits = utility::bits_in_range(Min, Max);
static_assert(num_bits <= sizeof(T) * 8);
if constexpr (sizeof(T) > 4 && num_bits > 32)
{
// If the given range is bigger than a word (32 bits)
uint32_t unsigned_value = static_cast<uint32_t>(value - Min);
BS_ASSERT(writer.serialize_bits(unsigned_value, 32));
unsigned_value = static_cast<uint32_t>((value - Min) >> 32);
BS_ASSERT(writer.serialize_bits(unsigned_value, num_bits - 32));
}
else
{
// If the given range is smaller than or equal to a word (32 bits)
uint32_t unsigned_value = static_cast<uint32_t>(value - Min);
BS_ASSERT(writer.serialize_bits(unsigned_value, num_bits));
}
return true;
}
/**
* @brief Reads an integer from the @p writer into @p value
* @param reader The stream to read from
* @param value The value to serialize
* @return Success
*/
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& reader, T& value) noexcept
{
static_assert(Min < Max);
constexpr uint32_t num_bits = utility::bits_in_range(Min, Max);
static_assert(num_bits <= sizeof(T) * 8);
if constexpr (sizeof(T) > 4 && num_bits > 32)
{
// If the given range is bigger than a word (32 bits)
value = 0;
uint32_t unsigned_value;
BS_ASSERT(reader.serialize_bits(unsigned_value, 32));
value |= static_cast<T>(unsigned_value);
BS_ASSERT(reader.serialize_bits(unsigned_value, num_bits - 32));
value |= static_cast<T>(unsigned_value) << 32;
value += Min;
}
else
{
// If the given range is smaller than or equal to a word (32 bits)
uint32_t unsigned_value;
BS_ASSERT(reader.serialize_bits(unsigned_value, num_bits));
value = static_cast<T>(unsigned_value) + Min;
}
BS_ASSERT(value >= Min && value <= Max);
return true;
}
};
#pragma endregion
#pragma region integral types
/**
* @brief A trait used to serialize integer values with runtime bounds
* @tparam T A type matching an integer value
*/
template<typename T>
struct serialize_traits<T, typename std::enable_if_t<std::is_integral_v<T> && !std::is_const_v<T>>>
{
static_assert(sizeof(T) <= 8, "Integers larger than 8 bytes are currently not supported. You will have to write this functionality yourself");
/**
* @brief Writes an integer into the @p writer
* @param writer The stream to write to
* @param value The value to serialize
* @param min The minimum bound that @p value can be. Inclusive
* @param max The maximum bound that @p value can be. Inclusive
* @return Success
*/
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& writer, in<T> value, T min, T max) noexcept
{
BS_ASSERT(min < max);
BS_ASSERT(value >= min && value <= max);
uint32_t num_bits = utility::bits_in_range(min, max);
BS_ASSERT(num_bits <= sizeof(T) * 8);
if constexpr (sizeof(T) > 4)
{
if (num_bits > 32)
{
// If the given range is bigger than a word (32 bits)
uint32_t unsigned_value = static_cast<uint32_t>(value - min);
BS_ASSERT(writer.serialize_bits(unsigned_value, 32));
unsigned_value = static_cast<uint32_t>((value - min) >> 32);
BS_ASSERT(writer.serialize_bits(unsigned_value, num_bits - 32));
return true;
}
}
// If the given range is smaller than or equal to a word (32 bits)
uint32_t unsigned_value = static_cast<uint32_t>(value - min);
BS_ASSERT(writer.serialize_bits(unsigned_value, num_bits));
return true;
}
/**
* @brief Reads an integer from the @p reader into @p value
* @param reader The stream to read from
* @param value The value to read into
* @param min The minimum bound that @p value can be. Inclusive
* @param max The maximum bound that @p value can be. Inclusive
* @return Success
*/
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& reader, T& value, T min, T max) noexcept
{
BS_ASSERT(min < max);
uint32_t num_bits = utility::bits_in_range(min, max);
BS_ASSERT(num_bits <= sizeof(T) * 8);
if constexpr (sizeof(T) > 4)
{
if (num_bits > 32)
{
// If the given range is bigger than a word (32 bits)
value = 0;
uint32_t unsigned_value;
BS_ASSERT(reader.serialize_bits(unsigned_value, 32));
value |= static_cast<T>(unsigned_value);
BS_ASSERT(reader.serialize_bits(unsigned_value, num_bits - 32));
value |= static_cast<T>(unsigned_value) << 32;
value += min;
BS_ASSERT(value >= min && value <= max);
return true;
}
}
// If the given range is smaller than or equal to a word (32 bits)
uint32_t unsigned_value;
BS_ASSERT(reader.serialize_bits(unsigned_value, num_bits));
value = static_cast<T>(unsigned_value) + min;
BS_ASSERT(value >= min && value <= max);
return true;
}
/**
* @brief Writes or reads an integer into the @p stream
* @param stream The stream to serialize to/from
* @param value The value to serialize
* @return Success
*/
template<typename Stream, typename U>
static bool serialize(Stream& stream, U&& value) noexcept
{
return serialize_traits<bounded_int<T>>::serialize(stream, std::forward<U>(value));
}
};
#pragma endregion
}

View File

@ -0,0 +1,113 @@
#pragma once
#include "../quantization/bounded_range.h"
#include "../quantization/half_precision.h"
#include "../quantization/smallest_three.h"
#include "../utility/assert.h"
#include "../utility/meta.h"
#include "../utility/parameter.h"
#include "../stream/serialize_traits.h"
#include <cstdint>
namespace bitstream
{
/**
* @brief A trait used to serialize a single-precision float as half-precision
*/
template<>
struct serialize_traits<half_precision>
{
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& stream, in<float> value) noexcept
{
uint32_t int_value = half_precision::quantize(value);
BS_ASSERT(stream.serialize_bits(int_value, 16));
return true;
}
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& stream, out<float> value) noexcept
{
uint32_t int_value;
BS_ASSERT(stream.serialize_bits(int_value, 16));
value = half_precision::dequantize(int_value);
return true;
}
};
/**
* @brief A trait used to quantize and serialize a float to be within a given range and precision
*/
template<>
struct serialize_traits<bounded_range>
{
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& stream, in<bounded_range> range, in<float> value) noexcept
{
uint32_t int_value = range.quantize(value);
BS_ASSERT(stream.serialize_bits(int_value, range.get_bits_required()));
return true;
}
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& stream, in<bounded_range> range, out<float> value) noexcept
{
uint32_t int_value;
BS_ASSERT(stream.serialize_bits(int_value, range.get_bits_required()));
value = range.dequantize(int_value);
return true;
}
};
/**
* @brief A trait used to quantize and serialize quaternions using the smallest-three algorithm
*/
template<typename Q, size_t BitsPerElement>
struct serialize_traits<smallest_three<Q, BitsPerElement>>
{
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& stream, in<Q> value) noexcept
{
quantized_quaternion quantized_quat = smallest_three<Q, BitsPerElement>::quantize(value);
BS_ASSERT(stream.serialize_bits(quantized_quat.m, 2));
BS_ASSERT(stream.serialize_bits(quantized_quat.a, BitsPerElement));
BS_ASSERT(stream.serialize_bits(quantized_quat.b, BitsPerElement));
BS_ASSERT(stream.serialize_bits(quantized_quat.c, BitsPerElement));
return true;
}
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& stream, out<Q> value) noexcept
{
quantized_quaternion quantized_quat;
BS_ASSERT(stream.serialize_bits(quantized_quat.m, 2));
BS_ASSERT(stream.serialize_bits(quantized_quat.a, BitsPerElement));
BS_ASSERT(stream.serialize_bits(quantized_quat.b, BitsPerElement));
BS_ASSERT(stream.serialize_bits(quantized_quat.c, BitsPerElement));
value = smallest_three<Q, BitsPerElement>::dequantize(quantized_quat);
return true;
}
};
}

View File

@ -0,0 +1,344 @@
#pragma once
#include "../utility/assert.h"
#include "../utility/bits.h"
#include "../utility/meta.h"
#include "../utility/parameter.h"
#include "../stream/serialize_traits.h"
#include <cstdint>
#include <string>
namespace bitstream
{
/**
* @brief Wrapper type for compiletime known string max_size
*/
template<typename T, size_t I>
struct bounded_string;
#pragma region const char*
/**
* @brief A trait used to serialize bounded c-style strings
*/
template<>
struct serialize_traits<const char*>
{
/**
* @brief Writes a c-style string into the @p writer
* @param writer The stream to write to
* @param value The string to serialize
* @param max_size The maximum expected length of the string, including the null terminator
* @return Success
*/
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& writer, const char* value, uint32_t max_size) noexcept
{
uint32_t length = static_cast<uint32_t>(std::char_traits<char>::length(value));
BS_ASSERT(length < max_size);
if (length == 0)
return true;
uint32_t num_bits = utility::bits_to_represent(max_size);
BS_ASSERT(writer.serialize_bits(length, num_bits));
return writer.serialize_bytes(reinterpret_cast<const uint8_t*>(value), length * 8);
}
/**
* @brief Read a c-style string from the @p reader into @p value
* @param reader The stream to read from
* @param value A pointer to the buffer that should be read into. The size of this buffer should be at least @p max_size
* @param max_size The maximum expected length of the string, including the null terminator
* @return Success
*/
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& reader, char* value, uint32_t max_size) noexcept
{
uint32_t num_bits = utility::bits_to_represent(max_size);
uint32_t length;
BS_ASSERT(reader.serialize_bits(length, num_bits));
BS_ASSERT(length < max_size);
if (length == 0)
{
value[0] = '\0';
return true;
}
BS_ASSERT(reader.serialize_bytes(reinterpret_cast<uint8_t*>(value), length * 8));
value[length] = '\0';
return true;
}
};
/**
* @brief A trait used to serialize bounded c-style strings with compiletime bounds
* @tparam MaxSize The maximum expected length of the string, including the null terminator
*/
template<size_t MaxSize>
struct serialize_traits<bounded_string<const char*, MaxSize>>
{
/**
* @brief Writes a c-style string into the @p writer
* @param writer The stream to write to
* @param value The string to serialize
* @return Success
*/
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& writer, const char* value) noexcept
{
uint32_t length = static_cast<uint32_t>(std::char_traits<char>::length(value));
BS_ASSERT(length < MaxSize);
if (length == 0)
return true;
constexpr uint32_t num_bits = utility::bits_to_represent(MaxSize);
BS_ASSERT(writer.serialize_bits(length, num_bits));
return writer.serialize_bytes(reinterpret_cast<const uint8_t*>(value), length * 8);
}
/**
* @brief Read a c-style string from the @p reader into @p value
* @param reader The stream to read from
* @param value A pointer to the buffer that should be read into. The size of this buffer should be at least @p max_size
* @return Success
*/
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& reader, char* value) noexcept
{
constexpr uint32_t num_bits = utility::bits_to_represent(MaxSize);
uint32_t length;
BS_ASSERT(reader.serialize_bits(length, num_bits));
BS_ASSERT(length < MaxSize);
if (length == 0)
{
value[0] = '\0';
return true;
}
BS_ASSERT(reader.serialize_bytes(reinterpret_cast<uint8_t*>(value), length * 8));
value[length] = '\0';
return true;
}
};
#pragma endregion
#ifdef __cpp_char8_t
/**
* @brief A trait used to serialize bounded c-style UTF-8 strings
*/
template<>
struct serialize_traits<const char8_t*>
{
/**
* @brief Writes a c-style UTF-8 string into the @p writer
* @param writer The stream to write to
* @param value The string to serialize
* @param max_size The maximum expected length of the string, including the null terminator
* @return Success
*/
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& writer, const char8_t* value, uint32_t max_size) noexcept
{
uint32_t length = static_cast<uint32_t>(std::char_traits<char8_t>::length(value));
BS_ASSERT(length < max_size);
if (length == 0)
return true;
uint32_t num_bits = utility::bits_to_represent(max_size);
BS_ASSERT(writer.serialize_bits(length, num_bits));
return writer.serialize_bytes(reinterpret_cast<const uint8_t*>(value), length * 8);
}
/**
* @brief Read a c-style UTF-8 string from the @p reader into @p value
* @param reader The stream to read from
* @param value A pointer to the buffer that should be read into. The size of this buffer should be at least @p max_size
* @param max_size The maximum expected length of the string, including the null terminator
* @return Success
*/
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& reader, char8_t* value, uint32_t max_size) noexcept
{
uint32_t num_bits = utility::bits_to_represent(max_size);
uint32_t length;
BS_ASSERT(reader.serialize_bits(length, num_bits));
BS_ASSERT(length < max_size);
if (length == 0)
{
value[0] = '\0';
return true;
}
BS_ASSERT(reader.serialize_bytes(reinterpret_cast<uint8_t*>(value), length * 8));
value[length] = '\0';
return true;
}
};
#endif
#pragma region std::basic_string
/**
* @brief A trait used to serialize any combination of std::basic_string
* @tparam T The character type to use
* @tparam Traits The trait type for the T type
* @tparam Alloc The allocator to use
*/
template<typename T, typename Traits, typename Alloc>
struct serialize_traits<std::basic_string<T, Traits, Alloc>>
{
/**
* @brief Writes a string into the @p writer
* @param writer The stream to write to
* @param value The string to serialize
* @param max_size The maximum expected length of the string, excluding the null terminator
* @return Success
*/
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& writer, in<std::basic_string<T, Traits, Alloc>> value, uint32_t max_size) noexcept
{
uint32_t length = static_cast<uint32_t>(value.size());
BS_ASSERT(length <= max_size);
uint32_t num_bits = utility::bits_to_represent(max_size);
BS_ASSERT(writer.serialize_bits(length, num_bits));
if (length == 0)
return true;
return writer.serialize_bytes(reinterpret_cast<const uint8_t*>(value.c_str()), length * sizeof(T) * 8);
}
/**
* @brief Reads a string from the @p reader into @p value
* @param reader The stream to read from
* @param value The string to read into. It will be resized if the read string won't fit
* @param max_size The maximum expected length of the string, excluding the null terminator
* @return Success
*/
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& reader, out<std::basic_string<T, Traits, Alloc>> value, uint32_t max_size)
{
uint32_t num_bits = utility::bits_to_represent(max_size);
uint32_t length;
BS_ASSERT(reader.serialize_bits(length, num_bits));
BS_ASSERT(length <= max_size);
if (length == 0)
{
value->clear();
return true;
}
value->resize(length);
BS_ASSERT(reader.serialize_bytes(reinterpret_cast<uint8_t*>(value->data()), length * sizeof(T) * 8));
return true;
}
};
/**
* @brief A trait used to serialize any combination of std::basic_string with compiletime bounds
* @tparam T The character type to use
* @tparam Traits The trait type for the T type
* @tparam Alloc The allocator to use
* @tparam MaxSize The maximum expected length of the string, excluding the null terminator
*/
template<typename T, typename Traits, typename Alloc, size_t MaxSize>
struct serialize_traits<bounded_string<std::basic_string<T, Traits, Alloc>, MaxSize>>
{
/**
* @brief Writes a string into the @p writer
* @param writer The stream to write to
* @param value The string to serialize
* @return Success
*/
template<typename Stream>
typename utility::is_writing_t<Stream>
static serialize(Stream& writer, in<std::basic_string<T, Traits, Alloc>> value) noexcept
{
uint32_t length = static_cast<uint32_t>(value.size());
BS_ASSERT(length <= MaxSize);
constexpr uint32_t num_bits = utility::bits_to_represent(MaxSize);
BS_ASSERT(writer.serialize_bits(length, num_bits));
if (length == 0)
return true;
return writer.serialize_bytes(reinterpret_cast<const uint8_t*>(value.c_str()), length * sizeof(T) * 8);
}
/**
* @brief Reads a string from the @p reader into @p value
* @param reader The stream to read from
* @param value The string to read into. It will be resized if the read string won't fit
* @return Success
*/
template<typename Stream>
typename utility::is_reading_t<Stream>
static serialize(Stream& reader, out<std::basic_string<T, Traits, Alloc>> value)
{
constexpr uint32_t num_bits = utility::bits_to_represent(MaxSize);
uint32_t length;
BS_ASSERT(reader.serialize_bits(length, num_bits));
BS_ASSERT(length <= MaxSize);
if (length == 0)
{
value->clear();
return true;
}
value->resize(length);
BS_ASSERT(reader.serialize_bytes(reinterpret_cast<uint8_t*>(value->data()), length * sizeof(T) * 8));
return true;
}
};
#pragma endregion
}

View File

@ -0,0 +1,18 @@
#pragma once
#ifdef BS_DEBUG_BREAK
#if defined(_WIN32) // Windows
#define BS_BREAKPOINT() __debugbreak()
#elif defined(__linux__) // Linux
#include <csignal>
#define BS_BREAKPOINT() std::raise(SIGTRAP)
#else // Non-supported
#define BS_BREAKPOINT() throw
#endif
#define BS_ASSERT(...) if (!(__VA_ARGS__)) { BS_BREAKPOINT(); return false; }
#else // BS_DEBUG_BREAK
#define BS_ASSERT(...) if (!(__VA_ARGS__)) { return false; }
#define BS_BREAKPOINT() throw
#endif // BS_DEBUG_BREAK

View File

@ -0,0 +1,26 @@
#pragma once
#include <cstddef>
#include <cstdint>
namespace bitstream::utility
{
constexpr inline uint32_t bits_to_represent(uintmax_t n)
{
uint32_t r = 0;
if (n >> 32) { r += 32U; n >>= 32U; }
if (n >> 16) { r += 16U; n >>= 16U; }
if (n >> 8) { r += 8U; n >>= 8U; }
if (n >> 4) { r += 4U; n >>= 4U; }
if (n >> 2) { r += 2U; n >>= 2U; }
if (n >> 1) { r += 1U; n >>= 1U; }
return r + static_cast<uint32_t>(n);
}
constexpr inline uint32_t bits_in_range(intmax_t min, intmax_t max)
{
return bits_to_represent(static_cast<uintmax_t>(max) - static_cast<uintmax_t>(min));
}
}

View File

@ -0,0 +1,47 @@
#pragma once
#include <array>
#include <cstdint>
namespace bitstream::utility
{
inline constexpr auto CHECKSUM_TABLE = []()
{
constexpr uint32_t POLYNOMIAL = 0xEDB88320;
std::array<uint32_t, 0x100> table{};
for (uint32_t i = 0; i < 0x100; ++i)
{
uint32_t item = i;
for (uint32_t bit = 0; bit < 8; ++bit)
item = ((item & 1) != 0) ? (POLYNOMIAL ^ (item >> 1)) : (item >> 1);
table[i] = item;
}
return table;
}();
inline constexpr uint32_t crc_uint32(const uint8_t* bytes, uint32_t size)
{
uint32_t result = 0xFFFFFFFF;
for (uint32_t i = 0; i < size; i++)
result = CHECKSUM_TABLE[(result & 0xFF) ^ *(bytes + i)] ^ (result >> 8);
return ~result;
}
inline constexpr uint32_t crc_uint32(const uint8_t* checksum, const uint8_t* bytes, uint32_t size)
{
uint32_t result = 0xFFFFFFFF;
for (uint32_t i = 0; i < 4; i++)
result = CHECKSUM_TABLE[(result & 0xFF) ^ *(checksum + i)] ^ (result >> 8);
for (uint32_t i = 0; i < size; i++)
result = CHECKSUM_TABLE[(result & 0xFF) ^ *(bytes + i)] ^ (result >> 8);
return ~result;
}
}

View File

@ -0,0 +1,89 @@
#pragma once
#include <cstdint>
#if defined(__cpp_lib_endian) && __cpp_lib_endian >= 201907L
#include <bit>
#else // __cpp_lib_endian
#ifndef BS_LITTLE_ENDIAN
// Detect with GCC 4.6's macro.
#if defined(__BYTE_ORDER__)
#if (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
#define BS_LITTLE_ENDIAN true
#elif (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
#define BS_LITTLE_ENDIAN false
#else
#error "Unknown machine byteorder endianness detected. Need to manually define BS_LITTLE_ENDIAN."
#endif
// Detect with GLIBC's endian.h.
#elif defined(__GLIBC__)
#include <endian.h>
#if (__BYTE_ORDER == __LITTLE_ENDIAN)
#define BS_LITTLE_ENDIAN true
#elif (__BYTE_ORDER == __BIG_ENDIAN)
#define BS_LITTLE_ENDIAN false
#else
#error "Unknown machine byteorder endianness detected. Need to manually define BS_LITTLE_ENDIAN."
#endif
// Detect with _LITTLE_ENDIAN and _BIG_ENDIAN macro.
#elif defined(_LITTLE_ENDIAN) && !defined(_BIG_ENDIAN)
#define BS_LITTLE_ENDIAN true
#elif defined(_BIG_ENDIAN) && !defined(_LITTLE_ENDIAN)
#define BS_LITTLE_ENDIAN false
// Detect with architecture macros.
#elif defined(__sparc) || defined(__sparc__) || defined(_POWER) || defined(__powerpc__) || defined(__ppc__) || defined(__hpux) || defined(__hppa) || defined(_MIPSEB) || defined(_POWER) || defined(__s390__)
#define BS_LITTLE_ENDIAN false
#elif defined(__i386__) || defined(__alpha__) || defined(__ia64) || defined(__ia64__) || defined(_M_IX86) || defined(_M_IA64) || defined(_M_ALPHA) || defined(__amd64) || defined(__amd64__) || defined(_M_AMD64) || defined(__x86_64) || defined(__x86_64__) || defined(_M_X64) || defined(__bfin__)
#define BS_LITTLE_ENDIAN true
#elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64))
#define BS_LITTLE_ENDIAN true
#else
#error "Unknown machine byteorder endianness detected. Need to manually define BS_LITTLE_ENDIAN."
#endif
#endif // BS_LITTLE_ENDIAN
#endif // __cpp_lib_endian
#if defined(_WIN32)
#include <intrin.h>
#endif
namespace bitstream::utility
{
inline constexpr bool little_endian()
{
#ifdef BS_LITTLE_ENDIAN
#if BS_LITTLE_ENDIAN
return true;
#else // BS_LITTLE_ENDIAN
return false;
#endif // BS_LITTLE_ENDIAN
#else // defined(BS_LITTLE_ENDIAN)
return std::endian::native == std::endian::little;
#endif // defined(BS_LITTLE_ENDIAN)
}
inline uint32_t endian_swap32(uint32_t value)
{
#if defined(_WIN32)
return _byteswap_ulong(value);
#elif defined(__linux__)
return __builtin_bswap32(value);
#else
const uint32_t first = (value << 24) & 0xFF000000;
const uint32_t second = (value << 8) & 0x00FF0000;
const uint32_t third = (value >> 8) & 0x0000FF00;
const uint32_t fourth = (value >> 24) & 0x000000FF;
return first | second | third | fourth;
#endif // _WIN32 || __linux__
}
inline uint32_t to_big_endian32(uint32_t value)
{
if constexpr (little_endian())
return endian_swap32(value);
else
return value;
}
}

View File

@ -0,0 +1,105 @@
#pragma once
#include "../stream/serialize_traits.h"
#include <type_traits>
namespace bitstream::utility
{
// Check if type has a serializable trait
template<typename Void, typename T, typename Stream, typename... Args>
struct has_serialize : std::false_type {};
template<typename T, typename Stream, typename... Args>
struct has_serialize<std::void_t<decltype(serialize_traits<T>::serialize(std::declval<Stream&>(), std::declval<Args>()...))>, T, Stream, Args...> : std::true_type {};
template<typename T, typename Stream, typename... Args>
using has_serialize_t = std::void_t<decltype(serialize_traits<T>::serialize(std::declval<Stream&>(), std::declval<Args>()...))>;
template<typename T, typename Stream, typename... Args>
constexpr bool has_serialize_v = has_serialize<void, T, Stream, Args...>::value;
// Check if stream is writing or reading
template<typename T, typename R = bool>
using is_writing_t = std::enable_if_t<T::writing, R>;
template<typename T, typename R = bool>
using is_reading_t = std::enable_if_t<T::reading, R>;
// Check if type is noexcept, if it exists
template<typename Void, typename T, typename Stream, typename... Args>
struct is_serialize_noexcept : std::false_type {};
template<typename T, typename Stream, typename... Args>
struct is_serialize_noexcept<std::enable_if_t<has_serialize_v<T, Stream, Args...>>, T, Stream, Args...> :
std::bool_constant<noexcept(serialize_traits<T>::serialize(std::declval<Stream&>(), std::declval<Args>()...))> {};
template<typename T, typename Stream, typename... Args>
constexpr bool is_serialize_noexcept_v = is_serialize_noexcept<void, T, Stream, Args...>::value;
// Get the underlying type without &, &&, * or const
template<typename T>
using base_t = typename std::remove_const_t<std::remove_pointer_t<std::decay_t<T>>>;
// Meta functions for guessing the trait type from the first argument
template<typename Void, typename Trait, typename Stream, typename... Args>
struct deduce_trait
{
using type = Trait;
};
// Non-const value
template<typename Trait, typename Stream, typename... Args>
struct deduce_trait<std::enable_if_t<
!std::is_pointer_v<std::decay_t<Trait>> &&
has_serialize_v<base_t<Trait>, Stream, Trait, Args...>>,
Trait, Stream, Args...>
{
using type = base_t<Trait>;
};
// Const value
template<typename Trait, typename Stream, typename... Args>
struct deduce_trait<std::enable_if_t<
!std::is_pointer_v<std::decay_t<Trait>> &&
has_serialize_v<std::add_const_t<base_t<Trait>>, Stream, Trait, Args...>>,
Trait, Stream, Args...>
{
using type = std::add_const_t<base_t<Trait>>;
};
// Non-const pointer
template<typename Trait, typename Stream, typename... Args>
struct deduce_trait<std::enable_if_t<
std::is_pointer_v<std::decay_t<Trait>> &&
has_serialize_v<std::add_pointer_t<base_t<Trait>>, Stream, Trait, Args...>>,
Trait, Stream, Args...>
{
using type = std::add_pointer_t<base_t<Trait>>;
};
// Const pointer
template<typename Trait, typename Stream, typename... Args>
struct deduce_trait<std::enable_if_t<
std::is_pointer_v<std::decay_t<Trait>> &&
has_serialize_v<std::add_pointer_t<std::add_const_t<base_t<Trait>>>, Stream, Trait, Args...>>,
Trait, Stream, Args...>
{
using type = std::add_pointer_t<std::add_const_t<base_t<Trait>>>;
};
template<typename Trait, typename Stream, typename... Args>
using deduce_trait_t = typename deduce_trait<void, Trait, Stream, Args...>::type;
// Shorthands for deduced type_traits
template<typename Trait, typename Stream, typename... Args>
using has_deduce_serialize_t = has_serialize_t<deduce_trait_t<Trait, Stream, Args...>, Stream, Trait, Args...>;
template<typename Trait, typename Stream, typename... Args>
constexpr bool is_deduce_serialize_noexcept_v = is_serialize_noexcept_v<deduce_trait_t<Trait, Stream, Args...>, Stream, Trait, Args...>;
}

View File

@ -0,0 +1,124 @@
#pragma once
#include "assert.h"
#include <utility>
#include <type_traits>
#ifdef __cpp_constexpr_dynamic_alloc
#define BS_CONSTEXPR constexpr
#else // __cpp_constexpr_dynamic_alloc
#define BS_CONSTEXPR
#endif // __cpp_constexpr_dynamic_alloc
namespace bitstream
{
#ifdef BS_DEBUG_BREAK
template<typename T>
class out
{
public:
BS_CONSTEXPR out(T& value) noexcept :
m_Value(value),
m_Constructed(false) {}
out(const out&) = delete;
out(out&&) = delete;
BS_CONSTEXPR ~out()
{
if (!m_Constructed)
BS_BREAKPOINT();
}
template<typename U, typename = std::enable_if_t<std::is_assignable_v<T&, U>>>
BS_CONSTEXPR out& operator=(U&& arg) noexcept(std::is_nothrow_assignable_v<T&, U>)
{
m_Value = std::forward<U>(arg);
m_Constructed = true;
return *this;
}
BS_CONSTEXPR T* operator->() noexcept
{
m_Constructed = true;
return &m_Value;
}
BS_CONSTEXPR T& operator*() noexcept
{
m_Constructed = true;
return m_Value;
}
private:
T& m_Value;
bool m_Constructed;
};
#else
template<typename T>
class out
{
public:
BS_CONSTEXPR out(T& value) noexcept :
m_Value(value) {}
out(const out&) = delete;
out(out&&) = delete;
template<typename U, typename = std::enable_if_t<std::is_assignable_v<T&, U>>>
BS_CONSTEXPR out& operator=(U&& arg) noexcept(std::is_nothrow_assignable_v<T&, U>)
{
m_Value = std::forward<U>(arg);
return *this;
}
BS_CONSTEXPR T* operator->() noexcept { return &m_Value; }
BS_CONSTEXPR T& operator*() noexcept { return m_Value; }
private:
T& m_Value;
};
#endif
/**
* @brief Passes by const or const reference depending on size
*/
template<typename T>
using in = std::conditional_t<(sizeof(T) <= 16 && std::is_trivially_copy_constructible_v<T>), std::add_const_t<T>, std::add_lvalue_reference_t<std::add_const_t<T>>>;
/**
* @brief Passes by reference
*/
template<typename Stream, typename T>
using inout = std::conditional_t<Stream::writing, in<T>, std::add_lvalue_reference_t<T>>;
/**
* @brief Test type
*/
template<typename Lambda>
class finally
{
public:
constexpr finally(Lambda func) noexcept :
m_Lambda(func) {}
~finally()
{
m_Lambda();
}
private:
Lambda m_Lambda;
};
template<typename Lambda>
finally(Lambda func) -> finally<Lambda>;
}

View File

@ -5,7 +5,7 @@
#include <stdexcept>
#include <vector>
#include "bitstream.h"
#include "bitstream/bitstream.h"
/**
* @class FECEncoder
@ -30,26 +30,30 @@ public:
* @param data The input BitStream to be encoded.
* @return The encoded BitStream.
*/
BitStream encode(const BitStream& data) {
BitStream input_data(data);
BitStream output_data;
void encode(bitstream::growing_bit_writer<std::vector<uint8_t>>& output_data, bitstream::fixed_bit_reader& input_data) {
std::vector<uint8_t> intermediate_buffer;
bitstream::growing_bit_writer<std::vector<uint8_t>> intermediate_data(intermediate_buffer);
while (input_data.get_remaining_bits() > 0) {
uint32_t bit;
input_data.serialize_bits(bit, 1);
while (input_data.hasNext()) {
uint8_t bit = input_data.getNextBit();
// Shift the input bit into the shift register
shift_register = ((shift_register << 1) | bit) & 0x7F;
// Calculate T1 and T2 using the generator polynomials
uint8_t t1 = calculateT1();
uint8_t t2 = calculateT2();
uint32_t t1 = calculateT1();
uint32_t t2 = calculateT2();
// Append T1 and T2 to the encoded data
output_data.putBit(t1);
output_data.putBit(t2);
intermediate_data.serialize_bits(t1, 1);
intermediate_data.serialize_bits(t2, 1);
}
bitstream::fixed_bit_reader intermediate_reader(intermediate_buffer.data(), intermediate_data.get_num_bits_serialized());
// Apply repetition or puncturing based on baud rate and operation mode
return adjustRate(output_data);
return adjustRate(output_data, intermediate_reader);
}
private:
@ -62,7 +66,9 @@ private:
* @return The calculated T1 bit.
*/
uint8_t calculateT1() {
return (shift_register >> 6) ^ ((shift_register >> 4) & 0x01) ^ ((shift_register >> 3) & 0x01) ^ ((shift_register >> 1) & 0x01) ^ (shift_register & 0x01);
return (shift_register >> 6) ^ ((shift_register >> 4) & 0x01) ^
((shift_register >> 3) & 0x01) ^ ((shift_register >> 1) & 0x01) ^
(shift_register & 0x01);
}
/**
@ -70,7 +76,9 @@ private:
* @return The calculated T2 bit.
*/
uint8_t calculateT2() {
return (shift_register >> 6) ^ ((shift_register >> 5) & 0x01) ^ ((shift_register >> 4) & 0x01) ^ ((shift_register >> 1) & 0x01) ^ (shift_register & 0x01);
return (shift_register >> 6) ^ ((shift_register >> 5) & 0x01) ^
((shift_register >> 4) & 0x01) ^ ((shift_register >> 1) & 0x01) ^
(shift_register & 0x01);
}
/**
@ -78,32 +86,46 @@ private:
* @param encoded_data The encoded BitStream to be adjusted.
* @return The adjusted BitStream.
*/
BitStream adjustRate(const BitStream& encoded_data) {
BitStream adjusted_data;
void adjustRate(bitstream::growing_bit_writer<std::vector<uint8_t>>& adjusted_data, bitstream::fixed_bit_reader& encoded_data) {
size_t repetition_factor = getRepetitionFactor();
if ((baud_rate == 300 || baud_rate == 150 || baud_rate == 75) && is_frequency_hopping) {
// Repetition for frequency-hopping operation at lower baud rates
size_t repetition_factor = (baud_rate == 300) ? 2 : (baud_rate == 150) ? 4 : 8;
for (size_t i = 0; i < encoded_data.getMaxBitIndex(); i += 2) {
for (size_t j = 0; j < repetition_factor; j++) {
adjusted_data.putBit(encoded_data.getBitVal(i));
adjusted_data.putBit(encoded_data.getBitVal(i + 1));
}
if (repetition_factor == 1) {
while (encoded_data.get_remaining_bits() > 0) {
uint32_t bit;
encoded_data.serialize_bits(bit, 1);
adjusted_data.serialize_bits(bit, 1);
}
} else if ((baud_rate == 300 || baud_rate == 150) && !is_frequency_hopping) {
// Repetition for fixed-frequency operation at lower baud rates
size_t repetition_factor = (baud_rate == 300) ? 2 : 4;
for (size_t i = 0; i < encoded_data.getMaxBitIndex(); i += 2) {
for (size_t j = 0; j < repetition_factor; j++) {
adjusted_data.putBit(encoded_data.getBitVal(i));
adjusted_data.putBit(encoded_data.getBitVal(i + 1));
}
}
} else {
adjusted_data = encoded_data;
return;
}
return adjusted_data;
while (encoded_data.get_remaining_bits() >= 2) {
uint32_t t1, t2;
encoded_data.serialize_bits(t1, 1);
encoded_data.serialize_bits(t2, 1);
for (size_t j = 0; j < repetition_factor; j++) {
adjusted_data.serialize_bits(t1, 1);
adjusted_data.serialize_bits(t2, 1);
}
}
}
size_t getRepetitionFactor() const {
if (is_frequency_hopping) {
switch (baud_rate) {
case 300: return 2;
case 150: return 4;
case 75: return 8;
default: return 1;
}
} else {
switch (baud_rate) {
case 300: return 2;
case 150: return 4;
default: return 1;
}
}
}
};

View File

@ -6,7 +6,7 @@
#include <stdexcept>
#include <vector>
#include "bitstream.h"
#include "bitstream/bitstream.h"
/**
* @class Interleaver
@ -34,30 +34,43 @@ public:
* @param input_data The input BitStream to be interleaved.
* @return A new BitStream containing the interleaved data.
*/
std::vector<uint8_t> interleaveStream(const BitStream& input_data) {
BitStream data = input_data;
BitStream interleaved_data;
std::vector<uint8_t> interleaveStream(bitstream::fixed_bit_reader& input_data) {
std::vector<uint8_t> interleaved_buffer;
bitstream::growing_bit_writer<std::vector<uint8_t>> interleaved_data(interleaved_buffer);
size_t chunk_size = rows * columns;
size_t input_index = 0;
while (input_index < data.getMaxBitIndex()) {
size_t end_index = std::min(input_index + chunk_size, data.getMaxBitIndex());
BitStream chunk = data.getSubStream(input_index, end_index);
if (chunk.getMaxBitIndex() > rows * columns) {
throw std::invalid_argument("Input data exceeds interleaver matrix size in loadChunk()");
std::vector<uint8_t> chunk_data((chunk_size + 7) / 8, 0);
while (input_data.get_remaining_bits() >= chunk_size) {
std::fill(chunk_data.begin(), chunk_data.end(), 0);
if (!input_data.serialize_bytes(chunk_data.data(), chunk_size)) {
throw std::runtime_error("Failed to serialize chunk from input data");
}
BitStream interleaved_chunk = interleaveChunk(chunk);
interleaved_data += interleaved_chunk;
input_index = end_index;
bitstream::fixed_bit_reader chunk_reader(chunk_data.data(), chunk_size);
interleaveChunk(interleaved_data, chunk_reader);
}
// Apply puncturing for 2400 bps in frequency-hopping mode (Rate 2/3)
if (baud_rate == 2400 && is_frequency_hopping) {
return applyPuncturing(interleaved_data);
std::vector<uint8_t> punctured_buffer;
bitstream::growing_bit_writer<std::vector<uint8_t>> punctured_writer(punctured_buffer);
bitstream::fixed_bit_reader interleaved_reader(interleaved_buffer.data(), interleaved_buffer.size() * 8);
applyPuncturing(punctured_writer, interleaved_reader);
interleaved_buffer = punctured_buffer;
}
std::vector<uint8_t> final_interleaved_data = groupSymbols(interleaved_data);
return final_interleaved_data;
bitstream::fixed_bit_reader final_reader(interleaved_buffer.data(), interleaved_buffer.size() * 8);
return groupSymbols(final_reader);
}
/**
@ -65,7 +78,7 @@ public:
* @return The number of bits needed for a complete flush.
*/
size_t getFlushBits() const {
return rows * columns;
return (interleave_setting == 0) ? 0 : (rows * columns);
}
private:
@ -84,21 +97,16 @@ private:
* @param input_data The input BitStream to be grouped into symbols.
* @return A vector of grouped symbols.
*/
std::vector<uint8_t> groupSymbols(BitStream& input_data) {
std::vector<uint8_t> groupSymbols(bitstream::fixed_bit_reader& input_data) {
std::vector<uint8_t> grouped_data;
size_t max_index = input_data.getMaxBitIndex();
size_t bits_per_symbol = (baud_rate == 2400) ? 3 : (baud_rate == 1200 || (baud_rate == 75 && !is_frequency_hopping)) ? 2 : 1;
size_t current_index = 0;
while ((current_index + bits_per_symbol) < max_index) {
uint8_t symbol = 0;
while (input_data.get_remaining_bits() >= bits_per_symbol) {
uint32_t symbol = 0;
for (int i = 0; i < bits_per_symbol; i++) {
symbol = (symbol << 1) | input_data.getBitVal(current_index + i);
}
input_data.serialize_bits(symbol, bits_per_symbol);
grouped_data.push_back(symbol);
current_index += bits_per_symbol;
grouped_data.push_back(static_cast<uint8_t>(symbol));
}
return grouped_data;
@ -109,16 +117,16 @@ private:
* @param input_data The input BitStream chunk.
* @return A BitStream representing the interleaved chunk.
*/
BitStream interleaveChunk(const BitStream& input_data) {
void interleaveChunk(bitstream::growing_bit_writer<std::vector<uint8_t>>& interleaved_writer, bitstream::fixed_bit_reader& input_data) {
loadChunk(input_data);
return fetchChunk();
return fetchChunk(interleaved_writer);
}
/**
* @brief Loads bits from the input BitStream into the interleaver matrix.
* @param data The input BitStream to load.
*/
void loadChunk(const BitStream& data) {
void loadChunk(bitstream::fixed_bit_reader& data) {
size_t row = 0;
size_t col = 0;
size_t index = 0;
@ -127,13 +135,19 @@ private:
// Load bits into the matrix
std::fill(matrix.begin(), matrix.end(), 0); // Clear previous data
while (index < data.getMaxBitIndex() && col < columns) {
while (data.get_remaining_bits() > 0 && col < columns) {
size_t matrix_idx = row * columns + col;
if (matrix_idx >= matrix.size()) {
throw std::out_of_range("Matrix index out of bounds in loadChunk()");
}
matrix[matrix_idx] = data.getBitVal(index++);
uint32_t bit = 0;
if (!data.serialize_bits(bit, 1)) {
throw std::runtime_error("Failed to read bit from chunk_reader in loadChunk()");
}
matrix[matrix_idx] = static_cast<uint8_t>(bit);
row = (row + row_increment) % rows;
if (row == 0) {
@ -146,21 +160,23 @@ private:
* @brief Fetches bits from the interleaver matrix in the interleaved order.
* @return A BitStream containing the fetched interleaved data.
*/
BitStream fetchChunk() {
BitStream fetched_data;
void fetchChunk(bitstream::growing_bit_writer<std::vector<uint8_t>>& interleaved_writer) {
size_t row = 0;
size_t col = 0;
size_t column_decrement = (baud_rate == 75 && interleave_setting == 2) ? 7 : 17;
// Fetch bits from the matrix
while (fetched_data.getMaxBitIndex() < rows * columns) {
for (size_t i = 0; i < rows * columns; i++) {
size_t matrix_idx = row * columns + col;
if (matrix_idx >= matrix.size()) {
throw std::out_of_range("Matrix index out of bounds in fetchChunk()");
}
fetched_data.putBit(matrix[matrix_idx]);
uint32_t bit = static_cast<uint32_t>(matrix[matrix_idx]);
if (!interleaved_writer.serialize_bits(bit, 1)) {
throw std::runtime_error("Failed to write bit to interleaved_writer in fetchChunk()");
}
row++;
if (row == rows) {
@ -170,8 +186,6 @@ private:
col = (col + columns - column_decrement) % columns;
}
}
return fetched_data;
}
@ -180,10 +194,7 @@ private:
* @brief Sets the matrix dimensions based on baud rate and interleave setting.
*/
void setMatrixDimensions() {
if (baud_rate == 4800) {
rows = 0;
columns = 0;
} else if (baud_rate == 2400) {
if (baud_rate == 2400) {
rows = 40;
columns = (interleave_setting == 2) ? 576 : 72;
} else if (baud_rate == 1200) {
@ -213,14 +224,18 @@ private:
* @param interleaved_data The interleaved data to be punctured.
* @return A BitStream containing punctured data.
*/
BitStream applyPuncturing(const BitStream& interleaved_data) {
BitStream punctured_data;
for (size_t i = 0; i < interleaved_data.getMaxBitIndex(); i++) {
if ((i % 4) != 1) { // Skip every fourth bit (the second value of T2)
punctured_data.putBit(interleaved_data.getBitVal(i));
void applyPuncturing(bitstream::growing_bit_writer<std::vector<uint8_t>>& punctured_data, bitstream::fixed_bit_reader& interleaved_data) {
size_t bit_index = 0;
while (interleaved_data.get_remaining_bits() > 0) {
uint32_t bit = 0;
interleaved_data.serialize_bits(bit, 1);
if ((bit_index % 4) != 3) {
punctured_data.serialize_bits(bit, 1);
}
bit_index++;
}
return punctured_data;
}
};

View File

@ -6,7 +6,7 @@
#include <memory>
#include <vector>
#include "bitstream.h"
#include "bitstream/bitstream.h"
#include "FECEncoder.h"
#include "Interleaver.h"
#include "MGDDecoder.h"
@ -45,50 +45,54 @@ public:
is_voice(_is_voice),
is_frequency_hopping(_is_frequency_hopping),
interleave_setting(_interleave_setting),
symbol_formation(_baud_rate, _interleave_setting, _is_voice, _is_frequency_hopping),
symbol_formation(baud_rate, interleave_setting, is_voice, is_frequency_hopping),
scrambler(),
fec_encoder(_baud_rate, _is_frequency_hopping),
interleaver(_baud_rate, _interleave_setting, _is_frequency_hopping),
mgd_decoder(_baud_rate, _is_frequency_hopping),
modulator(48000, _is_frequency_hopping, 48) {}
fec_encoder(baud_rate, is_frequency_hopping),
interleaver(baud_rate, interleave_setting, is_frequency_hopping),
mgd_decoder(baud_rate, is_frequency_hopping),
modulator(baud_rate, 48000, 0.5, is_frequency_hopping) {}
/**
* @brief Transmits the input data by processing it through different phases like FEC encoding, interleaving, symbol formation, scrambling, and modulation.
* @return The scrambled data ready for modulation.
* @note The modulated signal is generated internally but is intended to be handled externally.
*/
std::vector<int16_t> transmit(const BitStream& input_data) {
// Step 1: Append EOM Symbols
BitStream eom_appended_data = appendEOMSymbols(input_data);
std::vector<int16_t> transmit(bitstream::fixed_bit_reader& input_data) {
// Step 1: Append EOM Symbols using a uint32_t aligned output buffer
std::vector<uint8_t> output_buffer;
bitstream::growing_bit_writer<std::vector<uint8_t>> output_writer(output_buffer);
appendEOMSymbols(output_writer, input_data);
// Step 2: Handle Baud Rate Specific Encoding
std::vector<uint8_t> processed_data;
if (baud_rate == 4800) {
processed_data = splitTribitSymbols(eom_appended_data);
// For 4800 baud, perform tribit symbol splitting
bitstream::fixed_bit_reader eom_appended_reader(output_buffer.data(), output_writer.get_num_bits_serialized());
processed_data = splitTribitSymbols(eom_appended_reader);
} else {
// Step 2: FEC Encoding
BitStream fec_encoded_data = fec_encoder.encode(eom_appended_data);
// Step 3: FEC Encoding
bitstream::fixed_bit_reader eom_appended_reader(output_buffer.data(), output_writer.get_num_bits_serialized());
std::vector<uint8_t> fec_encoded_buffer;
bitstream::growing_bit_writer<std::vector<uint8_t>> fec_encoded_writer(fec_encoded_buffer);
fec_encoder.encode(fec_encoded_writer, eom_appended_reader);
// Step 3: Interleaving
processed_data = interleaver.interleaveStream(fec_encoded_data);
// Step 4: Interleaving
bitstream::fixed_bit_reader fec_encoded_reader(fec_encoded_buffer.data(), fec_encoded_writer.get_num_bits_serialized());
processed_data = interleaver.interleaveStream(fec_encoded_reader);
}
// Step 4: MGD Decoding
// Step 5: MGD Decoding
std::vector<uint8_t> mgd_decoded_data = mgd_decoder.mgdDecode(processed_data);
// Step 5: Symbol Formation. This function injects the sync preamble symbols. Scrambling is handled internally.
// Step 6: Symbol Formation (including sync preamble and scrambling)
std::vector<uint8_t> symbol_stream = symbol_formation.formSymbols(mgd_decoded_data);
// Step 6. Modulation. The symbols are applied via 2400-bps 8-PSK modulation, with a 48 KHz sample rate.
// Step 7: Modulation
std::vector<int16_t> modulated_signal = modulator.modulate(symbol_stream);
return modulated_signal;
}
BitStream receive(const std::vector<int16_t>& passband_signal) {
// Step one: Demodulate the passband signal and retrieve decoded symbols
std::vector<uint8_t> demodulated_symbols = modulator.demodulate(passband_signal, baud_rate, interleave_setting, is_voice);
return BitStream();
}
private:
size_t baud_rate; ///< The baud rate for the modem.
@ -112,39 +116,38 @@ private:
* the FEC encoder and interleaver matrices. The function calculates the number of flush bits required
* based on the FEC and interleaver settings.
*/
BitStream appendEOMSymbols(const BitStream& input_data) const {
BitStream eom_data = input_data;
void appendEOMSymbols(bitstream::growing_bit_writer<std::vector<uint8_t>>& output_data, bitstream::fixed_bit_reader& input_data) const {
while (input_data.get_num_bits_serialized() < input_data.get_total_bits()) {
uint32_t value;
uint32_t bits_to_read = std::min(32U, input_data.get_remaining_bits());
input_data.serialize_bits(value, bits_to_read);
output_data.serialize_bits(value, bits_to_read);
}
// Append the EOM sequence (4B65A5B2 in hexadecimal)
BitStream eom_sequence({0x4B, 0x65, 0xA5, 0xB2}, 32);
eom_data += eom_sequence;
uint32_t eom_sequence = 0x4B65A5B2;
output_data.serialize_bits(eom_sequence, 32);
// Append additional zeros to flush the FEC encoder and interleaver
size_t fec_flush_bits = 144; // FEC encoder flush bits
size_t interleave_flush_bits = interleaver.getFlushBits();
size_t total_flush_bits = fec_flush_bits + ((interleave_setting == 0) ? 0 : interleave_flush_bits);
if (interleave_flush_bits > 0) {
while ((eom_data.getMaxBitIndex() + total_flush_bits) % interleave_flush_bits)
total_flush_bits++;
}
size_t total_bytes = (total_flush_bits + 7) / 8; // Round up to ensure we have enough bytes to handle all bits.
BitStream flush_bits(std::vector<uint8_t>(total_bytes, 0), total_flush_bits);
eom_data += flush_bits;
return eom_data;
size_t current_bit_index = output_data.get_num_bits_serialized();
size_t alignment_bits_needed = (interleave_flush_bits - (current_bit_index + fec_flush_bits) % interleave_flush_bits) % interleave_flush_bits;
total_flush_bits += alignment_bits_needed;
}
std::vector<uint8_t> splitTribitSymbols(const BitStream& input_data) {
std::vector<uint8_t> splitTribitSymbols(bitstream::fixed_bit_reader& input_data) {
std::vector<uint8_t> return_vector;
size_t max_index = input_data.getMaxBitIndex();
size_t current_index = 0;
size_t total_bits = input_data.get_total_bits();
size_t num_bits_serialized = input_data.get_num_bits_serialized();
while (current_index + 2 < max_index) {
uint8_t symbol = 0;
for (int i = 0; i < 3; i++) {
symbol = (symbol << 1) | input_data.getBitVal(current_index + i);
}
return_vector.push_back(symbol);
current_index += 3;
while (num_bits_serialized + 3 <= total_bits) {
uint32_t symbol = 0;
input_data.serialize_bits(symbol, 3);
return_vector.push_back(static_cast<uint8_t>(symbol));
num_bits_serialized = input_data.get_num_bits_serialized();
}
return return_vector;

View File

@ -15,14 +15,14 @@ public:
/**
* @brief Constructor initializes the scrambler with a predefined register value.
*/
Scrambler() : data_sequence_register(0x0BAD), symbol_count(0), preamble_table_index(0) {}
Scrambler() : data_sequence_register(0x0BAD), symbol_count(0) {}
/**
* @brief Scrambles a synchronization preamble using a fixed randomizer sequence.
* @param preamble The synchronization preamble to scramble.
* @return The scrambled synchronization preamble.
*/
std::vector<uint8_t> scrambleSyncPreamble(const std::vector<uint8_t>& preamble) {
std::vector<uint8_t> scrambleSyncPreamble(const std::vector<uint8_t>& preamble) const {
static const std::array<uint8_t, 32> sync_randomizer_sequence = {
7, 4, 3, 0, 5, 1, 5, 0, 2, 2, 1, 1,
5, 7, 4, 3, 5, 0, 2, 6, 2, 1, 6, 2,
@ -33,9 +33,8 @@ public:
scrambled_preamble.reserve(preamble.size()); // Preallocate to improve efficiency
for (size_t i = 0; i < preamble.size(); ++i) {
uint8_t scrambled_value = (preamble[i] + sync_randomizer_sequence[preamble_table_index]) % 8;
uint8_t scrambled_value = (preamble[i] + sync_randomizer_sequence[i % sync_randomizer_sequence.size()]) % 8;
scrambled_preamble.push_back(scrambled_value);
preamble_table_index = (preamble_table_index + 1) % sync_randomizer_sequence.size();
}
return scrambled_preamble;
@ -62,7 +61,6 @@ public:
private:
uint16_t data_sequence_register;
size_t symbol_count;
size_t preamble_table_index;
/**
* @brief Generates the next value from the data sequence randomizing generator.

View File

@ -21,10 +21,16 @@ std::vector<uint8_t> baud75_normal_3 = {0, 4, 4, 0};
class SymbolFormation {
public:
SymbolFormation(size_t baud_rate, size_t interleave_setting, bool is_voice, bool is_frequency_hopping) : interleave_setting(interleave_setting), baud_rate(baud_rate), is_voice(is_voice), is_frequency_hopping(is_frequency_hopping) {
SymbolFormation(size_t baud_rate, size_t interleave_setting, bool is_voice, bool is_frequency_hopping) : interleave_setting(interleave_setting), baud_rate(baud_rate), is_voice(is_voice), is_frequency_hopping(is_frequency_hopping) {}
std::vector<uint8_t> formSymbols(std::vector<uint8_t>& symbol_data) {
// Generate and scramble the sync preamble
std::vector<uint8_t> sync_preamble = generateSyncPreamble();
sync_preamble = scrambler.scrambleSyncPreamble(sync_preamble);
// Determine the block sizes
unknown_data_block_size = (baud_rate >= 2400) ? 32 : 20;
known_data_block_size = (baud_rate >= 2400) ? 16 : 20;
size_t unknown_data_block_size = (baud_rate >= 2400) ? 32 : 20;
size_t interleaver_block_size;
if (baud_rate == 2400) {
interleaver_block_size = (interleave_setting == 2) ? (40 * 576) : (40 * 72);
@ -36,43 +42,38 @@ class SymbolFormation {
interleaver_block_size = (interleave_setting == 2) ? (20 * 36) : (10 * 9);
}
total_frames = interleaver_block_size / (unknown_data_block_size + known_data_block_size);
}
std::vector<uint8_t> formSymbols(std::vector<uint8_t>& symbol_data) {
// Generate and scramble the sync preamble
std::vector<uint8_t> sync_preamble = generateSyncPreamble();
sync_preamble = scrambler.scrambleSyncPreamble(sync_preamble);
size_t set_count = 0;
size_t symbol_count = 0;
std::vector<uint8_t> data_stream;
if (baud_rate == 75) {
size_t set_count = 0;
for (size_t i = 0; i < symbol_data.size(); i++) {
bool is_exceptional_set = (set_count % ((interleave_setting == 1) ? 45 : 360)) == 0;
append75bpsMapping(data_stream, symbol_data[i], is_exceptional_set);
set_count++;
}
} else {
size_t symbol_count = 0;
size_t current_frame = 0;
size_t current_index = 0;
size_t current_index = 0;
while (current_index < symbol_data.size()) {
// Determine the size of the current unknown data block
size_t block_size = std::min(unknown_data_block_size, symbol_data.size() - current_index);
std::vector<uint8_t> unknown_data_block(symbol_data.begin() + current_index, symbol_data.begin() + current_index + block_size);
current_index += block_size;
while (current_index < symbol_data.size()) {
// Determine the size of the current unknown data block
size_t block_size = std::min(unknown_data_block_size, symbol_data.size() - current_index);
std::vector<uint8_t> unknown_data_block(symbol_data.begin() + current_index, symbol_data.begin() + current_index + block_size);
current_index += block_size;
// Map the unknown data based on baud rate
// Map the unknown data based on baud rate
if (baud_rate == 75) {
size_t set_size = (interleave_setting == 2) ? 360 : 32;
for (size_t i = 0; i < unknown_data_block.size(); i += set_size) {
bool is_exceptional_set = (set_count % ((interleave_setting == 1) ? 45 : 360)) == 0;
std::vector<uint8_t> mapped_set = map75bpsSet(unknown_data_block, i, set_size, is_exceptional_set);
data_stream.insert(data_stream.end(), mapped_set.begin(), mapped_set.end());
set_count++;
}
} else {
// For baud rates greater than 75 bps
std::vector<uint8_t> mapped_unknown_data = mapUnknownData(unknown_data_block);
symbol_count += mapped_unknown_data.size();
data_stream.insert(data_stream.end(), mapped_unknown_data.begin(), mapped_unknown_data.end());
}
// Insert probe data if we are at an interleaver block boundary
std::vector<uint8_t> probe_data = generateProbeData(current_frame, total_frames);
// Insert probe data if we are at an interleaver block boundary
if (baud_rate > 75) {
bool is_at_boundary = (symbol_count % interleaver_block_size) == 0;
std::vector<uint8_t> probe_data = generateProbeData(!is_at_boundary);
data_stream.insert(data_stream.end(), probe_data.begin(), probe_data.end());
current_frame = (current_frame + 1) % total_frames;
}
}
@ -92,10 +93,6 @@ class SymbolFormation {
int interleave_setting;
bool is_voice;
bool is_frequency_hopping;
size_t interleaver_block_size;
size_t unknown_data_block_size;
size_t known_data_block_size;
size_t total_frames;
Scrambler scrambler = Scrambler();
std::vector<uint8_t> mapChannelSymbolToTribitPattern(uint8_t symbol, bool repeat_twice = false) {
@ -130,14 +127,17 @@ class SymbolFormation {
throw std::invalid_argument("Invalid channel symbol");
}
size_t repetitions = repeat_twice ? 2 : 4;
std::vector<uint8_t> repeated_pattern;
for (size_t i = 0; i < repetitions; i++) {
repeated_pattern.insert(repeated_pattern.end(), tribit_pattern.begin(), tribit_pattern.end());
if (repeat_twice) {
// Repeat the pattern twice instead of four times for known symbols
tribit_pattern.insert(tribit_pattern.end(), tribit_pattern.begin(), tribit_pattern.end());
} else {
// Repeat the pattern four times as per Table XIII
tribit_pattern.insert(tribit_pattern.end(), tribit_pattern.begin(), tribit_pattern.end());
tribit_pattern.insert(tribit_pattern.end(), tribit_pattern.begin(), tribit_pattern.end());
tribit_pattern.insert(tribit_pattern.end(), tribit_pattern.begin(), tribit_pattern.end());
}
return repeated_pattern;
return tribit_pattern;
}
std::vector<uint8_t> generateSyncPreamble() {
@ -212,60 +212,59 @@ class SymbolFormation {
return preamble;
}
std::vector<uint8_t> generateProbeData(size_t current_frame, size_t total_frames) {
std::vector<uint8_t> generateProbeData(bool is_inside_block) {
std::vector<uint8_t> probe_data;
// Set the known symbol patterns for D1 and D2 based on Table XI
uint8_t D1, D2;
if (baud_rate == 4800) {
D1 = 7; D2 = 6;
} else if (baud_rate == 2400 && is_voice) {
D1 = 7; D2 = 7;
} else if (baud_rate == 2400) {
D1 = (interleave_setting <= 1) ? 6 : 4;
D2 = 4;
// Determine interleaver block size based on baud rate and interleave setting
size_t interleaver_block_size;
if (baud_rate == 2400) {
interleaver_block_size = (interleave_setting == 2) ? (40 * 576) : (40 * 72);
} else if (baud_rate == 1200) {
D1 = (interleave_setting <= 1) ? 6 : 4;
D2 = 5;
} else if (baud_rate == 600) {
D1 = (interleave_setting <= 1) ? 6 : 4;
D2 = 6;
} else if (baud_rate == 300) {
D1 = (interleave_setting <= 1) ? 6 : 4;
D2 = 7;
} else if (baud_rate == 150) {
D1 = (interleave_setting <= 1) ? 7 : 5;
D2 = 4;
} else if (baud_rate == 75) {
D1 = (interleave_setting <= 1) ? 7 : 5;
D2 = 5;
interleaver_block_size = (interleave_setting == 2) ? (40 * 288) : (40 * 36);
} else if ((baud_rate >= 150) || (baud_rate == 75 && is_frequency_hopping)) {
interleaver_block_size = (interleave_setting == 2) ? (40 * 144) : (40 * 18);
} else {
throw std::invalid_argument("Invalid baud rate for generateProbeData");
interleaver_block_size = (interleave_setting == 2) ? (20 * 36) : (10 * 9);
}
// If the current frame is not the last two frames, set probe data to zeros
if (current_frame < total_frames - 2) {
probe_data.resize(known_data_block_size, 0x00);
}
// If the current frame is the second-to-last frame, set probe data to D1 pattern
else if (current_frame == total_frames - 2) {
// If we are inside an interleaver block, the probe data is filled with zeros
if (is_inside_block) {
probe_data.resize(interleaver_block_size, 0x00);
} else {
// Set the known symbol patterns for D1 and D2 based on Table XI
uint8_t D1, D2;
if (baud_rate == 4800) {
D1 = 7; D2 = 6;
} else if (baud_rate == 2400 && is_voice) {
D1 = 7; D2 = 7;
} else if (baud_rate == 2400) {
D1 = (interleave_setting <= 1) ? 6 : 4;
D2 = 4;
} else if (baud_rate == 1200) {
D1 = (interleave_setting <= 1) ? 6 : 4;
D2 = 5;
} else if (baud_rate == 600) {
D1 = (interleave_setting <= 1) ? 6 : 4;
D2 = 6;
} else if (baud_rate == 300) {
D1 = (interleave_setting <= 1) ? 6 : 4;
D2 = 7;
} else if (baud_rate == 150) {
D1 = (interleave_setting <= 1) ? 7 : 5;
D2 = 4;
} else if (baud_rate == 75) {
D1 = (interleave_setting <= 1) ? 7 : 5;
D2 = 5;
} else {
throw std::invalid_argument("Invalid baud rate for generateProbeData");
}
// Generate the known symbol patterns D1 and D2, repeated twice
std::vector<uint8_t> d1_pattern = mapChannelSymbolToTribitPattern(D1, true);
probe_data.insert(probe_data.end(), d1_pattern.begin(), d1_pattern.end());
// Fill the remaining symbols with zeros if necessary
if (probe_data.size() < known_data_block_size) {
probe_data.resize(known_data_block_size, 0x00);
}
}
// If the current frame is the last frame, set probe data to D2 pattern
else if (current_frame == total_frames - 1) {
std::vector<uint8_t> d2_pattern = mapChannelSymbolToTribitPattern(D2, true);
probe_data.insert(probe_data.end(), d2_pattern.begin(), d2_pattern.end());
// Fill the remaining symbols with zeros if necessary
if (probe_data.size() < known_data_block_size) {
probe_data.resize(known_data_block_size, 0x00);
}
probe_data.insert(probe_data.end(), d1_pattern.begin(), d1_pattern.end());
probe_data.insert(probe_data.end(), d2_pattern.begin(), d2_pattern.end());
}
return probe_data;
@ -333,6 +332,19 @@ class SymbolFormation {
}
}
std::vector<uint8_t> map75bpsSet(const std::vector<uint8_t>& data, size_t start_index, size_t set_size, bool is_exceptional_set) {
std::vector<uint8_t> mapped_set;
// Make sure we do not exceed the size of the data vector
size_t end_index = std::min(start_index + set_size, data.size());
for (size_t i = start_index; i < end_index; ++i) {
append75bpsMapping(mapped_set, data[i], is_exceptional_set);
}
return mapped_set;
}
std::vector<uint8_t> mapUnknownData(const std::vector<uint8_t>& data) {
std::vector<uint8_t> mapped_data;

View File

@ -0,0 +1,131 @@
#ifndef FSK_DEMODULATOR_H
#define FSK_DEMODULATOR_H
#include <cmath>
#include <cstdint>
#include <functional>
#include <memory>
#include <vector>
#include "BitStreamWriter.h"
class FSKDemodulatorConfig {
public:
int freq_lo;
int freq_hi;
int sample_rate;
int baud_rate;
std::shared_ptr<BitStreamWriter> bitstreamwriter;
};
namespace milstd {
class FSKDemodulator {
public:
FSKDemodulator(const FSKDemodulatorConfig& s) : freq_lo(s.freq_lo), freq_hi(s.freq_hi), sample_rate(s.sample_rate), baud_rate(s.baud_rate), bit_writer(s.bitstreamwriter) {
initialize();
}
void demodulate(const std::vector<int16_t>& samples) {
size_t nb = samples.size();
for (size_t i = 0; i < nb; i++) {
filter_buf[buf_ptr++] = samples[i];
if (buf_ptr == filter_buf.size()) {
std::copy(filter_buf.begin() + filter_buf.size() - filter_size, filter_buf.end(), filter_buf.begin());
buf_ptr = filter_size;
}
int corr;
int sum = 0;
corr = dotProduct(&filter_buf[buf_ptr - filter_size], filter_hi_i.data(), filter_size);
sum += corr * corr;
corr = dotProduct(&filter_buf[buf_ptr - filter_size], filter_hi_q.data(), filter_size);
sum += corr * corr;
corr = dotProduct(&filter_buf[buf_ptr - filter_size], filter_lo_i.data(), filter_size);
sum -= corr * corr;
corr = dotProduct(&filter_buf[buf_ptr - filter_size], filter_lo_q.data(), filter_size);
sum -= corr * corr;
int new_sample = (sum > 0) ? 1 : 0;
if (last_sample != new_sample) {
last_sample = new_sample;
if (baud_pll < 0.5)
baud_pll += baud_pll_adj;
else
baud_pll -= baud_pll_adj;
}
baud_pll += baud_incr;
if (baud_pll >= 1.0) {
baud_pll -= 1.0;
bit_writer->putBit(last_sample);
}
}
}
private:
int freq_lo;
int freq_hi;
int sample_rate;
int baud_rate;
std::shared_ptr<BitStreamWriter> bit_writer;
int filter_size;
std::vector<double> filter_lo_i;
std::vector<double> filter_lo_q;
std::vector<double> filter_hi_i;
std::vector<double> filter_hi_q;
std::vector<double> filter_buf;
size_t buf_ptr;
double baud_incr;
double baud_pll;
double baud_pll_adj;
int last_sample;
void initialize() {
baud_incr = static_cast<double>(baud_rate) / sample_rate;
baud_pll = 0.0;
baud_pll_adj = baud_incr / 4;
filter_size = sample_rate / baud_rate;
filter_buf.resize(filter_size * 2, 0.0);
buf_ptr = filter_size;
last_sample = 0;
filter_lo_i.resize(filter_size);
filter_lo_q.resize(filter_size);
filter_hi_i.resize(filter_size);
filter_hi_q.resize(filter_size);
for (int i = 0; i < filter_size; i++) {
double phase_lo = 2.0 * M_PI * freq_lo * i / sample_rate;
filter_lo_i[i] = std::cos(phase_lo);
filter_lo_q[i] = std::sin(phase_lo);
double phase_hi = 2.0 * M_PI * freq_hi * i / sample_rate;
filter_hi_i[i] = std::cos(phase_hi);
filter_hi_q[i] = std::sin(phase_hi);
}
}
double dotProduct(const double* x, const double* y, size_t size) {
double sum = 0.0;
for (size_t i = 0; i < size; i++) {
sum += x[i] * y[i];
}
return sum;
}
};
} // namespace milstd
#endif /* FSK_DEMODULATOR_H */

View File

@ -0,0 +1,87 @@
#ifndef FSK_MODULATOR_H
#define FSK_MODULATOR_H
#include <cmath>
#include <cstdint>
#include <functional>
#include <memory>
#include <vector>
#include "BitStreamReader.h"
class FSKModulatorConfig {
public:
int freq_lo;
int freq_hi;
int sample_rate;
int baud_rate;
std::shared_ptr<BitStreamReader> bitstreamreader;
};
namespace milstd {
class FSKModulator {
public:
FSKModulator(const FSKModulatorConfig& s) : freq_lo(s.freq_lo), freq_hi(s.freq_hi), sample_rate(s.sample_rate), baud_rate(s.baud_rate), bit_reader(s.bitstreamreader) {
omega[0] = (2.0 * M_PI * freq_lo) / sample_rate;
omega[1] = (2.0 * M_PI * freq_hi) / sample_rate;
baud_incr = static_cast<double>(baud_rate) / sample_rate;
phase = 0.0;
baud_frac = 0.0;
current_bit = 0;
}
std::vector<int16_t> modulate(unsigned int num_samples) {
std::vector<int16_t> samples;
samples.reserve(num_samples);
int bit = current_bit;
for (unsigned int i = 0; i < num_samples; i++) {
baud_frac += baud_incr;
if (baud_frac >= 1.0) {
baud_frac -= 1.0;
try
{
bit = bit_reader->getNextBit();
}
catch(const std::out_of_range&)
{
bit = 0;
}
if (bit != 0 && bit != 1)
bit = 0;
}
double sample = std::cos(phase);
int16_t sample_int = static_cast<int16_t>(sample * 32767);
samples.push_back(sample_int);
phase += omega[bit];
if (phase >= 2.0 * M_PI) {
phase -= 2.0 * M_PI;
}
}
current_bit = bit;
return samples;
}
private:
// parameters
int freq_lo, freq_hi;
int sample_rate;
int baud_rate;
std::shared_ptr<BitStreamReader> bit_reader;
// state variables
double phase;
double baud_frac;
double baud_incr;
std::array<double, 2> omega;
int current_bit;
};
} // namespace milstd
#endif /* FSK_MODULATOR_H */

View File

View File

@ -1,366 +1,136 @@
#ifndef PSK_MODULATOR_H
#define PSK_MODULATOR_H
#include <algorithm>
#include <array>
#include <cmath>
#include <complex>
#include <cstdint>
#include <numeric>
#include <stdexcept>
#include <vector>
#include <fftw3.h>
#include <map>
#include <tuple>
#include "costasloop.h"
#include "filters.h"
#include "Scrambler.h"
static constexpr double CARRIER_FREQ = 1800.0;
static constexpr size_t SYMBOL_RATE = 2400;
static constexpr double ROLLOFF_FACTOR = 0.35;
static constexpr double SCALE_FACTOR = 32767.0;
#include <cmath>
#include <cstdint>
#include <stdexcept>
#include <complex>
#include <algorithm>
class PSKModulator {
public:
PSKModulator(const double _sample_rate, const bool _is_frequency_hopping, const size_t _num_taps)
: sample_rate(validateSampleRate(_sample_rate)), gain(1.0/sqrt(2.0)), is_frequency_hopping(_is_frequency_hopping), samples_per_symbol(static_cast<size_t>(sample_rate / SYMBOL_RATE)), srrc_filter(8, _sample_rate, SYMBOL_RATE, ROLLOFF_FACTOR) {
initializeSymbolMap();
phase_detector = PhaseDetector(symbolMap);
PSKModulator(double baud_rate, double sample_rate, double energy_per_bit, bool is_frequency_hopping)
: sample_rate(sample_rate), carrier_freq(1800), phase(0.0) {
initializeSymbolMap();
symbol_rate = 2400; // Fixed symbol rate as per specification (2400 symbols per second)
samples_per_symbol = static_cast<size_t>(sample_rate / symbol_rate);
}
std::vector<int16_t> modulate(const std::vector<uint8_t>& symbols) {
std::vector<std::complex<double>> baseband_components(symbols.size() * samples_per_symbol);
size_t symbol_index = 0;
std::vector<std::complex<double>> modulated_signal;
for (const auto& symbol : symbols) {
const double phase_increment = 2 * M_PI * carrier_freq / sample_rate;
for (auto symbol : symbols) {
if (symbol >= symbolMap.size()) {
throw std::out_of_range("Invalid symbol value for 8-PSK modulation. Symbol must be between 0 and 7.");
throw std::out_of_range("Invalid symbol value for 8-PSK modulation");
}
const std::complex<double> target_symbol = symbolMap[symbol];
std::complex<double> target_symbol = symbolMap[symbol];
for (size_t i = 0; i < samples_per_symbol; ++i) {
baseband_components[symbol_index * samples_per_symbol + i] = target_symbol;
double in_phase = std::cos(phase + target_symbol.real());
double quadrature = std::sin(phase + target_symbol.imag());
modulated_signal.emplace_back(in_phase, quadrature);
phase = std::fmod(phase + phase_increment, 2 * M_PI);
}
symbol_index++;
}
// Filter the I/Q phase components
std::vector<std::complex<double>> filtered_components = srrc_filter.applyFilter(baseband_components);
// Apply raised-cosine filter
auto filter_taps = sqrtRaisedCosineFilter(201, symbol_rate); // Adjusted number of filter taps to 201 for balance
auto filtered_signal = applyFilter(modulated_signal, filter_taps);
// Combine the I and Q components
std::vector<double> passband_signal;
passband_signal.reserve(baseband_components.size());
// Normalize the filtered signal
double max_value = 0.0;
for (const auto& sample : filtered_signal) {
max_value = std::max(max_value, std::abs(sample.real()));
max_value = std::max(max_value, std::abs(sample.imag()));
}
double gain = (max_value > 0) ? (32767.0 / max_value) : 1.0;
double carrier_phase = 0.0;
double carrier_phase_increment = 2 * M_PI * CARRIER_FREQ / sample_rate;
for (const auto& sample : filtered_components) {
double carrier_cos = std::cos(carrier_phase);
double carrier_sin = -std::sin(carrier_phase);
double passband_value = sample.real() * carrier_cos + sample.imag() * carrier_sin;
passband_signal.emplace_back(passband_value * SCALE_FACTOR); // Scale to int16_t
carrier_phase += carrier_phase_increment;
if (carrier_phase >= 2 * M_PI)
carrier_phase -= 2 * M_PI;
// Combine the I and Q components and apply gain for audio output
std::vector<int16_t> combined_signal;
for (auto& sample : filtered_signal) {
int16_t combined_sample = static_cast<int16_t>(std::clamp(gain * (sample.real() + sample.imag()), -32768.0, 32767.0));
combined_signal.push_back(combined_sample);
}
std::vector<int16_t> final_signal;
final_signal.reserve(passband_signal.size());
for (const auto& sample : passband_signal) {
int16_t value = static_cast<int16_t>(sample);
value = std::clamp(value, (int16_t)-32768, (int16_t)32767);
final_signal.emplace_back(value);
}
return final_signal;
return combined_signal;
}
std::vector<uint8_t> demodulate(const std::vector<int16_t> passband_signal, size_t& baud_rate, size_t& interleave_setting, bool& is_voice) {
// Carrier recovery. initialize the Costas loop.
CostasLoop costas_loop(CARRIER_FREQ, sample_rate, symbolMap, 5.0, 0.05, 0.01);
std::vector<double> sqrtRaisedCosineFilter(size_t num_taps, double symbol_rate) {
double rolloff = 0.35; // Fixed rolloff factor as per specification
std::vector<double> filter_taps(num_taps);
double norm_factor = 0.0;
double sampling_interval = 1.0 / sample_rate;
double symbol_duration = 1.0 / symbol_rate;
double half_num_taps = static_cast<double>(num_taps - 1) / 2.0;
// Convert passband signal to doubles.
std::vector<double> normalized_passband(passband_signal.size());
for (size_t i = 0; i < passband_signal.size(); i++) {
normalized_passband[i] = passband_signal[i] / 32767.0;
}
// Downmix passband to baseband
std::vector<std::complex<double>> baseband_IQ = costas_loop.process(normalized_passband);
std::vector<uint8_t> detected_symbols;
// Phase detection and symbol formation
size_t samples_per_symbol = sample_rate / SYMBOL_RATE;
bool sync_found = false;
size_t sync_segments_detected;
size_t window_size = 32*15;
for (size_t i = 0; i < baseband_IQ.size(); i += samples_per_symbol) {
std::complex<double> symbol_avg(0.0, 0.0);
for (size_t j = 0; j < samples_per_symbol; j++) {
symbol_avg += baseband_IQ[i + j];
for (size_t i = 0; i < num_taps; ++i) {
double t = (i - half_num_taps) * sampling_interval;
if (std::abs(t) < 1e-10) {
filter_taps[i] = 1.0;
} else {
double numerator = std::sin(M_PI * t / symbol_duration * (1.0 - rolloff)) +
4.0 * rolloff * t / symbol_duration * std::cos(M_PI * t / symbol_duration * (1.0 + rolloff));
double denominator = M_PI * t * (1.0 - std::pow(4.0 * rolloff * t / symbol_duration, 2));
filter_taps[i] = numerator / denominator;
}
symbol_avg /= static_cast<double>(samples_per_symbol);
uint8_t detected_symbol = phase_detector.getSymbol(symbol_avg);
detected_symbols.push_back(detected_symbol);
norm_factor += filter_taps[i] * filter_taps[i];
}
if (processSyncSegments(detected_symbols, baud_rate, interleave_setting, is_voice)) {
return processDataSymbols(detected_symbols);
norm_factor = std::sqrt(norm_factor);
std::for_each(filter_taps.begin(), filter_taps.end(), [&norm_factor](double &tap) { tap /= norm_factor; });
return filter_taps;
}
std::vector<std::complex<double>> applyFilter(const std::vector<std::complex<double>>& signal, const std::vector<double>& filter_taps) {
std::vector<std::complex<double>> filtered_signal(signal.size());
size_t filter_length = filter_taps.size();
size_t half_filter_length = filter_length / 2;
// Convolve the signal with the filter taps
for (size_t i = 0; i < signal.size(); ++i) {
double filtered_i = 0.0;
double filtered_q = 0.0;
for (size_t j = 0; j < filter_length; ++j) {
if (i >= j) {
filtered_i += filter_taps[j] * signal[i - j].real();
filtered_q += filter_taps[j] * signal[i - j].imag();
} else {
// Handle edge case by zero-padding
filtered_i += filter_taps[j] * 0.0;
filtered_q += filter_taps[j] * 0.0;
}
}
filtered_signal[i] = std::complex<double>(filtered_i, filtered_q);
}
return std::vector<uint8_t>();
return filtered_signal;
}
private:
const double sample_rate; ///< The sample rate of the system.
const double gain; ///< The gain of the modulated signal.
double sample_rate; ///< The sample rate of the system.
double carrier_freq; ///< The frequency of the carrier, set to 1800 Hz as per standard.
double phase; ///< Current phase of the carrier waveform.
size_t samples_per_symbol; ///< Number of samples per symbol, calculated to match symbol duration with cycle.
PhaseDetector phase_detector;
SRRCFilter srrc_filter;
size_t symbol_rate;
std::vector<std::complex<double>> symbolMap; ///< The mapping of tribit symbols to I/Q components.
const bool is_frequency_hopping; ///< Whether to use frequency hopping methods. Not implemented (yet?)
static inline double validateSampleRate(const double rate) {
if (rate <= 2 * (CARRIER_FREQ + SYMBOL_RATE * (1 + ROLLOFF_FACTOR) / 2)) {
throw std::out_of_range("Sample rate must be above the Nyquist frequency (PSKModulator.h)");
}
return rate;
}
inline void initializeSymbolMap() {
void initializeSymbolMap() {
symbolMap = {
{gain * std::cos(2.0*M_PI*(0.0/8.0)), gain * std::sin(2.0*M_PI*(0.0/8.0))}, // 0 (000) corresponds to I = 1.0, Q = 0.0
{gain * std::cos(2.0*M_PI*(1.0/8.0)), gain * std::sin(2.0*M_PI*(1.0/8.0))}, // 1 (001) corresponds to I = cos(45), Q = sin(45)
{gain * std::cos(2.0*M_PI*(2.0/8.0)), gain * std::sin(2.0*M_PI*(2.0/8.0))}, // 2 (010) corresponds to I = 0.0, Q = 1.0
{gain * std::cos(2.0*M_PI*(3.0/8.0)), gain * std::sin(2.0*M_PI*(3.0/8.0))}, // 3 (011) corresponds to I = cos(135), Q = sin(135)
{gain * std::cos(2.0*M_PI*(4.0/8.0)), gain * std::sin(2.0*M_PI*(4.0/8.0))}, // 4 (100) corresponds to I = -1.0, Q = 0.0
{gain * std::cos(2.0*M_PI*(5.0/8.0)), gain * std::sin(2.0*M_PI*(5.0/8.0))}, // 5 (101) corresponds to I = cos(225), Q = sin(225)
{gain * std::cos(2.0*M_PI*(6.0/8.0)), gain * std::sin(2.0*M_PI*(6.0/8.0))}, // 6 (110) corresponds to I = 0.0, Q = -1.0
{gain * std::cos(2.0*M_PI*(7.0/8.0)), gain * std::sin(2.0*M_PI*(7.0/8.0))} // 7 (111) corresponds to I = cos(315), Q = sin(315)
{1.0, 0.0}, // 0 (000) corresponds to I = 1.0, Q = 0.0
{std::sqrt(2.0) / 2.0, std::sqrt(2.0) / 2.0}, // 1 (001) corresponds to I = cos(45), Q = sin(45)
{0.0, 1.0}, // 2 (010) corresponds to I = 0.0, Q = 1.0
{-std::sqrt(2.0) / 2.0, std::sqrt(2.0) / 2.0}, // 3 (011) corresponds to I = cos(135), Q = sin(135)
{-1.0, 0.0}, // 4 (100) corresponds to I = -1.0, Q = 0.0
{-std::sqrt(2.0) / 2.0, -std::sqrt(2.0) / 2.0}, // 5 (101) corresponds to I = cos(225), Q = sin(225)
{0.0, -1.0}, // 6 (110) corresponds to I = 0.0, Q = -1.0
{std::sqrt(2.0) / 2.0, -std::sqrt(2.0) / 2.0} // 7 (111) corresponds to I = cos(315), Q = sin(315)
};
}
uint8_t extractBestTribit(const std::vector<uint8_t>& stream, const size_t start, const size_t window_size) const {
if (start + window_size > stream.size()) {
throw std::out_of_range("Window size exceeds symbol stream size.");
}
Scrambler scrambler;
std::vector<uint8_t> symbol(stream.begin() + start, stream.begin() + start + window_size);
std::vector<uint8_t> descrambled_symbol = scrambler.scrambleSyncPreamble(symbol);
const size_t split_len = window_size / 4;
std::array<uint8_t, 8> tribit_counts = {0}; // Counts for each channel symbol (000 to 111)
// Loop through each split segment (4 segments)
for (size_t i = 0; i < 4; ++i) {
// Extract the range for this split
size_t segment_start = start + i * split_len;
size_t segment_end = segment_start + split_len;
// Compare this segment to the predefined patterns from the table and map to a channel symbol
uint8_t tribit_value = mapSegmentToChannelSymbol(descrambled_symbol, segment_start, segment_end);
// Increment the corresponding channel symbol count
tribit_counts[tribit_value]++;
}
// Find the channel symbol with the highest count (majority vote)
uint8_t best_symbol = std::distance(tribit_counts.begin(), std::max_element(tribit_counts.begin(), tribit_counts.end()));
return best_symbol;
}
// Function to map a segment of the stream back to a channel symbol based on the repeating patterns
uint8_t mapSegmentToChannelSymbol(const std::vector<uint8_t>& segment, size_t start, size_t end) const {
std::vector<uint8_t> extracted_pattern(segment.begin() + start, segment.begin() + end);
// Compare the extracted pattern with known patterns from the table
if (matchesPattern(extracted_pattern, {0, 0, 0, 0, 0, 0, 0, 0})) return 0b000;
if (matchesPattern(extracted_pattern, {0, 4, 0, 4, 0, 4, 0, 4})) return 0b001;
if (matchesPattern(extracted_pattern, {0, 0, 4, 4, 0, 0, 4, 4})) return 0b010;
if (matchesPattern(extracted_pattern, {0, 4, 4, 0, 0, 4, 4, 0})) return 0b011;
if (matchesPattern(extracted_pattern, {0, 0, 0, 0, 4, 4, 4, 4})) return 0b100;
if (matchesPattern(extracted_pattern, {0, 4, 0, 4, 4, 0, 4, 0})) return 0b101;
if (matchesPattern(extracted_pattern, {0, 0, 4, 4, 4, 4, 0, 0})) return 0b110;
if (matchesPattern(extracted_pattern, {0, 4, 4, 0, 4, 0, 0, 4})) return 0b111;
throw std::invalid_argument("Invalid segment pattern");
}
// Helper function to compare two patterns
bool matchesPattern(const std::vector<uint8_t>& segment, const std::vector<uint8_t>& pattern) const {
return std::equal(segment.begin(), segment.end(), pattern.begin());
}
bool configureModem(uint8_t D1, uint8_t D2, size_t& baud_rate, size_t& interleave_setting, bool& is_voice) {
// Predefine all the valid combinations in a lookup map
static const std::map<std::pair<uint8_t, uint8_t>, std::tuple<size_t, size_t, bool>> modemConfig = {
{{7, 6}, {4800, 1, false}}, // 4800 bps
{{7, 7}, {2400, 1, true}}, // 2400 bps, voice
{{6, 4}, {2400, 1, false}}, // 2400 bps, data
{{6, 5}, {1200, 1, false}}, // 1200 bps
{{6, 6}, {600, 1, false}}, // 600 bps
{{6, 7}, {300, 1, false}}, // 300 bps
{{7, 4}, {150, 1, false}}, // 150 bps
{{7, 5}, {75, 1, false}}, // 75 bps
{{4, 4}, {2400, 2, false}}, // 2400 bps, long interleave
{{4, 5}, {1200, 2, false}}, // 1200 bps, long interleave
{{4, 6}, {600, 2, false}}, // 600 bps, long interleave
{{4, 7}, {300, 2, false}}, // 300 bps, long interleave
{{5, 4}, {150, 2, false}}, // 150 bps, long interleave
{{5, 5}, {75, 2, false}}, // 75 bps, long interleave
};
// Use D1 and D2 to look up the correct configuration
auto it = modemConfig.find({D1, D2});
if (it != modemConfig.end()) {
// Set the parameters if found
std::tie(baud_rate, interleave_setting, is_voice) = it->second;
return true;
} else {
return false;
}
}
uint8_t calculateSegmentCount(const uint8_t C1, const uint8_t C2, const uint8_t C3) {
uint8_t extracted_C1 = C1 & 0b11;
uint8_t extracted_C2 = C2 & 0b11;
uint8_t extracted_C3 = C3 & 0b11;
uint8_t segment_count = (extracted_C1 << 4) | (extracted_C2 << 2) | extracted_C3;
return segment_count;
}
bool processSegment(const std::vector<uint8_t>& detected_symbols, size_t& start, size_t symbol_size, size_t& segment_count, uint8_t& D1, uint8_t& D2) {
size_t sync_pattern_length = 9;
if (start + symbol_size * sync_pattern_length > detected_symbols.size()) {
start = detected_symbols.size();
return false;
}
std::vector<uint8_t> window(detected_symbols.begin() + start, detected_symbols.begin() + start + sync_pattern_length * symbol_size);
std::vector<uint8_t> extracted_window;
for (size_t i = 0; i < sync_pattern_length; i++) {
extracted_window.push_back(extractBestTribit(window, i * symbol_size, symbol_size));
}
if (!matchesPattern(extracted_window, {0, 1, 3, 0, 1, 3, 1, 2, 0})) {
start += symbol_size;
return false;
}
start += sync_pattern_length * symbol_size;
size_t D1_index = start + symbol_size;
size_t D2_index = D1_index + symbol_size;
if (D2_index + symbol_size > detected_symbols.size()) {
start = detected_symbols.size();
return false;
}
D1 = extractBestTribit(detected_symbols, D1_index, symbol_size);
D2 = extractBestTribit(detected_symbols, D2_index, symbol_size);
// Process the count symbols (C1, C2, C3)
size_t C1_index = D2_index + symbol_size;
size_t C2_index = C1_index + symbol_size;
size_t C3_index = C2_index + symbol_size;
if (C3_index + symbol_size > detected_symbols.size()) {
start = detected_symbols.size();
return false;
}
uint8_t C1 = extractBestTribit(detected_symbols, C1_index, symbol_size);
uint8_t C2 = extractBestTribit(detected_symbols, C2_index, symbol_size);
uint8_t C3 = extractBestTribit(detected_symbols, C3_index, symbol_size);
segment_count = calculateSegmentCount(C1, C2, C3);
// Check for the constant zero pattern
size_t constant_zero_index = C3_index + symbol_size;
if (constant_zero_index + symbol_size > detected_symbols.size()) {
start = detected_symbols.size();
return false;
}
uint8_t constant_zero = extractBestTribit(detected_symbols, constant_zero_index, symbol_size);
if (constant_zero != 0) {
start = constant_zero_index + symbol_size;
return false; // Failed zero check, resync
}
// Successfully processed the segment
start = constant_zero_index + symbol_size; // Move start to next segment
return true;
}
bool processSyncSegments(const std::vector<uint8_t>& detected_symbols, size_t& baud_rate, size_t& interleave_setting, bool& is_voice) {
size_t symbol_size = 32;
size_t start = 0;
size_t segment_count = 0;
std::map<std::pair<uint8_t, uint8_t>, int> vote_map;
const int short_interleave_threshold = 2;
const int long_interleave_threshold = 5;
// Attempt to detect interleave setting dynamically
bool interleave_detected = false;
int current_threshold = short_interleave_threshold; // Start by assuming short interleave
while (start + symbol_size * 15 < detected_symbols.size()) {
uint8_t D1 = 0, D2 = 0;
if (processSegment(detected_symbols, start, symbol_size, segment_count, D1, D2)) {
vote_map[{D1, D2}]++;
// Check if we have enough votes to make a decision based on current interleave assumption
if (vote_map.size() >= current_threshold) {
auto majority_vote = std::max_element(vote_map.begin(), vote_map.end(), [](const auto& a, const auto& b) { return a.second < b.second; });
if (configureModem(majority_vote->first.first, majority_vote->first.second, baud_rate, interleave_setting, is_voice)) {
interleave_detected = true;
break; // Successfully configured modem, exit loop
} else {
// If configuration fails, retry with the other interleave type
if (current_threshold == short_interleave_threshold) {
current_threshold = long_interleave_threshold; // Switch to long interleave
vote_map.clear(); // Clear the vote map and start fresh
start = 0; // Restart segment processing
} else {
continue; // Both short and long interleave attempts failed, signal is not usable
}
}
}
if (segment_count > 0) {
while (segment_count > 0 && start < detected_symbols.size()) {
uint8_t dummy_D1, dummy_D2;
if (!processSegment(detected_symbols, start, symbol_size, segment_count, dummy_D1, dummy_D2)) {
continue;
}
}
}
} else {
start += symbol_size; // Move to the next segment
}
}
return interleave_detected;
}
std::vector<uint8_t> processDataSymbols(const std::vector<uint8_t>& detected_symbols) {
return std::vector<uint8_t>();
}
};
#endif

View File

@ -1,220 +0,0 @@
#ifndef BITSTREAM_H
#define BITSTREAM_H
#include <algorithm>
#include <cstdint>
#include <stdexcept>
#include <vector>
/**
* @class BitStream
* @brief A class to represent a stream of bits with bit-level read and write access.
*
* The BitStream class provides functionality to manipulate a byte stream at the bit level.
* It derives from std::vector<uint8_t> to utilize the benefits of a byte vector while providing
* additional methods for bit manipulation.
*/
class BitStream : public std::vector<uint8_t> {
public:
/**
* @brief Default constructor.
*/
BitStream() : std::vector<uint8_t>(), bit_index(0), max_bit_idx(0) {}
/**
* @brief Constructs a BitStream from an existing vector of bytes.
* @param data The byte stream to be used for initializing the BitStream.
*/
BitStream(const std::vector<uint8_t>& data) : std::vector<uint8_t>(data), bit_index(0), max_bit_idx(data.size() * 8) {}
/**
* @brief Constructs a BitStream from an existing vector of bytes with a specified bit size.
* @param data The byte stream to be used for initializing the BitStream.
* @param size_in_bits The number of bits to consider in the stream.
*/
BitStream(const std::vector<uint8_t>& data, size_t size_in_bits) : std::vector<uint8_t>(data), bit_index(0), max_bit_idx(size_in_bits) {}
/**
* @brief Copy constructor from another BitStream.
* @param data The BitStream to copy from.
*/
BitStream(const BitStream& data) : std::vector<uint8_t>(data), bit_index(0), max_bit_idx(data.max_bit_idx) {}
/**
* @brief Constructs a BitStream from a substream of another BitStream.
* @param other The original BitStream.
* @param start_bit The starting bit index of the substream.
* @param end_bit The ending bit index of the substream (exclusive).
* @throws std::out_of_range if start or end indices are out of bounds.
*/
BitStream(const BitStream& other, size_t start_bit, size_t end_bit) : bit_index(0) {
if (start_bit >= other.max_bit_idx || end_bit > other.max_bit_idx || start_bit > end_bit) {
throw std::out_of_range("BitStream substream indices are out of range.");
}
max_bit_idx = end_bit - start_bit;
for (size_t i = start_bit; i < end_bit; i++) {
putBit(other.getBitVal(i));
}
}
/**
* @brief Reads the next bit from the stream.
* @return The next bit (0 or 1).
* @throws std::out_of_range if no more bits are available in the stream.
*/
int getNextBit() {
if (bit_index >= max_bit_idx) {
throw std::out_of_range("No more bits available in the stream.");
}
int bit = getBitVal(bit_index++);
return bit;
}
/**
* @brief Gets the value of a bit at a specific index.
* @param idx The index of the bit to be retrieved.
* @return The value of the bit (0 or 1).
* @throws std::out_of_range if the bit index is out of range.
*/
int getBitVal(const size_t idx) const {
if (idx >= max_bit_idx) {
throw std::out_of_range("Bit index out of range in getBitVal.");
}
size_t byte_idx = idx / 8;
size_t bit_idx = idx % 8;
uint8_t tmp = this->at(byte_idx);
uint8_t mask = 0x80 >> bit_idx;
uint8_t result = tmp & mask;
return result ? 1 : 0;
}
/**
* @brief Checks if there are more bits available in the stream.
* @return True if there are more bits available, otherwise false.
*/
bool hasNext() const {
return bit_index < max_bit_idx;
}
/**
* @brief Sets a specific bit value in the stream.
* @param idx The index of the bit to set.
* @param val The value to set the bit to (0 or 1).
*
* This function ensures that the stream has enough bytes to accommodate
* the given bit index. If the bit index is out of bounds, the stream is
* resized accordingly.
*/
void setBitVal(const size_t idx, uint8_t val) {
size_t byte_idx = idx / 8;
size_t bit_idx = idx % 8;
uint8_t mask = 0x80 >> bit_idx;
if (byte_idx >= this->size()) {
this->resize(byte_idx + 1, 0);
}
if (val == 0) {
this->at(byte_idx) &= ~mask;
} else {
this->at(byte_idx) |= mask;
}
}
/**
* @brief Appends a bit to the end of the stream.
* @param bit The value of the bit to append (0 or 1).
*
* This function keeps track of the current bit index and appends bits
* sequentially. If necessary, the stream is resized to accommodate the new bit.
*/
void putBit(uint8_t bit) {
size_t byte_idx = max_bit_idx / 8;
if (byte_idx >= this->size()) {
this->push_back(0);
}
size_t bit_idx = max_bit_idx % 8;
setBitVal(max_bit_idx, bit);
max_bit_idx += 1;
}
/**
* @brief Resets the bit index to the beginning of the stream.
*/
void resetBitIndex() {
bit_index = 0;
}
/**
* @brief Returns the maximum bit index value (total number of bits in the stream).
* @return The total number of bits in the stream.
*/
size_t getMaxBitIndex() const {
return max_bit_idx;
}
BitStream& operator=(const BitStream& other) {
this->clear();
this->resize(other.size());
std::copy(other.begin(), other.end(), this->begin());
this->bit_index = other.bit_index;
this->max_bit_idx = other.max_bit_idx;
return *this;
}
/**
* @brief Adds the contents of another BitStream to the current BitStream.
* @param other The BitStream to be added.
* @return Reference to the current BitStream after adding.
*/
BitStream& operator+=(const BitStream& other) {
size_t other_max_bit_idx = other.getMaxBitIndex();
for (size_t i = 0; i < other_max_bit_idx; i++) {
this->putBit(other.getBitVal(i));
}
return *this;
}
/**
* @brief Gets a substream from the current BitStream.
* @param start_bit The starting bit index of the substream.
* @param end_bit The ending bit index of the substream (exclusive).
* @return A new BitStream containing the specified substream.
* @throws std::out_of_range if start or end indices are out of bounds.
*/
BitStream getSubStream(size_t start_bit, size_t end_bit) const {
if (start_bit >= max_bit_idx || end_bit > max_bit_idx || start_bit > end_bit) {
throw std::out_of_range("BitStream substream indices are out of range.");
}
BitStream substream;
for (size_t i = start_bit; i < end_bit; i++) {
substream.putBit(getBitVal(i));
}
return substream;
}
/**
* @brief Returns the current bit index in the stream.
* @return The current bit index.
*/
size_t getCurrentBitIndex() const {
return bit_index;
}
private:
size_t bit_index; ///< The current bit index in the stream.
size_t max_bit_idx; ///< The total number of bits in the stream.
};
BitStream operator+(const BitStream& lhs, const BitStream& rhs) {
BitStream result = lhs;
result += rhs;
return result;
}
#endif

View File

@ -1,113 +0,0 @@
#include <complex>
#include <cmath>
#include <vector>
#include <iostream>
#include "filters.h"
class PhaseDetector {
public:
PhaseDetector() {}
PhaseDetector(const std::vector<std::complex<double>>& _symbolMap) : symbolMap(_symbolMap) {}
uint8_t getSymbol(const std::complex<double>& input) {
double phase = std::atan2(input.imag(), input.real());
return symbolFromPhase(phase);
}
private:
std::vector<std::complex<double>> symbolMap;
uint8_t symbolFromPhase(const double phase) {
// Calculate the closest symbol based on phase difference
double min_distance = 2 * M_PI; // Maximum possible phase difference
uint8_t closest_symbol = 0;
for (uint8_t i = 0; i < symbolMap.size(); ++i) {
double symbol_phase = std::atan2(symbolMap[i].imag(), symbolMap[i].real());
double distance = std::abs(symbol_phase - phase);
if (distance < min_distance) {
min_distance = distance;
closest_symbol = i;
}
}
return closest_symbol;
}
};
class CostasLoop {
public:
CostasLoop(const double _carrier_freq, const double _sample_rate, const std::vector<std::complex<double>>& _symbolMap, const double _vco_gain, const double _alpha, const double _beta)
: carrier_freq(_carrier_freq), sample_rate(_sample_rate), vco_gain(_vco_gain), alpha(_alpha), beta(_beta), freq_error(0.0), k_factor(-1 / (_sample_rate * _vco_gain)),
prev_in_iir(0), prev_out_iir(0), prev_in_vco(0), feedback(1.0, 0.0),
error_total(0), out_iir_total(0), in_vco_total(0),
srrc_filter(SRRCFilter(48, _sample_rate, 2400, 0.35)) {}
std::vector<std::complex<double>> process(const std::vector<double>& input_signal) {
std::vector<std::complex<double>> output_signal(input_signal.size());
double current_phase = 0.0;
error_total = 0;
out_iir_total = 0;
in_vco_total = 0;
for (size_t i = 0; i < input_signal.size(); ++i) {
// Multiply input by feedback signal
std::complex<double> multiplied = input_signal[i] * feedback;
// Filter signal
std::complex<double> filtered = srrc_filter.filterSample(multiplied);
// Output best-guess corrected I/Q components
output_signal[i] = filtered;
// Generate limited components
std::complex<double> limited = limiter(filtered);
// IIR Filter
double error_real = (limited.real() > 0 ? 1.0 : -1.0) * limited.imag();
double error_imag = (limited.imag() > 0 ? 1.0 : -1.0) * limited.real();
double phase_error = error_real - error_imag;
phase_error = 0.5 * (std::abs(phase_error + 1) - std::abs(phase_error - 1));
freq_error += beta * phase_error;
double phase_adjust = alpha * phase_error + freq_error;
current_phase += 2 * M_PI * carrier_freq / sample_rate + k_factor * phase_adjust;
if (current_phase > M_PI) current_phase -= 2 * M_PI;
else if (current_phase < -M_PI) current_phase += 2 * M_PI;
// Generate feedback signal for next iteration
double feedback_real = std::cos(current_phase);
double feedback_imag = -std::sin(current_phase);
feedback = std::complex<double>(feedback_real, feedback_imag);
}
return output_signal;
}
private:
double carrier_freq;
double sample_rate;
double k_factor;
double prev_in_iir;
double prev_out_iir;
double prev_in_vco;
std::complex<double> feedback;
double error_total;
double out_iir_total;
double in_vco_total;
SRRCFilter srrc_filter;
double vco_gain;
double alpha;
double beta;
double freq_error;
std::complex<double> limiter(const std::complex<double>& sample) const {
double limited_I = std::clamp(sample.real(), -1.0, 1.0);
double limited_Q = std::clamp(sample.imag(), -1.0, 1.0);
return std::complex<double>(limited_I, limited_Q);
}
};

View File

@ -1,151 +0,0 @@
#ifndef FILTERS_H
#define FILTERS_H
#include <cmath>
#include <cstdint>
#include <fftw3.h>
#include <numeric>
#include <vector>
class TapGenerators {
public:
std::vector<double> generateSRRCTaps(size_t num_taps, double sample_rate, double symbol_rate, double rolloff) const {
std::vector<double> taps(num_taps);
double T = 1.0 / symbol_rate; // Symbol period
double dt = 1.0 / sample_rate; // Time step
double t_center = (num_taps - 1) / 2.0;
for (size_t i = 0; i < num_taps; ++i) {
double t = (i - t_center) * dt;
double sinc_part = (t == 0.0) ? 1.0 : std::sin(M_PI * t / T * (1 - rolloff)) / (M_PI * t / T * (1 - rolloff));
double cos_part = (t == 0.0) ? std::cos(M_PI * t / T * (1 + rolloff)) : std::cos(M_PI * t / T * (1 + rolloff));
double denominator = 1.0 - (4.0 * rolloff * t / T) * (4.0 * rolloff * t / T);
if (std::fabs(denominator) < 1e-8) {
// Handle singularity at t = T / (4R)
taps[i] = rolloff * (std::sin(M_PI / (4.0 * rolloff)) + (1.0 / (4.0 * rolloff)) * std::cos(M_PI / (4.0 * rolloff))) / (M_PI / (4.0 * rolloff));
} else {
taps[i] = (4.0 * rolloff / (M_PI * std::sqrt(T))) * (cos_part / denominator);
}
taps[i] *= sinc_part;
}
// Normalize filter taps
double sum = std::accumulate(taps.begin(), taps.end(), 0.0);
for (auto& tap : taps) {
tap /= sum;
}
return taps;
}
std::vector<double> generateLowpassTaps(size_t num_taps, double cutoff_freq, double sample_rate) const {
std::vector<double> taps(num_taps);
double fc = cutoff_freq / (sample_rate / 2.0); // Normalized cutoff frequency (0 < fc < 1)
double M = num_taps - 1;
double mid = M / 2.0;
for (size_t n = 0; n < num_taps; ++n) {
double n_minus_mid = n - mid;
double h;
if (n_minus_mid == 0.0) {
h = fc;
} else {
h = fc * (std::sin(M_PI * fc * n_minus_mid) / (M_PI * fc * n_minus_mid));
}
// Apply window function (e.g., Hamming window)
double window = 0.54 - 0.46 * std::cos(2.0 * M_PI * n / M);
taps[n] = h * window;
}
// Normalize filter taps
double sum = std::accumulate(taps.begin(), taps.end(), 0.0);
for (auto& tap : taps) {
tap /= sum;
}
return taps;
}
};
class Filter {
public:
Filter(const std::vector<double>& _filter_taps) : filter_taps(_filter_taps), buffer(_filter_taps.size(), 0.0), buffer_index(0) {}
double filterSample(const double sample) {
buffer[buffer_index] = std::complex<double>(sample, 0.0);
double filtered_val = 0.0;
size_t idx = buffer_index;
for (size_t j = 0; j < filter_taps.size(); j++) {
filtered_val += filter_taps[j] * buffer[idx].real();
if (idx == 0) {
idx = buffer.size() - 1;
} else {
idx--;
}
}
buffer_index = (buffer_index + 1) % buffer.size();
return filtered_val;
}
std::complex<double> filterSample(const std::complex<double> sample) {
buffer[buffer_index] = sample;
std::complex<double> filtered_val = std::complex<double>(0.0, 0.0);
size_t idx = buffer_index;
for (size_t j = 0; j < filter_taps.size(); j++) {
filtered_val += filter_taps[j] * buffer[idx];
if (idx == 0) {
idx = buffer.size() - 1;
} else {
idx--;
}
}
buffer_index = (buffer_index + 1) % buffer.size();
return filtered_val;
}
std::vector<double> applyFilter(const std::vector<double>& signal) {
std::vector<double> filtered_signal(signal.size(), 0.0);
// Convolve the signal with the filter taps
for (size_t i = 0; i < signal.size(); ++i) {
filtered_signal[i] = filterSample(signal[i]);
}
return filtered_signal;
}
std::vector<std::complex<double>> applyFilter(const std::vector<std::complex<double>>& signal) {
std::vector<std::complex<double>> filtered_signal(signal.size(), std::complex<double>(0.0, 0.0));
// Convolve the signal with the filter taps
for (size_t i = 0; i < signal.size(); ++i) {
filtered_signal[i] = filterSample(signal[i]);
}
return filtered_signal;
}
private:
std::vector<double> filter_taps;
std::vector<std::complex<double>> buffer;
size_t buffer_index;
};
class SRRCFilter : public Filter {
public:
SRRCFilter(const size_t num_taps, const double sample_rate, const double symbol_rate, const double rolloff) : Filter(TapGenerators().generateSRRCTaps(num_taps, sample_rate, symbol_rate, rolloff)) {}
};
class LowPassFilter : public Filter {
public:
LowPassFilter(const size_t num_taps, const double cutoff_freq, const double sample_rate) : Filter(TapGenerators().generateLowpassTaps(num_taps, cutoff_freq, sample_rate)) {}
};
#endif /* FILTERS_H */

View File

@ -1,410 +0,0 @@
#ifndef WATTERSONCHANNEL_H
#define WATTERSONCHANNEL_H
#include <iostream>
#include <complex>
#include <vector>
#include <cmath>
#include <random>
#include <algorithm>
#include <functional>
#include <fftw3.h> // FFTW library for FFT-based Hilbert transform
constexpr double PI = 3.14159265358979323846;
class WattersonChannel {
public:
WattersonChannel(double sampleRate, double symbolRate, double delaySpread, double fadingBandwidth, double SNRdB, int numSamples, int numpaths, bool isFading);
// Process a block of input samples
void process(const std::vector<double>& inputSignal, std::vector<double>& outputSignal);
private:
double Fs; // Sample rate
double Rs; // Symbol rate
double delaySpread; // Delay spread in seconds
std::vector<int> delays = {0, L};
double fadingBandwidth; // Fading bandwidth d in Hz
double SNRdB; // SNR in dB
int L; // Length of the simulated channel
std::vector<double> f_jt; // Filter impulse response
std::vector<std::vector<std::complex<double>>> h_j; // Fading tap gains over time for h_0 and h_(L-1)
double Ts; // Sample period
double k; // Normalization constant for filter
double tau; // Truncation width
double fadingSampleRate; // Sample rate for fading process
std::vector<std::vector<double>> wgnFadingReal; // WGN samples for fading (double part)
std::vector<std::vector<double>> wgnFadingImag; // WGN samples for fading (imaginary part)
std::vector<std::complex<double>> n_i; // WGN samples for noise
std::mt19937 rng; // Random number generator
int numSamples; // Number of samples in the simulation
int numFadingSamples; // Number of fading samples
int numPaths;
bool isFading;
void normalizeTapGains();
void generateFilter();
void generateFadingTapGains();
void generateNoise(const std::vector<std::complex<double>>& x_i);
void generateWGN(std::vector<double>& wgn, int numSamples);
void resampleFadingTapGains();
void hilbertTransform(const std::vector<double>& input, std::vector<std::complex<double>>& output);
};
WattersonChannel::WattersonChannel(double sampleRate, double symbolRate, double delaySpread, double fadingBandwidth, double SNRdB, int numSamples, int numPaths, bool isFading)
: Fs(sampleRate), Rs(symbolRate), delaySpread(delaySpread), fadingBandwidth(fadingBandwidth), SNRdB(SNRdB), numSamples(numSamples), rng(std::random_device{}()), numPaths(numPaths), isFading(isFading)
{
Ts = 1.0 / Fs;
// Compute L
if (numPaths == 1) {
L = 1;
} else {
L = static_cast<int>(std::round(delaySpread / Ts));
if (L < 1) L = 1;
}
// Compute truncation width tau
double ln100 = std::log(100.0);
tau = std::sqrt(ln100) / (PI * fadingBandwidth);
// Initialize k (will be normalized later)
k = 1.0;
// Fading sample rate, at least 32 times the fading bandwidth
fadingSampleRate = std::max(32.0 * fadingBandwidth, Fs);
h_j.resize(numPaths);
wgnFadingReal.resize(numPaths);
wgnFadingImag.resize(numPaths);
if (isFading) {
// Generate filter impulse response
generateFilter();
// Number of fading samples
double simulationTime = numSamples / Fs;
numFadingSamples = static_cast<int>(std::ceil(simulationTime * fadingSampleRate));
// Generate WGN for fading
for (int pathIndex = 0; pathIndex < numPaths; ++pathIndex) {
generateWGN(wgnFadingReal[pathIndex], numFadingSamples);
generateWGN(wgnFadingImag[pathIndex], numFadingSamples);
}
// Generate fading tap gains
generateFadingTapGains();
// Resample fading tap gains to match sample rate Fs
resampleFadingTapGains();
} else {
// For fixed channel, set tap gains directly
generateFadingTapGains();
}
// Generate noise n_i
}
void WattersonChannel::normalizeTapGains() {
double totalPower = 0.0;
int numValidSamples = h_j[0].size();
for (int i = 0; i < numValidSamples; i++) {
for (int pathIndex = 0; pathIndex < numPaths; pathIndex++) {
totalPower += std::norm(h_j[pathIndex][i]);
}
}
totalPower /= numValidSamples;
double normFactor = 1.0 / std::sqrt(totalPower);
for (int pathIndex = 0; pathIndex < numPaths; pathIndex++) {
for (auto& val : h_j[pathIndex]) {
val *= normFactor;
}
}
}
void WattersonChannel::generateFilter()
{
// Generate filter impulse response f_j(t) = k * sqrt(2) * e^{-π² * t² * d²}, -tau < t < tau
// Number of filter samples
int numFilterSamples = static_cast<int>(std::ceil(2 * tau * fadingSampleRate)) + 1; // Include center point
f_jt.resize(numFilterSamples);
double dt = 1.0 / fadingSampleRate;
int halfSamples = numFilterSamples / 2;
double totalEnergy = 0.0;
for (int n = 0; n < numFilterSamples; ++n) {
double t = (n - halfSamples) * dt;
double val = k * std::sqrt(2.0) * std::exp(-PI * PI * t * t * fadingBandwidth * fadingBandwidth);
f_jt[n] = val;
totalEnergy += val * val * dt;
}
// Normalize k so that total energy is 1.0
double k_new = k / std::sqrt(totalEnergy);
for (auto& val : f_jt) {
val *= k_new;
}
k = k_new;
}
void WattersonChannel::generateFadingTapGains()
{
if (!isFading) {
for (int pathIndex = 0; pathIndex < numPaths; pathIndex++) {
h_j[pathIndex].assign(numSamples, std::complex<double>(1.0, 0.0));
}
} else {
// Prepare for FFT-based convolution
int convSize = numFadingSamples + f_jt.size() - 1;
int fftSize = 1;
while (fftSize < convSize) {
fftSize <<= 1; // Next power of two
}
std::vector<double> f_jtPadded(fftSize, 0.0);
std::copy(f_jt.begin(), f_jt.end(), f_jtPadded.begin());
fftw_complex* f_jtFFT = fftw_alloc_complex(fftSize);
fftw_plan planF_jt = fftw_plan_dft_r2c_1d(fftSize, f_jtPadded.data(), f_jtFFT, FFTW_ESTIMATE);
fftw_execute(planF_jt);
for (int pathIndex = 0; pathIndex < numPaths; pathIndex++) {
// Zero-pad inputs
std::vector<double> wgnRealPadded(fftSize, 0.0);
std::vector<double> wgnImagPadded(fftSize, 0.0);
std::copy(wgnFadingReal[pathIndex].begin(), wgnFadingReal[pathIndex].end(), wgnRealPadded.begin());
std::copy(wgnFadingImag[pathIndex].begin(), wgnFadingImag[pathIndex].end(), wgnImagPadded.begin());
// Perform FFTs
fftw_complex* WGNRealFFT = fftw_alloc_complex(fftSize);
fftw_complex* WGNImagFFT = fftw_alloc_complex(fftSize);
fftw_plan planWGNReal = fftw_plan_dft_r2c_1d(fftSize, wgnRealPadded.data(), WGNRealFFT, FFTW_ESTIMATE);
fftw_plan planWGNImag = fftw_plan_dft_r2c_1d(fftSize, wgnImagPadded.data(), WGNImagFFT, FFTW_ESTIMATE);
fftw_execute(planWGNReal);
fftw_execute(planWGNImag);
// Multiply in frequency domain
int fftComplexSize = fftSize / 2 + 1;
for (int i = 0; i < fftComplexSize; ++i) {
// Multiply WGNRealFFT and f_jtFFT
double realPart = WGNRealFFT[i][0] * f_jtFFT[i][0] - WGNRealFFT[i][1] * f_jtFFT[i][1];
double imagPart = WGNRealFFT[i][0] * f_jtFFT[i][1] + WGNRealFFT[i][1] * f_jtFFT[i][0];
WGNRealFFT[i][0] = realPart;
WGNRealFFT[i][1] = imagPart;
// Multiply WGNImagFFT and f_jtFFT
realPart = WGNImagFFT[i][0] * f_jtFFT[i][0] - WGNImagFFT[i][1] * f_jtFFT[i][1];
imagPart = WGNImagFFT[i][0] * f_jtFFT[i][1] + WGNImagFFT[i][1] * f_jtFFT[i][0];
WGNImagFFT[i][0] = realPart;
WGNImagFFT[i][1] = imagPart;
}
// Perform inverse FFTs
fftw_plan planInvReal = fftw_plan_dft_c2r_1d(fftSize, WGNRealFFT, wgnRealPadded.data(), FFTW_ESTIMATE);
fftw_plan planInvImag = fftw_plan_dft_c2r_1d(fftSize, WGNImagFFT, wgnImagPadded.data(), FFTW_ESTIMATE);
fftw_execute(planInvReal);
fftw_execute(planInvImag);
// Normalize
double scale = 1.0 / fftSize;
for (int i = 0; i < fftSize; ++i) {
wgnRealPadded[i] *= scale;
wgnImagPadded[i] *= scale;
}
// Assign h_j[0] and h_j[1]
int numValidSamples = numFadingSamples;
h_j[pathIndex].resize(numValidSamples);
for (int i = 0; i < numValidSamples; i++) {
h_j[pathIndex][i] = std::complex<double>(wgnRealPadded[i], wgnImagPadded[i]);
}
// Clean up
fftw_destroy_plan(planWGNReal);
fftw_destroy_plan(planWGNImag);
fftw_destroy_plan(planInvReal);
fftw_destroy_plan(planInvImag);
fftw_free(WGNRealFFT);
fftw_free(WGNImagFFT);
}
fftw_destroy_plan(planF_jt);
fftw_free(f_jtFFT);
normalizeTapGains();
}
}
void WattersonChannel::resampleFadingTapGains()
{
// Resample h_j[0] and h_j[1] from fadingSampleRate to Fs
int numOutputSamples = numSamples;
double resampleRatio = fadingSampleRate / Fs;
for (int pathIndex = 0; pathIndex < numPaths; pathIndex++) {
std::vector<std::complex<double>> resampled_h(numOutputSamples);
for (int i = 0; i < numOutputSamples; ++i) {
double t = i * (1.0 / Fs);
double index = t * fadingSampleRate;
int idx = static_cast<int>(index);
double frac = index - idx;
// Simple linear interpolation
if (idx + 1 < h_j[pathIndex].size()) {
resampled_h[i] = h_j[pathIndex][idx] * (1.0 - frac) + h_j[pathIndex][idx + 1] * frac;
}
else if (idx < h_j[pathIndex].size()) {
resampled_h[i] = h_j[pathIndex][idx];
}
else {
resampled_h[i] = std::complex<double>(0.0, 0.0);
}
}
h_j[pathIndex] = std::move(resampled_h);
}
}
void WattersonChannel::generateNoise(const std::vector<std::complex<double>>& x_i)
{
// Generate WGN samples for noise n_i with appropriate power to achieve the specified SNR
n_i.resize(numSamples);
double inputSignalPower = 0.0;
for (const auto& sample : x_i) {
inputSignalPower += std::norm(sample);
}
inputSignalPower /= x_i.size();
// Compute signal power (assuming average power of input signal x_i is normalized to 1.0)
double channelGainPower = 0.0;
for (int i = 0; i < numSamples; i++) {
std::complex<double> combinedGain = std::complex<double>(0.0, 0.0);
for (int pathIndex = 0; pathIndex < numPaths; pathIndex++) {
combinedGain += h_j[pathIndex][i];
}
channelGainPower += std::norm(combinedGain);
}
channelGainPower /= numSamples;
double signalPower = inputSignalPower * channelGainPower;
// Compute noise power
double SNR_linear = std::pow(10.0, SNRdB / 10.0);
double noisePower = signalPower / SNR_linear;
std::normal_distribution<double> normalDist(0.0, std::sqrt(noisePower / 2.0)); // Divided by 2 for double and imag parts
for (int i = 0; i < numSamples; ++i) {
double realPart = normalDist(rng);
double imagPart = normalDist(rng);
n_i[i] = std::complex<double>(realPart, imagPart);
}
}
void WattersonChannel::generateWGN(std::vector<double>& wgn, int numSamples)
{
wgn.resize(numSamples);
std::normal_distribution<double> normalDist(0.0, 1.0); // Standard normal distribution
for (int i = 0; i < numSamples; ++i) {
wgn[i] = normalDist(rng);
}
}
void WattersonChannel::hilbertTransform(const std::vector<double>& input, std::vector<std::complex<double>>& output)
{
// Implement Hilbert transform using FFT method
int N = input.size();
// Allocate input and output arrays for FFTW
double* in = fftw_alloc_real(N);
fftw_complex* out = fftw_alloc_complex(N);
// Copy input signal to in array
for (int i = 0; i < N; ++i) {
in[i] = input[i];
}
// Create plan for forward FFT
fftw_plan plan_forward = fftw_plan_dft_r2c_1d(N, in, out, FFTW_ESTIMATE);
// Execute forward FFT
fftw_execute(plan_forward);
// Apply the Hilbert transform in frequency domain
// For positive frequencies, multiply by 2; for zero and negative frequencies, set to zero
int N_half = N / 2 + 1;
for (int i = 0; i < N_half; ++i) {
if (i == 0 || i == N / 2) { // DC and Nyquist frequency components
out[i][0] = 0.0;
out[i][1] = 0.0;
}
else {
out[i][0] *= 2.0;
out[i][1] *= 2.0;
}
}
// Create plan for inverse FFT
fftw_plan plan_backward = fftw_plan_dft_c2r_1d(N, out, in, FFTW_ESTIMATE);
// Execute inverse FFT
fftw_execute(plan_backward);
// Normalize and store result in output vector
output.resize(N);
double scale = 1.0 / N;
for (int i = 0; i < N; ++i) {
output[i] = std::complex<double>(input[i], in[i] * scale);
}
// Clean up
fftw_destroy_plan(plan_forward);
fftw_destroy_plan(plan_backward);
fftw_free(in);
fftw_free(out);
}
void WattersonChannel::process(const std::vector<double>& inputSignal, std::vector<double>& outputSignal)
{
// Apply Hilbert transform to input signal to get complex x_i
std::vector<std::complex<double>> x_i;
hilbertTransform(inputSignal, x_i);
generateNoise(x_i);
// Process the signal through the channel
std::vector<std::complex<double>> y_i(numSamples);
// For each sample, compute y_i = h_j[0][i] * x_i + h_j[1][i] * x_{i - (L - 1)} + n_i[i]
for (int i = 0; i < numSamples; ++i) {
std::complex<double> y = n_i[i];
for (int pathIndex = 0; pathIndex < numPaths; pathIndex++) {
int delay = delays[pathIndex];
int idx = i - delay;
if (idx >= 0) {
y += h_j[pathIndex][i] * x_i[idx];
}
}
y_i[i] = y;
}
// Output the double part of y_i
outputSignal.resize(numSamples);
for (int i = 0; i < numSamples; ++i) {
outputSignal[i] = y_i[i].real();
}
}
#endif

256
main.cpp
View File

@ -1,242 +1,52 @@
// main.cpp
#include <bitset>
#include <fstream>
#include <iostream>
#include <string>
#include <sndfile.h>
#include <vector>
#include <cmath>
#include <complex>
#include <random>
#include <sndfile.h> // For WAV file handling
// GNU Radio headers
#include <gnuradio/top_block.h>
#include <gnuradio/blocks/vector_source.h>
#include <gnuradio/blocks/vector_sink.h>
#include <gnuradio/blocks/wavfile_sink.h>
#include <gnuradio/blocks/wavfile_source.h>
#include <gnuradio/blocks/multiply.h>
#include <gnuradio/blocks/complex_to_real.h>
#include <gnuradio/blocks/add_blk.h>
#include <gnuradio/analog/sig_source.h>
#include <gnuradio/analog/noise_source.h>
#include <gnuradio/filter/hilbert_fc.h>
#include <gnuradio/channels/selective_fading_model.h>
// Include your ModemController and BitStream classes
#include "ModemController.h"
#include "bitstream.h"
// Function to generate Bernoulli data
BitStream generateBernoulliData(const size_t length, const double p = 0.5, const unsigned int seed = 0) {
BitStream random_data;
std::mt19937 gen(seed);
std::bernoulli_distribution dist(p);
for (size_t i = 0; i < length * 8; ++i) {
random_data.putBit(dist(gen));
}
return random_data;
}
// Function to write int16_t data to a WAV file
void writeWavFile(const std::string& filename, const std::vector<int16_t>& data, float sample_rate) {
SF_INFO sfinfo;
sfinfo.channels = 1;
sfinfo.samplerate = static_cast<int>(sample_rate);
sfinfo.format = SF_FORMAT_WAV | SF_FORMAT_PCM_16;
SNDFILE* outfile = sf_open(filename.c_str(), SFM_WRITE, &sfinfo);
if (!outfile) {
std::cerr << "Error opening output file: " << sf_strerror(nullptr) << std::endl;
return;
}
sf_count_t frames_written = sf_write_short(outfile, data.data(), data.size());
if (frames_written != static_cast<sf_count_t>(data.size())) {
std::cerr << "Error writing to output file: " << sf_strerror(outfile) << std::endl;
}
sf_close(outfile);
}
int main() {
// Step 1: Gather parameters and variables
// Sample test data
std::string sample_string = "The quick brown fox jumps over the lazy dog 1234567890";
std::vector<uint8_t> sample_data(sample_string.begin(), sample_string.end());
// Define the preset based on your table (e.g., 4800 bps, 2 fading paths)
struct ChannelPreset {
size_t user_bit_rate;
int num_paths;
bool is_fading;
float multipath_ms;
float fading_bw_hz;
float snr_db;
double target_ber;
};
// Convert sample data to a BitStream object
BitStream bitstream(sample_data, sample_data.size() * 8);
// For this example, let's use the second preset from your table
ChannelPreset preset = {
4800, // user_bit_rate
2, // num_paths
true, // is_fading
2.0f, // multipath_ms
0.5f, // fading_bw_hz
27.0f, // snr_db
1e-3 // target_ber
};
// Configuration for modem
size_t baud_rate = 150;
bool is_voice = false; // False indicates data mode
bool is_frequency_hopping = false; // Fixed frequency operation
size_t interleave_setting = 2; // Short interleave
// Sampling rate (Hz)
double Fs = 48000.0; // Adjust to match your modem's sampling rate
double Ts = 1.0 / Fs;
// Create ModemController instance
ModemController modem(baud_rate, is_voice, is_frequency_hopping, interleave_setting, bitstream);
// Carrier frequency (Hz)
float carrier_freq = 1800.0f; // Adjust to match your modem's carrier frequency
const char* file_name = "modulated_signal_150bps_longinterleave.wav";
// Step 2: Initialize the modem
size_t baud_rate = preset.user_bit_rate;
bool is_voice = false;
bool is_frequency_hopping = false;
size_t interleave_setting = 2; // Adjust as necessary
// Perform transmit operation to generate modulated signal
std::vector<int16_t> modulated_signal = modem.transmit();
ModemController modem(baud_rate, is_voice, is_frequency_hopping, interleave_setting);
// Output modulated signal to a WAV file using libsndfile
SF_INFO sfinfo;
sfinfo.channels = 1;
sfinfo.samplerate = 48000;
sfinfo.format = SF_FORMAT_WAV | SF_FORMAT_PCM_16;
// Step 3: Generate input modulator data
size_t data_length = 28800; // Length in bytes
unsigned int data_seed = 42; // Random seed
BitStream input_data = generateBernoulliData(data_length, 0.5, data_seed);
// Step 4: Use the modem to modulate the input data
std::vector<int16_t> passband_signal = modem.transmit(input_data);
// Write the raw passband audio to a WAV file
writeWavFile("modem_output_raw.wav", passband_signal, Fs);
// Step 5: Process the modem output through the channel model
// Convert passband audio to float and normalize
std::vector<float> passband_signal_float(passband_signal.size());
for (size_t i = 0; i < passband_signal.size(); ++i) {
passband_signal_float[i] = passband_signal[i] / 32768.0f;
SNDFILE* sndfile = sf_open(file_name, SFM_WRITE, &sfinfo);
if (sndfile == nullptr) {
std::cerr << "Unable to open WAV file for writing modulated signal: " << sf_strerror(sndfile) << "\n";
return 1;
}
// Create GNU Radio top block
auto tb = gr::make_top_block("Passband to Baseband and Channel Model");
sf_write_short(sndfile, modulated_signal.data(), modulated_signal.size());
sf_close(sndfile);
std::cout << "Modulated signal written to " << file_name << '\n';
// Create vector source from passband signal
auto src = gr::blocks::vector_source_f::make(passband_signal_float, false);
// Apply Hilbert Transform to get analytic signal
int hilbert_taps = 129; // Number of taps
auto hilbert = gr::filter::hilbert_fc::make(hilbert_taps);
// Multiply by complex exponential to shift to baseband
auto freq_shift_down = gr::analog::sig_source_c::make(
Fs, gr::analog::GR_COS_WAVE, -carrier_freq, 1.0f, 0.0f);
auto multiplier_down = gr::blocks::multiply_cc::make();
// Connect the blocks for downconversion
tb->connect(src, 0, hilbert, 0);
tb->connect(hilbert, 0, multiplier_down, 0);
tb->connect(freq_shift_down, 0, multiplier_down, 1);
// At this point, multiplier_down outputs the complex baseband signal
// Configure the channel model parameters
std::vector<float> delays = {0.0f};
std::vector<float> mags = {1.0f};
if (preset.num_paths == 2 && preset.multipath_ms > 0.0f) {
delays.push_back(preset.multipath_ms / 1000.0f); // Convert ms to seconds
float path_gain = 1.0f / sqrtf(2.0f); // Equal average power
mags[0] = path_gain;
mags.push_back(path_gain);
}
int N = 8; // Number of sinusoids
bool LOS = false; // Rayleigh fading
float K = 0.0f; // K-factor
unsigned int seed = 0;
int ntaps = 64; // Number of taps
float fD = preset.fading_bw_hz; // Maximum Doppler frequency in Hz
float fDTs = fD * Ts; // Normalized Doppler frequency
auto channel_model = gr::channels::selective_fading_model::make(
N, fDTs, LOS, K, seed, delays, mags, ntaps);
// Add AWGN to the signal
float SNR_dB = preset.snr_db;
float SNR_linear = powf(10.0f, SNR_dB / 10.0f);
float signal_power = 0.0f; // Assume normalized
for (const auto& sample : passband_signal_float) {
signal_power += sample * sample;
}
signal_power /= passband_signal_float.size();
float noise_power = signal_power / SNR_linear;
float noise_voltage = sqrtf(noise_power);
auto noise_src = gr::analog::noise_source_c::make(
gr::analog::GR_GAUSSIAN, noise_voltage, seed);
auto adder = gr::blocks::add_cc::make();
// Connect the blocks for channel model and noise addition
tb->connect(multiplier_down, 0, channel_model, 0);
tb->connect(channel_model, 0, adder, 0);
tb->connect(noise_src, 0, adder, 1);
// Multiply by complex exponential to shift back to passband
auto freq_shift_up = gr::analog::sig_source_c::make(
Fs, gr::analog::GR_COS_WAVE, carrier_freq, 1.0f, 0.0f);
auto multiplier_up = gr::blocks::multiply_cc::make();
// Connect the blocks for upconversion
tb->connect(adder, 0, multiplier_up, 0);
tb->connect(freq_shift_up, 0, multiplier_up, 1);
// Convert to real signal
auto complex_to_real = gr::blocks::complex_to_real::make();
// Connect the blocks
tb->connect(multiplier_up, 0, complex_to_real, 0);
// Collect the output samples
auto sink = gr::blocks::vector_sink_f::make();
tb->connect(complex_to_real, 0, sink, 0);
// Run the flowgraph
tb->run();
// Retrieve the output data
std::vector<float> received_passband_audio = sink->data();
// Normalize and convert to int16_t
// Find maximum absolute value
float max_abs_value = 0.0f;
for (const auto& sample : received_passband_audio) {
if (fabs(sample) > max_abs_value) {
max_abs_value = fabs(sample);
}
}
if (max_abs_value == 0.0f) {
max_abs_value = 1.0f;
}
float scaling_factor = 0.9f / max_abs_value; // Prevent clipping at extremes
// Apply scaling and convert to int16_t
std::vector<int16_t> received_passband_signal(received_passband_audio.size());
for (size_t i = 0; i < received_passband_audio.size(); ++i) {
float sample = received_passband_audio[i] * scaling_factor;
// Ensure the sample is within [-1.0, 1.0]
if (sample > 1.0f) sample = 1.0f;
if (sample < -1.0f) sample = -1.0f;
received_passband_signal[i] = static_cast<int16_t>(sample * 32767.0f);
}
// Step 6: Write the received passband audio to another WAV file
writeWavFile("modem_output_received.wav", received_passband_signal, Fs);
std::cout << "Processing complete. Output files generated." << std::endl;
// Success message
std::cout << "Modem test completed successfully.\n";
return 0;
}

View File

@ -1,19 +0,0 @@
list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake")
# Find the installed gtest package
find_package(GTest REQUIRED)
find_package(SndFile REQUIRED)
find_package(FFTW3 REQUIRED)
# Add test executable
add_executable(PSKModulatorTest PSKModulatorTests.cpp)
# Link the test executable with the GTest libraries
target_link_libraries(PSKModulatorTest GTest::GTest GTest::Main FFTW3::fftw3 SndFile::sndfile)
# Enable C++17 standard
set_target_properties(PSKModulatorTest PROPERTIES CXX_STANDARD 17)
# Add test cases
include(GoogleTest)
gtest_discover_tests(PSKModulatorTest)

View File

@ -0,0 +1,23 @@
#include "gtest/gtest.h"
#include "FSKModulator.h"
#include <vector>
TEST(FSKModulatorTest, SignalLength) {
using namespace milstd;
// Parameters
FSKModulator modulator(FSKModulator::ShiftType::NarrowShift, 75.0, 8000.0);
// Input data bits
std::vector<uint8_t> dataBits = {1, 0, 1, 1, 0};
// Modulate the data
std::vector<double> signal = modulator.modulate(dataBits);
// Calculate expected number of samples
size_t samplesPerSymbol = static_cast<size_t>(modulator.getSampleRate() * modulator.getSymbolDuration());
size_t expectedSamples = dataBits.size() * samplesPerSymbol;
// Verify signal length
EXPECT_EQ(signal.size(), expectedSamples);
}

View File

@ -1,68 +0,0 @@
#include <gtest/gtest.h>
#include "modulation/PSKModulator.h"
// Fixture for PSK Modulator tests
class PSKModulatorTest : public ::testing::Test {
protected:
double sample_rate = 48000;
size_t num_taps = 48;
bool is_frequency_hopping = false;
PSKModulator modulator{sample_rate, is_frequency_hopping, num_taps};
std::vector<uint8_t> symbols = {0, 3, 5, 7};
};
TEST_F(PSKModulatorTest, ModulationOutputLength) {
auto signal = modulator.modulate(symbols);
size_t expected_length = symbols.size() * (sample_rate / SYMBOL_RATE);
ASSERT_EQ(signal.size(), expected_length);
for (auto& sample : signal) {
EXPECT_GE(sample, -32768);
EXPECT_LE(sample, 32767);
}
}
TEST_F(PSKModulatorTest, DemodulationOutput) {
auto passband_signal = modulator.modulate(symbols);
// Debug: Print modulated passband signal
std::cout << "Modulated Passband Signal: ";
for (const auto& sample : passband_signal) {
std::cout << sample << " ";
}
std::cout << std::endl;
size_t baud_rate;
size_t interleave_setting;
bool is_voice;
auto decoded_symbols = modulator.demodulate(passband_signal, baud_rate, interleave_setting, is_voice);
// Debug: Print decoded symbols
std::cout << "Decoded Symbols: ";
for (const auto& symbol : decoded_symbols) {
std::cout << (int)symbol << " ";
}
std::cout << std::endl;
// Debug: Print expected symbols
std::cout << "Expected Symbols: ";
for (const auto& symbol : symbols) {
std::cout << (int)symbol << " ";
}
std::cout << std::endl;
ASSERT_EQ(symbols.size(), decoded_symbols.size());
for (size_t i = 0; i < symbols.size(); i++) {
EXPECT_EQ(symbols[i], decoded_symbols[i]) << " at index " << i;
}
}
TEST_F(PSKModulatorTest, InvalidSymbolInput) {
std::vector<uint8_t> invalid_symbols = {0, 8, 9};
EXPECT_THROW(modulator.modulate(invalid_symbols), std::out_of_range);
}