399 lines
19 KiB
Python
399 lines
19 KiB
Python
# visualizer.py
|
||
|
||
import logging
|
||
import os
|
||
from typing import Optional, Union, TYPE_CHECKING, List # Import List for extent type hint
|
||
import time # For benchmarking processing time
|
||
|
||
# --- Dependency Checks ---
|
||
try:
|
||
import matplotlib.pyplot as plt
|
||
from mpl_toolkits.mplot3d import Axes3D # Required for 3D projection
|
||
import numpy as np # Required by Matplotlib & for data handling
|
||
MATPLOTLIB_AVAILABLE = True
|
||
except ImportError:
|
||
MATPLOTLIB_AVAILABLE = False
|
||
# Define dummy/fallback classes and functions if Matplotlib is missing
|
||
class plt: # type: ignore
|
||
@staticmethod
|
||
def figure(*args, **kwargs): return plt
|
||
def add_subplot(self, *args, **kwargs): return plt
|
||
def plot_surface(self, *args, **kwargs): return plt
|
||
def set_xlabel(self, *args, **kwargs): pass
|
||
def set_ylabel(self, *args, **kwargs): pass
|
||
def set_zlabel(self, *args, **kwargs): pass
|
||
def set_title(self, *args, **kwargs): pass
|
||
def set_zlim(self, *args, **kwargs): pass
|
||
def colorbar(self, *args, **kwargs): pass
|
||
@staticmethod
|
||
def show(*args, **kwargs):
|
||
logging.warning("Matplotlib not available, cannot show plot.")
|
||
@staticmethod
|
||
def subplots(*args, **kwargs): return plt, plt # Dummy fig, ax
|
||
def imshow(self, *args, **kwargs): pass
|
||
def axis(self, *args, **kwargs): pass
|
||
def tight_layout(self, *args, **kwargs): pass
|
||
|
||
|
||
class Axes3D: pass # Dummy class
|
||
class np: # Minimal numpy dummy # type: ignore
|
||
ndarray = type(None)
|
||
@staticmethod
|
||
def array(*args, **kwargs): return None
|
||
@staticmethod
|
||
def arange(*args, **kwargs): return []
|
||
@staticmethod
|
||
def meshgrid(*args, **kwargs): return [], []
|
||
@staticmethod
|
||
def linspace(*args, **kwargs): return []
|
||
@staticmethod
|
||
def nanmin(*args, **kwargs): return 0
|
||
@staticmethod
|
||
def nanmax(*args, **kwargs): return 0
|
||
@staticmethod
|
||
def nanmean(*args, **kwargs): return 0
|
||
@staticmethod
|
||
def nan_to_num(*args, **kwargs): return np.array([])
|
||
@staticmethod
|
||
def isnan(*args, **kwargs): return False
|
||
@staticmethod
|
||
def issubdtype(*args, **kwargs): return False
|
||
@staticmethod
|
||
def any(*args, **kwargs): return False
|
||
@staticmethod
|
||
def sum(*args, **kwargs): return 0
|
||
@staticmethod
|
||
def isfinite(*args, **kwargs): return True
|
||
floating = float
|
||
float64 = float
|
||
nan = float('nan')
|
||
|
||
logging.warning(
|
||
"Matplotlib or NumPy not found. "
|
||
"Visualization features (2D/3D plots) will be disabled."
|
||
)
|
||
|
||
try:
|
||
import scipy.ndimage # For gaussian_filter
|
||
from scipy.interpolate import RectBivariateSpline # For smooth 2D interpolation
|
||
SCIPY_AVAILABLE = True
|
||
except ImportError:
|
||
SCIPY_AVAILABLE = False
|
||
# Dummy class for RectBivariateSpline if SciPy missing
|
||
class RectBivariateSpline: # type: ignore
|
||
def __init__(self, *args, **kwargs): pass
|
||
def __call__(self, *args, **kwargs): return np.array([[0.0]])
|
||
# Dummy module for scipy.ndimage
|
||
class scipy_ndimage_dummy: # type: ignore
|
||
@staticmethod
|
||
def gaussian_filter(*args, **kwargs): return args[0] # Return input array
|
||
scipy = type('SciPyDummy', (), {'ndimage': scipy_ndimage_dummy})() # type: ignore
|
||
|
||
logging.warning(
|
||
"SciPy library not found. "
|
||
"Advanced smoothing/interpolation for 3D plots will be disabled."
|
||
)
|
||
|
||
# Check for Pillow (PIL) needed for loading images from paths or PIL objects
|
||
try:
|
||
from PIL import Image
|
||
PIL_AVAILABLE_VIS = True
|
||
except ImportError:
|
||
PIL_AVAILABLE_VIS = False
|
||
class Image: # Dummy class # type: ignore
|
||
Image = type(None)
|
||
@staticmethod
|
||
def open(*args, **kwargs): raise ImportError("Pillow not available")
|
||
|
||
# Use TYPE_CHECKING to hint dependencies without runtime import errors
|
||
if TYPE_CHECKING:
|
||
import numpy as np_typing # Use an alias to avoid conflict with dummy np
|
||
from PIL import Image as PILImage_typing
|
||
|
||
|
||
# === Visualization Functions ===
|
||
|
||
def show_image_matplotlib(
|
||
image_source: Union[str, "np_typing.ndarray", "PILImage_typing.Image"],
|
||
title: str = "Image Preview",
|
||
extent: Optional[List[float]] = None # MODIFIED: Added optional extent parameter
|
||
):
|
||
"""
|
||
Displays an image in a separate Matplotlib window with interactive zoom/pan.
|
||
Supports loading from a file path (str), a NumPy array, or a PIL Image object.
|
||
Optionally applies a geographic extent for correct aspect ratio display.
|
||
|
||
Args:
|
||
image_source (Union[str, np.ndarray, PIL.Image.Image]): The image data or path.
|
||
title (str): The title for the plot window.
|
||
extent (Optional[List[float]]): A list [left, right, bottom, top] in geographic
|
||
coordinates. If provided, Matplotlib will use this
|
||
for plot limits and aspect ratio.
|
||
"""
|
||
if not MATPLOTLIB_AVAILABLE:
|
||
logging.error("Cannot display image: Matplotlib is not available.")
|
||
return
|
||
|
||
img_display_np: Optional["np_typing.ndarray"] = None
|
||
source_type = type(image_source).__name__
|
||
# MODIFIED: Added log info about whether extent is provided.
|
||
# WHY: Useful for debugging.
|
||
# HOW: Check if extent is None in the log message.
|
||
logging.info(f"Attempting to display image '{title}' from source type: {source_type}. Extent provided: {'Yes' if extent is not None else 'No'}")
|
||
|
||
|
||
try:
|
||
if isinstance(image_source, str):
|
||
if not PIL_AVAILABLE_VIS:
|
||
logging.error("Cannot display image from path: Pillow (PIL) is required.")
|
||
return
|
||
if not os.path.exists(image_source):
|
||
logging.error(f"Image file not found: {image_source}")
|
||
return
|
||
try:
|
||
with Image.open(image_source) as img_pil:
|
||
img_display_np = np.array(img_pil)
|
||
except Exception as e_load:
|
||
logging.error(f"Failed to load image from path {image_source}: {e_load}", exc_info=True)
|
||
return
|
||
elif isinstance(image_source, np.ndarray):
|
||
img_display_np = image_source.copy()
|
||
# MODIFIED: Added check for PIL_AVAILABLE_VIS before isinstance check against Image.Image.
|
||
# WHY: Avoids NameError if Image is the dummy class.
|
||
# HOW: Added the boolean check.
|
||
elif PIL_AVAILABLE_VIS and isinstance(image_source, Image.Image): # type: ignore
|
||
img_display_np = np.array(image_source)
|
||
else:
|
||
logging.error(f"Unsupported image source type for Matplotlib: {type(image_source)}")
|
||
return
|
||
|
||
if img_display_np is None:
|
||
logging.error("Failed to get image data as NumPy array.")
|
||
return
|
||
|
||
fig, ax = plt.subplots(figsize=(8, 8))
|
||
# MODIFIED: Pass the extent parameter to imshow.
|
||
# WHY: Allows Matplotlib to display the image with the correct geographic aspect ratio.
|
||
# HOW: Added extent=extent argument.
|
||
ax.imshow(img_display_np, extent=extent)
|
||
ax.set_title(title)
|
||
# MODIFIED: Only turn off axis if extent is NOT provided.
|
||
# WHY: If extent is provided, the axes represent geographic coordinates and should usually be visible.
|
||
# HOW: Added conditional check.
|
||
if extent is None:
|
||
ax.axis('off') # Turn off axes for simple image display
|
||
else:
|
||
ax.set_xlabel("Longitude") # Label axes with geographic meaning
|
||
ax.set_ylabel("Latitude")
|
||
ax.set_aspect('auto', adjustable='box') # Let Matplotlib adjust aspect based on extent and figure size
|
||
# It might be better to use 'equal' if we strictly want geographic square pixels,
|
||
# but 'auto' often gives a better fit within the figure while respecting the extent.
|
||
# Let's stick with default 'auto' when extent is provided, which implicitly
|
||
# respects the aspect ratio defined by the extent if adjustable='box'.
|
||
|
||
|
||
plt.show()
|
||
logging.debug(f"Plot window for '{title}' closed.")
|
||
except Exception as e:
|
||
logging.error(f"Error displaying image '{title}' with Matplotlib: {e}", exc_info=True)
|
||
|
||
|
||
def show_3d_matplotlib(
|
||
hgt_array: Optional["np_typing.ndarray"],
|
||
title: str = "3D Elevation View",
|
||
initial_subsample: int = 1,
|
||
smooth_sigma: Optional[float] = None,
|
||
interpolation_factor: int = 1,
|
||
plot_grid_points: int = 500
|
||
):
|
||
"""
|
||
Displays elevation data as a 3D surface plot.
|
||
Optionally smooths, interpolates to a dense grid, and then plots a
|
||
subsampled version of this dense grid to maintain performance.
|
||
"""
|
||
# --- Input and Dependency Checks ---
|
||
if not MATPLOTLIB_AVAILABLE:
|
||
logging.error("Cannot display 3D plot: Matplotlib is not available.")
|
||
return
|
||
if hgt_array is None:
|
||
logging.error("Cannot display 3D plot: Input data array is None.")
|
||
return
|
||
if not isinstance(hgt_array, np.ndarray) or hgt_array.ndim != 2 or hgt_array.size == 0:
|
||
logging.error(f"Invalid input data for 3D plot (shape/type).")
|
||
return
|
||
if not isinstance(initial_subsample, int) or initial_subsample < 1:
|
||
logging.warning(f"Invalid initial_subsample ({initial_subsample}). Using 1.")
|
||
initial_subsample = 1
|
||
if smooth_sigma is not None and (not isinstance(smooth_sigma, (int, float)) or smooth_sigma <= 0):
|
||
logging.warning(f"Invalid smooth_sigma ({smooth_sigma}). Disabling smoothing.")
|
||
smooth_sigma = None
|
||
if not isinstance(interpolation_factor, int) or interpolation_factor < 1:
|
||
logging.warning(f"Invalid interpolation_factor ({interpolation_factor}). Using 1.")
|
||
interpolation_factor = 1
|
||
if not isinstance(plot_grid_points, int) or plot_grid_points < 30: # Min for a reasonable plot
|
||
logging.warning(f"Invalid plot_grid_points ({plot_grid_points}). Setting to 200.")
|
||
plot_grid_points = 200
|
||
|
||
# Disable advanced features if SciPy is missing
|
||
if (interpolation_factor > 1 or smooth_sigma is not None) and not SCIPY_AVAILABLE:
|
||
logging.warning("SciPy not available. Disabling smoothing and/or interpolation.")
|
||
smooth_sigma = None
|
||
interpolation_factor = 1
|
||
|
||
processing_start_time = time.time()
|
||
try:
|
||
plot_title_parts = [title] # Build title dynamically
|
||
|
||
# --- 1. Initial Subsampling (of raw data) ---
|
||
data_to_process = hgt_array[::initial_subsample, ::initial_subsample].copy()
|
||
if initial_subsample > 1:
|
||
plot_title_parts.append(f"RawSub{initial_subsample}x")
|
||
logging.info(f"Initial subsampling by {initial_subsample}. Data shape: {data_to_process.shape}")
|
||
|
||
# --- 2. Handle NoData (Convert to float, mark NaNs) ---
|
||
if not np.issubdtype(data_to_process.dtype, np.floating):
|
||
data_to_process = data_to_process.astype(np.float64) # Use float64 for precision
|
||
common_nodata_value = -32768.0 # Ensure float for comparison with float array
|
||
nodata_mask_initial = (data_to_process == common_nodata_value)
|
||
if np.any(nodata_mask_initial):
|
||
data_to_process[nodata_mask_initial] = np.nan # Use NaN internally
|
||
logging.debug(f"Marked {np.sum(nodata_mask_initial)} NoData points as NaN.")
|
||
|
||
# --- 3. Gaussian Smoothing (Optional, before interpolation) ---
|
||
if smooth_sigma is not None and SCIPY_AVAILABLE:
|
||
logging.info(f"Applying Gaussian smoothing (sigma={smooth_sigma})...")
|
||
try:
|
||
# gaussian_filter handles NaNs by effectively giving them zero weight
|
||
data_to_process = scipy.ndimage.gaussian_filter(
|
||
data_to_process, sigma=smooth_sigma, mode='nearest'
|
||
)
|
||
plot_title_parts.append(f"Smooth σ{smooth_sigma:.1f}")
|
||
except Exception as e_smooth:
|
||
logging.error(f"Gaussian smoothing failed: {e_smooth}", exc_info=True)
|
||
plot_title_parts.append("(SmoothFail)")
|
||
|
||
# --- 4. Interpolation (if requested) ---
|
||
rows_proc, cols_proc = data_to_process.shape
|
||
x_proc_coords = np.arange(cols_proc) # Original X indices of processed data
|
||
y_proc_coords = np.arange(rows_proc) # Original Y indices of processed data
|
||
|
||
# These will hold the grid and values for the final plot_surface call
|
||
X_for_plot, Y_for_plot, Z_for_plot = None, None, None
|
||
|
||
if interpolation_factor > 1 and SCIPY_AVAILABLE:
|
||
plot_title_parts.append(f"Interp{interpolation_factor}x")
|
||
logging.info(f"Performing spline interpolation (factor={interpolation_factor}). Input shape: {data_to_process.shape}")
|
||
|
||
# Define the DENSE grid for spline evaluation
|
||
x_dense_eval_coords = np.linspace(x_proc_coords.min(), x_proc_coords.max(), cols_proc * interpolation_factor)
|
||
y_dense_eval_coords = np.linspace(y_proc_coords.min(), y_proc_coords.max(), rows_proc * interpolation_factor)
|
||
|
||
# Prepare data for spline fitting (RectBivariateSpline doesn't like NaNs)
|
||
data_for_spline_fit = data_to_process.copy()
|
||
nan_in_data_for_spline = np.isnan(data_for_spline_fit)
|
||
if np.any(nan_in_data_for_spline):
|
||
# Fill NaNs, e.g., with mean of valid data or 0 if all are NaN
|
||
fill_value = np.nanmean(data_for_spline_fit)
|
||
if np.isnan(fill_value): fill_value = 0.0 # Fallback if all data was NaN
|
||
data_for_spline_fit[nan_in_data_for_spline] = fill_value
|
||
logging.debug(f"Filled {np.sum(nan_in_data_for_spline)} NaNs with {fill_value:.2f} for spline fitting.")
|
||
|
||
try:
|
||
# Create spline interpolator (kx=3, ky=3 for bicubic)
|
||
spline = RectBivariateSpline(y_proc_coords, x_proc_coords, data_for_spline_fit, kx=3, ky=3, s=0)
|
||
# Evaluate spline on the DENSE grid
|
||
Z_dense_interpolated = spline(y_dense_eval_coords, x_dense_eval_coords)
|
||
logging.info(f"Interpolation complete. Dense grid shape: {Z_dense_interpolated.shape}")
|
||
|
||
# Subsample this DENSE interpolated grid for PLOTTING
|
||
# Calculate stride to approximate `plot_grid_points` along each axis
|
||
plot_stride_y = max(1, int(Z_dense_interpolated.shape[0] / plot_grid_points))
|
||
plot_stride_x = max(1, int(Z_dense_interpolated.shape[1] / plot_grid_points))
|
||
logging.info(f"Subsampling dense interpolated grid for plotting with Y-stride:{plot_stride_y}, X-stride:{plot_stride_x}")
|
||
|
||
# Select coordinates and Z values for the final plot grid
|
||
final_y_coords_for_plot = y_dense_eval_coords[::plot_stride_y]
|
||
final_x_coords_for_plot = x_dense_eval_coords[::plot_stride_x]
|
||
X_for_plot, Y_for_plot = np.meshgrid(final_x_coords_for_plot, final_y_coords_for_plot)
|
||
Z_for_plot = Z_dense_interpolated[::plot_stride_y, ::plot_stride_x]
|
||
|
||
except Exception as e_interp:
|
||
logging.error(f"Spline interpolation or subsequent subsampling failed: {e_interp}", exc_info=True)
|
||
plot_title_parts.append("(InterpFail)")
|
||
# Fallback: plot the processed (maybe smoothed) data, subsampled to plot_grid_points
|
||
plot_stride_y = max(1, int(rows_proc / plot_grid_points))
|
||
plot_stride_x = max(1, int(cols_proc / plot_grid_points))
|
||
X_for_plot, Y_for_plot = np.meshgrid(x_proc_coords[::plot_stride_x], y_proc_coords[::plot_stride_y])
|
||
Z_for_plot = data_to_process[::plot_stride_y, ::plot_stride_x]
|
||
else:
|
||
# No interpolation: plot the processed data, subsampled to achieve plot_grid_points
|
||
logging.info("Skipping interpolation. Subsampling processed data for plotting.")
|
||
plot_stride_y = max(1, int(rows_proc / plot_grid_points))
|
||
plot_stride_x = max(1, int(cols_proc / plot_grid_points))
|
||
X_for_plot, Y_for_plot = np.meshgrid(x_proc_coords[::plot_stride_x], y_proc_coords[::plot_stride_y])
|
||
Z_for_plot = data_to_process[::plot_stride_y, ::plot_stride_x]
|
||
|
||
# Construct final plot title
|
||
final_plot_title = " ".join(plot_title_parts)
|
||
# Display actual plot grid size (Y, X for shape)
|
||
final_plot_title += f" (PlotGrid {Z_for_plot.shape[0]}x{Z_for_plot.shape[1]})"
|
||
|
||
processing_end_time = time.time()
|
||
logging.info(f"Data processing for 3D plot took {processing_end_time - processing_start_time:.2f} seconds.")
|
||
|
||
# --- 5. Plotting the Result ---
|
||
logging.info(f"Generating Matplotlib 3D plot. Final plot grid size: {Z_for_plot.shape}")
|
||
fig = plt.figure(figsize=(10, 8)) # Slightly larger figure
|
||
ax = fig.add_subplot(111, projection='3d')
|
||
|
||
# Determine Z limits from the final data to be plotted (Z_for_plot)
|
||
# Handle potential NaNs that might persist or be introduced
|
||
z_min, z_max = np.nanmin(Z_for_plot), np.nanmax(Z_for_plot)
|
||
if np.isnan(z_min) or not np.isfinite(z_min): z_min = 0.0
|
||
if np.isnan(z_max) or not np.isfinite(z_max): z_max = z_min + 100.0 # Fallback range if max is also bad
|
||
if z_min >= z_max : z_max = z_min + 100.0 # Ensure z_max > z_min
|
||
|
||
|
||
# Create the 3D surface plot
|
||
# rstride/cstride=1 because X_for_plot/Y_for_plot/Z_for_plot are already at the desired plot density
|
||
surf = ax.plot_surface(
|
||
X_for_plot, Y_for_plot, Z_for_plot,
|
||
rstride=1, cstride=1, # Plot all points from the prepared grid
|
||
cmap='terrain', # Standard colormap for terrain
|
||
linewidth=0.1, # Thin lines for dense meshes can look good
|
||
antialiased=False, # For smoother rendering of facets
|
||
shade=False, # Apply shading for better 3D perception
|
||
vmin=z_min, # Set color limits based on data range
|
||
vmax=z_max
|
||
)
|
||
|
||
# --- Customize Plot Appearance ---
|
||
ax.set_xlabel("X Index (Scaled/Processed)")
|
||
ax.set_ylabel("Y Index (Scaled/Processed)")
|
||
ax.set_zlabel("Elevation (m)")
|
||
ax.set_title(final_plot_title, fontsize=10) # Use a slightly smaller font for the potentially long title
|
||
|
||
# Set Z-axis limits with some padding
|
||
z_range = z_max - z_min
|
||
padding = z_range * 0.1 if z_range > 0 else 10.0 # Ensure some padding even if range is zero
|
||
ax_z_min = z_min - padding
|
||
ax_z_max = z_max + padding
|
||
ax.set_zlim(ax_z_min, ax_z_max)
|
||
|
||
# Add a color bar
|
||
fig.colorbar(surf, shrink=0.5, aspect=10, label="Elevation (m)", pad=0.1) # Add padding to colorbar
|
||
|
||
# Improve layout to prevent labels from overlapping
|
||
try:
|
||
fig.tight_layout()
|
||
except Exception:
|
||
logging.warning("fig.tight_layout() failed, plot might have overlapping elements.")
|
||
|
||
|
||
plotting_end_time = time.time()
|
||
logging.info(f"Plotting setup complete. Total time: {plotting_end_time - processing_start_time:.2f}s. Showing plot...")
|
||
plt.show() # BLOCKING CALL
|
||
|
||
except Exception as e:
|
||
# Catch any unexpected errors during the entire process
|
||
logging.error(f"Critical error in show_3d_matplotlib ('{title}'): {e}", exc_info=True) |