SXXXXXXX_GeoElevation/geoelevation/visualizer.py
VALLONGOL 13ae515707 Chore: Stop tracking files based on .gitignore update.
Untracked files matching the following rules:
- Rule "_build/": 16 files
2025-05-06 09:24:45 +02:00

366 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# visualizer.py
import logging
import os
from typing import Optional, Union, TYPE_CHECKING
import time # For benchmarking processing time
# --- Dependency Checks ---
try:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # Required for 3D projection
import numpy as np # Required by Matplotlib & for data handling
MATPLOTLIB_AVAILABLE = True
except ImportError:
MATPLOTLIB_AVAILABLE = False
# Define dummy/fallback classes and functions if Matplotlib is missing
class plt: # type: ignore
@staticmethod
def figure(*args, **kwargs): return plt
def add_subplot(self, *args, **kwargs): return plt
def plot_surface(self, *args, **kwargs): return plt
def set_xlabel(self, *args, **kwargs): pass
def set_ylabel(self, *args, **kwargs): pass
def set_zlabel(self, *args, **kwargs): pass
def set_title(self, *args, **kwargs): pass
def set_zlim(self, *args, **kwargs): pass
def colorbar(self, *args, **kwargs): pass
@staticmethod
def show(*args, **kwargs):
logging.warning("Matplotlib not available, cannot show plot.")
@staticmethod
def subplots(*args, **kwargs): return plt, plt # Dummy fig, ax
def imshow(self, *args, **kwargs): pass
def axis(self, *args, **kwargs): pass
def tight_layout(self, *args, **kwargs): pass
class Axes3D: pass # Dummy class
class np: # Minimal numpy dummy # type: ignore
ndarray = type(None)
@staticmethod
def array(*args, **kwargs): return None
@staticmethod
def arange(*args, **kwargs): return []
@staticmethod
def meshgrid(*args, **kwargs): return [], []
@staticmethod
def linspace(*args, **kwargs): return []
@staticmethod
def nanmin(*args, **kwargs): return 0
@staticmethod
def nanmax(*args, **kwargs): return 0
@staticmethod
def nanmean(*args, **kwargs): return 0
@staticmethod
def nan_to_num(*args, **kwargs): return np.array([])
@staticmethod
def isnan(*args, **kwargs): return False
@staticmethod
def issubdtype(*args, **kwargs): return False
@staticmethod
def any(*args, **kwargs): return False
@staticmethod
def sum(*args, **kwargs): return 0
@staticmethod
def isfinite(*args, **kwargs): return True
floating = float
float64 = float
nan = float('nan')
logging.warning(
"Matplotlib or NumPy not found. "
"Visualization features (2D/3D plots) will be disabled."
)
try:
import scipy.ndimage # For gaussian_filter
from scipy.interpolate import RectBivariateSpline # For smooth 2D interpolation
SCIPY_AVAILABLE = True
except ImportError:
SCIPY_AVAILABLE = False
# Dummy class for RectBivariateSpline if SciPy missing
class RectBivariateSpline: # type: ignore
def __init__(self, *args, **kwargs): pass
def __call__(self, *args, **kwargs): return np.array([[0.0]])
# Dummy module for scipy.ndimage
class scipy_ndimage_dummy: # type: ignore
@staticmethod
def gaussian_filter(*args, **kwargs): return args[0] # Return input array
scipy = type('SciPyDummy', (), {'ndimage': scipy_ndimage_dummy})() # type: ignore
logging.warning(
"SciPy library not found. "
"Advanced smoothing/interpolation for 3D plots will be disabled."
)
# Check for Pillow (PIL) needed for loading images from paths or PIL objects
try:
from PIL import Image
PIL_AVAILABLE_VIS = True
except ImportError:
PIL_AVAILABLE_VIS = False
class Image: # Dummy class # type: ignore
Image = type(None)
@staticmethod
def open(*args, **kwargs): raise ImportError("Pillow not available")
# Use TYPE_CHECKING to hint dependencies without runtime import errors
if TYPE_CHECKING:
import numpy as np_typing # Use an alias to avoid conflict with dummy np
from PIL import Image as PILImage_typing
# === Visualization Functions ===
def show_image_matplotlib(
image_source: Union[str, "np_typing.ndarray", "PILImage_typing.Image"],
title: str = "Image Preview"
):
"""
Displays an image in a separate Matplotlib window with interactive zoom/pan.
Supports loading from a file path (str), a NumPy array, or a PIL Image object.
"""
if not MATPLOTLIB_AVAILABLE:
logging.error("Cannot display image: Matplotlib is not available.")
return
img_display_np: Optional["np_typing.ndarray"] = None
source_type = type(image_source).__name__
logging.info(f"Attempting to display image '{title}' from source type: {source_type}")
try:
if isinstance(image_source, str):
if not PIL_AVAILABLE_VIS:
logging.error("Cannot display image from path: Pillow (PIL) is required.")
return
if not os.path.exists(image_source):
logging.error(f"Image file not found: {image_source}")
return
try:
with Image.open(image_source) as img_pil:
img_display_np = np.array(img_pil)
except Exception as e_load:
logging.error(f"Failed to load image from path {image_source}: {e_load}", exc_info=True)
return
elif isinstance(image_source, np.ndarray):
img_display_np = image_source.copy()
elif PIL_AVAILABLE_VIS and isinstance(image_source, Image.Image): # type: ignore
img_display_np = np.array(image_source)
else:
logging.error(f"Unsupported image source type for Matplotlib: {type(image_source)}")
return
if img_display_np is None:
logging.error("Failed to get image data as NumPy array.")
return
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(img_display_np)
ax.set_title(title)
ax.axis('off')
plt.show()
logging.debug(f"Plot window for '{title}' closed.")
except Exception as e:
logging.error(f"Error displaying image '{title}' with Matplotlib: {e}", exc_info=True)
def show_3d_matplotlib(
hgt_array: Optional["np_typing.ndarray"],
title: str = "3D Elevation View",
initial_subsample: int = 1,
smooth_sigma: Optional[float] = None,
interpolation_factor: int = 1,
plot_grid_points: int = 500
):
"""
Displays elevation data as a 3D surface plot.
Optionally smooths, interpolates to a dense grid, and then plots a
subsampled version of this dense grid to maintain performance.
"""
# --- Input and Dependency Checks ---
if not MATPLOTLIB_AVAILABLE:
logging.error("Cannot display 3D plot: Matplotlib is not available.")
return
if hgt_array is None:
logging.error("Cannot display 3D plot: Input data array is None.")
return
if not isinstance(hgt_array, np.ndarray) or hgt_array.ndim != 2 or hgt_array.size == 0:
logging.error(f"Invalid input data for 3D plot (shape/type).")
return
if not isinstance(initial_subsample, int) or initial_subsample < 1:
logging.warning(f"Invalid initial_subsample ({initial_subsample}). Using 1.")
initial_subsample = 1
if smooth_sigma is not None and (not isinstance(smooth_sigma, (int, float)) or smooth_sigma <= 0):
logging.warning(f"Invalid smooth_sigma ({smooth_sigma}). Disabling smoothing.")
smooth_sigma = None
if not isinstance(interpolation_factor, int) or interpolation_factor < 1:
logging.warning(f"Invalid interpolation_factor ({interpolation_factor}). Using 1.")
interpolation_factor = 1
if not isinstance(plot_grid_points, int) or plot_grid_points < 30: # Min for a reasonable plot
logging.warning(f"Invalid plot_grid_points ({plot_grid_points}). Setting to 200.")
plot_grid_points = 200
# Disable advanced features if SciPy is missing
if (interpolation_factor > 1 or smooth_sigma is not None) and not SCIPY_AVAILABLE:
logging.warning("SciPy not available. Disabling smoothing and/or interpolation.")
smooth_sigma = None
interpolation_factor = 1
processing_start_time = time.time()
try:
plot_title_parts = [title] # Build title dynamically
# --- 1. Initial Subsampling (of raw data) ---
data_to_process = hgt_array[::initial_subsample, ::initial_subsample].copy()
if initial_subsample > 1:
plot_title_parts.append(f"RawSub{initial_subsample}x")
logging.info(f"Initial subsampling by {initial_subsample}. Data shape: {data_to_process.shape}")
# --- 2. Handle NoData (Convert to float, mark NaNs) ---
if not np.issubdtype(data_to_process.dtype, np.floating):
data_to_process = data_to_process.astype(np.float64) # Use float64 for precision
common_nodata_value = -32768.0 # Ensure float for comparison with float array
nodata_mask_initial = (data_to_process == common_nodata_value)
if np.any(nodata_mask_initial):
data_to_process[nodata_mask_initial] = np.nan # Use NaN internally
logging.debug(f"Marked {np.sum(nodata_mask_initial)} NoData points as NaN.")
# --- 3. Gaussian Smoothing (Optional, before interpolation) ---
if smooth_sigma is not None and SCIPY_AVAILABLE:
logging.info(f"Applying Gaussian smoothing (sigma={smooth_sigma})...")
try:
# gaussian_filter handles NaNs by effectively giving them zero weight
data_to_process = scipy.ndimage.gaussian_filter(
data_to_process, sigma=smooth_sigma, mode='nearest'
)
plot_title_parts.append(f"Smooth σ{smooth_sigma:.1f}")
except Exception as e_smooth:
logging.error(f"Gaussian smoothing failed: {e_smooth}", exc_info=True)
plot_title_parts.append("(SmoothFail)")
# --- 4. Interpolation (if requested) ---
rows_proc, cols_proc = data_to_process.shape
x_proc_coords = np.arange(cols_proc) # Original X indices of processed data
y_proc_coords = np.arange(rows_proc) # Original Y indices of processed data
# These will hold the grid and values for the final plot_surface call
X_for_plot, Y_for_plot, Z_for_plot = None, None, None
if interpolation_factor > 1 and SCIPY_AVAILABLE:
plot_title_parts.append(f"Interp{interpolation_factor}x")
logging.info(f"Performing spline interpolation (factor={interpolation_factor}). Input shape: {data_to_process.shape}")
# Define the DENSE grid for spline evaluation
x_dense_eval_coords = np.linspace(x_proc_coords.min(), x_proc_coords.max(), cols_proc * interpolation_factor)
y_dense_eval_coords = np.linspace(y_proc_coords.min(), y_proc_coords.max(), rows_proc * interpolation_factor)
# Prepare data for spline fitting (RectBivariateSpline doesn't like NaNs)
data_for_spline_fit = data_to_process.copy()
nan_in_data_for_spline = np.isnan(data_for_spline_fit)
if np.any(nan_in_data_for_spline):
# Fill NaNs, e.g., with mean of valid data or 0 if all are NaN
fill_value = np.nanmean(data_for_spline_fit)
if np.isnan(fill_value): fill_value = 0.0 # Fallback if all data was NaN
data_for_spline_fit[nan_in_data_for_spline] = fill_value
logging.debug(f"Filled {np.sum(nan_in_data_for_spline)} NaNs with {fill_value:.2f} for spline fitting.")
try:
# Create spline interpolator (kx=3, ky=3 for bicubic)
spline = RectBivariateSpline(y_proc_coords, x_proc_coords, data_for_spline_fit, kx=3, ky=3, s=0)
# Evaluate spline on the DENSE grid
Z_dense_interpolated = spline(y_dense_eval_coords, x_dense_eval_coords)
logging.info(f"Interpolation complete. Dense grid shape: {Z_dense_interpolated.shape}")
# Subsample this DENSE interpolated grid for PLOTTING
# Calculate stride to approximate `plot_grid_points` along each axis
plot_stride_y = max(1, int(Z_dense_interpolated.shape[0] / plot_grid_points))
plot_stride_x = max(1, int(Z_dense_interpolated.shape[1] / plot_grid_points))
logging.info(f"Subsampling dense interpolated grid for plotting with Y-stride:{plot_stride_y}, X-stride:{plot_stride_x}")
# Select coordinates and Z values for the final plot grid
final_y_coords_for_plot = y_dense_eval_coords[::plot_stride_y]
final_x_coords_for_plot = x_dense_eval_coords[::plot_stride_x]
X_for_plot, Y_for_plot = np.meshgrid(final_x_coords_for_plot, final_y_coords_for_plot)
Z_for_plot = Z_dense_interpolated[::plot_stride_y, ::plot_stride_x]
except Exception as e_interp:
logging.error(f"Spline interpolation or subsequent subsampling failed: {e_interp}", exc_info=True)
plot_title_parts.append("(InterpFail)")
# Fallback: plot the processed (maybe smoothed) data, subsampled to plot_grid_points
plot_stride_y = max(1, int(rows_proc / plot_grid_points))
plot_stride_x = max(1, int(cols_proc / plot_grid_points))
X_for_plot, Y_for_plot = np.meshgrid(x_proc_coords[::plot_stride_x], y_proc_coords[::plot_stride_y])
Z_for_plot = data_to_process[::plot_stride_y, ::plot_stride_x]
else:
# No interpolation: plot the processed data, subsampled to achieve plot_grid_points
logging.info("Skipping interpolation. Subsampling processed data for plotting.")
plot_stride_y = max(1, int(rows_proc / plot_grid_points))
plot_stride_x = max(1, int(cols_proc / plot_grid_points))
X_for_plot, Y_for_plot = np.meshgrid(x_proc_coords[::plot_stride_x], y_proc_coords[::plot_stride_y])
Z_for_plot = data_to_process[::plot_stride_y, ::plot_stride_x]
# Construct final plot title
final_plot_title = " ".join(plot_title_parts)
# Display actual plot grid size (Y, X for shape)
final_plot_title += f" (PlotGrid {Z_for_plot.shape[0]}x{Z_for_plot.shape[1]})"
processing_end_time = time.time()
logging.info(f"Data processing for 3D plot took {processing_end_time - processing_start_time:.2f} seconds.")
# --- 5. Plotting the Result ---
logging.info(f"Generating Matplotlib 3D plot. Final plot grid size: {Z_for_plot.shape}")
fig = plt.figure(figsize=(10, 8)) # Slightly larger figure
ax = fig.add_subplot(111, projection='3d')
# Determine Z limits from the final data to be plotted (Z_for_plot)
# Handle potential NaNs that might persist or be introduced
z_min, z_max = np.nanmin(Z_for_plot), np.nanmax(Z_for_plot)
if np.isnan(z_min) or not np.isfinite(z_min): z_min = 0.0
if np.isnan(z_max) or not np.isfinite(z_max): z_max = z_min + 100.0 # Fallback range if max is also bad
if z_min >= z_max : z_max = z_min + 100.0 # Ensure z_max > z_min
# Create the 3D surface plot
# rstride/cstride=1 because X_for_plot/Y_for_plot/Z_for_plot are already at the desired plot density
surf = ax.plot_surface(
X_for_plot, Y_for_plot, Z_for_plot,
rstride=1, cstride=1, # Plot all points from the prepared grid
cmap='terrain', # Standard colormap for terrain
linewidth=0.1, # Thin lines for dense meshes can look good
antialiased=False, # For smoother rendering of facets
shade=False, # Apply shading for better 3D perception
vmin=z_min, # Set color limits based on data range
vmax=z_max
)
# --- Customize Plot Appearance ---
ax.set_xlabel("X Index (Scaled/Processed)")
ax.set_ylabel("Y Index (Scaled/Processed)")
ax.set_zlabel("Elevation (m)")
ax.set_title(final_plot_title, fontsize=10) # Use a slightly smaller font for the potentially long title
# Set Z-axis limits with some padding
z_range = z_max - z_min
padding = z_range * 0.1 if z_range > 0 else 10.0 # Ensure some padding even if range is zero
ax_z_min = z_min - padding
ax_z_max = z_max + padding
ax.set_zlim(ax_z_min, ax_z_max)
# Add a color bar
fig.colorbar(surf, shrink=0.5, aspect=10, label="Elevation (m)", pad=0.1) # Add padding to colorbar
# Improve layout to prevent labels from overlapping
try:
fig.tight_layout()
except Exception:
logging.warning("fig.tight_layout() failed, plot might have overlapping elements.")
plotting_end_time = time.time()
logging.info(f"Plotting setup complete. Total time: {plotting_end_time - processing_start_time:.2f}s. Showing plot...")
plt.show() # BLOCKING CALL
except Exception as e:
# Catch any unexpected errors during the entire process
logging.error(f"Critical error in show_3d_matplotlib ('{title}'): {e}", exc_info=True)