diff --git a/cirkit/backend/torch/layers/__init__.py b/cirkit/backend/torch/layers/__init__.py index 658ac8ad..90d4f614 100644 --- a/cirkit/backend/torch/layers/__init__.py +++ b/cirkit/backend/torch/layers/__init__.py @@ -6,6 +6,7 @@ from .input import TorchCategoricalLayer as TorchCategoricalLayer from .input import TorchConstantValueLayer as TorchLogPartitionLayer from .input import TorchExpFamilyLayer as TorchExpFamilyLayer +from .input import TorchGaussianLayer as TorchDiscretizedLogisticLayer from .input import TorchGaussianLayer as TorchGaussianLayer from .input import TorchInputLayer as TorchInputLayer from .input import TorchPolynomialLayer as TorchPolynomialLayer diff --git a/cirkit/backend/torch/layers/input.py b/cirkit/backend/torch/layers/input.py index 5cb02330..26d2bc5c 100644 --- a/cirkit/backend/torch/layers/input.py +++ b/cirkit/backend/torch/layers/input.py @@ -685,6 +685,198 @@ def sample(self, num_samples: int = 1) -> Tensor: return samples +class TorchDiscretizedLogisticLayer(TorchExpFamilyLayer): + """The discretized logistic distribution layer. Optionally, this layer can encode unnormalized discretized logistic + distributions with the spefication of a log-partition function parameter.""" + + def __init__( + self, + scope_idx: Tensor, + num_output_units: int, + *, + marginal_mean: float, + marginal_stddev: float, + mean: TorchParameter, + stddev: TorchParameter, + log_partition: TorchParameter | None = None, + semiring: Semiring | None = None, + ) -> None: + r"""Initialize a discretized logistic layer. + + Args: + marginal_mean: The mean of the fitted data, which is used to rescale the learned parameters. + marginal_stddev: The standard deviation of the fitted data, which is used to rescale the learned parameters. + Instead of rescaling the data to be fitted, the learned parameters are rescaled instead. + This allows to preserve the discreteness of data. + scope_idx: A tensor of shape $(F, D)$, where $F$ is the number of folds, and + $D$ is the number of variables on which the input layers in each fold are defined on. + Alternatively, a tensor of shape $(D,)$ can be specified, which will be interpreted + as a tensor of shape $(1, D)$, i.e., with $F = 1$. + num_output_units: The number of output units. + mean: The mean parameter, having shape $(F, K)$, where $K$ is the number of + output units. + stddev: The standard deviation parameter, having shape $(F, K$, where $K$ is the + number of output units. + log_partition: An optional parameter of shape $(F, K$, encoding the log-partition. + function. If this is not None, then the discretized logistic layer encodes unnormalized + discretized logistic likelihoods, which are then normalized with the given log-partition + function. + semiring: The evaluation semiring. + Defaults to [SumProductSemiring][cirkit.backend.torch.semiring.SumProductSemiring]. + + Raises: + ValueError: If the scope contains more than one variable. + ValueError: If the mean and standard deviation parameter shapes are incorrect. + ValueError: If the log-partition function parameter shape is incorrect. + """ + num_variables = scope_idx.shape[-1] + if num_variables != 1: + raise ValueError("The Gaussian layer encodes a univariate distribution") + super().__init__( + scope_idx, + num_output_units, + semiring=semiring, + ) + if not self._valid_mean_stddev_shape(mean): + raise ValueError( + f"Expected number of folds {self.num_folds} " + f"and shape {self._mean_stddev_shape} for 'mean', found" + f"{mean.num_folds} and {mean.shape}, respectively" + ) + if not self._valid_mean_stddev_shape(stddev): + raise ValueError( + f"Expected number of folds {self.num_folds} " + f"and shape {self._mean_stddev_shape} for 'stddev', found" + f"{stddev.num_folds} and {stddev.shape}, respectively" + ) + if log_partition is not None and not self._valid_log_partition_shape(log_partition): + raise ValueError( + f"Expected number of folds {self.num_folds} " + f"and shape {self._log_partition_shape} for 'log_partition', found" + f"{log_partition.num_folds} and {log_partition.shape}, respectively" + ) + self.mean = mean + self.stddev = stddev + self.log_partition = log_partition + self.marginal_mean = marginal_mean + self.marginal_stddev = marginal_stddev + + def _valid_mean_stddev_shape(self, p: TorchParameter) -> bool: + if p.num_folds != self.num_folds: + return False + return p.shape == self._mean_stddev_shape + + def _valid_log_partition_shape(self, log_partition: TorchParameter) -> bool: + if log_partition.num_folds != self.num_folds: + return False + return log_partition.shape == self._log_partition_shape + + @property + def _mean_stddev_shape(self) -> tuple[int, ...]: + return (self.num_output_units,) + + @property + def _log_partition_shape(self) -> tuple[int, ...]: + return (self.num_output_units,) + + @property + def config(self) -> Mapping[str, Any]: + return { + "num_output_units": self.num_output_units, + "marginal_mean": self.marginal_mean, + "marginal_stddev": self.marginal_stddev, + } + + @property + def params(self) -> Mapping[str, TorchParameter]: + params = {"mean": self.mean, "stddev": self.stddev} + if self.log_partition is not None: + params["log_partition"] = self.log_partition + return params + + def _log1mexp(self, x: Tensor) -> Tensor: + """ + Computes log(1 - exp(-x)) with x>0 in a numerically stable way. + This is based on https://github.com/wouterkool/estimating-gradients-without-replacement/blob/9d8bf8b/bernoulli/gumbel.py#L7-L11 + """ + return torch.where( + x < 0.693, + torch.log(-torch.expm1(-x)), + torch.log1p(-torch.exp(-x)), + ) + + def _discrete_logistic_ll(self, x: Tensor, loc: Tensor, scale: Tensor) -> Tensor: + """ + Computes the log-likelihood of the discretized logistic distribution: + ll(x) = log(cdf(x + 0.5) - cdf(x - 0.5)) + where cdf(x) = 1 / (1 + exp(-(x - loc)/scale)) + Check https://en.wikipedia.org/wiki/Logistic_distribution for more details on the logistic distribution. + + This is a full derivation of the function: + ll(x) = log(cdf(x + 0.5) - cdf(x - 0.5)) + = log{1 / (1 + exp(-(x + 0.5 - loc)/scale)) - 1 / (1 + exp(-(x - 0.5 - loc)/scale))} + = log{[(1 + exp(-(x - 0.5 - loc)/scale)) - (1 + exp(-(x + 0.5 - loc)/scale))] / (1 + exp(-(x + 0.5 - loc)/scale)) * (1 + exp(-(x - 0.5 - loc)/scale))} + = log{[ exp(-(x - 0.5 - loc)/scale) - exp(-(x + 0.5 - loc)/scale) ]} - log{1 + exp(-(x + 0.5 - loc)/scale)} - log{1 + exp(-(x - 0.5 - loc)/scale)} + = log{exp(-(x - 0.5 - loc)/scale)*[ 1 - exp(-(x + 0.5 - loc)/scale + (x - 0.5 - loc)/scale)]} - log{1 + exp(-(x + 0.5 - loc)/scale)} - log{1 + exp(-(x - 0.5 - loc)/scale)} + = -(x - 0.5 - loc)/scale + log{1 - exp(-1/scale)} - log{1 + exp(-(x + 0.5 - loc)/scale)} - log{1 + exp(-(x - 0.5 - loc)/scale)} + + let's set + precision = 1 / scale + a = (x - 0.5 - loc) * precision + + = -a + log(1 - exp(-precision)) + log(sigmoid((x + 0.5 - loc) * precision)) + log(sigmoid(a)) + """ + precision = 1 / scale + a = (x - 0.5 - loc) * precision + return ( + -a + + self._log1mexp(precision) + + torch.nn.functional.logsigmoid((x + 0.5 - loc) * precision) + + torch.nn.functional.logsigmoid(a) + ) + + def _rescaled_mean(self) -> Tensor: + return self.mean() * self.marginal_stddev + self.marginal_mean + + def _rescaled_stddev(self) -> Tensor: + return self.stddev() * self.marginal_stddev + + def log_unnormalized_likelihood(self, x: Tensor) -> Tensor: + rescaled_mean = self._rescaled_mean().unsqueeze(dim=1) # (F, 1, K) + rescaled_stddev = self._rescaled_stddev().unsqueeze(dim=1) # (F, 1, K) + + x_out = self._discrete_logistic_ll(x.round(), rescaled_mean, rescaled_stddev) # (F, B, K) + if self.log_partition is not None: + log_partition = self.log_partition() # (F, K) + x_out = x_out + log_partition.unsqueeze(dim=1) + return x_out + + def log_partition_function(self) -> Tensor: + if self.log_partition is None: + return torch.zeros( + size=(self.num_folds, 1, self.num_output_units), device=self.mean.device + ) + log_partition = self.log_partition() # (F, K) + return log_partition.unsqueeze(dim=1) # (F, 1, K) + + def sample(self, num_samples: int = 1) -> Tensor: + """Sample from the discretized logistic distribution.""" + + def quantile_function(p: Tensor, loc: Tensor, scale: Tensor) -> Tensor: + return loc + scale * torch.log(p / (1 - p)) + + return ( + quantile_function( + torch.rand(size=(num_samples, *self.mean().shape), device=self.mean().device), + self._rescaled_mean(), + self._rescaled_stddev(), + ) + .detach() + .round() + .permute(1, 2, 0) # (F, K, N) + ) + + class TorchConstantValueLayer(TorchConstantLayer): """An input layer having empty scope and computing a constant value.""" diff --git a/cirkit/backend/torch/rules/layers.py b/cirkit/backend/torch/rules/layers.py index 0df235b2..39418160 100644 --- a/cirkit/backend/torch/rules/layers.py +++ b/cirkit/backend/torch/rules/layers.py @@ -10,6 +10,7 @@ TorchBinomialLayer, TorchCategoricalLayer, TorchConstantValueLayer, + TorchDiscretizedLogisticLayer, TorchEmbeddingLayer, TorchEvidenceLayer, TorchGaussianLayer, @@ -20,6 +21,7 @@ BinomialLayer, CategoricalLayer, ConstantValueLayer, + DiscretizedLogisticLayer, EmbeddingLayer, EvidenceLayer, GaussianLayer, @@ -101,6 +103,27 @@ def compile_gaussian_layer(compiler: "TorchCompiler", sl: GaussianLayer) -> Torc ) +def compile_discretized_logistic_layer( + compiler: "TorchCompiler", sl: DiscretizedLogisticLayer +) -> TorchDiscretizedLogisticLayer: + mean = compiler.compile_parameter(sl.mean) + stddev = compiler.compile_parameter(sl.stddev) + if sl.log_partition is not None: + log_partition = compiler.compile_parameter(sl.log_partition) + else: + log_partition = None + return TorchDiscretizedLogisticLayer( + torch.tensor(tuple(sl.scope)), + sl.num_output_units, + mean=mean, + stddev=stddev, + log_partition=log_partition, + semiring=compiler.semiring, + marginal_mean=sl.marginal_mean, + marginal_stddev=sl.marginal_stddev, + ) + + def compile_polynomial_layer( compiler: "TorchCompiler", sl: PolynomialLayer ) -> TorchPolynomialLayer: @@ -158,6 +181,7 @@ def compile_evidence_layer(compiler: "TorchCompiler", sl: EvidenceLayer) -> Torc CategoricalLayer: compile_categorical_layer, BinomialLayer: compile_binomial_layer, GaussianLayer: compile_gaussian_layer, + DiscretizedLogisticLayer: compile_discretized_logistic_layer, PolynomialLayer: compile_polynomial_layer, HadamardLayer: compile_hadamard_layer, KroneckerLayer: compile_kronecker_layer, diff --git a/cirkit/symbolic/layers.py b/cirkit/symbolic/layers.py index 794b7e46..6606c1ff 100644 --- a/cirkit/symbolic/layers.py +++ b/cirkit/symbolic/layers.py @@ -529,6 +529,109 @@ def params(self) -> Mapping[str, Parameter]: return params +class DiscretizedLogisticLayer(InputLayer): + """A symbolic discretized logistic layer, which is parameterized by mean and standard deviations. + Optionally, it can represent an unnormalized discretized logistic layer by specifying the log partition + function.""" + + def __init__( + self, + scope: Scope, + num_output_units: int, + *, + marginal_mean: float, + marginal_stddev: float, + mean: Parameter | None = None, + stddev: Parameter | None = None, + log_partition: Parameter | None = None, + mean_factory: ParameterFactory | None = None, + stddev_factory: ParameterFactory | None = None, + ): + r"""Initializes a discretized logistic layer. + + Args: + marginal_mean: The mean of the fitted data, which is used to rescale the learned parameters. + marginal_stddev: The standard deviation of the fitted data, which is used to rescale the learned parameters. + Instead of rescaling the data to be fitted, the learned parameters are rescaled instead. + This allows to preserve the discreteness of data. + scope: The variables scope the layer depends on. + num_output_units: The number of discretized logistic units in the layer. + mean: The mean parameter of shape $(K)$, where $K$ is the number of output units. + If it is None, then a default symbolic parameter will be instantiated with a + [NormalInitializer][cirkit.symbolic.initializers.NormalInitializer] as + symbolic initializer. + stddev: The standard deviation parameter of shape $(K)$, where $K$ is the number of + output units. If it is None, then a default symbolic parameter will be instantiated + with a [NormalInitializer][cirkit.symbolic.initializers.NormalInitializer] as + symbolic initializer, which is then re-parameterized to be positve using a + [ScaledSigmoidParameter][cirkit.symbolic.parameters.ScaledSigmoidParameter]. + log_partition: The log-partition parameter of the discretized laussian, of shape $(K,)$. + If the discretized laussian is a normalized discretized laussian, then this should be None. + mean_factory: A factory used to construct the mean parameter, if it is not specified. + stddev_factory: A factory used to construct the standard deviation parameter, if it is + not specified. + """ + if len(scope) != 1: + raise ValueError("The discretized laussian layer encodes a univariate distribution") + super().__init__(scope, num_output_units) + if mean is None: + if mean_factory is None: + mean = Parameter.from_input( + TensorParameter(*self._mean_stddev_shape, initializer=NormalInitializer()) + ) + else: + mean = mean_factory(self._mean_stddev_shape) + if stddev is None: + if stddev_factory is None: + stddev = Parameter.from_unary( + ScaledSigmoidParameter(self._mean_stddev_shape, vmin=1e-5, vmax=1.0), + TensorParameter(*self._mean_stddev_shape, initializer=NormalInitializer()), + ) + else: + stddev = stddev_factory(self._mean_stddev_shape) + if mean.shape != self._mean_stddev_shape: + raise ValueError( + f"Expected parameter shape {self._mean_stddev_shape}, found {mean.shape}" + ) + if stddev.shape != self._mean_stddev_shape: + raise ValueError( + f"Expected parameter shape {self._mean_stddev_shape}, found {stddev.shape}" + ) + if log_partition is not None and log_partition.shape != self._log_partition_shape: + raise ValueError( + f"Expected parameter shape {self._log_partition_shape}, found {log_partition.shape}" + ) + self.mean = mean + self.stddev = stddev + self.log_partition = log_partition + self.marginal_mean = marginal_mean + self.marginal_stddev = marginal_stddev + + @property + def _mean_stddev_shape(self) -> tuple[int, ...]: + return (self.num_output_units,) + + @property + def _log_partition_shape(self) -> tuple[int, ...]: + return (self.num_output_units,) + + @property + def config(self) -> Mapping[str, Any]: + return { + "scope": self.scope, + "num_output_units": self.num_output_units, + "marginal_mean": self.marginal_mean, + "marginal_stddev": self.marginal_stddev, + } + + @property + def params(self) -> Mapping[str, Parameter]: + params = {"mean": self.mean, "stddev": self.stddev} + if self.log_partition is not None: + params.update(log_partition=self.log_partition) + return params + + class PolynomialLayer(InputLayer): """A symbolic layer that evaluates polynomials.""" diff --git a/cirkit/templates/data_modalities.py b/cirkit/templates/data_modalities.py index 2ea9afe5..66968b70 100644 --- a/cirkit/templates/data_modalities.py +++ b/cirkit/templates/data_modalities.py @@ -86,7 +86,13 @@ def image_data( "poon-domingos", ]: raise ValueError(f"Unknown region graph called {region_graph}") - if input_layer not in ["categorical", "binomial", "embedding", "gaussian"]: + if input_layer not in [ + "categorical", + "binomial", + "embedding", + "gaussian", + "discretized_logistic", + ]: raise ValueError(f"Unknown input layer called {input_layer}") # Construct the image-tailored region graph @@ -116,6 +122,8 @@ def image_data( input_kwargs = {"num_states": 256} case "gaussian": input_kwargs = {} + case "discretized_logistic": + input_kwargs = {} case _: assert False if input_params is not None: diff --git a/cirkit/templates/pgms.py b/cirkit/templates/pgms.py index 2550feb8..c7ea44b6 100644 --- a/cirkit/templates/pgms.py +++ b/cirkit/templates/pgms.py @@ -22,8 +22,8 @@ def fully_factorized( Args: num_variables: The number of variables. - input_layer: The input layer to use for the factors. It can be 'categorical', 'binomial' or - 'gaussian'. Defaults to 'categorical'. + input_layer: The input layer to use for the factors. It can be 'categorical', 'binomial', + 'discretized_logistic' or 'gaussian'. Defaults to 'categorical'. input_params: A dictionary mapping each name of a parameter of the input layer to its parameterization. If it is None, then the default parameterization of the chosen input layer will be chosen. @@ -35,7 +35,7 @@ def fully_factorized( """ if num_variables <= 0: raise ValueError("The number of variables should be a positive integer") - if input_layer not in ["categorical", "binomial", "gaussian"]: + if input_layer not in ["categorical", "binomial", "gaussian", "discretized_logistic"]: raise ValueError(f"Unknown input layer called {input_layer}") input_layer_kwargs_ls: list[Mapping[str, Any]] if input_layer_kwargs is None: @@ -89,8 +89,8 @@ def hmm( Args: ordering: The input order of variables of the HMM. - input_layer: The input layer to use for the factors. It can be 'categorical', 'binomial' or - 'gaussian'. Defaults to 'categorical'. + input_layer: The input layer to use for the factors. It can be 'categorical', 'binomial', + 'discretized_logistic' or 'gaussian'. Defaults to 'categorical'. num_latent_states: The number of states the latent variables can assume or, equivalently, the number of sum units per sum layer. input_params: A dictionary mapping each name of a parameter of the input layer to @@ -116,7 +116,7 @@ def hmm( num_variables = len(ordering) if set(ordering) != set(range(num_variables)): raise ValueError("The 'ordering' of variables is not valid") - if input_layer not in ["categorical", "binomial", "gaussian"]: + if input_layer not in ["categorical", "binomial", "gaussian", "discretized_logistic"]: raise ValueError(f"Unknown input layer called {input_layer}") input_layer_kwargs_ls: list[Mapping[str, Any]] if input_layer_kwargs is None: diff --git a/cirkit/templates/utils.py b/cirkit/templates/utils.py index b26ef9fa..75a83e91 100644 --- a/cirkit/templates/utils.py +++ b/cirkit/templates/utils.py @@ -13,6 +13,7 @@ from cirkit.symbolic.layers import ( BinomialLayer, CategoricalLayer, + DiscretizedLogisticLayer, EmbeddingLayer, GaussianLayer, InputLayer, @@ -114,7 +115,7 @@ def name_to_input_layer_factory(name: str, **kwargs: Any) -> InputLayerFactory: Args: name: The name of the input layer. It can be one of the following: - 'embedding', 'categorical', 'gaussian', 'binomial'. + 'embedding', 'categorical', 'gaussian', 'binomial', 'discretized_logistic. **kwargs: Arguments to pass to the factory. Returns: @@ -132,6 +133,8 @@ def name_to_input_layer_factory(name: str, **kwargs: Any) -> InputLayerFactory: return functools.partial(BinomialLayer, **kwargs) case "gaussian": return functools.partial(GaussianLayer, **kwargs) + case "discretized_logistic": + return functools.partial(DiscretizedLogisticLayer, **kwargs) case _: raise ValueError(f"Unknown input layer called {name}")