Skip to content

Commit 0864d8b

Browse files
committed
crypto: Compute modexp base_mont using var-length division
Replace fixed-width intx::udivrem with span-based rem() for computing the Montgomery form of the base in modexp_odd. The rem() function reuses intx's internal division primitives operating on dynamic word spans.
1 parent 58563c8 commit 0864d8b

File tree

2 files changed

+178
-22
lines changed

2 files changed

+178
-22
lines changed

lib/evmone_precompiles/modexp.cpp

Lines changed: 143 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,14 @@ void store(std::span<uint8_t> r, std::span<const uint64_t> words) noexcept
129129
std::ranges::fill(r.subspan(0, pos), uint8_t{0});
130130
}
131131

132+
/// Trims a little-endian word array to significant words.
133+
template <typename T>
134+
constexpr std::span<T> trim(std::span<T> x) noexcept
135+
{
136+
const auto it = std::ranges::find_if(x.rbegin(), x.rend(), [](auto w) { return w != 0; });
137+
return x.first(static_cast<size_t>(std::ranges::distance(it, x.rend())));
138+
}
139+
132140
/// Counts trailing zeros in a non-zero little-endian word array.
133141
constexpr unsigned ctz(std::span<const uint64_t> x) noexcept
134142
{
@@ -169,6 +177,79 @@ void shr(std::span<uint64_t> r, std::span<const uint64_t> x, unsigned k) noexcep
169177
}
170178

171179

180+
/// Computes r[] = u[] % d[] (remainder only).
181+
/// The d[] must be non-zero. The r[] size must be >= num significant words in d[].
182+
void rem(std::span<uint64_t> r, std::span<const uint64_t> u, std::span<const uint64_t> d) noexcept
183+
{
184+
assert(!d.empty());
185+
assert(!u.empty());
186+
assert(d.back() != 0);
187+
assert(u.back() != 0);
188+
assert(r.size() >= d.size());
189+
assert(u.size() > d.size()); // Because used only for to-Montgomery conversion.
190+
191+
const auto un_storage = std::make_unique_for_overwrite<uint64_t[]>(u.size() + 1);
192+
auto un = std::span{un_storage.get(), u.size() + 1};
193+
un.back() = 0; // Only the extra top word needs zeroing; the rest is set by normalization.
194+
195+
// Normalize: left-shift both u and d so that the MSB of d's top word is set.
196+
const auto shift = static_cast<unsigned>(std::countl_zero(d.back()));
197+
198+
// Allocate normalized divisor.
199+
const auto dn_storage = std::make_unique_for_overwrite<uint64_t[]>(d.size());
200+
const auto dn = std::span{dn_storage.get(), d.size()};
201+
202+
if (shift != 0)
203+
{
204+
for (size_t i = d.size() - 1; i != 0; --i)
205+
dn[i] = (d[i] << shift) | (d[i - 1] >> (64 - shift));
206+
dn[0] = d[0] << shift;
207+
208+
// Normalize numerator into un.
209+
un[u.size()] = u.back() >> (64 - shift);
210+
for (size_t i = u.size() - 1; i != 0; --i)
211+
un[i] = (u[i] << shift) | (u[i - 1] >> (64 - shift));
212+
un[0] = u[0] << shift;
213+
}
214+
else
215+
{
216+
std::ranges::copy(d, dn.begin());
217+
std::ranges::copy(u, un.begin());
218+
}
219+
220+
// Shrink off the extra top word if it is not significant for the normalized numerator.
221+
if (un.back() == 0 && un[un.size() - 2] < dn.back())
222+
un = un.first(un.size() - 1);
223+
224+
const auto denormalize = [&r, shift](std::span<const uint64_t> x) noexcept {
225+
assert(r.size() >= x.size());
226+
shr(r.first(x.size()), x, shift);
227+
std::ranges::fill(r.subspan(x.size()), uint64_t{0});
228+
};
229+
230+
assert(un.size() > dn.size()); // Not possible in the current usage.
231+
232+
if (dn.size() == 1)
233+
{
234+
const uint64_t rem_words[1]{intx::internal::udivrem_by1(un, dn[0])};
235+
denormalize(rem_words);
236+
}
237+
else if (dn.size() == 2)
238+
{
239+
const auto rem2 = intx::internal::udivrem_by2(un, uint128{dn[0], dn[1]});
240+
denormalize(as_words(rem2));
241+
}
242+
else
243+
{
244+
// General case: Knuth's algorithm. Quotient is stored in-place in un[dn.size()..].
245+
// We don't need the quotient, but udivrem_knuth requires storage for it.
246+
const auto q_len = un.size() - dn.size();
247+
const auto q_storage = std::make_unique_for_overwrite<uint64_t[]>(q_len);
248+
intx::internal::udivrem_knuth(q_storage.get(), un, dn);
249+
denormalize(un.subspan(0, dn.size()));
250+
}
251+
}
252+
172253
/// Represents the exponent value of the modular exponentiation operation.
173254
///
174255
/// This is a view type of the big-endian bytes representing the bits of the exponent.
@@ -249,6 +330,8 @@ constexpr void mul_amm(std::span<uint64_t, N> r, std::span<const uint64_t, N> y,
249330
std::ranges::copy(t, r.begin());
250331
}
251332

