diff --git a/Cargo.lock b/Cargo.lock index df1c4709c..729d1bdbd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2461,6 +2461,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "base64 0.21.5", "bytes", "chrono", "dotenvy", @@ -2477,6 +2478,7 @@ dependencies = [ "sea-orm", "sea-orm-migration", "sha2", + "sqlx", "thiserror", "tokio", "tokio-util", diff --git a/apps/keck/tests/pg_notify.rs b/apps/keck/tests/pg_notify.rs new file mode 100644 index 000000000..d94d90ebc --- /dev/null +++ b/apps/keck/tests/pg_notify.rs @@ -0,0 +1,56 @@ +use std::{io::{BufRead, BufReader}, process::{Child, Command, Stdio}, thread::sleep, time::Duration}; + +use rand::{thread_rng, Rng}; + +fn start_server(port: u16, db: &str) -> Child { + let mut child = Command::new("cargo") + .args(["run", "-p", "keck"]) + .env("KECK_PORT", port.to_string()) + .env("DATABASE_URL", db) + .stdout(Stdio::piped()) + .spawn() + .expect("Failed to run command"); + + if let Some(ref mut stdout) = child.stdout { + let reader = BufReader::new(stdout); + for line in reader.lines() { + let line = line.expect("Failed to read line"); + if line.contains("listening on 0.0.0.0:") { + break; + } + } + } + + child +} + +#[tokio::test] +#[ignore = "requires external postgres"] +async fn blocks_consistent_between_nodes() { + let port1 = thread_rng().gen_range(20000..30000); + let port2 = port1 + 1; + let db = std::env::var("TEST_DATABASE_URL") + .unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/postgres".into()); + let c1 = start_server(port1, &db); + let c2 = start_server(port2, &db); + + let client = reqwest::Client::new(); + let ws = "ws1"; + let block = "b1"; + let url1 = format!("http://localhost:{port1}/api/block/{ws}/{block}?flavour=text"); + client + .post(url1) + .json(&serde_json::json!({"prop:text": "hi"})) + .send() + .await + .unwrap(); + sleep(Duration::from_secs(1)); + let url2 = format!("http://localhost:{port2}/api/block/{ws}/{block}"); + let resp = client.get(url2).send().await.unwrap(); + assert!(resp.status().is_success()); + let json: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(json["prop:text"], "hi"); + + unsafe { libc::kill(c1.id() as i32, libc::SIGTERM) }; + unsafe { libc::kill(c2.id() as i32, libc::SIGTERM) }; +} diff --git a/libs/jwst-storage/Cargo.toml b/libs/jwst-storage/Cargo.toml index 92609437a..b6e9f4ff5 100644 --- a/libs/jwst-storage/Cargo.toml +++ b/libs/jwst-storage/Cargo.toml @@ -40,6 +40,8 @@ chrono = { workspace = true, features = ["serde"] } futures = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["fs", "macros", "sync"] } +base64 = "0.21.4" +sqlx = { version = "0.7", default-features = false, features = ["postgres", "runtime-tokio-rustls"] } jwst-core = { workspace = true } jwst-codec = { workspace = true } diff --git a/libs/jwst-storage/src/storage/docs/database.rs b/libs/jwst-storage/src/storage/docs/database.rs index 12325d4f0..84eab0448 100644 --- a/libs/jwst-storage/src/storage/docs/database.rs +++ b/libs/jwst-storage/src/storage/docs/database.rs @@ -6,7 +6,7 @@ use sea_orm::Condition; use tokio::task::spawn_blocking; use super::{entities::prelude::*, *}; -use crate::types::JwstStorageResult; +use crate::types::{JwstStorageError, JwstStorageResult}; const MAX_TRIM_UPDATE_LIMIT: u64 = 500; @@ -38,6 +38,49 @@ impl DocDBStorage { Self::init_with_pool(pool, get_bucket(is_sqlite)).await } + #[cfg(feature = "postgres")] + pub async fn listen_remote(self: Arc, database: &str) -> JwstStorageResult<()> { + use base64::engine::general_purpose::STANDARD as BASE64; + use base64::Engine; + use sqlx::postgres::PgListener; + + if !database.starts_with("postgres") { + return Ok(()); + } + + let mut listener = PgListener::connect(database) + .await + .map_err(|e| JwstStorageError::Crud(e.to_string()))?; + listener + .listen("jwst_docs_update") + .await + .map_err(|e| JwstStorageError::Crud(e.to_string()))?; + tokio::spawn(async move { + loop { + match listener.recv().await { + Ok(notification) => { + let payload = notification.payload().to_string(); + if let Some((ws, data)) = payload.split_once(':') { + if let Ok(binary) = BASE64.decode(data) { + if let Some(workspace) = self.workspaces.write().await.get_mut(ws) { + workspace.sync_messages(vec![binary.clone()]); + } + if let Some(tx) = self.remote.read().await.get(ws) { + let _ = tx.send(binary); + } + } + } + } + Err(e) => { + warn!("pg listener error: {:?}", e); + break; + } + } + } + }); + Ok(()) + } + pub fn remote(&self) -> &RwLock>>> { &self.remote } @@ -203,14 +246,33 @@ impl DocDBStorage { Self::insert(conn, workspace, guid, blob).await?; trace!("end update: {guid}"); + let msg = encode_update_as_message(blob.into())?; trace!("update {}bytes to {}", blob.len(), guid); if let Entry::Occupied(remote) = self.remote.write().await.entry(guid.into()) { let broadcast = &remote.get(); - if broadcast.send(encode_update_as_message(blob.into())?).is_err() { + if broadcast.send(msg.clone()).is_err() { // broadcast failures are not fatal errors, only warnings are required warn!("send {guid} update to pipeline failed"); } } + #[cfg(feature = "postgres")] + if matches!(self.pool, DatabaseConnection::SqlxPostgresPoolConnection(_)) { + use base64::engine::general_purpose::STANDARD as BASE64; + use base64::Engine; + use sea_orm::{DatabaseBackend, Statement}; + let payload = format!("{}:{}", workspace, BASE64.encode(&msg)); + if let Err(e) = self + .pool + .execute(Statement::from_sql_and_values( + DatabaseBackend::Postgres, + "SELECT pg_notify('jwst_docs_update', $1)", + [payload.into()], + )) + .await + { + warn!("pg notify failed: {:?}", e); + } + } trace!("end update broadcast: {guid}"); Ok(()) diff --git a/libs/jwst-storage/src/storage/docs/mod.rs b/libs/jwst-storage/src/storage/docs/mod.rs index 50ac6a891..29df67608 100644 --- a/libs/jwst-storage/src/storage/docs/mod.rs +++ b/libs/jwst-storage/src/storage/docs/mod.rs @@ -18,11 +18,25 @@ pub struct SharedDocDBStorage(pub(super) Arc); impl SharedDocDBStorage { pub async fn init_with_pool(pool: DatabaseConnection, bucket: Arc) -> JwstStorageResult { - Ok(Self(Arc::new(DocDBStorage::init_with_pool(pool, bucket).await?))) + let storage = Arc::new(DocDBStorage::init_with_pool(pool, bucket).await?); + Ok(Self(storage)) } pub async fn init_pool(database: &str) -> JwstStorageResult { - Ok(Self(Arc::new(DocDBStorage::init_pool(database).await?))) + let storage = Arc::new(DocDBStorage::init_pool(database).await?); + #[cfg(feature = "postgres")] + if database.starts_with("postgres") { + storage.clone().listen_remote(database).await?; + } + Ok(Self(storage)) + } + + #[cfg(feature = "postgres")] + pub async fn listen_remote(&self, database: &str) -> JwstStorageResult<()> { + if database.starts_with("postgres") { + self.0.clone().listen_remote(database).await? + } + Ok(()) } pub fn remote(&self) -> &RwLock>>> { diff --git a/libs/jwst-storage/src/storage/mod.rs b/libs/jwst-storage/src/storage/mod.rs index f1c75e056..6e43f1e71 100644 --- a/libs/jwst-storage/src/storage/mod.rs +++ b/libs/jwst-storage/src/storage/mod.rs @@ -51,6 +51,10 @@ impl JwstStorage { ), }; let docs = SharedDocDBStorage::init_with_pool(pool.clone(), bucket.clone()).await?; + #[cfg(feature = "postgres")] + if database.starts_with("postgres") { + docs.listen_remote(database).await?; + } Ok(Self { pool,