@@ -129,6 +129,14 @@ void store(std::span<uint8_t> r, std::span<const uint64_t> words) noexcept
129129 std::ranges::fill (r.subspan (0 , pos), uint8_t {0 });
130130}
131131
132+ // / Trims a little-endian word array to significant words.
133+ template <typename T>
134+ constexpr std::span<T> trim (std::span<T> x) noexcept
135+ {
136+ const auto it = std::ranges::find_if (x.rbegin (), x.rend (), [](auto w) { return w != 0 ; });
137+ return x.first (static_cast <size_t >(std::ranges::distance (it, x.rend ())));
138+ }
139+
132140// / Counts trailing zeros in a non-zero little-endian word array.
133141constexpr unsigned ctz (std::span<const uint64_t > x) noexcept
134142{
@@ -169,6 +177,79 @@ void shr(std::span<uint64_t> r, std::span<const uint64_t> x, unsigned k) noexcep
169177}
170178
171179
180+ // / Computes r[] = u[] % d[] (remainder only).
181+ // / The d[] must be non-zero. The r[] size must be >= num significant words in d[].
182+ void rem (std::span<uint64_t > r, std::span<const uint64_t > u, std::span<const uint64_t > d) noexcept
183+ {
184+ assert (!d.empty ());
185+ assert (!u.empty ());
186+ assert (d.back () != 0 );
187+ assert (u.back () != 0 );
188+ assert (r.size () >= d.size ());
189+ assert (u.size () > d.size ()); // Because used only for to-Montgomery conversion.
190+
191+ const auto un_storage = std::make_unique_for_overwrite<uint64_t []>(u.size () + 1 );
192+ auto un = std::span{un_storage.get (), u.size () + 1 };
193+ un.back () = 0 ; // Only the extra top word needs zeroing; the rest is set by normalization.
194+
195+ // Normalize: left-shift both u and d so that the MSB of d's top word is set.
196+ const auto shift = static_cast <unsigned >(std::countl_zero (d.back ()));
197+
198+ // Allocate normalized divisor.
199+ const auto dn_storage = std::make_unique_for_overwrite<uint64_t []>(d.size ());
200+ const auto dn = std::span{dn_storage.get (), d.size ()};
201+
202+ if (shift != 0 )
203+ {
204+ for (size_t i = d.size () - 1 ; i != 0 ; --i)
205+ dn[i] = (d[i] << shift) | (d[i - 1 ] >> (64 - shift));
206+ dn[0 ] = d[0 ] << shift;
207+
208+ // Normalize numerator into un.
209+ un[u.size ()] = u.back () >> (64 - shift);
210+ for (size_t i = u.size () - 1 ; i != 0 ; --i)
211+ un[i] = (u[i] << shift) | (u[i - 1 ] >> (64 - shift));
212+ un[0 ] = u[0 ] << shift;
213+ }
214+ else
215+ {
216+ std::ranges::copy (d, dn.begin ());
217+ std::ranges::copy (u, un.begin ());
218+ }
219+
220+ // Shrink off the extra top word if it is not significant for the normalized numerator.
221+ if (un.back () == 0 && un[un.size () - 2 ] < dn.back ())
222+ un = un.first (un.size () - 1 );
223+
224+ const auto denormalize = [&r, shift](std::span<const uint64_t > x) noexcept {
225+ assert (r.size () >= x.size ());
226+ shr (r.first (x.size ()), x, shift);
227+ std::ranges::fill (r.subspan (x.size ()), uint64_t {0 });
228+ };
229+
230+ assert (un.size () > dn.size ()); // Not possible in the current usage.
231+
232+ if (dn.size () == 1 )
233+ {
234+ const uint64_t rem_words[1 ]{intx::internal::udivrem_by1 (un, dn[0 ])};
235+ denormalize (rem_words);
236+ }
237+ else if (dn.size () == 2 )
238+ {
239+ const auto rem2 = intx::internal::udivrem_by2 (un, uint128{dn[0 ], dn[1 ]});
240+ denormalize (as_words (rem2));
241+ }
242+ else
243+ {
244+ // General case: Knuth's algorithm. Quotient is stored in-place in un[dn.size()..].
245+ // We don't need the quotient, but udivrem_knuth requires storage for it.
246+ const auto q_len = un.size () - dn.size ();
247+ const auto q_storage = std::make_unique_for_overwrite<uint64_t []>(q_len);
248+ intx::internal::udivrem_knuth (q_storage.get (), un, dn);
249+ denormalize (un.subspan (0 , dn.size ()));
250+ }
251+ }
252+
172253// / Represents the exponent value of the modular exponentiation operation.
173254// /
174255// / This is a view type of the big-endian bytes representing the bits of the exponent.
@@ -249,6 +330,8 @@ constexpr void mul_amm(std::span<uint64_t, N> r, std::span<const uint64_t, N> y,
249330 std::ranges::copy (t, r.begin ());
250331}
251332
333+ // / Performs modular exponentiation for an odd modulus using Montgomery multiplication.
334+ // / The base must already be in Montgomery form: base = (orig_base * R) % mod.
252335template <typename UIntT>
253336UIntT modexp_odd_fixed_size (const UIntT& base, Exponent exp, const UIntT& mod) noexcept
254337{
@@ -257,16 +340,12 @@ UIntT modexp_odd_fixed_size(const UIntT& base, Exponent exp, const UIntT& mod) n
257340
258341 const auto mod_inv = evmmax::compute_mont_mod_inv (mod);
259342
260- // / Convert the base to Montgomery form: base*R % mod, where R = 2^(num_bits).
261- const auto base_mont =
262- udivrem (intx::uint<UIntT::num_bits * 2 >{base} << UIntT::num_bits, mod).rem ;
263-
264- auto ret = base_mont;
343+ auto ret = base;
265344 for (auto i = exp.bit_width () - 1 ; i != 0 ; --i)
266345 {
267346 mul_amm<N>(as_words (ret), as_words (ret), as_words (mod), mod_inv);
268347 if (exp[i - 1 ])
269- mul_amm<N>(as_words (ret), as_words (base_mont ), as_words (mod), mod_inv);
348+ mul_amm<N>(as_words (ret), as_words (base ), as_words (mod), mod_inv);
270349 }
271350
272351 // Convert the result from Montgomery form by multiplying with the standard integer 1.
@@ -281,33 +360,75 @@ UIntT modexp_odd_fixed_size(const UIntT& base, Exponent exp, const UIntT& mod) n
281360 return ret;
282361}
283362
284- void modexp_odd (std::span<uint64_t > result, const std::span<const uint64_t > base, Exponent exp,
285- const std::span<const uint64_t > mod) noexcept
363+ void modexp_odd (std::span<uint64_t > result, std::span<const uint64_t > base, Exponent exp,
364+ std::span<const uint64_t > mod) noexcept
286365{
287- static constexpr auto MAX_INPUT_SIZE = 1024 / sizeof (uint64_t ); // 8192 bits, as in EIP-7823.
288- assert (base.size () <= MAX_INPUT_SIZE);
289- assert (base.size () <= MAX_INPUT_SIZE);
290- assert (result.size () == mod.size ());
291- assert (base.size () == mod.size ()); // True for the current callers. Relax if needed.
366+ base = trim (base);
367+ mod = trim (mod);
368+ assert (!mod.empty ());
369+ assert (exp.bit_width () != 0 );
370+
371+ if (base.empty ()) [[unlikely]] // base is 0: 0^exp = 0 for exp > 0.
372+ {
373+ std::ranges::fill (result, uint64_t {0 });
374+ return ;
375+ }
376+
377+ const auto n = mod.size ();
378+
379+ // Select the fixed-size width (in words) for Montgomery multiplication.
380+ static constexpr auto MAX_SIZE = 1024 / sizeof (uint64_t ); // 8192 bits, as in EIP-7823.
381+ assert (n <= MAX_SIZE);
382+ static constexpr size_t SIZES[] = {2 , 4 , 8 , 16 , 32 , MAX_SIZE};
383+ const auto r_size = *std::ranges::lower_bound (SIZES, n);
384+
385+ // Compute base_mont = (base * R) % mod, where R = 2^(r_size*64).
386+ // R must match the width used by Montgomery multiplication (mul_amm).
387+ // The numerator is base shifted left by r_size words (r_size + base.size() words).
388+ // The result (base * R) % mod can be up to mod-1, always requiring n words.
389+ const auto u_len = r_size + base.size ();
390+ const auto tmp_storage = std::make_unique_for_overwrite<uint64_t []>(u_len + n);
391+ const auto tmp = std::span{tmp_storage.get (), u_len + n};
392+ const auto u = tmp.first (u_len);
393+ const auto base_mont = tmp.subspan (u_len, n);
394+ std::ranges::fill (u.first (r_size), uint64_t {0 });
395+ std::ranges::copy (base, u.subspan (r_size).begin ());
396+ rem (base_mont, u, mod);
292397
293398 const auto impl = [=]<size_t N>() {
294399 using UintT = intx::uint<N * 64 >;
295- const auto r = modexp_odd_fixed_size (UintT{base}, exp, UintT{mod});
296- std::ranges::copy (as_words (r).subspan (0 , result.size ()), result.begin ());
400+
401+ // Pass zero-extended fixed-size representation.
402+ const auto r = modexp_odd_fixed_size (UintT{base_mont}, exp, UintT{mod});
403+
404+ // TODO: Because the caller's mod is not trimmed, we must also zero-extend the result.
405+ const auto rw = as_words (r);
406+ const auto [_, out] =
407+ std::ranges::copy (rw.first (std::min (rw.size (), result.size ())), result.begin ());
408+ std::fill (out, result.end (), 0 );
297409 };
298410
299- if (const auto n = mod.size (); n <= 2 )
411+ switch (r_size)
412+ {
413+ case 2 :
300414 impl.operator ()<2 >();
301- else if (n <= 4 )
415+ break ;
416+ case 4 :
302417 impl.operator ()<4 >();
303- else if (n <= 8 )
418+ break ;
419+ case 8 :
304420 impl.operator ()<8 >();
305- else if (n <= 16 )
421+ break ;
422+ case 16 :
306423 impl.operator ()<16 >();
307- else if (n <= 32 )
424+ break ;
425+ case 32 :
308426 impl.operator ()<32 >();
309- else
310- impl.operator ()<MAX_INPUT_SIZE>();
427+ break ;
428+ default :
429+ impl.operator ()<MAX_SIZE>();
430+ break ;
431+ }
311432}
312433
313434// / Trims the multi-word number x[] to k bits.
0 commit comments