Add API to manage recipe tags
authorGreg Burri <greg.burri@gmail.com>
Fri, 3 Jan 2025 22:32:54 +0000 (23:32 +0100)
committerGreg Burri <greg.burri@gmail.com>
Fri, 3 Jan 2025 22:32:54 +0000 (23:32 +0100)
backend/sql/version_1.sql
backend/src/data/db.rs
backend/src/data/db/recipe.rs
backend/src/main.rs
backend/src/services/ron.rs
common/src/ron_api.rs

index 48b063c..2db2354 100644 (file)
@@ -82,12 +82,11 @@ CREATE TABLE [RecipeTag] (
 
 CREATE TABLE [Tag] (
     [id] INTEGER PRIMARY KEY,
-    [name] TEXT NOT NULL,
+    [name] TEXT NOT NULL
      -- https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes
-    [lang] TEXT NOT NULL DEFAULT 'en'
 ) STRICT;
 
-CREATE UNIQUE INDEX [Tag_name_lang_index] ON [Tag] ([name], [lang]);
+CREATE UNIQUE INDEX [Tag_name_lang_index] ON [Tag]([name]);
 
 CREATE TABLE [Group] (
     [id] INTEGER PRIMARY KEY,
index 9797369..ab30bb8 100644 (file)
@@ -195,6 +195,13 @@ WHERE [type] = 'table' AND [name] = 'Version'
             .map(|db_result| db_result.rows_affected())
             .map_err(DBError::from)
     }
+
+    // pub async fn execute_sql_and_fetch_all<'a>(
+    //     &self,
+    //     query: sqlx::query::Query<'a, Sqlite, sqlx::sqlite::SqliteArguments<'a>>,
+    // ) -> Result<Vec<SqliteRow>> {
+    //     query.fetch_all(&self.pool).await.map_err(DBError::from)
+    // }
 }
 
 fn load_sql_file<P: AsRef<Path> + fmt::Display>(sql_file: P) -> Result<String> {
index 6aeb4d6..28608df 100644 (file)
@@ -197,6 +197,133 @@ WHERE [Recipe].[user_id] = $1
             .map_err(DBError::from)
     }
 
