Implementation of Shamir's trick (working in progress).
[crypto_lab3.git] / src / Tests.cpp
index 2adcd32..c9a2839 100644 (file)
@@ -8,6 +8,7 @@ using namespace std;
 #include "Rand.h"
 #include "RsaStd.h"
 #include "RsaCrt.h"
+#include "RsaCrtShamirsTrick.h"
 
 Tests::Tests(uint keySizeBits, uint rsaPublicExponent) :
    KEY_SIZE_BITS(keySizeBits),
@@ -28,6 +29,14 @@ void Tests::runTests()
       cout << "RSA CRT failed!" << endl;
 }
 
+void Tests::runTestsWithShamirsTrick()
+{
+   if (this->rsaCrtWithShamirsTrick())
+      cout << "RSA CRT with shamir's trick OK" << endl;
+   else
+      cout << "RSA CRT with shamir's trick failed!" << endl;
+}
+
 void Tests::runTimeMeasures()
 {
    const int N = 1000;
@@ -35,16 +44,20 @@ void Tests::runTimeMeasures()
 
    int timeRsaStd = 0;
    int timeRsaCRT = 0;
+   int timeRsaCRTShamirsTrick = 0;
 
    for (int k = 0; k < nbKeys; ++k)
    {
       timeRsaStd += timeSignRsaStd(N);
       timeRsaCRT += timeSignRsaCRT(N);
+      timeRsaCRTShamirsTrick += timeSignRsaCRTShamirsTrick(N);
    }
 
    cout << N * nbKeys << " x RSA standard: " << timeRsaStd << " ms" << endl;
    cout << N * nbKeys << " x RSA CRT: " << timeRsaCRT << " ms" << endl;
-   cout << "Speedup: " << (double(timeRsaStd) / double(timeRsaCRT)) << endl;
+   cout << N * nbKeys << " x RSA CRT Shamir's trick: " << timeRsaCRTShamirsTrick << " ms" << endl;
+   cout << "Speedup for CRT: " << (double(timeRsaStd) / double(timeRsaCRT)) << endl;
+   cout << "Speedup for CRT with Shamir's trick: " << (double(timeRsaStd) / double(timeRsaCRT)) << endl;
 }
 
 void Tests::doAttack()
@@ -52,11 +65,12 @@ void Tests::doAttack()
    const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
    const auto& kPub = keys.first;
    const auto& kPriv = keys.second;
+
    mpz_class message = Rand::randSize(128);
    mpz_class faultySignature = RsaCrt::signWithFaultySp(message, kPriv);
    mpz_class correctSignature = RsaCrt::sign(message, kPriv);
 
-   bool attackOK = true;
+   bool attackSuccessful = true;
 
    cout << "Original:" << endl;
    cout << " p = " << kPriv.p << endl;
@@ -75,7 +89,7 @@ void Tests::doAttack()
       cout << " p = " << p << endl;
       cout << " q = " << q << endl;
 
-      attackOK = attackOK && kPriv.p == p && kPriv.q == q; // With p and q we can recreate the original private key.
+      attackSuccessful = attackSuccessful && kPriv.p == p && kPriv.q == q; // With p and q we can recreate the original private key.
    }
 
    // Try the attack with a correct signature.
@@ -91,15 +105,22 @@ void Tests::doAttack()
       cout << " p = " << p << endl; // Equal to 1.
       cout << " q = " << q << endl; // Equal to n.
 
-      attackOK = attackOK && kPriv.p != p && kPriv.q != q;
+      attackSuccessful = attackSuccessful && kPriv.p != p && kPriv.q != q;
    }
 
-   if (attackOK)
+   if (attackSuccessful)
       cout << "Attack successful" << endl;
    else
       cout << "Attack failed" << endl;
 }
 
+void Tests::doAttackFixed()
+{
+   const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+   const auto& kPub = keys.first;
+   const auto& kPriv = keys.second;
+}
+
 bool Tests::rsaStandard()
 {
    const auto& keys = RsaStd::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
@@ -160,6 +181,36 @@ bool Tests::rsaCrt()
    return true;
 }
 
+bool Tests::rsaCrtWithShamirsTrick()
+{
+    const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+    const auto& kPub = keys.first;
+    const auto& kPriv = keys.second;
+
+    {
+       mpz_class message = kPub.n;
+       mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
+       if (Rsa::verifySignature(message, signature, kPub)) // Must not be able to signe message greater than kPub.n.
+          return false;
+    }
+
+    {
+       mpz_class message = kPub.n - 1;
+       mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
+       if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
+          return false;
+    }
+
+    {
+       mpz_class message = kPub.n / 2;
+       mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
+       if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
+          return false;
+    }
+
+    return true;
+}
+
 int Tests::timeSignRsaStd(int N)
 {
    Timer timer;
@@ -183,3 +234,15 @@ int Tests::timeSignRsaCRT(int N)
 
    return timer.ms();
 }
+
+int Tests::timeSignRsaCRTShamirsTrick(int N)
+{
+   Timer timer;
+   const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+
+   mpz_class message = Rand::randSize(KEY_SIZE_BITS / 2);
+   for (int i = 0; i < N; i++)
+      RsaCrtShamirsTrick::sign(message, keys.second);
+
+   return timer.ms();
+}