1 use std
::{fmt
::Display
, fs
::{self, File
}, path
::Path
, io
::Read
};
3 use itertools
::Itertools
;
4 use chrono
::{prelude
::*, Duration
};
5 use rusqlite
::{named_params
, OptionalExtension
, params
, Params
};
7 use r2d2_sqlite
::SqliteConnectionManager
;
8 use rand
::distributions
::{Alphanumeric
, DistString
};
11 use crate::hash
::{hash
, verify_password
};
15 const CURRENT_DB_VERSION
: u32 = 1;
19 SqliteError(rusqlite
::Error
),
20 R2d2Error(r2d2
::Error
),
21 UnsupportedVersion(u32),
25 impl From
<rusqlite
::Error
> for DBError
{
26 fn from(error
: rusqlite
::Error
) -> Self {
27 DBError
::SqliteError(error
)
31 impl From
<r2d2
::Error
> for DBError
{
32 fn from(error
: r2d2
::Error
) -> Self {
33 DBError
::R2d2Error(error
)
37 // TODO: Is there a better solution?
39 fn from_dyn_error(error
: Box
<dyn std
::error
::Error
>) -> Self {
40 DBError
::Other(error
.to_string())
44 type Result
<T
> = std
::result
::Result
<T
, DBError
>;
47 pub enum SignUpResult
{
49 UserCreatedWaitingForValidation(String
), // Validation token.
53 pub enum ValidationResult
{
56 Ok(String
, i32), // Returns token and user id.
60 pub enum SignInResult
{
63 Ok(String
, i32), // Returns token and user id.
67 pub enum AuthenticationResult
{
69 Ok(i32), // Returns user id.
73 pub struct Connection
{
74 //con: rusqlite::Connection
75 pool
: Pool
<SqliteConnectionManager
>
79 pub fn new() -> Result
<Connection
> {
80 let path
= Path
::new(consts
::DB_DIRECTORY
).join(consts
::DB_FILENAME
);
81 Self::new_from_file(path
)
84 pub fn new_in_memory() -> Result
<Connection
> {
85 Self::create_connection(SqliteConnectionManager
::memory())
88 pub fn new_from_file
<P
: AsRef
<Path
>>(file
: P
) -> Result
<Connection
> {
89 if let Some(data_dir
) = file
.as_ref().parent() {
90 if !data_dir
.exists() {
91 fs
::DirBuilder
::new().create(data_dir
).unwrap();
95 Self::create_connection(SqliteConnectionManager
::file(file
))
98 fn create_connection(manager
: SqliteConnectionManager
) -> Result
<Connection
> {;
99 let pool
= r2d2
::Pool
::new(manager
).unwrap();
100 let connection
= Connection
{ pool
};
101 connection
.create_or_update()?
;
105 /// Called after the connection has been established for creating or updating the database.
106 /// The 'Version' table tracks the current state of the database.
107 fn create_or_update(&self) -> Result
<()> {
108 // Check the Database version.
109 let mut con
= self.pool
.get()?
;
110 let tx
= con
.transaction()?
;
112 // Version 0 corresponds to an empty database.
115 "SELECT [name] FROM [sqlite_master] WHERE [type] = 'table' AND [name] = 'Version'",
117 |row
| row
.get
::<usize, String
>(0)
119 Ok(_
) => tx
.query_row("SELECT [version] FROM [Version] ORDER BY [id] DESC", [], |row
| row
.get(0)).unwrap_or_default(),
124 while Self::update_to_next_version(version
, &tx
)?
{
133 fn update_to_next_version(current_version
: u32, tx
: &rusqlite
::Transaction
) -> Result
<bool
> {
134 let next_version
= current_version
+ 1;
136 if next_version
<= CURRENT_DB_VERSION
{
137 println!("Update to version {}...", next_version
);
140 fn update_version(to_version
: u32, tx
: &rusqlite
::Transaction
) -> Result
<()> {
141 tx
.execute("INSERT INTO [Version] ([version], [datetime]) VALUES (?1, datetime('now'))", [to_version
]).map(|_
| ()).map_err(DBError
::from
)
144 fn ok(updated
: bool
) -> Result
<bool
> {
146 println!("Version updated");
153 let sql_file
= consts
::SQL_FILENAME
.replace("{VERSION}", &next_version
.to_string());
154 tx
.execute_batch(&load_sql_file(&sql_file
)?
)?
;
155 update_version(next_version
, tx
)?
;
160 // Version 1 doesn't exist yet.
165 Err(DBError
::UnsupportedVersion(v
)),
169 pub fn get_all_recipe_titles(&self) -> Result
<Vec
<(i32, String
)>> {
170 let con
= self.pool
.get()?
;
171 let mut stmt
= con
.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?
;
173 stmt
.query_map([], |row
| {
174 Ok((row
.get(0)?
, row
.get(1)?
))
175 })?
.map(|r
| r
.unwrap()).collect_vec(); // TODO: remove unwrap.
179 /* Not used for the moment.
180 pub fn get_all_recipes(&self) -> Result<Vec<model::Recipe>> {
181 let con = self.pool.get()?;
182 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
184 stmt.query_map([], |row| {
185 Ok(model::Recipe::new(row.get(0)?, row.get(1)?))
186 })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
190 pub fn get_recipe(&self, id
: i32) -> Result
<model
::Recipe
> {
191 let con
= self.pool
.get()?
;
192 con
.query_row("SELECT [id], [title] FROM [Recipe] WHERE [id] = ?1", [id
], |row
| {
193 Ok(model
::Recipe
::new(row
.get(0)?
, row
.get(1)?
))
194 }).map_err(DBError
::from
)
197 pub fn get_user_login_info(&self, token
: &str) -> Result
<UserLoginInfo
> {
198 let con
= self.pool
.get()?
;
199 con
.query_row("SELECT [last_login_datetime], [ip], [user_agent] FROM [UserLoginToken] WHERE [token] = ?1", [token
], |r
| {
201 last_login_datetime
: r
.get("last_login_datetime")?
,
203 user_agent
: r
.get("user_agent")?
,
205 }).map_err(DBError
::from
)
209 pub fn sign_up(&self, password
: &str, email
: &str) -> Result
<SignUpResult
> {
210 self.sign_up_with_given_time(password
, email
, Utc
::now())
213 fn sign_up_with_given_time(&self, password
: &str, email
: &str, datetime
: DateTime
<Utc
>) -> Result
<SignUpResult
> {
214 let mut con
= self.pool
.get()?
;
215 let tx
= con
.transaction()?
;
217 match tx
.query_row("SELECT [id], [validation_token] FROM [User] WHERE [email] = ?1", [email
], |r
| {
218 Ok((r
.get
::<&str, i32>("id")?
, r
.get
::<&str, Option
<String
>>("validation_token")?
))
220 Some((id
, validation_token
)) => {
221 if validation_token
.is_none() {
222 return Ok(SignUpResult
::UserAlreadyExists
)
224 let token
= generate_token();
225 let hashed_password
= hash(password
).map_err(|e
| DBError
::from_dyn_error(e
))?
;
226 tx
.execute("UPDATE [User] SET [validation_token] = ?2, [creation_datetime] = ?3, [password] = ?4 WHERE [id] = ?1", params
![id
, token
, datetime
, hashed_password
])?
;
230 let token
= generate_token();
231 let hashed_password
= hash(password
).map_err(|e
| DBError
::from_dyn_error(e
))?
;
232 tx
.execute("INSERT INTO [User] ([email], [validation_token], [creation_datetime], [password]) VALUES (?1, ?2, ?3, ?4)", params
![email
, token
, datetime
, hashed_password
])?
;
237 Ok(SignUpResult
::UserCreatedWaitingForValidation(token
))
240 pub fn validation(&self, token
: &str, validation_time
: Duration
, ip
: &str, user_agent
: &str) -> Result
<ValidationResult
> {
241 let mut con
= self.pool
.get()?
;
242 let tx
= con
.transaction()?
;
244 match tx
.query_row("SELECT [id], [creation_datetime] FROM [User] WHERE [validation_token] = ?1", [token
], |r
| {
245 Ok((r
.get
::<&str, i32>("id")?
, r
.get
::<&str, DateTime
<Utc
>>("creation_datetime")?
))
247 Some((id
, creation_datetime
)) => {
248 if Utc
::now() - creation_datetime
> validation_time
{
249 return Ok(ValidationResult
::ValidationExpired
)
251 tx
.execute("UPDATE [User] SET [validation_token] = NULL WHERE [id] = ?1", [id
])?
;
255 return Ok(ValidationResult
::UnknownUser
)
258 let token
= Connection
::create_login_token(&tx
, user_id
, ip
, user_agent
)?
;
260 Ok(ValidationResult
::Ok(token
, user_id
))
263 pub fn sign_in(&self, password
: &str, email
: &str, ip
: &str, user_agent
: &str) -> Result
<SignInResult
> {
264 let mut con
= self.pool
.get()?
;
265 let tx
= con
.transaction()?
;
266 match tx
.query_row("SELECT [id], [password] FROM [User] WHERE [email] = ?1", [email
], |r
| {
267 Ok((r
.get
::<&str, i32>("id")?
, r
.get
::<&str, String
>("password")?
))
269 Some((id
, stored_password
)) => {
270 if verify_password(password
, &stored_password
).map_err(DBError
::from_dyn_error
)?
{
271 let token
= Connection
::create_login_token(&tx
, id
, ip
, user_agent
)?
;
273 Ok(SignInResult
::Ok(token
, id
))
275 Ok(SignInResult
::PasswordsDontMatch
)
279 Ok(SignInResult
::UserNotFound
)
284 pub fn authentication(&self, token
: &str, ip
: &str, user_agent
: &str) -> Result
<AuthenticationResult
> {
285 let mut con
= self.pool
.get()?
;
286 let tx
= con
.transaction()?
;
287 match tx
.query_row("SELECT [id], [user_id] FROM [UserLoginToken] WHERE [token] = ?1", [token
], |r
| {
288 Ok((r
.get
::<&str, i32>("id")?
, r
.get
::<&str, i32>("user_id")?
))
290 Some((login_id
, user_id
)) => {
291 tx
.execute("UPDATE [UserLoginToken] SET [last_login_datetime] = ?2, [ip] = ?3, [user_agent] = ?4 WHERE [id] = ?1", params
![login_id
, Utc
::now(), ip
, user_agent
])?
;
293 Ok(AuthenticationResult
::Ok(user_id
))
296 Ok(AuthenticationResult
::NotValidToken
)
300 pub fn sign_out(&self, token
: &str) -> Result
<()> {
301 let mut con
= self.pool
.get()?
;
302 let tx
= con
.transaction()?
;
303 match tx
.query_row("SELECT [id] FROM [UserLoginToken] WHERE [token] = ?1", [token
], |r
| {
304 Ok(r
.get
::<&str, i32>("id")?
)
307 tx
.execute("DELETE FROM [UserLoginToken] WHERE [id] = ?1", params
![login_id
])?
;
315 /// Execute a given SQL file.
316 pub fn execute_file
<P
: AsRef
<Path
> + Display
>(&self, file
: P
) -> Result
<()> {
317 let con
= self.pool
.get()?
;
318 let sql
= load_sql_file(file
)?
;
319 con
.execute_batch(&sql
).map_err(DBError
::from
)
322 /// Execute any SQL statement.
323 /// Mainly used for testing.
324 pub fn execute_sql
<P
: Params
>(&self, sql
: &str, params
: P
) -> Result
<usize> {
325 let con
= self.pool
.get()?
;
326 con
.execute(sql
, params
).map_err(DBError
::from
)
330 fn create_login_token(tx
: &rusqlite
::Transaction
, user_id
: i32, ip
: &str, user_agent
: &str) -> Result
<String
> {
331 let token
= generate_token();
332 tx
.execute("INSERT INTO [UserLoginToken] ([user_id], [last_login_datetime], [token], [ip], [user_agent]) VALUES (?1, ?2, ?3, ?4, ?5)", params
![user_id
, Utc
::now(), token
, ip
, user_agent
])?
;
337 fn load_sql_file
<P
: AsRef
<Path
> + Display
>(sql_file
: P
) -> Result
<String
> {
338 let mut file
= File
::open(&sql_file
).map_err(|err
| DBError
::Other(format!("Cannot open SQL file ({}): {}", &sql_file
, err
.to_string())))?
;
339 let mut sql
= String
::new();
340 file
.read_to_string(&mut sql
).map_err(|err
| DBError
::Other(format!("Cannot read SQL file ({}) : {}", &sql_file
, err
.to_string())))?
;
344 fn generate_token() -> String
{
345 Alphanumeric
.sample_string(&mut rand
::thread_rng(), 24)
353 fn sign_up() -> Result
<()> {
354 let connection
= Connection
::new_in_memory()?
;
355 match connection
.sign_up("12345", "paul@test.org")?
{
356 SignUpResult
::UserCreatedWaitingForValidation(_
) => (), // Nominal case.
357 other
=> panic!("{:?}", other
),
363 fn sign_up_to_an_already_existing_user() -> Result
<()> {
364 let connection
= Connection
::new_in_memory()?
;
365 connection
.execute_sql("
366 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
371 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
375 match connection
.sign_up("12345", "paul@test.org")?
{
376 SignUpResult
::UserAlreadyExists
=> (), // Nominal case.
377 other
=> panic!("{:?}", other
),
383 fn sign_up_to_an_unvalidated_already_existing_user() -> Result
<()> {
384 let connection
= Connection
::new_in_memory()?
;
385 let token
= generate_token();
386 connection
.execute_sql("
387 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
392 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
395 );", named_params
! { ":token": token
})?
;
396 match connection
.sign_up("12345", "paul@test.org")?
{
397 SignUpResult
::UserCreatedWaitingForValidation(_
) => (), // Nominal case.
398 other
=> panic!("{:?}", other
),
404 fn sign_up_then_send_validation_at_time() -> Result
<()> {
405 let connection
= Connection
::new_in_memory()?
;
406 let validation_token
=
407 match connection
.sign_up("12345", "paul@test.org")?
{
408 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
409 other
=> panic!("{:?}", other
),
411 match connection
.validation(&validation_token
, Duration
::hours(1), "127.0.0.1", "Mozilla/5.0")?
{
412 ValidationResult
::Ok(_
, _
) => (), // Nominal case.
413 other
=> panic!("{:?}", other
),
419 fn sign_up_then_send_validation_too_late() -> Result
<()> {
420 let connection
= Connection
::new_in_memory()?
;
421 let validation_token
=
422 match connection
.sign_up_with_given_time("12345", "paul@test.org", Utc
::now() - Duration
::days(1))?
{
423 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
424 other
=> panic!("{:?}", other
),
426 match connection
.validation(&validation_token
, Duration
::hours(1), "127.0.0.1", "Mozilla/5.0")?
{
427 ValidationResult
::ValidationExpired
=> (), // Nominal case.
428 other
=> panic!("{:?}", other
),
434 fn sign_up_then_send_validation_with_bad_token() -> Result
<()> {
435 let connection
= Connection
::new_in_memory()?
;
436 let _validation_token
=
437 match connection
.sign_up("12345", "paul@test.org")?
{
438 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
439 other
=> panic!("{:?}", other
),
441 let random_token
= generate_token();
442 match connection
.validation(&random_token
, Duration
::hours(1), "127.0.0.1", "Mozilla/5.0")?
{
443 ValidationResult
::UnknownUser
=> (), // Nominal case.
444 other
=> panic!("{:?}", other
),
450 fn sign_up_then_send_validation_then_sign_in() -> Result
<()> {
451 let connection
= Connection
::new_in_memory()?
;
453 let password
= "12345";
454 let email
= "paul@test.org";
457 let validation_token
=
458 match connection
.sign_up(password
, email
)?
{
459 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
460 other
=> panic!("{:?}", other
),
464 match connection
.validation(&validation_token
, Duration
::hours(1), "127.0.0.1", "Mozilla/5.0")?
{
465 ValidationResult
::Ok(_
, _
) => (),
466 other
=> panic!("{:?}", other
),
470 match connection
.sign_in(password
, email
, "127.0.0.1", "Mozilla/5.0")?
{
471 SignInResult
::Ok(_
, _
) => (), // Nominal case.
472 other
=> panic!("{:?}", other
),
479 fn sign_up_then_send_validation_then_authentication() -> Result
<()> {
480 let connection
= Connection
::new_in_memory()?
;
482 let password
= "12345";
483 let email
= "paul@test.org";
486 let validation_token
=
487 match connection
.sign_up(password
, email
)?
{
488 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
489 other
=> panic!("{:?}", other
),
493 let (authentication_token
, user_id
) = match connection
.validation(&validation_token
, Duration
::hours(1), "127.0.0.1", "Mozilla")?
{
494 ValidationResult
::Ok(token
, user_id
) => (token
, user_id
),
495 other
=> panic!("{:?}", other
),
498 // Check user login information.
499 let user_login_info_1
= connection
.get_user_login_info(&authentication_token
)?
;
500 assert_eq!(user_login_info_1
.ip
, "127.0.0.1");
501 assert_eq!(user_login_info_1
.user_agent
, "Mozilla");
505 match connection
.authentication(&authentication_token
, "192.168.1.1", "Chrome")?
{
506 AuthenticationResult
::Ok(user_id
) => user_id
, // Nominal case.
507 other
=> panic!("{:?}", other
),
510 // Check user login information.
511 let user_login_info_2
= connection
.get_user_login_info(&authentication_token
)?
;
512 assert_eq!(user_login_info_2
.ip
, "192.168.1.1");
513 assert_eq!(user_login_info_2
.user_agent
, "Chrome");
519 fn sign_up_then_send_validation_then_sign_out_then_sign_in() -> Result
<()> {
520 let connection
= Connection
::new_in_memory()?
;
522 let password
= "12345";
523 let email
= "paul@test.org";
526 let validation_token
=
527 match connection
.sign_up(password
, email
)?
{
528 SignUpResult
::UserCreatedWaitingForValidation(token
) => token
, // Nominal case.
529 other
=> panic!("{:?}", other
),
533 let (authentication_token_1
, user_id_1
) =
534 match connection
.validation(&validation_token
, Duration
::hours(1), "127.0.0.1", "Mozilla")?
{
535 ValidationResult
::Ok(token
, user_id
) => (token
, user_id
),
536 other
=> panic!("{:?}", other
),
539 // Check user login information.
540 let user_login_info_1
= connection
.get_user_login_info(&authentication_token_1
)?
;
541 assert_eq!(user_login_info_1
.ip
, "127.0.0.1");
542 assert_eq!(user_login_info_1
.user_agent
, "Mozilla");
545 connection
.sign_out(&authentication_token_1
)?
;
548 let (authentication_token_2
, user_id_2
) =
549 match connection
.sign_in(password
, email
, "192.168.1.1", "Chrome")?
{
550 SignInResult
::Ok(token
, user_id
) => (token
, user_id
),
551 other
=> panic!("{:?}", other
),
554 assert_eq!(user_id_1
, user_id_2
);
555 assert_ne!(authentication_token_1
, authentication_token_2
);
557 // Check user login information.
558 let user_login_info_2
= connection
.get_user_login_info(&authentication_token_2
)?
;
560 assert_eq!(user_login_info_2
.ip
, "192.168.1.1");
561 assert_eq!(user_login_info_2
.user_agent
, "Chrome");