Add some data access methods to Connection
[recipes.git] / backend / src / db.rs
1 use std::{fs::{self, File}, path::Path, io::Read};
2
3 use itertools::Itertools;
4 //use rusqlite::types::ToSql;
5 //use rusqlite::{Connection, Result, NO_PARAMS};
6 use r2d2::Pool;
7 use r2d2_sqlite::SqliteConnectionManager;
8
9 use crate::consts;
10 use crate::model;
11
12 const CURRENT_DB_VERSION: u32 = 1;
13
14 #[derive(Debug)]
15 pub enum DBError {
16 SqliteError(rusqlite::Error),
17 R2d2Error(r2d2::Error),
18 UnsupportedVersion(u32),
19 Other(String),
20 }
21
22 impl From<rusqlite::Error> for DBError {
23 fn from(error: rusqlite::Error) -> Self {
24 DBError::SqliteError(error)
25 }
26 }
27
28 impl From<r2d2::Error> for DBError {
29 fn from(error: r2d2::Error) -> Self {
30 DBError::R2d2Error(error)
31 }
32 }
33
34 type Result<T> = std::result::Result<T, DBError>;
35
36 #[derive(Clone)]
37 pub struct Connection {
38 //con: rusqlite::Connection
39 pool: Pool<SqliteConnectionManager>
40 }
41
42 impl Connection {
43 pub fn new() -> Result<Connection> {
44
45 let data_dir = Path::new(consts::DB_DIRECTORY);
46
47 if !data_dir.exists() {
48 fs::DirBuilder::new().create(data_dir).unwrap();
49 }
50
51 let manager = SqliteConnectionManager::file(consts::DB_FILENAME);
52 let pool = r2d2::Pool::new(manager).unwrap();
53
54 let connection = Connection { pool };
55 connection.create_or_update()?;
56 Ok(connection)
57 }
58
59 /*
60 * Called after the connection has been established for creating or updating the database.
61 * The 'Version' table tracks the current state of the database.
62 */
63 fn create_or_update(&self) -> Result<()> {
64 // Check the Database version.
65 let mut con = self.pool.get()?;
66 let tx = con.transaction()?;
67
68 // Version 0 corresponds to an empty database.
69 let mut version = {
70 match tx.query_row(
71 "SELECT [name] FROM [sqlite_master] WHERE [type] = 'table' AND [name] = 'Version'",
72 [],
73 |row| row.get::<usize, String>(0)
74 ) {
75 Ok(_) => tx.query_row("SELECT [version] FROM [Version] ORDER BY [id] DESC", [], |row| row.get(0)).unwrap_or_default(),
76 Err(_) => 0
77 }
78 };
79
80 while Connection::update_to_next_version(version, &tx)? {
81 version += 1;
82 }
83
84 tx.commit()?;
85
86 Ok(())
87 }
88
89 fn update_to_next_version(current_version: u32, tx: &rusqlite::Transaction) -> Result<bool> {
90 let next_version = current_version + 1;
91
92 if next_version <= CURRENT_DB_VERSION {
93 println!("Update to version {}...", next_version);
94 }
95
96 fn update_version(to_version: u32, tx: &rusqlite::Transaction) -> Result<()> {
97 tx.execute("INSERT INTO [Version] ([version], [datetime]) VALUES (?1, datetime('now'))", [to_version]).map(|_| ()).map_err(DBError::from)
98 }
99
100 fn ok(updated: bool) -> Result<bool> {
101 if updated {
102 println!("Version updated");
103 }
104 Ok(updated)
105 }
106
107 match next_version {
108 1 => {
109 tx.execute_batch(&load_sql_file(next_version)?)?;
110 update_version(next_version, tx)?;
111
112 ok(true)
113 }
114
115 // Version 1 doesn't exist yet.
116 2 =>
117 ok(false),
118
119 v =>
120 Err(DBError::UnsupportedVersion(v)),
121 }
122 }
123
124 pub fn get_all_recipe_titles(&self) -> Result<Vec<(i32, String)>> {
125 let con = self.pool.get()?;
126 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
127 let titles =
128 stmt.query_map([], |row| {
129 Ok((row.get(0)?, row.get(1)?))
130 })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
131 Ok(titles)
132 }
133
134 pub fn get_all_recipes(&self) -> Result<Vec<model::Recipe>> {
135 let con = self.pool.get()?;
136 let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
137 let recipes =
138 stmt.query_map([], |row| {
139 Ok(model::Recipe::new(row.get(0)?, row.get(1)?))
140 })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
141 Ok(recipes)
142 }
143
144 pub fn get_recipe(&self, id: i32) -> Result<model::Recipe> {
145 let con = self.pool.get()?;
146 con.query_row("SELECT [id], [title] FROM [Recipe] WHERE [id] = ?1", [id], |row| {
147 Ok(model::Recipe::new(row.get(0)?, row.get(1)?))
148 }).map_err(DBError::from)
149 }
150 }
151
152 fn load_sql_file(version: u32) -> Result<String> {
153 let sql_file = consts::SQL_FILENAME.replace("{VERSION}", &version.to_string());
154 let mut file = File::open(&sql_file).map_err(|err| DBError::Other(format!("Cannot open SQL file ({}): {}", &sql_file, err.to_string())))?;
155 let mut sql = String::new();
156 file.read_to_string(&mut sql).map_err(|err| DBError::Other(format!("Cannot read SQL file ({}) : {}", &sql_file, err.to_string())))?;
157 Ok(sql)
158 }