Sign up method.
[recipes.git] / backend / src / db.rs
1 use std::{fmt::Display, fs::{self, File}, path::Path, io::Read};
2
3 use itertools::Itertools;
4 use chrono::{prelude::*, Duration};
5 use rusqlite::{params, Params, OptionalExtension};
6 use r2d2::Pool;
7 use r2d2_sqlite::SqliteConnectionManager;
8 use rand::distributions::{Alphanumeric, DistString};
9
10 use crate::consts;
11 use crate::hash::hash;
12 use crate::model;
13
14 const CURRENT_DB_VERSION: u32 = 1;
15
16 #[derive(Debug)]
17 pub enum DBError {
18 SqliteError(rusqlite::Error),
19 R2d2Error(r2d2::Error),
20 UnsupportedVersion(u32),
21 Other(String),
22 }
23
24 impl From<rusqlite::Error> for DBError {
25 fn from(error: rusqlite::Error) -> Self {
26 DBError::SqliteError(error)
27 }
28 }
29
30 impl From<r2d2::Error> for DBError {
31 fn from(error: r2d2::Error) -> Self {
32 DBError::R2d2Error(error)
33 }
34 }
35
36 // TODO: Is there a better solution?
37 impl DBError {
38 fn from_dyn_error(error: Box<dyn std::error::Error>) -> Self {
39 DBError::Other(error.to_string())
40 }
41 }
42
43 type Result<T> = std::result::Result<T, DBError>;
44
45 #[derive(Debug)]
46 pub enum SignUpResult {
47 UserAlreadyExists,
48 UserCreatedWaitingForValidation(String), // Validation token.
49 }
50
51 #[derive(Debug)]
52 pub enum ValidationResult {
53 ValidationExpired,
54 OK,
55 }
56
57 #[derive(Debug)]
58 pub enum SignInResult {
59 NotValidToken,
60 OK,
61 }
62
63 #[derive(Debug)]
64 pub enum AuthenticationResult {
65 NotValidToken,
66 OK,
67 }
68
69 #[derive(Clone)]
70 pub struct Connection {
71 //con: rusqlite::Connection
72 pool: Pool<SqliteConnectionManager>
73 }
74
75 impl Connection {
76 pub fn new() -> Result<Connection> {
77 let path = Path::new(consts::DB_DIRECTORY).join(consts::DB_FILENAME);
78 Self::new_from_file(path)
79 }
80
81 pub fn new_in_memory() -> Result<Connection> {
82 Self::create_connection(SqliteConnectionManager::memory())
83 }
84
85 pub fn new_from_file<P: AsRef<Path>>(file: P) -> Result<Connection> {
86 if let Some(data_dir) = file.as_ref().parent() {
87 if !data_dir.exists() {
88 fs::DirBuilder::new().create(data_dir).unwrap();
89 }
90 }
91
92 Self::create_connection(SqliteConnectionManager::file(file))
93 }
94
95 /// Called after the connection has been established for creating or updating the database.
96 /// The 'Version' table tracks the current state of the database.
97 fn create_or_update(&self) -> Result<()> {
98 // Check the Database version.
99 let mut con = self.pool.get()?;
100 let tx = con.transaction()?;
101
102 // Version 0 corresponds to an empty database.
103 let mut version = {
104 match tx.query_row(
105 "SELECT [name] FROM [sqlite_master] WHERE [type] = 'table' AND [name] = 'Version'",
106 [],
107 |row| row.get::<usize, String>(0)
108 ) {
109 Ok(_) => tx.query_row("SELECT [version] FROM [Version] ORDER BY [id] DESC", [], |row| row.get(0)).unwrap_or_default(),
110 Err(_) => 0
111 }
112 };
113
114 while Connection::update_to_next_version(version, &tx)? {
115 version += 1;
116 }
117
118 tx.commit()?;
119
120 Ok(())
121 }
122
123 fn create_connection(manager: SqliteConnectionManager) -> Result<Connection> {;
124 let pool = r2d2::Pool::new(manager).unwrap();
125 let connection = Connection { pool };
126 connection.create_or_update()?;
127 Ok(connection)
128 }
129
130 fn update_to_next_version(current_version: u32, tx: &rusqlite::Transaction) -> Result<bool> {
131 let next_version = current_version + 1;
132
133 if next_version <= CURRENT_DB_VERSION {
134 println!("Update to version {}...", next_version);
135 }
136
137 fn update_version(to_version: u32, tx: &rusqlite::Transaction) -> Result<()> {
138 tx.execute("INSERT INTO [Version] ([version], [datetime]) VALUES (?1, datetime('now'))", [to_version]).map(|_| ()).map_err(DBError::from)
139 }
140
141 fn ok(updated: bool) -> Result<bool> {
142 if updated {
143 println!("Version updated");
144 }
145 Ok(updated)
146 }
147
148 match next_version {
149 1 => {
150 let sql_file = consts::SQL_FILENAME.replace("{VERSION}", &next_version.to_string());
151 tx.execute_batch(&load_sql_file(&sql_file)?)?;
152 update_version(next_version, tx)?;
153
154 ok(true)
155 }
156
157 // Version 1 doesn't exist yet.
158 2 =>
159 ok(false),
160
161 v =>
162 Err(DBError::UnsupportedVersion(v)),
163 }
164 }
165
166 pub fn get_all_recipe_titles(&self) -> Result<Vec<(i32, String)>> {
167 let con = self.pool.get()?;
168 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
169 let titles =
170 stmt.query_map([], |row| {
171 Ok((row.get(0)?, row.get(1)?))
172 })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
173 Ok(titles)
174 }
175
176 /* Not used for the moment.
177 pub fn get_all_recipes(&self) -> Result<Vec<model::Recipe>> {
178 let con = self.pool.get()?;
179 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
180 let recipes =
181 stmt.query_map([], |row| {
182 Ok(model::Recipe::new(row.get(0)?, row.get(1)?))
183 })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
184 Ok(recipes)
185 } */
186
187 pub fn get_recipe(&self, id: i32) -> Result<model::Recipe> {
188 let con = self.pool.get()?;
189 con.query_row("SELECT [id], [title] FROM [Recipe] WHERE [id] = ?1", [id], |row| {
190 Ok(model::Recipe::new(row.get(0)?, row.get(1)?))
191 }).map_err(DBError::from)
192 }
193
194 ///
195 pub fn sign_up(&self, password: &str, email: &str) -> Result<SignUpResult> {
196 self.sign_up_with_given_time(password, email, Utc::now())
197 }
198
199 fn sign_up_with_given_time(&self, password: &str, email: &str, datetime: DateTime<Utc>) -> Result<SignUpResult> {
200 let mut con = self.pool.get()?;
201 let tx = con.transaction()?;
202 let token =
203 match tx.query_row("SELECT [id], [validation_token] FROM [User] WHERE [email] = ?1", [email], |r| {
204 Ok((r.get::<&str, i32>("id")?, r.get::<&str, Option<String>>("validation_token")?))
205 }).optional()? {
206 Some((id, validation_token)) => {
207 if validation_token.is_none() {
208 return Ok(SignUpResult::UserAlreadyExists)
209 }
210 let token = generate_token();
211 let hashed_password = hash(password).map_err(|e| DBError::from_dyn_error(e))?;
212 tx.execute("UPDATE [User] SET [validation_token] = ?2, [creation_datetime] = ?3, [password] = ?4 WHERE [id] = ?1", params![id, token, datetime, hashed_password])?;
213 token
214 },
215 None => {
216 let token = generate_token();
217 let hashed_password = hash(password).map_err(|e| DBError::from_dyn_error(e))?;
218 tx.execute("INSERT INTO [User] ([email], [validation_token], [creation_datetime], [password]) VALUES (?1, ?2, ?3, ?4)", params![email, token, datetime, hashed_password])?;
219 token
220 },
221 };
222 tx.commit()?;
223 Ok(SignUpResult::UserCreatedWaitingForValidation(token))
224 }
225
226 pub fn validation(&self, token: &str, validation_time: Duration) -> Result<ValidationResult> {
227 todo!()
228 }
229
230 pub fn sign_in(&self, password: &str, email: String) -> Result<SignInResult> {
231 todo!()
232 }
233
234 pub fn authentication(&self, token: &str) -> Result<AuthenticationResult> {
235 todo!()
236 }
237
238 pub fn logout(&self, token: &str) -> Result<()> {
239 todo!()
240 }
241
242 /// Execute a given SQL file.
243 pub fn execute_file<P: AsRef<Path> + Display>(&self, file: P) -> Result<()> {
244 let con = self.pool.get()?;
245 let sql = load_sql_file(file)?;
246 con.execute_batch(&sql).map_err(DBError::from)
247 }
248
249 /// Execute any SQL statement.
250 /// Mainly used for testing.
251 pub fn execute_sql<P: Params>(&self, sql: &str, params: P) -> Result<usize> {
252 let con = self.pool.get()?;
253 con.execute(sql, params).map_err(DBError::from)
254 }
255 }
256
257 fn load_sql_file<P: AsRef<Path> + Display>(sql_file: P) -> Result<String> {
258 let mut file = File::open(&sql_file).map_err(|err| DBError::Other(format!("Cannot open SQL file ({}): {}", &sql_file, err.to_string())))?;
259 let mut sql = String::new();
260 file.read_to_string(&mut sql).map_err(|err| DBError::Other(format!("Cannot read SQL file ({}) : {}", &sql_file, err.to_string())))?;
261 Ok(sql)
262 }
263
264 fn generate_token() -> String {
265 Alphanumeric.sample_string(&mut rand::thread_rng(), 24)
266 }
267
268 #[cfg(test)]
269 mod tests {
270 use super::*;
271
272 #[test]
273 fn sign_up() -> Result<()> {
274 let connection = Connection::new_in_memory()?;
275 match connection.sign_up("12345", "paul@test.org")? {
276 SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case.
277 other => panic!("{:?}", other),
278 }
279 Ok(())
280 }
281
282 #[test]
283 fn sign_up_to_an_already_existing_user() -> Result<()> {
284 let connection = Connection::new_in_memory()?;
285 connection.execute_sql("
286 INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
287 VALUES (
288 1,
289 'paul@test.org',
290 'paul',
291 '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
292 0,
293 NULL
294 );", [])?;
295 match connection.sign_up("12345", "paul@test.org")? {
296 SignUpResult::UserAlreadyExists => (), // Nominal case.
297 other => panic!("{:?}", other),
298 }
299 Ok(())
300 }
301
302 #[test]
303 fn sign_up_to_an_unvalidated_already_existing_user() -> Result<()> {
304 todo!()
305 }
306
307 fn sign_up_then_send_validation_at_time() -> Result<()> {
308 todo!()
309 }
310
311 fn sign_up_then_send_validation_too_late() -> Result<()> {
312 todo!()
313 }
314
315 //fn sign_up_then_send_validation_then_sign_in()
316 }