@@ -170,21 +170,58 @@ UIntT modexp_odd(const UIntT& base, Exponent exp, const UIntT& mod) noexcept
170170 return ret;
171171}
172172
173- template <typename UIntT>
174- UIntT modexp_pow2 (const UIntT& base, Exponent exp, unsigned k) noexcept
173+ // / Trims the multi-word number x[] to k bits.
174+ // / TODO: Currently this assumes no leading zeros in x. Re-design this after modexp is dynamic.
175+ void mask_pow2 (std::span<uint64_t > x, unsigned k) noexcept
176+ {
177+ assert (k != 0 );
178+ assert (x.size () >= (k + 63 ) / 64 );
179+ assert (!x.empty ());
180+ if (const auto rem = k % 64 ; rem != 0 )
181+ x.back () &= (uint64_t {1 } << rem) - 1 ;
182+ }
183+
184+ // / Computes r[] = base[]^exp % 2^k.
185+ // / Only the least words matching the k bits of the base are used.
186+ // / Also, the same amount of the result words are produced. The rest is not modified.
187+ void modexp_pow2 (
188+ std::span<uint64_t > r, std::span<const uint64_t > base, Exponent exp, unsigned k) noexcept
175189{
176- assert (k != 0 ); // Modulus of 1 should be covered as "odd".
177- UIntT ret = 1 ;
178- for (auto i = exp.bit_width (); i != 0 ; --i)
190+ assert (k != 0 ); // Modulus of 1 should be covered as "odd".
191+ assert (exp.bit_width () != 0 ); // Exponent of zero must be handled outside.
192+ assert (r.data () != base.data ()); // No in-place operation.
193+
194+ const auto num_pow2_words = (k + 63 ) / 64 ;
195+ assert (r.size () >= num_pow2_words);
196+ assert (base.size () >= num_pow2_words);
197+
198+ const auto base_k = base.subspan (0 , num_pow2_words);
199+ auto r_k = r.subspan (0 , num_pow2_words);
200+
201+ // Allocate temporary storage for iterations.
202+ // TODO: Move to stack if the size is small enough or provide from the caller.
203+ const auto tmp_storage = std::make_unique_for_overwrite<uint64_t []>(num_pow2_words);
204+ auto tmp = std::span{tmp_storage.get (), num_pow2_words};
205+
206+ std::ranges::copy (base_k, r_k.begin ());
207+
208+ for (auto i = exp.bit_width () - 1 ; i != 0 ; --i)
179209 {
180- ret *= ret;
210+ mul (tmp, r_k, r_k);
211+ std::swap (r_k, tmp);
212+
181213 if (exp[i - 1 ])
182- ret *= base;
214+ {
215+ mul (tmp, r_k, base_k);
216+ std::swap (r_k, tmp);
217+ }
183218 }
184219
185- const auto mod_pow2_mask = (UIntT{1 } << k) - 1 ;
186- ret &= mod_pow2_mask;
187- return ret;
220+ mask_pow2 (r_k, k);
221+
222+ // result_k may point to the tmp_storage. Copy back to the result buffer if needed.
223+ if (r_k.data () != r.data ())
224+ std::ranges::copy (r_k, r.begin ());
188225}
189226
190227// / Computes modular inversion of the multi-word number x[] modulo 2^(r.size() * 64).
@@ -229,7 +266,9 @@ UIntT modexp_even(const UIntT& base, Exponent exp, const UIntT& mod_odd, unsigne
229266 assert (k != 0 );
230267
231268 const auto x1 = modexp_odd (base, exp, mod_odd);
232- const auto x2 = modexp_pow2 (base, exp, k);
269+
270+ UIntT x2;
271+ modexp_pow2 (as_words (x2), as_words (base), exp, k);
233272
234273 const auto mod_odd_words = as_words (mod_odd);
235274 UIntT mod_odd_inv;
@@ -255,10 +294,10 @@ void modexp_impl(std::span<const uint8_t> base_bytes, Exponent exp,
255294 result = mod != 1 ; // - result is 1 except mod 1
256295 else if (const auto mod_tz = ctz (mod); mod_tz == 0 ) // Modulus is:
257296 result = modexp_odd (base, exp, mod); // - odd
258- else if (const auto mod_odd = mod >> mod_tz; mod_odd == 1 ) //
259- result = modexp_pow2 (base, exp, mod_tz); // - power of 2
260- else //
261- result = modexp_even (base, exp, mod_odd, mod_tz); // - even
297+ else if (const auto mod_odd = mod >> mod_tz; mod_odd == 1 ) // - power of 2
298+ modexp_pow2 ( as_words ( result), as_words (base) , exp, mod_tz);
299+ else //
300+ result = modexp_even (base, exp, mod_odd, mod_tz); // - even
262301
263302 intx::be::trunc (std::span{output, mod_bytes.size ()}, result);
264303}
0 commit comments