Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/pke/include/schemerns/rns-cryptoparameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,18 @@ class CryptoParametersRNS : public CryptoParametersRLWE<DCRTPoly> {
return m_paramsP;
}

/**
* Gets cached params for the first sizeQl towers of Q, used in KeySwitchDown.
* Only populated when key switching technique is HYBRID.
* @param sizeQl number of Q towers (1 to GetElementParams()->GetParams().size()).
* @return shared_ptr to ILDCRTParams for the first sizeQl moduli, or nullptr if not cached.
*/
const std::shared_ptr<ILDCRTParams<BigInteger>> GetParamsQlHybrid(uint32_t sizeQl) const {
if (m_paramsQlHybrid.empty() || sizeQl == 0 || sizeQl > m_paramsQlHybrid.size())
return nullptr;
return m_paramsQlHybrid[sizeQl - 1];
}

/**
* Method that returns the number of towers within every digit.
* This is the alpha parameter from the paper (see documentation
Expand Down Expand Up @@ -1446,6 +1458,10 @@ class CryptoParametersRNS : public CryptoParametersRLWE<DCRTPoly> {
// used in GHS key switching
std::shared_ptr<ILDCRTParams<BigInteger>> m_paramsQP;

// Cached params for Q_l (first sizeQl towers of Q) used in KeySwitchDown.
// m_paramsQlHybrid[l] = params for first (l+1) Q towers; index by sizeQl-1 when looking up.
std::vector<std::shared_ptr<ILDCRTParams<BigInteger>>> m_paramsQlHybrid;

// Stores the partition size {PartQ} = {Q_1,...,Q_l}
// where each Q_i is the product of q_j
uint32_t m_numPartQ = 0;
Expand Down
39 changes: 22 additions & 17 deletions src/pke/lib/keyswitch/keyswitch-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@

namespace lbcrypto {

namespace {

std::shared_ptr<DCRTPoly::Params> MakeParamsQlFromQlP(
const std::shared_ptr<DCRTPoly::Params>& paramsQlP, uint32_t sizeQl) {
std::vector<NativeInteger> moduliQ(sizeQl);
std::vector<NativeInteger> rootsQ(sizeQl);
for (uint32_t i = 0; i < sizeQl; ++i) {
moduliQ[i] = paramsQlP->GetParams()[i]->GetModulus();
rootsQ[i] = paramsQlP->GetParams()[i]->GetRootOfUnity();
}
return std::make_shared<DCRTPoly::Params>(paramsQlP->GetCyclotomicOrder(), moduliQ, rootsQ);
}

} // namespace

EvalKey<DCRTPoly> KeySwitchHYBRID::KeySwitchGenInternal(const PrivateKey<DCRTPoly> oldKey,
const PrivateKey<DCRTPoly> newKey) const {
return KeySwitchHYBRID::KeySwitchGenInternal(oldKey, newKey, nullptr);
Expand Down Expand Up @@ -248,17 +263,12 @@ Ciphertext<DCRTPoly> KeySwitchHYBRID::KeySwitchDown(ConstCiphertext<DCRTPoly> ci
const auto paramsQlP = cv[0].GetParams();

const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());
const auto paramsP = cryptoParams->GetParamsP();
const auto paramsP = cryptoParams->GetParamsP();

// TODO : (Andrey) precompute paramsQl in cryptoparameters
const uint32_t sizeQl = paramsQlP->GetParams().size() - paramsP->GetParams().size();
std::vector<NativeInteger> moduliQ(sizeQl);
std::vector<NativeInteger> rootsQ(sizeQl);
for (uint32_t i = 0; i < sizeQl; ++i) {
moduliQ[i] = paramsQlP->GetParams()[i]->GetModulus();
rootsQ[i] = paramsQlP->GetParams()[i]->GetRootOfUnity();
}
const auto paramsQl = std::make_shared<ParmType>(paramsQlP->GetCyclotomicOrder(), moduliQ, rootsQ);
std::shared_ptr<ParmType> paramsQl = cryptoParams->GetParamsQlHybrid(sizeQl);
if (!paramsQl)
paramsQl = MakeParamsQlFromQlP(paramsQlP, sizeQl);

const PlaintextModulus t = (cryptoParams->GetNoiseScale() == 1) ? 0 : cryptoParams->GetPlaintextModulus();

Expand Down Expand Up @@ -286,15 +296,10 @@ DCRTPoly KeySwitchHYBRID::KeySwitchDownFirstElement(ConstCiphertext<DCRTPoly> ci
const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(ciphertext->GetCryptoParameters());
const auto paramsP = cryptoParams->GetParamsP();

// TODO : (Andrey) precompute paramsQl in cryptoparameters
const uint32_t sizeQl = paramsQlP->GetParams().size() - paramsP->GetParams().size();
std::vector<NativeInteger> moduliQ(sizeQl);
std::vector<NativeInteger> rootsQ(sizeQl);
for (uint32_t i = 0; i < sizeQl; ++i) {
moduliQ[i] = paramsQlP->GetParams()[i]->GetModulus();
rootsQ[i] = paramsQlP->GetParams()[i]->GetRootOfUnity();
}
const auto paramsQl = std::make_shared<ParmType>(paramsQlP->GetCyclotomicOrder(), moduliQ, rootsQ);
std::shared_ptr<ParmType> paramsQl = cryptoParams->GetParamsQlHybrid(sizeQl);
if (!paramsQl)
paramsQl = MakeParamsQlFromQlP(paramsQlP, sizeQl);

const PlaintextModulus t = (cryptoParams->GetNoiseScale() == 1) ? 0 : cryptoParams->GetPlaintextModulus();

Expand Down
13 changes: 13 additions & 0 deletions src/pke/lib/schemerns/rns-cryptoparameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,19 @@ void CryptoParametersRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, Scaling

m_paramsQP = std::make_shared<ILDCRTParams<BigInteger>>(2 * n, moduliQP, rootsQP);

// Precompute params for first 1..sizeQ towers of Q for KeySwitchDown (avoids per-call allocation).
m_paramsQlHybrid.resize(sizeQ);
for (size_t l = 0; l < sizeQ; ++l) {
std::vector<NativeInteger> moduliQl(l + 1);
std::vector<NativeInteger> rootsQl(l + 1);
for (size_t i = 0; i <= l; ++i) {
moduliQl[i] = moduliQ[i];
rootsQl[i] = rootsQ[i];
}
m_paramsQlHybrid[l] =
std::make_shared<ILDCRTParams<BigInteger>>(2 * n, std::move(moduliQl), std::move(rootsQl));
}

// Pre-compute CRT::FFT values for P
ChineseRemainderTransformFTT<NativeVector>().PreCompute(rootsP, 2 * n, moduliP);

Expand Down
Loading