.map_err(DBError::from)
}
+ pub async fn get_all_tags(&self) -> Result<Vec<String>> {
+ sqlx::query_scalar(
+ r#"
+SELECT [name] FROM [Tag]
+ORDER BY [name]
+ "#,
+ )
+ .fetch_all(&self.pool)
+ .await
+ .map_err(DBError::from)
+ }
+
+ pub async fn get_all_tags_by_lang(&self, lang: &str) -> Result<Vec<String>> {
+ sqlx::query_scalar(
+ r#"
+SELECT DISTINCT [name] FROM [Tag]
+INNER JOIN [RecipeTag] ON [RecipeTag].[tag_id] = [Tag].[id]
+INNER JOIN [Recipe] ON [Recipe].[id] = [RecipeTag].[recipe_id]
+WHERE [Recipe].[lang] = $1
+ORDER BY [name]
+ "#,
+ )
+ .bind(lang)
+ .fetch_all(&self.pool)
+ .await
+ .map_err(DBError::from)
+ }
+
+ pub async fn get_recipes_tags(&self, recipe_id: i64) -> Result<Vec<String>> {
+ sqlx::query_scalar(
+ r#"
+SELECT [name]
+FROM [Tag]
+INNER JOIN [RecipeTag] ON [RecipeTag].[tag_id] = [Tag].[id]
+INNER JOIN [Recipe] ON [Recipe].[id] = [RecipeTag].[recipe_id]
+WHERE [Recipe].[id] = $1
+ORDER BY [name]
+ "#,
+ )
+ .bind(recipe_id)
+ .fetch_all(&self.pool)
+ .await
+ .map_err(DBError::from)
+ }
+
+ pub async fn add_recipe_tags<T>(&self, recipe_id: i64, tags: &[T]) -> Result<()>
+ where
+ T: AsRef<str>,
+ {
+ let mut tx = self.tx().await?;
+ for tag in tags {
+ let tag = tag.as_ref().trim().to_lowercase();
+ let tag_id: i64 = if let Some(tag_id) =
+ sqlx::query_scalar("SELECT [id] FROM [Tag] WHERE [name] = $1")
+ .bind(&tag)
+ .fetch_optional(&mut *tx)
+ .await?
+ {
+ tag_id
+ } else {
+ let result = sqlx::query("INSERT INTO [Tag] ([name]) VALUES ($1)")
+ .bind(&tag)
+ .execute(&mut *tx)
+ .await?;
+ result.last_insert_rowid()
+ };
+
+ sqlx::query(
+ r#"
+INSERT INTO [RecipeTag] ([recipe_id], [tag_id])
+VALUES ($1, $2)
+ON CONFLICT DO NOTHING
+ "#,
+ )
+ .bind(recipe_id)
+ .bind(tag_id)
+ .execute(&mut *tx)
+ .await?;
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ pub async fn rm_recipe_tags<T>(&self, recipe_id: i64, tags: &[T]) -> Result<()>
+ where
+ T: AsRef<str>,
+ {
+ let mut tx = self.tx().await?;
+ for tag in tags {
+ if let Some(tag_id) = sqlx::query_scalar::<_, i64>(
+ r#"
+DELETE FROM [RecipeTag]
+WHERE [id] IN (
+ SELECT [RecipeTag].[id] FROM [RecipeTag]
+ INNER JOIN [Tag] ON [Tag].[id] = [tag_id]
+ WHERE [recipe_id] = $1 AND [Tag].[name] = $2
+)
+RETURNING [RecipeTag].[tag_id]
+ "#,
+ )
+ .bind(recipe_id)
+ .bind(tag.as_ref())
+ .fetch_optional(&mut *tx)
+ .await?
+ {
+ sqlx::query(
+ r#"
+DELETE FROM [Tag]
+WHERE [id] = $1 AND [id] NOT IN (
+ SELECT [tag_id] FROM [RecipeTag]
+ WHERE [tag_id] = $1
+)
+ "#,
+ )
+ .bind(tag_id)
+ .execute(&mut *tx)
+ .await?;
+ }
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+
pub async fn set_recipe_difficulty(
&self,
recipe_id: i64,
#[cfg(test)]
mod tests {
+ use axum::routing::connect;
+
use super::*;
#[tokio::test]
).await?;
Ok(user_id)
}
+
+ #[tokio::test]
+ async fn add_and_remove_tags() -> Result<()> {
+ let connection = Connection::new_in_memory().await?;
+ let user_id = create_a_user(&connection).await?;
+ let recipe_id_1 = connection.create_recipe(user_id).await?;
+ connection.set_recipe_title(recipe_id_1, "recipe 1").await?;
+
+ let tags_1 = ["abc", "xyz"];
+ connection.add_recipe_tags(recipe_id_1, &tags_1).await?;
+
+ // Adding the same tags should do nothing.
+ connection.add_recipe_tags(recipe_id_1, &tags_1).await?;
+
+ assert_eq!(connection.get_recipes_tags(recipe_id_1).await?, tags_1);
+
+ let tags_2 = ["abc", "def", "xyz"];
+ let recipe_id_2 = connection.create_recipe(user_id).await?;
+ connection.set_recipe_title(recipe_id_2, "recipe 2").await?;
+
+ connection.add_recipe_tags(recipe_id_2, &tags_2).await?;
+
+ assert_eq!(connection.get_recipes_tags(recipe_id_1).await?, tags_1);
+ assert_eq!(connection.get_recipes_tags(recipe_id_2).await?, tags_2);
+
+ assert_eq!(connection.get_all_tags().await?, ["abc", "def", "xyz"]);
+ connection.rm_recipe_tags(recipe_id_2, &["abc"]).await?;
+ assert_eq!(connection.get_all_tags().await?, ["abc", "def", "xyz"]);
+
+ assert_eq!(
+ connection.get_recipes_tags(recipe_id_1).await?,
+ ["abc", "xyz"]
+ );
+ assert_eq!(
+ connection.get_recipes_tags(recipe_id_2).await?,
+ ["def", "xyz"]
+ );
+
+ connection.rm_recipe_tags(recipe_id_1, &["abc"]).await?;
+
+ assert_eq!(connection.get_recipes_tags(recipe_id_1).await?, ["xyz"]);
+ assert_eq!(
+ connection.get_recipes_tags(recipe_id_2).await?,
+ ["def", "xyz"]
+ );
+ assert_eq!(connection.get_all_tags().await?, ["def", "xyz"]);
+ assert_eq!(connection.get_all_tags_by_lang("en").await?, ["def", "xyz"]);
+
+ connection.rm_recipe_tags(recipe_id_1, &tags_1).await?;
+ connection.rm_recipe_tags(recipe_id_2, &tags_2).await?;
+
+ assert!(connection.get_all_tags().await?.is_empty());
+
+ Ok(())
+ }
}
const NOT_AUTHORIZED_MESSAGE: &str = "Action not authorized";
+#[derive(Deserialize)]
+pub struct RecipeId {
+ #[serde(rename = "recipe_id")]
+ id: i64,
+}
+
#[allow(dead_code)]
#[debug_handler]
pub async fn update_user(
Ok(StatusCode::OK)
}
+#[debug_handler]
+pub async fn get_tags(
+ State(connection): State<db::Connection>,
+ recipe_id: Query<RecipeId>,
+) -> Result<impl IntoResponse> {
+ Ok(ron_response(
+ StatusCode::OK,
+ common::ron_api::Tags {
+ recipe_id: recipe_id.id,
+ tags: connection.get_recipes_tags(recipe_id.id).await?,
+ },
+ ))
+}
+
+#[debug_handler]
+pub async fn add_tags(
+ State(connection): State<db::Connection>,
+ Extension(user): Extension<Option<model::User>>,
+ ExtractRon(ron): ExtractRon<common::ron_api::Tags>,
+) -> Result<impl IntoResponse> {
+ check_user_rights_recipe(&connection, &user, ron.recipe_id).await?;
+ connection.add_recipe_tags(ron.recipe_id, &ron.tags).await?;
+ Ok(StatusCode::OK)
+}
+
+#[debug_handler]
+pub async fn rm_tags(
+ State(connection): State<db::Connection>,
+ Extension(user): Extension<Option<model::User>>,
+ ExtractRon(ron): ExtractRon<common::ron_api::Tags>,
+) -> Result<impl IntoResponse> {
+ check_user_rights_recipe(&connection, &user, ron.recipe_id).await?;
+ connection.rm_recipe_tags(ron.recipe_id, &ron.tags).await?;
+ Ok(StatusCode::OK)
+}
+
#[debug_handler]
pub async fn set_difficulty(
State(connection): State<db::Connection>,
}
}
-#[derive(Deserialize)]
-pub struct RecipeId {
- #[serde(rename = "recipe_id")]
- id: i64,
-}
-
#[debug_handler]
pub async fn get_groups(
State(connection): State<db::Connection>,