Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirkit/backend/torch/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .input import TorchCategoricalLayer as TorchCategoricalLayer
from .input import TorchConstantValueLayer as TorchLogPartitionLayer
from .input import TorchExpFamilyLayer as TorchExpFamilyLayer
from .input import TorchGaussianLayer as TorchDiscretizedLogisticLayer
from .input import TorchGaussianLayer as TorchGaussianLayer
from .input import TorchInputLayer as TorchInputLayer
from .input import TorchPolynomialLayer as TorchPolynomialLayer
Expand Down
192 changes: 192 additions & 0 deletions cirkit/backend/torch/layers/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,198 @@ def sample(self, num_samples: int = 1) -> Tensor:
return samples


class TorchDiscretizedLogisticLayer(TorchExpFamilyLayer):
"""The discretized logistic distribution layer. Optionally, this layer can encode unnormalized discretized logistic
distributions with the spefication of a log-partition function parameter."""

def __init__(
self,
scope_idx: Tensor,
num_output_units: int,
*,
marginal_mean: float,
marginal_stddev: float,
mean: TorchParameter,
stddev: TorchParameter,
log_partition: TorchParameter | None = None,
semiring: Semiring | None = None,
) -> None:
r"""Initialize a discretized logistic layer.

Args:
marginal_mean: The mean of the fitted data, which is used to rescale the learned parameters.
marginal_stddev: The standard deviation of the fitted data, which is used to rescale the learned parameters.
Instead of rescaling the data to be fitted, the learned parameters are rescaled instead.
This allows to preserve the discreteness of data.
scope_idx: A tensor of shape $(F, D)$, where $F$ is the number of folds, and
$D$ is the number of variables on which the input layers in each fold are defined on.
Alternatively, a tensor of shape $(D,)$ can be specified, which will be interpreted
as a tensor of shape $(1, D)$, i.e., with $F = 1$.
num_output_units: The number of output units.
mean: The mean parameter, having shape $(F, K)$, where $K$ is the number of
output units.
stddev: The standard deviation parameter, having shape $(F, K$, where $K$ is the
number of output units.
log_partition: An optional parameter of shape $(F, K$, encoding the log-partition.
function. If this is not None, then the discretized logistic layer encodes unnormalized
discretized logistic likelihoods, which are then normalized with the given log-partition
function.
semiring: The evaluation semiring.
Defaults to [SumProductSemiring][cirkit.backend.torch.semiring.SumProductSemiring].

Raises:
ValueError: If the scope contains more than one variable.
ValueError: If the mean and standard deviation parameter shapes are incorrect.
ValueError: If the log-partition function parameter shape is incorrect.
"""
num_variables = scope_idx.shape[-1]
if num_variables != 1:
raise ValueError("The Gaussian layer encodes a univariate distribution")
super().__init__(
scope_idx,
num_output_units,
semiring=semiring,
)
if not self._valid_mean_stddev_shape(mean):
raise ValueError(
f"Expected number of folds {self.num_folds} "
f"and shape {self._mean_stddev_shape} for 'mean', found"
f"{mean.num_folds} and {mean.shape}, respectively"
)
if not self._valid_mean_stddev_shape(stddev):
raise ValueError(
f"Expected number of folds {self.num_folds} "
f"and shape {self._mean_stddev_shape} for 'stddev', found"
f"{stddev.num_folds} and {stddev.shape}, respectively"
)
if log_partition is not None and not self._valid_log_partition_shape(log_partition):
raise ValueError(
f"Expected number of folds {self.num_folds} "
f"and shape {self._log_partition_shape} for 'log_partition', found"
f"{log_partition.num_folds} and {log_partition.shape}, respectively"
)
self.mean = mean
self.stddev = stddev
self.log_partition = log_partition
self.marginal_mean = marginal_mean
self.marginal_stddev = marginal_stddev

def _valid_mean_stddev_shape(self, p: TorchParameter) -> bool:
if p.num_folds != self.num_folds:
return False
return p.shape == self._mean_stddev_shape

def _valid_log_partition_shape(self, log_partition: TorchParameter) -> bool:
if log_partition.num_folds != self.num_folds:
return False
return log_partition.shape == self._log_partition_shape

