240 lines
9.9 KiB
Python
240 lines
9.9 KiB
Python
# core/sfp_transport.py
|
|
"""
|
|
Provides a reusable transport layer for the Simple Fragmentation Protocol (SFP).
|
|
|
|
This module handles UDP socket communication, SFP header parsing, fragment
|
|
reassembly, and ACK generation. It is application-agnostic and uses a
|
|
callback/handler system to pass fully reassembled payloads to the
|
|
application layer based on the SFP_FLOW identifier.
|
|
"""
|
|
|
|
import socket
|
|
import logging
|
|
import threading
|
|
import time
|
|
from typing import Dict, Callable, Optional
|
|
|
|
from controlpanel import config
|
|
from controlpanel.utils.network import create_udp_socket, close_udp_socket
|
|
from controlpanel.core.sfp_structures import SFPHeader
|
|
|
|
# Define a type hint for payload handlers
|
|
PayloadHandler = Callable[[bytearray], None]
|
|
|
|
class SfpTransport:
|
|
"""Manages SFP communication and payload reassembly."""
|
|
|
|
def __init__(self, host: str, port: int, payload_handlers: Dict[int, PayloadHandler]):
|
|
"""
|
|
Initializes the SFP Transport layer.
|
|
|
|
Args:
|
|
host (str): The local IP address to bind the UDP socket to.
|
|
port (int): The local port to listen on.
|
|
payload_handlers (Dict[int, PayloadHandler]): A dictionary mapping
|
|
SFP_FLOW IDs (as integers) to their corresponding handler functions.
|
|
Each handler will be called with the complete payload (bytearray).
|
|
"""
|
|
self._log_prefix = "[SfpTransport]"
|
|
logging.info(f"{self._log_prefix} Initializing for {host}:{port}...")
|
|
|
|
self._host = host
|
|
self._port = port
|
|
self._payload_handlers = payload_handlers
|
|
self._socket: Optional[socket.socket] = None
|
|
self._receiver_thread: Optional[threading.Thread] = None
|
|
self._stop_event = threading.Event()
|
|
|
|
# Reassembly state dictionaries, managed by this transport layer
|
|
self._fragments: Dict[tuple, Dict[int, int]] = {}
|
|
self._buffers: Dict[tuple, bytearray] = {}
|
|
|
|
logging.debug(
|
|
f"{self._log_prefix} Registered handlers for flows: "
|
|
f"{[chr(k) if 32 <= k <= 126 else k for k in self._payload_handlers.keys()]}"
|
|
)
|
|
|
|
def start(self) -> bool:
|
|
"""
|
|
Starts the transport layer by creating the socket and launching the receiver thread.
|
|
|
|
Returns:
|
|
bool: True if started successfully, False otherwise.
|
|
"""
|
|
if self._receiver_thread is not None and self._receiver_thread.is_alive():
|
|
logging.warning(f"{self._log_prefix} Start called, but receiver is already running.")
|
|
return True
|
|
|
|
self._socket = create_udp_socket(self._host, self._port)
|
|
if not self._socket:
|
|
logging.critical(f"{self._log_prefix} Failed to create and bind socket. Cannot start.")
|
|
return False
|
|
|
|
self._stop_event.clear()
|
|
self._receiver_thread = threading.Thread(
|
|
target=self._receive_loop, name="SfpTransportThread", daemon=True
|
|
)
|
|
self._receiver_thread.start()
|
|
logging.info(f"{self._log_prefix} Receiver thread started.")
|
|
return True
|
|
|
|
def shutdown(self):
|
|
"""Stops the receiver thread and closes the socket."""
|
|
logging.info(f"{self._log_prefix} Shutdown initiated.")
|
|
self._stop_event.set()
|
|
|
|
# The socket is closed here to interrupt the blocking recvfrom call
|
|
if self._socket:
|
|
close_udp_socket(self._socket)
|
|
self._socket = None
|
|
|
|
if self._receiver_thread and self._receiver_thread.is_alive():
|
|
logging.debug(f"{self._log_prefix} Waiting for receiver thread to join...")
|
|
self._receiver_thread.join(timeout=2.0)
|
|
if self._receiver_thread.is_alive():
|
|
logging.warning(f"{self._log_prefix} Receiver thread did not join cleanly.")
|
|
|
|
logging.info(f"{self._log_prefix} Shutdown complete.")
|
|
|
|
def _receive_loop(self):
|
|
"""The main loop that listens for UDP packets and processes them."""
|
|
log_prefix = f"{self._log_prefix} Loop"
|
|
logging.info(f"{log_prefix} Starting receive loop.")
|
|
|
|
while not self._stop_event.is_set():
|
|
if not self._socket:
|
|
logging.error(f"{log_prefix} Socket is not available. Stopping loop.")
|
|
break
|
|
|
|
try:
|
|
data, addr = self._socket.recvfrom(65535)
|
|
if not data:
|
|
continue
|
|
except socket.timeout:
|
|
continue
|
|
except OSError:
|
|
# This is expected when the socket is closed during shutdown
|
|
if not self._stop_event.is_set():
|
|
logging.error(f"{log_prefix} Socket error.", exc_info=True)
|
|
break
|
|
except Exception:
|
|
logging.exception(f"{log_prefix} Unexpected error in recvfrom.")
|
|
time.sleep(0.01)
|
|
continue
|
|
|
|
self._process_packet(data, addr)
|
|
|
|
logging.info(f"{log_prefix} Receive loop terminated.")
|
|
|
|
def _process_packet(self, data: bytes, addr: tuple):
|
|
"""Parses an SFP packet and handles fragment reassembly."""
|
|
header_size = SFPHeader.size()
|
|
if len(data) < header_size:
|
|
logging.warning(f"Packet from {addr} is too small for SFP header. Ignoring.")
|
|
return
|
|
|
|
try:
|
|
header = SFPHeader.from_buffer_copy(data)
|
|
except (ValueError, TypeError):
|
|
logging.error(f"Failed to parse SFP header from {addr}. Ignoring.")
|
|
return
|
|
|
|
flow, tid = header.SFP_FLOW, header.SFP_TID
|
|
frag, total_frags = header.SFP_FRAG, header.SFP_TOTFRGAS
|
|
pl_size, pl_offset = header.SFP_PLSIZE, header.SFP_PLOFFSET
|
|
total_size = header.SFP_TOTSIZE
|
|
key = (flow, tid)
|
|
|
|
# Handle ACK Request
|
|
if header.SFP_FLAGS & 0x01:
|
|
self._send_ack(addr, data[:header_size])
|
|
|
|
# Validate packet metadata
|
|
if total_frags == 0 or total_frags > 60000 or total_size <= 0:
|
|
logging.warning(f"Invalid metadata for {key}: total_frags={total_frags}, total_size={total_size}. Ignoring.")
|
|
return
|
|
|
|
# Start of a new transaction
|
|
if frag == 0:
|
|
self._cleanup_lingering_transactions(flow, tid)
|
|
logging.debug(f"New transaction started for key={key}. Total size: {total_size} bytes.")
|
|
self._fragments[key] = {}
|
|
try:
|
|
self._buffers[key] = bytearray(total_size)
|
|
except (MemoryError, ValueError):
|
|
logging.error(f"Failed to allocate {total_size} bytes for key={key}. Ignoring transaction.")
|
|
self._fragments.pop(key, None)
|
|
return
|
|
|
|
# Check if we are tracking this transaction
|
|
if key not in self._buffers or key not in self._fragments:
|
|
logging.debug(f"Ignoring fragment {frag} for untracked transaction key={key}.")
|
|
return
|
|
|
|
# Store fragment info and copy payload
|
|
self._fragments[key][frag] = total_frags
|
|
payload = data[header_size:]
|
|
bytes_to_copy = min(pl_size, len(payload))
|
|
|
|
if (pl_offset + bytes_to_copy) > len(self._buffers[key]):
|
|
logging.error(f"Payload for key={key}, frag={frag} would overflow buffer. Ignoring.")
|
|
return
|
|
|
|
self._buffers[key][pl_offset : pl_offset + bytes_to_copy] = payload[:bytes_to_copy]
|
|
|
|
# Check for completion
|
|
if len(self._fragments[key]) == total_frags:
|
|
#logging.info(f"Transaction complete for key={key}. Handing off to application layer.")
|
|
|
|
# Retrieve completed buffer and clean up state for this key
|
|
completed_payload = self._buffers.pop(key)
|
|
self._fragments.pop(key)
|
|
|
|
# Find and call the appropriate handler
|
|
handler = self._payload_handlers.get(flow)
|
|
if handler:
|
|
try:
|
|
handler(completed_payload)
|
|
except Exception:
|
|
logging.exception(f"Error executing payload handler for flow {flow}.")
|
|
else:
|
|
logging.warning(f"No payload handler registered for flow ID {flow}.")
|
|
|
|
def _send_ack(self, dest_addr: tuple, original_header_bytes: bytes):
|
|
"""Sends an SFP ACK packet back to the sender."""
|
|
log_prefix = f"{self._log_prefix} ACK"
|
|
if not self._socket: return
|
|
|
|
try:
|
|
ack_header = bytearray(original_header_bytes)
|
|
flow = ack_header[SFPHeader.get_field_offset("SFP_FLOW")]
|
|
|
|
# Determine window size based on flow
|
|
window_size = 0
|
|
if flow == ord('M'): # MFD
|
|
window_size = config.ACK_WINDOW_SIZE_MFD
|
|
elif flow == ord('S'): # SAR
|
|
window_size = config.ACK_WINDOW_SIZE_SAR
|
|
|
|
# Modify header for ACK response
|
|
ack_header[SFPHeader.get_field_offset("SFP_DIRECTION")] = 0x3C # '<'
|
|
ack_header[SFPHeader.get_field_offset("SFP_WIN")] = window_size
|
|
original_flags = ack_header[SFPHeader.get_field_offset("SFP_FLAGS")]
|
|
ack_header[SFPHeader.get_field_offset("SFP_FLAGS")] = (original_flags | 0x01) & ~0x02
|
|
|
|
self._socket.sendto(ack_header, dest_addr)
|
|
logging.debug(f"{log_prefix} Sent to {dest_addr} for flow {chr(flow) if 32<=flow<=126 else flow}.")
|
|
except Exception:
|
|
logging.exception(f"{log_prefix} Failed to send to {dest_addr}.")
|
|
|
|
def _cleanup_lingering_transactions(self, current_flow: int, current_tid: int):
|
|
"""Removes old, incomplete transactions for the same flow."""
|
|
# This is a simplified cleanup. The original was more complex for stats.
|
|
keys_to_remove = [
|
|
key for key in self._fragments
|
|
if key[0] == current_flow and key[1] != current_tid
|
|
]
|
|
for key in keys_to_remove:
|
|
logging.warning(f"Cleaning up lingering/incomplete transaction for key={key}.")
|
|
self._fragments.pop(key, None)
|
|
self._buffers.pop(key, None) |