Shamir's trick implementation.
[crypto_lab3.git] / src / Tests.cpp
1 #include "Tests.h"
2
3 #include <iostream>
4 using namespace std;
5
6 #include <gmpxx.h>
7
8 #include "Rand.h"
9 #include "RsaStd.h"
10 #include "RsaCrt.h"
11 #include "RsaCrtShamirsTrick.h"
12
13 Tests::Tests(uint keySizeBits, uint rsaPublicExponent) :
14 KEY_SIZE_BITS(keySizeBits),
15 RSA_PUBLIC_EXPONENT(rsaPublicExponent)
16 {
17 }
18
19 void Tests::runTests()
20 {
21 cout << "Tests::runTests() ..." << endl;
22
23 cout << "RSA standard: " << (this->rsaStandard() ? "OK" : "failed!") << endl;
24 cout << "RSA CRT: " << (this->rsaCrt() ? "OK" : "failed!") << endl;
25 }
26
27 void Tests::runTestsWithShamirsTrick()
28 {
29 cout << "Tests::runTestsWithShamirsTrick() ..." << endl;
30
31 cout << "RSA CRT with Shamir's trick: " << (this->rsaCrtWithShamirsTrick() ? "OK" : "failed!") << endl;
32 }
33
34 void Tests::runTimeMeasures()
35 {
36 cout << "Tests::runTimeMeasures() ..." << endl;
37
38 const int N = 1000;
39 const int nbKeys = 20; // Number of different generated key.
40
41 int timeRsaStd = 0;
42 int timeRsaCRT = 0;
43 int timeRsaCRTShamirsTrick = 0;
44
45 for (int k = 0; k < nbKeys; ++k)
46 {
47 timeRsaStd += timeSignRsaStd(N);
48 timeRsaCRT += timeSignRsaCRT(N);
49 timeRsaCRTShamirsTrick += timeSignRsaCRTShamirsTrick(N);
50 }
51
52 cout << N * nbKeys << " x RSA standard: " << timeRsaStd << " ms" << endl;
53 cout << N * nbKeys << " x RSA CRT: " << timeRsaCRT << " ms" << endl;
54 cout << N * nbKeys << " x RSA CRT Shamir's trick: " << timeRsaCRTShamirsTrick << " ms" << endl;
55 cout << "Speedup for CRT: " << (double(timeRsaStd) / double(timeRsaCRT)) << endl;
56 cout << "Speedup for CRT with Shamir's trick: " << (double(timeRsaStd) / double(timeRsaCRTShamirsTrick)) << endl;
57 }
58
59 void Tests::doAttack()
60 {
61 cout << "Tests::doAttack() ..." << endl;
62
63 const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
64 const auto& kPub = keys.first;
65 const auto& kPriv = keys.second;
66
67 mpz_class message = Rand::randSize(128);
68 mpz_class faultySignature = RsaCrt::signWithFaultySp(message, kPriv);
69 mpz_class correctSignature = RsaCrt::sign(message, kPriv);
70
71 bool attackSuccessful = true;
72
73 cout << "Original:" << endl;
74 cout << " p = " << kPriv.p << endl;
75 cout << " q = " << kPriv.q << endl;
76
77 // At this point the attacker doesn't know the private key but he has intercepted the message and the faulty signature.
78 {
79 mpz_class faultySignaturePowerE;
80 mpz_pow_ui(faultySignaturePowerE.get_mpz_t(), faultySignature.get_mpz_t(), RSA_PUBLIC_EXPONENT);
81 mpz_class messageMinuxFaultySignaturePowerE = message - faultySignaturePowerE;
82 mpz_class q;
83 mpz_gcd(q.get_mpz_t(), messageMinuxFaultySignaturePowerE.get_mpz_t(), kPub.n.get_mpz_t());
84 mpz_class p = kPub.n / q;
85
86 cout << "Found with a faulty signature:" << endl;
87 cout << " p = " << p << endl;
88 cout << " q = " << q << endl;
89
90 attackSuccessful = attackSuccessful && kPriv.p == p && kPriv.q == q; // With p and q we can recreate the original private key.
91 }
92
93 // Try the attack with a correct signature (p and q shouldn't be found).
94 {
95 mpz_class correctSignaturePowerE;
96 mpz_pow_ui(correctSignaturePowerE.get_mpz_t(), correctSignature.get_mpz_t(), RSA_PUBLIC_EXPONENT);
97 mpz_class messageMinuxCorrectSignaturePowerE = message - correctSignaturePowerE;
98 mpz_class q;
99 mpz_gcd(q.get_mpz_t(), messageMinuxCorrectSignaturePowerE.get_mpz_t(), kPub.n.get_mpz_t());
100 mpz_class p = kPub.n / q;
101
102 cout << "Found with a correct signature:" << endl;
103 cout << " p = " << p << endl; // Equal to 1.
104 cout << " q = " << q << endl; // Equal to n.
105
106 attackSuccessful = attackSuccessful && kPriv.p != p && kPriv.q != q;
107 }
108
109 if (attackSuccessful)
110 cout << "Attack successful" << endl;
111 else
112 cout << "Attack failed" << endl;
113 }
114
115 void Tests::doAttackFixed()
116 {
117 cout << "Tests::doAttackFixed() ..." << endl;
118
119 const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
120 const auto& kPriv = keys.second;
121
122 mpz_class message = Rand::randSize(128);
123
124 try
125 {
126 RsaCrtShamirsTrick::signWithFaultySp(message, kPriv);
127 cout << "Attack successful -> incorrect" << endl;
128 }
129 catch (const RsaCrtShamirsTrick::UnableToSignWithShamirsTrick& e)
130 {
131 cout << "Attack failed -> correct" << endl;
132 }
133 }
134
135 bool Tests::rsaStandard()
136 {
137 const auto& keys = RsaStd::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
138 const auto& kPub = keys.first;
139 const auto& kPriv = keys.second;
140
141 {
142 mpz_class message = kPriv.n;
143 mpz_class signature = RsaStd::sign(message, kPriv);
144 if (Rsa::verifySignature(message, signature, kPub)) // Must not be able to signe message greater than kPriv.n.
145 return false;
146 }
147
148 {
149 mpz_class message = kPriv.n - 1;
150 mpz_class signature = RsaStd::sign(message, kPriv);
151 if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
152 return false;
153 }
154
155 {
156 mpz_class message = kPriv.n / 2;
157 mpz_class signature = RsaStd::sign(message, kPriv);
158 if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
159 return false;
160 }
161
162 return true;
163 }
164
165 bool Tests::rsaCrt()
166 {
167 const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
168 const auto& kPub = keys.first;
169 const auto& kPriv = keys.second;
170
171 {
172 mpz_class message = kPub.n;
173 mpz_class signature = RsaCrt::sign(message, kPriv);
174 if (Rsa::verifySignature(message, signature, kPub)) // Must not be able to signe message greater than kPub.n.
175 return false;
176 }
177
178 {
179 mpz_class message = kPub.n - 1;
180 mpz_class signature = RsaCrt::sign(message, kPriv);
181 if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
182 return false;
183 }
184
185 {
186 mpz_class message = kPub.n / 2;
187 mpz_class signature = RsaCrt::sign(message, kPriv);
188 if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
189 return false;
190 }
191
192 return true;
193 }
194
195 bool Tests::rsaCrtWithShamirsTrick()
196 {
197 const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
198 const auto& kPub = keys.first;
199 const auto& kPriv = keys.second;
200
201 {
202 mpz_class message = kPub.n;
203 mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
204 if (Rsa::verifySignature(message, signature, kPub)) // Must not be able to signe message greater than kPub.n.
205 return false;
206 }
207
208 {
209 mpz_class message = kPub.n - 1;
210 mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
211 if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
212 return false;
213 }
214
215 {
216 mpz_class message = kPub.n / 2;
217 mpz_class signature = RsaCrtShamirsTrick::sign(message, kPriv);
218 if (!Rsa::verifySignature(message, signature, kPub) || Rsa::verifySignature(message + 1, signature, kPub))
219 return false;
220 }
221
222 return true;
223 }
224
225 int Tests::timeSignRsaStd(int N)
226 {
227 Timer timer;
228 const auto& keys = RsaStd::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
229
230 mpz_class message = Rand::randSize(KEY_SIZE_BITS / 2);
231 for (int i = 0; i < N; i++)
232 RsaStd::sign(message, keys.second);
233
234 return timer.ms();
235 }
236
237 int Tests::timeSignRsaCRT(int N)
238 {
239 Timer timer;
240 const auto& keys = RsaCrt::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
241
242 mpz_class message = Rand::randSize(KEY_SIZE_BITS / 2);
243 for (int i = 0; i < N; i++)
244 RsaCrt::sign(message, keys.second);
245
246 return timer.ms();
247 }
248
249 int Tests::timeSignRsaCRTShamirsTrick(int N)
250 {
251 Timer timer;
252 const auto& keys = RsaCrtShamirsTrick::generateRSAKeys(RSA_PUBLIC_EXPONENT, KEY_SIZE_BITS);
253
254 mpz_class message = Rand::randSize(KEY_SIZE_BITS / 2);
255 for (int i = 0; i < N; i++)
256 RsaCrtShamirsTrick::sign(message, keys.second);
257
258 return timer.ms();
259 }