Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 94 additions & 30 deletions flax/nnx/variablelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,50 +328,74 @@ class VariableEffect(jax.core.Effect): ...
hjx.control_flow_allowed_effects.add_type(VariableEffect)


def _bind_new_variable(
*leaves, treedef, var_type, has_qdd, mutable, ref
) -> HijaxVariable:
"""Binds new_variable_p after instantiating any Zero tangents."""
leaves = tuple(hjx.instantiate_zeros(leaf) for leaf in leaves)
return new_variable_p.bind(
*leaves,
treedef=treedef,
var_type=var_type,
has_qdd=has_qdd,
mutable=mutable,
ref=ref,
)


def _new_hijax_from_variable(variable: Variable) -> HijaxVariable:
has_qdd = variable.mutable and not variable.ref
leaves, treedef = jax.tree.flatten(variable)
var_type = type(variable)
hijax_var = new_variable_p.bind(
*leaves, treedef=treedef, var_type=var_type, has_qdd=has_qdd
hijax_var = _bind_new_variable(
*leaves,
treedef=treedef,
var_type=var_type,
has_qdd=has_qdd,
mutable=variable.mutable,
ref=variable.ref,
)
return hijax_var


class NewVariable(hjx.HiPrimitive):
def is_high(self, *leaves, treedef, var_type, has_qdd) -> bool:
def is_high(self, *leaves, treedef, var_type, has_qdd, mutable, ref) -> bool:
return True # type: ignore

def impl(self, *leaves, treedef, var_type, has_qdd):
return HijaxVariable._new(leaves, treedef, var_type, has_qdd)
def impl(self, *leaves, treedef, var_type, has_qdd, mutable, ref):
return HijaxVariable._new(
leaves, treedef, var_type, has_qdd, mutable=mutable, ref=ref
)

def abstract_eval(self, *leaves, treedef, var_type, has_qdd):
aval = AbstractVariable(var_type, treedef, leaves, has_qdd)
def abstract_eval(self, *leaves, treedef, var_type, has_qdd, mutable, ref):
aval = AbstractVariable(
var_type, treedef, leaves, has_qdd, mutable=mutable, ref=ref
)
if has_qdd:
qdd = VariableQDD(tuple(leaves), treedef, var_type)
aval_qdd = hjx.AvalQDD(aval, qdd) # type: ignore
return aval_qdd, {variable_effect}
else:
return aval, set()

def to_lojax(self, *leaves, treedef, var_type, has_qdd):
return HijaxVariable._new(leaves, treedef, var_type, has_qdd)
def to_lojax(self, *leaves, treedef, var_type, has_qdd, mutable, ref):
return HijaxVariable._new(leaves, treedef, var_type, has_qdd, mutable=mutable, ref=ref)

def jvp(_, primals, tangents, *, treedef, var_type, has_qdd):
def jvp(_, primals, tangents, *, treedef, var_type, has_qdd, mutable, ref):
if has_qdd:
raise NotImplementedError(
"jvp not implemented for 'new_variable' with QDD"
)
primal_hijax_var = new_variable_p.bind(
*primals, treedef=treedef, var_type=var_type, has_qdd=has_qdd
primal_hijax_var = _bind_new_variable(
*primals, treedef=treedef, var_type=var_type, has_qdd=has_qdd, mutable=mutable, ref=ref
)
tangent_hijax_var = new_variable_p.bind(
*tangents, treedef=treedef, var_type=var_type, has_qdd=has_qdd
tangent_hijax_var = _bind_new_variable(
*tangents, treedef=treedef, var_type=var_type, has_qdd=has_qdd, mutable=mutable, ref=ref
)
return primal_hijax_var, tangent_hijax_var

def transpose(
_, out_var: HijaxVariable, *input_leaves, treedef, var_type, has_qdd
_, out_var: HijaxVariable, *input_leaves, treedef, var_type, has_qdd, mutable, ref
):
if has_qdd:
raise NotImplementedError(
Expand Down Expand Up @@ -553,8 +577,13 @@ def transpose(_, out, hijax_var, *, treedef, avals, var_type, has_qdd):
if hjx.is_undefined_primal(hijax_var)
else jax.typeof(hijax_var)
)
hijax_var_dot = new_variable_p.bind(
*out, treedef=abstract_var._treedef, var_type=var_type, has_qdd=has_qdd
hijax_var_dot = _bind_new_variable(
*out,
treedef=abstract_var._treedef,
var_type=var_type,
has_qdd=has_qdd,
mutable=abstract_var.mutable,
ref=abstract_var.ref,
)
return (hijax_var_dot,)

Expand Down Expand Up @@ -672,11 +701,13 @@ def __instancecheck__(self, instance):
class HijaxVariable(
tp.Generic[A], reprlib.Representable, metaclass=HijaxVariableMeta
): # type: ignore
__slots__ = ('_treedef', '_leaves', '_var_type', 'has_qdd')
__slots__ = ('_treedef', '_leaves', '_var_type', 'has_qdd', '_mutable', '_ref')
_treedef: PyTreeDef
_leaves: tuple[Leaf, ...]
_var_type: type[Variable[tp.Any]]
has_qdd: bool
_mutable: bool
_ref: bool

@classmethod
def _new(
Expand All @@ -685,12 +716,17 @@ def _new(
treedef: PyTreeDef,
var_type: type[Variable[A]],
has_qdd: bool,
*,
mutable: bool = True,
ref: bool = False,
):
hijax_var = object.__new__(cls)
object.__setattr__(hijax_var, '_treedef', treedef)
object.__setattr__(hijax_var, '_leaves', leaves)
object.__setattr__(hijax_var, '_var_type', var_type)
object.__setattr__(hijax_var, 'has_qdd', has_qdd)
object.__setattr__(hijax_var, '_mutable', mutable)
object.__setattr__(hijax_var, '_ref', ref)
return hijax_var

__init__ = _as_hijax_method('__init__')
Expand Down Expand Up @@ -726,8 +762,15 @@ def var_type(self) -> type[Variable[A]]:
type = _as_hijax_property('type', get=True, set=False)
type = _as_hijax_property('type', get=True, set=False)
hijax = _as_hijax_property('hijax', get=True, set=False)
ref = _as_hijax_property('ref', get=True, set=False)
mutable = _as_hijax_property('mutable', get=True, set=False)

@property
def ref(self) -> bool:
return self._ref

@property
def mutable(self) -> bool:
return self._mutable

get_metadata = _as_hijax_method('get_metadata')
set_metadata = _as_hijax_method('set_metadata')

Expand Down Expand Up @@ -849,7 +892,12 @@ def _to_abstract_variable(hijax_var: HijaxVariable):
leaves = tuple(map(jax.typeof, hijax_var._leaves))
treedef = hijax_var._treedef
return AbstractVariable(
hijax_var._var_type, treedef, leaves, hijax_var.has_qdd
hijax_var._var_type,
treedef,
leaves,
hijax_var.has_qdd,
mutable=hijax_var.mutable,
ref=hijax_var.ref,
)


Expand All @@ -860,19 +908,27 @@ def _to_abstract_variable(hijax_var: HijaxVariable):
# AbstractVariable
# ---------------------------------
class AbstractVariable(tp.Generic[A], hjx.MutableHiType):
__slots__ = ['_var_type', '_treedef', '_leaves', 'has_qdd']
__slots__ = ['_var_type', '_treedef', '_leaves', 'has_qdd', '_mutable', '_ref']
_var_type: type[Variable[A]]
_treedef: PyTreeDef | None
_leaves: tuple[hjx.AbstractValue, ...] | None
has_qdd: bool
_mutable: bool
_ref: bool
# forwarded to value
var_type = hjx.aval_property(lambda self: self.aval._var_type)
var_type = hjx.aval_property(lambda self: self.aval._var_type)
hijax = _as_aval_property(HijaxVariable.hijax)
ref = _as_aval_property(HijaxVariable.ref)
mutable = _as_aval_property(HijaxVariable.mutable)
_trace_state = _as_aval_property(HijaxVariable._trace_state)
_can_update = _as_aval_property(HijaxVariable._can_update)

@property
def ref(self) -> bool:
return self._ref

@property
def mutable(self) -> bool:
return self._mutable

@property
def hijax(self):
return True

_check_can_update = hjx.aval_method(HijaxVariable._check_can_update)

def __init__(
Expand All @@ -881,13 +937,18 @@ def __init__(
treedef: PyTreeDef | None,
leaves: tuple[hjx.AbstractValue, ...] | None,
has_qdd: bool,
*,
mutable: bool = True,
ref: bool = False,
):
if (treedef is None) ^ (leaves is None):
raise ValueError('treedef and leaves must be both provided or both None')
object.__setattr__(self, '_treedef', treedef)
object.__setattr__(self, '_leaves', leaves)
object.__setattr__(self, '_var_type', var_type)
object.__setattr__(self, 'has_qdd', has_qdd)
object.__setattr__(self, '_mutable', mutable)
object.__setattr__(self, '_ref', ref)

@property
def dtype(self):
Expand Down Expand Up @@ -1060,6 +1121,8 @@ def new_from_loval( # type: ignore[override]
variable_state.treedef,
self._var_type,
has_qdd=self.has_qdd,
mutable=self.mutable,
ref=self.ref,
) # will be mutated

def read_loval(self, variable_state: VariableQDD, variable) -> list: # type: ignore
Expand Down Expand Up @@ -1088,9 +1151,10 @@ def to_tangent_aval(self):
self._treedef,
self._leaves,
self.has_qdd,
mutable=self.mutable,
ref=self.ref,
)


# --------------------------------------------
# Variable
# --------------------------------------------
Expand Down
Loading
Loading