# visualizer.py import logging import os from typing import Optional, Union, TYPE_CHECKING import time # For benchmarking processing time # --- Dependency Checks --- try: import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # Required for 3D projection import numpy as np # Required by Matplotlib & for data handling MATPLOTLIB_AVAILABLE = True except ImportError: MATPLOTLIB_AVAILABLE = False # Define dummy/fallback classes and functions if Matplotlib is missing class plt: # type: ignore @staticmethod def figure(*args, **kwargs): return plt def add_subplot(self, *args, **kwargs): return plt def plot_surface(self, *args, **kwargs): return plt def set_xlabel(self, *args, **kwargs): pass def set_ylabel(self, *args, **kwargs): pass def set_zlabel(self, *args, **kwargs): pass def set_title(self, *args, **kwargs): pass def set_zlim(self, *args, **kwargs): pass def colorbar(self, *args, **kwargs): pass @staticmethod def show(*args, **kwargs): logging.warning("Matplotlib not available, cannot show plot.") @staticmethod def subplots(*args, **kwargs): return plt, plt # Dummy fig, ax def imshow(self, *args, **kwargs): pass def axis(self, *args, **kwargs): pass def tight_layout(self, *args, **kwargs): pass class Axes3D: pass # Dummy class class np: # Minimal numpy dummy # type: ignore ndarray = type(None) @staticmethod def array(*args, **kwargs): return None @staticmethod def arange(*args, **kwargs): return [] @staticmethod def meshgrid(*args, **kwargs): return [], [] @staticmethod def linspace(*args, **kwargs): return [] @staticmethod def nanmin(*args, **kwargs): return 0 @staticmethod def nanmax(*args, **kwargs): return 0 @staticmethod def nanmean(*args, **kwargs): return 0 @staticmethod def nan_to_num(*args, **kwargs): return np.array([]) @staticmethod def isnan(*args, **kwargs): return False @staticmethod def issubdtype(*args, **kwargs): return False @staticmethod def any(*args, **kwargs): return False @staticmethod def sum(*args, **kwargs): return 0 @staticmethod def isfinite(*args, **kwargs): return True floating = float float64 = float nan = float('nan') logging.warning( "Matplotlib or NumPy not found. " "Visualization features (2D/3D plots) will be disabled." ) try: import scipy.ndimage # For gaussian_filter from scipy.interpolate import RectBivariateSpline # For smooth 2D interpolation SCIPY_AVAILABLE = True except ImportError: SCIPY_AVAILABLE = False # Dummy class for RectBivariateSpline if SciPy missing class RectBivariateSpline: # type: ignore def __init__(self, *args, **kwargs): pass def __call__(self, *args, **kwargs): return np.array([[0.0]]) # Dummy module for scipy.ndimage class scipy_ndimage_dummy: # type: ignore @staticmethod def gaussian_filter(*args, **kwargs): return args[0] # Return input array scipy = type('SciPyDummy', (), {'ndimage': scipy_ndimage_dummy})() # type: ignore logging.warning( "SciPy library not found. " "Advanced smoothing/interpolation for 3D plots will be disabled." ) # Check for Pillow (PIL) needed for loading images from paths or PIL objects try: from PIL import Image PIL_AVAILABLE_VIS = True except ImportError: PIL_AVAILABLE_VIS = False class Image: # Dummy class # type: ignore Image = type(None) @staticmethod def open(*args, **kwargs): raise ImportError("Pillow not available") # Use TYPE_CHECKING to hint dependencies without runtime import errors if TYPE_CHECKING: import numpy as np_typing # Use an alias to avoid conflict with dummy np from PIL import Image as PILImage_typing # === Visualization Functions === def show_image_matplotlib( image_source: Union[str, "np_typing.ndarray", "PILImage_typing.Image"], title: str = "Image Preview" ): """ Displays an image in a separate Matplotlib window with interactive zoom/pan. Supports loading from a file path (str), a NumPy array, or a PIL Image object. """ if not MATPLOTLIB_AVAILABLE: logging.error("Cannot display image: Matplotlib is not available.") return img_display_np: Optional["np_typing.ndarray"] = None source_type = type(image_source).__name__ logging.info(f"Attempting to display image '{title}' from source type: {source_type}") try: if isinstance(image_source, str): if not PIL_AVAILABLE_VIS: logging.error("Cannot display image from path: Pillow (PIL) is required.") return if not os.path.exists(image_source): logging.error(f"Image file not found: {image_source}") return try: with Image.open(image_source) as img_pil: img_display_np = np.array(img_pil) except Exception as e_load: logging.error(f"Failed to load image from path {image_source}: {e_load}", exc_info=True) return elif isinstance(image_source, np.ndarray): img_display_np = image_source.copy() elif PIL_AVAILABLE_VIS and isinstance(image_source, Image.Image): # type: ignore img_display_np = np.array(image_source) else: logging.error(f"Unsupported image source type for Matplotlib: {type(image_source)}") return if img_display_np is None: logging.error("Failed to get image data as NumPy array.") return fig, ax = plt.subplots(figsize=(8, 8)) ax.imshow(img_display_np) ax.set_title(title) ax.axis('off') plt.show() logging.debug(f"Plot window for '{title}' closed.") except Exception as e: logging.error(f"Error displaying image '{title}' with Matplotlib: {e}", exc_info=True) def show_3d_matplotlib( hgt_array: Optional["np_typing.ndarray"], title: str = "3D Elevation View", initial_subsample: int = 1, smooth_sigma: Optional[float] = None, interpolation_factor: int = 1, plot_grid_points: int = 500 ): """ Displays elevation data as a 3D surface plot. Optionally smooths, interpolates to a dense grid, and then plots a subsampled version of this dense grid to maintain performance. """ # --- Input and Dependency Checks --- if not MATPLOTLIB_AVAILABLE: logging.error("Cannot display 3D plot: Matplotlib is not available.") return if hgt_array is None: logging.error("Cannot display 3D plot: Input data array is None.") return if not isinstance(hgt_array, np.ndarray) or hgt_array.ndim != 2 or hgt_array.size == 0: logging.error(f"Invalid input data for 3D plot (shape/type).") return if not isinstance(initial_subsample, int) or initial_subsample < 1: logging.warning(f"Invalid initial_subsample ({initial_subsample}). Using 1.") initial_subsample = 1 if smooth_sigma is not None and (not isinstance(smooth_sigma, (int, float)) or smooth_sigma <= 0): logging.warning(f"Invalid smooth_sigma ({smooth_sigma}). Disabling smoothing.") smooth_sigma = None if not isinstance(interpolation_factor, int) or interpolation_factor < 1: logging.warning(f"Invalid interpolation_factor ({interpolation_factor}). Using 1.") interpolation_factor = 1 if not isinstance(plot_grid_points, int) or plot_grid_points < 30: # Min for a reasonable plot logging.warning(f"Invalid plot_grid_points ({plot_grid_points}). Setting to 200.") plot_grid_points = 200 # Disable advanced features if SciPy is missing if (interpolation_factor > 1 or smooth_sigma is not None) and not SCIPY_AVAILABLE: logging.warning("SciPy not available. Disabling smoothing and/or interpolation.") smooth_sigma = None interpolation_factor = 1 processing_start_time = time.time() try: plot_title_parts = [title] # Build title dynamically # --- 1. Initial Subsampling (of raw data) --- data_to_process = hgt_array[::initial_subsample, ::initial_subsample].copy() if initial_subsample > 1: plot_title_parts.append(f"RawSub{initial_subsample}x") logging.info(f"Initial subsampling by {initial_subsample}. Data shape: {data_to_process.shape}") # --- 2. Handle NoData (Convert to float, mark NaNs) --- if not np.issubdtype(data_to_process.dtype, np.floating): data_to_process = data_to_process.astype(np.float64) # Use float64 for precision common_nodata_value = -32768.0 # Ensure float for comparison with float array nodata_mask_initial = (data_to_process == common_nodata_value) if np.any(nodata_mask_initial): data_to_process[nodata_mask_initial] = np.nan # Use NaN internally logging.debug(f"Marked {np.sum(nodata_mask_initial)} NoData points as NaN.") # --- 3. Gaussian Smoothing (Optional, before interpolation) --- if smooth_sigma is not None and SCIPY_AVAILABLE: logging.info(f"Applying Gaussian smoothing (sigma={smooth_sigma})...") try: # gaussian_filter handles NaNs by effectively giving them zero weight data_to_process = scipy.ndimage.gaussian_filter( data_to_process, sigma=smooth_sigma, mode='nearest' ) plot_title_parts.append(f"Smooth σ{smooth_sigma:.1f}") except Exception as e_smooth: logging.error(f"Gaussian smoothing failed: {e_smooth}", exc_info=True) plot_title_parts.append("(SmoothFail)") # --- 4. Interpolation (if requested) --- rows_proc, cols_proc = data_to_process.shape x_proc_coords = np.arange(cols_proc) # Original X indices of processed data y_proc_coords = np.arange(rows_proc) # Original Y indices of processed data # These will hold the grid and values for the final plot_surface call X_for_plot, Y_for_plot, Z_for_plot = None, None, None if interpolation_factor > 1 and SCIPY_AVAILABLE: plot_title_parts.append(f"Interp{interpolation_factor}x") logging.info(f"Performing spline interpolation (factor={interpolation_factor}). Input shape: {data_to_process.shape}") # Define the DENSE grid for spline evaluation x_dense_eval_coords = np.linspace(x_proc_coords.min(), x_proc_coords.max(), cols_proc * interpolation_factor) y_dense_eval_coords = np.linspace(y_proc_coords.min(), y_proc_coords.max(), rows_proc * interpolation_factor) # Prepare data for spline fitting (RectBivariateSpline doesn't like NaNs) data_for_spline_fit = data_to_process.copy() nan_in_data_for_spline = np.isnan(data_for_spline_fit) if np.any(nan_in_data_for_spline): # Fill NaNs, e.g., with mean of valid data or 0 if all are NaN fill_value = np.nanmean(data_for_spline_fit) if np.isnan(fill_value): fill_value = 0.0 # Fallback if all data was NaN data_for_spline_fit[nan_in_data_for_spline] = fill_value logging.debug(f"Filled {np.sum(nan_in_data_for_spline)} NaNs with {fill_value:.2f} for spline fitting.") try: # Create spline interpolator (kx=3, ky=3 for bicubic) spline = RectBivariateSpline(y_proc_coords, x_proc_coords, data_for_spline_fit, kx=3, ky=3, s=0) # Evaluate spline on the DENSE grid Z_dense_interpolated = spline(y_dense_eval_coords, x_dense_eval_coords) logging.info(f"Interpolation complete. Dense grid shape: {Z_dense_interpolated.shape}") # Subsample this DENSE interpolated grid for PLOTTING # Calculate stride to approximate `plot_grid_points` along each axis plot_stride_y = max(1, int(Z_dense_interpolated.shape[0] / plot_grid_points)) plot_stride_x = max(1, int(Z_dense_interpolated.shape[1] / plot_grid_points)) logging.info(f"Subsampling dense interpolated grid for plotting with Y-stride:{plot_stride_y}, X-stride:{plot_stride_x}") # Select coordinates and Z values for the final plot grid final_y_coords_for_plot = y_dense_eval_coords[::plot_stride_y] final_x_coords_for_plot = x_dense_eval_coords[::plot_stride_x] X_for_plot, Y_for_plot = np.meshgrid(final_x_coords_for_plot, final_y_coords_for_plot) Z_for_plot = Z_dense_interpolated[::plot_stride_y, ::plot_stride_x] except Exception as e_interp: logging.error(f"Spline interpolation or subsequent subsampling failed: {e_interp}", exc_info=True) plot_title_parts.append("(InterpFail)") # Fallback: plot the processed (maybe smoothed) data, subsampled to plot_grid_points plot_stride_y = max(1, int(rows_proc / plot_grid_points)) plot_stride_x = max(1, int(cols_proc / plot_grid_points)) X_for_plot, Y_for_plot = np.meshgrid(x_proc_coords[::plot_stride_x], y_proc_coords[::plot_stride_y]) Z_for_plot = data_to_process[::plot_stride_y, ::plot_stride_x] else: # No interpolation: plot the processed data, subsampled to achieve plot_grid_points logging.info("Skipping interpolation. Subsampling processed data for plotting.") plot_stride_y = max(1, int(rows_proc / plot_grid_points)) plot_stride_x = max(1, int(cols_proc / plot_grid_points)) X_for_plot, Y_for_plot = np.meshgrid(x_proc_coords[::plot_stride_x], y_proc_coords[::plot_stride_y]) Z_for_plot = data_to_process[::plot_stride_y, ::plot_stride_x] # Construct final plot title final_plot_title = " ".join(plot_title_parts) # Display actual plot grid size (Y, X for shape) final_plot_title += f" (PlotGrid {Z_for_plot.shape[0]}x{Z_for_plot.shape[1]})" processing_end_time = time.time() logging.info(f"Data processing for 3D plot took {processing_end_time - processing_start_time:.2f} seconds.") # --- 5. Plotting the Result --- logging.info(f"Generating Matplotlib 3D plot. Final plot grid size: {Z_for_plot.shape}") fig = plt.figure(figsize=(10, 8)) # Slightly larger figure ax = fig.add_subplot(111, projection='3d') # Determine Z limits from the final data to be plotted (Z_for_plot) # Handle potential NaNs that might persist or be introduced z_min, z_max = np.nanmin(Z_for_plot), np.nanmax(Z_for_plot) if np.isnan(z_min) or not np.isfinite(z_min): z_min = 0.0 if np.isnan(z_max) or not np.isfinite(z_max): z_max = z_min + 100.0 # Fallback range if max is also bad if z_min >= z_max : z_max = z_min + 100.0 # Ensure z_max > z_min # Create the 3D surface plot # rstride/cstride=1 because X_for_plot/Y_for_plot/Z_for_plot are already at the desired plot density surf = ax.plot_surface( X_for_plot, Y_for_plot, Z_for_plot, rstride=1, cstride=1, # Plot all points from the prepared grid cmap='terrain', # Standard colormap for terrain linewidth=0.1, # Thin lines for dense meshes can look good antialiased=False, # For smoother rendering of facets shade=False, # Apply shading for better 3D perception vmin=z_min, # Set color limits based on data range vmax=z_max ) # --- Customize Plot Appearance --- ax.set_xlabel("X Index (Scaled/Processed)") ax.set_ylabel("Y Index (Scaled/Processed)") ax.set_zlabel("Elevation (m)") ax.set_title(final_plot_title, fontsize=10) # Use a slightly smaller font for the potentially long title # Set Z-axis limits with some padding z_range = z_max - z_min padding = z_range * 0.1 if z_range > 0 else 10.0 # Ensure some padding even if range is zero ax_z_min = z_min - padding ax_z_max = z_max + padding ax.set_zlim(ax_z_min, ax_z_max) # Add a color bar fig.colorbar(surf, shrink=0.5, aspect=10, label="Elevation (m)", pad=0.1) # Add padding to colorbar # Improve layout to prevent labels from overlapping try: fig.tight_layout() except Exception: logging.warning("fig.tight_layout() failed, plot might have overlapping elements.") plotting_end_time = time.time() logging.info(f"Plotting setup complete. Total time: {plotting_end_time - processing_start_time:.2f}s. Showing plot...") plt.show() # BLOCKING CALL except Exception as e: # Catch any unexpected errors during the entire process logging.error(f"Critical error in show_3d_matplotlib ('{title}'): {e}", exc_info=True)