S1005403_RisCC/target_simulator/gui/analysis_window.py

631 lines
23 KiB
Python

# target_simulator/gui/analysis_window.py
"""
A Toplevel window for displaying performance analysis by processing
an efficient trail file.
"""
import tkinter as tk
from tkinter import ttk, messagebox
import json
import os
import csv
import math
import statistics
import warnings
from typing import Optional, Dict, List, Any, Tuple
from target_simulator.gui.performance_analysis_window import PerformanceAnalysisWindow
try:
import numpy as np
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import (
FigureCanvasTkAgg,
NavigationToolbar2Tk,
)
MATPLOTLIB_AVAILABLE = True
except ImportError:
np = None
MATPLOTLIB_AVAILABLE = False
# Constants for analysis
DOWNSAMPLE_THRESHOLD = 4000 # Number of points before downsampling is applied
class AnalysisWindow(tk.Toplevel):
"""
A window that displays tracking performance analysis by loading data
from an archive's main JSON file and its associated `.trail.csv`.
"""
def __init__(self, master, archive_filepath: str):
super().__init__(master)
self.title(f"Analysis for: {os.path.basename(archive_filepath)}")
self.geometry("1100x800")
self.archive_filepath = archive_filepath
self.trail_filepath: Optional[str] = None
self.performance_data_path: Optional[str] = None
self.scenario_name = "Unknown"
self.target_ids: List[int] = []
self.selected_target_id = tk.IntVar()
self._show_loading_window(archive_filepath)
def _load_data_and_setup(self, filepath: str):
"""Loads metadata from the main archive and finds associated data files."""
try:
with open(filepath, "r", encoding="utf-8") as f:
archive_data = json.load(f)
except Exception as e:
raise IOError(f"Could not load archive file: {e}")
metadata = archive_data.get("metadata", {})
self.scenario_name = metadata.get("scenario_name", "Unknown")
self.title(f"Analysis - {self.scenario_name}")
# Find the associated trail, latency, and performance files
base_path, _ = os.path.splitext(filepath)
self.trail_filepath = f"{base_path}.trail.csv"
self.latency_filepath = f"{base_path}.latency.csv"
self.performance_data_path = f"{base_path}.perf.csv"
if not os.path.exists(self.trail_filepath):
raise FileNotFoundError(
f"Required trail file not found: {self.trail_filepath}"
)
# Get available target IDs from the trail file header
with open(self.trail_filepath, "r", encoding="utf-8") as f:
reader = csv.reader(f)
headers = next(reader, [])
if "target_id" not in headers:
raise ValueError("Trail file missing 'target_id' column.")
target_id_index = headers.index("target_id")
ids = set()
for row in reader:
if row and not row[0].startswith("#"):
try:
ids.add(int(row[target_id_index]))
except (ValueError, IndexError):
continue
self.target_ids = sorted(list(ids))
def _show_loading_window(self, archive_filepath: str):
"""Shows a loading dialog and loads data in the background."""
loading_dialog = tk.Toplevel(self)
# ... (loading dialog implementation is unchanged)
loading_dialog.title("Loading Analysis")
loading_dialog.geometry("400x150")
loading_dialog.transient(self)
loading_dialog.grab_set()
loading_dialog.update_idletasks()
x = (
self.winfo_x()
+ (self.winfo_width() // 2)
- (loading_dialog.winfo_width() // 2)
)
y = (
self.winfo_y()
+ (self.winfo_height() // 2)
- (loading_dialog.winfo_height() // 2)
)
loading_dialog.geometry(f"+{x}+{y}")
ttk.Label(
loading_dialog, text="Loading simulation data...", font=("Segoe UI", 11)
).pack(pady=(20, 10))
progress_label = ttk.Label(
loading_dialog, text="Please wait", font=("Segoe UI", 9)
)
progress_label.pack(pady=5)
progress = ttk.Progressbar(loading_dialog, mode="indeterminate", length=300)
progress.pack(pady=10)
progress.start(10)
def load_and_display():
try:
progress_label.config(text="Locating data files...")
self.update()
self._load_data_and_setup(archive_filepath)
progress_label.config(text="Creating widgets...")
self.update()
self._create_widgets()
progress_label.config(text="Ready.")
self.update()
loading_dialog.destroy()
# Trigger initial analysis
self._on_target_select()
except Exception as e:
loading_dialog.destroy()
messagebox.showerror(
"Analysis Error", f"Failed to load analysis:\n{e}", parent=self
)
self.destroy()
self.after(100, load_and_display)
def _create_widgets(self):
main_pane = ttk.PanedWindow(self, orient=tk.VERTICAL)
main_pane.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
stats_frame = ttk.LabelFrame(main_pane, text="Error Statistics (feet)")
main_pane.add(stats_frame, weight=1)
self._create_stats_widgets(stats_frame)
plot_frame = ttk.LabelFrame(main_pane, text="Error Over Time (feet)")
main_pane.add(plot_frame, weight=4)
self._create_plot_widgets(plot_frame)
def _create_stats_widgets(self, parent):
# Configure grid per il layout
parent.rowconfigure(0, weight=0) # Top bar
parent.rowconfigure(1, weight=1) # Content area
parent.columnconfigure(0, weight=1)
# Top bar con combobox e pulsante
top_bar = ttk.Frame(parent, padding=5)
top_bar.grid(row=0, column=0, sticky="ew")
ttk.Label(top_bar, text="Select Target ID:").pack(side=tk.LEFT)
self.target_selector = ttk.Combobox(
top_bar,
textvariable=self.selected_target_id,
state="readonly",
width=5,
values=self.target_ids,
)
self.target_selector.pack(side=tk.LEFT, padx=5)
self.target_selector.bind("<<ComboboxSelected>>", self._on_target_select)
if self.target_ids:
self.selected_target_id.set(self.target_ids[0])
# Performance Analysis button (always visible, disabled if no data)
perf_button = ttk.Button(
top_bar,
text="Open Performance Analysis",
command=self._open_performance_window,
)
perf_button.pack(side=tk.LEFT, padx=(20, 0))
if not os.path.exists(self.performance_data_path):
perf_button.config(state="disabled")
# Content container diviso in due colonne
content_frame = ttk.Frame(parent)
content_frame.grid(row=1, column=0, sticky="nsew", padx=5, pady=5)
content_frame.columnconfigure(0, weight=1, uniform="half")
content_frame.columnconfigure(1, weight=1, uniform="half")
content_frame.rowconfigure(0, weight=1)
# Left: Stats table
table_container = ttk.Frame(content_frame)
table_container.grid(row=0, column=0, sticky="nsew", padx=(0, 2))
columns = ("error_type", "mean", "std_dev", "rmse")
self.stats_tree = ttk.Treeview(
table_container, columns=columns, show="headings", height=4
)
self.stats_tree.heading("error_type", text="")
self.stats_tree.heading("mean", text="Mean (ft)")
self.stats_tree.heading("std_dev", text="Std Dev (ft)")
self.stats_tree.heading("rmse", text="RMSE (ft)")
self.stats_tree.column("error_type", width=100, anchor=tk.W)
self.stats_tree.column("mean", anchor=tk.E, width=120)
self.stats_tree.column("std_dev", anchor=tk.E, width=120)
self.stats_tree.column("rmse", anchor=tk.E, width=120)
self.stats_tree.pack(fill=tk.BOTH, expand=True)
# Right: Legend frame
legend_frame = ttk.Frame(content_frame)
legend_frame.grid(row=0, column=1, sticky="nsew", padx=(2, 0))
legend_title = ttk.Label(
legend_frame, text="How to Interpret Results:", font=("Segoe UI", 9, "bold")
)
legend_title.pack(anchor=tk.NW, pady=(0, 5))
explanation_text = (
"Error = Real - Simulated Position\n\n"
"Sign (e.g., X axis):\n"
"• Positive: Real target at larger X\n"
"• Negative: Real target at smaller X\n\n"
"Spike Filtering:\n"
"Transients >20x median filtered\n"
"from plots and statistics.\n\n"
"Latency:\n"
"Time from packet generation\n"
"(server) to reception (client)."
)
ttk.Label(
legend_frame, text=explanation_text, justify=tk.LEFT, font=("Segoe UI", 9)
).pack(anchor=tk.NW, fill=tk.BOTH, expand=True)
def _create_plot_widgets(self, parent):
if not MATPLOTLIB_AVAILABLE:
ttk.Label(parent, text="Matplotlib is required for plotting.").pack()
return
fig = Figure(figsize=(5, 7), dpi=100)
# Check if latency file exists to determine subplot layout
has_latency = os.path.exists(self.latency_filepath)
if has_latency:
# Two subplots: errors (top) and latency (bottom)
gs = fig.add_gridspec(2, 1, height_ratios=[2, 1], hspace=0.3, top=0.95)
self.ax = fig.add_subplot(gs[0, 0])
self.ax_latency = fig.add_subplot(gs[1, 0], sharex=self.ax)
else:
# Single subplot: just errors
self.ax = fig.add_subplot(111)
self.ax_latency = None
# Error plot
self.ax.set_title("Instantaneous Error")
self.ax.set_ylabel("Error (ft)")
(self.line_x,) = self.ax.plot([], [], lw=1.5, label="Error X", color="#1f77b4")
(self.line_y,) = self.ax.plot([], [], lw=1.5, label="Error Y", color="#ff7f0e")
(self.line_z,) = self.ax.plot([], [], lw=1.5, label="Error Z", color="#2ca02c")
self.ax.grid(True, alpha=0.3)
self.ax.axhline(0.0, color="black", lw=1, linestyle="--", alpha=0.5)
self.ax.legend(loc="upper right", fontsize=9)
if not has_latency:
self.ax.set_xlabel("Elapsed Time (s)")
# Latency plot (if file exists)
if has_latency:
self.ax_latency.set_title("Latency Evolution")
self.ax_latency.set_xlabel("Elapsed Time (s)")
self.ax_latency.set_ylabel("Latency (ms)")
(self.line_latency,) = self.ax_latency.plot(
[], [], lw=1.5, color="#d62728", label="Latency"
)
self.ax_latency.grid(True, alpha=0.3)
self.ax_latency.legend(loc="upper right", fontsize=9)
else:
self.line_latency = None
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
fig.tight_layout()
canvas_frame = ttk.Frame(parent)
canvas_frame.pack(fill=tk.BOTH, expand=True)
toolbar_frame = ttk.Frame(canvas_frame)
toolbar_frame.pack(side=tk.TOP, fill=tk.X)
self.canvas = FigureCanvasTkAgg(fig, master=canvas_frame)
toolbar = NavigationToolbar2Tk(self.canvas, toolbar_frame)
toolbar.update()
self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
def _on_target_select(self, event=None):
"""Initiates analysis for the selected target."""
if not self.trail_filepath:
return
target_id = self.selected_target_id.get()
# Analyze data (fast operation now)
timestamps, errors, stats = self._analyze_trail_file(target_id)
# Update UI - load latency first so stats table can include it
self._update_latency_plot()
self._update_stats_table(stats)
self._update_plot(timestamps, errors)
def _analyze_trail_file(
self, target_id: int
) -> Tuple[List[float], Dict[str, List[float]], Dict[str, Dict[str, float]]]:
"""
Analyzes the trail file for a specific target using an efficient
two-pointer algorithm.
"""
sim_points = []
real_points = []
with open(self.trail_filepath, "r", encoding="utf-8") as f:
reader = csv.DictReader(line for line in f if not line.startswith("#"))
for row in reader:
try:
if int(row["target_id"]) == target_id:
point = (
float(row["timestamp"]),
float(row["x_ft"]),
float(row["y_ft"]),
float(row["z_ft"]),
)
if row["source"] == "simulated":
sim_points.append(point)
elif row["source"] == "real":
real_points.append(point)
except (ValueError, KeyError):
continue
if not sim_points or not real_points:
return [], {}, {}
# --- Two-Pointer Algorithm for Error Calculation ---
timestamps, errors_x, errors_y, errors_z = [], [], [], []
sim_idx = 0
for real_p in real_points:
real_ts, real_x, real_y, real_z = real_p
# Advance sim_idx to find the bracketing segment for the current real point
while (
sim_idx + 1 < len(sim_points) and sim_points[sim_idx + 1][0] < real_ts
):
sim_idx += 1
if sim_idx + 1 < len(sim_points):
p1 = sim_points[sim_idx]
p2 = sim_points[sim_idx + 1]
# Check if the real point is within this segment
if p1[0] <= real_ts <= p2[0]:
# Interpolate
ts1, x1, y1, z1 = p1
ts2, x2, y2, z2 = p2
duration = ts2 - ts1
if duration == 0:
continue
factor = (real_ts - ts1) / duration
interp_x = x1 + (x2 - x1) * factor
interp_y = y1 + (y2 - y1) * factor
interp_z = z1 + (z2 - z1) * factor
timestamps.append(real_ts)
errors_x.append(real_x - interp_x)
errors_y.append(real_y - interp_y)
errors_z.append(real_z - interp_z)
errors = {"x": errors_x, "y": errors_y, "z": errors_z}
# Calculate final statistics on the full (non-downsampled) data
stats = {}
for axis, err_list in errors.items():
if not err_list:
stats[axis] = {"mean": 0, "std_dev": 0, "rmse": 0}
continue
mean = statistics.mean(err_list)
stdev = statistics.stdev(err_list) if len(err_list) > 1 else 0
rmse = math.sqrt(sum(e**2 for e in err_list) / len(err_list))
stats[axis] = {"mean": mean, "std_dev": stdev, "rmse": rmse}
return timestamps, errors, stats
def _downsample_data(self, timestamps: List, errors: Dict) -> Tuple[List, Dict]:
"""Reduces the number of points for plotting while preserving shape."""
if len(timestamps) <= DOWNSAMPLE_THRESHOLD:
return timestamps, errors
# Simple interval-based downsampling
step = len(timestamps) // DOWNSAMPLE_THRESHOLD
ts_down = timestamps[::step]
err_down = {
"x": errors["x"][::step],
"y": errors["y"][::step],
"z": errors["z"][::step],
}
return ts_down, err_down
def _update_stats_table(self, stats: Dict):
"""Populates the stats Treeview with calculated metrics."""
self.stats_tree.delete(*self.stats_tree.get_children())
for axis, data in stats.items():
self.stats_tree.insert(
"",
"end",
values=(
f"Error {axis.upper()}",
f"{data['mean']:.3f}",
f"{data['std_dev']:.3f}",
f"{data['rmse']:.3f}",
),
)
# Add latency statistics if available
if hasattr(self, "_latency_data") and self._latency_data:
lat_mean = statistics.mean(self._latency_data)
lat_std = (
statistics.stdev(self._latency_data)
if len(self._latency_data) > 1
else 0.0
)
lat_min = min(self._latency_data)
lat_max = max(self._latency_data)
self.stats_tree.insert(
"",
"end",
values=(
"Latency (ms)",
f"{lat_mean:.2f}",
f"{lat_std:.2f}",
f"{lat_min:.2f} - {lat_max:.2f}",
),
)
def _update_plot(self, timestamps: List[float], errors: Dict[str, List[float]]):
"""Updates the matplotlib plot with (potentially downsampled) data."""
# Apply spike filtering
filtered_ts, filtered_errors, spike_count, max_spike_error, max_spike_time = (
self._filter_spikes(timestamps, errors)
)
# Convert to elapsed time (seconds from start)
if filtered_ts:
start_time = min(filtered_ts)
elapsed_times = [t - start_time for t in filtered_ts]
else:
elapsed_times = []
# Downsample if needed
ts_plot, errors_plot = self._downsample_data(elapsed_times, filtered_errors)
self.line_x.set_data(ts_plot, errors_plot["x"])
self.line_y.set_data(ts_plot, errors_plot["y"])
self.line_z.set_data(ts_plot, errors_plot["z"])
# Remove old spike annotations
for txt in getattr(self.ax, "_spike_annotations", []):
txt.remove()
self.ax._spike_annotations = []
# Add spike annotation if any were filtered
if spike_count > 0:
annotation_text = (
f"{spike_count} acquisition spike(s) filtered\n"
f"(max error: {max_spike_error:.0f} ft at t={max_spike_time:.1f}s)\n"
f"Spikes excluded from statistics"
)
txt = self.ax.text(
0.02,
0.98,
annotation_text,
transform=self.ax.transAxes,
verticalalignment="top",
bbox=dict(boxstyle="round", facecolor="yellow", alpha=0.7),
fontsize=8,
)
self.ax._spike_annotations.append(txt)
self.ax.relim()
self.ax.autoscale_view()
self.canvas.draw_idle()
def _filter_spikes(
self, timestamps: List[float], errors: Dict[str, List[float]]
) -> tuple:
"""Filters acquisition spikes from error data."""
if not timestamps:
return timestamps, errors, 0, 0.0, 0.0
# Calculate magnitude for each point
magnitudes = []
for i in range(len(timestamps)):
mag = math.sqrt(
errors["x"][i] ** 2 + errors["y"][i] ** 2 + errors["z"][i] ** 2
)
magnitudes.append(mag)
# Sample a window 5-15 seconds into the simulation to compute threshold
min_time = min(timestamps)
sample_mags = []
for i, t in enumerate(timestamps):
if min_time + 5.0 <= t <= min_time + 15.0:
sample_mags.append(magnitudes[i])
if not sample_mags:
return timestamps, errors, 0, 0.0, 0.0
# Threshold: 20x the median error magnitude in the sample window
threshold = max(statistics.median(sample_mags) * 20, 500.0)
# Filter out spikes
filtered_ts = []
filtered_errors = {"x": [], "y": [], "z": []}
spike_count = 0
max_spike_error = 0.0
max_spike_time = 0.0
for i in range(len(timestamps)):
if magnitudes[i] > threshold:
spike_count += 1
if magnitudes[i] > max_spike_error:
max_spike_error = magnitudes[i]
max_spike_time = timestamps[i]
else:
filtered_ts.append(timestamps[i])
filtered_errors["x"].append(errors["x"][i])
filtered_errors["y"].append(errors["y"][i])
filtered_errors["z"].append(errors["z"][i])
return (
filtered_ts,
filtered_errors,
spike_count,
max_spike_error,
max_spike_time,
)
def _update_latency_plot(self):
"""Updates the latency subplot with data from the latency CSV file."""
if not self.ax_latency or not self.line_latency:
self._latency_data = []
return
if not os.path.exists(self.latency_filepath):
self.line_latency.set_data([], [])
self._latency_data = []
self.ax_latency.relim()
self.ax_latency.autoscale_view()
self.canvas.draw_idle()
return
timestamps = []
latencies = []
try:
with open(self.latency_filepath, "r", encoding="utf-8") as f:
reader = csv.DictReader(line for line in f if not line.startswith("#"))
for row in reader:
try:
timestamps.append(float(row["timestamp"]))
latencies.append(float(row["latency_ms"]))
except (ValueError, KeyError):
continue
# Save full data for statistics
self._latency_data = latencies
# Convert to elapsed time (seconds from start)
if timestamps:
start_time = min(timestamps)
elapsed_times = [t - start_time for t in timestamps]
else:
elapsed_times = []
# Downsample for plotting if needed
ts_plot = elapsed_times
lat_plot = latencies
if len(elapsed_times) > DOWNSAMPLE_THRESHOLD:
step = len(elapsed_times) // DOWNSAMPLE_THRESHOLD
ts_plot = elapsed_times[::step]
lat_plot = latencies[::step]
self.line_latency.set_data(ts_plot, lat_plot)
self.ax_latency.relim()
self.ax_latency.autoscale_view()
self.canvas.draw_idle()
except Exception as e:
self.line_latency.set_data([], [])
self._latency_data = []
print(f"Warning: Failed to load latency data: {e}")
def _open_performance_window(self):
"""Opens the dedicated performance analysis window."""
if not self.performance_data_path or not os.path.exists(
self.performance_data_path
):
messagebox.showinfo(
"No Data", "No performance data file found for this run.", parent=self
)
return
try:
PerformanceAnalysisWindow(
parent=self, performance_csv_path=self.performance_data_path
)
except Exception as e:
messagebox.showerror(
"Error", f"Failed to open performance analysis:\n{e}", parent=self
)