333+
/// Performs modular exponentiation for an odd modulus using Montgomery multiplication.
334+
/// The base must already be in Montgomery form: base = (orig_base * R) % mod.
252335
template <typename UIntT>
253336
UIntT modexp_odd_fixed_size(const UIntT& base, Exponent exp, const UIntT& mod) noexcept
254337
{
@@ -257,16 +340,12 @@ UIntT modexp_odd_fixed_size(const UIntT& base, Exponent exp, const UIntT& mod) n
257340

258341
const auto mod_inv = evmmax::compute_mont_mod_inv(mod);
259342

260-
/// Convert the base to Montgomery form: base*R % mod, where R = 2^(num_bits).
261-
const auto base_mont =
262-
udivrem(intx::uint<UIntT::num_bits * 2>{base} << UIntT::num_bits, mod).rem;
263-
264-
auto ret = base_mont;
343+
auto ret = base;
265344
for (auto i = exp.bit_width() - 1; i != 0; --i)
266345
{
267346
mul_amm<N>(as_words(ret), as_words(ret), as_words(mod), mod_inv);
268347
if (exp[i - 1])
269-
mul_amm<N>(as_words(ret), as_words(base_mont), as_words(mod), mod_inv);
348+
mul_amm<N>(as_words(ret), as_words(base), as_words(mod), mod_inv);
270349
}
271350

272351
// Convert the result from Montgomery form by multiplying with the standard integer 1.
@@ -281,33 +360,75 @@ UIntT modexp_odd_fixed_size(const UIntT& base, Exponent exp, const UIntT& mod) n
281360
return ret;
282361
}
283362

