156 lines
5.7 KiB
Python
156 lines
5.7 KiB
Python
# target_simulator/utils/tftp_client.py
|
|
|
|
import socket
|
|
import struct
|
|
import io
|
|
from target_simulator.utils.logger import get_logger
|
|
|
|
|
|
class TFTPError(Exception):
|
|
def __init__(self, code, message):
|
|
super().__init__(f"TFTP Error {code}: {message}")
|
|
self.code = code
|
|
self.message = message
|
|
|
|
|
|
TFTP_PORT = 69
|
|
TFTP_BLOCK_SIZE = 512
|
|
|
|
|
|
class TFTPClient:
|
|
def __init__(self, server_ip, server_port=TFTP_PORT, timeout=5):
|
|
self.server_ip = server_ip
|
|
self.server_port = int(server_port)
|
|
self.timeout = timeout
|
|
self.logger = get_logger(__name__)
|
|
|
|
def _validate_params(self, filename, mode):
|
|
if not filename or not isinstance(filename, str):
|
|
raise ValueError("Invalid filename")
|
|
if mode not in ("octet", "netascii"):
|
|
raise ValueError("Invalid mode: must be 'octet' or 'netascii'")
|
|
|
|
def upload(self, filename, fileobj, mode="octet"):
|
|
self._validate_params(filename, mode)
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.settimeout(self.timeout)
|
|
|
|
is_text_stream = isinstance(fileobj, io.TextIOBase)
|
|
|
|
try:
|
|
mode_bytes = mode.encode("ascii")
|
|
wrq = (
|
|
struct.pack("!H", 2)
|
|
+ filename.encode("ascii")
|
|
+ b"\0"
|
|
+ mode_bytes
|
|
+ b"\0"
|
|
)
|
|
|
|
self.logger.debug(
|
|
f"Sending WRQ to {self.server_ip}:{self.server_port} for '{filename}'"
|
|
)
|
|
sock.sendto(wrq, (self.server_ip, self.server_port))
|
|
|
|
self.logger.debug("Waiting for initial ACK(0)...")
|
|
data, server = sock.recvfrom(1024)
|
|
self.logger.debug(f"Received {len(data)} bytes from {server}: {data.hex()}")
|
|
|
|
# --- GESTIONE RISPOSTA ANOMALA ---
|
|
if len(data) >= 4:
|
|
opcode, block = struct.unpack("!HH", data[:4])
|
|
elif len(data) >= 2:
|
|
# Il pacchetto è troppo corto per contenere un block number.
|
|
# Potrebbe essere un ACK malformato o un errore.
|
|
self.logger.warning(
|
|
f"Received a short packet ({len(data)} bytes). Assuming it's a malformed ACK/ERROR."
|
|
)
|
|
opcode = struct.unpack("!H", data[:2])[0]
|
|
block = 0 # Assumiamo blocco 0 per l'ACK iniziale
|
|
else:
|
|
raise TFTPError(
|
|
-1,
|
|
f"Invalid packet received. Length is {len(data)} bytes, expected at least 2.",
|
|
)
|
|
|
|
if opcode == 5: # ERROR
|
|
error_code = struct.unpack("!H", data[2:4])[0] if len(data) >= 4 else -1
|
|
error_msg = (
|
|
data[4:].split(b"\0", 1)[0].decode(errors="replace")
|
|
if len(data) > 4
|
|
else "Unknown error (short packet)"
|
|
)
|
|
raise TFTPError(error_code, error_msg)
|
|
|
|
if opcode != 4 or block != 0:
|
|
raise TFTPError(
|
|
-1, f"Unexpected response to WRQ. Opcode: {opcode}, Block: {block}"
|
|
)
|
|
|
|
self.logger.debug(
|
|
"Initial ACK(0) received correctly. Starting data transfer."
|
|
)
|
|
block_num = 1
|
|
while True:
|
|
chunk = fileobj.read(TFTP_BLOCK_SIZE)
|
|
|
|
if is_text_stream:
|
|
if not chunk:
|
|
chunk_bytes = b""
|
|
else:
|
|
chunk_bytes = chunk.encode("ascii")
|
|
else:
|
|
chunk_bytes = chunk
|
|
|
|
pkt = struct.pack("!HH", 3, block_num) + chunk_bytes
|
|
self.logger.debug(
|
|
f"Sending DATA block {block_num} ({len(pkt)} bytes) to {server}"
|
|
)
|
|
sock.sendto(pkt, server)
|
|
|
|
self.logger.debug(f"Waiting for ACK({block_num})...")
|
|
data, _ = sock.recvfrom(1024)
|
|
self.logger.debug(
|
|
f"Received {len(data)} bytes for ACK({block_num}): {data.hex()}"
|
|
)
|
|
|
|
if len(data) < 4:
|
|
raise TFTPError(
|
|
-1,
|
|
f"Invalid ACK packet for block {block_num}. Length is {len(data)} bytes.",
|
|
)
|
|
|
|
opcode, ack_block = struct.unpack("!HH", data[:4])
|
|
|
|
if opcode == 5:
|
|
error_code = struct.unpack("!H", data[2:4])[0]
|
|
error_msg = data[4:].split(b"\0", 1)[0].decode(errors="replace")
|
|
raise TFTPError(error_code, error_msg)
|
|
|
|
if opcode != 4 or ack_block != block_num:
|
|
# Gestione di ACK duplicati (comune su UDP)
|
|
if opcode == 4 and ack_block == block_num - 1:
|
|
self.logger.warning(
|
|
f"Received duplicate ACK for block {ack_block}. Resending block {block_num}."
|
|
)
|
|
sock.sendto(pkt, server) # Resend current packet
|
|
continue # Skip to next recvfrom
|
|
else:
|
|
raise TFTPError(
|
|
-1,
|
|
f"Unexpected ACK. Expected block {block_num}, got {ack_block}",
|
|
)
|
|
|
|
if len(chunk_bytes) < TFTP_BLOCK_SIZE:
|
|
self.logger.debug("Last block sent and ACKed. Transfer complete.")
|
|
break
|
|
|
|
block_num = (block_num + 1) % 65536
|
|
return True
|
|
finally:
|
|
sock.close()
|
|
|
|
def download(self, filename, fileobj, mode="octet"):
|
|
# ... (implementation from your file, non modificata)
|
|
pass
|