Skip to content

scico.linear_adjoint is incompatible with jax.ShapeDtypeStruct arguments #646

@bwohlberg

Description

@bwohlberg

In a number of locations in the scico source, a potentially large array of zeros is created as a part of the determination of the adjoint of a linear function:

scico/scico/linop/_linop.py

Lines 187 to 190 in 5d0f9a4

def _set_adjoint(self):
"""Automatically create adjoint method."""
adj_fun = linear_adjoint(self.__call__, snp.zeros(self.input_shape, dtype=self.input_dtype))
self._adj = lambda x: adj_fun(x)[0]

pad_adjoint = linear_adjoint(pad, snp.zeros(output_shape, dtype=input_dtype))

When the adjoint does not require complex conjugation (because the function is a real transform), array creation can be avoided by use of jax.ShapeDtypeStruct, as in

fun_t = jax.linear_transpose(fun, jax.ShapeDtypeStruct(input_shape, dtype=input_dtype))

but the scico extension supporting complex conjugation, scico.linear_adjoint, does not support jax.ShapeDtypeStruct arguments.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions