Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions dedalus/core/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,18 +430,31 @@ def build_cartesian_ncc_matrix(self, subproblem, ncc_cutoff, max_ncc_terms):
Gamma = Gamma.transpose((2, 0, 1))
# Loop over NCC modes
shape = (subproblem.field_size(out), subproblem.field_size(arg))
matrix = sparse.csr_matrix(shape, dtype=self.dtype)
subproblem_shape = subproblem.coeff_shape(out.domain)
ncc_rank = len(ncc.tensorsig)
select_all_comps = tuple(slice(None) for i in range(ncc_rank))
# Optimization: batch accumulate matrices instead of sequential addition
all_rows = []
all_cols = []
all_data = []
if np.any(self._ncc_data):
for ncc_mode in np.ndindex(self._ncc_data.shape[ncc_rank:]):
ncc_coeffs = self._ncc_data[select_all_comps + ncc_mode]
if np.max(np.abs(ncc_coeffs)) > ncc_cutoff:
mode_matrix = self.cartesian_mode_matrix(subproblem_shape, ncc.domain, arg.domain, out.domain, ncc_mode)
mode_matrix = sparse.kron(np.dot(Gamma, ncc_coeffs.ravel()), mode_matrix, format='csr')
matrix = matrix + mode_matrix
return matrix
mode_matrix = sparse.kron(np.dot(Gamma, ncc_coeffs.ravel()), mode_matrix, format='coo')
all_rows.append(mode_matrix.row)
all_cols.append(mode_matrix.col)
all_data.append(mode_matrix.data)
# Batch merge all mode matrices
if all_rows:
combined_row = np.concatenate(all_rows)
combined_col = np.concatenate(all_cols)
combined_data = np.concatenate(all_data)
matrix = sparse.coo_matrix((combined_data, (combined_row, combined_col)), shape=shape)
return matrix.tocsr()
else:
return sparse.csr_matrix(shape, dtype=self.dtype)

@classmethod
def cartesian_mode_matrix(cls, subproblem_shape, ncc_domain, arg_domain, out_domain, ncc_mode):
Expand Down