Implementation of Shamir's trick (working in progress).
[crypto_lab3.git] / src / Tests.cpp
1 #include "Tests.h"
2
3 #include <iostream>
4 using namespace std;
5
6 #include <gmpxx.h>
7
8 #include "Rand.h"
9 #include "RsaStd.h"
10 #include "RsaCrt.h"
11 #include "RsaCrtShamirsTrick.h"
12
13 Tests::Tests(uint keySizeBits, uint rsaPublicExponent) :
14 KEY_SIZE_BITS(keySizeBits),
15 RSA_PUBLIC_EXPONENT(rsaPublicExponent)
16 {
17 }
18
19 void Tests::runTests()
20 {
21 if (this->rsaStandard())
22 cout << "RSA standard OK" << endl;
23 else
24 cout << "RSA standard failed!" << endl;
25
26 if (this->rsaCrt())
27 cout << "RSA CRT OK" << endl;
28 else
29 cout << "RSA CRT failed!" << endl;
30 }
31
32 void Tests::runTestsWithShamirsTrick()
33 {
34 if (this->rsaCrtWithShamirsTrick())
35 cout << "RSA CRT with shamir's trick OK" << endl;
36 else
37 cout << "RSA CRT with shamir's trick failed!" << endl;
38 }
39
40 void Tests::runTimeMeasures()
41 {
42 const int N = 1000;
43 const int nbKeys = 20; // Number of different generated key.
44
45 int timeRsaStd = 0;
46 int timeRsaCRT = 0;
47 int timeRsaCRTShamirsTrick = 0;
48
49 for (int k = 0; k < nbKeys; ++k)
50 {
51 timeRsaStd += timeSignRsaStd(N);
52 timeRsaCRT += timeSignRsaCRT(N);
53 timeRsaCRTShamirsTrick += timeSignRsaCRTShamirsTrick(N);
54 }
55
56 cout << N * nbKeys << " x RSA standard: " << timeRsaStd << " ms" << endl;
57 cout << N * nbKeys << " x RSA CRT: " << timeRsaCRT << " ms" << endl;
58 cout << N * nbKeys << " x RSA CRT Shamir's trick: " << timeRsaCRTShamirsTrick << " ms" << endl;
59 cout << "Speedup for CRT: " << (double(timeRsaStd) / double(timeRsaCRT)) << endl;
60 cout << "Speedup for CRT with Shamir's trick: " << (double(timeRsaStd) / double(timeRsaCRT)) << endl;
61 }
62
63 void Tests::doAttack()
64 {
65 const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
66 const auto& kPub = keys.first;
67 const auto& kPriv = keys.second;
68
69 mpz_class message = Rand::randSize(128);
70 mpz_class faultySignature = RsaCrt::signWithFaultySp(message, kPriv);
71 mpz_class correctSignature = RsaCrt::sign(message, kPriv);
72
73 bool attackSuccessful = true;
74
75 cout << "Original:" << endl;
76 cout << " p = " << kPriv.p << endl;
77 cout << " q = " << kPriv.q << endl;
78
79 // At this point the attacker doesn't know the private key but he has intercepted the message and the faulty signature.
80 {
81 mpz_class faultySignaturePowerE;
82 mpz_pow_ui(faultySignaturePowerE.get_mpz_t(), faultySignature.get_mpz_t(), RSA_PUBLIC_EXPONENT);
83 mpz_class messageMinuxFaultySignaturePowerE = message - faultySignaturePowerE;
84 mpz_class q;
85 mpz_gcd(q.get_mpz_t(), messageMinuxFaultySignaturePowerE.get_mpz_t(), kPub.n.get_mpz_t());
86 mpz_class p = kPub.n / q;
87
88 cout << "Found with a faulty signature:" << endl;
89 cout << " p = " << p << endl;
90 cout << " q = " << q << endl;
91
92 attackSuccessful = attackSuccessful && kPriv.p == p && kPriv.q == q; // With p and q we can recreate the original private key.
93 }
94
95 // Try the attack with a correct signature.
96 {
97 mpz_class correctSignaturePowerE;
98 mpz_pow_ui(correctSignaturePowerE.get_mpz_t(), correctSignature.get_mpz_t(), RSA_PUBLIC_EXPONENT);
99 mpz_class messageMinuxCorrectSignaturePowerE = message - correctSignaturePowerE;
100 mpz_class q;
101 mpz_gcd(q.get_mpz_t(), messageMinuxCorrectSignaturePowerE.get_mpz_t(), kPub.n.get_mpz_t());
102 mpz_class p = kPub.n / q;
103
104 cout << "Found with a correct signature:" << endl;
105 cout << " p = " << p << endl; // Equal to 1.
106 cout << " q = " << q << endl; // Equal to n.
107
108 attackSuccessful = attackSuccessful && kPriv.p != p && kPriv.q != q;
109 }
110
111 if (attackSuccessful)
112 cout << "Attack successful" << endl;
113 else
114 cout << "Attack failed" << endl;
115 }
116
117 void Tests::doAttackFixed()
118 {
119 const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
120 const auto& kPub = keys.first;
121 const auto& kPriv = keys.second;
122 }
123
124 bool Tests::rsaStandard()
125 {
126 const auto& keys = RsaStd::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
127 const auto& kPub = keys.first;
128 const auto& kPriv = keys.second;
129
130 {
131 mpz_class message = kPriv.n;
132 mpz_class signature = RsaStd::sign(message, kPriv);
133 if (Rsa::verifySignature(message, signature, kPub)) // Must not be able to signe message greater than kPriv.n.
134 return false;
135 }
136
137 {
138 mpz_class message = kPriv.n - 1;
139 mpz_class signature = RsaStd::sign(message, kPriv);
140 if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
141 return false;
142 }
143
144 {
145 mpz_class message = kPriv.n / 2;
146 mpz_class signature = RsaStd::sign(message, kPriv);
147 if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
148 return false;
149 }
150
151 return true;
152 }
153
154 bool Tests::rsaCrt()
155 {
156 const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
157 const auto& kPub = keys.first;
158 const auto& kPriv = keys.second;
159
160 {
161 mpz_class message = kPub.n;
162 mpz_class signature = RsaCrt::sign(message, kPriv);
163 if (Rsa::verifySignature(message, signature, kPub)) // Must not be able to signe message greater than kPub.n.
164 return false;
165 }
166
167 {
168 mpz_class message = kPub.n - 1;
169 mpz_class signature = RsaCrt::sign(message, kPriv);
170 if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
171 return false;
172 }
173
174 {
175 mpz_class message = kPub.n / 2;
176 mpz_class signature = RsaCrt::sign(message, kPriv);
177 if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
178 return false;
179 }
180
181 return true;
182 }
183
184 bool Tests::rsaCrtWithShamirsTrick()
185 {
186 const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
187 const auto& kPub = keys.first;
188 const auto& kPriv = keys.second;
189
190 {
191 mpz_class message = kPub.n;
192 mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
193 if (Rsa::verifySignature(message, signature, kPub)) // Must not be able to signe message greater than kPub.n.
194 return false;
195 }
196
197 {
198 mpz_class message = kPub.n - 1;
199 mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
200 if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
201 return false;
202 }
203
204 {
205 mpz_class message = kPub.n / 2;
206 mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
207 if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
208 return false;
209 }
210
211 return true;
212 }
213
214 int Tests::timeSignRsaStd(int N)
215 {
216 Timer timer;
217 const auto& keys = RsaStd::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
218
219 mpz_class message = Rand::randSize(KEY_SIZE_BITS / 2);
220 for (int i = 0; i < N; i++)
221 RsaStd::sign(message, keys.second);
222
223 return timer.ms();
224 }
225
226 int Tests::timeSignRsaCRT(int N)
227 {
228 Timer timer;
229 const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
230
231 mpz_class message = Rand::randSize(KEY_SIZE_BITS / 2);
232 for (int i = 0; i < N; i++)
233 RsaCrt::sign(message, keys.second);
234
235 return timer.ms();
236 }
237
238 int Tests::timeSignRsaCRTShamirsTrick(int N)
239 {
240 Timer timer;
241 const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
242
243 mpz_class message = Rand::randSize(KEY_SIZE_BITS / 2);
244 for (int i = 0; i < N; i++)
245 RsaCrtShamirsTrick::sign(message, keys.second);
246
247 return timer.ms();
248 }