Add asynchronous call to database.
[recipes.git] / backend / src / data / db.rs
1 use std::{fmt, fs::{self, File}, path::Path, io::Read};
2
3 use itertools::Itertools;
4 use chrono::{prelude::*, Duration};
5 use rusqlite::{named_params, OptionalExtension, params, Params};
6 use r2d2::Pool;
7 use r2d2_sqlite::SqliteConnectionManager;
8 use rand::distributions::{Alphanumeric, DistString};
9
10 use crate::{consts, user};
11 use crate::hash::{hash, verify_password};
12 use crate::model;
13 use crate::user::*;
14
15 const CURRENT_DB_VERSION: u32 = 1;
16
17 #[derive(Debug)]
18 pub enum DBError {
19 SqliteError(rusqlite::Error),
20 R2d2Error(r2d2::Error),
21 UnsupportedVersion(u32),
22 Other(String),
23 }
24
25 impl fmt::Display for DBError {
26 fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), fmt::Error> {
27 write!(f, "{:?}", self)
28 }
29 }
30
31 impl std::error::Error for DBError { }
32
33 impl From<rusqlite::Error> for DBError {
34 fn from(error: rusqlite::Error) -> Self {
35 DBError::SqliteError(error)
36 }
37 }
38
39 impl From<r2d2::Error> for DBError {
40 fn from(error: r2d2::Error) -> Self {
41 DBError::R2d2Error(error)
42 }
43 }
44
45 impl DBError {
46 fn from_dyn_error(error: Box<dyn std::error::Error>) -> Self {
47 DBError::Other(error.to_string())
48 }
49 }
50
51 type Result<T> = std::result::Result<T, DBError>;
52
53 #[derive(Debug)]
54 pub enum SignUpResult {
55 UserAlreadyExists,
56 UserCreatedWaitingForValidation(String), // Validation token.
57 }
58
59 #[derive(Debug)]
60 pub enum ValidationResult {
61 UnknownUser,
62 ValidationExpired,
63 Ok(String, i32), // Returns token and user id.
64 }
65
66 #[derive(Debug)]
67 pub enum SignInResult {
68 UserNotFound,
69 WrongPassword,
70 AccountNotValidated,
71 Ok(String, i32), // Returns token and user id.
72 }
73
74 #[derive(Debug)]
75 pub enum AuthenticationResult {
76 NotValidToken,
77 Ok(i32), // Returns user id.
78 }
79
80 #[derive(Clone)]
81 pub struct Connection {
82 //con: rusqlite::Connection
83 pool: Pool<SqliteConnectionManager>
84 }
85
86 impl Connection {
87 pub fn new() -> Result<Connection> {
88 let path = Path::new(consts::DB_DIRECTORY).join(consts::DB_FILENAME);
89 Self::new_from_file(path)
90 }
91
92 pub fn new_in_memory() -> Result<Connection> {
93 Self::create_connection(SqliteConnectionManager::memory())
94 }
95
96 pub fn new_from_file<P: AsRef<Path>>(file: P) -> Result<Connection> {
97 if let Some(data_dir) = file.as_ref().parent() {
98 if !data_dir.exists() {
99 fs::DirBuilder::new().create(data_dir).unwrap();
100 }
101 }
102
103 Self::create_connection(SqliteConnectionManager::file(file))
104 }
105
106 fn create_connection(manager: SqliteConnectionManager) -> Result<Connection> {
107 let pool = r2d2::Pool::new(manager).unwrap();
108 let connection = Connection { pool };
109 connection.create_or_update()?;
110 Ok(connection)
111 }
112
113 /// Called after the connection has been established for creating or updating the database.
114 /// The 'Version' table tracks the current state of the database.
115 fn create_or_update(&self) -> Result<()> {
116 // Check the Database version.
117 let mut con = self.pool.get()?;
118 let tx = con.transaction()?;
119
120 // Version 0 corresponds to an empty database.
121 let mut version = {
122 match tx.query_row(
123 "SELECT [name] FROM [sqlite_master] WHERE [type] = 'table' AND [name] = 'Version'",
124 [],
125 |row| row.get::<usize, String>(0)
126 ) {
127 Ok(_) => tx.query_row("SELECT [version] FROM [Version] ORDER BY [id] DESC", [], |row| row.get(0)).unwrap_or_default(),
128 Err(_) => 0
129 }
130 };
131
132 while Self::update_to_next_version(version, &tx)? {
133 version += 1;
134 }
135
136 tx.commit()?;
137
138 Ok(())
139 }
140
141 fn update_to_next_version(current_version: u32, tx: &rusqlite::Transaction) -> Result<bool> {
142 let next_version = current_version + 1;
143
144 if next_version <= CURRENT_DB_VERSION {
145 println!("Update to version {}...", next_version);
146 }
147
148 fn update_version(to_version: u32, tx: &rusqlite::Transaction) -> Result<()> {
149 tx.execute("INSERT INTO [Version] ([version], [datetime]) VALUES (?1, datetime('now'))", [to_version]).map(|_| ()).map_err(DBError::from)
150 }
151
152 fn ok(updated: bool) -> Result<bool> {
153 if updated {
154 println!("Version updated");
155 }
156 Ok(updated)
157 }
158
159 match next_version {
160 1 => {
161 let sql_file = consts::SQL_FILENAME.replace("{VERSION}", &next_version.to_string());
162 tx.execute_batch(&load_sql_file(&sql_file)?)?;
163 update_version(next_version, tx)?;
164
165 ok(true)
166 }
167
168 // Version 1 doesn't exist yet.
169 2 =>
170 ok(false),
171
172 v =>
173 Err(DBError::UnsupportedVersion(v)),
174 }
175 }
176
177 pub fn get_all_recipe_titles(&self) -> Result<Vec<(i32, String)>> {
178 let con = self.pool.get()?;
179
180 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
181
182 let titles: std::result::Result<Vec<(i32, String)>, rusqlite::Error> =
183 stmt.query_map([], |row| {
184 Ok((row.get("id")?, row.get("title")?))
185 })?.collect();
186
187 titles.map_err(DBError::from)
188 }
189
190 /* Not used for the moment.
191 pub fn get_all_recipes(&self) -> Result<Vec<model::Recipe>> {
192 let con = self.pool.get()?;
193 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
194 let recipes =
195 stmt.query_map([], |row| {
196 Ok(model::Recipe::new(row.get(0)?, row.get(1)?))
197 })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
198 Ok(recipes)
199 } */
200
201 pub fn get_recipe(&self, id: i32) -> Result<model::Recipe> {
202 let con = self.pool.get()?;
203 con.query_row("SELECT [id], [title], [description] FROM [Recipe] WHERE [id] = ?1", [id], |row| {
204 Ok(model::Recipe::new(row.get("id")?, row.get("title")?, row.get("description")?))
205 }).map_err(DBError::from)
206 }
207
208 pub fn get_user_login_info(&self, token: &str) -> Result<UserLoginInfo> {
209 let con = self.pool.get()?;
210 con.query_row("SELECT [last_login_datetime], [ip], [user_agent] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| {
211 Ok(UserLoginInfo {
212 last_login_datetime: r.get("last_login_datetime")?,
213 ip: r.get("ip")?,
214 user_agent: r.get("user_agent")?,
215 })
216 }).map_err(DBError::from)
217 }
218
219 pub fn load_user(&self, user_id: i32) -> Result<User> {
220 let con = self.pool.get()?;
221 con.query_row("SELECT [email] FROM [User] WHERE [id] = ?1", [user_id], |r| {
222 Ok(User {
223 email: r.get("email")?,
224 })
225 }).map_err(DBError::from)
226 }
227
228 pub fn sign_up(&self, email: &str, password: &str) -> Result<SignUpResult> {
229 self.sign_up_with_given_time(email, password, Utc::now())
230 }
231
232 fn sign_up_with_given_time(&self, email: &str, password: &str, datetime: DateTime<Utc>) -> Result<SignUpResult> {
233 let mut con = self.pool.get()?;
234 let tx = con.transaction()?;
235 let token =
236 match tx.query_row("SELECT [id], [validation_token] FROM [User] WHERE [email] = ?1", [email], |r| {
237 Ok((r.get::<&str, i32>("id")?, r.get::<&str, Option<String>>("validation_token")?))
238 }).optional()? {
239 Some((id, validation_token)) => {
240 if validation_token.is_none() {
241 return Ok(SignUpResult::UserAlreadyExists)
242 }
243 let token = generate_token();
244 let hashed_password = hash(password).map_err(|e| DBError::from_dyn_error(e))?;
245 tx.execute("UPDATE [User] SET [validation_token] = ?2, [creation_datetime] = ?3, [password] = ?4 WHERE [id] = ?1", params![id, token, datetime, hashed_password])?;
246 token
247 },
248 None => {
249 let token = generate_token();
250 let hashed_password = hash(password).map_err(|e| DBError::from_dyn_error(e))?;
251 tx.execute("INSERT INTO [User] ([email], [validation_token], [creation_datetime], [password]) VALUES (?1, ?2, ?3, ?4)", params![email, token, datetime, hashed_password])?;
252 token
253 },
254 };
255 tx.commit()?;
256 Ok(SignUpResult::UserCreatedWaitingForValidation(token))
257 }
258
259 pub fn validation(&self, token: &str, validation_time: Duration, ip: &str, user_agent: &str) -> Result<ValidationResult> {
260 let mut con = self.pool.get()?;
261 let tx = con.transaction()?;
262 let user_id =
263 match tx.query_row("SELECT [id], [creation_datetime] FROM [User] WHERE [validation_token] = ?1", [token], |r| {
264 Ok((r.get::<&str, i32>("id")?, r.get::<&str, DateTime<Utc>>("creation_datetime")?))
265 }).optional()? {
266 Some((id, creation_datetime)) => {
267 if Utc::now() - creation_datetime > validation_time {
268 return Ok(ValidationResult::ValidationExpired)
269 }
270 tx.execute("UPDATE [User] SET [validation_token] = NULL WHERE [id] = ?1", [id])?;
271 id
272 },
273 None => {
274 return Ok(ValidationResult::UnknownUser)
275 },
276 };
277 let token = Connection::create_login_token(&tx, user_id, ip, user_agent)?;
278 tx.commit()?;
279 Ok(ValidationResult::Ok(token, user_id))
280 }
281
282 pub fn sign_in(&self, email: &str, password: &str, ip: &str, user_agent: &str) -> Result<SignInResult> {
283 let mut con = self.pool.get()?;
284 let tx = con.transaction()?;
285 match tx.query_row("SELECT [id], [password], [validation_token] FROM [User] WHERE [email] = ?1", [email], |r| {
286 Ok((r.get::<&str, i32>("id")?, r.get::<&str, String>("password")?, r.get::<&str, Option<String>>("validation_token")?))
287 }).optional()? {
288 Some((id, stored_password, validation_token)) => {
289 if validation_token.is_some() {
290 Ok(SignInResult::AccountNotValidated)
291 } else if verify_password(password, &stored_password).map_err(DBError::from_dyn_error)? {
292 let token = Connection::create_login_token(&tx, id, ip, user_agent)?;
293 tx.commit()?;
294 Ok(SignInResult::Ok(token, id))
295 } else {
296 Ok(SignInResult::WrongPassword)
297 }
298 },
299 None => {
300 Ok(SignInResult::UserNotFound)
301 },
302 }
303 }
304
305 pub fn authentication(&self, token: &str, ip: &str, user_agent: &str) -> Result<AuthenticationResult> {
306 let mut con = self.pool.get()?;
307 let tx = con.transaction()?;
308 match tx.query_row("SELECT [id], [user_id] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| {
309 Ok((r.get::<&str, i32>("id")?, r.get::<&str, i32>("user_id")?))
310 }).optional()? {
311 Some((login_id, user_id)) => {
312 tx.execute("UPDATE [UserLoginToken] SET [last_login_datetime] = ?2, [ip] = ?3, [user_agent] = ?4 WHERE [id] = ?1", params![login_id, Utc::now(), ip, user_agent])?;
313 tx.commit()?;
314 Ok(AuthenticationResult::Ok(user_id))
315 },
316 None =>
317 Ok(AuthenticationResult::NotValidToken)
318 }
319 }
320
321 pub fn sign_out(&self, token: &str) -> Result<()> {
322 let mut con = self.pool.get()?;
323 let tx = con.transaction()?;
324 match tx.query_row("SELECT [id] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| {
325 Ok(r.get::<&str, i32>("id")?)
326 }).optional()? {
327 Some(login_id) => {
328 tx.execute("DELETE FROM [UserLoginToken] WHERE [id] = ?1", params![login_id])?;
329 tx.commit()?
330 },
331 None => (),
332 }
333 Ok(())
334 }
335
336 /// Execute a given SQL file.
337 pub fn execute_file<P: AsRef<Path> + fmt::Display>(&self, file: P) -> Result<()> {
338 let con = self.pool.get()?;
339 let sql = load_sql_file(file)?;
340 con.execute_batch(&sql).map_err(DBError::from)
341 }
342
343 /// Execute any SQL statement.
344 /// Mainly used for testing.
345 pub fn execute_sql<P: Params>(&self, sql: &str, params: P) -> Result<usize> {
346 let con = self.pool.get()?;
347 con.execute(sql, params).map_err(DBError::from)
348 }
349
350 // Return the token.
351 fn create_login_token(tx: &rusqlite::Transaction, user_id: i32, ip: &str, user_agent: &str) -> Result<String> {
352 let token = generate_token();
353 tx.execute("INSERT INTO [UserLoginToken] ([user_id], [last_login_datetime], [token], [ip], [user_agent]) VALUES (?1, ?2, ?3, ?4, ?5)", params![user_id, Utc::now(), token, ip, user_agent])?;
354 Ok(token)
355 }
356 }
357
358 fn load_sql_file<P: AsRef<Path> + fmt::Display>(sql_file: P) -> Result<String> {
359 let mut file = File::open(&sql_file).map_err(|err| DBError::Other(format!("Cannot open SQL file ({}): {}", &sql_file, err.to_string())))?;
360 let mut sql = String::new();
361 file.read_to_string(&mut sql).map_err(|err| DBError::Other(format!("Cannot read SQL file ({}) : {}", &sql_file, err.to_string())))?;
362 Ok(sql)
363 }
364
365 fn generate_token() -> String {
366 Alphanumeric.sample_string(&mut rand::thread_rng(), consts::AUTHENTICATION_TOKEN_SIZE)
367 }
368
369 #[cfg(test)]
370 mod tests {
371 use super::*;
372
373 #[test]
374 fn sign_up() -> Result<()> {
375 let connection = Connection::new_in_memory()?;
376 match connection.sign_up("paul@test.org", "12345")? {
377 SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case.
378 other => panic!("{:?}", other),
379 }
380 Ok(())
381 }
382
383 #[test]
384 fn sign_up_to_an_already_existing_user() -> Result<()> {
385 let connection = Connection::new_in_memory()?;
386 connection.execute_sql("
387 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
388 VALUES (
389 1,
390 'paul@test.org',
391 'paul',
392 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
393 0,
394 NULL
395 );", [])?;
396 match connection.sign_up("paul@test.org", "12345")? {
397 SignUpResult::UserAlreadyExists => (), // Nominal case.
398 other => panic!("{:?}", other),
399 }
400 Ok(())
401 }
402
403 #[test]
404 fn sign_up_and_sign_in_without_validation() -> Result<()> {
405 let connection = Connection::new_in_memory()?;
406
407 let email = "paul@test.org";
408 let password = "12345";
409
410 match connection.sign_up(email, password)? {
411 SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case.
412 other => panic!("{:?}", other),
413 }
414
415 match connection.sign_in(email, password, "127.0.0.1", "Mozilla/5.0")? {
416 SignInResult::AccountNotValidated => (), // Nominal case.
417 other => panic!("{:?}", other),
418 }
419
420 Ok(())
421 }
422
423 #[test]
424 fn sign_up_to_an_unvalidated_already_existing_user() -> Result<()> {
425 let connection = Connection::new_in_memory()?;
426 let token = generate_token();
427 connection.execute_sql("
428 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
429 VALUES (
430 1,
431 'paul@test.org',
432 'paul',
433 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
434 0,
435 :token
436 );", named_params! { ":token": token })?;
437 match connection.sign_up("paul@test.org", "12345")? {
438 SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case.
439 other => panic!("{:?}", other),
440 }
441 Ok(())
442 }
443
444 #[test]
445 fn sign_up_then_send_validation_at_time() -> Result<()> {
446 let connection = Connection::new_in_memory()?;
447 let validation_token =
448 match connection.sign_up("paul@test.org", "12345")? {
449 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
450 other => panic!("{:?}", other),
451 };
452 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
453 ValidationResult::Ok(_, _) => (), // Nominal case.
454 other => panic!("{:?}", other),
455 }
456 Ok(())
457 }
458
459 #[test]
460 fn sign_up_then_send_validation_too_late() -> Result<()> {
461 let connection = Connection::new_in_memory()?;
462 let validation_token =
463 match connection.sign_up_with_given_time("paul@test.org", "12345", Utc::now() - Duration::days(1))? {
464 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
465 other => panic!("{:?}", other),
466 };
467 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
468 ValidationResult::ValidationExpired => (), // Nominal case.
469 other => panic!("{:?}", other),
470 }
471 Ok(())
472 }
473
474 #[test]
475 fn sign_up_then_send_validation_with_bad_token() -> Result<()> {
476 let connection = Connection::new_in_memory()?;
477 let _validation_token =
478 match connection.sign_up("paul@test.org", "12345")? {
479 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
480 other => panic!("{:?}", other),
481 };
482 let random_token = generate_token();
483 match connection.validation(&random_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
484 ValidationResult::UnknownUser => (), // Nominal case.
485 other => panic!("{:?}", other),
486 }
487 Ok(())
488 }
489
490 #[test]
491 fn sign_up_then_send_validation_then_sign_in() -> Result<()> {
492 let connection = Connection::new_in_memory()?;
493
494 let email = "paul@test.org";
495 let password = "12345";
496
497 // Sign up.
498 let validation_token =
499 match connection.sign_up(email, password)? {
500 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
501 other => panic!("{:?}", other),
502 };
503
504 // Validation.
505 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
506 ValidationResult::Ok(_, _) => (),
507 other => panic!("{:?}", other),
508 };
509
510 // Sign in.
511 match connection.sign_in(email, password, "127.0.0.1", "Mozilla/5.0")? {
512 SignInResult::Ok(_, _) => (), // Nominal case.
513 other => panic!("{:?}", other),
514 }
515
516 Ok(())
517 }
518
519 #[test]
520 fn sign_up_then_send_validation_then_authentication() -> Result<()> {
521 let connection = Connection::new_in_memory()?;
522
523 let email = "paul@test.org";
524 let password = "12345";
525
526 // Sign up.
527 let validation_token =
528 match connection.sign_up(email, password)? {
529 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
530 other => panic!("{:?}", other),
531 };
532
533 // Validation.
534 let (authentication_token, user_id) = match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla")? {
535 ValidationResult::Ok(token, user_id) => (token, user_id),
536 other => panic!("{:?}", other),
537 };
538
539 // Check user login information.
540 let user_login_info_1 = connection.get_user_login_info(&authentication_token)?;
541 assert_eq!(user_login_info_1.ip, "127.0.0.1");
542 assert_eq!(user_login_info_1.user_agent, "Mozilla");
543
544 // Authentication.
545 let _user_id =
546 match connection.authentication(&authentication_token, "192.168.1.1", "Chrome")? {
547 AuthenticationResult::Ok(user_id) => user_id, // Nominal case.
548 other => panic!("{:?}", other),
549 };
550
551 // Check user login information.
552 let user_login_info_2 = connection.get_user_login_info(&authentication_token)?;
553 assert_eq!(user_login_info_2.ip, "192.168.1.1");
554 assert_eq!(user_login_info_2.user_agent, "Chrome");
555
556 Ok(())
557 }
558
559 #[test]
560 fn sign_up_then_send_validation_then_sign_out_then_sign_in() -> Result<()> {
561 let connection = Connection::new_in_memory()?;
562
563 let email = "paul@test.org";
564 let password = "12345";
565
566 // Sign up.
567 let validation_token =
568 match connection.sign_up(email, password)? {
569 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
570 other => panic!("{:?}", other),
571 };
572
573 // Validation.
574 let (authentication_token_1, user_id_1) =
575 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla")? {
576 ValidationResult::Ok(token, user_id) => (token, user_id),
577 other => panic!("{:?}", other),
578 };
579
580 // Check user login information.
581 let user_login_info_1 = connection.get_user_login_info(&authentication_token_1)?;
582 assert_eq!(user_login_info_1.ip, "127.0.0.1");
583 assert_eq!(user_login_info_1.user_agent, "Mozilla");
584
585 // Sign out.
586 connection.sign_out(&authentication_token_1)?;
587
588 // Sign in.
589 let (authentication_token_2, user_id_2) =
590 match connection.sign_in(email, password, "192.168.1.1", "Chrome")? {
591 SignInResult::Ok(token, user_id) => (token, user_id),
592 other => panic!("{:?}", other),
593 };
594
595 assert_eq!(user_id_1, user_id_2);
596 assert_ne!(authentication_token_1, authentication_token_2);
597
598 // Check user login information.
599 let user_login_info_2 = connection.get_user_login_info(&authentication_token_2)?;
600
601 assert_eq!(user_login_info_2.ip, "192.168.1.1");
602 assert_eq!(user_login_info_2.user_agent, "Chrome");
603
604 Ok(())
605 }
606 }