diff --git a/cirkit/backend/torch/parameters/nodes.py b/cirkit/backend/torch/parameters/nodes.py index 8d5a1cd3..e98a6b7f 100644 --- a/cirkit/backend/torch/parameters/nodes.py +++ b/cirkit/backend/torch/parameters/nodes.py @@ -728,6 +728,17 @@ def forward(self, x: Tensor) -> Tensor: return torch.clamp(x, min=self.vmin, max=self.vmax) +class TorchSoftplusParameter(TorchEntrywiseParameterOp): + """Softmax reparameterization. + + Range: (0, + inf), 0 available if input is masked. + Constraints: Positive. + """ + + def forward(self, x: Tensor) -> Tensor: + return torch.nn.functional.softplus(x) + + class TorchConjugateParameter(TorchEntrywiseParameterOp): """Conjugate parameterization.""" diff --git a/cirkit/backend/torch/rules/parameters.py b/cirkit/backend/torch/rules/parameters.py index 227c1dba..0a9968c9 100644 --- a/cirkit/backend/torch/rules/parameters.py +++ b/cirkit/backend/torch/rules/parameters.py @@ -31,6 +31,7 @@ TorchScaledSigmoidParameter, TorchSigmoidParameter, TorchSoftmaxParameter, + TorchSoftplusParameter, TorchSquareParameter, TorchSumParameter, TorchTensorParameter, @@ -61,6 +62,7 @@ ScaledSigmoidParameter, SigmoidParameter, SoftmaxParameter, + SoftplusParameter, SquareParameter, SumParameter, TensorParameter, @@ -188,6 +190,13 @@ def compile_clamp_parameter(compiler: "TorchCompiler", p: ClampParameter) -> Tor return TorchClampParameter(in_shape, vmin=p.vmin, vmax=p.vmax) +def compile_softplus_parameter( + compiler: "TorchCompiler", p: SoftplusParameter +) -> TorchSoftplusParameter: + (in_shape,) = p.in_shapes + return TorchSoftplusParameter(in_shape) + + def compile_conjugate_parameter( compiler: "TorchCompiler", p: ClampParameter ) -> TorchConjugateParameter: @@ -285,6 +294,7 @@ def compile_polynomial_differential( SigmoidParameter: compile_sigmoid_parameter, ScaledSigmoidParameter: compile_scaled_sigmoid_parameter, ClampParameter: compile_clamp_parameter, + SoftplusParameter: compile_softplus_parameter, ConjugateParameter: compile_conjugate_parameter, ReduceSumParameter: compile_reduce_sum_parameter, ReduceProductParameter: compile_reduce_product_parameter, diff --git a/cirkit/templates/utils.py b/cirkit/templates/utils.py index b26ef9fa..01a87f4c 100644 --- a/cirkit/templates/utils.py +++ b/cirkit/templates/utils.py @@ -25,6 +25,7 @@ ParameterFactory, SigmoidParameter, SoftmaxParameter, + SoftplusParameter, TensorParameter, UnaryParameterOp, ) @@ -189,6 +190,8 @@ def name_to_parameter_activation( if "vmin" not in kwargs: kwargs["vmin"] = 1e-18 return functools.partial(ClampParameter, **kwargs) + case "softplus": + return functools.partial(SoftplusParameter, **kwargs) case _: raise ValueError