Skip to content

Commit 7d0c2bd

Browse files
committed
add NaN checks for input arrays in odr_fit
1 parent f531aa9 commit 7d0c2bd

2 files changed

Lines changed: 92 additions & 0 deletions

File tree

src/odrpack/odr_scipy.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,24 @@ def odr_fit(f: Callable[[F64Array, F64Array], F64Array],
404404
ldwe = 1
405405
ld2we = 1
406406

407+
# Check all arrays for NaNs
408+
for array, name in [(xdata, 'xdata'),
409+
(ydata, 'ydata'),
410+
(beta0, 'beta0'),
411+
(weight_x, 'weight_x'),
412+
(weight_y, 'weight_y'),
413+
(lower, 'bounds[0]'),
414+
(upper, 'bounds[1]'),
415+
(fix_beta, 'fix_beta'),
416+
(fix_x, 'fix_x'),
417+
(delta0, 'delta0'),
418+
(step_beta, 'step_beta'),
419+
(step_delta, 'step_delta'),
420+
(scale_beta, 'scale_beta'),
421+
(scale_delta, 'scale_delta')]:
422+
if array is not None and np.isnan(array).any():
423+
raise ValueError(f"`{name}` contains NaN values.")
424+
407425
# Check model function
408426
f0 = f(xdata, beta0)
409427
if f0.shape != ydata.shape:

tests/test_odr_fit.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,18 @@ def test_base_cases(case1, case2, case3):
176176
# y has invalid shape
177177
_ = odr_fit(f=case2['f'], xdata=np.ones((2, 10)), ydata=np.ones((1, 2, 10)),
178178
beta0=case2['beta0'])
179+
with pytest.raises(ValueError):
180+
# x has nan values
181+
xdata = case1['xdata'].copy()
182+
xdata[0] = np.nan
183+
_ = odr_fit(f=case1['f'], xdata=xdata, ydata=case1['ydata'],
184+
beta0=case1['beta0'])
185+
with pytest.raises(ValueError):
186+
# y has nan values
187+
ydata = case1['ydata'].copy()
188+
ydata[-1] = np.nan
189+
_ = odr_fit(f=case1['f'], xdata=case1['xdata'], ydata=ydata,
190+
beta0=case1['beta0'])
179191

180192

181193
def test_beta0_related(case1):
@@ -242,6 +254,36 @@ def test_beta0_related(case1):
242254
with pytest.raises(ValueError):
243255
# invalid task
244256
_ = odr_fit(**case1, task='invalid')
257+
with pytest.raises(ValueError):
258+
# beta0 has nan values
259+
beta0 = case1['beta0'].copy()
260+
beta0[0] = np.nan
261+
_ = odr_fit(f=case1['f'], xdata=case1['xdata'], ydata=case1['ydata'],
262+
beta0=beta0)
263+
with pytest.raises(ValueError):
264+
# lower has nan values
265+
lower = case1['beta0'].copy() - 1.0
266+
lower[-1] = np.nan
267+
_ = odr_fit(**case1, bounds=(lower, None))
268+
with pytest.raises(ValueError):
269+
# upper has nan values
270+
upper = case1['beta0'].copy() + 1.0
271+
upper[0] = np.nan
272+
_ = odr_fit(**case1, bounds=(None, upper))
273+
with pytest.raises(ValueError):
274+
# fix_beta has nan values
275+
fix_beta = np.array([True, np.nan, False, False])
276+
_ = odr_fit(**case1, fix_beta=fix_beta)
277+
with pytest.raises(ValueError):
278+
# step_beta has nan values
279+
step_beta = case1['beta0'].copy()
280+
step_beta[2] = np.nan
281+
_ = odr_fit(**case1, step_beta=step_beta)
282+
with pytest.raises(ValueError):
283+
# scale_beta has nan values
284+
scale_beta = case1['beta0'].copy()
285+
scale_beta[2] = np.nan
286+
_ = odr_fit(**case1, scale_beta=scale_beta)
245287

246288

247289
def test_delta0_related(case1, case3):
@@ -312,6 +354,26 @@ def test_delta0_related(case1, case3):
312354
with pytest.raises(ValueError):
313355
# delta0 has invalid shape
314356
_ = odr_fit(**case3, delta0=np.zeros_like(case1['ydata']))
357+
with pytest.raises(ValueError):
358+
# fix_x has nan values
359+
fix_x = np.ones_like(case1['xdata'])
360+
fix_x[0] = np.nan
361+
_ = odr_fit(**case1, fix_x=fix_x)
362+
with pytest.raises(ValueError):
363+
# step_delta has nan values
364+
step_delta = np.ones_like(case3['xdata'])
365+
step_delta[0, 0] = np.nan
366+
_ = odr_fit(**case3, step_delta=step_delta)
367+
with pytest.raises(ValueError):
368+
# scale_delta has nan values
369+
scale_delta = np.ones_like(case3['xdata'])
370+
scale_delta[1, 1] = np.nan
371+
_ = odr_fit(**case3, scale_delta=scale_delta)
372+
with pytest.raises(ValueError):
373+
# delta0 has nan values
374+
delta0 = np.ones_like(case3['xdata'])
375+
delta0[2, 2] = np.nan
376+
_ = odr_fit(**case3, delta0=delta0)
315377

316378

317379
def test_weight_x(case1, case3):
@@ -386,6 +448,12 @@ def test_weight_x(case1, case3):
386448
with pytest.raises(TypeError):
387449
_ = odr_fit(**case3, weight_x=[1.0, 1.0, 1.0])
388450

451+
# weight_x has nan values
452+
weight_x = np.ones_like(case1['xdata'])
453+
weight_x[0] = np.nan
454+
with pytest.raises(ValueError):
455+
_ = odr_fit(**case1, weight_x=weight_x)
456+
389457

390458
def test_weight_y(case1, case3):
391459

@@ -461,6 +529,12 @@ def test_weight_y(case1, case3):
461529
with pytest.raises(TypeError):
462530
_ = odr_fit(**case3, weight_y=[1.0, 1.0, 1.0])
463531

532+
# weight_y has nan values
533+
weight_y = np.ones_like(case1['ydata'])
534+
weight_y[-1] = np.nan
535+
with pytest.raises(ValueError):
536+
_ = odr_fit(**case1, weight_y=weight_y)
537+
464538

465539
def test_parameters(case1):
466540
# maxit

0 commit comments

Comments
 (0)