68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
"""Spline utilities for Catmull-Rom interpolation and trajectory sampling.
|
|
|
|
Provides a small helper :func:`catmull_rom_spline` that performs Catmull-Rom
|
|
interpolation over a sequence of control points and samples the resulting
|
|
curve into a list of points. The implementation supports 2D or 3D points and
|
|
falls back to a polyline when there are fewer than four control points.
|
|
"""
|
|
|
|
import numpy as np
|
|
|
|
|
|
def catmull_rom_spline(points, num_points=100):
|
|
"""Sample a Catmull-Rom spline through the provided control points.
|
|
|
|
Args:
|
|
points: Iterable of 2D or 3D points (tuples or arrays) that act as
|
|
control points for the spline.
|
|
num_points: Number of sample points to produce along the resulting
|
|
spline (including endpoints).
|
|
|
|
Returns:
|
|
A list of sampled points (each a list of floats). When fewer than 4
|
|
control points are supplied the function returns a simple polyline
|
|
(the input points converted to lists).
|
|
"""
|
|
points = np.asarray(points, dtype=float)
|
|
n = len(points)
|
|
if n < 4:
|
|
# Not enough points for a spline, return a polyline
|
|
return points.tolist()
|
|
|
|
# Pad the points to ensure continuity at the ends
|
|
p_start = points[0]
|
|
p_end = points[-1]
|
|
extended_points = np.vstack([p_start, points, p_end])
|
|
|
|
# Define the Catmull-Rom matrix
|
|
C = 0.5 * np.array([[0, 2, 0, 0], [-1, 0, 1, 0], [2, -5, 4, -1], [-1, 3, -3, 1]])
|
|
|
|
result = []
|
|
total_segments = n - 1
|
|
if total_segments <= 0:
|
|
return points.tolist()
|
|
|
|
for k in range(num_points):
|
|
s = (k / (num_points - 1)) * total_segments
|
|
seg = int(np.floor(s))
|
|
if seg >= total_segments:
|
|
seg = total_segments - 1
|
|
t = s - seg
|
|
|
|
# Control points for the segment
|
|
# The segment is between P1 and P2 of the control points
|
|
# extended_points is indexed s.t. extended_points[i+1] = points[i]
|
|
P = extended_points[seg : seg + 4]
|
|
|
|
# Powers of t
|
|
T = np.array([1, t, t**2, t**3])
|
|
|
|
# Calculate the point
|
|
pt = T @ C @ P
|
|
result.append(pt.tolist())
|
|
|
|
# Ensure exact endpoints match control points
|
|
result[0] = points[0].tolist()
|
|
result[-1] = points[-1].tolist()
|
|
return result
|