@property
def _mean_stddev_shape(self) -> tuple[int, ...]:
return (self.num_output_units,)

@property
def _log_partition_shape(self) -> tuple[int, ...]:
return (self.num_output_units,)

@property
def config(self) -> Mapping[str, Any]:
return {
"num_output_units": self.num_output_units,
"marginal_mean": self.marginal_mean,
"marginal_stddev": self.marginal_stddev,
}

@property
def params(self) -> Mapping[str, TorchParameter]:
params = {"mean": self.mean, "stddev": self.stddev}
if self.log_partition is not None:
params["log_partition"] = self.log_partition
return params

def _log1mexp(self, x: Tensor) -> Tensor:
"""
Computes log(1 - exp(-x)) with x>0 in a numerically stable way.
This is based on https://github.com/wouterkool/estimating-gradients-without-replacement/blob/9d8bf8b/bernoulli/gumbel.py#L7-L11
"""
return torch.where(
x < 0.693,
torch.log(-torch.expm1(-x)),
torch.log1p(-torch.exp(-x)),
)

def _discrete_logistic_ll(self, x: Tensor, loc: Tensor, scale: Tensor) -> Tensor:
"""
Computes the log-likelihood of the discretized logistic distribution:
ll(x) = log(cdf(x + 0.5) - cdf(x - 0.5))
where cdf(x) = 1 / (1 + exp(-(x - loc)/scale))
Check https://en.wikipedia.org/wiki/Logistic_distribution for more details on the logistic distribution.

This is a full derivation of the function:
ll(x) = log(cdf(x + 0.5) - cdf(x - 0.5))
= log{1 / (1 + exp(-(x + 0.5 - loc)/scale)) - 1 / (1 + exp(-(x - 0.5 - loc)/scale))}
= log{[(1 + exp(-(x - 0.5 - loc)/scale)) - (1 + exp(-(x + 0.5 - loc)/scale))] / (1 + exp(-(x + 0.5 - loc)/scale)) * (1 + exp(-(x - 0.5 - loc)/scale))}
= log{[ exp(-(x - 0.5 - loc)/scale) - exp(-(x + 0.5 - loc)/scale) ]} - log{1 + exp(-(x + 0.5 - loc)/scale)} - log{1 + exp(-(x - 0.5 - loc)/scale)}
= log{exp(-(x - 0.5 - loc)/scale)*[ 1 - exp(-(x + 0.5 - loc)/scale + (x - 0.5 - loc)/scale)]} - log{1 + exp(-(x + 0.5 - loc)/scale)} - log{1 + exp(-(x - 0.5 - loc)/scale)}
= -(x - 0.5 - loc)/scale + log{1 - exp(-1/scale)} - log{1 + exp(-(x + 0.5 - loc)/scale)} - log{1 + exp(-(x - 0.5 - loc)/scale)}

let's set
precision = 1 / scale
a = (x - 0.5 - loc) * precision

= -a + log(1 - exp(-precision)) + log(sigmoid((x + 0.5 - loc) * precision)) + log(sigmoid(a))
"""
precision = 1 / scale
a = (x - 0.5 - loc) * precision
return (
-a
+ self._log1mexp(precision)
+ torch.nn.functional.logsigmoid((x + 0.5 - loc) * precision)
+ torch.nn.functional.logsigmoid(a)
)

def _rescaled_mean(self) -> Tensor:
return self.mean() * self.marginal_stddev + self.marginal_mean

def _rescaled_stddev(self) -> Tensor:
return self.stddev() * self.marginal_stddev

def log_unnormalized_likelihood(self, x: Tensor) -> Tensor:
rescaled_mean = self._rescaled_mean().unsqueeze(dim=1) # (F, 1, K)
rescaled_stddev = self._rescaled_stddev().unsqueeze(dim=1) # (F, 1, K)

x_out = self._discrete_logistic_ll(x.round(), rescaled_mean, rescaled_stddev) # (F, B, K)
if self.log_partition is not None:
log_partition = self.log_partition() # (F, K)
x_out = x_out + log_partition.unsqueeze(dim=1)
return x_out

