Sign up form and other stuff
[recipes.git] / backend / src / db.rs
1 use std::{fmt, fs::{self, File}, path::Path, io::Read};
2
3 use itertools::Itertools;
4 use chrono::{prelude::*, Duration};
5 use rusqlite::{named_params, OptionalExtension, params, Params};
6 use r2d2::Pool;
7 use r2d2_sqlite::SqliteConnectionManager;
8 use rand::distributions::{Alphanumeric, DistString};
9
10 use crate::consts;
11 use crate::hash::{hash, verify_password};
12 use crate::model;
13 use crate::user::*;
14
15 const CURRENT_DB_VERSION: u32 = 1;
16
17 #[derive(Debug)]
18 pub enum DBError {
19 SqliteError(rusqlite::Error),
20 R2d2Error(r2d2::Error),
21 UnsupportedVersion(u32),
22 Other(String),
23 }
24
25 impl fmt::Display for DBError {
26 fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), fmt::Error> {
27 write!(f, "{:?}", self)
28 }
29 }
30
31 impl std::error::Error for DBError { }
32
33 impl From<rusqlite::Error> for DBError {
34 fn from(error: rusqlite::Error) -> Self {
35 DBError::SqliteError(error)
36 }
37 }
38
39 impl From<r2d2::Error> for DBError {
40 fn from(error: r2d2::Error) -> Self {
41 DBError::R2d2Error(error)
42 }
43 }
44
45 // TODO: Is there a better solution?
46 impl DBError {
47 fn from_dyn_error(error: Box<dyn std::error::Error>) -> Self {
48 DBError::Other(error.to_string())
49 }
50 }
51
52 type Result<T> = std::result::Result<T, DBError>;
53
54 #[derive(Debug)]
55 pub enum SignUpResult {
56 UserAlreadyExists,
57 UserCreatedWaitingForValidation(String), // Validation token.
58 }
59
60 #[derive(Debug)]
61 pub enum ValidationResult {
62 UnknownUser,
63 ValidationExpired,
64 Ok(String, i32), // Returns token and user id.
65 }
66
67 #[derive(Debug)]
68 pub enum SignInResult {
69 UserNotFound,
70 PasswordsDontMatch,
71 Ok(String, i32), // Returns token and user id.
72 }
73
74 #[derive(Debug)]
75 pub enum AuthenticationResult {
76 NotValidToken,
77 Ok(i32), // Returns user id.
78 }
79
80 #[derive(Clone)]
81 pub struct Connection {
82 //con: rusqlite::Connection
83 pool: Pool<SqliteConnectionManager>
84 }
85
86 impl Connection {
87 pub fn new() -> Result<Connection> {
88 let path = Path::new(consts::DB_DIRECTORY).join(consts::DB_FILENAME);
89 Self::new_from_file(path)
90 }
91
92 pub fn new_in_memory() -> Result<Connection> {
93 Self::create_connection(SqliteConnectionManager::memory())
94 }
95
96 pub fn new_from_file<P: AsRef<Path>>(file: P) -> Result<Connection> {
97 if let Some(data_dir) = file.as_ref().parent() {
98 if !data_dir.exists() {
99 fs::DirBuilder::new().create(data_dir).unwrap();
100 }
101 }
102
103 Self::create_connection(SqliteConnectionManager::file(file))
104 }
105
106 fn create_connection(manager: SqliteConnectionManager) -> Result<Connection> {
107 let pool = r2d2::Pool::new(manager).unwrap();
108 let connection = Connection { pool };
109 connection.create_or_update()?;
110 Ok(connection)
111 }
112
113 /// Called after the connection has been established for creating or updating the database.
114 /// The 'Version' table tracks the current state of the database.
115 fn create_or_update(&self) -> Result<()> {
116 // Check the Database version.
117 let mut con = self.pool.get()?;
118 let tx = con.transaction()?;
119
120 // Version 0 corresponds to an empty database.
121 let mut version = {
122 match tx.query_row(
123 "SELECT [name] FROM [sqlite_master] WHERE [type] = 'table' AND [name] = 'Version'",
124 [],
125 |row| row.get::<usize, String>(0)
126 ) {
127 Ok(_) => tx.query_row("SELECT [version] FROM [Version] ORDER BY [id] DESC", [], |row| row.get(0)).unwrap_or_default(),
128 Err(_) => 0
129 }
130 };
131
132 while Self::update_to_next_version(version, &tx)? {
133 version += 1;
134 }
135
136 tx.commit()?;
137
138 Ok(())
139 }
140
141 fn update_to_next_version(current_version: u32, tx: &rusqlite::Transaction) -> Result<bool> {
142 let next_version = current_version + 1;
143
144 if next_version <= CURRENT_DB_VERSION {
145 println!("Update to version {}...", next_version);
146 }
147
148 fn update_version(to_version: u32, tx: &rusqlite::Transaction) -> Result<()> {
149 tx.execute("INSERT INTO [Version] ([version], [datetime]) VALUES (?1, datetime('now'))", [to_version]).map(|_| ()).map_err(DBError::from)
150 }
151
152 fn ok(updated: bool) -> Result<bool> {
153 if updated {
154 println!("Version updated");
155 }
156 Ok(updated)
157 }
158
159 match next_version {
160 1 => {
161 let sql_file = consts::SQL_FILENAME.replace("{VERSION}", &next_version.to_string());
162 tx.execute_batch(&load_sql_file(&sql_file)?)?;
163 update_version(next_version, tx)?;
164
165 ok(true)
166 }
167
168 // Version 1 doesn't exist yet.
169 2 =>
170 ok(false),
171
172 v =>
173 Err(DBError::UnsupportedVersion(v)),
174 }
175 }
176
177 pub fn get_all_recipe_titles(&self) -> Result<Vec<(i32, String)>> {
178 let con = self.pool.get()?;
179 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
180 let titles =
181 stmt.query_map([], |row| {
182 Ok((row.get(0)?, row.get(1)?))
183 })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
184 Ok(titles)
185 }
186
187 /* Not used for the moment.
188 pub fn get_all_recipes(&self) -> Result<Vec<model::Recipe>> {
189 let con = self.pool.get()?;
190 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
191 let recipes =
192 stmt.query_map([], |row| {
193 Ok(model::Recipe::new(row.get(0)?, row.get(1)?))
194 })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
195 Ok(recipes)
196 } */
197
198 pub fn get_recipe(&self, id: i32) -> Result<model::Recipe> {
199 let con = self.pool.get()?;
200 con.query_row("SELECT [id], [title] FROM [Recipe] WHERE [id] = ?1", [id], |row| {
201 Ok(model::Recipe::new(row.get(0)?, row.get(1)?))
202 }).map_err(DBError::from)
203 }
204
205 pub fn get_user_login_info(&self, token: &str) -> Result<UserLoginInfo> {
206 let con = self.pool.get()?;
207 con.query_row("SELECT [last_login_datetime], [ip], [user_agent] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| {
208 Ok(UserLoginInfo {
209 last_login_datetime: r.get("last_login_datetime")?,
210 ip: r.get("ip")?,
211 user_agent: r.get("user_agent")?,
212 })
213 }).map_err(DBError::from)
214 }
215
216 ///
217 pub fn sign_up(&self, email: &str, password: &str) -> Result<SignUpResult> {
218 self.sign_up_with_given_time(email, password, Utc::now())
219 }
220
221 fn sign_up_with_given_time(&self, email: &str, password: &str, datetime: DateTime<Utc>) -> Result<SignUpResult> {
222 let mut con = self.pool.get()?;
223 let tx = con.transaction()?;
224 let token =
225 match tx.query_row("SELECT [id], [validation_token] FROM [User] WHERE [email] = ?1", [email], |r| {
226 Ok((r.get::<&str, i32>("id")?, r.get::<&str, Option<String>>("validation_token")?))
227 }).optional()? {
228 Some((id, validation_token)) => {
229 if validation_token.is_none() {
230 return Ok(SignUpResult::UserAlreadyExists)
231 }
232 let token = generate_token();
233 let hashed_password = hash(password).map_err(|e| DBError::from_dyn_error(e))?;
234 tx.execute("UPDATE [User] SET [validation_token] = ?2, [creation_datetime] = ?3, [password] = ?4 WHERE [id] = ?1", params![id, token, datetime, hashed_password])?;
235 token
236 },
237 None => {
238 let token = generate_token();
239 let hashed_password = hash(password).map_err(|e| DBError::from_dyn_error(e))?;
240 tx.execute("INSERT INTO [User] ([email], [validation_token], [creation_datetime], [password]) VALUES (?1, ?2, ?3, ?4)", params![email, token, datetime, hashed_password])?;
241 token
242 },
243 };
244 tx.commit()?;
245 Ok(SignUpResult::UserCreatedWaitingForValidation(token))
246 }
247
248 pub fn validation(&self, token: &str, validation_time: Duration, ip: &str, user_agent: &str) -> Result<ValidationResult> {
249 let mut con = self.pool.get()?;
250 let tx = con.transaction()?;
251 let user_id =
252 match tx.query_row("SELECT [id], [creation_datetime] FROM [User] WHERE [validation_token] = ?1", [token], |r| {
253 Ok((r.get::<&str, i32>("id")?, r.get::<&str, DateTime<Utc>>("creation_datetime")?))
254 }).optional()? {
255 Some((id, creation_datetime)) => {
256 if Utc::now() - creation_datetime > validation_time {
257 return Ok(ValidationResult::ValidationExpired)
258 }
259 tx.execute("UPDATE [User] SET [validation_token] = NULL WHERE [id] = ?1", [id])?;
260 id
261 },
262 None => {
263 return Ok(ValidationResult::UnknownUser)
264 },
265 };
266 let token = Connection::create_login_token(&tx, user_id, ip, user_agent)?;
267 tx.commit()?;
268 Ok(ValidationResult::Ok(token, user_id))
269 }
270
271 pub fn sign_in(&self, password: &str, email: &str, ip: &str, user_agent: &str) -> Result<SignInResult> {
272 let mut con = self.pool.get()?;
273 let tx = con.transaction()?;
274 match tx.query_row("SELECT [id], [password] FROM [User] WHERE [email] = ?1", [email], |r| {
275 Ok((r.get::<&str, i32>("id")?, r.get::<&str, String>("password")?))
276 }).optional()? {
277 Some((id, stored_password)) => {
278 if verify_password(password, &stored_password).map_err(DBError::from_dyn_error)? {
279 let token = Connection::create_login_token(&tx, id, ip, user_agent)?;
280 tx.commit()?;
281 Ok(SignInResult::Ok(token, id))
282 } else {
283 Ok(SignInResult::PasswordsDontMatch)
284 }
285 },
286 None => {
287 Ok(SignInResult::UserNotFound)
288 },
289 }
290 }
291
292 pub fn authentication(&self, token: &str, ip: &str, user_agent: &str) -> Result<AuthenticationResult> {
293 let mut con = self.pool.get()?;
294 let tx = con.transaction()?;
295 match tx.query_row("SELECT [id], [user_id] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| {
296 Ok((r.get::<&str, i32>("id")?, r.get::<&str, i32>("user_id")?))
297 }).optional()? {
298 Some((login_id, user_id)) => {
299 tx.execute("UPDATE [UserLoginToken] SET [last_login_datetime] = ?2, [ip] = ?3, [user_agent] = ?4 WHERE [id] = ?1", params![login_id, Utc::now(), ip, user_agent])?;
300 tx.commit()?;
301 Ok(AuthenticationResult::Ok(user_id))
302 },
303 None =>
304 Ok(AuthenticationResult::NotValidToken)
305 }
306 }
307
308 pub fn sign_out(&self, token: &str) -> Result<()> {
309 let mut con = self.pool.get()?;
310 let tx = con.transaction()?;
311 match tx.query_row("SELECT [id] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| {
312 Ok(r.get::<&str, i32>("id")?)
313 }).optional()? {
314 Some(login_id) => {
315 tx.execute("DELETE FROM [UserLoginToken] WHERE [id] = ?1", params![login_id])?;
316 tx.commit()?
317 },
318 None => (),
319 }
320 Ok(())
321 }
322
323 /// Execute a given SQL file.
324 pub fn execute_file<P: AsRef<Path> + fmt::Display>(&self, file: P) -> Result<()> {
325 let con = self.pool.get()?;
326 let sql = load_sql_file(file)?;
327 con.execute_batch(&sql).map_err(DBError::from)
328 }
329
330 /// Execute any SQL statement.
331 /// Mainly used for testing.
332 pub fn execute_sql<P: Params>(&self, sql: &str, params: P) -> Result<usize> {
333 let con = self.pool.get()?;
334 con.execute(sql, params).map_err(DBError::from)
335 }
336
337 // Return the token.
338 fn create_login_token(tx: &rusqlite::Transaction, user_id: i32, ip: &str, user_agent: &str) -> Result<String> {
339 let token = generate_token();
340 tx.execute("INSERT INTO [UserLoginToken] ([user_id], [last_login_datetime], [token], [ip], [user_agent]) VALUES (?1, ?2, ?3, ?4, ?5)", params![user_id, Utc::now(), token, ip, user_agent])?;
341 Ok(token)
342 }
343 }
344
345 fn load_sql_file<P: AsRef<Path> + fmt::Display>(sql_file: P) -> Result<String> {
346 let mut file = File::open(&sql_file).map_err(|err| DBError::Other(format!("Cannot open SQL file ({}): {}", &sql_file, err.to_string())))?;
347 let mut sql = String::new();
348 file.read_to_string(&mut sql).map_err(|err| DBError::Other(format!("Cannot read SQL file ({}) : {}", &sql_file, err.to_string())))?;
349 Ok(sql)
350 }
351
352 fn generate_token() -> String {
353 Alphanumeric.sample_string(&mut rand::thread_rng(), 24)
354 }
355
356 #[cfg(test)]
357 mod tests {
358 use super::*;
359
360 #[test]
361 fn sign_up() -> Result<()> {
362 let connection = Connection::new_in_memory()?;
363 match connection.sign_up("paul@test.org", "12345")? {
364 SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case.
365 other => panic!("{:?}", other),
366 }
367 Ok(())
368 }
369
370 #[test]
371 fn sign_up_to_an_already_existing_user() -> Result<()> {
372 let connection = Connection::new_in_memory()?;
373 connection.execute_sql("
374 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
375 VALUES (
376 1,
377 'paul@test.org',
378 'paul',
379 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
380 0,
381 NULL
382 );", [])?;
383 match connection.sign_up("paul@test.org", "12345")? {
384 SignUpResult::UserAlreadyExists => (), // Nominal case.
385 other => panic!("{:?}", other),
386 }
387 Ok(())
388 }
389
390 #[test]
391 fn sign_up_to_an_unvalidated_already_existing_user() -> Result<()> {
392 let connection = Connection::new_in_memory()?;
393 let token = generate_token();
394 connection.execute_sql("
395 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
396 VALUES (
397 1,
398 'paul@test.org',
399 'paul',
400 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
401 0,
402 :token
403 );", named_params! { ":token": token })?;
404 match connection.sign_up("paul@test.org", "12345")? {
405 SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case.
406 other => panic!("{:?}", other),
407 }
408 Ok(())
409 }
410
411 #[test]
412 fn sign_up_then_send_validation_at_time() -> Result<()> {
413 let connection = Connection::new_in_memory()?;
414 let validation_token =
415 match connection.sign_up("paul@test.org", "12345")? {
416 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
417 other => panic!("{:?}", other),
418 };
419 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
420 ValidationResult::Ok(_, _) => (), // Nominal case.
421 other => panic!("{:?}", other),
422 }
423 Ok(())
424 }
425
426 #[test]
427 fn sign_up_then_send_validation_too_late() -> Result<()> {
428 let connection = Connection::new_in_memory()?;
429 let validation_token =
430 match connection.sign_up_with_given_time("paul@test.org", "12345", Utc::now() - Duration::days(1))? {
431 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
432 other => panic!("{:?}", other),
433 };
434 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
435 ValidationResult::ValidationExpired => (), // Nominal case.
436 other => panic!("{:?}", other),
437 }
438 Ok(())
439 }
440
441 #[test]
442 fn sign_up_then_send_validation_with_bad_token() -> Result<()> {
443 let connection = Connection::new_in_memory()?;
444 let _validation_token =
445 match connection.sign_up("paul@test.org", "12345")? {
446 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
447 other => panic!("{:?}", other),
448 };
449 let random_token = generate_token();
450 match connection.validation(&random_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
451 ValidationResult::UnknownUser => (), // Nominal case.
452 other => panic!("{:?}", other),
453 }
454 Ok(())
455 }
456
457 #[test]
458 fn sign_up_then_send_validation_then_sign_in() -> Result<()> {
459 let connection = Connection::new_in_memory()?;
460
461 let email = "paul@test.org";
462 let password = "12345";
463
464 // Sign up.
465 let validation_token =
466 match connection.sign_up(email, password)? {
467 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
468 other => panic!("{:?}", other),
469 };
470
471 // Validation.
472 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
473 ValidationResult::Ok(_, _) => (),
474 other => panic!("{:?}", other),
475 };
476
477 // Sign in.
478 match connection.sign_in(password, email, "127.0.0.1", "Mozilla/5.0")? {
479 SignInResult::Ok(_, _) => (), // Nominal case.
480 other => panic!("{:?}", other),
481 }
482
483 Ok(())
484 }
485
486 #[test]
487 fn sign_up_then_send_validation_then_authentication() -> Result<()> {
488 let connection = Connection::new_in_memory()?;
489
490 let email = "paul@test.org";
491 let password = "12345";
492
493 // Sign up.
494 let validation_token =
495 match connection.sign_up(email, password)? {
496 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
497 other => panic!("{:?}", other),
498 };
499
500 // Validation.
501 let (authentication_token, user_id) = match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla")? {
502 ValidationResult::Ok(token, user_id) => (token, user_id),
503 other => panic!("{:?}", other),
504 };
505
506 // Check user login information.
507 let user_login_info_1 = connection.get_user_login_info(&authentication_token)?;
508 assert_eq!(user_login_info_1.ip, "127.0.0.1");
509 assert_eq!(user_login_info_1.user_agent, "Mozilla");
510
511 // Authentication.
512 let _user_id =
513 match connection.authentication(&authentication_token, "192.168.1.1", "Chrome")? {
514 AuthenticationResult::Ok(user_id) => user_id, // Nominal case.
515 other => panic!("{:?}", other),
516 };
517
518 // Check user login information.
519 let user_login_info_2 = connection.get_user_login_info(&authentication_token)?;
520 assert_eq!(user_login_info_2.ip, "192.168.1.1");
521 assert_eq!(user_login_info_2.user_agent, "Chrome");
522
523 Ok(())
524 }
525
526 #[test]
527 fn sign_up_then_send_validation_then_sign_out_then_sign_in() -> Result<()> {
528 let connection = Connection::new_in_memory()?;
529
530 let email = "paul@test.org";
531 let password = "12345";
532
533 // Sign up.
534 let validation_token =
535 match connection.sign_up(email, password)? {
536 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
537 other => panic!("{:?}", other),
538 };
539
540 // Validation.
541 let (authentication_token_1, user_id_1) =
542 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla")? {
543 ValidationResult::Ok(token, user_id) => (token, user_id),
544 other => panic!("{:?}", other),
545 };
546
547 // Check user login information.
548 let user_login_info_1 = connection.get_user_login_info(&authentication_token_1)?;
549 assert_eq!(user_login_info_1.ip, "127.0.0.1");
550 assert_eq!(user_login_info_1.user_agent, "Mozilla");
551
552 // Sign out.
553 connection.sign_out(&authentication_token_1)?;
554
555 // Sign in.
556 let (authentication_token_2, user_id_2) =
557 match connection.sign_in(password, email, "192.168.1.1", "Chrome")? {
558 SignInResult::Ok(token, user_id) => (token, user_id),
559 other => panic!("{:?}", other),
560 };
561
562 assert_eq!(user_id_1, user_id_2);
563 assert_ne!(authentication_token_1, authentication_token_2);
564
565 // Check user login information.
566 let user_login_info_2 = connection.get_user_login_info(&authentication_token_2)?;
567
568 assert_eq!(user_login_info_2.ip, "192.168.1.1");
569 assert_eq!(user_login_info_2.user_agent, "Chrome");
570
571 Ok(())
572 }
573 }