S1005403_RisCC/target_simulator/gui/analysis_window.py
2025-11-14 15:47:48 +01:00

400 lines
17 KiB
Python

"""
A Toplevel window for displaying real-time performance analysis, including
error statistics and plots.
"""
import tkinter as tk
from tkinter import ttk, messagebox
import json
import os
import csv
from typing import Optional, Dict, List, Any
from target_simulator.analysis.performance_analyzer import PerformanceAnalyzer
from target_simulator.analysis.simulation_state_hub import SimulationStateHub
from target_simulator.gui.performance_analysis_window import PerformanceAnalysisWindow
try:
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
MATPLOTLIB_AVAILABLE = True
except ImportError:
MATPLOTLIB_AVAILABLE = False
class AnalysisWindow(tk.Toplevel):
"""
A window that displays real-time analysis of tracking performance.
"""
def __init__(self, master, archive_filepath: str):
super().__init__(master)
self.title(f"Analysis for: {os.path.basename(archive_filepath)}")
self.geometry("900x750")
self.archive_filepath = archive_filepath
self.performance_data_path: Optional[str] = None
self.scenario_name = "Unknown"
# State variables
self.selected_target_id = tk.IntVar(value=0)
self._active = True
self._filtered_errors = None
self._show_loading_window(archive_filepath)
def _load_data_and_setup(self, filepath: str):
"""Loads data from the main archive and finds the performance data file."""
try:
with open(filepath, "r", encoding="utf-8") as f:
archive_data = json.load(f)
except Exception as e:
raise IOError(f"Could not load archive file.\n{e}")
metadata = archive_data.get("metadata", {})
self.estimated_latency_ms = metadata.get("estimated_latency_ms")
self.prediction_offset_ms = metadata.get("prediction_offset_ms")
self.scenario_name = metadata.get("scenario_name", "Unknown")
latency_samples = metadata.get("latency_samples", [])
self.latency_timestamps = [s[0] for s in latency_samples if isinstance(s, list) and len(s) > 1]
self.latency_values_ms = [s[1] for s in latency_samples if isinstance(s, list) and len(s) > 1]
self._hub = SimulationStateHub(history_size=100000)
results = archive_data.get("simulation_results", {})
for target_id_str, data in results.items():
target_id = int(target_id_str)
for state in data.get("simulated", []):
self._hub.add_simulated_state(target_id, state[0], tuple(state[1:]))
for state in data.get("real", []):
self._hub.add_real_state(target_id, state[0], tuple(state[1:]))
self._analyzer = PerformanceAnalyzer(self._hub)
# Find the associated performance data file
self.performance_data_path = self._find_performance_data_file(filepath)
def _find_performance_data_file(self, archive_path: str) -> Optional[str]:
"""Finds the .perf.csv or .perf.json file associated with an archive."""
base_path, _ = os.path.splitext(archive_path)
# Prefer the new CSV format
csv_path = f"{base_path}.perf.csv"
if os.path.exists(csv_path):
return csv_path
# Fallback to the old JSON format for backward compatibility
json_path = f"{base_path}.perf.json"
if os.path.exists(json_path):
return json_path
return None
def _show_loading_window(self, archive_filepath: str):
"""Show a loading dialog and load data asynchronously."""
loading_dialog = tk.Toplevel(self)
loading_dialog.title("Loading Analysis")
loading_dialog.geometry("400x150")
loading_dialog.transient(self)
loading_dialog.grab_set()
loading_dialog.update_idletasks()
x = self.winfo_x() + (self.winfo_width() // 2) - (loading_dialog.winfo_width() // 2)
y = self.winfo_y() + (self.winfo_height() // 2) - (loading_dialog.winfo_height() // 2)
loading_dialog.geometry(f"+{x}+{y}")
ttk.Label(loading_dialog, text="Loading simulation data...", font=("Segoe UI", 11)).pack(pady=(20, 10))
progress_label = ttk.Label(loading_dialog, text="Please wait", font=("Segoe UI", 9))
progress_label.pack(pady=5)
progress = ttk.Progressbar(loading_dialog, mode='indeterminate', length=300)
progress.pack(pady=10)
progress.start(10)
def load_and_display():
try:
progress_label.config(text="Reading archive file...")
self.update()
self._load_data_and_setup(archive_filepath)
progress_label.config(text="Creating widgets...")
self.update()
self._create_widgets()
progress_label.config(text="Analyzing data...")
self.update()
self._populate_analysis()
loading_dialog.destroy()
except Exception as e:
loading_dialog.destroy()
messagebox.showerror("Analysis Error", f"Failed to load analysis:\n{e}", parent=self)
self.destroy()
self.after(100, load_and_display)
def _populate_analysis(self):
"""Runs the analysis and populates the widgets once."""
self._update_target_selector()
target_ids = self.target_selector["values"]
if target_ids:
self.selected_target_id.set(target_ids[0])
self._on_target_select()
def _create_widgets(self):
main_pane = ttk.PanedWindow(self, orient=tk.VERTICAL)
main_pane.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
stats_frame = ttk.LabelFrame(main_pane, text="Error Statistics (feet)")
main_pane.add(stats_frame, weight=1)
self._create_stats_widgets(stats_frame)
plot_frame = ttk.LabelFrame(main_pane, text="Error Over Time (feet)")
main_pane.add(plot_frame, weight=4)
self._create_plot_widgets(plot_frame)
def _create_stats_widgets(self, parent):
container = ttk.Frame(parent)
container.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
left = ttk.Frame(container)
left.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
right = ttk.Frame(container)
right.pack(side=tk.RIGHT, fill=tk.Y)
top_bar = ttk.Frame(left)
top_bar.pack(fill=tk.X, padx=0, pady=(0, 6))
ttk.Label(top_bar, text="Select Target ID:").pack(side=tk.LEFT)
self.target_selector = ttk.Combobox(
top_bar, textvariable=self.selected_target_id, state="readonly", width=5
)
self.target_selector.pack(side=tk.LEFT, padx=5)
self.target_selector.bind("<<ComboboxSelected>>", self._on_target_select)
sync_frame = ttk.Frame(top_bar)
sync_frame.pack(side=tk.LEFT, padx=(20, 0))
if self.estimated_latency_ms is not None:
ttk.Label(sync_frame, text="Avg. Latency:").pack(side=tk.LEFT)
ttk.Label(
sync_frame, text=f"{self.estimated_latency_ms:.1f} ms",
font=("Segoe UI", 9, "bold"), foreground="blue"
).pack(side=tk.LEFT, padx=4)
if self.prediction_offset_ms is not None:
ttk.Label(sync_frame, text="Prediction Offset:").pack(side=tk.LEFT, padx=(10, 0))
ttk.Label(
sync_frame, text=f"{self.prediction_offset_ms:.1f} ms",
font=("Segoe UI", 9, "bold"), foreground="green"
).pack(side=tk.LEFT, padx=4)
# The button is now conditional
if self.performance_data_path:
perf_button = ttk.Button(
sync_frame, text="Performance Analysis...", command=self._open_performance_window
)
perf_button.pack(side=tk.LEFT, padx=(20, 0))
columns = ("error_type", "mean", "std_dev", "rmse")
self.stats_tree = ttk.Treeview(left, columns=columns, show="headings", height=4)
self.stats_tree.heading("error_type", text="")
self.stats_tree.heading("mean", text="Mean (ft)")
self.stats_tree.heading("std_dev", text="Std Dev (ft)")
self.stats_tree.heading("rmse", text="RMSE (ft)")
self.stats_tree.column("error_type", width=100, anchor=tk.W)
self.stats_tree.column("mean", anchor=tk.E, width=100)
self.stats_tree.column("std_dev", anchor=tk.E, width=100)
self.stats_tree.column("rmse", anchor=tk.E, width=100)
self.stats_tree.pack(fill=tk.BOTH, expand=True)
legend_title = ttk.Label(right, text="How to Interpret Results:", font=(None, 9, "bold"))
legend_title.pack(anchor=tk.NW, padx=(6, 6), pady=(4, 4))
explanation_text = (
"Formula: Error = Real Position - Simulated Position\n\n"
"Sign of Error (e.g., on X axis):\n"
"• Positive Error (+): Real target is at a larger X coordinate.\n"
"• Negative Error (-): Real target is at a smaller X coordinate.\n\n"
"Prediction Offset:\n"
"A manual offset to compensate for server processing delay."
)
ttk.Label(right, text=explanation_text, justify=tk.LEFT, wraplength=280).pack(anchor=tk.NW, padx=(6, 6))
def _create_plot_widgets(self, parent):
fig = Figure(figsize=(5, 6), dpi=100)
gs = fig.add_gridspec(2, 1, height_ratios=[2, 1], hspace=0.35, top=0.95)
self.ax = fig.add_subplot(gs[0, 0])
self.ax.set_title("Instantaneous Error")
self.ax.set_xlabel("Time (s)")
self.ax.set_ylabel("Error (ft)")
(self.line_x,) = self.ax.plot([], [], lw=2, label="Error X")
(self.line_y,) = self.ax.plot([], [], lw=2, label="Error Y")
(self.line_z,) = self.ax.plot([], [], lw=2, label="Error Z")
self.ax.grid(True)
self.ax.axhline(0.0, color="black", lw=1, linestyle="--", alpha=0.8)
self.ax.legend(loc="upper right", fontsize=9)
self.ax_latency = fig.add_subplot(gs[1, 0], sharex=self.ax)
self.ax_latency.set_title("Latency Evolution")
self.ax_latency.set_xlabel("Time (s)")
self.ax_latency.set_ylabel("Latency (ms)")
(self.line_latency,) = self.ax_latency.plot([], [], lw=2, color="orange", label="Latency")
self.ax_latency.grid(True)
self.ax_latency.legend(loc="upper right", fontsize=9)
fig.tight_layout()
plot_container = ttk.Frame(parent)
plot_container.pack(fill=tk.BOTH, expand=True)
toolbar_frame = ttk.Frame(plot_container)
toolbar_frame.pack(side=tk.TOP, fill=tk.X)
self.canvas = FigureCanvasTkAgg(fig, master=plot_container)
toolbar = NavigationToolbar2Tk(self.canvas, toolbar_frame)
toolbar.update()
self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
self.canvas.draw()
def _update_target_selector(self):
try:
target_ids = sorted(self._hub.get_all_target_ids())
if target_ids:
self.target_selector["values"] = target_ids
if self.selected_target_id.get() not in target_ids:
self.selected_target_id.set(target_ids[0])
except Exception:
pass
def _update_stats_table(self, results: Dict):
self.stats_tree.delete(*self.stats_tree.get_children())
if hasattr(self, '_filtered_errors') and self._filtered_errors:
import math
for axis in ["x", "y", "z"]:
errors = self._filtered_errors.get(axis, [])
if errors:
n = len(errors)
mean = sum(errors) / n
variance = sum((x - mean) ** 2 for x in errors) / n
std_dev = math.sqrt(variance)
rmse = math.sqrt(sum(x**2 for x in errors) / n)
self.stats_tree.insert("", "end", values=(f"Error {axis.upper()}", f"{mean:.3f}", f"{std_dev:.3f}", f"{rmse:.3f}"))
else:
self.stats_tree.insert("", "end", values=(f"Error {axis.upper()}", "N/A", "N/A", "N/A"))
else:
for axis in ["x", "y", "z"]:
self.stats_tree.insert("", "end", values=(f"Error {axis.upper()}", f"{results[axis]['mean']:.3f}", f"{results[axis]['std_dev']:.3f}", f"{results[axis]['rmse']:.3f}"))
if self.latency_values_ms:
import statistics
lat_mean = statistics.mean(self.latency_values_ms)
lat_std = statistics.stdev(self.latency_values_ms) if len(self.latency_values_ms) > 1 else 0.0
lat_min = min(self.latency_values_ms)
lat_max = max(self.latency_values_ms)
self.stats_tree.insert("", "end", values=("Latency (ms)", f"{lat_mean:.2f}", f"{lat_std:.2f}", f"{lat_min:.2f} - {lat_max:.2f}"))
def _update_plot(self, target_id: int):
history = self._hub.get_target_history(target_id)
if not history or not history["real"] or len(history["simulated"]) < 2:
self._clear_views()
return
times, errors_x, errors_y, errors_z = [], [], [], []
sim_hist = sorted(history["simulated"])
for real_state in history["real"]:
real_ts, real_x, real_y, real_z = real_state
p1, p2 = self._analyzer._find_bracketing_points(real_ts, sim_hist)
if p1 and p2:
_ts, interp_x, interp_y, interp_z = self._analyzer._interpolate(real_ts, p1, p2)
times.append(real_ts)
errors_x.append(real_x - interp_x)
errors_y.append(real_y - interp_y)
errors_z.append(real_z - interp_z)
if not times:
self._clear_views()
return
# Filtering logic
import statistics
sample_errors = []
min_time = min(times)
for i, t in enumerate(times):
if min_time + 5.0 <= t <= min_time + 15.0:
sample_errors.append((errors_x[i]**2 + errors_y[i]**2 + errors_z[i]**2) ** 0.5)
threshold = max(statistics.median(sample_errors) * 20, 500.0) if sample_errors else 1000.0
filtered_times, filtered_x, filtered_y, filtered_z = [], [], [], []
outlier_count = 0
for i, t in enumerate(times):
if (errors_x[i]**2 + errors_y[i]**2 + errors_z[i]**2) ** 0.5 > threshold:
outlier_count += 1
else:
filtered_times.append(t)
filtered_x.append(errors_x[i])
filtered_y.append(errors_y[i])
filtered_z.append(errors_z[i])
self._filtered_errors = {'x': filtered_x, 'y': filtered_y, 'z': filtered_z}
self.line_x.set_data(filtered_times, filtered_x)
self.line_y.set_data(filtered_times, filtered_y)
self.line_z.set_data(filtered_times, filtered_z)
for txt in getattr(self.ax, '_spike_annotations', []):
txt.remove()
self.ax._spike_annotations = []
if outlier_count > 0:
txt = self.ax.text(0.02, 0.98, f"{outlier_count} spike(s) filtered", transform=self.ax.transAxes,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.7), fontsize=9)
self.ax._spike_annotations.append(txt)
self.ax.relim()
self.ax.autoscale_view()
self.canvas.draw_idle()
def _update_latency_plot(self):
if self.latency_values_ms and self.latency_timestamps:
self.line_latency.set_data(self.latency_timestamps, self.latency_values_ms)
else:
self.line_latency.set_data([], [])
self.ax_latency.relim()
self.ax_latency.autoscale_view()
self.canvas.draw_idle()
def _clear_views(self):
self.stats_tree.delete(*self.stats_tree.get_children())
self.line_x.set_data([], [])
self.line_y.set_data([], [])
self.line_z.set_data([], [])
self.line_latency.set_data([], [])
for ax in [self.ax, self.ax_latency]:
ax.relim()
ax.autoscale_view()
self.canvas.draw_idle()
def _open_performance_window(self):
"""Open the dedicated performance analysis window."""
if not self.performance_data_path:
messagebox.showinfo("No Data", "No performance data file found for this simulation run.", parent=self)
return
try:
# Pass the path to the CSV file to the performance window
PerformanceAnalysisWindow(parent=self, performance_csv_path=self.performance_data_path)
except Exception as e:
messagebox.showerror("Performance Analysis Error", f"Failed to open performance analysis:\n{e}", parent=self)
def _on_target_select(self, event=None):
"""Handle combobox selection changes and update stats/plot."""
try:
sel_id = self.selected_target_id.get()
analysis_results = self._analyzer.analyze()
if sel_id in analysis_results:
self._update_plot(sel_id) # Update plot first to calculate filtered errors
self._update_stats_table(analysis_results[sel_id]) # Then update table with filtered stats
else:
self._clear_views()
self._update_latency_plot()
except Exception:
self._clear_views()