Shamir's trick implementation.
[crypto_lab3.git] / src / Tests.cpp
index 92bccd0..73ee552 100644 (file)
@@ -8,6 +8,7 @@ using namespace std;
 #include "Rand.h"
 #include "RsaStd.h"
 #include "RsaCrt.h"
+#include "RsaCrtShamirsTrick.h"
 
 Tests::Tests(uint keySizeBits, uint rsaPublicExponent) :
    KEY_SIZE_BITS(keySizeBits),
@@ -17,34 +18,118 @@ Tests::Tests(uint keySizeBits, uint rsaPublicExponent) :
 
 void Tests::runTests()
 {
-   if (this->rsaStandard())
-      cout << "RSA standard OK" << endl;
-   else
-      cout << "RSA standard failed!" << endl;
+   cout << "Tests::runTests() ..." << endl;
 
-   if (this->rsaCrt())
-      cout << "RSA CRT OK" << endl;
-   else
-      cout << "RSA CRT failed!" << endl;
+   cout << "RSA standard: " << (this->rsaStandard() ? "OK" : "failed!") << endl;
+   cout << "RSA CRT: " <<  (this->rsaCrt() ? "OK" : "failed!") << endl;
+}
+
+void Tests::runTestsWithShamirsTrick()
+{
+   cout << "Tests::runTestsWithShamirsTrick() ..." << endl;
+
+   cout << "RSA CRT with Shamir's trick: " << (this->rsaCrtWithShamirsTrick() ? "OK" : "failed!") << endl;
 }
 
 void Tests::runTimeMeasures()
 {
+   cout << "Tests::runTimeMeasures() ..." << endl;
+
    const int N = 1000;
    const int nbKeys = 20; // Number of different generated key.
 
    int timeRsaStd = 0;
    int timeRsaCRT = 0;
+   int timeRsaCRTShamirsTrick = 0;
 
    for (int k = 0; k < nbKeys; ++k)
    {
       timeRsaStd += timeSignRsaStd(N);
       timeRsaCRT += timeSignRsaCRT(N);
+      timeRsaCRTShamirsTrick += timeSignRsaCRTShamirsTrick(N);
    }
 
    cout << N * nbKeys << " x RSA standard: " << timeRsaStd << " ms" << endl;
    cout << N * nbKeys << " x RSA CRT: " << timeRsaCRT << " ms" << endl;
-   cout << "Speedup: " << (double(timeRsaStd) / double(timeRsaCRT)) << endl;
+   cout << N * nbKeys << " x RSA CRT Shamir's trick: " << timeRsaCRTShamirsTrick << " ms" << endl;
+   cout << "Speedup for CRT: " << (double(timeRsaStd) / double(timeRsaCRT)) << endl;
+   cout << "Speedup for CRT with Shamir's trick: " << (double(timeRsaStd) / double(timeRsaCRTShamirsTrick)) << endl;
+}
+
+void Tests::doAttack()
+{
+   cout << "Tests::doAttack() ..." << endl;
+
+   const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+   const auto& kPub = keys.first;
+   const auto& kPriv = keys.second;
+
+   mpz_class message = Rand::randSize(128);
+   mpz_class faultySignature = RsaCrt::signWithFaultySp(message, kPriv);
+   mpz_class correctSignature = RsaCrt::sign(message, kPriv);
+
+   bool attackSuccessful = true;
+
+   cout << "Original:" << endl;
+   cout << " p = " << kPriv.p << endl;
+   cout << " q = " << kPriv.q << endl;
+
+   // At this point the attacker doesn't know the private key but he has intercepted the message and the faulty signature.
+   {
+      mpz_class faultySignaturePowerE;
+      mpz_pow_ui(faultySignaturePowerE.get_mpz_t(), faultySignature.get_mpz_t(), RSA_PUBLIC_EXPONENT);
+      mpz_class messageMinuxFaultySignaturePowerE = message - faultySignaturePowerE;
+      mpz_class q;
+      mpz_gcd(q.get_mpz_t(), messageMinuxFaultySignaturePowerE.get_mpz_t(), kPub.n.get_mpz_t());
+      mpz_class p = kPub.n / q;
+
+      cout << "Found with a faulty signature:" << endl;
+      cout << " p = " << p << endl;
+      cout << " q = " << q << endl;
+
+      attackSuccessful = attackSuccessful && kPriv.p == p && kPriv.q == q; // With p and q we can recreate the original private key.
+   }
+
+   // Try the attack with a correct signature (p and q shouldn't be found).
+   {
+      mpz_class correctSignaturePowerE;
+      mpz_pow_ui(correctSignaturePowerE.get_mpz_t(), correctSignature.get_mpz_t(), RSA_PUBLIC_EXPONENT);
+      mpz_class messageMinuxCorrectSignaturePowerE = message - correctSignaturePowerE;
+      mpz_class q;
+      mpz_gcd(q.get_mpz_t(), messageMinuxCorrectSignaturePowerE.get_mpz_t(), kPub.n.get_mpz_t());
+      mpz_class p = kPub.n / q;
+
+      cout << "Found with a correct signature:" << endl;
+      cout << " p = " << p << endl; // Equal to 1.
+      cout << " q = " << q << endl; // Equal to n.
+
+      attackSuccessful = attackSuccessful && kPriv.p != p && kPriv.q != q;
+   }
+
+   if (attackSuccessful)
+      cout << "Attack successful" << endl;
+   else
+      cout << "Attack failed" << endl;
+}
+
+void Tests::doAttackFixed()
+{
+   cout << "Tests::doAttackFixed() ..." << endl;
+
+   const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+   const auto& kPriv = keys.second;
+
+   mpz_class message = Rand::randSize(128);
+
+   try
+   {
+      RsaCrtShamirsTrick::signWithFaultySp(message, kPriv);
+      cout << "Attack successful -> incorrect" << endl;
+   }
+   catch (const RsaCrtShamirsTrick::UnableToSignWithShamirsTrick& e)
+   {
+      cout << "Attack failed -> correct" << endl;
+   }
 }
 
 bool Tests::rsaStandard()
@@ -107,6 +192,36 @@ bool Tests::rsaCrt()
    return true;
 }
 
+bool Tests::rsaCrtWithShamirsTrick()
+{
+    const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+    const auto& kPub = keys.first;
+    const auto& kPriv = keys.second;
+
+    {
+       mpz_class message = kPub.n;
+       mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
+       if (Rsa::verifySignature(message, signature, kPub)) // Must not be able to signe message greater than kPub.n.
+          return false;
+    }
+
+    {
+       mpz_class message = kPub.n - 1;
+       mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
+       if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
+          return false;
+    }
+
+    {
+       mpz_class message = kPub.n / 2;
+       mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
+       if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
+          return false;
+    }
+
+    return true;
+}
+
 int Tests::timeSignRsaStd(int N)
 {
    Timer timer;
@@ -130,3 +245,15 @@ int Tests::timeSignRsaCRT(int N)
 
    return timer.ms();
 }
+
+int Tests::timeSignRsaCRTShamirsTrick(int N)
+{
+   Timer timer;
+   const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+
+   mpz_class message = Rand::randSize(KEY_SIZE_BITS / 2);
+   for (int i = 0; i < N; i++)
+      RsaCrtShamirsTrick::sign(message, keys.second);
+
+   return timer.ms();
+}