+    pub async fn get_all_tags(&self) -> Result<Vec<String>> {
+        sqlx::query_scalar(
+            r#"
+SELECT [name] FROM [Tag]
+ORDER BY [name]
+            "#,
+        )
+        .fetch_all(&self.pool)
+        .await
+        .map_err(DBError::from)
+    }
+
+    pub async fn get_all_tags_by_lang(&self, lang: &str) -> Result<Vec<String>> {
+        sqlx::query_scalar(
+            r#"
+SELECT DISTINCT [name] FROM [Tag]
+INNER JOIN [RecipeTag] ON [RecipeTag].[tag_id] = [Tag].[id]
+INNER JOIN [Recipe] ON [Recipe].[id] = [RecipeTag].[recipe_id]
+WHERE [Recipe].[lang] = $1
+ORDER BY [name]
+            "#,
+        )
+        .bind(lang)
+        .fetch_all(&self.pool)
+        .await
+        .map_err(DBError::from)
+    }
+
+    pub async fn get_recipes_tags(&self, recipe_id: i64) -> Result<Vec<String>> {
+        sqlx::query_scalar(
+            r#"
+SELECT [name]
+FROM [Tag]
+INNER JOIN [RecipeTag] ON [RecipeTag].[tag_id] = [Tag].[id]
+INNER JOIN [Recipe] ON [Recipe].[id] = [RecipeTag].[recipe_id]
+WHERE [Recipe].[id] = $1
+ORDER BY [name]
+            "#,
+        )
+        .bind(recipe_id)
+        .fetch_all(&self.pool)
+        .await
+        .map_err(DBError::from)
+    }
+
+    pub async fn add_recipe_tags<T>(&self, recipe_id: i64, tags: &[T]) -> Result<()>
+    where
+        T: AsRef<str>,
+    {
+        let mut tx = self.tx().await?;
+        for tag in tags {
+            let tag = tag.as_ref().trim().to_lowercase();
+            let tag_id: i64 = if let Some(tag_id) =
+                sqlx::query_scalar("SELECT [id] FROM [Tag] WHERE [name] = $1")
+                    .bind(&tag)
+                    .fetch_optional(&mut *tx)
+                    .await?
+            {
+                tag_id
+            } else {
+                let result = sqlx::query("INSERT INTO [Tag] ([name]) VALUES ($1)")
+                    .bind(&tag)
+                    .execute(&mut *tx)
+                    .await?;
+                result.last_insert_rowid()
+            };
+
+            sqlx::query(
+                r#"
+INSERT INTO [RecipeTag] ([recipe_id], [tag_id])
+VALUES ($1, $2)
+ON CONFLICT DO NOTHING
+                "#,
+            )
+            .bind(recipe_id)
+            .bind(tag_id)
+            .execute(&mut *tx)
+            .await?;
+        }
+
+        tx.commit().await?;
+
+        Ok(())
+    }
+
+    pub async fn rm_recipe_tags<T>(&self, recipe_id: i64, tags: &[T]) -> Result<()>
+    where
+        T: AsRef<str>,
+    {
+        let mut tx = self.tx().await?;
+        for tag in tags {
+            if let Some(tag_id) = sqlx::query_scalar::<_, i64>(
+                r#"
+DELETE FROM [RecipeTag]
+WHERE [id] IN (
+    SELECT [RecipeTag].[id] FROM [RecipeTag]
+    INNER JOIN [Tag] ON [Tag].[id] = [tag_id]
+    WHERE [recipe_id] = $1 AND [Tag].[name] = $2
+)
+RETURNING [RecipeTag].[tag_id]
+                "#,
+            )
+            .bind(recipe_id)
+            .bind(tag.as_ref())
+            .fetch_optional(&mut *tx)
+            .await?
+            {
+                sqlx::query(
+                    r#"
+DELETE FROM [Tag]
+WHERE [id] = $1 AND [id] NOT IN (
+    SELECT [tag_id] FROM [RecipeTag]
+    WHERE [tag_id] = $1
+)
+                    "#,
+                )
+                .bind(tag_id)
+                .execute(&mut *tx)
+                .await?;
+            }
+        }
+
+        tx.commit().await?;
+
+        Ok(())
+    }
+
     pub async fn set_recipe_difficulty(
         &self,
         recipe_id: i64,
@@ -416,6 +543,8 @@ ORDER BY [name]
 
 #[cfg(test)]
 mod tests {
+    use axum::routing::connect;
+
     use super::*;
 
     #[tokio::test]
@@ -503,4 +632,59 @@ VALUES
         ).await?;
         Ok(user_id)
     }
+
+    #[tokio::test]
+    async fn add_and_remove_tags() -> Result<()> {
+        let connection = Connection::new_in_memory().await?;
+        let user_id = create_a_user(&connection).await?;
+        let recipe_id_1 = connection.create_recipe(user_id).await?;
+        connection.set_recipe_title(recipe_id_1, "recipe 1").await?;
+
+        let tags_1 = ["abc", "xyz"];
+        connection.add_recipe_tags(recipe_id_1, &tags_1).await?;
+
+        // Adding the same tags should do nothing.
+        connection.add_recipe_tags(recipe_id_1, &tags_1).await?;
+
+        assert_eq!(connection.get_recipes_tags(recipe_id_1).await?, tags_1);
+
+        let tags_2 = ["abc", "def", "xyz"];
+        let recipe_id_2 = connection.create_recipe(user_id).await?;
+        connection.set_recipe_title(recipe_id_2, "recipe 2").await?;
+
+        connection.add_recipe_tags(recipe_id_2, &tags_2).await?;
+
+        assert_eq!(connection.get_recipes_tags(recipe_id_1).await?, tags_1);
+        assert_eq!(connection.get_recipes_tags(recipe_id_2).await?, tags_2);
+
+        assert_eq!(connection.get_all_tags().await?, ["abc", "def", "xyz"]);
+        connection.rm_recipe_tags(recipe_id_2, &["abc"]).await?;
+        assert_eq!(connection.get_all_tags().await?, ["abc", "def", "xyz"]);
+
+        assert_eq!(
+            connection.get_recipes_tags(recipe_id_1).await?,
+            ["abc", "xyz"]
+        );
+        assert_eq!(
+            connection.get_recipes_tags(recipe_id_2).await?,
+            ["def", "xyz"]
+        );
+
+        connection.rm_recipe_tags(recipe_id_1, &["abc"]).await?;
+
+        assert_eq!(connection.get_recipes_tags(recipe_id_1).await?, ["xyz"]);
+        assert_eq!(
+            connection.get_recipes_tags(recipe_id_2).await?,
+            ["def", "xyz"]
+        );
+        assert_eq!(connection.get_all_tags().await?, ["def", "xyz"]);
+        assert_eq!(connection.get_all_tags_by_lang("en").await?, ["def", "xyz"]);
+
+        connection.rm_recipe_tags(recipe_id_1, &tags_1).await?;
+        connection.rm_recipe_tags(recipe_id_2, &tags_2).await?;
+
+        assert!(connection.get_all_tags().await?.is_empty());
+
+        Ok(())
+    }
 }