284-
void modexp_odd(std::span<uint64_t> result, const std::span<const uint64_t> base, Exponent exp,
285-
const std::span<const uint64_t> mod) noexcept
363+
void modexp_odd(std::span<uint64_t> result, std::span<const uint64_t> base, Exponent exp,
364+
std::span<const uint64_t> mod) noexcept
286365
{
287-
static constexpr auto MAX_INPUT_SIZE = 1024 / sizeof(uint64_t); // 8192 bits, as in EIP-7823.
288-
assert(base.size() <= MAX_INPUT_SIZE);
289-
assert(base.size() <= MAX_INPUT_SIZE);
290-
assert(result.size() == mod.size());
291-
assert(base.size() == mod.size()); // True for the current callers. Relax if needed.
366+
base = trim(base);
367+
mod = trim(mod);
368+
assert(!mod.empty());
369+
assert(exp.bit_width() != 0);
370+
371+
if (base.empty()) [[unlikely]] // base is 0: 0^exp = 0 for exp > 0.
372+
{
373+
std::ranges::fill(result, uint64_t{0});
374+
return;
375+
}
376+
377+
const auto n = mod.size();
378+
379+
// Select the fixed-size width (in words) for Montgomery multiplication.
380+
static constexpr auto MAX_SIZE = 1024 / sizeof(uint64_t); // 8192 bits, as in EIP-7823.
381+
assert(n <= MAX_SIZE);
382+
static constexpr size_t SIZES[] = {2, 4, 8, 16, 32, MAX_SIZE};
383+
const auto r_size = *std::ranges::lower_bound(SIZES, n);
384+
385+
// Compute base_mont = (base * R) % mod, where R = 2^(r_size*64).
386+
// R must match the width used by Montgomery multiplication (mul_amm).
387+
// The numerator is base shifted left by r_size words (r_size + base.size() words).
388+
// The result (base * R) % mod can be up to mod-1, always requiring n words.
389+
const auto u_len = r_size + base.size();
390+
const auto tmp_storage = std::make_unique_for_overwrite<uint64_t[]>(u_len + n);
391+
const auto tmp = std::span{tmp_storage.get(), u_len + n};
392+
const auto u = tmp.first(u_len);
393+
const auto base_mont = tmp.subspan(u_len, n);
394+
std::ranges::fill(u.first(r_size), uint64_t{0});
395+
std::ranges::copy(base, u.subspan(r_size).begin());
396+
rem(base_mont, u, mod);
292397

293398
const auto impl = [=]<size_t N>() {
294399
using UintT = intx::uint<N * 64>;
295-
const auto r = modexp_odd_fixed_size(UintT{base}, exp, UintT{mod});
296-
std::ranges::copy(as_words(r).subspan(0, result.size()), result.begin());
400+
401+
// Pass zero-extended fixed-size representation.
402+
const auto r = modexp_odd_fixed_size(UintT{base_mont}, exp, UintT{mod});
403+
404+
// TODO: Because the caller's mod is not trimmed, we must also zero-extend the result.
405+
const auto rw = as_words(r);
406+
const auto [_, out] =
407+
std::ranges::copy(rw.first(std::min(rw.size(), result.size())), result.begin());
408+
std::fill(out, result.end(), 0);
297409
};
298410

299-
if (const auto n = mod.size(); n <= 2)
411+
switch (r_size)
412+
{
413+
case 2:
300414
impl.operator()<2>();
301-
else if (n <= 4)
415+
break;
416+
case 4:
302417
impl.operator()<4>();
303-
else if (n <= 8)
418+
break;
419+
case 8:
304420
impl.operator()<8>();
305-
else if (n <= 16)
421+
break;
422+
case 16:
306423
impl.operator()<16>();
307-
else if (n <= 32)
424+
break;
425+
case 32:
308426
impl.operator()<32>();
309-
else
310-
impl.operator()<MAX_INPUT_SIZE>();
427+
break;
428+
default:
429+
impl.operator()<MAX_SIZE>();
430+
break;
431+
}
311432
}
312433

313434
/// Trims the multi-word number x[] to k bits.

test/unittests/precompiles_expmod_test.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,43 @@ TEST(expmod, test_vectors)
5959
{"09", "05", "18", "09"},
6060
{"03", "80", "ff", "ab"},
6161
{"03", "1c93", "61", "5f"},
62+
// base=0 with exp>0: 0^n = 0 for all paths (odd, even, pow2).
63+
{"00", "01", "07", "00"},
64+
{"00", "03", "06", "00"},
65+
{"00", "01", "08", "00"},
66+
// Different base/mod byte sizes.
67+
{"0100", "01", "07", "04"}, // large base, small odd mod
68+
{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "01", "0006", "0003"},
69+
{"02", "05", "060000000000000000", "000000000000000020"},
70+
{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "02",
71+
"fffffffffffffffd", "0000000000001900"},
72+
{"02", "03", "0100000000000000000000000000000001", "0000000000000000000000000000000008"},
6273
// Power-of-two modulus bigger than single word.
6374
{"cc", "11", "00000001000000000000000000000000", "00000000fe8477d6c9cef3cc00000000"},
75+
// Odd modulus dispatch-width coverage (n != N for n=1,3,5).
76+
{"0000000000000002", "01", "8000000000000001", "0000000000000002"},
77+
{"000000000000000000000000000000000000000000000002", "01",
78+
"800000000000000000000000000000000000000000000001",
79+
"000000000000000000000000000000000000000000000002"},
80+
{"00000000000000000000000000000000000000000000000000000000000000000000000000000002", "01",
81+
"80000000000000000000000000000000000000000000000000000000000000000000000000000001",
82+
"00000000000000000000000000000000000000000000000000000000000000000000000000000002"},
83+
// Full-width base triggers normalization headroom path in rem().
84+
{"80000000000000000000000000000000", "01", "80000000000000000000000000000001",
85+
"80000000000000000000000000000000"},
86+
{"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
87+
"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
88+
"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
89+
"00002",
90+
"01",
91+
"80000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
92+
"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
93+
"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
94+
"00000000000001",
95+
"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
96+
"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
97+
"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
98+
"00000000000002"},
6499
{
65100
"03",
66101
"fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2e",

0 commit comments

Comments
 (0)