+#include "Tests.h"
+
+#include <iostream>
+using namespace std;
+
+#include <gmpxx.h>
+
+#include "Rand.h"
+#include "RsaStd.h"
+#include "RsaCrt.h"
+
+Tests::Tests(uint keySizeBits, uint rsaPublicExponent) :
+ KEY_SIZE_BITS(keySizeBits),
+ RSA_PUBLIC_EXPONENT(rsaPublicExponent)
+{
+}
+
+void Tests::runTests()
+{
+ if (this->rsaStandard())
+ cout << "RSA standard OK" << endl;
+ else
+ cout << "RSA standard failed!" << endl;
+
+ if (this->rsaCrt())
+ cout << "RSA CRT OK" << endl;
+ else
+ cout << "RSA CRT failed!" << endl;
+}
+
+void Tests::runTimeMeasures()
+{
+ const int N = 1000;
+ const int nbKeys = 20; // Number of different generated key.
+
+ int timeRsaStd = 0;
+ int timeRsaCRT = 0;
+
+ for (int k = 0; k < nbKeys; ++k)
+ {
+ timeRsaStd += timeSignRsaStd(N);
+ timeRsaCRT += timeSignRsaCRT(N);
+ }
+
+ cout << N * nbKeys << " x RSA standard: " << timeRsaStd << " ms" << endl;
+ cout << N * nbKeys << " x RSA CRT: " << timeRsaCRT << " ms" << endl;
+ cout << "Speedup: " << (double(timeRsaStd) / double(timeRsaCRT)) << endl;
+}
+
+bool Tests::rsaStandard()
+{
+ const auto& keys = RsaStd::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+ const auto& kPub = keys.first;
+ const auto& kPriv = keys.second;
+
+ {
+ mpz_class message = kPriv.n;
+ mpz_class signature = RsaStd::sign(message, kPriv);
+ if (Rsa::verifySignature(message, signature, kPub)) // Must not be able to signe message greater than kPriv.n.
+ return false;
+ }
+
+ {
+ mpz_class message = kPriv.n - 1;
+ mpz_class signature = RsaStd::sign(message, kPriv);
+ if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
+ return false;
+ }
+
+ {
+ mpz_class message = kPriv.n / 2;
+ mpz_class signature = RsaStd::sign(message, kPriv);
+ if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
+ return false;
+ }
+
+ return true;
+}
+
+bool Tests::rsaCrt()
+{
+ const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+ const auto& kPub = keys.first;
+ const auto& kPriv = keys.second;
+
+ {
+ mpz_class message = kPriv.n;
+ mpz_class signature = RsaCrt::sign(message, kPriv);
+ if (Rsa::verifySignature(message, signature, kPub)) // Must not be able to signe message greater than kPriv.n.
+ return false;
+ }
+
+ {
+ mpz_class message = kPriv.n - 1;
+ mpz_class signature = RsaCrt::sign(message, kPriv);
+ if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
+ return false;
+ }
+
+ {
+ mpz_class message = kPriv.n / 2;
+ mpz_class signature = RsaCrt::sign(message, kPriv);
+ if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
+ return false;
+ }
+
+ return true;
+}
+
+int Tests::timeSignRsaStd(int N)
+{
+ Timer timer;
+ const auto& keys = RsaStd::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+
+ mpz_class message = Rand::randSize(KEY_SIZE_BITS / 2);
+ for (int i = 0; i < N; i++)
+ RsaStd::sign(message, keys.second);
+
+ return timer.ms();
+}
+
+int Tests::timeSignRsaCRT(int N)
+{
+ Timer timer;
+ const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
+
+ mpz_class message = Rand::randSize(KEY_SIZE_BITS / 2);
+ for (int i = 0; i < N; i++)
+ RsaCrt::sign(message, keys.second);
+
+ return timer.ms();
+}