e38aae7f2715174855c435c992ae7201d5eb8d7e
[recipes.git] / backend / src / db.rs
1 use std::{fmt, fs::{self, File}, path::Path, io::Read};
2
3 use itertools::Itertools;
4 use chrono::{prelude::*, Duration};
5 use rusqlite::{named_params, OptionalExtension, params, Params};
6 use r2d2::Pool;
7 use r2d2_sqlite::SqliteConnectionManager;
8 use rand::distributions::{Alphanumeric, DistString};
9
10 use crate::{consts, user};
11 use crate::hash::{hash, verify_password};
12 use crate::model;
13 use crate::user::*;
14
15 const CURRENT_DB_VERSION: u32 = 1;
16
17 #[derive(Debug)]
18 pub enum DBError {
19 SqliteError(rusqlite::Error),
20 R2d2Error(r2d2::Error),
21 UnsupportedVersion(u32),
22 Other(String),
23 }
24
25 impl fmt::Display for DBError {
26 fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), fmt::Error> {
27 write!(f, "{:?}", self)
28 }
29 }
30
31 impl std::error::Error for DBError { }
32
33 impl From<rusqlite::Error> for DBError {
34 fn from(error: rusqlite::Error) -> Self {
35 DBError::SqliteError(error)
36 }
37 }
38
39 impl From<r2d2::Error> for DBError {
40 fn from(error: r2d2::Error) -> Self {
41 DBError::R2d2Error(error)
42 }
43 }
44
45 // TODO: Is there a better solution?
46 impl DBError {
47 fn from_dyn_error(error: Box<dyn std::error::Error>) -> Self {
48 DBError::Other(error.to_string())
49 }
50 }
51
52 type Result<T> = std::result::Result<T, DBError>;
53
54 #[derive(Debug)]
55 pub enum SignUpResult {
56 UserAlreadyExists,
57 UserCreatedWaitingForValidation(String), // Validation token.
58 }
59
60 #[derive(Debug)]
61 pub enum ValidationResult {
62 UnknownUser,
63 ValidationExpired,
64 Ok(String, i32), // Returns token and user id.
65 }
66
67 #[derive(Debug)]
68 pub enum SignInResult {
69 UserNotFound,
70 WrongPassword,
71 AccountNotValidated,
72 Ok(String, i32), // Returns token and user id.
73 }
74
75 #[derive(Debug)]
76 pub enum AuthenticationResult {
77 NotValidToken,
78 Ok(i32), // Returns user id.
79 }
80
81 #[derive(Clone)]
82 pub struct Connection {
83 //con: rusqlite::Connection
84 pool: Pool<SqliteConnectionManager>
85 }
86
87 impl Connection {
88 pub fn new() -> Result<Connection> {
89 let path = Path::new(consts::DB_DIRECTORY).join(consts::DB_FILENAME);
90 Self::new_from_file(path)
91 }
92
93 pub fn new_in_memory() -> Result<Connection> {
94 Self::create_connection(SqliteConnectionManager::memory())
95 }
96
97 pub fn new_from_file<P: AsRef<Path>>(file: P) -> Result<Connection> {
98 if let Some(data_dir) = file.as_ref().parent() {
99 if !data_dir.exists() {
100 fs::DirBuilder::new().create(data_dir).unwrap();
101 }
102 }
103
104 Self::create_connection(SqliteConnectionManager::file(file))
105 }
106
107 fn create_connection(manager: SqliteConnectionManager) -> Result<Connection> {
108 let pool = r2d2::Pool::new(manager).unwrap();
109 let connection = Connection { pool };
110 connection.create_or_update()?;
111 Ok(connection)
112 }
113
114 /// Called after the connection has been established for creating or updating the database.
115 /// The 'Version' table tracks the current state of the database.
116 fn create_or_update(&self) -> Result<()> {
117 // Check the Database version.
118 let mut con = self.pool.get()?;
119 let tx = con.transaction()?;
120
121 // Version 0 corresponds to an empty database.
122 let mut version = {
123 match tx.query_row(
124 "SELECT [name] FROM [sqlite_master] WHERE [type] = 'table' AND [name] = 'Version'",
125 [],
126 |row| row.get::<usize, String>(0)
127 ) {
128 Ok(_) => tx.query_row("SELECT [version] FROM [Version] ORDER BY [id] DESC", [], |row| row.get(0)).unwrap_or_default(),
129 Err(_) => 0
130 }
131 };
132
133 while Self::update_to_next_version(version, &tx)? {
134 version += 1;
135 }
136
137 tx.commit()?;
138
139 Ok(())
140 }
141
142 fn update_to_next_version(current_version: u32, tx: &rusqlite::Transaction) -> Result<bool> {
143 let next_version = current_version + 1;
144
145 if next_version <= CURRENT_DB_VERSION {
146 println!("Update to version {}...", next_version);
147 }
148
149 fn update_version(to_version: u32, tx: &rusqlite::Transaction) -> Result<()> {
150 tx.execute("INSERT INTO [Version] ([version], [datetime]) VALUES (?1, datetime('now'))", [to_version]).map(|_| ()).map_err(DBError::from)
151 }
152
153 fn ok(updated: bool) -> Result<bool> {
154 if updated {
155 println!("Version updated");
156 }
157 Ok(updated)
158 }
159
160 match next_version {
161 1 => {
162 let sql_file = consts::SQL_FILENAME.replace("{VERSION}", &next_version.to_string());
163 tx.execute_batch(&load_sql_file(&sql_file)?)?;
164 update_version(next_version, tx)?;
165
166 ok(true)
167 }
168
169 // Version 1 doesn't exist yet.
170 2 =>
171 ok(false),
172
173 v =>
174 Err(DBError::UnsupportedVersion(v)),
175 }
176 }
177
178 pub fn get_all_recipe_titles(&self) -> Result<Vec<(i32, String)>> {
179 let con = self.pool.get()?;
180 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
181 let titles =
182 stmt.query_map([], |row| {
183 Ok((row.get(0)?, row.get(1)?))
184 })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
185 Ok(titles)
186 }
187
188 /* Not used for the moment.
189 pub fn get_all_recipes(&self) -> Result<Vec<model::Recipe>> {
190 let con = self.pool.get()?;
191 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
192 let recipes =
193 stmt.query_map([], |row| {
194 Ok(model::Recipe::new(row.get(0)?, row.get(1)?))
195 })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
196 Ok(recipes)
197 } */
198
199 pub fn get_recipe(&self, id: i32) -> Result<model::Recipe> {
200 let con = self.pool.get()?;
201 con.query_row("SELECT [id], [title], [description] FROM [Recipe] WHERE [id] = ?1", [id], |row| {
202 Ok(model::Recipe::new(row.get("id")?, row.get("title")?, row.get("description")?))
203 }).map_err(DBError::from)
204 }
205
206 pub fn get_user_login_info(&self, token: &str) -> Result<UserLoginInfo> {
207 let con = self.pool.get()?;
208 con.query_row("SELECT [last_login_datetime], [ip], [user_agent] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| {
209 Ok(UserLoginInfo {
210 last_login_datetime: r.get("last_login_datetime")?,
211 ip: r.get("ip")?,
212 user_agent: r.get("user_agent")?,
213 })
214 }).map_err(DBError::from)
215 }
216
217 pub fn load_user(&self, user_id: i32) -> Result<User> {
218 let con = self.pool.get()?;
219 con.query_row("SELECT [email] FROM [User] WHERE [id] = ?1", [user_id], |r| {
220 Ok(User {
221 email: r.get("email")?,
222 })
223 }).map_err(DBError::from)
224 }
225
226 ///
227 pub fn sign_up(&self, email: &str, password: &str) -> Result<SignUpResult> {
228 self.sign_up_with_given_time(email, password, Utc::now())
229 }
230
231 fn sign_up_with_given_time(&self, email: &str, password: &str, datetime: DateTime<Utc>) -> Result<SignUpResult> {
232 let mut con = self.pool.get()?;
233 let tx = con.transaction()?;
234 let token =
235 match tx.query_row("SELECT [id], [validation_token] FROM [User] WHERE [email] = ?1", [email], |r| {
236 Ok((r.get::<&str, i32>("id")?, r.get::<&str, Option<String>>("validation_token")?))
237 }).optional()? {
238 Some((id, validation_token)) => {
239 if validation_token.is_none() {
240 return Ok(SignUpResult::UserAlreadyExists)
241 }
242 let token = generate_token();
243 let hashed_password = hash(password).map_err(|e| DBError::from_dyn_error(e))?;
244 tx.execute("UPDATE [User] SET [validation_token] = ?2, [creation_datetime] = ?3, [password] = ?4 WHERE [id] = ?1", params![id, token, datetime, hashed_password])?;
245 token
246 },
247 None => {
248 let token = generate_token();
249 let hashed_password = hash(password).map_err(|e| DBError::from_dyn_error(e))?;
250 tx.execute("INSERT INTO [User] ([email], [validation_token], [creation_datetime], [password]) VALUES (?1, ?2, ?3, ?4)", params![email, token, datetime, hashed_password])?;
251 token
252 },
253 };
254 tx.commit()?;
255 Ok(SignUpResult::UserCreatedWaitingForValidation(token))
256 }
257
258 pub fn validation(&self, token: &str, validation_time: Duration, ip: &str, user_agent: &str) -> Result<ValidationResult> {
259 let mut con = self.pool.get()?;
260 let tx = con.transaction()?;
261 let user_id =
262 match tx.query_row("SELECT [id], [creation_datetime] FROM [User] WHERE [validation_token] = ?1", [token], |r| {
263 Ok((r.get::<&str, i32>("id")?, r.get::<&str, DateTime<Utc>>("creation_datetime")?))
264 }).optional()? {
265 Some((id, creation_datetime)) => {
266 if Utc::now() - creation_datetime > validation_time {
267 return Ok(ValidationResult::ValidationExpired)
268 }
269 tx.execute("UPDATE [User] SET [validation_token] = NULL WHERE [id] = ?1", [id])?;
270 id
271 },
272 None => {
273 return Ok(ValidationResult::UnknownUser)
274 },
275 };
276 let token = Connection::create_login_token(&tx, user_id, ip, user_agent)?;
277 tx.commit()?;
278 Ok(ValidationResult::Ok(token, user_id))
279 }
280
281 pub fn sign_in(&self, email: &str, password: &str, ip: &str, user_agent: &str) -> Result<SignInResult> {
282 let mut con = self.pool.get()?;
283 let tx = con.transaction()?;
284 match tx.query_row("SELECT [id], [password], [validation_token] FROM [User] WHERE [email] = ?1", [email], |r| {
285 Ok((r.get::<&str, i32>("id")?, r.get::<&str, String>("password")?, r.get::<&str, Option<String>>("validation_token")?))
286 }).optional()? {
287 Some((id, stored_password, validation_token)) => {
288 if validation_token.is_some() {
289 Ok(SignInResult::AccountNotValidated)
290 } else if verify_password(password, &stored_password).map_err(DBError::from_dyn_error)? {
291 let token = Connection::create_login_token(&tx, id, ip, user_agent)?;
292 tx.commit()?;
293 Ok(SignInResult::Ok(token, id))
294 } else {
295 Ok(SignInResult::WrongPassword)
296 }
297 },
298 None => {
299 Ok(SignInResult::UserNotFound)
300 },
301 }
302 }
303
304 pub fn authentication(&self, token: &str, ip: &str, user_agent: &str) -> Result<AuthenticationResult> {
305 let mut con = self.pool.get()?;
306 let tx = con.transaction()?;
307 match tx.query_row("SELECT [id], [user_id] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| {
308 Ok((r.get::<&str, i32>("id")?, r.get::<&str, i32>("user_id")?))
309 }).optional()? {
310 Some((login_id, user_id)) => {
311 tx.execute("UPDATE [UserLoginToken] SET [last_login_datetime] = ?2, [ip] = ?3, [user_agent] = ?4 WHERE [id] = ?1", params![login_id, Utc::now(), ip, user_agent])?;
312 tx.commit()?;
313 Ok(AuthenticationResult::Ok(user_id))
314 },
315 None =>
316 Ok(AuthenticationResult::NotValidToken)
317 }
318 }
319
320 pub fn sign_out(&self, token: &str) -> Result<()> {
321 let mut con = self.pool.get()?;
322 let tx = con.transaction()?;
323 match tx.query_row("SELECT [id] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| {
324 Ok(r.get::<&str, i32>("id")?)
325 }).optional()? {
326 Some(login_id) => {
327 tx.execute("DELETE FROM [UserLoginToken] WHERE [id] = ?1", params![login_id])?;
328 tx.commit()?
329 },
330 None => (),
331 }
332 Ok(())
333 }
334
335 /// Execute a given SQL file.
336 pub fn execute_file<P: AsRef<Path> + fmt::Display>(&self, file: P) -> Result<()> {
337 let con = self.pool.get()?;
338 let sql = load_sql_file(file)?;
339 con.execute_batch(&sql).map_err(DBError::from)
340 }
341
342 /// Execute any SQL statement.
343 /// Mainly used for testing.
344 pub fn execute_sql<P: Params>(&self, sql: &str, params: P) -> Result<usize> {
345 let con = self.pool.get()?;
346 con.execute(sql, params).map_err(DBError::from)
347 }
348
349 // Return the token.
350 fn create_login_token(tx: &rusqlite::Transaction, user_id: i32, ip: &str, user_agent: &str) -> Result<String> {
351 let token = generate_token();
352 tx.execute("INSERT INTO [UserLoginToken] ([user_id], [last_login_datetime], [token], [ip], [user_agent]) VALUES (?1, ?2, ?3, ?4, ?5)", params![user_id, Utc::now(), token, ip, user_agent])?;
353 Ok(token)
354 }
355 }
356
357 fn load_sql_file<P: AsRef<Path> + fmt::Display>(sql_file: P) -> Result<String> {
358 let mut file = File::open(&sql_file).map_err(|err| DBError::Other(format!("Cannot open SQL file ({}): {}", &sql_file, err.to_string())))?;
359 let mut sql = String::new();
360 file.read_to_string(&mut sql).map_err(|err| DBError::Other(format!("Cannot read SQL file ({}) : {}", &sql_file, err.to_string())))?;
361 Ok(sql)
362 }
363
364 fn generate_token() -> String {
365 Alphanumeric.sample_string(&mut rand::thread_rng(), 24)
366 }
367
368 #[cfg(test)]
369 mod tests {
370 use super::*;
371
372 #[test]
373 fn sign_up() -> Result<()> {
374 let connection = Connection::new_in_memory()?;
375 match connection.sign_up("paul@test.org", "12345")? {
376 SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case.
377 other => panic!("{:?}", other),
378 }
379 Ok(())
380 }
381
382 #[test]
383 fn sign_up_to_an_already_existing_user() -> Result<()> {
384 let connection = Connection::new_in_memory()?;
385 connection.execute_sql("
386 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
387 VALUES (
388 1,
389 'paul@test.org',
390 'paul',
391 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
392 0,
393 NULL
394 );", [])?;
395 match connection.sign_up("paul@test.org", "12345")? {
396 SignUpResult::UserAlreadyExists => (), // Nominal case.
397 other => panic!("{:?}", other),
398 }
399 Ok(())
400 }
401
402 #[test]
403 fn sign_up_and_sign_in_without_validation() -> Result<()> {
404 let connection = Connection::new_in_memory()?;
405
406 let email = "paul@test.org";
407 let password = "12345";
408
409 match connection.sign_up(email, password)? {
410 SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case.
411 other => panic!("{:?}", other),
412 }
413
414 match connection.sign_in(email, password, "127.0.0.1", "Mozilla/5.0")? {
415 SignInResult::AccountNotValidated => (), // Nominal case.
416 other => panic!("{:?}", other),
417 }
418
419 Ok(())
420 }
421
422 #[test]
423 fn sign_up_to_an_unvalidated_already_existing_user() -> Result<()> {
424 let connection = Connection::new_in_memory()?;
425 let token = generate_token();
426 connection.execute_sql("
427 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
428 VALUES (
429 1,
430 'paul@test.org',
431 'paul',
432 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
433 0,
434 :token
435 );", named_params! { ":token": token })?;
436 match connection.sign_up("paul@test.org", "12345")? {
437 SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case.
438 other => panic!("{:?}", other),
439 }
440 Ok(())
441 }
442
443 #[test]
444 fn sign_up_then_send_validation_at_time() -> Result<()> {
445 let connection = Connection::new_in_memory()?;
446 let validation_token =
447 match connection.sign_up("paul@test.org", "12345")? {
448 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
449 other => panic!("{:?}", other),
450 };
451 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
452 ValidationResult::Ok(_, _) => (), // Nominal case.
453 other => panic!("{:?}", other),
454 }
455 Ok(())
456 }
457
458 #[test]
459 fn sign_up_then_send_validation_too_late() -> Result<()> {
460 let connection = Connection::new_in_memory()?;
461 let validation_token =
462 match connection.sign_up_with_given_time("paul@test.org", "12345", Utc::now() - Duration::days(1))? {
463 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
464 other => panic!("{:?}", other),
465 };
466 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
467 ValidationResult::ValidationExpired => (), // Nominal case.
468 other => panic!("{:?}", other),
469 }
470 Ok(())
471 }
472
473 #[test]
474 fn sign_up_then_send_validation_with_bad_token() -> Result<()> {
475 let connection = Connection::new_in_memory()?;
476 let _validation_token =
477 match connection.sign_up("paul@test.org", "12345")? {
478 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
479 other => panic!("{:?}", other),
480 };
481 let random_token = generate_token();
482 match connection.validation(&random_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
483 ValidationResult::UnknownUser => (), // Nominal case.
484 other => panic!("{:?}", other),
485 }
486 Ok(())
487 }
488
489 #[test]
490 fn sign_up_then_send_validation_then_sign_in() -> Result<()> {
491 let connection = Connection::new_in_memory()?;
492
493 let email = "paul@test.org";
494 let password = "12345";
495
496 // Sign up.
497 let validation_token =
498 match connection.sign_up(email, password)? {
499 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
500 other => panic!("{:?}", other),
501 };
502
503 // Validation.
504 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
505 ValidationResult::Ok(_, _) => (),
506 other => panic!("{:?}", other),
507 };
508
509 // Sign in.
510 match connection.sign_in(email, password, "127.0.0.1", "Mozilla/5.0")? {
511 SignInResult::Ok(_, _) => (), // Nominal case.
512 other => panic!("{:?}", other),
513 }
514
515 Ok(())
516 }
517
518 #[test]
519 fn sign_up_then_send_validation_then_authentication() -> Result<()> {
520 let connection = Connection::new_in_memory()?;
521
522 let email = "paul@test.org";
523 let password = "12345";
524
525 // Sign up.
526 let validation_token =
527 match connection.sign_up(email, password)? {
528 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
529 other => panic!("{:?}", other),
530 };
531
532 // Validation.
533 let (authentication_token, user_id) = match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla")? {
534 ValidationResult::Ok(token, user_id) => (token, user_id),
535 other => panic!("{:?}", other),
536 };
537
538 // Check user login information.
539 let user_login_info_1 = connection.get_user_login_info(&authentication_token)?;
540 assert_eq!(user_login_info_1.ip, "127.0.0.1");
541 assert_eq!(user_login_info_1.user_agent, "Mozilla");
542
543 // Authentication.
544 let _user_id =
545 match connection.authentication(&authentication_token, "192.168.1.1", "Chrome")? {
546 AuthenticationResult::Ok(user_id) => user_id, // Nominal case.
547 other => panic!("{:?}", other),
548 };
549
550 // Check user login information.
551 let user_login_info_2 = connection.get_user_login_info(&authentication_token)?;
552 assert_eq!(user_login_info_2.ip, "192.168.1.1");
553 assert_eq!(user_login_info_2.user_agent, "Chrome");
554
555 Ok(())
556 }
557
558 #[test]
559 fn sign_up_then_send_validation_then_sign_out_then_sign_in() -> Result<()> {
560 let connection = Connection::new_in_memory()?;
561
562 let email = "paul@test.org";
563 let password = "12345";
564
565 // Sign up.
566 let validation_token =
567 match connection.sign_up(email, password)? {
568 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
569 other => panic!("{:?}", other),
570 };
571
572 // Validation.
573 let (authentication_token_1, user_id_1) =
574 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla")? {
575 ValidationResult::Ok(token, user_id) => (token, user_id),
576 other => panic!("{:?}", other),
577 };
578
579 // Check user login information.
580 let user_login_info_1 = connection.get_user_login_info(&authentication_token_1)?;
581 assert_eq!(user_login_info_1.ip, "127.0.0.1");
582 assert_eq!(user_login_info_1.user_agent, "Mozilla");
583
584 // Sign out.
585 connection.sign_out(&authentication_token_1)?;
586
587 // Sign in.
588 let (authentication_token_2, user_id_2) =
589 match connection.sign_in(email, password, "192.168.1.1", "Chrome")? {
590 SignInResult::Ok(token, user_id) => (token, user_id),
591 other => panic!("{:?}", other),
592 };
593
594 assert_eq!(user_id_1, user_id_2);
595 assert_ne!(authentication_token_1, authentication_token_2);
596
597 // Check user login information.
598 let user_login_info_2 = connection.get_user_login_info(&authentication_token_2)?;
599
600 assert_eq!(user_login_info_2.ip, "192.168.1.1");
601 assert_eq!(user_login_info_2.user_agent, "Chrome");
602
603 Ok(())
604 }
605 }