SXXXXXXX_PyMsc/pymsc/utils/profiler.py

70 lines
2.3 KiB
Python

import time
import inspect
import csv
import logging
from functools import wraps
from collections import defaultdict
from typing import Dict, Any, Tuple
# Storage for timing statistics
# Key: (function_name, defined_in, called_from)
function_stats: Dict[Tuple[str, str, str], Dict[str, Any]] = defaultdict(
lambda: {"calls": 0, "total_time": 0.0}
)
def monitor_execution(func):
"""
Decorator to measure execution time and track caller modules.
"""
@wraps(func)
def wrapper(*args, **kwargs):
# Identify source and caller
module_defined = inspect.getmodule(func).__name__
caller_frame = inspect.stack()[1]
module_called = caller_frame.frame.f_globals.get("__name__", "unknown")
start_time = time.perf_counter()
result = func(*args, **kwargs)
end_time = time.perf_counter()
duration = end_time - start_time
key = (func.__name__, module_defined, module_called)
stats = function_stats[key]
stats["calls"] += 1
stats["total_time"] += duration
return result
return wrapper
def save_stats_to_csv(file_path: str) -> None:
"""
Exports performance data to a CSV file.
"""
try:
with open(file_path, mode='w', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
header = ["Function", "Calls", "Total Time", "Avg Time", "Defined In", "Called From"]
writer.writerow(header)
for (name, defined, called), data in function_stats.items():
avg = data["total_time"] / data["calls"] if data["calls"] > 0 else 0
row = [name, data["calls"], f"{data['total_time']:.6f}", f"{avg:.6f}", defined, called]
writer.writerow(row)
except Exception as e:
logging.error(f"Failed to save profiler CSV: {e}")
def print_stats() -> None:
"""
Prints a summary table of statistics to the console.
"""
header = f"{'Function':<30} {'Calls':<8} {'Avg Time (s)':<12} {'Defined In':<25}"
print("\n" + "=" * 80)
print(header)
print("-" * 80)
for (name, defined, _), data in function_stats.items():
avg = data["total_time"] / data["calls"] if data["calls"] > 0 else 0
line = f"{name:<30} {data['calls']:<8} {avg:<12.6f} {defined:<25}"
print(line)
print("=" * 80 + "\n")