S1005403_RisCC/target_simulator/analysis/performance_analyzer.py

149 lines
5.0 KiB
Python

# target_simulator/analysis/performance_analyzer.py
"""
Provides the PerformanceAnalyzer class for calculating error metrics
by comparing simulated data with real-time radar data.
"""
import math
from typing import Dict, List, Optional, Tuple
from target_simulator.analysis.simulation_state_hub import (
SimulationStateHub,
TargetState,
)
# Structure to hold analysis results for a single target
AnalysisResult = Dict[str, Dict[str, float]]
class PerformanceAnalyzer:
"""
Analyzes the performance of the radar tracking by comparing simulated
'ground truth' data against the real data received from the radar.
"""
def __init__(self, hub: SimulationStateHub):
"""
Initializes the analyzer with a reference to the data hub.
Args:
hub: The SimulationStateHub containing the historical data.
"""
self._hub = hub
def analyze(self) -> Dict[int, AnalysisResult]:
"""
Performs a full analysis on all targets currently in the hub.
For each target, it aligns the real and simulated data streams
temporally using linear interpolation and calculates key performance
metrics like Mean Error and Root Mean Square Error (RMSE).
Returns:
A dictionary where keys are target IDs and values are the
analysis results for that target.
"""
results: Dict[int, AnalysisResult] = {}
target_ids = self._hub.get_all_target_ids()
for tid in target_ids:
history = self._hub.get_target_history(tid)
if not history or not history["real"] or len(history["simulated"]) < 2:
# Not enough data to perform analysis
continue
simulated_history = sorted(
history["simulated"]
) # Ensure sorted by timestamp
real_history = history["real"]
errors_x: List[float] = []
errors_y: List[float] = []
errors_z: List[float] = []
for real_state in real_history:
real_ts, real_x, real_y, real_z = real_state
# Find the two simulated points that bracket the real point in time
p1, p2 = self._find_bracketing_points(real_ts, simulated_history)
if p1 and p2:
# We have bracketing points, so we can interpolate
interpolated_state = self._interpolate(real_ts, p1, p2)
_interp_ts, interp_x, interp_y, interp_z = interpolated_state
# Calculate instantaneous error
errors_x.append(real_x - interp_x)
errors_y.append(real_y - interp_y)
errors_z.append(real_z - interp_z)
# If we have collected errors, calculate statistics
if errors_x:
results[tid] = {
"x": self._calculate_stats(errors_x),
"y": self._calculate_stats(errors_y),
"z": self._calculate_stats(errors_z),
}
return results
def _find_bracketing_points(
self, timestamp: float, history: List[TargetState]
) -> Tuple[Optional[TargetState], Optional[TargetState]]:
"""
Finds two points in a time-sorted history that surround a given timestamp.
"""
p1, p2 = None, None
for i in range(len(history) - 1):
if history[i][0] <= timestamp <= history[i + 1][0]:
p1 = history[i]
p2 = history[i + 1]
break
return p1, p2
def _interpolate(
self, timestamp: float, p1: TargetState, p2: TargetState
) -> TargetState:
"""
Performs linear interpolation between two state points (p1 and p2)
to estimate the state at a given timestamp.
"""
ts1, x1, y1, z1 = p1
ts2, x2, y2, z2 = p2
# Avoid division by zero if timestamps are identical
duration = ts2 - ts1
if duration == 0:
return p1
# Calculate interpolation factor (how far timestamp is between ts1 and ts2)
factor = (timestamp - ts1) / duration
# Interpolate each coordinate
interp_x = x1 + (x2 - x1) * factor
interp_y = y1 + (y2 - y1) * factor
interp_z = z1 + (z2 - z1) * factor
return (timestamp, interp_x, interp_y, interp_z)
def _calculate_stats(self, errors: List[float]) -> Dict[str, float]:
"""Calculates mean, variance, and RMSE for a list of errors."""
n = len(errors)
if n == 0:
return {"mean": 0, "variance": 0, "std_dev": 0, "rmse": 0}
mean = sum(errors) / n
# Variance and Standard Deviation
variance = sum((x - mean) ** 2 for x in errors) / n
std_dev = math.sqrt(variance)
# Root Mean Square Error
rmse = math.sqrt(sum(x**2 for x in errors) / n)
return {
"mean": mean,
"variance": variance,
"std_dev": std_dev,
"rmse": rmse,
}