Skip to content

Commit 10eac67

Browse files
committed
api: fix interp/eval of expressions
1 parent 676b8bf commit 10eac67

File tree

5 files changed

+113
-77
lines changed

5 files changed

+113
-77
lines changed

devito/finite_differences/derivative.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,9 @@ def _eval_at(self, func):
478478
setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx
479479
has to be computed at x=x + h_x/2.
480480
"""
481+
# No staggering, don't waste time
482+
if not self.expr.staggered and not func.staggered:
483+
return self
481484
# If an x0 already exists or evaluating at the same function (i.e u = u.dx)
482485
# do not overwrite it
483486
if self.x0 or self.side is not None or func.function is self.expr.function:

devito/finite_differences/differentiable.py

Lines changed: 71 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,30 @@ def dtype(self):
9595

9696
@cached_property
9797
def indices(self):
98-
return tuple(filter_ordered(flatten(getattr(i, 'indices', ())
99-
for i in self._args_diff)))
98+
if not self._args_diff:
99+
return DimensionTuple()
100+
# Get indices of all args and merge them
101+
mapper = {}
102+
for a in self._args_diff:
103+
for d, i in a.indices.getters.items():
104+
mapper.setdefault(d, []).append(i)
105+
# Filter unique indices
106+
mapper = {k: v[0] if len(v) == 1 else tuple(filter_ordered(v))
107+
for k, v in mapper.items()}
108+
return DimensionTuple(*mapper.values(), getters=tuple(mapper.keys()))
100109

101110
@cached_property
102111
def dimensions(self):
103-
return tuple(filter_ordered(flatten(getattr(i, 'dimensions', ())
104-
for i in self._args_diff)))
112+
if not self._args_diff:
113+
return DimensionTuple()
114+
return highest_priority(self).dimensions
115+
116+
@cached_property
117+
def staggered(self):
118+
if not self._args_diff:
119+
return None
120+
# Use the staggering of the highest priority function
121+
return highest_priority(self).staggered
105122

106123
@cached_property
107124
def root_dimensions(self):
@@ -117,11 +134,6 @@ def indices_ref(self):
117134
return DimensionTuple(*self.dimensions, getters=self.dimensions)
118135
return highest_priority(self).indices_ref
119136

120-
@cached_property
121-
def staggered(self):
122-
return tuple(filter_ordered(flatten(getattr(i, 'staggered', ())
123-
for i in self._args_diff)))
124-
125137
@cached_property
126138
def is_Staggered(self):
127139
return any([getattr(i, 'is_Staggered', False) for i in self._args_diff])
@@ -475,12 +487,19 @@ def has_free(self, *patterns):
475487

476488

477489
def highest_priority(DiffOp):
490+
if not DiffOp._args_diff:
491+
return DiffOp
478492
# We want to get the object with highest priority
479493
# We also need to make sure that the object with the largest
480494
# set of dimensions is used when multiple ones with the same
481495
# priority appear
482496
prio = lambda x: (getattr(x, '_fd_priority', 0), len(x.dimensions))
483-
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]
497+
prio_func = sorted(DiffOp._args_diff, key=prio, reverse=True)[0]
498+
499+
# The highest priority must be a Function
500+
if not isinstance(prio_func, AbstractFunction):
501+
return highest_priority(prio_func)
502+
return prio_func
484503

485504

486505
class DifferentiableOp(Differentiable):
@@ -548,8 +567,11 @@ class DifferentiableFunction(DifferentiableOp):
548567
def __new__(cls, *args, **kwargs):
549568
return cls.__sympy_class__.__new__(cls, *args, **kwargs)
550569

551-
def _eval_at(self, func):
552-
return self
570+
@property
571+
def _fd_priority(self):
572+
if highest_priority(self) is self:
573+
return super()._fd_priority
574+
return highest_priority(self)._fd_priority
553575

554576

555577
class Add(DifferentiableOp, sympy.Add):
@@ -633,26 +655,12 @@ def _gather_for_diff(self):
633655
if len(set(f.staggered for f in self._args_diff)) == 1:
634656
return self
635657

636-
func_args = highest_priority(self)
637-
new_args = []
638-
ref_inds = func_args.indices_ref.getters
639-
640-
for f in self.args:
641-
if f not in self._args_diff \
642-
or f is func_args \
643-
or isinstance(f, DifferentiableFunction):
644-
new_args.append(f)
645-
else:
646-
ind_f = f.indices_ref.getters
647-
mapper = {ind_f.get(d, d): ref_inds.get(d, d)
648-
for d in self.dimensions
649-
if ind_f.get(d, d) is not ref_inds.get(d, d)}
650-
if mapper:
651-
new_args.append(f.subs(mapper))
652-
else:
653-
new_args.append(f)
654-
655-
return self.func(*new_args, evaluate=False)
658+
derivs, other = split(self.args, lambda a: isinstance(a, sympy.Derivative))
659+
if len(derivs) == 0:
660+
return self._eval_at(highest_priority(self))
661+
else:
662+
other = self.func(*other)._eval_at(highest_priority(self))
663+
return self.func(other, *derivs)
656664

657665

658666
class Pow(DifferentiableOp, sympy.Pow):
@@ -1034,6 +1042,9 @@ def __new__(cls, *args, base=None, **kwargs):
10341042
obj = super().__new__(cls, *args, **kwargs)
10351043

10361044
try:
1045+
if base is obj:
1046+
# In some rare cases (rebuild?) base may be obj itself
1047+
base = base.base
10371048
obj.base = base
10381049
except AttributeError:
10391050
# This might happen if e.g. one attempts a (re)construction with
@@ -1061,6 +1072,10 @@ def _eval_at(self, func):
10611072
# and should not be re-evaluated at a different location
10621073
return self
10631074

1075+
@property
1076+
def indices_ref(self):
1077+
return self.base.indices_ref
1078+
10641079

10651080
class diffify:
10661081

@@ -1184,6 +1199,28 @@ def _(expr, x0, **kwargs):
11841199
return expr.func(interp_for_fd(expr.expr, x0_expr, **kwargs))
11851200

11861201

1202+
@interp_for_fd.register(Mul)
1203+
def _(expr, x0, **kwargs):
1204+
# For a expression (e.g Mul or Add), we interpolate the whole expression
1205+
# Do we actually need interpolation
1206+
if all(expr.indices[d] is i for d, i in x0.items()):
1207+
return expr
1208+
1209+
# Split args between those that need interp and those that don't
1210+
def test0(a):
1211+
return all(a.indices[d] is i for d, i in x0.items() if d in a.dimensions)
1212+
oa, ia = split(expr._args_diff,
1213+
lambda a: isinstance(a, sympy.Derivative) or test0(a))
1214+
oa = oa + tuple(a for a in expr.args if a not in expr._args_diff)
1215+
1216+
# Interpolate the necessary args
1217+
d_dims = tuple((d, 0) for d in x0)
1218+
fd_order = tuple(expr.interp_order for d in x0)
1219+
iexpr = expr.func(*ia).diff(*d_dims, fd_order=fd_order, x0=x0, **kwargs)
1220+
1221+
return expr.func(iexpr, *oa)
1222+
1223+
11871224
@interp_for_fd.register(sympy.Expr)
11881225
def _(expr, x0, **kwargs):
11891226
if expr.args:
@@ -1194,7 +1231,8 @@ def _(expr, x0, **kwargs):
11941231

11951232
@interp_for_fd.register(AbstractFunction)
11961233
def _(expr, x0, **kwargs):
1197-
x0_expr = {d: v for d, v in x0.items() if v.has(d)}
1234+
x0_expr = {d: v for d, v in x0.items() if v.has(d)
1235+
and expr.indices[d] is not v}
11981236
if x0_expr:
11991237
return expr.subs({expr.indices[d]: v for d, v in x0_expr.items()})
12001238
else:

devito/finite_differences/tools.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,9 @@ def check_input(func):
5050
def wrapper(expr, *args, **kwargs):
5151
try:
5252
return S.Zero if expr.is_Number else func(expr, *args, **kwargs)
53-
except AttributeError:
54-
raise ValueError(
55-
f"'{expr}' must be of type Differentiable, not {type(expr)}"
56-
) from None
53+
except Exception as e:
54+
raise type(e)(f"Error while computing finite-difference for expr={expr}: "
55+
f"{e}") from e
5756
return wrapper
5857

5958

examples/seismic/tti/operators.py

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,6 @@ def kernel_staggered_2d(model, u, v, **kwargs):
280280
epsilon = 1 + 2 * epsilon
281281
delta = sqrt(1 + 2 * delta)
282282
s = model.grid.stepping_dim.spacing
283-
x, z = model.grid.dimensions
284283

285284
# Get source
286285
qu = kwargs.get('qu', 0)
@@ -291,31 +290,31 @@ def kernel_staggered_2d(model, u, v, **kwargs):
291290

292291
if forward:
293292
# Stencils
294-
phdx = costheta * u.dx - sintheta * u.dyc
293+
phdx = costheta * u.dx - sintheta * u.dy
295294
u_vx = Eq(vx.forward, dampl * vx - dampl * s * phdx)
296295

297-
pvdz = sintheta * v.dxc + costheta * v.dy
296+
pvdz = sintheta * v.dx + costheta * v.dy
298297
u_vz = Eq(vz.forward, dampl * vz - dampl * s * pvdz)
299298

300-
dvx = costheta * vx.forward.dx - sintheta * vx.forward.dyc
301-
dvz = sintheta * vz.forward.dxc + costheta * vz.forward.dy
299+
dvx = costheta * vx.forward.dx - sintheta * vx.forward.dy
300+
dvz = sintheta * vz.forward.dx + costheta * vz.forward.dy
302301

303302
# u and v equations
304303
pv_eq = Eq(v.forward, dampl * (v - s / m * (delta * dvx + dvz)) + s / m * qv)
305304
ph_eq = Eq(u.forward, dampl * (u - s / m * (epsilon * dvx + delta * dvz)) +
306305
s / m * qu)
307306
else:
308307
# Stencils
309-
phdx = ((costheta*epsilon*u).dx - (sintheta*epsilon*u).dyc +
310-
(costheta*delta*v).dx - (sintheta*delta*v).dyc)
308+
a = epsilon * u + delta * v
309+
phdx = (costheta * a).dx - (sintheta * a).dy
311310
u_vx = Eq(vx.backward, dampl * vx + dampl * s * phdx)
312311

313-
pvdz = ((sintheta*delta*u).dxc + (costheta*delta*u).dy +
314-
(sintheta*v).dxc + (costheta*v).dy)
312+
b = delta * u + v
313+
pvdz = (sintheta * b).dx + (costheta * b).dy
315314
u_vz = Eq(vz.backward, dampl * vz + dampl * s * pvdz)
316315

317-
dvx = (costheta * vx.backward).dx - (sintheta * vx.backward).dyc
318-
dvz = (sintheta * vz.backward).dxc + (costheta * vz.backward).dy
316+
dvx = (costheta * vx.backward).dx - (sintheta * vx.backward).dy
317+
dvz = (sintheta * vz.backward).dx + (costheta * vz.backward).dy
319318

320319
# u and v equations
321320
pv_eq = Eq(v.backward, dampl * (v + s / m * dvz))
@@ -356,24 +355,24 @@ def kernel_staggered_3d(model, u, v, **kwargs):
356355
if forward:
357356
# Stencils
358357
phdx = (costheta * cosphi * u.dx +
359-
costheta * sinphi * u.dyc -
360-
sintheta * u.dzc)
358+
costheta * sinphi * u.dy -
359+
sintheta * u.dz)
361360
u_vx = Eq(vx.forward, dampl * vx - dampl * s * phdx)
362361

363-
phdy = -sinphi * u.dxc + cosphi * u.dy
362+
phdy = -sinphi * u.dx + cosphi * u.dy
364363
u_vy = Eq(vy.forward, dampl * vy - dampl * s * phdy)
365364

366-
pvdz = (sintheta * cosphi * v.dxc +
367-
sintheta * sinphi * v.dyc +
365+
pvdz = (sintheta * cosphi * v.dx +
366+
sintheta * sinphi * v.dy +
368367
costheta * v.dz)
369368
u_vz = Eq(vz.forward, dampl * vz - dampl * s * pvdz)
370369

371370
dvx = (costheta * cosphi * vx.forward.dx +
372-
costheta * sinphi * vx.forward.dyc -
373-
sintheta * vx.forward.dzc)
374-
dvy = -sinphi * vy.forward.dxc + cosphi * vy.forward.dy
375-
dvz = (sintheta * cosphi * vz.forward.dxc +
376-
sintheta * sinphi * vz.forward.dyc +
371+
costheta * sinphi * vx.forward.dy -
372+
sintheta * vx.forward.dz)
373+
dvy = -sinphi * vy.forward.dx + cosphi * vy.forward.dy
374+
dvz = (sintheta * cosphi * vz.forward.dx +
375+
sintheta * sinphi * vz.forward.dy +
377376
costheta * vz.forward.dz)
378377
# u and v equations
379378
pv_eq = Eq(v.forward, dampl * (v - s / m * (delta * (dvx + dvy) + dvz)) +
@@ -383,30 +382,27 @@ def kernel_staggered_3d(model, u, v, **kwargs):
383382
delta * dvz)) + s / m * qu)
384383
else:
385384
# Stencils
386-
phdx = ((costheta * cosphi * epsilon*u).dx +
387-
(costheta * sinphi * epsilon*u).dyc -
388-
(sintheta * epsilon*u).dzc + (costheta * cosphi * delta*v).dx +
389-
(costheta * sinphi * delta*v).dyc -
390-
(sintheta * delta*v).dzc)
385+
a = epsilon * u + delta * v
386+
phdx = ((costheta * cosphi * a).dx +
387+
(costheta * sinphi * a).dy -
388+
(sintheta * a).dz)
391389
u_vx = Eq(vx.backward, dampl * vx + dampl * s * phdx)
392390

393-
phdy = (-(sinphi * epsilon*u).dxc + (cosphi * epsilon*u).dy -
394-
(sinphi * delta*v).dxc + (cosphi * delta*v).dy)
391+
phdy = (-(sinphi * a).dx + (cosphi * a).dy)
395392
u_vy = Eq(vy.backward, dampl * vy + dampl * s * phdy)
396393

397-
pvdz = ((sintheta * cosphi * delta*u).dxc +
398-
(sintheta * sinphi * delta*u).dyc +
399-
(costheta * delta*u).dz + (sintheta * cosphi * v).dxc +
400-
(sintheta * sinphi * v).dyc +
401-
(costheta * v).dz)
394+
b = delta * u + v
395+
pvdz = ((sintheta * cosphi * b).dx +
396+
(sintheta * sinphi * b).dy +
397+
(costheta * b).dz)
402398
u_vz = Eq(vz.backward, dampl * vz + dampl * s * pvdz)
403399

404400
dvx = ((costheta * cosphi * vx.backward).dx +
405-
(costheta * sinphi * vx.backward).dyc -
406-
(sintheta * vx.backward).dzc)
407-
dvy = (-sinphi * vy.backward).dxc + (cosphi * vy.backward).dy
408-
dvz = ((sintheta * cosphi * vz.backward).dxc +
409-
(sintheta * sinphi * vz.backward).dyc +
401+
(costheta * sinphi * vx.backward).dy -
402+
(sintheta * vx.backward).dz)
403+
dvy = (-sinphi * vy.backward).dx + (cosphi * vy.backward).dy
404+
dvz = ((sintheta * cosphi * vz.backward).dx +
405+
(sintheta * sinphi * vz.backward).dy +
410406
(costheta * vz.backward).dz)
411407
# u and v equations
412408
pv_eq = Eq(v.backward, dampl * (v + s / m * dvz))

tests/test_derivatives.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010
from devito.finite_differences import Derivative, Differentiable, diffify
1111
from devito.finite_differences.differentiable import (
12-
Add, DiffDerivative, EvalDerivative, IndexDerivative, IndexSum, Weights
12+
Add, DiffDerivative, EvalDerivative, IndexDerivative, IndexSum, Weights, interp_for_fd
1313
)
1414
from devito.symbolics import indexify, retrieve_indexed
1515
from devito.types.dimension import StencilDimension
@@ -921,7 +921,7 @@ def test_param_stagg_add(self):
921921
assert simplify(eq0.evaluate.rhs - expect0) == 0
922922

923923
# Expects to evaluate c11 and txy at xp then the derivative at yp
924-
expect1 = (c11._subs(x, xp).evaluate * txx._subs(x, xp).evaluate).dy.evaluate
924+
expect1 = (interp_for_fd((c11 * txx), {x: xp}).evaluate).dy.evaluate
925925
assert simplify(eq1.evaluate.rhs - expect1) == 0
926926

927927
# Addition should apply the same logic as above for each term

0 commit comments

Comments
 (0)