Skip to content

Commit f45955b

Browse files
Hhh 118 update interpolation method in waverespons as todays method is deprecated (#65)
1 parent d0bd0d8 commit f45955b

3 files changed

Lines changed: 17 additions & 12 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ classifiers = [
2121
"Operating System :: Microsoft :: Windows",
2222
]
2323
dependencies = [
24-
"numpy<2.0.0",
24+
"numpy",
2525
"pandas",
26-
"scipy<1.14.0",
26+
"scipy",
2727
"pyarrow"
2828
]
2929

src/waveresponse/_core.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66
from scipy.integrate import trapezoid
7-
from scipy.interpolate import interp2d
7+
from scipy.interpolate import RegularGridInterpolator as RGI
88
from scipy.special import gamma
99

1010

@@ -601,11 +601,12 @@ def rotate(self, angle, degrees=False):
601601

602602
def _interpolate_function(self, complex_convert="rectangular", **kw):
603603
"""
604-
Interpolation function based on ``scipy.interpolate.interp2d``.
604+
Interpolation function based on ``scipy.interpolate.RegularGridInterpolator``.
605605
"""
606606
xp = np.concatenate(
607607
(self._dirs[-1:] - 2 * np.pi, self._dirs, self._dirs[:1] + 2.0 * np.pi)
608608
)
609+
609610
yp = self._freq
610611
zp = np.concatenate(
611612
(
@@ -617,11 +618,11 @@ def _interpolate_function(self, complex_convert="rectangular", **kw):
617618
)
618619

619620
if np.all(np.isreal(zp)):
620-
return interp2d(xp, yp, zp, **kw)
621+
return RGI((xp, yp), zp.T, **kw)
621622
elif complex_convert.lower() == "polar":
622623
amp, phase = complex_to_polar(zp, phase_degrees=False)
623-
interp_amp = interp2d(xp, yp, amp, **kw)
624-
interp_phase = interp2d(xp, yp, phase, **kw)
624+
interp_amp = RGI((xp, yp), amp.T, **kw)
625+
interp_phase = RGI((xp, yp), phase.T, **kw)
625626
return lambda *args, **kwargs: (
626627
polar_to_complex(
627628
interp_amp(*args, **kwargs),
@@ -630,8 +631,8 @@ def _interpolate_function(self, complex_convert="rectangular", **kw):
630631
)
631632
)
632633
elif complex_convert.lower() == "rectangular":
633-
interp_real = interp2d(xp, yp, np.real(zp), **kw)
634-
interp_imag = interp2d(xp, yp, np.imag(zp), **kw)
634+
interp_real = RGI((xp, yp), np.real(zp.T), **kw)
635+
interp_imag = RGI((xp, yp), np.imag(zp.T), **kw)
635636
return lambda *args, **kwargs: (
636637
interp_real(*args, **kwargs) + 1j * interp_imag(*args, **kwargs)
637638
)
@@ -702,10 +703,14 @@ def interpolate(
702703
self._check_dirs(dirs)
703704

704705
interp_fun = self._interpolate_function(
705-
complex_convert=complex_convert, kind="linear", fill_value=fill_value
706+
complex_convert=complex_convert,
707+
method="linear",
708+
bounds_error=False,
709+
fill_value=fill_value,
706710
)
707711

708-
return interp_fun(dirs, freq, assume_sorted=True)
712+
dirsnew, freqnew = np.meshgrid(dirs, freq, indexing="ij", sparse=True)
713+
return interp_fun((dirsnew, freqnew)).T
709714

710715
def reshape(
711716
self,

tests/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1400,7 +1400,7 @@ def test_interpolate_single_coordinate(self):
14001400

14011401
vals_out = grid.interpolate(1.8, 12.1, freq_hz=True, degrees=True)
14021402

1403-
vals_expect = np.array([a * 12.1 + b * 1.8])
1403+
vals_expect = np.array(a * 12.1 + b * 1.8)
14041404

14051405
np.testing.assert_array_almost_equal(vals_out, vals_expect)
14061406

0 commit comments

Comments
 (0)