#include "Rand.h"
#include "RsaStd.h"
#include "RsaCrt.h"
+#include "RsaCrtShamirsTrick.h"
Tests::Tests(uint keySizeBits, uint rsaPublicExponent) :
KEY_SIZE_BITS(keySizeBits),
void Tests::runTests()
{
- if (this->rsaStandard())
- cout << "RSA standard OK" << endl;
- else
- cout << "RSA standard failed!" << endl;
+ cout << "Tests::runTests() ..." << endl;
- if (this->rsaCrt())
- cout << "RSA CRT OK" << endl;
- else
- cout << "RSA CRT failed!" << endl;
+ cout << "RSA standard: " << (this->rsaStandard() ? "OK" : "failed!") << endl;
+ cout << "RSA CRT: " << (this->rsaCrt() ? "OK" : "failed!") << endl;
+}
+
+void Tests::runTestsWithShamirsTrick()
+{
+ cout << "Tests::runTestsWithShamirsTrick() ..." << endl;
+
+ cout << "RSA CRT with Shamir's trick: " << (this->rsaCrtWithShamirsTrick() ? "OK" : "failed!") << endl;
}
void Tests::runTimeMeasures()
{
+ cout << "Tests::runTimeMeasures() ..." << endl;
+
const int N = 1000;
const int nbKeys = 20; // Number of different generated key.
int timeRsaStd = 0;
int timeRsaCRT = 0;
+ int timeRsaCRTShamirsTrick = 0;
for (int k = 0; k < nbKeys; ++k)
{
timeRsaStd += timeSignRsaStd(N);
timeRsaCRT += timeSignRsaCRT(N);
+ timeRsaCRTShamirsTrick += timeSignRsaCRTShamirsTrick(N);
}
cout << N * nbKeys << " x RSA standard: " << timeRsaStd << " ms" << endl;
cout << N * nbKeys << " x RSA CRT: " << timeRsaCRT << " ms" << endl;
- cout << "Speedup: " << (double(timeRsaStd) / double(timeRsaCRT)) << endl;
+ cout << N * nbKeys << " x RSA CRT Shamir's trick: " << timeRsaCRTShamirsTrick << " ms" << endl;
+ cout << "Speedup for CRT: " << (double(timeRsaStd) / double(timeRsaCRT)) << endl;
+ cout << "Speedup for CRT with Shamir's trick: " << (double(timeRsaStd) / double(timeRsaCRTShamirsTrick)) << endl;
}
void Tests::doAttack()
{
+ cout << "Tests::doAttack() ..." << endl;
+
const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
const auto& kPub = keys.first;
const auto& kPriv = keys.second;
+
mpz_class message = Rand::randSize(128);
mpz_class faultySignature = RsaCrt::signWithFaultySp(message, kPriv);
mpz_class correctSignature = RsaCrt::sign(message, kPriv);
- bool attackOK = true;
+ bool attackSuccessful = true;
cout << "Original:" << endl;
cout << " p = " << kPriv.p << endl;
cout << " p = " << p << endl;
cout << " q = " << q << endl;
- attackOK = attackOK && kPriv.p == p && kPriv.q == q; // With p and q we can recreate the original private key.
+ attackSuccessful = attackSuccessful && kPriv.p == p && kPriv.q == q; // With p and q we can recreate the original private key.
}
- // Try the attack with a correct signature.
+ // Try the attack with a correct signature (p and q shouldn't be found).
{
mpz_class correctSignaturePowerE;
mpz_pow_ui(correctSignaturePowerE.get_mpz_t(), correctSignature.get_mpz_t(), RSA_PUBLIC_EXPONENT);
cout << " p = " << p << endl; // Equal to 1.
cout << " q = " << q << endl; // Equal to n.
- attackOK = attackOK && kPriv.p != p && kPriv.q != q;
+ attackSuccessful = attackSuccessful && kPriv.p != p && kPriv.q != q;
}
- if (attackOK)
+ if (attackSuccessful)
cout << "Attack successful" << endl;
else
cout << "Attack failed" << endl;
}
+void Tests::doAttackFixed()
+{
+ cout << "Tests::doAttackFixed() ..." << endl;
+
+ const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+ const auto& kPriv = keys.second;
+
+ mpz_class message = Rand::randSize(128);
+
+ try
+ {
+ RsaCrtShamirsTrick::signWithFaultySp(message, kPriv);
+ cout << "Attack successful -> incorrect" << endl;
+ }
+ catch (const RsaCrtShamirsTrick::UnableToSignWithShamirsTrick& e)
+ {
+ cout << "Attack failed -> correct" << endl;
+ }
+}
+
bool Tests::rsaStandard()
{
const auto& keys = RsaStd::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
return true;
}
+bool Tests::rsaCrtWithShamirsTrick()
+{
+ const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+ const auto& kPub = keys.first;
+ const auto& kPriv = keys.second;
+
+ {
+ mpz_class message = kPub.n;
+ mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
+ if (Rsa::verifySignature(message, signature, kPub)) // Must not be able to signe message greater than kPub.n.
+ return false;
+ }
+
+ {
+ mpz_class message = kPub.n - 1;
+ mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
+ if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
+ return false;
+ }
+
+ {
+ mpz_class message = kPub.n / 2;
+ mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
+ if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
+ return false;
+ }
+
+ return true;
+}
+
int Tests::timeSignRsaStd(int N)
{
Timer timer;
return timer.ms();
}
+
+int Tests::timeSignRsaCRTShamirsTrick(int N)
+{
+ Timer timer;
+ const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+
+ mpz_class message = Rand::randSize(KEY_SIZE_BITS / 2);
+ for (int i = 0; i < N; i++)
+ RsaCrtShamirsTrick::sign(message, keys.second);
+
+ return timer.ms();
+}