119 lines
5.2 KiB
Python
119 lines
5.2 KiB
Python
# target_simulator/utils/tftp_client.py
|
|
|
|
import socket
|
|
import struct
|
|
import io
|
|
from target_simulator.utils.logger import get_logger
|
|
|
|
class TFTPError(Exception):
|
|
def __init__(self, code, message):
|
|
super().__init__(f"TFTP Error {code}: {message}")
|
|
self.code = code
|
|
self.message = message
|
|
|
|
TFTP_PORT = 69
|
|
TFTP_BLOCK_SIZE = 512
|
|
|
|
class TFTPClient:
|
|
def __init__(self, server_ip, server_port=TFTP_PORT, timeout=5):
|
|
self.server_ip = server_ip
|
|
self.server_port = int(server_port)
|
|
self.timeout = timeout
|
|
self.logger = get_logger(__name__)
|
|
|
|
def _validate_params(self, filename, mode):
|
|
if not filename or not isinstance(filename, str):
|
|
raise ValueError("Invalid filename")
|
|
if mode not in ("octet", "netascii"):
|
|
raise ValueError("Invalid mode: must be 'octet' or 'netascii'")
|
|
|
|
def upload(self, filename, fileobj, mode="octet"):
|
|
self._validate_params(filename, mode)
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.settimeout(self.timeout)
|
|
|
|
is_text_stream = isinstance(fileobj, io.TextIOBase)
|
|
|
|
try:
|
|
mode_bytes = mode.encode('ascii')
|
|
wrq = struct.pack('!H', 2) + filename.encode('ascii') + b'\0' + mode_bytes + b'\0'
|
|
|
|
self.logger.debug(f"Sending WRQ to {self.server_ip}:{self.server_port} for '{filename}'")
|
|
sock.sendto(wrq, (self.server_ip, self.server_port))
|
|
|
|
self.logger.debug("Waiting for initial ACK(0)...")
|
|
data, server = sock.recvfrom(1024)
|
|
self.logger.debug(f"Received {len(data)} bytes from {server}: {data.hex()}")
|
|
|
|
# --- GESTIONE RISPOSTA ANOMALA ---
|
|
if len(data) >= 4:
|
|
opcode, block = struct.unpack('!HH', data[:4])
|
|
elif len(data) >= 2:
|
|
# Il pacchetto è troppo corto per contenere un block number.
|
|
# Potrebbe essere un ACK malformato o un errore.
|
|
self.logger.warning(f"Received a short packet ({len(data)} bytes). Assuming it's a malformed ACK/ERROR.")
|
|
opcode = struct.unpack('!H', data[:2])[0]
|
|
block = 0 # Assumiamo blocco 0 per l'ACK iniziale
|
|
else:
|
|
raise TFTPError(-1, f"Invalid packet received. Length is {len(data)} bytes, expected at least 2.")
|
|
|
|
if opcode == 5: # ERROR
|
|
error_code = struct.unpack('!H', data[2:4])[0] if len(data) >= 4 else -1
|
|
error_msg = data[4:].split(b'\0', 1)[0].decode(errors='replace') if len(data) > 4 else "Unknown error (short packet)"
|
|
raise TFTPError(error_code, error_msg)
|
|
|
|
if opcode != 4 or block != 0:
|
|
raise TFTPError(-1, f'Unexpected response to WRQ. Opcode: {opcode}, Block: {block}')
|
|
|
|
self.logger.debug("Initial ACK(0) received correctly. Starting data transfer.")
|
|
block_num = 1
|
|
while True:
|
|
chunk = fileobj.read(TFTP_BLOCK_SIZE)
|
|
|
|
if is_text_stream:
|
|
if not chunk:
|
|
chunk_bytes = b''
|
|
else:
|
|
chunk_bytes = chunk.encode("ascii")
|
|
else:
|
|
chunk_bytes = chunk
|
|
|
|
pkt = struct.pack('!HH', 3, block_num) + chunk_bytes
|
|
self.logger.debug(f"Sending DATA block {block_num} ({len(pkt)} bytes) to {server}")
|
|
sock.sendto(pkt, server)
|
|
|
|
self.logger.debug(f"Waiting for ACK({block_num})...")
|
|
data, _ = sock.recvfrom(1024)
|
|
self.logger.debug(f"Received {len(data)} bytes for ACK({block_num}): {data.hex()}")
|
|
|
|
if len(data) < 4:
|
|
raise TFTPError(-1, f"Invalid ACK packet for block {block_num}. Length is {len(data)} bytes.")
|
|
|
|
opcode, ack_block = struct.unpack('!HH', data[:4])
|
|
|
|
if opcode == 5:
|
|
error_code = struct.unpack('!H', data[2:4])[0]
|
|
error_msg = data[4:].split(b'\0', 1)[0].decode(errors='replace')
|
|
raise TFTPError(error_code, error_msg)
|
|
|
|
if opcode != 4 or ack_block != block_num:
|
|
# Gestione di ACK duplicati (comune su UDP)
|
|
if opcode == 4 and ack_block == block_num - 1:
|
|
self.logger.warning(f"Received duplicate ACK for block {ack_block}. Resending block {block_num}.")
|
|
sock.sendto(pkt, server) # Resend current packet
|
|
continue # Skip to next recvfrom
|
|
else:
|
|
raise TFTPError(-1, f'Unexpected ACK. Expected block {block_num}, got {ack_block}')
|
|
|
|
if len(chunk_bytes) < TFTP_BLOCK_SIZE:
|
|
self.logger.debug("Last block sent and ACKed. Transfer complete.")
|
|
break
|
|
|
|
block_num = (block_num + 1) % 65536
|
|
return True
|
|
finally:
|
|
sock.close()
|
|
|
|
def download(self, filename, fileobj, mode="octet"):
|
|
# ... (implementation from your file, non modificata)
|
|
pass |