#include "Rand.h"
#include "RsaStd.h"
#include "RsaCrt.h"
+#include "RsaCrtShamirsTrick.h"
Tests::Tests(uint keySizeBits, uint rsaPublicExponent) :
KEY_SIZE_BITS(keySizeBits),
cout << "RSA CRT failed!" << endl;
}
+void Tests::runTestsWithShamirsTrick()
+{
+ if (this->rsaCrtWithShamirsTrick())
+ cout << "RSA CRT with shamir's trick OK" << endl;
+ else
+ cout << "RSA CRT with shamir's trick failed!" << endl;
+}
+
void Tests::runTimeMeasures()
{
const int N = 1000;
int timeRsaStd = 0;
int timeRsaCRT = 0;
+ int timeRsaCRTShamirsTrick = 0;
for (int k = 0; k < nbKeys; ++k)
{
timeRsaStd += timeSignRsaStd(N);
timeRsaCRT += timeSignRsaCRT(N);
+ timeRsaCRTShamirsTrick += timeSignRsaCRTShamirsTrick(N);
}
cout << N * nbKeys << " x RSA standard: " << timeRsaStd << " ms" << endl;
cout << N * nbKeys << " x RSA CRT: " << timeRsaCRT << " ms" << endl;
- cout << "Speedup: " << (double(timeRsaStd) / double(timeRsaCRT)) << endl;
+ cout << N * nbKeys << " x RSA CRT Shamir's trick: " << timeRsaCRTShamirsTrick << " ms" << endl;
+ cout << "Speedup for CRT: " << (double(timeRsaStd) / double(timeRsaCRT)) << endl;
+ cout << "Speedup for CRT with Shamir's trick: " << (double(timeRsaStd) / double(timeRsaCRT)) << endl;
}
void Tests::doAttack()
const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
const auto& kPub = keys.first;
const auto& kPriv = keys.second;
+
mpz_class message = Rand::randSize(128);
mpz_class faultySignature = RsaCrt::signWithFaultySp(message, kPriv);
mpz_class correctSignature = RsaCrt::sign(message, kPriv);
- bool attackOK = true;
+ bool attackSuccessful = true;
cout << "Original:" << endl;
cout << " p = " << kPriv.p << endl;
cout << " p = " << p << endl;
cout << " q = " << q << endl;
- attackOK = attackOK && kPriv.p == p && kPriv.q == q; // With p and q we can recreate the original private key.
+ attackSuccessful = attackSuccessful && kPriv.p == p && kPriv.q == q; // With p and q we can recreate the original private key.
}
// Try the attack with a correct signature.
cout << " p = " << p << endl; // Equal to 1.
cout << " q = " << q << endl; // Equal to n.
- attackOK = attackOK && kPriv.p != p && kPriv.q != q;
+ attackSuccessful = attackSuccessful && kPriv.p != p && kPriv.q != q;
}
- if (attackOK)
+ if (attackSuccessful)
cout << "Attack successful" << endl;
else
cout << "Attack failed" << endl;
}
+void Tests::doAttackFixed()
+{
+ const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+ const auto& kPub = keys.first;
+ const auto& kPriv = keys.second;
+}
+
bool Tests::rsaStandard()
{
const auto& keys = RsaStd::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
return true;
}
+bool Tests::rsaCrtWithShamirsTrick()
+{
+ const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+ const auto& kPub = keys.first;
+ const auto& kPriv = keys.second;
+
+ {
+ mpz_class message = kPub.n;
+ mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
+ if (Rsa::verifySignature(message, signature, kPub)) // Must not be able to signe message greater than kPub.n.
+ return false;
+ }
+
+ {
+ mpz_class message = kPub.n - 1;
+ mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
+ if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
+ return false;
+ }
+
+ {
+ mpz_class message = kPub.n / 2;
+ mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
+ if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
+ return false;
+ }
+
+ return true;
+}
+
int Tests::timeSignRsaStd(int N)
{
Timer timer;
return timer.ms();
}
+
+int Tests::timeSignRsaCRTShamirsTrick(int N)
+{
+ Timer timer;
+ const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+
+ mpz_class message = Rand::randSize(KEY_SIZE_BITS / 2);
+ for (int i = 0; i < N; i++)
+ RsaCrtShamirsTrick::sign(message, keys.second);
+
+ return timer.ms();
+}