S1005403_RisCC/target_simulator/gui/analysis_window.py

708 lines
27 KiB
Python

# target_simulator/gui/analysis_window.py
"""
A Toplevel window for displaying performance analysis by processing
an efficient trail file.
"""
import tkinter as tk
from tkinter import ttk, messagebox
import json
import os
import csv
import math
import statistics
import warnings
from typing import Optional, Dict, List, Any, Tuple
from target_simulator.gui.performance_analysis_window import PerformanceAnalysisWindow
try:
import numpy as np
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import (
FigureCanvasTkAgg,
NavigationToolbar2Tk,
)
MATPLOTLIB_AVAILABLE = True
except ImportError:
np = None
MATPLOTLIB_AVAILABLE = False
# Constants for analysis
DOWNSAMPLE_THRESHOLD = 4000 # Number of points before downsampling is applied
# Threshold for spike filtering (e.g., 20 times the median error, or minimum 500ft)
SPIKE_THRESHOLD_FACTOR = 20
SPIKE_THRESHOLD_MIN_FT = 500
class AnalysisWindow(tk.Toplevel):
"""
A window that displays tracking performance analysis by loading data
from an archive's main JSON file and its associated `.trail.csv`.
"""
def __init__(self, master, archive_filepath: str):
super().__init__(master)
self.title(f"Analysis for: {os.path.basename(archive_filepath)}")
self.geometry("1100x800")
self.archive_filepath = archive_filepath
self.trail_filepath: Optional[str] = None
self.latency_filepath: Optional[str] = None
self.performance_data_path: Optional[str] = None
self.scenario_name = "Unknown"
self.target_ids: List[int] = []
self.selected_target_id = tk.IntVar()
self._show_loading_window(archive_filepath)
def _load_data_and_setup(self, filepath: str):
"""Loads metadata from the main archive and finds associated data files."""
try:
with open(filepath, "r", encoding="utf-8") as f:
archive_data = json.load(f)
except Exception as e:
raise IOError(f"Could not load archive file: {e}")
metadata = archive_data.get("metadata", {})
self.scenario_name = metadata.get("scenario_name", "Unknown")
self.title(f"Analysis - {self.scenario_name}")
# Find the associated trail, latency, and performance files
base_path, _ = os.path.splitext(filepath)
self.trail_filepath = f"{base_path}.trail.csv"
self.latency_filepath = f"{base_path}.latency.csv"
self.performance_data_path = f"{base_path}.perf.csv"
if not os.path.exists(self.trail_filepath):
raise FileNotFoundError(
f"Required trail file not found: {self.trail_filepath}"
)
# Get available target IDs from the trail file header
with open(self.trail_filepath, "r", encoding="utf-8") as f:
reader = csv.reader(f)
headers = next(reader, [])
if "target_id" not in headers:
raise ValueError("Trail file missing 'target_id' column.")
target_id_index = headers.index("target_id")
ids = set()
for row in reader:
# --- MODIFIED PART START ---
# Check for comment lines beginning with '#'
if not row or row[0].strip().startswith("#"):
continue
# --- MODIFIED PART END ---
try:
ids.add(int(row[target_id_index]))
except (ValueError, IndexError):
continue
self.target_ids = sorted(list(ids))
def _show_loading_window(self, archive_filepath: str):
"""Shows a loading dialog and loads data in the background."""
loading_dialog = tk.Toplevel(self)
loading_dialog.title("Loading Analysis")
loading_dialog.geometry("400x150")
loading_dialog.transient(self)
loading_dialog.grab_set()
loading_dialog.update_idletasks()
x = (
self.winfo_x()
+ (self.winfo_width() // 2)
- (loading_dialog.winfo_width() // 2)
)
y = (
self.winfo_y()
+ (self.winfo_height() // 2)
- (loading_dialog.winfo_height() // 2)
)
loading_dialog.geometry(f"+{x}+{y}")
ttk.Label(
loading_dialog, text="Loading simulation data...", font=("Segoe UI", 11)
).pack(pady=(20, 10))
progress_label = ttk.Label(
loading_dialog, text="Please wait", font=("Segoe UI", 9)
)
progress_label.pack(pady=5)
progress = ttk.Progressbar(loading_dialog, mode="indeterminate", length=300)
progress.pack(pady=10)
progress.start(10)
def load_and_display():
try:
progress_label.config(text="Locating data files...")
self.update()
self._load_data_and_setup(archive_filepath)
progress_label.config(text="Creating widgets...")
self.update()
self._create_widgets()
progress_label.config(text="Ready.")
self.update()
loading_dialog.destroy()
# Trigger initial analysis
if self.target_ids:
self._on_target_select()
else:
messagebox.showinfo(
"No Target Data",
"No active target data found in this simulation run.",
parent=self,
)
except Exception as e:
loading_dialog.destroy()
messagebox.showerror(
"Analysis Error", f"Failed to load analysis:\n{e}", parent=self
)
self.destroy()
self.after(100, load_and_display)
def _create_widgets(self):
main_pane = ttk.PanedWindow(self, orient=tk.VERTICAL)
main_pane.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
stats_frame = ttk.LabelFrame(main_pane, text="Error Statistics (feet)")
main_pane.add(stats_frame, weight=1)
self._create_stats_widgets(stats_frame)
plot_frame = ttk.LabelFrame(main_pane, text="Error Over Time (feet)")
main_pane.add(plot_frame, weight=4)
self._create_plot_widgets(plot_frame)
def _create_stats_widgets(self, parent):
# Configure grid per il layout
parent.rowconfigure(0, weight=0) # Top bar
parent.rowconfigure(1, weight=1) # Content area
parent.columnconfigure(0, weight=1)
# Top bar con combobox e pulsante
top_bar = ttk.Frame(parent, padding=5)
top_bar.grid(row=0, column=0, sticky="ew")
ttk.Label(top_bar, text="Select Target ID:").pack(side=tk.LEFT)
self.target_selector = ttk.Combobox(
top_bar,
textvariable=self.selected_target_id,
state="readonly",
width=5,
values=self.target_ids,
)
self.target_selector.pack(side=tk.LEFT, padx=5)
self.target_selector.bind("<<ComboboxSelected>>", self._on_target_select)
if self.target_ids:
self.selected_target_id.set(self.target_ids[0])
# Performance Analysis button (always visible, disabled if no data)
perf_button = ttk.Button(
top_bar,
text="Open Performance Analysis",
command=self._open_performance_window,
)
perf_button.pack(side=tk.LEFT, padx=(20, 0))
if not os.path.exists(self.performance_data_path):
perf_button.config(state="disabled")
# Content container diviso in due colonne
content_frame = ttk.Frame(parent)
content_frame.grid(row=1, column=0, sticky="nsew", padx=5, pady=5)
content_frame.columnconfigure(0, weight=1, uniform="half")
content_frame.columnconfigure(1, weight=1, uniform="half")
content_frame.rowconfigure(0, weight=1)
# Left: Stats table
table_container = ttk.Frame(content_frame)
table_container.grid(row=0, column=0, sticky="nsew", padx=(0, 2))
columns = ("error_type", "mean", "std_dev", "rmse")
self.stats_tree = ttk.Treeview(
table_container, columns=columns, show="headings", height=4
)
self.stats_tree.heading("error_type", text="")
self.stats_tree.heading("mean", text="Mean (ft)")
self.stats_tree.heading("std_dev", text="Std Dev (ft)")
self.stats_tree.heading("rmse", text="RMSE (ft)")
self.stats_tree.column("error_type", width=100, anchor=tk.W)
self.stats_tree.column("mean", anchor=tk.E, width=120)
self.stats_tree.column("std_dev", anchor=tk.E, width=120)
self.stats_tree.column("rmse", anchor=tk.E, width=120)
self.stats_tree.pack(fill=tk.BOTH, expand=True)
# Right: Legend frame
legend_frame = ttk.Frame(content_frame)
legend_frame.grid(row=0, column=1, sticky="nsew", padx=(2, 0))
legend_title = ttk.Label(
legend_frame, text="How to Interpret Results:", font=("Segoe UI", 9, "bold")
)
legend_title.pack(anchor=tk.NW, pady=(0, 5))
explanation_text = (
"Error = Real - Simulated Position\n\n"
"Sign (e.g., X axis):\n"
"• Positive: Real target at larger X\n"
"• Negative: Real target at smaller X\n\n"
"Spike Filtering:\n"
"Transients >20x median filtered\n"
"from plots and statistics.\n\n"
"Latency:\n"
"Time from packet generation\n"
"(server) to reception (client)."
)
ttk.Label(
legend_frame, text=explanation_text, justify=tk.LEFT, font=("Segoe UI", 9)
).pack(anchor=tk.NW, fill=tk.BOTH, expand=True)
def _create_plot_widgets(self, parent):
if not MATPLOTLIB_AVAILABLE:
ttk.Label(parent, text="Matplotlib is required for plotting.").pack()
return
fig = Figure(figsize=(5, 7), dpi=100)
# Check if latency file exists to determine subplot layout
has_latency = os.path.exists(self.latency_filepath)
if has_latency:
# Two subplots: errors (top) and latency (bottom)
gs = fig.add_gridspec(2, 1, height_ratios=[2, 1], hspace=0.3, top=0.95)
self.ax = fig.add_subplot(gs[0, 0])
self.ax_latency = fig.add_subplot(gs[1, 0], sharex=self.ax)
else:
# Single subplot: just errors
self.ax = fig.add_subplot(111)
self.ax_latency = None
# Error plot
self.ax.set_title("Instantaneous Error")
self.ax.set_ylabel("Error (ft)")
(self.line_x,) = self.ax.plot([], [], lw=1.5, label="Error X", color="#1f77b4")
(self.line_y,) = self.ax.plot([], [], lw=1.5, label="Error Y", color="#ff7f0e")
(self.line_z,) = self.ax.plot([], [], lw=1.5, label="Error Z", color="#2ca02c")
self.ax.grid(True, alpha=0.3)
self.ax.axhline(0.0, color="black", lw=1, linestyle="--", alpha=0.5)
self.ax.legend(loc="upper right", fontsize=9)
if not has_latency:
self.ax.set_xlabel("Elapsed Time (s)")
# Latency plot (if file exists)
if has_latency:
self.ax_latency.set_title("Latency Evolution")
self.ax_latency.set_xlabel("Elapsed Time (s)")
self.ax_latency.set_ylabel("Latency (ms)")
(self.line_latency,) = self.ax_latency.plot(
[], [], lw=1.5, color="#d62728", label="Latency"
)
self.ax_latency.grid(True, alpha=0.3)
self.ax_latency.legend(loc="upper right", fontsize=9)
else:
self.line_latency = None
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
fig.tight_layout()
canvas_frame = ttk.Frame(parent)
canvas_frame.pack(fill=tk.BOTH, expand=True)
toolbar_frame = ttk.Frame(canvas_frame)
toolbar_frame.pack(side=tk.TOP, fill=tk.X)
self.canvas = FigureCanvasTkAgg(fig, master=canvas_frame)
toolbar = NavigationToolbar2Tk(self.canvas, toolbar_frame)
toolbar.update()
self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
def _on_target_select(self, event=None):
"""Initiates analysis for the selected target."""
if not self.trail_filepath:
return
target_id = self.selected_target_id.get()
# Analyze data (raw timestamps and errors)
timestamps_raw, errors_raw = self._analyze_trail_file(target_id)
# Filter spikes from error data
filtered_ts, filtered_errors, spike_count, max_spike_error, max_spike_time = (
self._filter_spikes(timestamps_raw, errors_raw)
)
# Get latency data and filter spikes from it
latency_timestamps_raw, latency_values_raw = self._load_latency_data()
filtered_latency_timestamps, filtered_latency_values, _, _, _ = (
self._filter_spikes_latency(latency_timestamps_raw, latency_values_raw)
)
# Update UI with filtered data
self._update_stats_table(
filtered_errors, filtered_latency_values, filtered_latency_timestamps
)
self._update_plot(
filtered_ts,
filtered_errors,
spike_count,
max_spike_error,
max_spike_time,
filtered_latency_timestamps,
filtered_latency_values,
)
def _analyze_trail_file(
self, target_id: int
) -> Tuple[List[float], Dict[str, List[float]]]:
"""
Analyzes the trail file for a specific target using an efficient
two-pointer algorithm. Returns raw timestamps and errors (non-filtered).
"""
sim_points = []
real_points = []
with open(self.trail_filepath, "r", encoding="utf-8") as f:
reader = csv.DictReader(line for line in f if not line.startswith("#"))
for row in reader:
try:
if int(row["target_id"]) == target_id:
point = (
float(row["timestamp"]),
float(row["x_ft"]),
float(row["y_ft"]),
float(row["z_ft"]),
)
if row["source"] == "simulated":
sim_points.append(point)
elif row["source"] == "real":
real_points.append(point)
except (ValueError, KeyError):
continue
if not sim_points or not real_points:
return [], {"x": [], "y": [], "z": []}
# --- Two-Pointer Algorithm for Error Calculation ---
timestamps, errors_x, errors_y, errors_z = [], [], [], []
sim_idx = 0
for real_p in real_points:
real_ts, real_x, real_y, real_z = real_p
# Advance sim_idx to find the bracketing segment for the current real point
while (
sim_idx + 1 < len(sim_points) and sim_points[sim_idx + 1][0] < real_ts
):
sim_idx += 1
if sim_idx + 1 < len(sim_points):
p1 = sim_points[sim_idx]
p2 = sim_points[sim_idx + 1]
# Check if the real point is within this segment
if p1[0] <= real_ts <= p2[0]:
# Interpolate
ts1, x1, y1, z1 = p1
ts2, x2, y2, z2 = p2
duration = ts2 - ts1
if duration == 0:
continue
factor = (real_ts - ts1) / duration
interp_x = x1 + (x2 - x1) * factor
interp_y = y1 + (y2 - y1) * factor
interp_z = z1 + (z2 - z1) * factor
timestamps.append(real_ts)
errors_x.append(real_x - interp_x)
errors_y.append(real_y - interp_y)
errors_z.append(real_z - interp_z)
errors = {"x": errors_x, "y": errors_y, "z": errors_z}
return timestamps, errors
def _downsample_data(self, timestamps: List, errors: Dict) -> Tuple[List, Dict]:
"""Reduces the number of points for plotting while preserving shape."""
if len(timestamps) <= DOWNSAMPLE_THRESHOLD:
return timestamps, errors
# Simple interval-based downsampling
step = len(timestamps) // DOWNSAMPLE_THRESHOLD
ts_down = timestamps[::step]
err_down = {
"x": errors["x"][::step],
"y": errors["y"][::step],
"z": errors["z"][::step],
}
return ts_down, err_down
def _update_stats_table(
self,
filtered_errors: Dict[str, List[float]],
filtered_latency_values: List[float],
filtered_latency_timestamps: List[float],
):
"""Populates the stats Treeview with calculated metrics on filtered data."""
self.stats_tree.delete(*self.stats_tree.get_children())
stats = {}
for axis, err_list in filtered_errors.items():
if not err_list:
stats[axis] = {"mean": 0, "std_dev": 0, "rmse": 0}
continue
mean = statistics.mean(err_list)
stdev = statistics.stdev(err_list) if len(err_list) > 1 else 0
rmse = math.sqrt(sum(e**2 for e in err_list) / len(err_list))
stats[axis] = {"mean": mean, "std_dev": stdev, "rmse": rmse}
self.stats_tree.insert(
"",
"end",
values=(
f"Error {axis.upper()}",
f"{mean:.3f}",
f"{stdev:.3f}",
f"{rmse:.3f}",
),
)
# Add latency statistics if available
if filtered_latency_values:
lat_mean = statistics.mean(filtered_latency_values)
lat_std = (
statistics.stdev(filtered_latency_values)
if len(filtered_latency_values) > 1
else 0.0
)
lat_min = min(filtered_latency_values)
lat_max = max(filtered_latency_values)
self.stats_tree.insert(
"",
"end",
values=(
"Latency (ms)",
f"{lat_mean:.2f}",
f"{lat_std:.2f}",
f"{lat_min:.2f} - {lat_max:.2f}",
),
)
def _update_plot(
self,
timestamps: List[float],
errors: Dict[str, List[float]],
spike_count: int,
max_spike_error: float,
max_spike_time: float,
latency_timestamps: List[float],
latency_values: List[float],
):
"""Updates the matplotlib plot with (potentially downsampled) data."""
# Convert to elapsed time (seconds from start) for errors
if timestamps:
start_time_errors = min(timestamps)
elapsed_times_errors = [t - start_time_errors for t in timestamps]
else:
elapsed_times_errors = []
# Downsample error data if needed
ts_plot_errors, errors_plot = self._downsample_data(
elapsed_times_errors, errors
)
self.line_x.set_data(ts_plot_errors, errors_plot["x"])
self.line_y.set_data(ts_plot_errors, errors_plot["y"])
self.line_z.set_data(ts_plot_errors, errors_plot["z"])
# Remove old spike annotations
for txt in getattr(self.ax, "_spike_annotations", []):
txt.remove()
self.ax._spike_annotations = []
# Add spike annotation if any were filtered
if spike_count > 0:
# Adjust max_spike_time to be relative to start_time_errors for display
display_spike_time = (
max_spike_time - start_time_errors if start_time_errors is not None else max_spike_time
)
annotation_text = (
f"{spike_count} acquisition spike(s) filtered\n"
f"(max error: {max_spike_error:.0f} ft at t={display_spike_time:.1f}s)\n"
f"Spikes excluded from statistics"
)
txt = self.ax.text(
0.02,
0.98,
annotation_text,
transform=self.ax.transAxes,
verticalalignment="top",
bbox=dict(boxstyle="round", facecolor="yellow", alpha=0.7),
fontsize=8,
)
self.ax._spike_annotations.append(txt)
self.ax.relim()
self.ax.autoscale_view()
# Update latency plot (if available)
if self.ax_latency and self.line_latency:
# Convert to elapsed time (seconds from start) for latency
if latency_timestamps:
start_time_latency = min(latency_timestamps)
elapsed_times_latency = [
t - start_time_latency for t in latency_timestamps
]
else:
elapsed_times_latency = []
# Downsample latency data for plotting if needed
ts_plot_latency = elapsed_times_latency
lat_plot = latency_values
if len(elapsed_times_latency) > DOWNSAMPLE_THRESHOLD:
step = len(elapsed_times_latency) // DOWNSAMPLE_THRESHOLD
ts_plot_latency = elapsed_times_latency[::step]
lat_plot = latency_values[::step]
self.line_latency.set_data(ts_plot_latency, lat_plot)
self.ax_latency.relim()
self.ax_latency.autoscale_view()
self.canvas.draw_idle()
def _filter_spikes(
self, timestamps: List[float], errors: Dict[str, List[float]]
) -> tuple:
"""Filters acquisition spikes from error data."""
if not timestamps:
return timestamps, errors, 0, 0.0, 0.0
# Calculate magnitude for each point
magnitudes = []
for i in range(len(timestamps)):
mag = math.sqrt(
errors["x"][i] ** 2 + errors["y"][i] ** 2 + errors["z"][i] ** 2
)
magnitudes.append(mag)
# Sample a window 5-15 seconds into the simulation to compute threshold
# --- MODIFIED PART START ---
# Using numpy for robust median calculation, especially for large datasets
# Ensure there are enough samples before trying to compute median
if np is not None and len(magnitudes) >= 10:
median_mag = np.median(magnitudes)
else:
median_mag = 0.0
# Threshold: SPIKE_THRESHOLD_FACTOR times the median error magnitude in the sample window,
# or a minimum value (SPIKE_THRESHOLD_MIN_FT) if median is very small.
# This prevents over-filtering legitimate small errors when median is near zero.
threshold = max(median_mag * SPIKE_THRESHOLD_FACTOR, SPIKE_THRESHOLD_MIN_FT)
# --- MODIFIED PART END ---
# Filter out spikes
filtered_ts = []
filtered_errors = {"x": [], "y": [], "z": []}
spike_count = 0
max_spike_error = 0.0
max_spike_time = 0.0
for i in range(len(timestamps)):
if magnitudes[i] > threshold:
spike_count += 1
if magnitudes[i] > max_spike_error:
max_spike_error = magnitudes[i]
max_spike_time = timestamps[i]
else:
filtered_ts.append(timestamps[i])
filtered_errors["x"].append(errors["x"][i])
filtered_errors["y"].append(errors["y"][i])
filtered_errors["z"].append(errors["z"][i])
return (
filtered_ts,
filtered_errors,
spike_count,
max_spike_error,
max_spike_time,
)
def _load_latency_data(self) -> Tuple[List[float], List[float]]:
"""Loads raw latency data from the CSV file."""
timestamps = []
latencies = []
if not os.path.exists(self.latency_filepath):
return [], []
try:
with open(self.latency_filepath, "r", encoding="utf-8") as f:
reader = csv.DictReader(line for line in f if not line.startswith("#"))
for row in reader:
try:
timestamps.append(float(row["timestamp"]))
latencies.append(float(row["latency_ms"]))
except (ValueError, KeyError):
continue
return timestamps, latencies
except Exception as e:
print(f"Warning: Failed to load raw latency data: {e}")
return [], []
def _filter_spikes_latency(
self, timestamps: List[float], values: List[float]
) -> Tuple[List[float], List[float], int, float, float]:
"""Filters acquisition spikes from latency data."""
if not timestamps:
return timestamps, values, 0, 0.0, 0.0
# --- MODIFIED PART START ---
# Using numpy for robust median calculation
if np is not None and len(values) >= 10:
median_val = np.median(values)
else:
median_val = 0.0
# Threshold for latency: a higher factor or a fixed max. Example: 5x median or 50ms
# The SPIKE_THRESHOLD_FACTOR is designed for position errors (feet).
# For latency (ms), a different threshold might be more appropriate.
# Let's use a fixed maximum of 50ms (or a factor if median is very high, unlikely for latency)
latency_spike_threshold = max(median_val * 5, 50.0) # 5x median or 50ms, whichever is greater
# --- MODIFIED PART END ---
filtered_ts = []
filtered_values = []
spike_count = 0
max_spike_val = 0.0
max_spike_time = 0.0
for i in range(len(timestamps)):
if values[i] > latency_spike_threshold:
spike_count += 1
if values[i] > max_spike_val:
max_spike_val = values[i]
max_spike_time = timestamps[i]
else:
filtered_ts.append(timestamps[i])
filtered_values.append(values[i])
return filtered_ts, filtered_values, spike_count, max_spike_val, max_spike_time
def _open_performance_window(self):
"""Opens the dedicated performance analysis window."""
if not self.performance_data_path or not os.path.exists(
self.performance_data_path
):
messagebox.showinfo(
"No Data", "No performance data file found for this run.", parent=self
)
return
try:
PerformanceAnalysisWindow(
parent=self, performance_csv_path=self.performance_data_path
)
except Exception as e:
messagebox.showerror(
"Error", f"Failed to open performance analysis:\n{e}", parent=self
)