55#include " modexp.hpp"
66#include < evmmax/evmmax.hpp>
77#include < bit>
8+ #include < memory>
89
910using namespace intx ;
1011
@@ -32,6 +33,33 @@ constexpr uint64_t addmul(std::span<uint64_t> r, std::span<const uint64_t> p,
3233 return c;
3334}
3435
36+ // / Computes multiplication of x times y and truncates the result to the size of r:
37+ // / r[] = x[] * y[].
38+ constexpr void mul (
39+ std::span<uint64_t > r, std::span<const uint64_t > x, std::span<const uint64_t > y) noexcept
40+ {
41+ assert (!x.empty ());
42+ assert (!y.empty ());
43+ assert (x.size () >= y.size ()); // Required for safe subspan arithmetic in the loop.
44+ assert (r.size () == std::max (x.size (), y.size ()));
45+
46+ std::ranges::fill (r, 0 );
47+ for (size_t j = 0 ; j < y.size (); ++j)
48+ addmul (r.subspan (j), r.subspan (j), x.subspan (0 , x.size () - j), y[j]);
49+ }
50+
51+ // / Computes x[] = 2 - x[].
52+ constexpr void neg_add2 (std::span<uint64_t > x) noexcept
53+ {
54+ assert (!x.empty ());
55+ bool c = false ;
56+
57+ std::tie (x[0 ], c) = intx::subc (2 , x[0 ]);
58+ for (auto it = x.begin () + 1 ; it != x.end (); ++it)
59+ std::tie (*it, c) = intx::subc (0 , *it, c);
60+ }
61+
62+
3563// / Represents the exponent value of the modular exponentiation operation.
3664// /
3765// / This is a view type of the big-endian bytes representing the bits of the exponent.
@@ -159,24 +187,37 @@ UIntT modexp_pow2(const UIntT& base, Exponent exp, unsigned k) noexcept
159187 return ret;
160188}
161189
162- // / Computes modular inversion for modulus of 2ᵏ.
163- // /
164- // / TODO: This actually may return more bits than k, the caller is responsible for masking the
165- // / result. Better design may be to pass std::span<uint64_t> without specifying k.
166- template <typename UIntT>
167- UIntT modinv_pow2 (const UIntT& x, unsigned k) noexcept
190+ // / Computes modular inversion of the multi-word number x[] modulo 2^(r.size() * 64).
191+ void modinv_pow2 (std::span<uint64_t > r, std::span<const uint64_t > x) noexcept
168192{
169- assert (bit_test (x, 0 )); // x must be odd for the inverse to exist.
170- assert (k <= UIntT::num_bits); // k must fit into the type.
193+ assert (!x.empty () && (x[0 ] & 1 ) != 0 ); // x must be odd.
194+ assert (r.size () <= x.size ()); // Truncating version.
195+ assert (!r.empty ());
171196
172- // Start with inversion mod 2⁶⁴.
173- UIntT inv = evmmax::modinv (x[0 ]);
197+ r[0 ] = evmmax::modinv (x[0 ]); // Good start: 64 correct bits.
174198
175- // Each iteration doubles the number of correct bits in the inverse. See modinv(uint32_t).
176- for (size_t iterations = 64 ; iterations < k; iterations *= 2 )
177- inv *= 2 - x * inv;
199+ // Allocate temporary storage for iterations.
200+ // TODO: Move to stack if the size is small enough or provide from the caller.
201+ const auto tmp_storage = std::make_unique_for_overwrite<uint64_t []>(2 * r.size ());
202+ const auto tmp = std::span{tmp_storage.get (), 2 * r.size ()};
178203
179- return inv;
204+ // Each iteration doubles the number of correct bits in the inverse. See evmmax::modinv().
205+ for (size_t i = 1 ; i < r.size (); i *= 2 )
206+ {
207+ // At the start of the iteration we have i-word correct inverse in r[0-i].
208+ // The iteration performs the Newton-Raphson step with double the precision (n=2i).
209+ const auto n = std::min (i * 2 , r.size ());
210+ const auto t1 = tmp.subspan (0 , n);
211+ const auto t2 = tmp.subspan (n, n);
212+
213+ mul (t1, x.subspan (0 , n), r.subspan (0 , i)); // t1 = x * inv
214+ neg_add2 (t1); // t1 = 2 - x * inv
215+ mul (t2, t1, r.subspan (0 , i)); // t2 = inv * (2 - x * inv)
216+ // TODO: Consider implementing the step as (inv << 1) - (x * inv * inv).
217+
218+ // TODO: Avoid copy by swapping buffers.
219+ std::ranges::copy (t2, r.begin ());
220+ }
180221}
181222
182223// / Computes modular exponentiation for even modulus: base^exp % (mod_odd * 2^k).
@@ -190,7 +231,10 @@ UIntT modexp_even(const UIntT& base, Exponent exp, const UIntT& mod_odd, unsigne
190231 const auto x1 = modexp_odd (base, exp, mod_odd);
191232 const auto x2 = modexp_pow2 (base, exp, k);
192233
193- const auto mod_odd_inv = modinv_pow2 (mod_odd, k);
234+ const auto mod_odd_words = as_words (mod_odd);
235+ UIntT mod_odd_inv;
236+ const auto num_pow2_words = (k + 63 ) / 64 ;
237+ modinv_pow2 (as_words (mod_odd_inv).subspan (0 , num_pow2_words), mod_odd_words);
194238
195239 const auto mod_pow2_mask = (UIntT{1 } << k) - 1 ;
196240 const auto y = ((x2 - x1) * mod_odd_inv) & mod_pow2_mask;
0 commit comments