28bf62c657725fd7fd3fb9297d0f35b4a562cfd4
[recipes.git] / backend / src / db.rs
1 use std::{fmt::Display, fs::{self, File}, path::Path, io::Read};
2
3 use itertools::Itertools;
4 use chrono::{prelude::*, Duration};
5 use rusqlite::{named_params, OptionalExtension, params, Params};
6 use r2d2::Pool;
7 use r2d2_sqlite::SqliteConnectionManager;
8 use rand::distributions::{Alphanumeric, DistString};
9
10 use crate::consts;
11 use crate::hash::{hash, verify_password};
12 use crate::model;
13 use crate::user::*;
14
15 const CURRENT_DB_VERSION: u32 = 1;
16
17 #[derive(Debug)]
18 pub enum DBError {
19 SqliteError(rusqlite::Error),
20 R2d2Error(r2d2::Error),
21 UnsupportedVersion(u32),
22 Other(String),
23 }
24
25 impl From<rusqlite::Error> for DBError {
26 fn from(error: rusqlite::Error) -> Self {
27 DBError::SqliteError(error)
28 }
29 }
30
31 impl From<r2d2::Error> for DBError {
32 fn from(error: r2d2::Error) -> Self {
33 DBError::R2d2Error(error)
34 }
35 }
36
37 // TODO: Is there a better solution?
38 impl DBError {
39 fn from_dyn_error(error: Box<dyn std::error::Error>) -> Self {
40 DBError::Other(error.to_string())
41 }
42 }
43
44 type Result<T> = std::result::Result<T, DBError>;
45
46 #[derive(Debug)]
47 pub enum SignUpResult {
48 UserAlreadyExists,
49 UserCreatedWaitingForValidation(String), // Validation token.
50 }
51
52 #[derive(Debug)]
53 pub enum ValidationResult {
54 UnknownUser,
55 ValidationExpired,
56 Ok(String, i32), // Returns token and user id.
57 }
58
59 #[derive(Debug)]
60 pub enum SignInResult {
61 UserNotFound,
62 PasswordsDontMatch,
63 Ok(String, i32), // Returns token and user id.
64 }
65
66 #[derive(Debug)]
67 pub enum AuthenticationResult {
68 NotValidToken,
69 Ok(i32), // Returns user id.
70 }
71
72 #[derive(Clone)]
73 pub struct Connection {
74 //con: rusqlite::Connection
75 pool: Pool<SqliteConnectionManager>
76 }
77
78 impl Connection {
79 pub fn new() -> Result<Connection> {
80 let path = Path::new(consts::DB_DIRECTORY).join(consts::DB_FILENAME);
81 Self::new_from_file(path)
82 }
83
84 pub fn new_in_memory() -> Result<Connection> {
85 Self::create_connection(SqliteConnectionManager::memory())
86 }
87
88 pub fn new_from_file<P: AsRef<Path>>(file: P) -> Result<Connection> {
89 if let Some(data_dir) = file.as_ref().parent() {
90 if !data_dir.exists() {
91 fs::DirBuilder::new().create(data_dir).unwrap();
92 }
93 }
94
95 Self::create_connection(SqliteConnectionManager::file(file))
96 }
97
98 fn create_connection(manager: SqliteConnectionManager) -> Result<Connection> {;
99 let pool = r2d2::Pool::new(manager).unwrap();
100 let connection = Connection { pool };
101 connection.create_or_update()?;
102 Ok(connection)
103 }
104
105 /// Called after the connection has been established for creating or updating the database.
106 /// The 'Version' table tracks the current state of the database.
107 fn create_or_update(&self) -> Result<()> {
108 // Check the Database version.
109 let mut con = self.pool.get()?;
110 let tx = con.transaction()?;
111
112 // Version 0 corresponds to an empty database.
113 let mut version = {
114 match tx.query_row(
115 "SELECT [name] FROM [sqlite_master] WHERE [type] = 'table' AND [name] = 'Version'",
116 [],
117 |row| row.get::<usize, String>(0)
118 ) {
119 Ok(_) => tx.query_row("SELECT [version] FROM [Version] ORDER BY [id] DESC", [], |row| row.get(0)).unwrap_or_default(),
120 Err(_) => 0
121 }
122 };
123
124 while Self::update_to_next_version(version, &tx)? {
125 version += 1;
126 }
127
128 tx.commit()?;
129
130 Ok(())
131 }
132
133 fn update_to_next_version(current_version: u32, tx: &rusqlite::Transaction) -> Result<bool> {
134 let next_version = current_version + 1;
135
136 if next_version <= CURRENT_DB_VERSION {
137 println!("Update to version {}...", next_version);
138 }
139
140 fn update_version(to_version: u32, tx: &rusqlite::Transaction) -> Result<()> {
141 tx.execute("INSERT INTO [Version] ([version], [datetime]) VALUES (?1, datetime('now'))", [to_version]).map(|_| ()).map_err(DBError::from)
142 }
143
144 fn ok(updated: bool) -> Result<bool> {
145 if updated {
146 println!("Version updated");
147 }
148 Ok(updated)
149 }
150
151 match next_version {
152 1 => {
153 let sql_file = consts::SQL_FILENAME.replace("{VERSION}", &next_version.to_string());
154 tx.execute_batch(&load_sql_file(&sql_file)?)?;
155 update_version(next_version, tx)?;
156
157 ok(true)
158 }
159
160 // Version 1 doesn't exist yet.
161 2 =>
162 ok(false),
163
164 v =>
165 Err(DBError::UnsupportedVersion(v)),
166 }
167 }
168
169 pub fn get_all_recipe_titles(&self) -> Result<Vec<(i32, String)>> {
170 let con = self.pool.get()?;
171 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
172 let titles =
173 stmt.query_map([], |row| {
174 Ok((row.get(0)?, row.get(1)?))
175 })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
176 Ok(titles)
177 }
178
179 /* Not used for the moment.
180 pub fn get_all_recipes(&self) -> Result<Vec<model::Recipe>> {
181 let con = self.pool.get()?;
182 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
183 let recipes =
184 stmt.query_map([], |row| {
185 Ok(model::Recipe::new(row.get(0)?, row.get(1)?))
186 })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
187 Ok(recipes)
188 } */
189
190 pub fn get_recipe(&self, id: i32) -> Result<model::Recipe> {
191 let con = self.pool.get()?;
192 con.query_row("SELECT [id], [title] FROM [Recipe] WHERE [id] = ?1", [id], |row| {
193 Ok(model::Recipe::new(row.get(0)?, row.get(1)?))
194 }).map_err(DBError::from)
195 }
196
197 pub fn get_user_login_info(&self, token: &str) -> Result<UserLoginInfo> {
198 let con = self.pool.get()?;
199 con.query_row("SELECT [last_login_datetime], [ip], [user_agent] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| {
200 Ok(UserLoginInfo {
201 last_login_datetime: r.get("last_login_datetime")?,
202 ip: r.get("ip")?,
203 user_agent: r.get("user_agent")?,
204 })
205 }).map_err(DBError::from)
206 }
207
208 ///
209 pub fn sign_up(&self, password: &str, email: &str) -> Result<SignUpResult> {
210 self.sign_up_with_given_time(password, email, Utc::now())
211 }
212
213 fn sign_up_with_given_time(&self, password: &str, email: &str, datetime: DateTime<Utc>) -> Result<SignUpResult> {
214 let mut con = self.pool.get()?;
215 let tx = con.transaction()?;
216 let token =
217 match tx.query_row("SELECT [id], [validation_token] FROM [User] WHERE [email] = ?1", [email], |r| {
218 Ok((r.get::<&str, i32>("id")?, r.get::<&str, Option<String>>("validation_token")?))
219 }).optional()? {
220 Some((id, validation_token)) => {
221 if validation_token.is_none() {
222 return Ok(SignUpResult::UserAlreadyExists)
223 }
224 let token = generate_token();
225 let hashed_password = hash(password).map_err(|e| DBError::from_dyn_error(e))?;
226 tx.execute("UPDATE [User] SET [validation_token] = ?2, [creation_datetime] = ?3, [password] = ?4 WHERE [id] = ?1", params![id, token, datetime, hashed_password])?;
227 token
228 },
229 None => {
230 let token = generate_token();
231 let hashed_password = hash(password).map_err(|e| DBError::from_dyn_error(e))?;
232 tx.execute("INSERT INTO [User] ([email], [validation_token], [creation_datetime], [password]) VALUES (?1, ?2, ?3, ?4)", params![email, token, datetime, hashed_password])?;
233 token
234 },
235 };
236 tx.commit()?;
237 Ok(SignUpResult::UserCreatedWaitingForValidation(token))
238 }
239
240 pub fn validation(&self, token: &str, validation_time: Duration, ip: &str, user_agent: &str) -> Result<ValidationResult> {
241 let mut con = self.pool.get()?;
242 let tx = con.transaction()?;
243 let user_id =
244 match tx.query_row("SELECT [id], [creation_datetime] FROM [User] WHERE [validation_token] = ?1", [token], |r| {
245 Ok((r.get::<&str, i32>("id")?, r.get::<&str, DateTime<Utc>>("creation_datetime")?))
246 }).optional()? {
247 Some((id, creation_datetime)) => {
248 if Utc::now() - creation_datetime > validation_time {
249 return Ok(ValidationResult::ValidationExpired)
250 }
251 tx.execute("UPDATE [User] SET [validation_token] = NULL WHERE [id] = ?1", [id])?;
252 id
253 },
254 None => {
255 return Ok(ValidationResult::UnknownUser)
256 },
257 };
258 let token = Connection::create_login_token(&tx, user_id, ip, user_agent)?;
259 tx.commit()?;
260 Ok(ValidationResult::Ok(token, user_id))
261 }
262
263 pub fn sign_in(&self, password: &str, email: &str, ip: &str, user_agent: &str) -> Result<SignInResult> {
264 let mut con = self.pool.get()?;
265 let tx = con.transaction()?;
266 match tx.query_row("SELECT [id], [password] FROM [User] WHERE [email] = ?1", [email], |r| {
267 Ok((r.get::<&str, i32>("id")?, r.get::<&str, String>("password")?))
268 }).optional()? {
269 Some((id, stored_password)) => {
270 if verify_password(password, &stored_password).map_err(DBError::from_dyn_error)? {
271 let token = Connection::create_login_token(&tx, id, ip, user_agent)?;
272 tx.commit()?;
273 Ok(SignInResult::Ok(token, id))
274 } else {
275 Ok(SignInResult::PasswordsDontMatch)
276 }
277 },
278 None => {
279 Ok(SignInResult::UserNotFound)
280 },
281 }
282 }
283
284 pub fn authentication(&self, token: &str, ip: &str, user_agent: &str) -> Result<AuthenticationResult> {
285 let mut con = self.pool.get()?;
286 let tx = con.transaction()?;
287 match tx.query_row("SELECT [id], [user_id] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| {
288 Ok((r.get::<&str, i32>("id")?, r.get::<&str, i32>("user_id")?))
289 }).optional()? {
290 Some((login_id, user_id)) => {
291 tx.execute("UPDATE [UserLoginToken] SET [last_login_datetime] = ?2, [ip] = ?3, [user_agent] = ?4 WHERE [id] = ?1", params![login_id, Utc::now(), ip, user_agent])?;
292 tx.commit()?;
293 Ok(AuthenticationResult::Ok(user_id))
294 },
295 None =>
296 Ok(AuthenticationResult::NotValidToken)
297 }
298 }
299
300 pub fn sign_out(&self, token: &str) -> Result<()> {
301 let mut con = self.pool.get()?;
302 let tx = con.transaction()?;
303 match tx.query_row("SELECT [id] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| {
304 Ok(r.get::<&str, i32>("id")?)
305 }).optional()? {
306 Some(login_id) => {
307 tx.execute("DELETE FROM [UserLoginToken] WHERE [id] = ?1", params![login_id])?;
308 tx.commit()?
309 },
310 None => (),
311 }
312 Ok(())
313 }
314
315 /// Execute a given SQL file.
316 pub fn execute_file<P: AsRef<Path> + Display>(&self, file: P) -> Result<()> {
317 let con = self.pool.get()?;
318 let sql = load_sql_file(file)?;
319 con.execute_batch(&sql).map_err(DBError::from)
320 }
321
322 /// Execute any SQL statement.
323 /// Mainly used for testing.
324 pub fn execute_sql<P: Params>(&self, sql: &str, params: P) -> Result<usize> {
325 let con = self.pool.get()?;
326 con.execute(sql, params).map_err(DBError::from)
327 }
328
329 // Return the token.
330 fn create_login_token(tx: &rusqlite::Transaction, user_id: i32, ip: &str, user_agent: &str) -> Result<String> {
331 let token = generate_token();
332 tx.execute("INSERT INTO [UserLoginToken] ([user_id], [last_login_datetime], [token], [ip], [user_agent]) VALUES (?1, ?2, ?3, ?4, ?5)", params![user_id, Utc::now(), token, ip, user_agent])?;
333 Ok(token)
334 }
335 }
336
337 fn load_sql_file<P: AsRef<Path> + Display>(sql_file: P) -> Result<String> {
338 let mut file = File::open(&sql_file).map_err(|err| DBError::Other(format!("Cannot open SQL file ({}): {}", &sql_file, err.to_string())))?;
339 let mut sql = String::new();
340 file.read_to_string(&mut sql).map_err(|err| DBError::Other(format!("Cannot read SQL file ({}) : {}", &sql_file, err.to_string())))?;
341 Ok(sql)
342 }
343
344 fn generate_token() -> String {
345 Alphanumeric.sample_string(&mut rand::thread_rng(), 24)
346 }
347
348 #[cfg(test)]
349 mod tests {
350 use super::*;
351
352 #[test]
353 fn sign_up() -> Result<()> {
354 let connection = Connection::new_in_memory()?;
355 match connection.sign_up("12345", "paul@test.org")? {
356 SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case.
357 other => panic!("{:?}", other),
358 }
359 Ok(())
360 }
361
362 #[test]
363 fn sign_up_to_an_already_existing_user() -> Result<()> {
364 let connection = Connection::new_in_memory()?;
365 connection.execute_sql("
366 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
367 VALUES (
368 1,
369 'paul@test.org',
370 'paul',
371 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
372 0,
373 NULL
374 );", [])?;
375 match connection.sign_up("12345", "paul@test.org")? {
376 SignUpResult::UserAlreadyExists => (), // Nominal case.
377 other => panic!("{:?}", other),
378 }
379 Ok(())
380 }
381
382 #[test]
383 fn sign_up_to_an_unvalidated_already_existing_user() -> Result<()> {
384 let connection = Connection::new_in_memory()?;
385 let token = generate_token();
386 connection.execute_sql("
387 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
388 VALUES (
389 1,
390 'paul@test.org',
391 'paul',
392 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
393 0,
394 :token
395 );", named_params! { ":token": token })?;
396 match connection.sign_up("12345", "paul@test.org")? {
397 SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case.
398 other => panic!("{:?}", other),
399 }
400 Ok(())
401 }
402
403 #[test]
404 fn sign_up_then_send_validation_at_time() -> Result<()> {
405 let connection = Connection::new_in_memory()?;
406 let validation_token =
407 match connection.sign_up("12345", "paul@test.org")? {
408 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
409 other => panic!("{:?}", other),
410 };
411 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
412 ValidationResult::Ok(_, _) => (), // Nominal case.
413 other => panic!("{:?}", other),
414 }
415 Ok(())
416 }
417
418 #[test]
419 fn sign_up_then_send_validation_too_late() -> Result<()> {
420 let connection = Connection::new_in_memory()?;
421 let validation_token =
422 match connection.sign_up_with_given_time("12345", "paul@test.org", Utc::now() - Duration::days(1))? {
423 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
424 other => panic!("{:?}", other),
425 };
426 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
427 ValidationResult::ValidationExpired => (), // Nominal case.
428 other => panic!("{:?}", other),
429 }
430 Ok(())
431 }
432
433 #[test]
434 fn sign_up_then_send_validation_with_bad_token() -> Result<()> {
435 let connection = Connection::new_in_memory()?;
436 let _validation_token =
437 match connection.sign_up("12345", "paul@test.org")? {
438 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
439 other => panic!("{:?}", other),
440 };
441 let random_token = generate_token();
442 match connection.validation(&random_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
443 ValidationResult::UnknownUser => (), // Nominal case.
444 other => panic!("{:?}", other),
445 }
446 Ok(())
447 }
448
449 #[test]
450 fn sign_up_then_send_validation_then_sign_in() -> Result<()> {
451 let connection = Connection::new_in_memory()?;
452
453 let password = "12345";
454 let email = "paul@test.org";
455
456 // Sign up.
457 let validation_token =
458 match connection.sign_up(password, email)? {
459 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
460 other => panic!("{:?}", other),
461 };
462
463 // Validation.
464 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0")? {
465 ValidationResult::Ok(_, _) => (),
466 other => panic!("{:?}", other),
467 };
468
469 // Sign in.
470 match connection.sign_in(password, email, "127.0.0.1", "Mozilla/5.0")? {
471 SignInResult::Ok(_, _) => (), // Nominal case.
472 other => panic!("{:?}", other),
473 }
474
475 Ok(())
476 }
477
478 #[test]
479 fn sign_up_then_send_validation_then_authentication() -> Result<()> {
480 let connection = Connection::new_in_memory()?;
481
482 let password = "12345";
483 let email = "paul@test.org";
484
485 // Sign up.
486 let validation_token =
487 match connection.sign_up(password, email)? {
488 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
489 other => panic!("{:?}", other),
490 };
491
492 // Validation.
493 let (authentication_token, user_id) = match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla")? {
494 ValidationResult::Ok(token, user_id) => (token, user_id),
495 other => panic!("{:?}", other),
496 };
497
498 // Check user login information.
499 let user_login_info_1 = connection.get_user_login_info(&authentication_token)?;
500 assert_eq!(user_login_info_1.ip, "127.0.0.1");
501 assert_eq!(user_login_info_1.user_agent, "Mozilla");
502
503 // Authentication.
504 let _user_id =
505 match connection.authentication(&authentication_token, "192.168.1.1", "Chrome")? {
506 AuthenticationResult::Ok(user_id) => user_id, // Nominal case.
507 other => panic!("{:?}", other),
508 };
509
510 // Check user login information.
511 let user_login_info_2 = connection.get_user_login_info(&authentication_token)?;
512 assert_eq!(user_login_info_2.ip, "192.168.1.1");
513 assert_eq!(user_login_info_2.user_agent, "Chrome");
514
515 Ok(())
516 }
517
518 #[test]
519 fn sign_up_then_send_validation_then_sign_out_then_sign_in() -> Result<()> {
520 let connection = Connection::new_in_memory()?;
521
522 let password = "12345";
523 let email = "paul@test.org";
524
525 // Sign up.
526 let validation_token =
527 match connection.sign_up(password, email)? {
528 SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
529 other => panic!("{:?}", other),
530 };
531
532 // Validation.
533 let (authentication_token_1, user_id_1) =
534 match connection.validation(&validation_token, Duration::hours(1), "127.0.0.1", "Mozilla")? {
535 ValidationResult::Ok(token, user_id) => (token, user_id),
536 other => panic!("{:?}", other),
537 };
538
539 // Check user login information.
540 let user_login_info_1 = connection.get_user_login_info(&authentication_token_1)?;
541 assert_eq!(user_login_info_1.ip, "127.0.0.1");
542 assert_eq!(user_login_info_1.user_agent, "Mozilla");
543
544 // Sign out.
545 connection.sign_out(&authentication_token_1)?;
546
547 // Sign in.
548 let (authentication_token_2, user_id_2) =
549 match connection.sign_in(password, email, "192.168.1.1", "Chrome")? {
550 SignInResult::Ok(token, user_id) => (token, user_id),
551 other => panic!("{:?}", other),
552 };
553
554 assert_eq!(user_id_1, user_id_2);
555 assert_ne!(authentication_token_1, authentication_token_2);
556
557 // Check user login information.
558 let user_login_info_2 = connection.get_user_login_info(&authentication_token_2)?;
559
560 assert_eq!(user_login_info_2.ip, "192.168.1.1");
561 assert_eq!(user_login_info_2.user_agent, "Chrome");
562
563 Ok(())
564 }
565 }