def log_partition_function(self) -> Tensor:
if self.log_partition is None:
return torch.zeros(
size=(self.num_folds, 1, self.num_output_units), device=self.mean.device
)
log_partition = self.log_partition() # (F, K)
return log_partition.unsqueeze(dim=1) # (F, 1, K)

def sample(self, num_samples: int = 1) -> Tensor:
"""Sample from the discretized logistic distribution."""

def quantile_function(p: Tensor, loc: Tensor, scale: Tensor) -> Tensor:
return loc + scale * torch.log(p / (1 - p))

return (
quantile_function(
torch.rand(size=(num_samples, *self.mean().shape), device=self.mean().device),
self._rescaled_mean(),
self._rescaled_stddev(),
)
.detach()
.round()
.permute(1, 2, 0) # (F, K, N)
)


class TorchConstantValueLayer(TorchConstantLayer):
"""An input layer having empty scope and computing a constant value."""

Expand Down
24 changes: 24 additions & 0 deletions cirkit/backend/torch/rules/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TorchBinomialLayer,
TorchCategoricalLayer,
TorchConstantValueLayer,
TorchDiscretizedLogisticLayer,
TorchEmbeddingLayer,
TorchEvidenceLayer,
TorchGaussianLayer,
Expand All @@ -20,6 +21,7 @@
BinomialLayer,
CategoricalLayer,
ConstantValueLayer,
DiscretizedLogisticLayer,
EmbeddingLayer,
EvidenceLayer,
GaussianLayer,
Expand Down Expand Up @@ -101,6 +103,27 @@ def compile_gaussian_layer(compiler: "TorchCompiler", sl: GaussianLayer) -> Torc
)


def compile_discretized_logistic_layer(
compiler: "TorchCompiler", sl: DiscretizedLogisticLayer
) -> TorchDiscretizedLogisticLayer:
mean = compiler.compile_parameter(sl.mean)
stddev = compiler.compile_parameter(sl.stddev)
if sl.log_partition is not None:
log_partition = compiler.compile_parameter(sl.log_partition)
else:
log_partition = None
return TorchDiscretizedLogisticLayer(
torch.tensor(tuple(sl.scope)),
sl.num_output_units,
mean=mean,
stddev=stddev,
log_partition=log_partition,
semiring=compiler.semiring,
marginal_mean=sl.marginal_mean,
marginal_stddev=sl.marginal_stddev,
)


def compile_polynomial_layer(
compiler: "TorchCompiler", sl: PolynomialLayer
) -> TorchPolynomialLayer:
Expand Down Expand Up @@ -158,6 +181,7 @@ def compile_evidence_layer(compiler: "TorchCompiler", sl: EvidenceLayer) -> Torc
CategoricalLayer: compile_categorical_layer,
BinomialLayer: compile_binomial_layer,
GaussianLayer: compile_gaussian_layer,
DiscretizedLogisticLayer: compile_discretized_logistic_layer,
PolynomialLayer: compile_polynomial_layer,
HadamardLayer: compile_hadamard_layer,
KroneckerLayer: compile_kronecker_layer,
Expand Down
103 changes: 103 additions & 0 deletions cirkit/symbolic/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,109 @@ def params(self) -> Mapping[str, Parameter]:
return params


class DiscretizedLogisticLayer(InputLayer):
"""A symbolic discretized logistic layer, which is parameterized by mean and standard deviations.
Optionally, it can represent an unnormalized discretized logistic layer by specifying the log partition
function."""

