@@ -95,13 +95,30 @@ def dtype(self):
9595
9696 @cached_property
9797 def indices (self ):
98- return tuple (filter_ordered (flatten (getattr (i , 'indices' , ())
99- for i in self ._args_diff )))
98+ if not self ._args_diff :
99+ return DimensionTuple ()
100+ # Get indices of all args and merge them
101+ mapper = {}
102+ for a in self ._args_diff :
103+ for d , i in a .indices .getters .items ():
104+ mapper .setdefault (d , []).append (i )
105+ # Filter unique indices
106+ mapper = {k : v [0 ] if len (v ) == 1 else tuple (filter_ordered (v ))
107+ for k , v in mapper .items ()}
108+ return DimensionTuple (* mapper .values (), getters = tuple (mapper .keys ()))
100109
101110 @cached_property
102111 def dimensions (self ):
103- return tuple (filter_ordered (flatten (getattr (i , 'dimensions' , ())
104- for i in self ._args_diff )))
112+ if not self ._args_diff :
113+ return DimensionTuple ()
114+ return highest_priority (self ).dimensions
115+
116+ @cached_property
117+ def staggered (self ):
118+ if not self ._args_diff :
119+ return None
120+ # Use the staggering of the highest priority function
121+ return highest_priority (self ).staggered
105122
106123 @cached_property
107124 def root_dimensions (self ):
@@ -117,11 +134,6 @@ def indices_ref(self):
117134 return DimensionTuple (* self .dimensions , getters = self .dimensions )
118135 return highest_priority (self ).indices_ref
119136
120- @cached_property
121- def staggered (self ):
122- return tuple (filter_ordered (flatten (getattr (i , 'staggered' , ())
123- for i in self ._args_diff )))
124-
125137 @cached_property
126138 def is_Staggered (self ):
127139 return any ([getattr (i , 'is_Staggered' , False ) for i in self ._args_diff ])
@@ -475,12 +487,19 @@ def has_free(self, *patterns):
475487
476488
477489def highest_priority (DiffOp ):
490+ if not DiffOp ._args_diff :
491+ return DiffOp
478492 # We want to get the object with highest priority
479493 # We also need to make sure that the object with the largest
480494 # set of dimensions is used when multiple ones with the same
481495 # priority appear
482496 prio = lambda x : (getattr (x , '_fd_priority' , 0 ), len (x .dimensions ))
483- return sorted (DiffOp ._args_diff , key = prio , reverse = True )[0 ]
497+ prio_func = sorted (DiffOp ._args_diff , key = prio , reverse = True )[0 ]
498+
499+ # The highest priority must be a Function
500+ if not isinstance (prio_func , AbstractFunction ):
501+ return highest_priority (prio_func )
502+ return prio_func
484503
485504
486505class DifferentiableOp (Differentiable ):
@@ -548,8 +567,11 @@ class DifferentiableFunction(DifferentiableOp):
548567 def __new__ (cls , * args , ** kwargs ):
549568 return cls .__sympy_class__ .__new__ (cls , * args , ** kwargs )
550569
551- def _eval_at (self , func ):
552- return self
570+ @property
571+ def _fd_priority (self ):
572+ if highest_priority (self ) is self :
573+ return super ()._fd_priority
574+ return highest_priority (self )._fd_priority
553575
554576
555577class Add (DifferentiableOp , sympy .Add ):
@@ -633,26 +655,12 @@ def _gather_for_diff(self):
633655 if len (set (f .staggered for f in self ._args_diff )) == 1 :
634656 return self
635657
636- func_args = highest_priority (self )
637- new_args = []
638- ref_inds = func_args .indices_ref .getters
639-
640- for f in self .args :
641- if f not in self ._args_diff \
642- or f is func_args \
643- or isinstance (f , DifferentiableFunction ):
644- new_args .append (f )
645- else :
646- ind_f = f .indices_ref .getters
647- mapper = {ind_f .get (d , d ): ref_inds .get (d , d )
648- for d in self .dimensions
649- if ind_f .get (d , d ) is not ref_inds .get (d , d )}
650- if mapper :
651- new_args .append (f .subs (mapper ))
652- else :
653- new_args .append (f )
654-
655- return self .func (* new_args , evaluate = False )
658+ derivs , other = split (self .args , lambda a : isinstance (a , sympy .Derivative ))
659+ if len (derivs ) == 0 :
660+ return self ._eval_at (highest_priority (self ))
661+ else :
662+ other = self .func (* other )._eval_at (highest_priority (self ))
663+ return self .func (other , * derivs )
656664
657665
658666class Pow (DifferentiableOp , sympy .Pow ):
@@ -1034,6 +1042,9 @@ def __new__(cls, *args, base=None, **kwargs):
10341042 obj = super ().__new__ (cls , * args , ** kwargs )
10351043
10361044 try :
1045+ if base is obj :
1046+ # In some rare cases (rebuild?) base may be obj itself
1047+ base = base .base
10371048 obj .base = base
10381049 except AttributeError :
10391050 # This might happen if e.g. one attempts a (re)construction with
@@ -1061,6 +1072,10 @@ def _eval_at(self, func):
10611072 # and should not be re-evaluated at a different location
10621073 return self
10631074
1075+ @property
1076+ def indices_ref (self ):
1077+ return self .base .indices_ref
1078+
10641079
10651080class diffify :
10661081
@@ -1184,6 +1199,28 @@ def _(expr, x0, **kwargs):
11841199 return expr .func (interp_for_fd (expr .expr , x0_expr , ** kwargs ))
11851200
11861201
1202+ @interp_for_fd .register (Mul )
1203+ def _ (expr , x0 , ** kwargs ):
1204+ # For a expression (e.g Mul or Add), we interpolate the whole expression
1205+ # Do we actually need interpolation
1206+ if all (expr .indices [d ] is i for d , i in x0 .items ()):
1207+ return expr
1208+
1209+ # Split args between those that need interp and those that don't
1210+ def test0 (a ):
1211+ return all (a .indices [d ] is i for d , i in x0 .items () if d in a .dimensions )
1212+ oa , ia = split (expr ._args_diff ,
1213+ lambda a : isinstance (a , sympy .Derivative ) or test0 (a ))
1214+ oa = oa + tuple (a for a in expr .args if a not in expr ._args_diff )
1215+
1216+ # Interpolate the necessary args
1217+ d_dims = tuple ((d , 0 ) for d in x0 )
1218+ fd_order = tuple (expr .interp_order for d in x0 )
1219+ iexpr = expr .func (* ia ).diff (* d_dims , fd_order = fd_order , x0 = x0 , ** kwargs )
1220+
1221+ return expr .func (iexpr , * oa )
1222+
1223+
11871224@interp_for_fd .register (sympy .Expr )
11881225def _ (expr , x0 , ** kwargs ):
11891226 if expr .args :
@@ -1194,7 +1231,8 @@ def _(expr, x0, **kwargs):
11941231
11951232@interp_for_fd .register (AbstractFunction )
11961233def _ (expr , x0 , ** kwargs ):
1197- x0_expr = {d : v for d , v in x0 .items () if v .has (d )}
1234+ x0_expr = {d : v for d , v in x0 .items () if v .has (d )
1235+ and expr .indices [d ] is not v }
11981236 if x0_expr :
11991237 return expr .subs ({expr .indices [d ]: v for d , v in x0_expr .items ()})
12001238 else :
0 commit comments