1 use std
::{fmt
, fs
::{self, File
}, path
::Path
, io
::Read
};
3 use itertools
::Itertools
;
4 use chrono
::{prelude
::*, Duration
};
5 use rusqlite
::{named_params
, OptionalExtension
, params
, Params
};
7 use r2d2_sqlite
::SqliteConnectionManager
;
8 use rand
::distributions
::{Alphanumeric
, DistString
};
10 use crate::{consts
, user
};
11 use crate::hash
::{hash
, verify_password
};
15 const CURRENT_DB_VERSION
: u32 = 1;
19 SqliteError(rusqlite
::Error
),
20 R2d2Error(r2d2
::Error
),
21 UnsupportedVersion(u32),
25 impl fmt
::Display
for DBError
{
26 fn fmt(&self, f
: &mut fmt
::Formatter
) -> std
::result
::Result
<(), fmt
::Error
> {
27 write!(f
, "{:?}", self)
31 impl std
::error
::Error
for DBError
{ }
33 impl From
<rusqlite
::Error
> for DBError
{
34 fn from(error
: rusqlite
::Error
) -> Self {
35 DBError
::SqliteError(error
)
39 impl From
<r2d2
::Error
> for DBError
{
40 fn from(error
: r2d2
::Error
) -> Self {
41 DBError
::R2d2Error(error
)
46 fn from_dyn_error(error
: Box
<dyn std
::error
::Error
>) -> Self {
47 DBError
::Other(error
.to_string())
51 type Result
<T
> = std
::result
::Result
<T
, DBError
>;
54 pub enum SignUpResult
{
56 UserCreatedWaitingForValidation(String
), // Validation token.
60 pub enum ValidationResult
{
63 Ok(String
, i64), // Returns token and user id.
67 pub enum SignInResult
{
71 Ok(String
, i64), // Returns token and user id.
75 pub enum AuthenticationResult
{
77 Ok(i64), // Returns user id.
81 pub struct Connection
{
82 pool
: Pool
<SqliteConnectionManager
>
86 pub fn new() -> Result
<Connection
> {
87 let path
= Path
::new(consts
::DB_DIRECTORY
).join(consts
::DB_FILENAME
);
88 Self::new_from_file(path
)
91 pub fn new_in_memory() -> Result
<Connection
> {
92 Self::create_connection(SqliteConnectionManager
::memory())
95 pub fn new_from_file
<P
: AsRef
<Path
>>(file
: P
) -> Result
<Connection
> {
96 if let Some(data_dir
) = file
.as_ref().parent() {
97 if !data_dir
.exists() {
98 fs
::DirBuilder
::new().create(data_dir
).unwrap();
102 Self::create_connection(SqliteConnectionManager
::file(file
))
105 fn create_connection(manager
: SqliteConnectionManager
) -> Result
<Connection
> {
106 let pool
= r2d2
::Pool
::new(manager
).unwrap();
107 let connection
= Connection
{ pool
};
108 connection
.create_or_update_db()?
;
112 /// Called after the connection has been established for creating or updating the database.
113 /// The 'Version' table tracks the current state of the database.
114 fn create_or_update_db(&self) -> Result
<()> {
115 // Check the Database version.
116 let mut con
= self.pool
.get()?
;
117 let tx
= con
.transaction()?
;
119 // Version 0 corresponds to an empty database.
122 "SELECT [name] FROM [sqlite_master] WHERE [type] = 'table' AND [name] = 'Version'",
124 |row
| row
.get
::<usize, String
>(0)
126 Ok(_
) => tx
.query_row("SELECT [version] FROM [Version] ORDER BY [id] DESC", [], |row
| row
.get(0)).unwrap_or_default(),
131 while Self::update_to_next_version(version
, &tx
)?
{
140 fn update_to_next_version(current_version
: u32, tx
: &rusqlite
::Transaction
) -> Result
<bool
> {
141 let next_version
= current_version
+ 1;
143 if next_version
<= CURRENT_DB_VERSION
{
144 println!("Update to version {}...", next_version
);
147 fn update_version(to_version
: u32, tx
: &rusqlite
::Transaction
) -> Result
<()> {
148 tx
.execute("INSERT INTO [Version] ([version], [datetime]) VALUES (?1, datetime('now'))", [to_version
]).map(|_
| ()).map_err(DBError
::from
)
151 fn ok(updated
: bool
) -> Result
<bool
> {
153 println!("Version updated");
160 let sql_file
= consts
::SQL_FILENAME
.replace("{VERSION}", &next_version
.to_string());
161 tx
.execute_batch(&load_sql_file(&sql_file
)?
)?
;
162 update_version(next_version
, tx
)?
;
167 // Version 1 doesn't exist yet.
172 Err(DBError
::UnsupportedVersion(v
)),
176 pub fn get_all_recipe_titles(&self) -> Result
<Vec
<(i64, String
)>> {
177 let con
= self.pool
.get()?
;
179 let mut stmt
= con
.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?
;
181 let titles
: std
::result
::Result
<Vec
<(i64, String
)>, rusqlite
::Error
> =
182 stmt
.query_map([], |row
| {
183 Ok((row
.get("id")?
, row
.get("title")?
))
186 titles
.map_err(DBError
::from
)
189 /* Not used for the moment.
190 pub fn get_all_recipes(&self) -> Result<Vec<model::Recipe>> {
191 let con = self.pool.get()?;
192 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
194 stmt.query_map([], |row| {
195 Ok(model::Recipe::new(row.get(0)?, row.get(1)?))
196 })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
200 pub fn get_recipe(&self, id
: i64) -> Result
<model
::Recipe
> {
201 let con
= self.pool
.get()?
;
202 con
.query_row("SELECT [id], [title], [description] FROM [Recipe] WHERE [id] = ?1", [id
], |row
| {
203 Ok(model
::Recipe
::new(row
.get("id")?
, row
.get("title")?
, row
.get("description")?
))
204 }).map_err(DBError
::from
)
207 pub fn get_user_login_info(&self, token
: &str) -> Result
<UserLoginInfo
> {
208 let con
= self.pool
.get()?
;
209 con
.query_row("SELECT [last_login_datetime], [ip], [user_agent] FROM [UserLoginToken] WHERE [token] = ?1", [token
], |r
| {
211 last_login_datetime
: r
.get("last_login_datetime")?
,
213 user_agent
: r
.get("user_agent")?
,
215 }).map_err(DBError
::from
)
218 pub fn load_user(&self, user_id
: i64) -> Result
<User
> {
219 let con
= self.pool
.get()?
;
220 con
.query_row("SELECT [email] FROM [User] WHERE [id] = ?1", [user_id
], |r
| {
222 email
: r
.get("email")?
,
224 }).map_err(DBError
::from
)
227 pub fn sign_up(&self, email
: &str, password
: &str) -> Result
<SignUpResult
> {
228 self.sign_up_with_given_time(email
, password
, Utc
::now())
231 fn sign_up_with_given_time(&self, email
: &str, password
: &str, datetime
: DateTime
<Utc
>) -> Result
<SignUpResult
> {
232 let mut con
= self.pool
.get()?
;
233 let tx
= con
.transaction()?
;
235 match tx
.query_row("SELECT [id], [validation_token] FROM [User] WHERE [email] = ?1", [email
], |r
| {
236 Ok((r
.get
::<&str, i64>("id")?
, r
.get
::<&str, Option
<String
>>("validation_token")?
))
238 Some((id
, validation_token
)) => {
239 if validation_token
.is_none() {
240 return Ok(SignUpResult
::UserAlreadyExists
)
242 let token
= generate_token();
243 let hashed_password
= hash(password
).map_err(|e
| DBError
::from_dyn_error(e
))?
;
244 tx
.execute("UPDATE [User] SET [validation_token] = ?2, [creation_datetime] = ?3, [password] = ?4 WHERE [id] = ?1", params
![id
, token
, datetime
, hashed_password
])?
;
248 let token
= generate_token();
249 let hashed_password
= hash(password
).map_err(|e
| DBError
::from_dyn_error(e
))?
;
250 tx
.execute("INSERT INTO [User] ([email], [validation_token], [creation_datetime], [password]) VALUES (?1, ?2, ?3, ?4)", params
![email
, token
, datetime
, hashed_password
])?
;
255 Ok(SignUpResult
::UserCreatedWaitingForValidation(token
))
258 pub fn validation(&self, token
: &str, validation_time
: Duration
, ip
: &str, user_agent
: &str) -> Result
<ValidationResult
> {
259 let mut con
= self.pool
.get()?
;
260 let tx
= con
.transaction()?
;
262 match tx
.query_row("SELECT [id], [creation_datetime] FROM [User] WHERE [validation_token] = ?1", [token
], |r
| {
263 Ok((r
.get
::<&str, i64>("id")?
, r
.get
::<&str, DateTime
<Utc
>>("creation_datetime")?
))
265 Some((id
, creation_datetime
)) => {
266 if Utc
::now() - creation_datetime
> validation_time
{
267 return Ok(ValidationResult
::ValidationExpired
)
269 tx
.execute("UPDATE [User] SET [validation_token] = NULL WHERE [id] = ?1", [id
])?
;
273 return Ok(ValidationResult
::UnknownUser
)
276 let token
= Connection
::create_login_token(&tx
, user_id
, ip
, user_agent
)?
;
278 Ok(ValidationResult
::Ok(token
, user_id
))
281 pub fn sign_in(&self, email
: &str, password
: &str, ip
: &str, user_agent
: &str) -> Result
<SignInResult
> {
282 let mut con
= self.pool
.get()?
;
283 let tx
= con
.transaction()?
;
284 match tx
.query_row("SELECT [id], [password], [validation_token] FROM [User] WHERE [email] = ?1", [email
], |r
| {
285 Ok((r
.get
::<&str, i64>("id")?
, r
.get
::<&str, String
>("password")?
, r
.get
::<&str, Option
<String
>>("validation_token")?
))
287 Some((id
, stored_password
, validation_token
)) => {
288 if validation_token
.is_some() {
289 Ok(SignInResult
::AccountNotValidated
)
290 } else if verify_password(password
, &stored_password
).map_err(DBError
::from_dyn_error
)?
{
291 let token
= Connection
::create_login_token(&tx
, id
, ip
, user_agent
)?
;
293 Ok(SignInResult
::Ok(token
, id
))
295 Ok(SignInResult
::WrongPassword
)
299 Ok(SignInResult
::UserNotFound
)
304 pub fn authentication(&self, token
: &str, ip
: &str, user_agent
: &str) -> Result
<AuthenticationResult
> {
305 let mut con
= self.pool
.get()?
;
306 let tx
= con
.transaction()?
;
307 match tx
.query_row("SELECT [id], [user_id] FROM [UserLoginToken] WHERE [token] = ?1", [token
], |r
| {
308 Ok((r
.get
::<&str, i64>("id")?
, r
.get
::<&str, i64>("user_id")?
))
310 Some((login_id
, user_id
)) => {
311 tx
.execute("UPDATE [UserLoginToken] SET [last_login_datetime] = ?2, [ip] = ?3, [user_agent] = ?4 WHERE [id] = ?1", params
![login_id
, Utc
::now(), ip
, user_agent
])?
;
313 Ok(AuthenticationResult
::Ok(user_id
))
316 Ok(AuthenticationResult
::NotValidToken
)
320 pub fn sign_out(&self, token
: &str) -> Result
<()> {
321 let mut con
= self.pool
.get()?
;
322 let tx
= con
.transaction()?
;
323 match tx
.query_row("SELECT [id] FROM [UserLoginToken] WHERE [token] = ?1", [token
], |r
| {
324 Ok(r
.get
::<&str, i64>("id")?
)
327 tx
.execute("DELETE FROM [UserLoginToken] WHERE [id] = ?1", params
![login_id
])?
;
335 pub fn create_recipe(&self, user_id
: i64) -> Result
<i64> {
336 let con
= self.pool
.get()?
;
338 // Verify if an empty recipe already exists. Returns its id if one exists.
340 "SELECT [Recipe].[id] FROM [Recipe]
341 INNER JOIN [Image] ON [Image].[recipe_id] = [Recipe].[id]
342 INNER JOIN [Group] ON [Group].[recipe_id] = [Recipe].[id]
343 WHERE [Recipe].[user_id] = ?1 AND [Recipe].[estimate_time] = NULL AND [Recipe].[description] = NULL",
346 Ok(r
.get
::<&str, i64>("id")?
)
349 Some(recipe_id
) => Ok(recipe_id
),
351 con
.execute("INSERT INTO [Recipe] ([user_id], [title]) VALUES (?1, '')", [user_id
])?
;
352 Ok(con
.last_insert_rowid())
357 pub fn set_recipe_title(&self, recipe_id
: i64, title
: &str) -> Result
<()> {
358 let con
= self.pool
.get()?
;
359 con
.execute("UPDATE [Recipe] SET [title] = ?2 WHERE [id] = ?1", params
![recipe_id
, title
]).map(|_n
| ()).map_err(DBError
::from
)
362 /// Execute a given SQL file.
363 pub fn execute_file
<P
: AsRef
<Path
> + fmt
::Display
>(&self, file
: P
) -> Result
<()> {
364 let con
= self.pool
.get()?
;
365 let sql
= load_sql_file(file
)?
;
366 con
.execute_batch(&sql
).map_err(DBError
::from
)
369 /// Execute any SQL statement.
370 /// Mainly used for testing.
371 pub fn execute_sql
<P
: Params
>(&self, sql
: &str, params
: P
) -> Result
<usize> {
372 let con
= self.pool
.get()?
;
373 con
.execute(sql
, params
).map_err(DBError
::from
)
377 fn create_login_token(tx
: &rusqlite
::Transaction
, user_id
: i64, ip
: &str, user_agent
: &str) -> Result
<String
> {
378 let token
= generate_token();
379 tx
.execute("INSERT INTO [UserLoginToken] ([user_id], [last_login_datetime], [token], [ip], [user_agent]) VALUES (?1, ?2, ?3, ?4, ?5)", params
![user_id
, Utc
::now(), token
, ip
, user_agent
])?
;
384 fn load_sql_file
<P
: AsRef
<Path
> + fmt
::Display
>(sql_file
: P
) -> Result
<String
> {
385 let mut file
= File
::open(&sql_file
).map_err(|err
| DBError
::Other(format!("Cannot open SQL file ({}): {}", &sql_file
, err
.to_string())))?
;
386 let mut sql
= String
::new();
387 file
.read_to_string(&mut sql
).map_err(|err
| DBError
::Other(format!("Cannot read SQL file ({}) : {}", &sql_file
, err
.to_string())))?
;
391 fn generate_token() -> String
{
392 Alphanumeric
.sample_string(&mut rand
::thread_rng(), consts
::AUTHENTICATION_TOKEN_SIZE
)
398 use rusqlite
::{Error
, ErrorCode
, ffi
, types
::Value
};
401 fn sign_up() -> Result
<()> {
402 let connection
= Connection
::new_in_memory()?
;
403 match connection
.sign_up("paul@test.org", "12345")?
{
404 SignUpResult
::UserCreatedWaitingForValidation(_
) => (), // Nominal case.
405 other
=> panic!("{:?}", other
),
411 fn sign_up_to_an_already_existing_user() -> Result
<()> {
412 let connection
= Connection
::new_in_memory()?
;
413 connection
.execute_sql("
414 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
419 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
423 match connection
.sign_up("paul@test.org", "12345")?
{
424 SignUpResult
::UserAlreadyExists
=> (), // Nominal case.
425 other
=> panic!("{:?}", other
),
431 fn sign_up_and_sign_in_without_validation() -> Result
<()> {
432 let connection
= Connection
::new_in_memory()?
;
434 let email
= "paul@test.org";
435 let password
= "12345";
437 match connection
.sign_up(email
, password
)?
{
438 SignUpResult
::UserCreatedWaitingForValidation(_
) => (), // Nominal case.
439 other
=> panic!("{:?}", other
),
442 match connection
.sign_in(email
, password
, "127.0.0.1", "Mozilla/5.0")?
{
443 SignInResult
::AccountNotValidated
=> (), // Nominal case.
444 other
=> panic!("{:?}", other
),
451 fn sign_up_to_an_unvalidated_already_existing_user() -> Result
<()> {
452 let connection
= Connection
::new_in_memory()?
;
453 let token
= generate_token();
454 connection
.execute_sql("
455 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
460 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
463 );", named_params
! { ":token": token
})?
;
464 match connection
.sign_up("paul@test.org", "12345")?
{
465 SignUpResult
::UserCreatedWaitingForValidation(_
) => (), // Nominal case.
466 other
=> panic!("{:?}", other
),
472 fn sign_up_then_send_validation_at_time() -> Result
<()> {
473 let connection
= Connection
::new_in_memory()?
;
474 let validation_token
=
475 match connection
.sign_up("paul@test.org", "12345")?
{
476 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
477 other
=> panic!("{:?}", other
),
479 match connection
.validation(&validation_token
, Duration
::hours(1), "127.0.0.1", "Mozilla/5.0")?
{
480 ValidationResult
::Ok(_
, _
) => (), // Nominal case.
481 other
=> panic!("{:?}", other
),
487 fn sign_up_then_send_validation_too_late() -> Result
<()> {
488 let connection
= Connection
::new_in_memory()?
;
489 let validation_token
=
490 match connection
.sign_up_with_given_time("paul@test.org", "12345", Utc
::now() - Duration
::days(1))?
{
491 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
492 other
=> panic!("{:?}", other
),
494 match connection
.validation(&validation_token
, Duration
::hours(1), "127.0.0.1", "Mozilla/5.0")?
{
495 ValidationResult
::ValidationExpired
=> (), // Nominal case.
496 other
=> panic!("{:?}", other
),
502 fn sign_up_then_send_validation_with_bad_token() -> Result
<()> {
503 let connection
= Connection
::new_in_memory()?
;
504 let _validation_token
=
505 match connection
.sign_up("paul@test.org", "12345")?
{
506 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
507 other
=> panic!("{:?}", other
),
509 let random_token
= generate_token();
510 match connection
.validation(&random_token
, Duration
::hours(1), "127.0.0.1", "Mozilla/5.0")?
{
511 ValidationResult
::UnknownUser
=> (), // Nominal case.
512 other
=> panic!("{:?}", other
),
518 fn sign_up_then_send_validation_then_sign_in() -> Result
<()> {
519 let connection
= Connection
::new_in_memory()?
;
521 let email
= "paul@test.org";
522 let password
= "12345";
525 let validation_token
=
526 match connection
.sign_up(email
, password
)?
{
527 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
528 other
=> panic!("{:?}", other
),
532 match connection
.validation(&validation_token
, Duration
::hours(1), "127.0.0.1", "Mozilla/5.0")?
{
533 ValidationResult
::Ok(_
, _
) => (),
534 other
=> panic!("{:?}", other
),
538 match connection
.sign_in(email
, password
, "127.0.0.1", "Mozilla/5.0")?
{
539 SignInResult
::Ok(_
, _
) => (), // Nominal case.
540 other
=> panic!("{:?}", other
),
547 fn sign_up_then_send_validation_then_authentication() -> Result
<()> {
548 let connection
= Connection
::new_in_memory()?
;
550 let email
= "paul@test.org";
551 let password
= "12345";
554 let validation_token
=
555 match connection
.sign_up(email
, password
)?
{
556 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
557 other
=> panic!("{:?}", other
),
561 let (authentication_token
, user_id
) = match connection
.validation(&validation_token
, Duration
::hours(1), "127.0.0.1", "Mozilla")?
{
562 ValidationResult
::Ok(token
, user_id
) => (token
, user_id
),
563 other
=> panic!("{:?}", other
),
566 // Check user login information.
567 let user_login_info_1
= connection
.get_user_login_info(&authentication_token
)?
;
568 assert_eq!(user_login_info_1
.ip
, "127.0.0.1");
569 assert_eq!(user_login_info_1
.user_agent
, "Mozilla");
573 match connection
.authentication(&authentication_token
, "192.168.1.1", "Chrome")?
{
574 AuthenticationResult
::Ok(user_id
) => user_id
, // Nominal case.
575 other
=> panic!("{:?}", other
),
578 // Check user login information.
579 let user_login_info_2
= connection
.get_user_login_info(&authentication_token
)?
;
580 assert_eq!(user_login_info_2
.ip
, "192.168.1.1");
581 assert_eq!(user_login_info_2
.user_agent
, "Chrome");
587 fn sign_up_then_send_validation_then_sign_out_then_sign_in() -> Result
<()> {
588 let connection
= Connection
::new_in_memory()?
;
590 let email
= "paul@test.org";
591 let password
= "12345";
594 let validation_token
=
595 match connection
.sign_up(email
, password
)?
{
596 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
597 other
=> panic!("{:?}", other
),
601 let (authentication_token_1
, user_id_1
) =
602 match connection
.validation(&validation_token
, Duration
::hours(1), "127.0.0.1", "Mozilla")?
{
603 ValidationResult
::Ok(token
, user_id
) => (token
, user_id
),
604 other
=> panic!("{:?}", other
),
607 // Check user login information.
608 let user_login_info_1
= connection
.get_user_login_info(&authentication_token_1
)?
;
609 assert_eq!(user_login_info_1
.ip
, "127.0.0.1");
610 assert_eq!(user_login_info_1
.user_agent
, "Mozilla");
613 connection
.sign_out(&authentication_token_1
)?
;
616 let (authentication_token_2
, user_id_2
) =
617 match connection
.sign_in(email
, password
, "192.168.1.1", "Chrome")?
{
618 SignInResult
::Ok(token
, user_id
) => (token
, user_id
),
619 other
=> panic!("{:?}", other
),
622 assert_eq!(user_id_1
, user_id_2
);
623 assert_ne!(authentication_token_1
, authentication_token_2
);
625 // Check user login information.
626 let user_login_info_2
= connection
.get_user_login_info(&authentication_token_2
)?
;
628 assert_eq!(user_login_info_2
.ip
, "192.168.1.1");
629 assert_eq!(user_login_info_2
.user_agent
, "Chrome");
636 fn create_a_new_recipe_then_update_its_title() -> Result
<()> {
637 let connection
= Connection
::new_in_memory()?
;
639 connection
.execute_sql(
640 "INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token]) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
645 "$argon2id$v=19$m=4096,t=3,p=1$G4fjepS05MkRbTqEImUdYg$GGziE8uVQe1L1oFHk37lBno10g4VISnVqynSkLCH3Lc",
646 "2022-11-29 22:05:04.121407300+00:00",
651 match connection
.create_recipe(2) {
652 Err(DBError
::SqliteError(Error
::SqliteFailure(ffi
::Error
{ code
: ErrorCode
::ConstraintViolation
, extended_code
: _
}, Some(_
)))) => (), // Nominal case.
653 other
=> panic!("Creating a recipe with an inexistant user must fail: {:?}", other
),
656 let recipe_id
= connection
.create_recipe(1)?
;
657 assert_eq!(recipe_id
, 1);
659 connection
.set_recipe_title(recipe_id
, "Crêpe")?
;
661 let recipe
= connection
.get_recipe(recipe_id
)?
;
662 assert_eq!(recipe
.title
, "Crêpe".to_string());