Add a functional cross entropy method for discrete search spaces#124
Add a functional cross entropy method for discrete search spaces#124engintoklu wants to merge 2 commits intomasterfrom
Conversation
📝 WalkthroughWalkthroughAdds a functional Discrete Cross Entropy Method (DCEM): a new module implementing probability computations, sampling, and ask/tell state updates for binary and categorical discrete optimization, and exports the DCEM API from the functional algorithms package. Changes
Sequence DiagramsequenceDiagram
participant User
participant Ask as dcem_ask
participant Tell as dcem_tell
participant State as DCEMState
User->>Ask: dcem_ask(state, popsize=N)
Ask->>Ask: Compute per-variable probabilities from state.center
Ask->>Ask: Sample population (binary/categorical) using probabilities
Ask-->>User: Return population (N × solution_length)
User->>User: Evaluate population -> evals
User->>Tell: dcem_tell(state, population, evals)
Tell->>Tell: Select top parents by parenthood_ratio and objective
Tell->>Tell: Compute new probabilities from parents, apply prob_min
Tell->>State: Build updated DCEMState with new center/probabilities
Tell-->>User: Return updated state
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@src/evotorch/algorithms/functional/funcdcem.py`:
- Around line 304-307: The ValueError message for the branch handling invalid
objective_sense is written as a regular string so `{repr(objective_sense)}`
won't interpolate; update the raise ValueError call (the block referencing
objective_sense) to use an f-string (e.g., prefix the string with f and
interpolate the variable, or use {objective_sense!r}) so the actual value is
shown in the error message.
- Around line 118-122: The error message for the check "if num_choices < 2:"
incorrectly claims num_choices is omitted; update the ValueError in that branch
(in the funcdcem handling of non-boolean integer variables) to report the actual
invalid value and clearer text—e.g. state that num_choices must be >= 2 and
include the provided num_choices value (reference symbol: num_choices and the
conditional block guarded by "if num_choices < 2:").
- Around line 441-451: Docstring references the parameter as `values` but the
function signature uses `population`; update the docstring to consistently use
`population` (replace all occurrences of `values` with `population`) and ensure
the shape/description examples still match (`population` shaped (N, L) ->
`evals` length N; (B, N, L) -> `evals` shape (B, N)). Adjust any related
sentences in the function docstring in funcdcem.py (the docstring for the
function that accepts the `population` parameter) so names and shape examples
align.
🧹 Nitpick comments (2)
src/evotorch/algorithms/functional/funcdcem.py (2)
139-181: Consider usingtorch.searchsortedfor categorical sampling.The current implementation builds a TensorFrame, broadcasts the random number, applies a row-wise check, then takes
max(). This is correct but heavyweight for what is essentially an inverse-CDF sample.torch.searchsortedon the cumulative sum would be more efficient and concise.That said, if the TensorFrame/
eachapproach is intentional forvmapcompatibility, please disregard.
35-68: Potential division by zero on line 60 is guarded by the caller — but worth a note.If all tail probabilities are zero while
relu(1 - total_so_far) > 0,torch.sum(tail)is 0 andrescalerbecomesinf. The caller_apply_lower_boundhandles this via thetorch.isfinitecheck on line 89-90, falling back to the original probabilities. This is a reasonable safety net, but consider adding a brief inline comment near line 60 documenting this invariant so future maintainers understand the contract.
| Args: | ||
| state: The old state of the cross entropy method search. | ||
| values: The most recent population, as a PyTorch tensor. | ||
| evals: Evaluation results (i.e. fitnesses) for the solutions expressed | ||
| by `values`. For example, if `values` is shaped `(N, L)`, this means | ||
| that there are `N` solutions (of length `L`). So, `evals` is | ||
| expected as a 1-dimensional tensor of length `N`, where `evals[i]` | ||
| expresses the fitness of the solution `values[i, :]`. | ||
| If `values` is shaped `(B, N, L)`, then there is also a batch | ||
| dimension, so, `evals` is expected as a 2-dimensional tensor of | ||
| shape `(B, N)`. |
There was a problem hiding this comment.
Docstring parameter name mismatch: values vs population.
The function signature uses population (line 419), but the docstring refers to it as values (lines 443, 445, 448, 449). This will confuse users relying on the documentation.
Proposed fix
Args:
state: The old state of the cross entropy method search.
- values: The most recent population, as a PyTorch tensor.
+ population: The most recent population, as a PyTorch tensor.
evals: Evaluation results (i.e. fitnesses) for the solutions expressed
- by `values`. For example, if `values` is shaped `(N, L)`, this means
+ by `population`. For example, if `population` is shaped `(N, L)`, this means
that there are `N` solutions (of length `L`). So, `evals` is
- expected as a 1-dimensional tensor of length `N`, where `evals[i]`
- expresses the fitness of the solution `values[i, :]`.
- If `values` is shaped `(B, N, L)`, then there is also a batch
+ expected as a 1-dimensional tensor of length `N`, where `evals[i]`
+ expresses the fitness of the solution `population[i, :]`.
+ If `population` is shaped `(B, N, L)`, then there is also a batch
dimension, so, `evals` is expected as a 2-dimensional tensor of
shape `(B, N)`.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Args: | |
| state: The old state of the cross entropy method search. | |
| values: The most recent population, as a PyTorch tensor. | |
| evals: Evaluation results (i.e. fitnesses) for the solutions expressed | |
| by `values`. For example, if `values` is shaped `(N, L)`, this means | |
| that there are `N` solutions (of length `L`). So, `evals` is | |
| expected as a 1-dimensional tensor of length `N`, where `evals[i]` | |
| expresses the fitness of the solution `values[i, :]`. | |
| If `values` is shaped `(B, N, L)`, then there is also a batch | |
| dimension, so, `evals` is expected as a 2-dimensional tensor of | |
| shape `(B, N)`. | |
| Args: | |
| state: The old state of the cross entropy method search. | |
| population: The most recent population, as a PyTorch tensor. | |
| evals: Evaluation results (i.e. fitnesses) for the solutions expressed | |
| by `population`. For example, if `population` is shaped `(N, L)`, this means | |
| that there are `N` solutions (of length `L`). So, `evals` is | |
| expected as a 1-dimensional tensor of length `N`, where `evals[i]` | |
| expresses the fitness of the solution `population[i, :]`. | |
| If `population` is shaped `(B, N, L)`, then there is also a batch | |
| dimension, so, `evals` is expected as a 2-dimensional tensor of | |
| shape `(B, N)`. |
🤖 Prompt for AI Agents
In `@src/evotorch/algorithms/functional/funcdcem.py` around lines 441 - 451,
Docstring references the parameter as `values` but the function signature uses
`population`; update the docstring to consistently use `population` (replace all
occurrences of `values` with `population`) and ensure the shape/description
examples still match (`population` shaped (N, L) -> `evals` length N; (B, N, L)
-> `evals` shape (B, N)). Adjust any related sentences in the function docstring
in funcdcem.py (the docstring for the function that accepts the `population`
parameter) so names and shape examples align.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #124 +/- ##
==========================================
- Coverage 75.43% 74.44% -1.00%
==========================================
Files 59 60 +1
Lines 9556 9723 +167
==========================================
+ Hits 7209 7238 +29
- Misses 2347 2485 +138 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
src/evotorch/algorithms/functional/funcdcem.py (2)
139-181: Categorical sampling via cumsum intervals is correct but potentially slow for large category counts.The TensorFrame +
each+max()approach createsnrows and vmapsget_index_if_matchesacross them to find which interval the random number lands in. This works correctly (using 0 as the sentinel for non-matching indices is safe since a true match at index 0 also returns 0, andmax()correctly picks the unique non-zero match otherwise).For large
num_choices,torch.searchsortedon the cumulative sum would be more efficient and isvmap-compatible. This is not a blocking issue given the expected category counts in discrete optimization, but worth considering if performance becomes a concern.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/evotorch/algorithms/functional/funcdcem.py` around lines 139 - 181, The current _index_of_chosen_item builds a TensorFrame of n rows and uses choices.each(get_index_if_matches).max() to locate the sampled category, which is correct but inefficient for large n; replace that multi-row TensorFrame + each + max pipeline by using torch.searchsorted on the cumulative lower bounds/upper bounds (or just on the cumulative sums) to compute the chosen index vectorizedly. Specifically, in _index_of_chosen_item compute the cumulative sum tensor (e.g., lb or cumsum of probabilities), then call torch.searchsorted(cumsum, number_between_0_1) (adjusting for dtype/device) to get the index directly instead of constructing choices, get_index_if_matches, and calling each/max.
35-68: In-place mutation inside@expects_ndim-decorated function is fragile undervmap.
_apply_lower_bound_for_categoricalmutatestblin-place viatbl.pick[i_head:, "PROBABILITY"] = ...(line 62). While the current call chain ensures this function is only invoked on already-unbatched tensors (called from inside_apply_lower_boundwhich handles batching),torch.func.vmapgenerally does not support in-place tensor mutation. If this function is ever called directly with batch dimensions, it will fail at runtime.Additionally, line 60 has a division-by-zero risk when
torch.sum(tail)equals 0 — though theisfiniteguard in_apply_lower_bound(line 89) provides a fallback.Consider adding a comment or assertion documenting that this function must not be called with batch dimensions, or restructuring the loop to avoid in-place mutation.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/evotorch/algorithms/functional/funcdcem.py` around lines 35 - 68, The function _apply_lower_bound_for_categorical mutates TensorFrame via tbl.pick[i_head:, "PROBABILITY"] which is fragile under torch.func.vmap and also risks division-by-zero when torch.sum(tail)==0; change the loop to avoid in-place updates by working with a local probability vector (e.g., create a local tensor prob = tbl.PROBABILITY.clone() or construct a new tensor each iteration) and assign into tbl or a result only after finishing the rescaling for each step, and guard the rescaler computation in the loop by checking denom = torch.sum(tail) and using rescaler = 0 (or relu(1-total_so_far)) when denom == 0 to prevent division-by-zero; also keep or add the `@expects_ndim`(1, 0) assertion and a short comment above _apply_lower_bound_for_categorical noting it expects unbatched 1-D input (called from _apply_lower_bound) to make the requirement explicit.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/evotorch/algorithms/functional/funcdcem.py`:
- Around line 210-370: The function dcem currently accepts parenthood_ratio
without bounds checking; add a validation just before constructing the
DCEMState: coerce parenthood_ratio to float (parenthood_ratio =
float(parenthood_ratio)) and raise ValueError if not (0.0 < parenthood_ratio <=
1.0) with a clear message about allowed range; this prevents surprising behavior
later in _dcem_tell which computes parent counts using parenthood_ratio and
relies on it being within (0,1].
---
Duplicate comments:
In `@src/evotorch/algorithms/functional/funcdcem.py`:
- Around line 419-470: Rename the public dcem_tell parameter from values to
population to match _dcem_tell and dcem_ask: update the function signature of
dcem_tell (currently def dcem_tell(state: DCEMState, values: torch.Tensor,
evals: torch.Tensor) -> DCEMState) to use population, replace all internal
references (the call to _dcem_tell should pass population instead of values),
and update the dcem_tell docstring occurrences of "values" to "population" so
the public API naming is consistent with _dcem_tell and dcem_ask.
---
Nitpick comments:
In `@src/evotorch/algorithms/functional/funcdcem.py`:
- Around line 139-181: The current _index_of_chosen_item builds a TensorFrame of
n rows and uses choices.each(get_index_if_matches).max() to locate the sampled
category, which is correct but inefficient for large n; replace that multi-row
TensorFrame + each + max pipeline by using torch.searchsorted on the cumulative
lower bounds/upper bounds (or just on the cumulative sums) to compute the chosen
index vectorizedly. Specifically, in _index_of_chosen_item compute the
cumulative sum tensor (e.g., lb or cumsum of probabilities), then call
torch.searchsorted(cumsum, number_between_0_1) (adjusting for dtype/device) to
get the index directly instead of constructing choices, get_index_if_matches,
and calling each/max.
- Around line 35-68: The function _apply_lower_bound_for_categorical mutates
TensorFrame via tbl.pick[i_head:, "PROBABILITY"] which is fragile under
torch.func.vmap and also risks division-by-zero when torch.sum(tail)==0; change
the loop to avoid in-place updates by working with a local probability vector
(e.g., create a local tensor prob = tbl.PROBABILITY.clone() or construct a new
tensor each iteration) and assign into tbl or a result only after finishing the
rescaling for each step, and guard the rescaler computation in the loop by
checking denom = torch.sum(tail) and using rescaler = 0 (or
relu(1-total_so_far)) when denom == 0 to prevent division-by-zero; also keep or
add the `@expects_ndim`(1, 0) assertion and a short comment above
_apply_lower_bound_for_categorical noting it expects unbatched 1-D input (called
from _apply_lower_bound) to make the requirement explicit.
| def dcem( | ||
| *, | ||
| objective_sense: str, | ||
| center_init: Optional[Union[np.ndarray, torch.Tensor]] = None, | ||
| solution_length: Optional[int] = None, | ||
| num_choices: Optional[int] = None, | ||
| prob_min: Optional[BatchableScalar] = None, | ||
| parenthood_ratio: float, | ||
| device: Optional[Device] = None, | ||
| ) -> DCEMState: | ||
| """ | ||
| Initialize a Discrete Cross Entropy Method. | ||
|
|
||
| This discretized counterpart of cross entropy method can be used for | ||
| heuristically solving optimization problems in which the decision | ||
| variables are binary or categorical. | ||
|
|
||
| **Binary case.** | ||
| If the argument `num_choices` is left as None, then it is interpreted | ||
| that the desired search space consists of binary variables. In this case, | ||
| a solution of `n` decision variables is represented by a tensor | ||
| of dtype `torch.bool`. | ||
|
|
||
| **Categorical case.** | ||
| If the argument `num_choices` is given as an integer `m`, then it is | ||
| interpreted that the search space consists of categorical variables, | ||
| and each categorical value is allowed to take an integer value | ||
| between 0 and m-1. In this case, a solution of `n` decision variables | ||
| is represented by a tensor of dtype `torch.int64`. | ||
|
|
||
| References: | ||
|
|
||
| Rubinstein, R. (1999). The cross-entropy method for combinatorial | ||
| and continuous optimization. | ||
| Methodology and computing in applied probability, 1(2), 127-190. | ||
|
|
||
| Botev, Z. I., Kroese, D. P., Rubinstein, R. Y., & L'ecuyer, P. (2013). | ||
| The cross-entropy method for optimization. | ||
| In Handbook of statistics (Vol. 31, pp. 35-59). | ||
|
|
||
| Args: | ||
| objective_sense: Expected as a string, either as 'min' or as 'max'. | ||
| Determines if the goal is to minimize or is to maximize. | ||
| center_init: Optionally the starting point for the heuristic search, | ||
| in which the values are real numbers between 0 and 1. | ||
| Let us assume that the problem at hand has `n` decision variables. | ||
| If the decision variables are binary (with dtype `torch.bool`), | ||
| then `center_init` can be given as a tensor with at least 1 | ||
| dimension, with length `n`. Within this tensor, the i-th item | ||
| represents the initial probability of setting the i-th variable | ||
| as True during the phase of population sampling. | ||
| If the decision variables are categorical with `m` categories, | ||
| then `center_init` can be given as a tensor with at least 2 | ||
| dimensions, the shape of these rightmost 2 dimensions being | ||
| `(n, m-1)`. For example, if we have 3 categories, and if the | ||
| item `[..., i, :]` is `[0.2, 0.3]`, then, the i-th variable's | ||
| first category has initial probability of 0.2, its second | ||
| category has initial probability of 0.3, and its third | ||
| category has initial probability of 0.5 (which is 1-(0.2+0.3)), | ||
| during the phase of population sampling. | ||
| Extra leftmost dimensions in the provided `center_init` will | ||
| be interpreted as batch dimensions. | ||
| Alternatively, `center_init` can be omitted altogether, and | ||
| the argument `solution_length` can be provided instead. | ||
| solution_length: Optionally the number of decision variables. | ||
| To be given if `center_init` is omitted. | ||
| If `center_init` is provided, this argument must be left as None, | ||
| because the number of decision variables are then inferred from | ||
| the shape of `center_init`. | ||
| num_choices: Number of categories for each decision variable. | ||
| If left as None, it will be assumed that the problem at hand | ||
| is binary, and solutions will assume the dtype `torch.bool`. | ||
| If given as an integer (at least 2), then the solutions will be | ||
| integer-typed. | ||
| prob_min: Optionally the lower bound for the probability | ||
| of choosing a category belonging to all variables (if given as a | ||
| scalar) or for each variable (if given as a vector whose length | ||
| is equal to the number of decision variables). | ||
| If any categorical choice's sampling probability is lower than | ||
| this value, the discrete cross entropy method will attempt to | ||
| re-adjust those sampling probabilities so that this lower bound | ||
| is respected. | ||
| parenthood_ratio: Proportion of the solutions that will be chosen as | ||
| the parents for the next generation. For example, if this is | ||
| given as 0.5, the top 50% of the solutions will be chosen as | ||
| parents. | ||
| device: If given as a string or as a `torch.device` instance, the | ||
| evolutionary search will be performed on this specified device. | ||
| """ | ||
| objective_sense = str(objective_sense) | ||
| if objective_sense == "min": | ||
| maximize = False | ||
| elif objective_sense == "max": | ||
| maximize = True | ||
| else: | ||
| raise ValueError( | ||
| f"`objective_sense` was expected as 'min' or 'max', but it was received as: {repr(objective_sense)}." | ||
| ) | ||
|
|
||
| if num_choices is not None: | ||
| num_choices = int(num_choices) | ||
| if num_choices < 2: | ||
| raise ValueError("`num_choices` was encountered as an integer less than 2, which is invalid.") | ||
|
|
||
| if device is None: | ||
| device_kwargs = {} | ||
| else: | ||
| device_kwargs = {"device": device} | ||
|
|
||
| if (center_init is None) and (solution_length is None): | ||
| raise ValueError("Both `center_init` and `solution_length` are avoided. Please provide one of them.") | ||
| elif (center_init is None) and (solution_length is not None): | ||
| solution_length = int(solution_length) | ||
| if solution_length < 1: | ||
| raise ValueError("`solution_length` was given as an integer that is less than 1, which is invalid.") | ||
| if num_choices is None: | ||
| center = torch.ones(solution_length, **device_kwargs) * 0.5 | ||
| else: | ||
| center = torch.ones((solution_length, (num_choices - 1)), **device_kwargs) * (1.0 / num_choices) | ||
| elif (center_init is not None) and (solution_length is None): | ||
| center = torch.as_tensor(center_init, **device_kwargs) | ||
| if center.ndim == 0: | ||
| raise ValueError("`center_init` was given as a scalar, which is not supported.") | ||
| if center.numel() == 0: | ||
| raise ValueError("`center_init` was given as an empty tensor, which is not supported.") | ||
| if num_choices is None: | ||
| solution_length = center.shape[-1] | ||
| else: | ||
| if center.ndim < 2: | ||
| raise ValueError( | ||
| "With `num_choices` given as an integer," | ||
| " `center_init` was expected as a tensor with at least 2 dimensions." | ||
| f" However, its number of dimensions is {center.ndim}." | ||
| ) | ||
| if center.shape[-1] != (num_choices - 1): | ||
| raise ValueError( | ||
| "With `num_choices` given as an integer," | ||
| " the rightmost dimension size of `center_init` was expected as `num_choices` - 1." | ||
| " However, the received `center_init` seems to violate this rule." | ||
| ) | ||
| solution_length = center.shape[-2] | ||
| else: | ||
| raise ValueError( | ||
| "Both `center_init` and `solution_length` are provided as values other than None." | ||
| " Please provide only one of them and leave the other one as None." | ||
| ) | ||
|
|
||
| if prob_min is None: | ||
| prob_min = torch.zeros(solution_length, dtype=center.dtype, device=center.device) | ||
| else: | ||
| prob_min = torch.as_tensor(prob_min, dtype=center.dtype, device=center.device).clamp(0.0, 1.0) * torch.ones( | ||
| solution_length, dtype=center.dtype, device=center.device | ||
| ) | ||
|
|
||
| return DCEMState( | ||
| center=center, | ||
| num_choices=num_choices, | ||
| prob_min=prob_min, | ||
| parenthood_ratio=float(parenthood_ratio), | ||
| maximize=maximize, | ||
| ) |
There was a problem hiding this comment.
Missing validation for parenthood_ratio.
parenthood_ratio is accepted as any float without bounds checking. A value ≤ 0 or > 1 is almost certainly a user error. While _dcem_tell uses max(1, ceil(len(population) * parenthood_ratio)) which prevents a zero-parent edge case, a negative or > 1 ratio silently produces surprising parent counts.
Proposed fix — add validation near line 368
+ parenthood_ratio = float(parenthood_ratio)
+ if not (0.0 < parenthood_ratio <= 1.0):
+ raise ValueError(
+ f"`parenthood_ratio` must be in (0, 1], but it was given as {parenthood_ratio}."
+ )
+
return DCEMState(
center=center,
num_choices=num_choices,
prob_min=prob_min,
- parenthood_ratio=float(parenthood_ratio),
+ parenthood_ratio=parenthood_ratio,
maximize=maximize,
)🧰 Tools
🪛 Ruff (0.15.0)
[warning] 305-307: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 306-306: Use explicit conversion flag
Replace with conversion flag
(RUF010)
[warning] 312-312: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 320-320: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 324-324: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 332-332: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 334-334: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 339-343: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 345-349: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 352-355: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/evotorch/algorithms/functional/funcdcem.py` around lines 210 - 370, The
function dcem currently accepts parenthood_ratio without bounds checking; add a
validation just before constructing the DCEMState: coerce parenthood_ratio to
float (parenthood_ratio = float(parenthood_ratio)) and raise ValueError if not
(0.0 < parenthood_ratio <= 1.0) with a clear message about allowed range; this
prevents surprising behavior later in _dcem_tell which computes parent counts
using parenthood_ratio and relies on it being within (0,1].
Summary by CodeRabbit