Skip to content

Commit ea5f763

Browse files
committed
api: enforce valid coordinates for inject/interp
1 parent 1745806 commit ea5f763

File tree

4 files changed

+47
-3
lines changed

4 files changed

+47
-3
lines changed

devito/operations/interpolators.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,34 @@ def wrapper(interp, *args, **kwargs):
3434
return wrapper
3535

3636

37+
def check_coords(func):
38+
@wraps(func)
39+
def wrapper(interp, *args, **kwargs):
40+
inputs = args + as_tuple(kwargs.get('expr', ()))
41+
42+
# SubFunction of the SparseFunction use to create the interpolator
43+
sfunc = interp.sfunction
44+
45+
# SubFunctions found in the arguments of the interpolation/injection operation
46+
a_sfuncs = {f for f in retrieve_functions(inputs)
47+
if f.is_SparseFunction} - {sfunc}
48+
if not a_sfuncs:
49+
# Only uses the the interpolator's SparseFunction, so no need to check
50+
return func(interp, *args, **kwargs)
51+
52+
# Check that it uses the same coordinates as the interpolator's SparseFunction
53+
subfuncs = {getattr(sfunc, s, None) for s in sfunc._sub_functions}
54+
for f in a_sfuncs:
55+
for s in f._sub_functions:
56+
if getattr(f, s, None) not in subfuncs:
57+
raise ValueError(f"Interpolation/injection with {sfunc}"
58+
f"requires {f} "
59+
f"to use the same {s} as {sfunc}")
60+
61+
return func(interp, *args, **kwargs)
62+
return wrapper
63+
64+
3765
def _extract_subdomain(variables):
3866
"""
3967
Check if any of the variables provided are defined on a SubDomain
@@ -322,6 +350,7 @@ def _interp_idx(self, variables, implicit_dims=None, pos_only=(), subdomain=None
322350
return idx_subs, temps
323351

324352
@check_radius
353+
@check_coords
325354
def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
326355
"""
327356
Generate equations interpolating an arbitrary expression into ``self``.
@@ -342,6 +371,7 @@ def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None)
342371
return Interpolation(expr, increment, implicit_dims, self_subs, self)
343372

344373
@check_radius
374+
@check_coords
345375
def inject(self, field, expr, implicit_dims=None):
346376
"""
347377
Generate equations injecting an arbitrary expression into a field.

devito/operator/operator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,8 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
647647
else:
648648
args[k] = args.unique(k, candidate=v)
649649

650-
kwargs['args'] = args.reduce_inplace()
650+
args.reduce_inplace()
651+
kwargs['args'] = args
651652

652653
for i in discretizations:
653654
args.update(i._arg_values(**kwargs))

devito/tools/data_structures.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,6 @@ def reduce_inplace(self):
223223
for k, v in self.reduce_all().items():
224224
self[k] = v
225225

226-
return self
227-
228226

229227
class DefaultOrderedDict(OrderedDict):
230228
# Source: http://stackoverflow.com/a/6190500/562769

tests/test_interpolation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,3 +1255,18 @@ def test_inject_subdomain_mpi(self, mode):
12551255
assert data1 == None # noqa
12561256
assert data2 == None # noqa
12571257
assert data3 == None # noqa
1258+
1259+
1260+
def test_wrong_coords():
1261+
grid = Grid(shape=(11, 11))
1262+
s = SparseFunction(name='src', npoint=1, grid=grid)
1263+
s2 = SparseFunction(name='src2', npoint=1, grid=grid)
1264+
u = Function(name='u', grid=grid)
1265+
1266+
with pytest.raises(ValueError) as vinfo:
1267+
s.inject(u, expr=s2)
1268+
assert "Interpolation/injection with" in str(vinfo.value)
1269+
1270+
with pytest.raises(ValueError) as vinfo:
1271+
s.interpolate(u + s2)
1272+
assert "Interpolation/injection with" in str(vinfo.value)

0 commit comments

Comments
 (0)