index 5358bf6..bb3d485 100644 (file)
@@ -96,6 +96,9 @@ async fn main() {
             "/recipe/set_estimated_time",
             put(services::ron::set_estimated_time),
         )
+        .route("/recipe/get_tags", get(services::ron::get_tags))
+        .route("/recipe/add_tags", post(services::ron::add_tags))
+        .route("/recipe/rm_tags", delete(services::ron::rm_tags))
         .route("/recipe/set_difficulty", put(services::ron::set_difficulty))
         .route("/recipe/set_language", put(services::ron::set_language))
         .route(
index bbda7c2..73330aa 100644 (file)
@@ -16,6 +16,12 @@ use crate::{
 
 const NOT_AUTHORIZED_MESSAGE: &str = "Action not authorized";
 
+#[derive(Deserialize)]
+pub struct RecipeId {
+    #[serde(rename = "recipe_id")]
+    id: i64,
+}
+
 #[allow(dead_code)]
 #[debug_handler]
 pub async fn update_user(
@@ -169,6 +175,42 @@ pub async fn set_estimated_time(
     Ok(StatusCode::OK)
 }
 
+#[debug_handler]
+pub async fn get_tags(
+    State(connection): State<db::Connection>,
+    recipe_id: Query<RecipeId>,
+) -> Result<impl IntoResponse> {
+    Ok(ron_response(
+        StatusCode::OK,
+        common::ron_api::Tags {
+            recipe_id: recipe_id.id,
+            tags: connection.get_recipes_tags(recipe_id.id).await?,
+        },
+    ))
+}
+
+#[debug_handler]
+pub async fn add_tags(
+    State(connection): State<db::Connection>,
+    Extension(user): Extension<Option<model::User>>,
+    ExtractRon(ron): ExtractRon<common::ron_api::Tags>,
+) -> Result<impl IntoResponse> {
+    check_user_rights_recipe(&connection, &user, ron.recipe_id).await?;
+    connection.add_recipe_tags(ron.recipe_id, &ron.tags).await?;
+    Ok(StatusCode::OK)
+}
+
+#[debug_handler]
+pub async fn rm_tags(
+    State(connection): State<db::Connection>,
+    Extension(user): Extension<Option<model::User>>,
+    ExtractRon(ron): ExtractRon<common::ron_api::Tags>,
+) -> Result<impl IntoResponse> {
+    check_user_rights_recipe(&connection, &user, ron.recipe_id).await?;
+    connection.rm_recipe_tags(ron.recipe_id, &ron.tags).await?;
+    Ok(StatusCode::OK)
+}
+
 #[debug_handler]
 pub async fn set_difficulty(
     State(connection): State<db::Connection>,
@@ -260,12 +302,6 @@ impl From<model::Ingredient> for common::ron_api::Ingredient {
     }
 }
 
-#[derive(Deserialize)]
-pub struct RecipeId {
-    #[serde(rename = "recipe_id")]
-    id: i64,
-}
-
 #[debug_handler]
 pub async fn get_groups(
     State(connection): State<db::Connection>,
index 5efe510..9525dd6 100644 (file)
@@ -163,6 +163,12 @@ pub struct SetIngredientUnit {
     pub unit: String,
 }
 
+#[derive(Serialize, Deserialize, Clone)]
+pub struct Tags {
+    pub recipe_id: i64,
+    pub tags: Vec<String>,
+}
+
 #[derive(Serialize, Deserialize, Clone, Debug)]
 pub struct Group {
     pub id: i64,