8 use chrono
::{prelude
::*, Duration
};
9 use itertools
::Itertools
;
10 use r2d2
::{Pool
, PooledConnection
};
11 use r2d2_sqlite
::SqliteConnectionManager
;
12 use rand
::distributions
::{Alphanumeric
, DistString
};
13 use rusqlite
::{named_params
, params
, OptionalExtension
, Params
};
15 use crate::hash
::{hash
, verify_password
};
18 use crate::{consts
, user
};
20 const CURRENT_DB_VERSION
: u32 = 1;
24 SqliteError(rusqlite
::Error
),
25 R2d2Error(r2d2
::Error
),
26 UnsupportedVersion(u32),
30 impl fmt
::Display
for DBError
{
31 fn fmt(&self, f
: &mut fmt
::Formatter
) -> std
::result
::Result
<(), fmt
::Error
> {
32 write!(f
, "{:?}", self)
36 impl std
::error
::Error
for DBError {}
38 impl From
<rusqlite
::Error
> for DBError
{
39 fn from(error
: rusqlite
::Error
) -> Self {
40 DBError
::SqliteError(error
)
44 impl From
<r2d2
::Error
> for DBError
{
45 fn from(error
: r2d2
::Error
) -> Self {
46 DBError
::R2d2Error(error
)
51 fn from_dyn_error(error
: Box
<dyn std
::error
::Error
>) -> Self {
52 DBError
::Other(error
.to_string())
56 type Result
<T
> = std
::result
::Result
<T
, DBError
>;
59 pub enum SignUpResult
{
61 UserCreatedWaitingForValidation(String
), // Validation token.
65 pub enum ValidationResult
{
68 Ok(String
, i64), // Returns token and user id.
72 pub enum SignInResult
{
76 Ok(String
, i64), // Returns token and user id.
80 pub enum AuthenticationResult
{
82 Ok(i64), // Returns user id.
86 pub struct Connection
{
87 pool
: Pool
<SqliteConnectionManager
>,
91 pub fn new() -> Result
<Connection
> {
92 let path
= Path
::new(consts
::DB_DIRECTORY
).join(consts
::DB_FILENAME
);
93 Self::new_from_file(path
)
96 pub fn new_in_memory() -> Result
<Connection
> {
97 Self::create_connection(SqliteConnectionManager
::memory())
100 pub fn new_from_file
<P
: AsRef
<Path
>>(file
: P
) -> Result
<Connection
> {
101 if let Some(data_dir
) = file
.as_ref().parent() {
102 if !data_dir
.exists() {
103 fs
::DirBuilder
::new().create(data_dir
).unwrap();
107 Self::create_connection(SqliteConnectionManager
::file(file
))
110 fn create_connection(manager
: SqliteConnectionManager
) -> Result
<Connection
> {
111 let pool
= r2d2
::Pool
::new(manager
).unwrap();
112 let connection
= Connection
{ pool
};
113 connection
.create_or_update_db()?
;
117 fn get(&self) -> Result
<PooledConnection
<SqliteConnectionManager
>> {
118 let con
= self.pool
.get()?
;
119 con
.pragma_update(None
, "synchronous", "NORMAL")?
;
123 /// Called after the connection has been established for creating or updating the database.
124 /// The 'Version' table tracks the current state of the database.
125 fn create_or_update_db(&self) -> Result
<()> {
126 // Check the Database version.
127 let mut con
= self.get()?
;
128 con
.pragma_update(None
, "journal_mode", "WAL")?
;
130 let tx
= con
.transaction()?
;
132 // Version 0 corresponds to an empty database.
135 "SELECT [name] FROM [sqlite_master] WHERE [type] = 'table' AND [name] = 'Version'",
137 |row
| row
.get
::<usize, String
>(0),
141 "SELECT [version] FROM [Version] ORDER BY [id] DESC",
145 .unwrap_or_default(),
150 while Self::update_to_next_version(version
, &tx
)?
{
159 fn update_to_next_version(current_version
: u32, tx
: &rusqlite
::Transaction
) -> Result
<bool
> {
160 let next_version
= current_version
+ 1;
162 if next_version
<= CURRENT_DB_VERSION
{
163 println!("Update to version {}...", next_version
);
166 fn update_version(to_version
: u32, tx
: &rusqlite
::Transaction
) -> Result
<()> {
168 "INSERT INTO [Version] ([version], [datetime]) VALUES (?1, datetime('now'))",
172 .map_err(DBError
::from
)
175 fn ok(updated
: bool
) -> Result
<bool
> {
177 println!("Version updated");
184 let sql_file
= consts
::SQL_FILENAME
.replace("{VERSION}", &next_version
.to_string());
185 tx
.execute_batch(&load_sql_file(&sql_file
)?
)?
;
186 update_version(next_version
, tx
)?
;
191 // Version 1 doesn't exist yet.
194 v
=> Err(DBError
::UnsupportedVersion(v
)),
198 pub fn get_all_recipe_titles(&self) -> Result
<Vec
<(i64, String
)>> {
199 let con
= self.get()?
;
201 let mut stmt
= con
.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?
;
203 let titles
: std
::result
::Result
<Vec
<(i64, String
)>, rusqlite
::Error
> = stmt
204 .query_map([], |row
| Ok((row
.get("id")?
, row
.get("title")?
)))?
207 titles
.map_err(DBError
::from
)
210 /* Not used for the moment.
211 pub fn get_all_recipes(&self) -> Result<Vec<model::Recipe>> {
212 let con = self.get()?;
213 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
215 stmt.query_map([], |row| {
216 Ok(model::Recipe::new(row.get(0)?, row.get(1)?))
217 })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
221 pub fn get_recipe(&self, id
: i64) -> Result
<model
::Recipe
> {
222 let con
= self.get()?
;
224 "SELECT [id], [title], [description] FROM [Recipe] WHERE [id] = ?1",
227 Ok(model
::Recipe
::new(
230 row
.get("description")?
,
234 .map_err(DBError
::from
)
237 pub fn get_user_login_info(&self, token
: &str) -> Result
<UserLoginInfo
> {
238 let con
= self.get()?
;
239 con
.query_row("SELECT [last_login_datetime], [ip], [user_agent] FROM [UserLoginToken] WHERE [token] = ?1", [token
], |r
| {
241 last_login_datetime
: r
.get("last_login_datetime")?
,
243 user_agent
: r
.get("user_agent")?
,
245 }).map_err(DBError
::from
)
248 pub fn load_user(&self, user_id
: i64) -> Result
<User
> {
249 let con
= self.get()?
;
251 "SELECT [email] FROM [User] WHERE [id] = ?1",
255 email
: r
.get("email")?
,
259 .map_err(DBError
::from
)
262 pub fn sign_up(&self, email
: &str, password
: &str) -> Result
<SignUpResult
> {
263 self.sign_up_with_given_time(email
, password
, Utc
::now())
266 fn sign_up_with_given_time(
270 datetime
: DateTime
<Utc
>,
271 ) -> Result
<SignUpResult
> {
272 let mut con
= self.get()?
;
273 let tx
= con
.transaction()?
;
276 "SELECT [id], [validation_token] FROM [User] WHERE [email] = ?1",
280 r
.get
::<&str, i64>("id")?
,
281 r
.get
::<&str, Option
<String
>>("validation_token")?
,
287 Some((id
, validation_token
)) => {
288 if validation_token
.is_none() {
289 return Ok(SignUpResult
::UserAlreadyExists
);
291 let token
= generate_token();
292 let hashed_password
= hash(password
).map_err(|e
| DBError
::from_dyn_error(e
))?
;
293 tx
.execute("UPDATE [User] SET [validation_token] = ?2, [creation_datetime] = ?3, [password] = ?4 WHERE [id] = ?1", params
![id
, token
, datetime
, hashed_password
])?
;
297 let token
= generate_token();
298 let hashed_password
= hash(password
).map_err(|e
| DBError
::from_dyn_error(e
))?
;
299 tx
.execute("INSERT INTO [User] ([email], [validation_token], [creation_datetime], [password]) VALUES (?1, ?2, ?3, ?4)", params
![email
, token
, datetime
, hashed_password
])?
;
304 Ok(SignUpResult
::UserCreatedWaitingForValidation(token
))
310 validation_time
: Duration
,
313 ) -> Result
<ValidationResult
> {
314 let mut con
= self.get()?
;
315 let tx
= con
.transaction()?
;
316 let user_id
= match tx
318 "SELECT [id], [creation_datetime] FROM [User] WHERE [validation_token] = ?1",
322 r
.get
::<&str, i64>("id")?
,
323 r
.get
::<&str, DateTime
<Utc
>>("creation_datetime")?
,
329 Some((id
, creation_datetime
)) => {
330 if Utc
::now() - creation_datetime
> validation_time
{
331 return Ok(ValidationResult
::ValidationExpired
);
334 "UPDATE [User] SET [validation_token] = NULL WHERE [id] = ?1",
339 None
=> return Ok(ValidationResult
::UnknownUser
),
341 let token
= Connection
::create_login_token(&tx
, user_id
, ip
, user_agent
)?
;
343 Ok(ValidationResult
::Ok(token
, user_id
))
352 ) -> Result
<SignInResult
> {
353 let mut con
= self.get()?
;
354 let tx
= con
.transaction()?
;
357 "SELECT [id], [password], [validation_token] FROM [User] WHERE [email] = ?1",
361 r
.get
::<&str, i64>("id")?
,
362 r
.get
::<&str, String
>("password")?
,
363 r
.get
::<&str, Option
<String
>>("validation_token")?
,
369 Some((id
, stored_password
, validation_token
)) => {
370 if validation_token
.is_some() {
371 Ok(SignInResult
::AccountNotValidated
)
372 } else if verify_password(password
, &stored_password
)
373 .map_err(DBError
::from_dyn_error
)?
375 let token
= Connection
::create_login_token(&tx
, id
, ip
, user_agent
)?
;
377 Ok(SignInResult
::Ok(token
, id
))
379 Ok(SignInResult
::WrongPassword
)
382 None
=> Ok(SignInResult
::UserNotFound
),
386 pub fn authentication(
391 ) -> Result
<AuthenticationResult
> {
392 let mut con
= self.get()?
;
393 let tx
= con
.transaction()?
;
396 "SELECT [id], [user_id] FROM [UserLoginToken] WHERE [token] = ?1",
398 |r
| Ok((r
.get
::<&str, i64>("id")?
, r
.get
::<&str, i64>("user_id")?
)),
402 Some((login_id
, user_id
)) => {
403 tx
.execute("UPDATE [UserLoginToken] SET [last_login_datetime] = ?2, [ip] = ?3, [user_agent] = ?4 WHERE [id] = ?1", params
![login_id
, Utc
::now(), ip
, user_agent
])?
;
405 Ok(AuthenticationResult
::Ok(user_id
))
407 None
=> Ok(AuthenticationResult
::NotValidToken
),
411 pub fn sign_out(&self, token
: &str) -> Result
<()> {
412 let mut con
= self.get()?
;
413 let tx
= con
.transaction()?
;
416 "SELECT [id] FROM [UserLoginToken] WHERE [token] = ?1",
418 |r
| Ok(r
.get
::<&str, i64>("id")?
),
424 "DELETE FROM [UserLoginToken] WHERE [id] = ?1",
434 pub fn create_recipe(&self, user_id
: i64) -> Result
<i64> {
435 let con
= self.get()?
;
437 // Verify if an empty recipe already exists. Returns its id if one exists.
439 "SELECT [Recipe].[id] FROM [Recipe]
440 INNER JOIN [Image] ON [Image].[recipe_id] = [Recipe].[id]
441 INNER JOIN [Group] ON [Group].[recipe_id] = [Recipe].[id]
442 WHERE [Recipe].[user_id] = ?1 AND [Recipe].[estimate_time] = NULL AND [Recipe].[description] = NULL",
445 Ok(r
.get
::<&str, i64>("id")?
)
448 Some(recipe_id
) => Ok(recipe_id
),
450 con
.execute("INSERT INTO [Recipe] ([user_id], [title]) VALUES (?1, '')", [user_id
])?
;
451 Ok(con
.last_insert_rowid())
456 pub fn set_recipe_title(&self, recipe_id
: i64, title
: &str) -> Result
<()> {
457 let con
= self.get()?
;
459 "UPDATE [Recipe] SET [title] = ?2 WHERE [id] = ?1",
460 params
![recipe_id
, title
],
463 .map_err(DBError
::from
)
466 pub fn set_recipe_description(&self, recipe_id
: i64, description
: &str) -> Result
<()> {
467 let con
= self.get()?
;
469 "UPDATE [Recipe] SET [description] = ?2 WHERE [id] = ?1",
470 params
![recipe_id
, description
],
473 .map_err(DBError
::from
)
476 /// Execute a given SQL file.
477 pub fn execute_file
<P
: AsRef
<Path
> + fmt
::Display
>(&self, file
: P
) -> Result
<()> {
478 let con
= self.get()?
;
479 let sql
= load_sql_file(file
)?
;
480 con
.execute_batch(&sql
).map_err(DBError
::from
)
483 /// Execute any SQL statement.
484 /// Mainly used for testing.
485 pub fn execute_sql
<P
: Params
>(&self, sql
: &str, params
: P
) -> Result
<usize> {
486 let con
= self.get()?
;
487 con
.execute(sql
, params
).map_err(DBError
::from
)
491 fn create_login_token(
492 tx
: &rusqlite
::Transaction
,
496 ) -> Result
<String
> {
497 let token
= generate_token();
498 tx
.execute("INSERT INTO [UserLoginToken] ([user_id], [last_login_datetime], [token], [ip], [user_agent]) VALUES (?1, ?2, ?3, ?4, ?5)", params
![user_id
, Utc
::now(), token
, ip
, user_agent
])?
;
503 fn load_sql_file
<P
: AsRef
<Path
> + fmt
::Display
>(sql_file
: P
) -> Result
<String
> {
504 let mut file
= File
::open(&sql_file
).map_err(|err
| {
505 DBError
::Other(format!(
506 "Cannot open SQL file ({}): {}",
511 let mut sql
= String
::new();
512 file
.read_to_string(&mut sql
).map_err(|err
| {
513 DBError
::Other(format!(
514 "Cannot read SQL file ({}) : {}",
522 fn generate_token() -> String
{
523 Alphanumeric
.sample_string(&mut rand
::thread_rng(), consts
::AUTHENTICATION_TOKEN_SIZE
)
529 use rusqlite
::{ffi
, types
::Value
, Error
, ErrorCode
};
532 fn sign_up() -> Result
<()> {
533 let connection
= Connection
::new_in_memory()?
;
534 match connection
.sign_up("paul@atreides.com", "12345")?
{
535 SignUpResult
::UserCreatedWaitingForValidation(_
) => (), // Nominal case.
536 other
=> panic!("{:?}", other
),
542 fn sign_up_to_an_already_existing_user() -> Result
<()> {
543 let connection
= Connection
::new_in_memory()?
;
544 connection
.execute_sql("
545 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
550 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
554 match connection
.sign_up("paul@atreides.com", "12345")?
{
555 SignUpResult
::UserAlreadyExists
=> (), // Nominal case.
556 other
=> panic!("{:?}", other
),
562 fn sign_up_and_sign_in_without_validation() -> Result
<()> {
563 let connection
= Connection
::new_in_memory()?
;
565 let email
= "paul@atreides.com";
566 let password
= "12345";
568 match connection
.sign_up(email
, password
)?
{
569 SignUpResult
::UserCreatedWaitingForValidation(_
) => (), // Nominal case.
570 other
=> panic!("{:?}", other
),
573 match connection
.sign_in(email
, password
, "127.0.0.1", "Mozilla/5.0")?
{
574 SignInResult
::AccountNotValidated
=> (), // Nominal case.
575 other
=> panic!("{:?}", other
),
582 fn sign_up_to_an_unvalidated_already_existing_user() -> Result
<()> {
583 let connection
= Connection
::new_in_memory()?
;
584 let token
= generate_token();
585 connection
.execute_sql("
586 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
591 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
594 );", named_params
! { ":token": token
})?
;
595 match connection
.sign_up("paul@atreides.com", "12345")?
{
596 SignUpResult
::UserCreatedWaitingForValidation(_
) => (), // Nominal case.
597 other
=> panic!("{:?}", other
),
603 fn sign_up_then_send_validation_at_time() -> Result
<()> {
604 let connection
= Connection
::new_in_memory()?
;
605 let validation_token
= match connection
.sign_up("paul@atreides.com", "12345")?
{
606 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
607 other
=> panic!("{:?}", other
),
609 match connection
.validation(
615 ValidationResult
::Ok(_
, _
) => (), // Nominal case.
616 other
=> panic!("{:?}", other
),
622 fn sign_up_then_send_validation_too_late() -> Result
<()> {
623 let connection
= Connection
::new_in_memory()?
;
624 let validation_token
= match connection
.sign_up_with_given_time(
627 Utc
::now() - Duration
::days(1),
629 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
630 other
=> panic!("{:?}", other
),
632 match connection
.validation(
638 ValidationResult
::ValidationExpired
=> (), // Nominal case.
639 other
=> panic!("{:?}", other
),
645 fn sign_up_then_send_validation_with_bad_token() -> Result
<()> {
646 let connection
= Connection
::new_in_memory()?
;
647 let _validation_token
= match connection
.sign_up("paul@atreides.com", "12345")?
{
648 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
649 other
=> panic!("{:?}", other
),
651 let random_token
= generate_token();
652 match connection
.validation(
658 ValidationResult
::UnknownUser
=> (), // Nominal case.
659 other
=> panic!("{:?}", other
),
665 fn sign_up_then_send_validation_then_sign_in() -> Result
<()> {
666 let connection
= Connection
::new_in_memory()?
;
668 let email
= "paul@atreides.com";
669 let password
= "12345";
672 let validation_token
= match connection
.sign_up(email
, password
)?
{
673 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
674 other
=> panic!("{:?}", other
),
678 match connection
.validation(
684 ValidationResult
::Ok(_
, _
) => (),
685 other
=> panic!("{:?}", other
),
689 match connection
.sign_in(email
, password
, "127.0.0.1", "Mozilla/5.0")?
{
690 SignInResult
::Ok(_
, _
) => (), // Nominal case.
691 other
=> panic!("{:?}", other
),
698 fn sign_up_then_send_validation_then_authentication() -> Result
<()> {
699 let connection
= Connection
::new_in_memory()?
;
701 let email
= "paul@atreides.com";
702 let password
= "12345";
705 let validation_token
= match connection
.sign_up(email
, password
)?
{
706 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
707 other
=> panic!("{:?}", other
),
711 let (authentication_token
, user_id
) = match connection
.validation(
717 ValidationResult
::Ok(token
, user_id
) => (token
, user_id
),
718 other
=> panic!("{:?}", other
),
721 // Check user login information.
722 let user_login_info_1
= connection
.get_user_login_info(&authentication_token
)?
;
723 assert_eq!(user_login_info_1
.ip
, "127.0.0.1");
724 assert_eq!(user_login_info_1
.user_agent
, "Mozilla");
728 match connection
.authentication(&authentication_token
, "192.168.1.1", "Chrome")?
{
729 AuthenticationResult
::Ok(user_id
) => user_id
, // Nominal case.
730 other
=> panic!("{:?}", other
),
733 // Check user login information.
734 let user_login_info_2
= connection
.get_user_login_info(&authentication_token
)?
;
735 assert_eq!(user_login_info_2
.ip
, "192.168.1.1");
736 assert_eq!(user_login_info_2
.user_agent
, "Chrome");
742 fn sign_up_then_send_validation_then_sign_out_then_sign_in() -> Result
<()> {
743 let connection
= Connection
::new_in_memory()?
;
745 let email
= "paul@atreides.com";
746 let password
= "12345";
749 let validation_token
= match connection
.sign_up(email
, password
)?
{
750 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
751 other
=> panic!("{:?}", other
),
755 let (authentication_token_1
, user_id_1
) = match connection
.validation(
761 ValidationResult
::Ok(token
, user_id
) => (token
, user_id
),
762 other
=> panic!("{:?}", other
),
765 // Check user login information.
766 let user_login_info_1
= connection
.get_user_login_info(&authentication_token_1
)?
;
767 assert_eq!(user_login_info_1
.ip
, "127.0.0.1");
768 assert_eq!(user_login_info_1
.user_agent
, "Mozilla");
771 connection
.sign_out(&authentication_token_1
)?
;
774 let (authentication_token_2
, user_id_2
) =
775 match connection
.sign_in(email
, password
, "192.168.1.1", "Chrome")?
{
776 SignInResult
::Ok(token
, user_id
) => (token
, user_id
),
777 other
=> panic!("{:?}", other
),
780 assert_eq!(user_id_1
, user_id_2
);
781 assert_ne!(authentication_token_1
, authentication_token_2
);
783 // Check user login information.
784 let user_login_info_2
= connection
.get_user_login_info(&authentication_token_2
)?
;
786 assert_eq!(user_login_info_2
.ip
, "192.168.1.1");
787 assert_eq!(user_login_info_2
.user_agent
, "Chrome");
793 fn create_a_new_recipe_then_update_its_title() -> Result
<()> {
794 let connection
= Connection
::new_in_memory()?
;
796 connection
.execute_sql(
797 "INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token]) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
802 "$argon2id$v=19$m=4096,t=3,p=1$G4fjepS05MkRbTqEImUdYg$GGziE8uVQe1L1oFHk37lBno10g4VISnVqynSkLCH3Lc",
803 "2022-11-29 22:05:04.121407300+00:00",
808 match connection
.create_recipe(2) {
809 Err(DBError
::SqliteError(Error
::SqliteFailure(
811 code
: ErrorCode
::ConstraintViolation
,
815 ))) => (), // Nominal case.
817 "Creating a recipe with an inexistant user must fail: {:?}",
822 let recipe_id
= connection
.create_recipe(1)?
;
823 assert_eq!(recipe_id
, 1);
825 connection
.set_recipe_title(recipe_id
, "Crêpe")?
;
827 let recipe
= connection
.get_recipe(recipe_id
)?
;
828 assert_eq!(recipe
.title
, "Crêpe".to_string());