def __init__(
self,
scope: Scope,
num_output_units: int,
*,
marginal_mean: float,
marginal_stddev: float,
mean: Parameter | None = None,
stddev: Parameter | None = None,
log_partition: Parameter | None = None,
mean_factory: ParameterFactory | None = None,
stddev_factory: ParameterFactory | None = None,
):
r"""Initializes a discretized logistic layer.

Args:
marginal_mean: The mean of the fitted data, which is used to rescale the learned parameters.
marginal_stddev: The standard deviation of the fitted data, which is used to rescale the learned parameters.
Instead of rescaling the data to be fitted, the learned parameters are rescaled instead.
This allows to preserve the discreteness of data.
scope: The variables scope the layer depends on.
num_output_units: The number of discretized logistic units in the layer.
mean: The mean parameter of shape $(K)$, where $K$ is the number of output units.
If it is None, then a default symbolic parameter will be instantiated with a
[NormalInitializer][cirkit.symbolic.initializers.NormalInitializer] as
symbolic initializer.
stddev: The standard deviation parameter of shape $(K)$, where $K$ is the number of
output units. If it is None, then a default symbolic parameter will be instantiated
with a [NormalInitializer][cirkit.symbolic.initializers.NormalInitializer] as
symbolic initializer, which is then re-parameterized to be positve using a
[ScaledSigmoidParameter][cirkit.symbolic.parameters.ScaledSigmoidParameter].
log_partition: The log-partition parameter of the discretized laussian, of shape $(K,)$.
If the discretized laussian is a normalized discretized laussian, then this should be None.
mean_factory: A factory used to construct the mean parameter, if it is not specified.
stddev_factory: A factory used to construct the standard deviation parameter, if it is
not specified.
"""
if len(scope) != 1:
raise ValueError("The discretized laussian layer encodes a univariate distribution")
super().__init__(scope, num_output_units)
if mean is None:
if mean_factory is None:
mean = Parameter.from_input(
TensorParameter(*self._mean_stddev_shape, initializer=NormalInitializer())
)
else:
mean = mean_factory(self._mean_stddev_shape)
if stddev is None:
if stddev_factory is None:
stddev = Parameter.from_unary(
ScaledSigmoidParameter(self._mean_stddev_shape, vmin=1e-5, vmax=1.0),
TensorParameter(*self._mean_stddev_shape, initializer=NormalInitializer()),
)
else:
stddev = stddev_factory(self._mean_stddev_shape)
if mean.shape != self._mean_stddev_shape:
raise ValueError(
f"Expected parameter shape {self._mean_stddev_shape}, found {mean.shape}"
)
if stddev.shape != self._mean_stddev_shape:
raise ValueError(
f"Expected parameter shape {self._mean_stddev_shape}, found {stddev.shape}"
)
if log_partition is not None and log_partition.shape != self._log_partition_shape:
raise ValueError(
f"Expected parameter shape {self._log_partition_shape}, found {log_partition.shape}"
)
self.mean = mean
self.stddev = stddev
self.log_partition = log_partition
self.marginal_mean = marginal_mean
self.marginal_stddev = marginal_stddev

@property
def _mean_stddev_shape(self) -> tuple[int, ...]:
return (self.num_output_units,)

@property
def _log_partition_shape(self) -> tuple[int, ...]:
return (self.num_output_units,)

@property
def config(self) -> Mapping[str, Any]:
return {
"scope": self.scope,
"num_output_units": self.num_output_units,
"marginal_mean": self.marginal_mean,
"marginal_stddev": self.marginal_stddev,
}

@property
def params(self) -> Mapping[str, Parameter]:
params = {"mean": self.mean, "stddev": self.stddev}
if self.log_partition is not None:
params.update(log_partition=self.log_partition)
return params


class PolynomialLayer(InputLayer):
"""A symbolic layer that evaluates polynomials."""

Expand Down
10 changes: 9 additions & 1 deletion cirkit/templates/data_modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,13 @@ def image_data(
"poon-domingos",
]:
raise ValueError(f"Unknown region graph called {region_graph}")
if input_layer not in ["categorical", "binomial", "embedding", "gaussian"]:
if input_layer not in [
"categorical",
"binomial",
"embedding",
"gaussian",
"discretized_logistic",
]:
raise ValueError(f"Unknown input layer called {input_layer}")

# Construct the image-tailored region graph
Expand Down Expand Up @@ -116,6 +122,8 @@ def image_data(
input_kwargs = {"num_states": 256}
case "gaussian":
input_kwargs = {}
case "discretized_logistic":
input_kwargs = {}
case _:
assert False
if input_params is not None:
Expand Down
Loading
Loading