S1005403_RisCC/target_simulator/utils/tftp_client.py
2025-10-02 14:37:38 +02:00

116 lines
5.0 KiB
Python

class TFTPError(Exception):
def __init__(self, code, message):
super().__init__(f"TFTP Error {code}: {message}")
self.code = code
self.message = message
import socket
import struct
import io
TFTP_PORT = 69
TFTP_BLOCK_SIZE = 512
class TFTPClient:
def download(self, filename, fileobj, mode="octet"):
"""
Downloads a file from the TFTP server.
filename: remote filename on server
fileobj: file-like object (opened in binary or text mode for writing)
mode: 'octet' (binary) or 'netascii' (text)
Returns True on success, raises TFTPError on error.
"""
self._validate_params(filename, mode)
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(self.timeout)
try:
mode_bytes = mode.encode('ascii')
rrq = struct.pack('!H', 1) + filename.encode('ascii') + b'\0' + mode_bytes + b'\0'
sock.sendto(rrq, (self.server_ip, self.server_port))
expected_block = 1
while True:
data, server = sock.recvfrom(1024)
opcode = struct.unpack('!H', data[:2])[0]
if opcode == 5:
error_code = struct.unpack('!H', data[2:4])[0]
error_msg = data[4:].split(b'\0', 1)[0].decode(errors='replace')
raise TFTPError(error_code, error_msg)
if opcode != 3:
raise TFTPError(-1, 'Unexpected response to RRQ')
block_num = struct.unpack('!H', data[2:4])[0]
if block_num != expected_block:
raise TFTPError(-1, f'Unexpected block number: {block_num}')
block_data = data[4:]
if mode == "netascii":
block_data = block_data.replace(b"\r\n", b"\n").decode("ascii")
fileobj.write(block_data)
else:
fileobj.write(block_data)
ack = struct.pack('!HH', 4, block_num)
sock.sendto(ack, server)
if len(data[4:]) < TFTP_BLOCK_SIZE:
break
expected_block = (expected_block + 1) % 65536
return True
finally:
sock.close()
def __init__(self, server_ip, server_port=TFTP_PORT, timeout=5):
"""
server_ip: str, IP address of TFTP server
server_port: int, port (default 69)
timeout: int, socket timeout in seconds
"""
self.server_ip = server_ip
self.server_port = server_port
self.timeout = timeout
def _validate_params(self, filename, mode):
if not filename or not isinstance(filename, str):
raise ValueError("Invalid filename")
if mode not in ("octet", "netascii"):
raise ValueError("Invalid mode: must be 'octet' or 'netascii'")
def upload(self, filename, fileobj, mode="octet"):
"""
Uploads a file to the TFTP server.
filename: remote filename on server
fileobj: file-like object (opened in binary or text mode)
mode: 'octet' (binary) or 'netascii' (text)
Returns True on success, raises TFTPError on error.
"""
self._validate_params(filename, mode)
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(self.timeout)
try:
mode_bytes = mode.encode('ascii')
wrq = struct.pack('!H', 2) + filename.encode('ascii') + b'\0' + mode_bytes + b'\0'
sock.sendto(wrq, (self.server_ip, self.server_port))
data, server = sock.recvfrom(1024)
opcode, block = struct.unpack('!HH', data[:4])
if opcode == 5:
error_code = struct.unpack('!H', data[2:4])[0]
error_msg = data[4:].split(b'\0', 1)[0].decode(errors='replace')
raise TFTPError(error_code, error_msg)
if opcode != 4 or block != 0:
raise TFTPError(-1, 'Unexpected response to WRQ')
block_num = 1
while True:
chunk = fileobj.read(TFTP_BLOCK_SIZE)
if mode == "netascii" and isinstance(chunk, str):
chunk = chunk.replace("\n", "\r\n").encode("ascii")
pkt = struct.pack('!HH', 3, block_num) + chunk
sock.sendto(pkt, server)
data, _ = sock.recvfrom(1024)
opcode, ack_block = struct.unpack('!HH', data[:4])
if opcode == 5:
error_code = struct.unpack('!H', data[2:4])[0]
error_msg = data[4:].split(b'\0', 1)[0].decode(errors='replace')
raise TFTPError(error_code, error_msg)
if opcode != 4 or ack_block != block_num:
raise TFTPError(-1, 'Unexpected response to DATA')
if len(chunk) < TFTP_BLOCK_SIZE:
break
block_num = (block_num + 1) % 65536
return True
finally:
sock.close()