Skip to content

Commit dd55876

Browse files
authored
crypto: Use variadic-length modinv in modexp (#1434)
Replace the modular inversion with mod 2ᵏ which uses fixed size `intx::uint` to one which uses variadic-length numbers represented by `std::span<uint64_t>`. This improves performance in pathological cases when the 2ᵏ part of the modexp modulus is significant shorter than the modulus itself.
1 parent 2c62685 commit dd55876

File tree

1 file changed

+59
-15
lines changed

1 file changed

+59
-15
lines changed

lib/evmone_precompiles/modexp.cpp

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "modexp.hpp"
66
#include <evmmax/evmmax.hpp>
77
#include <bit>
8+
#include <memory>
89

910
using namespace intx;
1011

@@ -32,6 +33,33 @@ constexpr uint64_t addmul(std::span<uint64_t> r, std::span<const uint64_t> p,
3233
return c;
3334
}
3435

36+
/// Computes multiplication of x times y and truncates the result to the size of r:
37+
/// r[] = x[] * y[].
38+
constexpr void mul(
39+
std::span<uint64_t> r, std::span<const uint64_t> x, std::span<const uint64_t> y) noexcept
40+
{
41+
assert(!x.empty());
42+
assert(!y.empty());
43+
assert(x.size() >= y.size()); // Required for safe subspan arithmetic in the loop.
44+
assert(r.size() == std::max(x.size(), y.size()));
45+
46+
std::ranges::fill(r, 0);
47+
for (size_t j = 0; j < y.size(); ++j)
48+
addmul(r.subspan(j), r.subspan(j), x.subspan(0, x.size() - j), y[j]);
49+
}
50+
51+
/// Computes x[] = 2 - x[].
52+
constexpr void neg_add2(std::span<uint64_t> x) noexcept
53+
{
54+
assert(!x.empty());
55+
bool c = false;
56+
57+
std::tie(x[0], c) = intx::subc(2, x[0]);
58+
for (auto it = x.begin() + 1; it != x.end(); ++it)
59+
std::tie(*it, c) = intx::subc(0, *it, c);
60+
}
61+
62+
3563
/// Represents the exponent value of the modular exponentiation operation.
3664
///
3765
/// This is a view type of the big-endian bytes representing the bits of the exponent.
@@ -159,24 +187,37 @@ UIntT modexp_pow2(const UIntT& base, Exponent exp, unsigned k) noexcept
159187
return ret;
160188
}
161189

162-
/// Computes modular inversion for modulus of 2ᵏ.
163-
///
164-
/// TODO: This actually may return more bits than k, the caller is responsible for masking the
165-
/// result. Better design may be to pass std::span<uint64_t> without specifying k.
166-
template <typename UIntT>
167-
UIntT modinv_pow2(const UIntT& x, unsigned k) noexcept
190+
/// Computes modular inversion of the multi-word number x[] modulo 2^(r.size() * 64).
191+
void modinv_pow2(std::span<uint64_t> r, std::span<const uint64_t> x) noexcept
168192
{
169-
assert(bit_test(x, 0)); // x must be odd for the inverse to exist.
170-
assert(k <= UIntT::num_bits); // k must fit into the type.
193+
assert(!x.empty() && (x[0] & 1) != 0); // x must be odd.
194+
assert(r.size() <= x.size()); // Truncating version.
195+
assert(!r.empty());
171196

172-
// Start with inversion mod 2⁶⁴.
173-
UIntT inv = evmmax::modinv(x[0]);
197+
r[0] = evmmax::modinv(x[0]); // Good start: 64 correct bits.
174198

175-
// Each iteration doubles the number of correct bits in the inverse. See modinv(uint32_t).
176-
for (size_t iterations = 64; iterations < k; iterations *= 2)
177-
inv *= 2 - x * inv;
199+
// Allocate temporary storage for iterations.
200+
// TODO: Move to stack if the size is small enough or provide from the caller.
201+
const auto tmp_storage = std::make_unique_for_overwrite<uint64_t[]>(2 * r.size());
202+
const auto tmp = std::span{tmp_storage.get(), 2 * r.size()};
178203

179-
return inv;
204+
// Each iteration doubles the number of correct bits in the inverse. See evmmax::modinv().
205+
for (size_t i = 1; i < r.size(); i *= 2)
206+
{
207+
// At the start of the iteration we have i-word correct inverse in r[0-i].
208+
// The iteration performs the Newton-Raphson step with double the precision (n=2i).
209+
const auto n = std::min(i * 2, r.size());
210+
const auto t1 = tmp.subspan(0, n);
211+
const auto t2 = tmp.subspan(n, n);
212+
213+
mul(t1, x.subspan(0, n), r.subspan(0, i)); // t1 = x * inv
214+
neg_add2(t1); // t1 = 2 - x * inv
215+
mul(t2, t1, r.subspan(0, i)); // t2 = inv * (2 - x * inv)
216+
// TODO: Consider implementing the step as (inv << 1) - (x * inv * inv).
217+
218+
// TODO: Avoid copy by swapping buffers.
219+
std::ranges::copy(t2, r.begin());
220+
}
180221
}
181222

182223
/// Computes modular exponentiation for even modulus: base^exp % (mod_odd * 2^k).
@@ -190,7 +231,10 @@ UIntT modexp_even(const UIntT& base, Exponent exp, const UIntT& mod_odd, unsigne
190231
const auto x1 = modexp_odd(base, exp, mod_odd);
191232
const auto x2 = modexp_pow2(base, exp, k);
192233

193-
const auto mod_odd_inv = modinv_pow2(mod_odd, k);
234+
const auto mod_odd_words = as_words(mod_odd);
235+
UIntT mod_odd_inv;
236+
const auto num_pow2_words = (k + 63) / 64;
237+
modinv_pow2(as_words(mod_odd_inv).subspan(0, num_pow2_words), mod_odd_words);
194238

195239
const auto mod_pow2_mask = (UIntT{1} << k) - 1;
196240
const auto y = ((x2 - x1) * mod_odd_inv) & mod_pow2_mask;

0 commit comments

Comments
 (0)