diff --git a/src/pke/include/schemerns/rns-cryptoparameters.h b/src/pke/include/schemerns/rns-cryptoparameters.h index cd9ceb41b..a5ed7bb65 100644 --- a/src/pke/include/schemerns/rns-cryptoparameters.h +++ b/src/pke/include/schemerns/rns-cryptoparameters.h @@ -390,6 +390,18 @@ class CryptoParametersRNS : public CryptoParametersRLWE { return m_paramsP; } + /** + * Gets cached params for the first sizeQl towers of Q, used in KeySwitchDown. + * Only populated when key switching technique is HYBRID. + * @param sizeQl number of Q towers (1 to GetElementParams()->GetParams().size()). + * @return shared_ptr to ILDCRTParams for the first sizeQl moduli, or nullptr if not cached. + */ + const std::shared_ptr> GetParamsQlHybrid(uint32_t sizeQl) const { + if (m_paramsQlHybrid.empty() || sizeQl == 0 || sizeQl > m_paramsQlHybrid.size()) + return nullptr; + return m_paramsQlHybrid[sizeQl - 1]; + } + /** * Method that returns the number of towers within every digit. * This is the alpha parameter from the paper (see documentation @@ -1446,6 +1458,10 @@ class CryptoParametersRNS : public CryptoParametersRLWE { // used in GHS key switching std::shared_ptr> m_paramsQP; + // Cached params for Q_l (first sizeQl towers of Q) used in KeySwitchDown. + // m_paramsQlHybrid[l] = params for first (l+1) Q towers; index by sizeQl-1 when looking up. + std::vector>> m_paramsQlHybrid; + // Stores the partition size {PartQ} = {Q_1,...,Q_l} // where each Q_i is the product of q_j uint32_t m_numPartQ = 0; diff --git a/src/pke/lib/keyswitch/keyswitch-hybrid.cpp b/src/pke/lib/keyswitch/keyswitch-hybrid.cpp index a33e42298..15604b685 100644 --- a/src/pke/lib/keyswitch/keyswitch-hybrid.cpp +++ b/src/pke/lib/keyswitch/keyswitch-hybrid.cpp @@ -43,6 +43,21 @@ namespace lbcrypto { +namespace { + + std::shared_ptr MakeParamsQlFromQlP( + const std::shared_ptr& paramsQlP, uint32_t sizeQl) { + std::vector moduliQ(sizeQl); + std::vector rootsQ(sizeQl); + for (uint32_t i = 0; i < sizeQl; ++i) { + moduliQ[i] = paramsQlP->GetParams()[i]->GetModulus(); + rootsQ[i] = paramsQlP->GetParams()[i]->GetRootOfUnity(); + } + return std::make_shared(paramsQlP->GetCyclotomicOrder(), moduliQ, rootsQ); + } + + } // namespace + EvalKey KeySwitchHYBRID::KeySwitchGenInternal(const PrivateKey oldKey, const PrivateKey newKey) const { return KeySwitchHYBRID::KeySwitchGenInternal(oldKey, newKey, nullptr); @@ -248,17 +263,12 @@ Ciphertext KeySwitchHYBRID::KeySwitchDown(ConstCiphertext ci const auto paramsQlP = cv[0].GetParams(); const auto cryptoParams = std::dynamic_pointer_cast(ciphertext->GetCryptoParameters()); - const auto paramsP = cryptoParams->GetParamsP(); + const auto paramsP = cryptoParams->GetParamsP(); - // TODO : (Andrey) precompute paramsQl in cryptoparameters const uint32_t sizeQl = paramsQlP->GetParams().size() - paramsP->GetParams().size(); - std::vector moduliQ(sizeQl); - std::vector rootsQ(sizeQl); - for (uint32_t i = 0; i < sizeQl; ++i) { - moduliQ[i] = paramsQlP->GetParams()[i]->GetModulus(); - rootsQ[i] = paramsQlP->GetParams()[i]->GetRootOfUnity(); - } - const auto paramsQl = std::make_shared(paramsQlP->GetCyclotomicOrder(), moduliQ, rootsQ); + std::shared_ptr paramsQl = cryptoParams->GetParamsQlHybrid(sizeQl); + if (!paramsQl) + paramsQl = MakeParamsQlFromQlP(paramsQlP, sizeQl); const PlaintextModulus t = (cryptoParams->GetNoiseScale() == 1) ? 0 : cryptoParams->GetPlaintextModulus(); @@ -286,15 +296,10 @@ DCRTPoly KeySwitchHYBRID::KeySwitchDownFirstElement(ConstCiphertext ci const auto cryptoParams = std::dynamic_pointer_cast(ciphertext->GetCryptoParameters()); const auto paramsP = cryptoParams->GetParamsP(); - // TODO : (Andrey) precompute paramsQl in cryptoparameters const uint32_t sizeQl = paramsQlP->GetParams().size() - paramsP->GetParams().size(); - std::vector moduliQ(sizeQl); - std::vector rootsQ(sizeQl); - for (uint32_t i = 0; i < sizeQl; ++i) { - moduliQ[i] = paramsQlP->GetParams()[i]->GetModulus(); - rootsQ[i] = paramsQlP->GetParams()[i]->GetRootOfUnity(); - } - const auto paramsQl = std::make_shared(paramsQlP->GetCyclotomicOrder(), moduliQ, rootsQ); + std::shared_ptr paramsQl = cryptoParams->GetParamsQlHybrid(sizeQl); + if (!paramsQl) + paramsQl = MakeParamsQlFromQlP(paramsQlP, sizeQl); const PlaintextModulus t = (cryptoParams->GetNoiseScale() == 1) ? 0 : cryptoParams->GetPlaintextModulus(); diff --git a/src/pke/lib/schemerns/rns-cryptoparameters.cpp b/src/pke/lib/schemerns/rns-cryptoparameters.cpp index 02c659239..29679fea6 100644 --- a/src/pke/lib/schemerns/rns-cryptoparameters.cpp +++ b/src/pke/lib/schemerns/rns-cryptoparameters.cpp @@ -193,6 +193,19 @@ void CryptoParametersRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, Scaling m_paramsQP = std::make_shared>(2 * n, moduliQP, rootsQP); + // Precompute params for first 1..sizeQ towers of Q for KeySwitchDown (avoids per-call allocation). + m_paramsQlHybrid.resize(sizeQ); + for (size_t l = 0; l < sizeQ; ++l) { + std::vector moduliQl(l + 1); + std::vector rootsQl(l + 1); + for (size_t i = 0; i <= l; ++i) { + moduliQl[i] = moduliQ[i]; + rootsQl[i] = rootsQ[i]; + } + m_paramsQlHybrid[l] = + std::make_shared>(2 * n, std::move(moduliQl), std::move(rootsQl)); + } + // Pre-compute CRT::FFT values for P ChineseRemainderTransformFTT().PreCompute(rootsP, 2 * n, moduliP);