From 1fb5dfae9d7cc769e3701422555333ee180c4d93 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 16 Mar 2026 16:31:57 -0700 Subject: [PATCH 01/17] Add Push Mode (Task Dispatchers and Pushers) --- Cargo.lock | 56 ++++++------ Cargo.toml | 10 +-- README.md | 2 +- src/config.rs | 20 +++++ src/dispatch.rs | 190 +++++++++++++++++++++++++++++++++++++++ src/grpc/server.rs | 8 ++ src/grpc/server_tests.rs | 42 ++++++--- src/lib.rs | 1 + src/main.rs | 52 ++++++++--- 9 files changed, 325 insertions(+), 56 deletions(-) create mode 100644 src/dispatch.rs diff --git a/Cargo.lock b/Cargo.lock index ec1f3551..d2a2ac9a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1189,7 +1189,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.6.0", + "socket2", "tokio", "tower-service", "tracing", @@ -1999,9 +1999,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" dependencies = [ "bytes", "prost-derive", @@ -2009,9 +2009,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", "itertools", @@ -2022,9 +2022,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" dependencies = [ "prost", ] @@ -2541,13 +2541,13 @@ dependencies = [ [[package]] name = "sentry_protos" -version = "0.4.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae1eac4a748b11a2bb5b342bea8546085751cf9a45e30fb1276b072bb5541e6" +version = "0.8.4" +source = "git+https://github.com/getsentry/sentry-protos?branch=george%2Fpush-taskbroker%2Fcreate-worker-service#f4cd3043b043c2f42e069104c3704177e1696504" dependencies = [ "prost", "prost-types", "tonic", + "tonic-prost", ] [[package]] @@ -2688,16 +2688,6 @@ dependencies = [ "serde", ] -[[package]] -name = "socket2" -version = "0.5.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - [[package]] name = "socket2" version = "0.6.0" @@ -3166,7 +3156,7 @@ dependencies = [ "pin-project-lite", "signal-hook-registry", "slab", - "socket2 0.6.0", + "socket2", "tokio-macros", "windows-sys 0.59.0", ] @@ -3236,9 +3226,9 @@ dependencies = [ [[package]] name = "tonic" -version = "0.13.1" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9" +checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec" dependencies = [ "async-trait", "axum", @@ -3253,8 +3243,8 @@ dependencies = [ "hyper-util", "percent-encoding", "pin-project", - "prost", - "socket2 0.5.10", + "socket2", + "sync_wrapper", "tokio", "tokio-stream", "tower", @@ -3265,14 +3255,26 @@ dependencies = [ [[package]] name = "tonic-health" -version = "0.13.1" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb87334d340313fefa513b6e60794d44a86d5f039b523229c99c323e4e19ca4b" +checksum = "f4ff0636fef47afb3ec02818f5bceb4377b8abb9d6a386aeade18bd6212f8eb7" dependencies = [ "prost", "tokio", "tokio-stream", "tonic", + "tonic-prost", +] + +[[package]] +name = "tonic-prost" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309" +dependencies = [ + "bytes", + "prost", + "tonic", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 67119a08..6ed2b7b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,8 +26,8 @@ http-body-util = "0.1.2" libsqlite3-sys = "0.30.1" metrics = "0.24.0" metrics-exporter-statsd = "0.9.0" -prost = "0.13" -prost-types = "0.13.3" +prost = "0.14" +prost-types = "0.14" rand = "0.8.5" rdkafka = { version = "0.37.0", features = ["cmake-build", "ssl"] } sentry = { version = "0.41.0", default-features = false, features = [ @@ -41,7 +41,7 @@ sentry = { version = "0.41.0", default-features = false, features = [ "tracing", "logs" ] } -sentry_protos = "0.4.11" +sentry_protos = { git = "https://github.com/getsentry/sentry-protos", branch = "george/push-taskbroker/create-worker-service" } serde = "1.0.214" serde_yaml = "0.9.34" sha2 = "0.10.8" @@ -49,8 +49,8 @@ sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "chrono", "po tokio = { version = "1.43.1", features = ["full"] } tokio-stream = { version = "0.1.16", features = ["full"] } tokio-util = "0.7.12" -tonic = "0.13" -tonic-health = "0.13" +tonic = "0.14" +tonic-health = "0.14" tower = "0.5.1" tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = [ diff --git a/README.md b/README.md index 860f446b..7a4f79bc 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ The test suite is composed of unit and integration tests in Rust, and end-to-end ```bash # Run unit/integration tests -make test +make unit-test # Run end-to-end tests make integration-test diff --git a/src/config.rs b/src/config.rs index 67b571d2..3f3bddb5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -239,6 +239,21 @@ pub struct Config { /// Enable additional metrics for the sqlite. pub enable_sqlite_status_metrics: bool, + + /// Run the taskbroker in push mode (as opposed to pull mode). + pub push_mode: bool, + + /// The number of concurrent dispatchers to run. + pub dispatchers: usize, + + /// The number of concurrent pushers each dispatcher should run. + pub pushers: usize, + + /// The size of the push queue. + pub push_queue_size: usize, + + /// The worker service endpoint. + pub worker_endpoint: String, } impl Default for Config { @@ -308,6 +323,11 @@ impl Default for Config { full_vacuum_on_upkeep: true, vacuum_interval_ms: 30000, enable_sqlite_status_metrics: true, + push_mode: false, + dispatchers: 1, + pushers: 1, + push_queue_size: 1, + worker_endpoint: "http://127.0.0.1:50052".into(), } } } diff --git a/src/dispatch.rs b/src/dispatch.rs new file mode 100644 index 00000000..97f1bcd8 --- /dev/null +++ b/src/dispatch.rs @@ -0,0 +1,190 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; +use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; + +use anyhow::Result; +use elegant_departure::get_shutdown_guard; +use prost::Message; +use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::time::sleep; +use tonic::transport::Channel; +use tracing::{debug, error, info}; + +use crate::config::Config; +use crate::store::inflight_activation::{InflightActivation, InflightActivationStore}; + +/// This data structure fetches pending activations from the store and pushes them to the worker service. Each dispatcher has... +/// - One "fetch" loop that gets a pending activation from the store, sends it to a push channel, and repeats +/// - One or more "push" loops, each of which receives an activation from a channel, pushes that activation to a worker, and repeats +pub struct TaskDispatcher { + /// Sender for every push loop. + senders: Vec>, + + /// Receiver for every push loop. + receivers: Vec>, + + /// For every pending activation, increment and send to the channel with this index. + next_sender_idx: usize, + + /// Broker configuration. + config: Arc, + + /// Broker inflight activation store. + store: Arc, +} + +impl TaskDispatcher { + /// Create a new task dispatcher. + pub fn new(config: Arc, store: Arc) -> Self { + let n = config.pushers; + + let mut senders = Vec::with_capacity(n); + let mut receivers = Vec::with_capacity(n); + let next_sender_idx = 0; + + for _ in 0..n { + let (tx, rx) = mpsc::channel(config.push_queue_size); + senders.push(tx); + receivers.push(rx); + } + + Self { + senders, + receivers, + next_sender_idx, + config, + store, + } + } + + /// Initialize push loops and dispatcher loop. + pub async fn start(mut self) -> Result<()> { + let n = self.senders.len(); + info!("Starting {n} push loops..."); + + let endpoint = self.config.worker_endpoint.clone(); + let receivers = std::mem::take(&mut self.receivers); + + // Initialize each push loop + for mut rx in receivers.into_iter() { + let endpoint = endpoint.clone(); + + tokio::spawn(async move { + let mut worker = match WorkerServiceClient::connect(endpoint).await { + Ok(w) => w, + + Err(e) => { + error!("Failed to connect to worker - {:?}", e); + return; + } + }; + + while let Some(activation) = rx.recv().await { + // Receive activation from the channel + let id = activation.id.clone(); + + // Try to push activation to the worker service + if let Err(e) = push_task(&mut worker, activation).await { + error!("Pushing activation {id} resulted in error - {:?}", e); + } else { + debug!("Activation {id} was sent to worker!"); + } + } + }); + } + + info!("Starting fetch loop..."); + let guard = get_shutdown_guard().shutdown_on_drop(); + + // Initialize the fetch loop + loop { + tokio::select! { + _ = guard.wait() => { + info!("Fetch loop received shutdown signal"); + break; + } + + _ = async { + debug!("About to fetch next activation..."); + self.fetch_activation().await; + } => {} + } + } + + info!("Activation dispatcher shutting down..."); + Ok(()) + } + + /// Grab the next pending activation from the store, mark it as processing, and send to push channel. + async fn fetch_activation(&mut self) { + let start = Instant::now(); + metrics::counter!("pusher.fetch_activation.runs").increment(1); + + debug!("Fetching next pending activation..."); + + match self.store.get_pending_activation(None, None).await { + Ok(Some(activation)) => { + let id = activation.id.clone(); + + let idx = self.next_sender_idx % self.senders.len(); + self.next_sender_idx = self.next_sender_idx.wrapping_add(1); + + if let Err(e) = self.senders[idx].send(activation).await { + error!("Failed to send activation {id} to worker - {:?}", e); + } + + metrics::histogram!("pusher.fetch_activation.duration").record(start.elapsed()); + } + + Ok(_) => { + debug!("No pending activations, sleeping briefly..."); + sleep(milliseconds(100)).await; + + metrics::histogram!("pusher.fetch_activation.duration").record(start.elapsed()); + } + + Err(e) => { + error!("Failed to fetch pending activations - {:?}", e); + sleep(milliseconds(100)).await; + + metrics::histogram!("pusher.fetch_activation.duration").record(start.elapsed()); + } + } + } +} + +/// Decode task activation and push it to a worker. +async fn push_task( + worker: &mut WorkerServiceClient, + activation: InflightActivation, +) -> Result<()> { + let start = Instant::now(); + let id = activation.id.clone(); + + // Try to decode activation (if it fails, we will see the error where `push_task` is called) + let task = TaskActivation::decode(&activation.activation as &[u8])?; + + let request = PushTaskRequest { task: Some(task) }; + + let result = match worker.push_task(request).await { + Ok(_) => { + debug!("Successfully sent activation {id} to worker service!"); + Ok(()) + } + + Err(e) => { + error!("Could not push activation {id} to worker service - {:?}", e); + Err(e.into()) + } + }; + + metrics::histogram!("pusher.push_task.duration").record(start.elapsed()); + result +} + +#[inline] +fn milliseconds(i: u64) -> Duration { + Duration::from_millis(i) +} diff --git a/src/grpc/server.rs b/src/grpc/server.rs index f5ac9292..68d06699 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -9,11 +9,13 @@ use std::sync::Arc; use std::time::Instant; use tonic::{Request, Response, Status}; +use crate::config::Config; use crate::store::inflight_activation::{InflightActivationStatus, InflightActivationStore}; use tracing::{error, instrument}; pub struct TaskbrokerServer { pub store: Arc, + pub config: Arc, } #[tonic::async_trait] @@ -23,6 +25,12 @@ impl ConsumerService for TaskbrokerServer { &self, request: Request, ) -> Result, Status> { + if self.config.push_mode { + return Err(Status::permission_denied( + "Cannot call while broker is in PUSH mode", + )); + } + let start_time = Instant::now(); let application = &request.get_ref().application; let namespace = &request.get_ref().namespace; diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index b99911a2..29e36f96 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -7,7 +7,7 @@ use sentry_protos::taskbroker::v1::{ }; use tonic::{Code, Request}; -use crate::test_utils::{create_test_store, make_activations}; +use crate::test_utils::{create_config, create_test_store, make_activations}; #[tokio::test] #[rstest] @@ -15,7 +15,9 @@ use crate::test_utils::{create_test_store, make_activations}; #[case::postgres("postgres")] async fn test_get_task(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let service = TaskbrokerServer { store }; + let config = create_config(); + + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, application: None, @@ -34,7 +36,9 @@ async fn test_get_task(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let service = TaskbrokerServer { store }; + let config = create_config(); + + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 5, // Complete @@ -53,7 +57,9 @@ async fn test_set_task_status(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_invalid(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let service = TaskbrokerServer { store }; + let config = create_config(); + + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 1, // Invalid @@ -76,10 +82,12 @@ async fn test_set_task_status_invalid(#[case] adapter: &str) { #[allow(deprecated)] async fn test_get_task_success(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(1); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, application: None, @@ -99,6 +107,8 @@ async fn test_get_task_success(#[case] adapter: &str) { #[allow(deprecated)] async fn test_get_task_with_application_success(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -108,7 +118,7 @@ async fn test_get_task_with_application_success(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, application: Some("hammers".into()), @@ -129,12 +139,14 @@ async fn test_get_task_with_application_success(#[case] adapter: &str) { #[allow(deprecated)] async fn test_get_task_with_namespace_requires_application(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(2); let namespace = activations[0].namespace.clone(); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: Some(namespace), application: None, @@ -153,10 +165,12 @@ async fn test_get_task_with_namespace_requires_application(#[case] adapter: &str #[allow(deprecated)] async fn test_set_task_status_success(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(2); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, @@ -192,6 +206,8 @@ async fn test_set_task_status_success(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_with_application(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -201,7 +217,7 @@ async fn test_set_task_status_with_application(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "id_0".to_string(), status: 5, // Complete @@ -229,6 +245,8 @@ async fn test_set_task_status_with_application(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -238,7 +256,7 @@ async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; // Request a task from an application without any activations. let request = SetTaskStatusRequest { id: "id_0".to_string(), @@ -261,12 +279,14 @@ async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_with_namespace_requires_application(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(2); let namespace = activations[0].namespace.clone(); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "id_0".to_string(), status: 5, // Complete diff --git a/src/lib.rs b/src/lib.rs index 33567944..6ff2b08d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ use clap::Parser; use std::fs; pub mod config; +pub mod dispatch; pub mod grpc; pub mod kafka; pub mod logging; diff --git a/src/main.rs b/src/main.rs index 7970939d..909fb594 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use anyhow::{Error, anyhow}; use chrono::Utc; use clap::Parser; use std::{sync::Arc, time::Duration}; +use taskbroker::dispatch::TaskDispatcher; use taskbroker::kafka::inflight_activation_batcher::{ ActivationBatcherConfig, InflightActivationBatcher, }; @@ -39,16 +40,16 @@ use taskbroker::store::postgres_activation_store::{ use taskbroker::{Args, get_version}; use tonic_health::ServingStatus; -async fn log_task_completion(name: &str, task: JoinHandle>) { +async fn log_task_completion>(name: T, task: JoinHandle>) { match task.await { Ok(Ok(())) => { - info!("Task {} completed", name); + info!("Task {} completed", name.as_ref()); } Ok(Err(e)) => { - error!("Task {} failed: {:?}", name, e); + error!("Task {} failed: {:?}", name.as_ref(), e); } Err(e) => { - error!("Task {} panicked: {:?}", name, e); + error!("Task {} panicked: {:?}", name.as_ref(), e); } } } @@ -190,22 +191,24 @@ async fn main() -> Result<(), Error> { // GRPC server let grpc_server_task = tokio::spawn({ - let grpc_store = store.clone(); - let grpc_config = config.clone(); + let store = store.clone(); + let config = config.clone(); + async move { - let addr = format!("{}:{}", grpc_config.grpc_addr, grpc_config.grpc_port) + let addr = format!("{}:{}", config.grpc_addr, config.grpc_port) .parse() .expect("Failed to parse address"); let layers = tower::ServiceBuilder::new() .layer(MetricsLayer::default()) - .layer(AuthLayer::new(&grpc_config)) + .layer(AuthLayer::new(&config)) .into_inner(); let server = Server::builder() .layer(layers) .add_service(ConsumerServiceServer::new(TaskbrokerServer { - store: grpc_store, + store, + config, })) .add_service(health_service.clone()) .serve(addr); @@ -236,7 +239,27 @@ async fn main() -> Result<(), Error> { } }); - elegant_departure::tokio::depart() + // Activation dispatchers + let dispatchers = if config.push_mode { + info!("Running in PUSH mode"); + + (0..config.dispatchers) + .map(|_| { + let store = store.clone(); + let config = config.clone(); + + tokio::spawn(async move { + let dispatcher = TaskDispatcher::new(config, store); + dispatcher.start().await + }) + }) + .collect() + } else { + info!("Running in PULL mode"); + vec![] + }; + + let mut departure = elegant_departure::tokio::depart() .on_termination() .on_sigint() .on_signal(SignalKind::hangup()) @@ -244,8 +267,13 @@ async fn main() -> Result<(), Error> { .on_completion(log_task_completion("consumer", consumer_task)) .on_completion(log_task_completion("grpc_server", grpc_server_task)) .on_completion(log_task_completion("upkeep_task", upkeep_task)) - .on_completion(log_task_completion("maintenance_task", maintenance_task)) - .await; + .on_completion(log_task_completion("maintenance_task", maintenance_task)); + + // Register each activation dispatch task + for (i, handle) in dispatchers.into_iter().enumerate() { + let task_name = format!("activation_dispatcher_{}", i); + departure = departure.on_completion(log_task_completion(task_name, handle)); + } Ok(()) } From a7286132f1ca96fd1196881d2ca11ca508a5d94a Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 16 Mar 2026 17:20:25 -0700 Subject: [PATCH 02/17] Add Unit Tests, Flush Tasks on Shutdown --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/dispatch.rs | 275 +++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 275 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d2a2ac9a..6e56936d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2542,7 +2542,7 @@ dependencies = [ [[package]] name = "sentry_protos" version = "0.8.4" -source = "git+https://github.com/getsentry/sentry-protos?branch=george%2Fpush-taskbroker%2Fcreate-worker-service#f4cd3043b043c2f42e069104c3704177e1696504" +source = "git+https://github.com/getsentry/sentry-protos#7873851032c697925dd7e532b6ad9888911f93b8" dependencies = [ "prost", "prost-types", diff --git a/Cargo.toml b/Cargo.toml index 6ed2b7b9..1e82bc3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ sentry = { version = "0.41.0", default-features = false, features = [ "tracing", "logs" ] } -sentry_protos = { git = "https://github.com/getsentry/sentry-protos", branch = "george/push-taskbroker/create-worker-service" } +sentry_protos = { git = "https://github.com/getsentry/sentry-protos" } serde = "1.0.214" serde_yaml = "0.9.34" sha2 = "0.10.8" diff --git a/src/dispatch.rs b/src/dispatch.rs index 97f1bcd8..4ed5dcf9 100644 --- a/src/dispatch.rs +++ b/src/dispatch.rs @@ -8,6 +8,7 @@ use anyhow::Result; use elegant_departure::get_shutdown_guard; use prost::Message; use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::task::JoinHandle; use tokio::time::sleep; use tonic::transport::Channel; use tracing::{debug, error, info}; @@ -59,6 +60,18 @@ impl TaskDispatcher { } } + /// Number of senders (and receivers) for testing purposes. + #[cfg(test)] + pub fn pusher_count(&self) -> usize { + self.senders.len() + } + + /// Take the receivers so a test can drain them. + #[cfg(test)] + pub fn take_receivers(&mut self) -> Vec> { + std::mem::take(&mut self.receivers) + } + /// Initialize push loops and dispatcher loop. pub async fn start(mut self) -> Result<()> { let n = self.senders.len(); @@ -67,11 +80,14 @@ impl TaskDispatcher { let endpoint = self.config.worker_endpoint.clone(); let receivers = std::mem::take(&mut self.receivers); + // Collect pusher handles so we can wait on them if shutdown is initiated + let mut handles: Vec> = Vec::with_capacity(receivers.len()); + // Initialize each push loop for mut rx in receivers.into_iter() { let endpoint = endpoint.clone(); - tokio::spawn(async move { + let handle = tokio::spawn(async move { let mut worker = match WorkerServiceClient::connect(endpoint).await { Ok(w) => w, @@ -93,6 +109,8 @@ impl TaskDispatcher { } } }); + + handles.push(handle); } info!("Starting fetch loop..."); @@ -114,11 +132,19 @@ impl TaskDispatcher { } info!("Activation dispatcher shutting down..."); + + // Close channels and drain any tasks still in the pushing pipeline + drop(std::mem::take(&mut self.senders)); + for handle in handles { + let _ = handle.await; + } + + info!("Activation dispatcher shut down."); Ok(()) } /// Grab the next pending activation from the store, mark it as processing, and send to push channel. - async fn fetch_activation(&mut self) { + pub async fn fetch_activation(&mut self) { let start = Instant::now(); metrics::counter!("pusher.fetch_activation.runs").increment(1); @@ -188,3 +214,248 @@ async fn push_task( fn milliseconds(i: u64) -> Duration { Duration::from_millis(i) } + +#[cfg(test)] +mod tests { + use std::sync::{Arc, Mutex}; + + use crate::config::Config; + use crate::store::inflight_activation::{ + FailedTasksForwarder, InflightActivation, InflightActivationStatus, + InflightActivationStore, QueryResult, + }; + use crate::test_utils::{create_test_store, make_activations}; + + use anyhow::Error; + use async_trait::async_trait; + use chrono::{DateTime, Utc}; + + use super::TaskDispatcher; + + /// Mock store that returns activations from a queue for `get_pending_activation`. + struct MockStore { + activations: Mutex>, + } + + impl MockStore { + fn new(activations: Vec) -> Arc { + Arc::new(Self { + activations: Mutex::new(activations), + }) + } + } + + #[async_trait] + impl InflightActivationStore for MockStore { + async fn vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn full_vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn db_size(&self) -> Result { + unimplemented!() + } + + async fn get_by_id(&self, _id: &str) -> Result, Error> { + unimplemented!() + } + + async fn store(&self, _batch: Vec) -> Result { + unimplemented!() + } + + async fn get_pending_activations_from_namespaces( + &self, + _application: Option<&str>, + _namespaces: Option<&[String]>, + limit: Option, + ) -> Result, Error> { + let limit = limit.unwrap_or(1) as usize; + let mut list = self.activations.lock().unwrap(); + let n = limit.min(list.len()); + + if n == 0 { + return Ok(vec![]); + } + + Ok(list.drain(..n).collect()) + } + + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { + unimplemented!() + } + + async fn count_by_status(&self, _status: InflightActivationStatus) -> Result { + Ok(self.activations.lock().unwrap().len()) + } + + async fn count(&self) -> Result { + Ok(self.activations.lock().unwrap().len()) + } + + async fn set_status( + &self, + _id: &str, + _status: InflightActivationStatus, + ) -> Result, Error> { + unimplemented!() + } + + async fn set_processing_deadline( + &self, + _id: &str, + _deadline: Option>, + ) -> Result<(), Error> { + unimplemented!() + } + + async fn delete_activation(&self, _id: &str) -> Result<(), Error> { + unimplemented!() + } + + async fn get_retry_activations(&self) -> Result, Error> { + unimplemented!() + } + + async fn clear(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn handle_processing_deadline(&self) -> Result { + unimplemented!() + } + + async fn handle_processing_attempts(&self) -> Result { + unimplemented!() + } + + async fn handle_expires_at(&self) -> Result { + unimplemented!() + } + + async fn handle_delay_until(&self) -> Result { + unimplemented!() + } + + async fn handle_failed_tasks(&self) -> Result { + unimplemented!() + } + + async fn mark_completed(&self, _ids: Vec) -> Result { + unimplemented!() + } + + async fn remove_completed(&self) -> Result { + unimplemented!() + } + + async fn remove_killswitched( + &self, + _killswitched_tasks: Vec, + ) -> Result { + unimplemented!() + } + } + + /// Asserts that a dispatcher built with X pushers has exactly X senders (and thus X receivers). + #[test] + fn pushers_x_creates_x_senders_and_receivers() { + // Use an empty mock store because we only care about construction, not fetching + let store: Arc = MockStore::new(vec![]); + + let config = Arc::new(Config { + pushers: 5, + push_queue_size: 10, + ..Config::default() + }); + + let dispatcher = TaskDispatcher::new(config, store); + + // One sender (and one receiver) per pusher + assert_eq!(dispatcher.pusher_count(), 5); + } + + /// Asserts that the fetch loop distributes activations round-robin across channels (0, 1, 2, 0, 1, 2, ...) + #[tokio::test] + async fn round_robin_sends_to_channels_0_1_2_0_1_2() { + // Six activations (id_0 .. id_5) so we get two full cycles across three channels + let activations = make_activations(6); + let store = MockStore::new(activations); + + let config = Arc::new(Config { + pushers: 3, + push_queue_size: 10, + ..Config::default() + }); + + let mut dispatcher = TaskDispatcher::new(config, store); + + // Take receivers so we can drain them - dispatcher keeps senders and will push to them + let mut receivers = dispatcher.take_receivers(); + assert_eq!(receivers.len(), 3); + + // Run the fetch loop six times - each run takes one activation from the mock and sends to next channel + for _ in 0..6 { + dispatcher.fetch_activation().await; + } + + // Receive in the same order the dispatcher sends - channel 0, then 1, then 2, then 0, 1, 2 + let mut received_by_channel: Vec> = vec![vec![], vec![], vec![]]; + for i in 0..6 { + let idx = i % 3; + let activation = receivers[idx].recv().await.expect("activation"); + received_by_channel[idx].push(activation.id.clone()); + } + + // Make sure round-robin works as intended... + // - Activations 1 and 4 go to channel 0 + // - Activations 2 and 5 go to channel 1 + // - Activations 3 and 6 go to channel 2 + assert_eq!(received_by_channel[0], &["id_0", "id_3"]); + assert_eq!(received_by_channel[1], &["id_1", "id_4"]); + assert_eq!(received_by_channel[2], &["id_2", "id_5"]); + } + + /// Asserts that after N fetch steps the store has zero pending activations (each fetch marks one as processing). + #[tokio::test] + async fn fetch_loop_drains_store() { + let activations = make_activations(3); + let store = create_test_store("sqlite").await; + + // Add activations to test store + store.store(activations).await.unwrap(); + assert_eq!(store.count_pending_activations().await.unwrap(), 3); + + let config = Arc::new(Config { + pushers: 2, + push_queue_size: 10, + ..Config::default() + }); + + let mut dispatcher = TaskDispatcher::new(config, store.clone()); + let mut receivers = dispatcher.take_receivers(); + + // Run fetch three times - each call gets one pending activation and moves it to processing + for _ in 0..3 { + dispatcher.fetch_activation().await; + } + + // Drain all activations from the channels so we've fully consumed what was fetched + let mut received = 0; + + for mut rx in receivers.drain(..) { + while rx.try_recv().is_ok() { + received += 1; + } + } + + // Have all activations been received? + assert_eq!(received, 3); + + // Real store marks as processing on `get_pending_activation` - so no pending left + assert_eq!(store.count_pending_activations().await.unwrap(), 0); + } +} From 6866ef20cb8d53874426f020c08b16b0a8b7e451 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 17 Mar 2026 09:08:26 -0700 Subject: [PATCH 03/17] Switch to Sentry Protos Release --- Cargo.lock | 5 +++-- Cargo.toml | 2 +- src/dispatch.rs | 11 +++++++++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6e56936d..cf633443 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2541,8 +2541,9 @@ dependencies = [ [[package]] name = "sentry_protos" -version = "0.8.4" -source = "git+https://github.com/getsentry/sentry-protos#7873851032c697925dd7e532b6ad9888911f93b8" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d3c4e8bca4c556eec616dc2594e519248891ca84f8bf958016c2c416223d8ff" dependencies = [ "prost", "prost-types", diff --git a/Cargo.toml b/Cargo.toml index 1e82bc3a..2d41f501 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ sentry = { version = "0.41.0", default-features = false, features = [ "tracing", "logs" ] } -sentry_protos = { git = "https://github.com/getsentry/sentry-protos" } +sentry_protos = "0.8.5" serde = "1.0.214" serde_yaml = "0.9.34" sha2 = "0.10.8" diff --git a/src/dispatch.rs b/src/dispatch.rs index 4ed5dcf9..79258e4f 100644 --- a/src/dispatch.rs +++ b/src/dispatch.rs @@ -78,6 +78,8 @@ impl TaskDispatcher { info!("Starting {n} push loops..."); let endpoint = self.config.worker_endpoint.clone(); + let callback_url = format!("{}:{}", self.config.grpc_addr, self.config.grpc_port); + let receivers = std::mem::take(&mut self.receivers); // Collect pusher handles so we can wait on them if shutdown is initiated @@ -86,6 +88,7 @@ impl TaskDispatcher { // Initialize each push loop for mut rx in receivers.into_iter() { let endpoint = endpoint.clone(); + let callback_url = callback_url.clone(); let handle = tokio::spawn(async move { let mut worker = match WorkerServiceClient::connect(endpoint).await { @@ -102,7 +105,7 @@ impl TaskDispatcher { let id = activation.id.clone(); // Try to push activation to the worker service - if let Err(e) = push_task(&mut worker, activation).await { + if let Err(e) = push_task(&mut worker, activation, callback_url.clone()).await { error!("Pushing activation {id} resulted in error - {:?}", e); } else { debug!("Activation {id} was sent to worker!"); @@ -185,6 +188,7 @@ impl TaskDispatcher { async fn push_task( worker: &mut WorkerServiceClient, activation: InflightActivation, + callback_url: String, ) -> Result<()> { let start = Instant::now(); let id = activation.id.clone(); @@ -192,7 +196,10 @@ async fn push_task( // Try to decode activation (if it fails, we will see the error where `push_task` is called) let task = TaskActivation::decode(&activation.activation as &[u8])?; - let request = PushTaskRequest { task: Some(task) }; + let request = PushTaskRequest { + task: Some(task), + callback_url, + }; let result = match worker.push_task(request).await { Ok(_) => { From 0a53d58971d7b666336069b3afbaa7e7f792f577 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 17 Mar 2026 17:33:02 -0700 Subject: [PATCH 04/17] Replace Dispatcher w/Separate Fetch and Push Pools --- Cargo.lock | 20 ++- Cargo.toml | 1 + src/config.rs | 8 +- src/dispatch.rs | 468 ------------------------------------------------ src/fetch.rs | 111 ++++++++++++ src/lib.rs | 3 +- src/main.rs | 46 ++--- src/push.rs | 146 +++++++++++++++ 8 files changed, 307 insertions(+), 496 deletions(-) delete mode 100644 src/dispatch.rs create mode 100644 src/fetch.rs create mode 100644 src/push.rs diff --git a/Cargo.lock b/Cargo.lock index cf633443..9d26de75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -741,6 +741,9 @@ name = "fastrand" version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +dependencies = [ + "getrandom 0.2.16", +] [[package]] name = "figment" @@ -781,6 +784,18 @@ dependencies = [ "spin", ] +[[package]] +name = "flume" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be" +dependencies = [ + "fastrand", + "futures-core", + "futures-sink", + "spin", +] + [[package]] name = "fnv" version = "1.0.7" @@ -940,8 +955,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.1+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -2893,7 +2910,7 @@ checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea" dependencies = [ "atoi", "chrono", - "flume", + "flume 0.11.1", "futures-channel", "futures-core", "futures-executor", @@ -2983,6 +3000,7 @@ dependencies = [ "derive_builder", "elegant-departure", "figment", + "flume 0.12.0", "futures", "futures-util", "hex", diff --git a/Cargo.toml b/Cargo.toml index 2d41f501..a9758e93 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ clap = { version = "4.5.20", features = ["derive"] } derive_builder = "0.20.2" elegant-departure = { version = "0.3.1", features = ["tokio"] } figment = { version = "0.10.19", features = ["env", "yaml", "test"] } +flume = "0.12.0" futures = "0.3.31" futures-util = "0.3.31" hex = "0.4.3" diff --git a/src/config.rs b/src/config.rs index 3f3bddb5..e9c03dfe 100644 --- a/src/config.rs +++ b/src/config.rs @@ -244,10 +244,10 @@ pub struct Config { pub push_mode: bool, /// The number of concurrent dispatchers to run. - pub dispatchers: usize, + pub fetch_threads: usize, /// The number of concurrent pushers each dispatcher should run. - pub pushers: usize, + pub push_threads: usize, /// The size of the push queue. pub push_queue_size: usize, @@ -324,8 +324,8 @@ impl Default for Config { vacuum_interval_ms: 30000, enable_sqlite_status_metrics: true, push_mode: false, - dispatchers: 1, - pushers: 1, + fetch_threads: 1, + push_threads: 1, push_queue_size: 1, worker_endpoint: "http://127.0.0.1:50052".into(), } diff --git a/src/dispatch.rs b/src/dispatch.rs deleted file mode 100644 index 79258e4f..00000000 --- a/src/dispatch.rs +++ /dev/null @@ -1,468 +0,0 @@ -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; -use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; - -use anyhow::Result; -use elegant_departure::get_shutdown_guard; -use prost::Message; -use tokio::sync::mpsc::{self, Receiver, Sender}; -use tokio::task::JoinHandle; -use tokio::time::sleep; -use tonic::transport::Channel; -use tracing::{debug, error, info}; - -use crate::config::Config; -use crate::store::inflight_activation::{InflightActivation, InflightActivationStore}; - -/// This data structure fetches pending activations from the store and pushes them to the worker service. Each dispatcher has... -/// - One "fetch" loop that gets a pending activation from the store, sends it to a push channel, and repeats -/// - One or more "push" loops, each of which receives an activation from a channel, pushes that activation to a worker, and repeats -pub struct TaskDispatcher { - /// Sender for every push loop. - senders: Vec>, - - /// Receiver for every push loop. - receivers: Vec>, - - /// For every pending activation, increment and send to the channel with this index. - next_sender_idx: usize, - - /// Broker configuration. - config: Arc, - - /// Broker inflight activation store. - store: Arc, -} - -impl TaskDispatcher { - /// Create a new task dispatcher. - pub fn new(config: Arc, store: Arc) -> Self { - let n = config.pushers; - - let mut senders = Vec::with_capacity(n); - let mut receivers = Vec::with_capacity(n); - let next_sender_idx = 0; - - for _ in 0..n { - let (tx, rx) = mpsc::channel(config.push_queue_size); - senders.push(tx); - receivers.push(rx); - } - - Self { - senders, - receivers, - next_sender_idx, - config, - store, - } - } - - /// Number of senders (and receivers) for testing purposes. - #[cfg(test)] - pub fn pusher_count(&self) -> usize { - self.senders.len() - } - - /// Take the receivers so a test can drain them. - #[cfg(test)] - pub fn take_receivers(&mut self) -> Vec> { - std::mem::take(&mut self.receivers) - } - - /// Initialize push loops and dispatcher loop. - pub async fn start(mut self) -> Result<()> { - let n = self.senders.len(); - info!("Starting {n} push loops..."); - - let endpoint = self.config.worker_endpoint.clone(); - let callback_url = format!("{}:{}", self.config.grpc_addr, self.config.grpc_port); - - let receivers = std::mem::take(&mut self.receivers); - - // Collect pusher handles so we can wait on them if shutdown is initiated - let mut handles: Vec> = Vec::with_capacity(receivers.len()); - - // Initialize each push loop - for mut rx in receivers.into_iter() { - let endpoint = endpoint.clone(); - let callback_url = callback_url.clone(); - - let handle = tokio::spawn(async move { - let mut worker = match WorkerServiceClient::connect(endpoint).await { - Ok(w) => w, - - Err(e) => { - error!("Failed to connect to worker - {:?}", e); - return; - } - }; - - while let Some(activation) = rx.recv().await { - // Receive activation from the channel - let id = activation.id.clone(); - - // Try to push activation to the worker service - if let Err(e) = push_task(&mut worker, activation, callback_url.clone()).await { - error!("Pushing activation {id} resulted in error - {:?}", e); - } else { - debug!("Activation {id} was sent to worker!"); - } - } - }); - - handles.push(handle); - } - - info!("Starting fetch loop..."); - let guard = get_shutdown_guard().shutdown_on_drop(); - - // Initialize the fetch loop - loop { - tokio::select! { - _ = guard.wait() => { - info!("Fetch loop received shutdown signal"); - break; - } - - _ = async { - debug!("About to fetch next activation..."); - self.fetch_activation().await; - } => {} - } - } - - info!("Activation dispatcher shutting down..."); - - // Close channels and drain any tasks still in the pushing pipeline - drop(std::mem::take(&mut self.senders)); - for handle in handles { - let _ = handle.await; - } - - info!("Activation dispatcher shut down."); - Ok(()) - } - - /// Grab the next pending activation from the store, mark it as processing, and send to push channel. - pub async fn fetch_activation(&mut self) { - let start = Instant::now(); - metrics::counter!("pusher.fetch_activation.runs").increment(1); - - debug!("Fetching next pending activation..."); - - match self.store.get_pending_activation(None, None).await { - Ok(Some(activation)) => { - let id = activation.id.clone(); - - let idx = self.next_sender_idx % self.senders.len(); - self.next_sender_idx = self.next_sender_idx.wrapping_add(1); - - if let Err(e) = self.senders[idx].send(activation).await { - error!("Failed to send activation {id} to worker - {:?}", e); - } - - metrics::histogram!("pusher.fetch_activation.duration").record(start.elapsed()); - } - - Ok(_) => { - debug!("No pending activations, sleeping briefly..."); - sleep(milliseconds(100)).await; - - metrics::histogram!("pusher.fetch_activation.duration").record(start.elapsed()); - } - - Err(e) => { - error!("Failed to fetch pending activations - {:?}", e); - sleep(milliseconds(100)).await; - - metrics::histogram!("pusher.fetch_activation.duration").record(start.elapsed()); - } - } - } -} - -/// Decode task activation and push it to a worker. -async fn push_task( - worker: &mut WorkerServiceClient, - activation: InflightActivation, - callback_url: String, -) -> Result<()> { - let start = Instant::now(); - let id = activation.id.clone(); - - // Try to decode activation (if it fails, we will see the error where `push_task` is called) - let task = TaskActivation::decode(&activation.activation as &[u8])?; - - let request = PushTaskRequest { - task: Some(task), - callback_url, - }; - - let result = match worker.push_task(request).await { - Ok(_) => { - debug!("Successfully sent activation {id} to worker service!"); - Ok(()) - } - - Err(e) => { - error!("Could not push activation {id} to worker service - {:?}", e); - Err(e.into()) - } - }; - - metrics::histogram!("pusher.push_task.duration").record(start.elapsed()); - result -} - -#[inline] -fn milliseconds(i: u64) -> Duration { - Duration::from_millis(i) -} - -#[cfg(test)] -mod tests { - use std::sync::{Arc, Mutex}; - - use crate::config::Config; - use crate::store::inflight_activation::{ - FailedTasksForwarder, InflightActivation, InflightActivationStatus, - InflightActivationStore, QueryResult, - }; - use crate::test_utils::{create_test_store, make_activations}; - - use anyhow::Error; - use async_trait::async_trait; - use chrono::{DateTime, Utc}; - - use super::TaskDispatcher; - - /// Mock store that returns activations from a queue for `get_pending_activation`. - struct MockStore { - activations: Mutex>, - } - - impl MockStore { - fn new(activations: Vec) -> Arc { - Arc::new(Self { - activations: Mutex::new(activations), - }) - } - } - - #[async_trait] - impl InflightActivationStore for MockStore { - async fn vacuum_db(&self) -> Result<(), Error> { - unimplemented!() - } - - async fn full_vacuum_db(&self) -> Result<(), Error> { - unimplemented!() - } - - async fn db_size(&self) -> Result { - unimplemented!() - } - - async fn get_by_id(&self, _id: &str) -> Result, Error> { - unimplemented!() - } - - async fn store(&self, _batch: Vec) -> Result { - unimplemented!() - } - - async fn get_pending_activations_from_namespaces( - &self, - _application: Option<&str>, - _namespaces: Option<&[String]>, - limit: Option, - ) -> Result, Error> { - let limit = limit.unwrap_or(1) as usize; - let mut list = self.activations.lock().unwrap(); - let n = limit.min(list.len()); - - if n == 0 { - return Ok(vec![]); - } - - Ok(list.drain(..n).collect()) - } - - async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { - unimplemented!() - } - - async fn count_by_status(&self, _status: InflightActivationStatus) -> Result { - Ok(self.activations.lock().unwrap().len()) - } - - async fn count(&self) -> Result { - Ok(self.activations.lock().unwrap().len()) - } - - async fn set_status( - &self, - _id: &str, - _status: InflightActivationStatus, - ) -> Result, Error> { - unimplemented!() - } - - async fn set_processing_deadline( - &self, - _id: &str, - _deadline: Option>, - ) -> Result<(), Error> { - unimplemented!() - } - - async fn delete_activation(&self, _id: &str) -> Result<(), Error> { - unimplemented!() - } - - async fn get_retry_activations(&self) -> Result, Error> { - unimplemented!() - } - - async fn clear(&self) -> Result<(), Error> { - unimplemented!() - } - - async fn handle_processing_deadline(&self) -> Result { - unimplemented!() - } - - async fn handle_processing_attempts(&self) -> Result { - unimplemented!() - } - - async fn handle_expires_at(&self) -> Result { - unimplemented!() - } - - async fn handle_delay_until(&self) -> Result { - unimplemented!() - } - - async fn handle_failed_tasks(&self) -> Result { - unimplemented!() - } - - async fn mark_completed(&self, _ids: Vec) -> Result { - unimplemented!() - } - - async fn remove_completed(&self) -> Result { - unimplemented!() - } - - async fn remove_killswitched( - &self, - _killswitched_tasks: Vec, - ) -> Result { - unimplemented!() - } - } - - /// Asserts that a dispatcher built with X pushers has exactly X senders (and thus X receivers). - #[test] - fn pushers_x_creates_x_senders_and_receivers() { - // Use an empty mock store because we only care about construction, not fetching - let store: Arc = MockStore::new(vec![]); - - let config = Arc::new(Config { - pushers: 5, - push_queue_size: 10, - ..Config::default() - }); - - let dispatcher = TaskDispatcher::new(config, store); - - // One sender (and one receiver) per pusher - assert_eq!(dispatcher.pusher_count(), 5); - } - - /// Asserts that the fetch loop distributes activations round-robin across channels (0, 1, 2, 0, 1, 2, ...) - #[tokio::test] - async fn round_robin_sends_to_channels_0_1_2_0_1_2() { - // Six activations (id_0 .. id_5) so we get two full cycles across three channels - let activations = make_activations(6); - let store = MockStore::new(activations); - - let config = Arc::new(Config { - pushers: 3, - push_queue_size: 10, - ..Config::default() - }); - - let mut dispatcher = TaskDispatcher::new(config, store); - - // Take receivers so we can drain them - dispatcher keeps senders and will push to them - let mut receivers = dispatcher.take_receivers(); - assert_eq!(receivers.len(), 3); - - // Run the fetch loop six times - each run takes one activation from the mock and sends to next channel - for _ in 0..6 { - dispatcher.fetch_activation().await; - } - - // Receive in the same order the dispatcher sends - channel 0, then 1, then 2, then 0, 1, 2 - let mut received_by_channel: Vec> = vec![vec![], vec![], vec![]]; - for i in 0..6 { - let idx = i % 3; - let activation = receivers[idx].recv().await.expect("activation"); - received_by_channel[idx].push(activation.id.clone()); - } - - // Make sure round-robin works as intended... - // - Activations 1 and 4 go to channel 0 - // - Activations 2 and 5 go to channel 1 - // - Activations 3 and 6 go to channel 2 - assert_eq!(received_by_channel[0], &["id_0", "id_3"]); - assert_eq!(received_by_channel[1], &["id_1", "id_4"]); - assert_eq!(received_by_channel[2], &["id_2", "id_5"]); - } - - /// Asserts that after N fetch steps the store has zero pending activations (each fetch marks one as processing). - #[tokio::test] - async fn fetch_loop_drains_store() { - let activations = make_activations(3); - let store = create_test_store("sqlite").await; - - // Add activations to test store - store.store(activations).await.unwrap(); - assert_eq!(store.count_pending_activations().await.unwrap(), 3); - - let config = Arc::new(Config { - pushers: 2, - push_queue_size: 10, - ..Config::default() - }); - - let mut dispatcher = TaskDispatcher::new(config, store.clone()); - let mut receivers = dispatcher.take_receivers(); - - // Run fetch three times - each call gets one pending activation and moves it to processing - for _ in 0..3 { - dispatcher.fetch_activation().await; - } - - // Drain all activations from the channels so we've fully consumed what was fetched - let mut received = 0; - - for mut rx in receivers.drain(..) { - while rx.try_recv().is_ok() { - received += 1; - } - } - - // Have all activations been received? - assert_eq!(received, 3); - - // Real store marks as processing on `get_pending_activation` - so no pending left - assert_eq!(store.count_pending_activations().await.unwrap(), 0); - } -} diff --git a/src/fetch.rs b/src/fetch.rs new file mode 100644 index 00000000..10cfbbf4 --- /dev/null +++ b/src/fetch.rs @@ -0,0 +1,111 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use anyhow::Result; +use elegant_departure::get_shutdown_guard; +use tokio::time::sleep; +use tracing::{debug, error, info}; + +use crate::config::Config; +use crate::push::PushPool; +use crate::store::inflight_activation::InflightActivationStore; + +/// Wrapper around `config.fetch_threads` asynchronous tasks, each of which fetches a pending activation from the store, passes is to the push pool, and repeats. +pub struct FetchPool { + /// Inflight activation store. + store: Arc, + + /// Pool of push threads that push activations to the worker service. + push_pool: Arc, + + /// Taskbroker configuration. + config: Arc, +} + +impl FetchPool { + /// Initialize a new fetch pool. + pub fn new( + store: Arc, + config: Arc, + push_pool: Arc, + ) -> Self { + Self { + store, + push_pool, + config, + } + } + + /// Spawn `config.fetch_threads` asynchronous tasks, each of which repeatedly moves pending activations from the store to the push pool until the shutdown signal is received. + pub async fn start(&self) -> Result<()> { + let mut handles = vec![]; + + for _ in 0..self.config.fetch_threads.max(1) { + let guard = get_shutdown_guard().shutdown_on_drop(); + + let store = self.store.clone(); + let push_pool = self.push_pool.clone(); + + let handle = tokio::spawn(async move { + loop { + tokio::select! { + _ = guard.wait() => { + info!("Fetch loop received shutdown signal"); + break; + } + + _ = async { + debug!("About to fetch next activation..."); + fetch_activations(store.clone(), push_pool.clone()).await; + } => {} + } + } + }); + + handles.push(handle); + } + + for handle in handles { + if let Err(e) = handle.await { + return Err(e.into()); + } + } + + Ok(()) + } +} + +/// Grab the next pending activation from the store, mark it as processing, and send to push channel. +pub async fn fetch_activations(store: Arc, push_pool: Arc) { + let start = Instant::now(); + metrics::counter!("fetch.fetch_activations.runs").increment(1); + + debug!("Fetching next pending activation..."); + + match store.get_pending_activation(None, None).await { + Ok(Some(activation)) => { + let id = activation.id.clone(); + debug!("Atomically fetched and marked task {id} as processing"); + + if let Err(e) = push_pool.submit(activation).await { + error!("Failed to submit task {id} to push pool - {:?}", e); + } + + metrics::histogram!("fetch.fetch_activations.duration").record(start.elapsed()); + } + + Ok(_) => { + debug!("No pending activations, sleeping briefly..."); + sleep(Duration::from_millis(100)).await; + + metrics::histogram!("fetch.fetch_activations.duration").record(start.elapsed()); + } + + Err(e) => { + error!("Failed to fetch pending activation - {:?}", e); + sleep(Duration::from_millis(100)).await; + + metrics::histogram!("fetch.fetch_activations.duration").record(start.elapsed()); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 6ff2b08d..baf480d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,11 +2,12 @@ use clap::Parser; use std::fs; pub mod config; -pub mod dispatch; +pub mod fetch; pub mod grpc; pub mod kafka; pub mod logging; pub mod metrics; +pub mod push; pub mod runtime_config; pub mod store; pub mod test_utils; diff --git a/src/main.rs b/src/main.rs index 909fb594..4efdbc58 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,10 +2,11 @@ use anyhow::{Error, anyhow}; use chrono::Utc; use clap::Parser; use std::{sync::Arc, time::Duration}; -use taskbroker::dispatch::TaskDispatcher; +use taskbroker::fetch::FetchPool; use taskbroker::kafka::inflight_activation_batcher::{ ActivationBatcherConfig, InflightActivationBatcher, }; +use taskbroker::push::PushPool; use taskbroker::upkeep::upkeep; use tokio::signal::unix::SignalKind; use tokio::task::JoinHandle; @@ -239,24 +240,22 @@ async fn main() -> Result<(), Error> { } }); - // Activation dispatchers - let dispatchers = if config.push_mode { - info!("Running in PUSH mode"); - - (0..config.dispatchers) - .map(|_| { - let store = store.clone(); - let config = config.clone(); - - tokio::spawn(async move { - let dispatcher = TaskDispatcher::new(config, store); - dispatcher.start().await - }) - }) - .collect() + // Initialize push and fetch pools + let push_pool = Arc::new(PushPool::new(config.clone())); + let fetch_pool = FetchPool::new(store.clone(), config.clone(), push_pool.clone()); + + // Initialize push threads + let push_task = if config.push_mode { + Some(tokio::spawn(async move { push_pool.start().await })) + } else { + None + }; + + // Initialize fetch threads + let fetch_task = if config.push_mode { + Some(tokio::spawn(async move { fetch_pool.start().await })) } else { - info!("Running in PULL mode"); - vec![] + None }; let mut departure = elegant_departure::tokio::depart() @@ -269,11 +268,14 @@ async fn main() -> Result<(), Error> { .on_completion(log_task_completion("upkeep_task", upkeep_task)) .on_completion(log_task_completion("maintenance_task", maintenance_task)); - // Register each activation dispatch task - for (i, handle) in dispatchers.into_iter().enumerate() { - let task_name = format!("activation_dispatcher_{}", i); - departure = departure.on_completion(log_task_completion(task_name, handle)); + if let Some(task) = push_task { + departure = departure.on_completion(log_task_completion("push_task", task)); + } + + if let Some(task) = fetch_task { + departure = departure.on_completion(log_task_completion("fetch_task", task)); } + departure.await; Ok(()) } diff --git a/src/push.rs b/src/push.rs new file mode 100644 index 00000000..6ffa23bc --- /dev/null +++ b/src/push.rs @@ -0,0 +1,146 @@ +use std::sync::Arc; +use std::time::Instant; + +use anyhow::Result; +use elegant_departure::get_shutdown_guard; +use flume::{Receiver, Sender}; +use prost::Message; +use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; +use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; +use tonic::transport::Channel; +use tracing::{debug, error, info}; + +use crate::config::Config; +use crate::store::inflight_activation::InflightActivation; + +/// Wrapper around `config.push_threads` asynchronous tasks, each of which receives an activation from the channel, sends it to the worker service, and repeats. +pub struct PushPool { + /// The sending end of a channel that accepts task activations. + sender: Sender, + + /// The receiving end of a channel that accepts task activations. + receiver: Receiver, + + /// Taskbroker configuration. + config: Arc, +} + +impl PushPool { + /// Initialize a new push pool. + pub fn new(config: Arc) -> Self { + let (sender, receiver) = flume::bounded(config.push_queue_size); + + Self { + sender, + receiver, + config, + } + } + + /// Spawn `config.push_threads` asynchronous tasks, each of which repeatedly moves pending activations from the channel to the worker service until the shutdown signal is received. + pub async fn start(&self) -> Result<()> { + let mut handles = vec![]; + + for _ in 0..self.config.push_threads { + let endpoint = self.config.worker_endpoint.clone(); + + let callback_url = format!("{}:{}", self.config.grpc_addr, self.config.grpc_port); + let receiver = self.receiver.clone(); + let guard = get_shutdown_guard().shutdown_on_drop(); + + let handle = tokio::spawn(async move { + let mut worker = match WorkerServiceClient::connect(endpoint).await { + Ok(w) => w, + Err(e) => { + error!("Failed to connect to worker - {:?}", e); + return; + } + }; + + loop { + tokio::select! { + _ = guard.wait() => { + info!("Push worker received shutdown signal"); + break; + } + + message = receiver.recv_async() => { + let activation = match message { + // Received activation from fetch thread + Ok(a) => a, + + // Channel closed + Err(_) => break + }; + + let id = activation.id.clone(); + + match push_task(&mut worker, activation, callback_url.clone()).await { + Ok(_) => debug!("Activation {id} was sent to worker!"), + Err(e) => error!("Pushing activation {id} resulted in error - {:?}", e) + }; + } + } + } + + // Drain channel before exiting + for activation in receiver.drain() { + let id = activation.id.clone(); + + match push_task(&mut worker, activation, callback_url.clone()).await { + Ok(_) => debug!("Activation {id} was sent to worker!"), + Err(e) => error!("Pushing activation {id} resulted in error - {:?}", e), + }; + } + }); + + handles.push(handle); + } + + for handle in handles { + if let Err(e) = handle.await { + return Err(e.into()); + } + } + + Ok(()) + } + + /// Send an activation to the internal asynchronous MPMC channel used by all running push threads. + pub async fn submit(&self, activation: InflightActivation) -> Result<()> { + Ok(self.sender.send_async(activation).await?) + } +} + +/// Decode task activation and push it to a worker. +async fn push_task( + worker: &mut WorkerServiceClient, + activation: InflightActivation, + callback_url: String, +) -> Result<()> { + let start = Instant::now(); + let id = activation.id.clone(); + + // Try to decode activation (if it fails, we will see the error where `push_task` is called) + let task = TaskActivation::decode(&activation.activation as &[u8])?; + + let request = PushTaskRequest { + task: Some(task), + callback_url, + }; + + let result = match worker.push_task(request).await { + Ok(_) => { + debug!("Successfully sent activation {id} to worker service!"); + Ok(()) + } + + Err(e) => { + error!("Could not push activation {id} to worker service - {:?}", e); + Err(e.into()) + } + }; + + metrics::histogram!("pusher.push_task.duration").record(start.elapsed()); + result +} From 043559ad3b320d7c5e2ce9ecc477f8286cd3e4c2 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 17 Mar 2026 17:36:19 -0700 Subject: [PATCH 05/17] Initialize gRPC Server w/`0.0.0.0` --- src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index 4efdbc58..0a2ca923 100644 --- a/src/main.rs +++ b/src/main.rs @@ -196,7 +196,7 @@ async fn main() -> Result<(), Error> { let config = config.clone(); async move { - let addr = format!("{}:{}", config.grpc_addr, config.grpc_port) + let addr = format!("0.0.0.0:{}", config.grpc_port) .parse() .expect("Failed to parse address"); From f70cc3a0171c8e987eb4865b0552c9f10a1bf0fd Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 17 Mar 2026 19:13:28 -0700 Subject: [PATCH 06/17] Add `PushPool` Unit Tests --- src/config.rs | 52 +++++++++++++++++ src/push.rs | 153 ++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 200 insertions(+), 5 deletions(-) diff --git a/src/config.rs b/src/config.rs index e9c03dfe..2831cae1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -254,6 +254,12 @@ pub struct Config { /// The worker service endpoint. pub worker_endpoint: String, + + /// The hostname used to construct `callback_url` for task push requests. + pub callback_addr: String, + + /// The port used to construct `callback_url` for task push requests. + pub callback_port: u32, } impl Default for Config { @@ -328,6 +334,8 @@ impl Default for Config { push_threads: 1, push_queue_size: 1, worker_endpoint: "http://127.0.0.1:50052".into(), + callback_addr: "0.0.0.0".into(), + callback_port: 50051, } } } @@ -732,4 +740,48 @@ mod tests { Ok(()) }); } + + #[test] + fn test_default_push_callback_fields() { + let config = Config::default(); + assert_eq!(config.callback_addr, "0.0.0.0"); + assert_eq!(config.callback_port, 50051); + } + + #[test] + fn test_from_args_push_callback_fields_from_env() { + Jail::expect_with(|jail| { + jail.set_env("TASKBROKER_CALLBACK_ADDR", "127.0.0.1"); + jail.set_env("TASKBROKER_CALLBACK_PORT", "51000"); + + let args = Args { config: None }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.callback_addr, "127.0.0.1"); + assert_eq!(config.callback_port, 51000); + + Ok(()) + }); + } + + #[test] + fn test_from_args_push_callback_fields_from_config_file() { + Jail::expect_with(|jail| { + jail.create_file( + "config.yaml", + r#" + callback_addr: 10.0.0.1 + callback_port: 52000 + "#, + )?; + + let args = Args { + config: Some("config.yaml".to_owned()), + }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.callback_addr, "10.0.0.1"); + assert_eq!(config.callback_port, 52000); + + Ok(()) + }); + } } diff --git a/src/push.rs b/src/push.rs index 6ffa23bc..ee5aeb1d 100644 --- a/src/push.rs +++ b/src/push.rs @@ -7,12 +7,28 @@ use flume::{Receiver, Sender}; use prost::Message; use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; +use tonic::async_trait; use tonic::transport::Channel; use tracing::{debug, error, info}; use crate::config::Config; use crate::store::inflight_activation::InflightActivation; +/// Thin interface for the worker client. It mostly serves to enable proper unit testing, but it also decouples the actual client implementation from our pushing logic. +#[async_trait] +trait WorkerClient { + /// Send a single `PushTaskRequest` to the worker service. + async fn send(&mut self, request: PushTaskRequest) -> Result<()>; +} + +#[async_trait] +impl WorkerClient for WorkerServiceClient { + async fn send(&mut self, request: PushTaskRequest) -> Result<()> { + self.push_task(request).await?; + Ok(()) + } +} + /// Wrapper around `config.push_threads` asynchronous tasks, each of which receives an activation from the channel, sends it to the worker service, and repeats. pub struct PushPool { /// The sending end of a channel that accepts task activations. @@ -44,7 +60,11 @@ impl PushPool { for _ in 0..self.config.push_threads { let endpoint = self.config.worker_endpoint.clone(); - let callback_url = format!("{}:{}", self.config.grpc_addr, self.config.grpc_port); + let callback_url = format!( + "{}:{}", + self.config.callback_addr, self.config.callback_port + ); + let receiver = self.receiver.clone(); let guard = get_shutdown_guard().shutdown_on_drop(); @@ -113,8 +133,8 @@ impl PushPool { } /// Decode task activation and push it to a worker. -async fn push_task( - worker: &mut WorkerServiceClient, +async fn push_task( + worker: &mut W, activation: InflightActivation, callback_url: String, ) -> Result<()> { @@ -129,7 +149,7 @@ async fn push_task( callback_url, }; - let result = match worker.push_task(request).await { + let result = match worker.send(request).await { Ok(_) => { debug!("Successfully sent activation {id} to worker service!"); Ok(()) @@ -141,6 +161,129 @@ async fn push_task( } }; - metrics::histogram!("pusher.push_task.duration").record(start.elapsed()); + metrics::histogram!("push.push_task.duration").record(start.elapsed()); result } + +#[cfg(test)] +mod tests { + use anyhow::anyhow; + use std::sync::Arc; + use tokio::time::{Duration, timeout}; + + use super::*; + use crate::test_utils::make_activations; + + /// Fake worker client for unit testing. + struct MockWorkerClient { + /// Capture all received requests so we can assert things about them. + captured_requests: Vec, + + /// Should requests to the worker client fail? + should_fail: bool, + } + + impl MockWorkerClient { + fn new(should_fail: bool) -> Self { + let captured_requests = vec![]; + + Self { + captured_requests, + should_fail, + } + } + } + + #[async_trait] + impl WorkerClient for MockWorkerClient { + async fn send(&mut self, request: PushTaskRequest) -> Result<()> { + self.captured_requests.push(request); + + if self.should_fail { + return Err(anyhow!("mock send failure")); + } + + Ok(()) + } + } + + #[tokio::test] + async fn push_task_returns_ok_on_client_success() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(false); + let callback_url = "taskbroker:50051".to_string(); + + let result = push_task(&mut worker, activation.clone(), callback_url.clone()).await; + assert!(result.is_ok(), "push_task should succeed"); + assert_eq!(worker.captured_requests.len(), 1); + + let request = &worker.captured_requests[0]; + assert_eq!(request.callback_url, callback_url); + assert_eq!( + request.task.as_ref().map(|task| task.id.as_str()), + Some(activation.id.as_str()) + ); + } + + #[tokio::test] + async fn push_task_returns_err_on_invalid_payload() { + let mut activation = make_activations(1).remove(0); + activation.activation = vec![1, 2, 3, 4]; + + let mut worker = MockWorkerClient::new(false); + let result = push_task(&mut worker, activation, "taskbroker:50051".to_string()).await; + + assert!(result.is_err(), "invalid payload should fail decoding"); + assert!( + worker.captured_requests.is_empty(), + "worker should not be called if decode fails" + ); + } + + #[tokio::test] + async fn push_task_propagates_client_error() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(true); + + let result = push_task(&mut worker, activation, "taskbroker:50051".to_string()).await; + assert!(result.is_err(), "worker send errors should propagate"); + assert_eq!(worker.captured_requests.len(), 1); + } + + #[tokio::test] + async fn push_pool_submit_enqueues_item() { + let config = Arc::new(Config { + push_queue_size: 2, + ..Config::default() + }); + + let pool = PushPool::new(config); + let activation = make_activations(1).remove(0); + + let result = pool.submit(activation).await; + assert!(result.is_ok(), "submit should enqueue activation"); + } + + #[tokio::test] + async fn push_pool_submit_backpressures_when_queue_full() { + let config = Arc::new(Config { + push_queue_size: 1, + ..Config::default() + }); + + let pool = PushPool::new(config); + + let first = make_activations(1).remove(0); + let second = make_activations(1).remove(0); + + pool.submit(first) + .await + .expect("first submit should fill queue"); + + let second_submit = timeout(Duration::from_millis(50), pool.submit(second)).await; + assert!( + second_submit.is_err(), + "second submit should block when queue is full" + ); + } +} From dafa06c254dcea6f2b5b284263ed9e1715e670e4 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 17 Mar 2026 19:35:07 -0700 Subject: [PATCH 07/17] Add `FetchPool` Unit Tests --- src/fetch.rs | 309 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 300 insertions(+), 9 deletions(-) diff --git a/src/fetch.rs b/src/fetch.rs index 10cfbbf4..94ffc1fa 100644 --- a/src/fetch.rs +++ b/src/fetch.rs @@ -4,35 +4,51 @@ use std::time::{Duration, Instant}; use anyhow::Result; use elegant_departure::get_shutdown_guard; use tokio::time::sleep; +use tonic::async_trait; use tracing::{debug, error, info}; use crate::config::Config; use crate::push::PushPool; +use crate::store::inflight_activation::InflightActivation; use crate::store::inflight_activation::InflightActivationStore; +/// Thin interface for the push pool. It mostly serves to enable proper unit testing, but it also decouples fetch logic from push logic even further. +#[async_trait] +pub trait TaskPusher { + /// Push a single task to the worker service. + async fn push_task(&self, activation: InflightActivation) -> Result<()>; +} + +#[async_trait] +impl TaskPusher for PushPool { + async fn push_task(&self, activation: InflightActivation) -> Result<()> { + self.submit(activation).await + } +} + /// Wrapper around `config.fetch_threads` asynchronous tasks, each of which fetches a pending activation from the store, passes is to the push pool, and repeats. -pub struct FetchPool { +pub struct FetchPool { /// Inflight activation store. store: Arc, /// Pool of push threads that push activations to the worker service. - push_pool: Arc, + pusher: Arc, /// Taskbroker configuration. config: Arc, } -impl FetchPool { +impl FetchPool { /// Initialize a new fetch pool. pub fn new( store: Arc, config: Arc, - push_pool: Arc, + pusher: Arc, ) -> Self { Self { store, - push_pool, config, + pusher, } } @@ -44,7 +60,7 @@ impl FetchPool { let guard = get_shutdown_guard().shutdown_on_drop(); let store = self.store.clone(); - let push_pool = self.push_pool.clone(); + let task_pusher = self.pusher.clone(); let handle = tokio::spawn(async move { loop { @@ -56,7 +72,7 @@ impl FetchPool { _ = async { debug!("About to fetch next activation..."); - fetch_activations(store.clone(), push_pool.clone()).await; + fetch_activations(store.clone(), task_pusher.clone()).await; } => {} } } @@ -76,7 +92,10 @@ impl FetchPool { } /// Grab the next pending activation from the store, mark it as processing, and send to push channel. -pub async fn fetch_activations(store: Arc, push_pool: Arc) { +pub async fn fetch_activations( + store: Arc, + pusher: Arc, +) { let start = Instant::now(); metrics::counter!("fetch.fetch_activations.runs").increment(1); @@ -87,7 +106,7 @@ pub async fn fetch_activations(store: Arc, push_poo let id = activation.id.clone(); debug!("Atomically fetched and marked task {id} as processing"); - if let Err(e) = push_pool.submit(activation).await { + if let Err(e) = pusher.push_task(activation).await { error!("Failed to submit task {id} to push pool - {:?}", e); } @@ -109,3 +128,275 @@ pub async fn fetch_activations(store: Arc, push_poo } } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + + use anyhow::{Error, anyhow}; + use chrono::{DateTime, Utc}; + use tokio::sync::Mutex; + use tokio::time::{Duration, timeout}; + + use super::*; + use crate::store::inflight_activation::{ + FailedTasksForwarder, InflightActivationStatus, QueryResult, + }; + use crate::test_utils::make_activations; + + enum MockPendingResult { + Some(InflightActivation), + None, + Err, + } + + /// Fake store for testing. + struct MockStore { + /// How should all calls to `get_pending_activation` respond? + pending_result: MockPendingResult, + + /// How many calls to `get_pending_activation` have been performed? + pending_calls: AtomicUsize, + } + + impl MockStore { + fn new(pending_result: MockPendingResult) -> Self { + let pending_calls = AtomicUsize::new(0); + + Self { + pending_result, + pending_calls, + } + } + } + + #[async_trait] + impl InflightActivationStore for MockStore { + async fn vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn full_vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn db_size(&self) -> Result { + unimplemented!() + } + + async fn get_by_id(&self, _id: &str) -> Result, Error> { + unimplemented!() + } + async fn store(&self, _batch: Vec) -> Result { + unimplemented!() + } + + async fn get_pending_activation( + &self, + _application: Option<&str>, + _namespace: Option<&str>, + ) -> Result, Error> { + self.pending_calls.fetch_add(1, Ordering::SeqCst); + match &self.pending_result { + MockPendingResult::Some(activation) => Ok(Some(activation.clone())), + MockPendingResult::None => Ok(None), + MockPendingResult::Err => Err(anyhow!("mock store error")), + } + } + + async fn get_pending_activations_from_namespaces( + &self, + _application: Option<&str>, + _namespaces: Option<&[String]>, + _limit: Option, + ) -> Result, Error> { + unimplemented!() + } + + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { + unimplemented!() + } + + async fn count_by_status(&self, _status: InflightActivationStatus) -> Result { + unimplemented!() + } + + async fn count(&self) -> Result { + unimplemented!() + } + + async fn set_status( + &self, + _id: &str, + _status: InflightActivationStatus, + ) -> Result, Error> { + unimplemented!() + } + + async fn set_processing_deadline( + &self, + _id: &str, + _deadline: Option>, + ) -> Result<(), Error> { + unimplemented!() + } + + async fn delete_activation(&self, _id: &str) -> Result<(), Error> { + unimplemented!() + } + + async fn get_retry_activations(&self) -> Result, Error> { + unimplemented!() + } + + async fn clear(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn handle_processing_deadline(&self) -> Result { + unimplemented!() + } + + async fn handle_processing_attempts(&self) -> Result { + unimplemented!() + } + + async fn handle_expires_at(&self) -> Result { + unimplemented!() + } + + async fn handle_delay_until(&self) -> Result { + unimplemented!() + } + + async fn handle_failed_tasks(&self) -> Result { + unimplemented!() + } + + async fn mark_completed(&self, _ids: Vec) -> Result { + unimplemented!() + } + + async fn remove_completed(&self) -> Result { + unimplemented!() + } + + async fn remove_killswitched( + &self, + _killswitched_tasks: Vec, + ) -> Result { + unimplemented!() + } + } + + /// Fake push pool for testing. + struct MockTaskPusher { + /// List of the IDs of all the activations that have been pushed. + pushed_ids: Mutex>, + + /// Should `push_task` fail? + should_fail: bool, + } + + impl MockTaskPusher { + fn new(should_fail: bool) -> Self { + let pushed_ids = Mutex::new(vec![]); + + Self { + pushed_ids, + should_fail, + } + } + } + + #[async_trait] + impl TaskPusher for MockTaskPusher { + async fn push_task(&self, activation: InflightActivation) -> Result<()> { + self.pushed_ids.lock().await.push(activation.id); + + if self.should_fail { + return Err(anyhow!("mock push error")); + } + + Ok(()) + } + } + + #[tokio::test] + async fn fetch_activations_submits_when_pending_exists() { + let activation = make_activations(1).remove(0); + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::Some(activation.clone()))); + let pusher = Arc::new(MockTaskPusher::new(false)); + + fetch_activations(store, pusher.clone()).await; + + let pushed = pusher.pushed_ids.lock().await; + assert_eq!(pushed.len(), 1); + assert_eq!(pushed[0], activation.id); + } + + #[tokio::test] + async fn fetch_activations_logs_submit_error_but_does_not_fail() { + let activation = make_activations(1).remove(0); + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::Some(activation))); + let pusher = Arc::new(MockTaskPusher::new(true)); + + fetch_activations(store, pusher.clone()).await; + + let pushed = pusher.pushed_ids.lock().await; + assert_eq!(pushed.len(), 1, "should attempt one push even if it fails"); + } + + #[tokio::test] + async fn fetch_activations_no_pending_does_not_submit() { + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::None)); + let pusher = Arc::new(MockTaskPusher::new(false)); + + fetch_activations(store, pusher.clone()).await; + + let pushed = pusher.pushed_ids.lock().await; + assert!( + pushed.is_empty(), + "should not push if no activation is pending" + ); + } + + #[tokio::test] + async fn fetch_activations_store_error_does_not_submit() { + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::Err)); + let pusher = Arc::new(MockTaskPusher::new(false)); + + fetch_activations(store, pusher.clone()).await; + + let pushed = pusher.pushed_ids.lock().await; + assert!( + pushed.is_empty(), + "should not push when pending activation lookup fails" + ); + } + + #[tokio::test] + async fn fetch_pool_start_spawns_at_least_one_worker_when_fetch_threads_zero() { + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::None)); + let pusher = Arc::new(MockTaskPusher::new(false)); + + let config = Arc::new(Config { + fetch_threads: 0, + ..Config::default() + }); + + let pool = FetchPool::new(store, config, pusher); + + let result = timeout(Duration::from_millis(50), pool.start()).await; + assert!( + result.is_err(), + "start() should not complete immediately when fetch_threads is 0 because .max(1) starts one worker loop" + ); + } +} From e47a1e8ce1cfab0293c28394452730941475d58c Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 17 Mar 2026 19:39:45 -0700 Subject: [PATCH 08/17] Fix Linting --- src/fetch.rs | 1 + src/push.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fetch.rs b/src/fetch.rs index 94ffc1fa..dbd98132 100644 --- a/src/fetch.rs +++ b/src/fetch.rs @@ -145,6 +145,7 @@ mod tests { }; use crate::test_utils::make_activations; + #[allow(clippy::large_enum_variant)] enum MockPendingResult { Some(InflightActivation), None, diff --git a/src/push.rs b/src/push.rs index ee5aeb1d..3156ec78 100644 --- a/src/push.rs +++ b/src/push.rs @@ -157,7 +157,7 @@ async fn push_task( Err(e) => { error!("Could not push activation {id} to worker service - {:?}", e); - Err(e.into()) + Err(e) } }; From cb47f99a951f745136f944ffacd35eb59c9790c9 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 18 Mar 2026 10:31:59 -0700 Subject: [PATCH 09/17] Address PR Comments (Fix Bugs, Make More Robust) --- src/config.rs | 4 + src/fetch.rs | 403 ----------------------------------- src/fetch/mod.rs | 139 ++++++++++++ src/fetch/tests.rs | 286 +++++++++++++++++++++++++ src/helpers.rs | 19 ++ src/lib.rs | 1 + src/main.rs | 2 +- src/{push.rs => push/mod.rs} | 162 +++----------- src/push/tests.rs | 123 +++++++++++ 9 files changed, 598 insertions(+), 541 deletions(-) delete mode 100644 src/fetch.rs create mode 100644 src/fetch/mod.rs create mode 100644 src/fetch/tests.rs create mode 100644 src/helpers.rs rename src/{push.rs => push/mod.rs} (56%) create mode 100644 src/push/tests.rs diff --git a/src/config.rs b/src/config.rs index 2831cae1..659d7a9a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -252,6 +252,9 @@ pub struct Config { /// The size of the push queue. pub push_queue_size: usize, + /// Maximum time in milliseconds to wait when submitting an activation to the push pool. + pub push_timeout_ms: u64, + /// The worker service endpoint. pub worker_endpoint: String, @@ -333,6 +336,7 @@ impl Default for Config { fetch_threads: 1, push_threads: 1, push_queue_size: 1, + push_timeout_ms: 5000, worker_endpoint: "http://127.0.0.1:50052".into(), callback_addr: "0.0.0.0".into(), callback_port: 50051, diff --git a/src/fetch.rs b/src/fetch.rs deleted file mode 100644 index dbd98132..00000000 --- a/src/fetch.rs +++ /dev/null @@ -1,403 +0,0 @@ -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use anyhow::Result; -use elegant_departure::get_shutdown_guard; -use tokio::time::sleep; -use tonic::async_trait; -use tracing::{debug, error, info}; - -use crate::config::Config; -use crate::push::PushPool; -use crate::store::inflight_activation::InflightActivation; -use crate::store::inflight_activation::InflightActivationStore; - -/// Thin interface for the push pool. It mostly serves to enable proper unit testing, but it also decouples fetch logic from push logic even further. -#[async_trait] -pub trait TaskPusher { - /// Push a single task to the worker service. - async fn push_task(&self, activation: InflightActivation) -> Result<()>; -} - -#[async_trait] -impl TaskPusher for PushPool { - async fn push_task(&self, activation: InflightActivation) -> Result<()> { - self.submit(activation).await - } -} - -/// Wrapper around `config.fetch_threads` asynchronous tasks, each of which fetches a pending activation from the store, passes is to the push pool, and repeats. -pub struct FetchPool { - /// Inflight activation store. - store: Arc, - - /// Pool of push threads that push activations to the worker service. - pusher: Arc, - - /// Taskbroker configuration. - config: Arc, -} - -impl FetchPool { - /// Initialize a new fetch pool. - pub fn new( - store: Arc, - config: Arc, - pusher: Arc, - ) -> Self { - Self { - store, - config, - pusher, - } - } - - /// Spawn `config.fetch_threads` asynchronous tasks, each of which repeatedly moves pending activations from the store to the push pool until the shutdown signal is received. - pub async fn start(&self) -> Result<()> { - let mut handles = vec![]; - - for _ in 0..self.config.fetch_threads.max(1) { - let guard = get_shutdown_guard().shutdown_on_drop(); - - let store = self.store.clone(); - let task_pusher = self.pusher.clone(); - - let handle = tokio::spawn(async move { - loop { - tokio::select! { - _ = guard.wait() => { - info!("Fetch loop received shutdown signal"); - break; - } - - _ = async { - debug!("About to fetch next activation..."); - fetch_activations(store.clone(), task_pusher.clone()).await; - } => {} - } - } - }); - - handles.push(handle); - } - - for handle in handles { - if let Err(e) = handle.await { - return Err(e.into()); - } - } - - Ok(()) - } -} - -/// Grab the next pending activation from the store, mark it as processing, and send to push channel. -pub async fn fetch_activations( - store: Arc, - pusher: Arc, -) { - let start = Instant::now(); - metrics::counter!("fetch.fetch_activations.runs").increment(1); - - debug!("Fetching next pending activation..."); - - match store.get_pending_activation(None, None).await { - Ok(Some(activation)) => { - let id = activation.id.clone(); - debug!("Atomically fetched and marked task {id} as processing"); - - if let Err(e) = pusher.push_task(activation).await { - error!("Failed to submit task {id} to push pool - {:?}", e); - } - - metrics::histogram!("fetch.fetch_activations.duration").record(start.elapsed()); - } - - Ok(_) => { - debug!("No pending activations, sleeping briefly..."); - sleep(Duration::from_millis(100)).await; - - metrics::histogram!("fetch.fetch_activations.duration").record(start.elapsed()); - } - - Err(e) => { - error!("Failed to fetch pending activation - {:?}", e); - sleep(Duration::from_millis(100)).await; - - metrics::histogram!("fetch.fetch_activations.duration").record(start.elapsed()); - } - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - use std::sync::atomic::{AtomicUsize, Ordering}; - - use anyhow::{Error, anyhow}; - use chrono::{DateTime, Utc}; - use tokio::sync::Mutex; - use tokio::time::{Duration, timeout}; - - use super::*; - use crate::store::inflight_activation::{ - FailedTasksForwarder, InflightActivationStatus, QueryResult, - }; - use crate::test_utils::make_activations; - - #[allow(clippy::large_enum_variant)] - enum MockPendingResult { - Some(InflightActivation), - None, - Err, - } - - /// Fake store for testing. - struct MockStore { - /// How should all calls to `get_pending_activation` respond? - pending_result: MockPendingResult, - - /// How many calls to `get_pending_activation` have been performed? - pending_calls: AtomicUsize, - } - - impl MockStore { - fn new(pending_result: MockPendingResult) -> Self { - let pending_calls = AtomicUsize::new(0); - - Self { - pending_result, - pending_calls, - } - } - } - - #[async_trait] - impl InflightActivationStore for MockStore { - async fn vacuum_db(&self) -> Result<(), Error> { - unimplemented!() - } - - async fn full_vacuum_db(&self) -> Result<(), Error> { - unimplemented!() - } - - async fn db_size(&self) -> Result { - unimplemented!() - } - - async fn get_by_id(&self, _id: &str) -> Result, Error> { - unimplemented!() - } - async fn store(&self, _batch: Vec) -> Result { - unimplemented!() - } - - async fn get_pending_activation( - &self, - _application: Option<&str>, - _namespace: Option<&str>, - ) -> Result, Error> { - self.pending_calls.fetch_add(1, Ordering::SeqCst); - match &self.pending_result { - MockPendingResult::Some(activation) => Ok(Some(activation.clone())), - MockPendingResult::None => Ok(None), - MockPendingResult::Err => Err(anyhow!("mock store error")), - } - } - - async fn get_pending_activations_from_namespaces( - &self, - _application: Option<&str>, - _namespaces: Option<&[String]>, - _limit: Option, - ) -> Result, Error> { - unimplemented!() - } - - async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { - unimplemented!() - } - - async fn count_by_status(&self, _status: InflightActivationStatus) -> Result { - unimplemented!() - } - - async fn count(&self) -> Result { - unimplemented!() - } - - async fn set_status( - &self, - _id: &str, - _status: InflightActivationStatus, - ) -> Result, Error> { - unimplemented!() - } - - async fn set_processing_deadline( - &self, - _id: &str, - _deadline: Option>, - ) -> Result<(), Error> { - unimplemented!() - } - - async fn delete_activation(&self, _id: &str) -> Result<(), Error> { - unimplemented!() - } - - async fn get_retry_activations(&self) -> Result, Error> { - unimplemented!() - } - - async fn clear(&self) -> Result<(), Error> { - unimplemented!() - } - - async fn handle_processing_deadline(&self) -> Result { - unimplemented!() - } - - async fn handle_processing_attempts(&self) -> Result { - unimplemented!() - } - - async fn handle_expires_at(&self) -> Result { - unimplemented!() - } - - async fn handle_delay_until(&self) -> Result { - unimplemented!() - } - - async fn handle_failed_tasks(&self) -> Result { - unimplemented!() - } - - async fn mark_completed(&self, _ids: Vec) -> Result { - unimplemented!() - } - - async fn remove_completed(&self) -> Result { - unimplemented!() - } - - async fn remove_killswitched( - &self, - _killswitched_tasks: Vec, - ) -> Result { - unimplemented!() - } - } - - /// Fake push pool for testing. - struct MockTaskPusher { - /// List of the IDs of all the activations that have been pushed. - pushed_ids: Mutex>, - - /// Should `push_task` fail? - should_fail: bool, - } - - impl MockTaskPusher { - fn new(should_fail: bool) -> Self { - let pushed_ids = Mutex::new(vec![]); - - Self { - pushed_ids, - should_fail, - } - } - } - - #[async_trait] - impl TaskPusher for MockTaskPusher { - async fn push_task(&self, activation: InflightActivation) -> Result<()> { - self.pushed_ids.lock().await.push(activation.id); - - if self.should_fail { - return Err(anyhow!("mock push error")); - } - - Ok(()) - } - } - - #[tokio::test] - async fn fetch_activations_submits_when_pending_exists() { - let activation = make_activations(1).remove(0); - let store: Arc = - Arc::new(MockStore::new(MockPendingResult::Some(activation.clone()))); - let pusher = Arc::new(MockTaskPusher::new(false)); - - fetch_activations(store, pusher.clone()).await; - - let pushed = pusher.pushed_ids.lock().await; - assert_eq!(pushed.len(), 1); - assert_eq!(pushed[0], activation.id); - } - - #[tokio::test] - async fn fetch_activations_logs_submit_error_but_does_not_fail() { - let activation = make_activations(1).remove(0); - let store: Arc = - Arc::new(MockStore::new(MockPendingResult::Some(activation))); - let pusher = Arc::new(MockTaskPusher::new(true)); - - fetch_activations(store, pusher.clone()).await; - - let pushed = pusher.pushed_ids.lock().await; - assert_eq!(pushed.len(), 1, "should attempt one push even if it fails"); - } - - #[tokio::test] - async fn fetch_activations_no_pending_does_not_submit() { - let store: Arc = - Arc::new(MockStore::new(MockPendingResult::None)); - let pusher = Arc::new(MockTaskPusher::new(false)); - - fetch_activations(store, pusher.clone()).await; - - let pushed = pusher.pushed_ids.lock().await; - assert!( - pushed.is_empty(), - "should not push if no activation is pending" - ); - } - - #[tokio::test] - async fn fetch_activations_store_error_does_not_submit() { - let store: Arc = - Arc::new(MockStore::new(MockPendingResult::Err)); - let pusher = Arc::new(MockTaskPusher::new(false)); - - fetch_activations(store, pusher.clone()).await; - - let pushed = pusher.pushed_ids.lock().await; - assert!( - pushed.is_empty(), - "should not push when pending activation lookup fails" - ); - } - - #[tokio::test] - async fn fetch_pool_start_spawns_at_least_one_worker_when_fetch_threads_zero() { - let store: Arc = - Arc::new(MockStore::new(MockPendingResult::None)); - let pusher = Arc::new(MockTaskPusher::new(false)); - - let config = Arc::new(Config { - fetch_threads: 0, - ..Config::default() - }); - - let pool = FetchPool::new(store, config, pusher); - - let result = timeout(Duration::from_millis(50), pool.start()).await; - assert!( - result.is_err(), - "start() should not complete immediately when fetch_threads is 0 because .max(1) starts one worker loop" - ); - } -} diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs new file mode 100644 index 00000000..2f97f957 --- /dev/null +++ b/src/fetch/mod.rs @@ -0,0 +1,139 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use anyhow::Result; +use elegant_departure::get_shutdown_guard; +use tokio::time::sleep; +use tonic::async_trait; +use tracing::{debug, error, info}; + +use crate::config::Config; +use crate::helpers; +use crate::push::PushPool; +use crate::store::inflight_activation::InflightActivation; +use crate::store::inflight_activation::InflightActivationStore; + +/// Thin interface for the push pool. It mostly serves to enable proper unit testing, but it also decouples fetch logic from push logic even further. +#[async_trait] +pub trait TaskPusher { + /// Push a single task to the worker service. + async fn push_task(&self, activation: InflightActivation) -> Result<()>; +} + +#[async_trait] +impl TaskPusher for PushPool { + async fn push_task(&self, activation: InflightActivation) -> Result<()> { + self.submit(activation).await + } +} + +/// Wrapper around `config.fetch_threads` asynchronous tasks, each of which fetches a pending activation from the store, passes is to the push pool, and repeats. +pub struct FetchPool { + /// Inflight activation store. + store: Arc, + + /// Pool of push threads that push activations to the worker service. + pusher: Arc, + + /// Taskbroker configuration. + config: Arc, +} + +impl FetchPool { + /// Initialize a new fetch pool. + pub fn new( + store: Arc, + config: Arc, + pusher: Arc, + ) -> Self { + Self { + store, + config, + pusher, + } + } + + /// Spawn `config.fetch_threads` asynchronous tasks, each of which repeatedly moves pending activations from the store to the push pool until the shutdown signal is received. + pub async fn start(&self) -> Result<()> { + let mut fetch_pool = helpers::spawn_pool(self.config.fetch_threads, |_| { + let store = self.store.clone(); + let pusher = self.pusher.clone(); + + let guard = get_shutdown_guard().shutdown_on_drop(); + + async move { + loop { + tokio::select! { + _ = guard.wait() => { + info!("Fetch loop received shutdown signal"); + return; + } + + _ = async { + debug!("About to fetch next activation..."); + + // Instead of returning when `fetch_activation` fails, we just try again + if let Ok(false) = fetch_activation(store.clone(), pusher.clone()).await { + // Found no pending activations, wait for some to appear + sleep(Duration::from_millis(100)).await; + } + } => {} + } + } + } + }); + + while let Some(res) = fetch_pool.join_next().await { + if let Err(e) = res { + return Err(e.into()); + } + } + + Ok(()) + } +} + +/// Grab the next pending activation from the store, mark it as processing, and send to push channel. Return... +/// - `Ok(true)` if an activation was found +/// - `Ok(false)` if none pending +/// - `Err` if fetching failed. +pub async fn fetch_activation( + store: Arc, + pusher: Arc, +) -> Result { + let start = Instant::now(); + metrics::counter!("fetch.fetch_activation.calls").increment(1); + + debug!("Fetching next pending activation..."); + + let found = match store.get_pending_activation(None, None).await { + Ok(Some(activation)) => { + let id = activation.id.clone(); + debug!("Atomically fetched and marked task {id} as processing"); + + // Times out after `config.push_timeout_ms` milliseconds + if let Err(e) = pusher.push_task(activation).await { + // Do not return `Err` because the fetch itself succeeded + error!("Failed to submit task {id} to push pool - {:?}", e); + } + + true + } + + Ok(_) => { + debug!("No pending activations"); + false + } + + Err(e) => { + error!("Failed to fetch pending activation - {:?}", e); + return Err(e); + } + }; + + metrics::histogram!("fetch.fetch_activation.duration").record(start.elapsed()); + Ok(found) +} + +#[cfg(test)] +mod tests; diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs new file mode 100644 index 00000000..6c496554 --- /dev/null +++ b/src/fetch/tests.rs @@ -0,0 +1,286 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use anyhow::{Error, anyhow}; +use chrono::{DateTime, Utc}; +use tokio::sync::Mutex; +use tokio::time::{Duration, timeout}; +use tonic::async_trait; + +use super::*; +use crate::config::Config; +use crate::store::inflight_activation::InflightActivation; +use crate::store::inflight_activation::InflightActivationStore; +use crate::store::inflight_activation::{ + FailedTasksForwarder, InflightActivationStatus, QueryResult, +}; +use crate::test_utils::make_activations; + +#[allow(clippy::large_enum_variant)] +enum MockPendingResult { + Some(InflightActivation), + None, + Err, +} + +/// Fake store for testing. +struct MockStore { + /// How should all calls to `get_pending_activation` respond? + pending_result: MockPendingResult, + + /// How many calls to `get_pending_activation` have been performed? + pending_calls: AtomicUsize, +} + +impl MockStore { + fn new(pending_result: MockPendingResult) -> Self { + let pending_calls = AtomicUsize::new(0); + + Self { + pending_result, + pending_calls, + } + } +} + +#[async_trait] +impl InflightActivationStore for MockStore { + async fn vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn full_vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn db_size(&self) -> Result { + unimplemented!() + } + + async fn get_by_id(&self, _id: &str) -> Result, Error> { + unimplemented!() + } + async fn store(&self, _batch: Vec) -> Result { + unimplemented!() + } + + async fn get_pending_activation( + &self, + _application: Option<&str>, + _namespace: Option<&str>, + ) -> Result, Error> { + self.pending_calls.fetch_add(1, Ordering::SeqCst); + match &self.pending_result { + MockPendingResult::Some(activation) => Ok(Some(activation.clone())), + MockPendingResult::None => Ok(None), + MockPendingResult::Err => Err(anyhow!("mock store error")), + } + } + + async fn get_pending_activations_from_namespaces( + &self, + _application: Option<&str>, + _namespaces: Option<&[String]>, + _limit: Option, + ) -> Result, Error> { + unimplemented!() + } + + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { + unimplemented!() + } + + async fn count_by_status(&self, _status: InflightActivationStatus) -> Result { + unimplemented!() + } + + async fn count(&self) -> Result { + unimplemented!() + } + + async fn set_status( + &self, + _id: &str, + _status: InflightActivationStatus, + ) -> Result, Error> { + unimplemented!() + } + + async fn set_processing_deadline( + &self, + _id: &str, + _deadline: Option>, + ) -> Result<(), Error> { + unimplemented!() + } + + async fn delete_activation(&self, _id: &str) -> Result<(), Error> { + unimplemented!() + } + + async fn get_retry_activations(&self) -> Result, Error> { + unimplemented!() + } + + async fn clear(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn handle_processing_deadline(&self) -> Result { + unimplemented!() + } + + async fn handle_processing_attempts(&self) -> Result { + unimplemented!() + } + + async fn handle_expires_at(&self) -> Result { + unimplemented!() + } + + async fn handle_delay_until(&self) -> Result { + unimplemented!() + } + + async fn handle_failed_tasks(&self) -> Result { + unimplemented!() + } + + async fn mark_completed(&self, _ids: Vec) -> Result { + unimplemented!() + } + + async fn remove_completed(&self) -> Result { + unimplemented!() + } + + async fn remove_killswitched(&self, _killswitched_tasks: Vec) -> Result { + unimplemented!() + } +} + +/// Fake push pool for testing. +struct MockTaskPusher { + /// List of the IDs of all the activations that have been pushed. + pushed_ids: Mutex>, + + /// Should `push_task` fail? + should_fail: bool, +} + +impl MockTaskPusher { + fn new(should_fail: bool) -> Self { + let pushed_ids = Mutex::new(vec![]); + + Self { + pushed_ids, + should_fail, + } + } +} + +#[async_trait] +impl TaskPusher for MockTaskPusher { + async fn push_task(&self, activation: InflightActivation) -> Result<()> { + self.pushed_ids.lock().await.push(activation.id); + + if self.should_fail { + return Err(anyhow!("mock push error")); + } + + Ok(()) + } +} + +#[tokio::test] +async fn fetch_activation_submits_when_pending_exists() { + let activation = make_activations(1).remove(0); + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::Some(activation.clone()))); + let pusher = Arc::new(MockTaskPusher::new(false)); + + let found = fetch_activation(store, pusher.clone()) + .await + .expect("fetch should succeed"); + assert!( + found, + "should return true when activation was found and submitted" + ); + + let pushed = pusher.pushed_ids.lock().await; + assert_eq!(pushed.len(), 1); + assert_eq!(pushed[0], activation.id); +} + +#[tokio::test] +async fn fetch_activation_logs_submit_error_but_does_not_fail() { + let activation = make_activations(1).remove(0); + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::Some(activation))); + let pusher = Arc::new(MockTaskPusher::new(true)); + + let found = fetch_activation(store, pusher.clone()) + .await + .expect("fetch should succeed"); + assert!( + found, + "should return true when activation was found even if push fails" + ); + + let pushed = pusher.pushed_ids.lock().await; + assert_eq!(pushed.len(), 1, "should attempt one push even if it fails"); +} + +#[tokio::test] +async fn fetch_activation_no_pending_returns_false() { + let store: Arc = Arc::new(MockStore::new(MockPendingResult::None)); + let pusher = Arc::new(MockTaskPusher::new(false)); + + let found = fetch_activation(store, pusher.clone()) + .await + .expect("fetch should succeed"); + assert!(!found, "should return false when no activation is pending"); + + let pushed = pusher.pushed_ids.lock().await; + assert!( + pushed.is_empty(), + "should not push if no activation is pending" + ); +} + +#[tokio::test] +async fn fetch_activation_store_error_returns_false() { + let store: Arc = Arc::new(MockStore::new(MockPendingResult::Err)); + let pusher = Arc::new(MockTaskPusher::new(false)); + + let result = fetch_activation(store, pusher.clone()).await; + assert!( + result.is_err(), + "should return error when pending activation lookup fails" + ); + + let pushed = pusher.pushed_ids.lock().await; + assert!( + pushed.is_empty(), + "should not push when pending activation lookup fails" + ); +} + +#[tokio::test] +async fn fetch_pool_start_spawns_at_least_one_worker_when_fetch_threads_zero() { + let store: Arc = Arc::new(MockStore::new(MockPendingResult::None)); + let pusher = Arc::new(MockTaskPusher::new(false)); + + let config = Arc::new(Config { + fetch_threads: 0, + ..Config::default() + }); + + let pool = FetchPool::new(store, config, pusher); + + let result = timeout(Duration::from_millis(50), pool.start()).await; + assert!( + result.is_err(), + "start() should not complete immediately when fetch_threads is 0 because .max(1) starts one worker loop" + ); +} diff --git a/src/helpers.rs b/src/helpers.rs new file mode 100644 index 00000000..d252818a --- /dev/null +++ b/src/helpers.rs @@ -0,0 +1,19 @@ +use tokio::task::JoinSet; + +/// Spawns `max(n, 1)` tasks, each running the future produced by `f` with the task's index. +/// Returns a [`JoinSet`] containing all spawned tasks. +pub fn spawn_pool(n: usize, f: F) -> JoinSet +where + F: Fn(usize) -> Fut, + Fut: Future + Send + 'static, + Fut::Output: Send, +{ + let mut join_set = JoinSet::new(); + + let count = n.max(1); + for i in 0..count { + join_set.spawn(f(i)); + } + + join_set +} diff --git a/src/lib.rs b/src/lib.rs index baf480d7..df19be34 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ use std::fs; pub mod config; pub mod fetch; pub mod grpc; +pub mod helpers; pub mod kafka; pub mod logging; pub mod metrics; diff --git a/src/main.rs b/src/main.rs index 0a2ca923..4efdbc58 100644 --- a/src/main.rs +++ b/src/main.rs @@ -196,7 +196,7 @@ async fn main() -> Result<(), Error> { let config = config.clone(); async move { - let addr = format!("0.0.0.0:{}", config.grpc_port) + let addr = format!("{}:{}", config.grpc_addr, config.grpc_port) .parse() .expect("Failed to parse address"); diff --git a/src/push.rs b/src/push/mod.rs similarity index 56% rename from src/push.rs rename to src/push/mod.rs index 3156ec78..7d69ce54 100644 --- a/src/push.rs +++ b/src/push/mod.rs @@ -1,5 +1,5 @@ use std::sync::Arc; -use std::time::Instant; +use std::time::{Duration, Instant}; use anyhow::Result; use elegant_departure::get_shutdown_guard; @@ -12,6 +12,7 @@ use tonic::transport::Channel; use tracing::{debug, error, info}; use crate::config::Config; +use crate::helpers; use crate::store::inflight_activation::InflightActivation; /// Thin interface for the worker client. It mostly serves to enable proper unit testing, but it also decouples the actual client implementation from our pushing logic. @@ -55,20 +56,18 @@ impl PushPool { /// Spawn `config.push_threads` asynchronous tasks, each of which repeatedly moves pending activations from the channel to the worker service until the shutdown signal is received. pub async fn start(&self) -> Result<()> { - let mut handles = vec![]; - - for _ in 0..self.config.push_threads { + let mut push_pool = helpers::spawn_pool(self.config.push_threads, |_| { let endpoint = self.config.worker_endpoint.clone(); + let receiver = self.receiver.clone(); + + let guard = get_shutdown_guard().shutdown_on_drop(); let callback_url = format!( "{}:{}", self.config.callback_addr, self.config.callback_port ); - let receiver = self.receiver.clone(); - let guard = get_shutdown_guard().shutdown_on_drop(); - - let handle = tokio::spawn(async move { + async move { let mut worker = match WorkerServiceClient::connect(endpoint).await { Ok(w) => w, Err(e) => { @@ -112,13 +111,11 @@ impl PushPool { Err(e) => error!("Pushing activation {id} resulted in error - {:?}", e), }; } - }); - - handles.push(handle); - } + } + }); - for handle in handles { - if let Err(e) = handle.await { + while let Some(res) = push_pool.join_next().await { + if let Err(e) = res { return Err(e.into()); } } @@ -126,9 +123,20 @@ impl PushPool { Ok(()) } - /// Send an activation to the internal asynchronous MPMC channel used by all running push threads. + /// Send an activation to the internal asynchronous MPMC channel used by all running push threads. Times out after `config.push_timeout_ms` milliseconds. pub async fn submit(&self, activation: InflightActivation) -> Result<()> { - Ok(self.sender.send_async(activation).await?) + let duration = Duration::from_millis(self.config.push_timeout_ms); + + tokio::time::timeout(duration, self.sender.send_async(activation)) + .await + .map_err(|_| { + anyhow::anyhow!( + "failed to submit to push pool within {} milliseconds", + self.config.push_timeout_ms + ) + })??; + + Ok(()) } } @@ -166,124 +174,4 @@ async fn push_task( } #[cfg(test)] -mod tests { - use anyhow::anyhow; - use std::sync::Arc; - use tokio::time::{Duration, timeout}; - - use super::*; - use crate::test_utils::make_activations; - - /// Fake worker client for unit testing. - struct MockWorkerClient { - /// Capture all received requests so we can assert things about them. - captured_requests: Vec, - - /// Should requests to the worker client fail? - should_fail: bool, - } - - impl MockWorkerClient { - fn new(should_fail: bool) -> Self { - let captured_requests = vec![]; - - Self { - captured_requests, - should_fail, - } - } - } - - #[async_trait] - impl WorkerClient for MockWorkerClient { - async fn send(&mut self, request: PushTaskRequest) -> Result<()> { - self.captured_requests.push(request); - - if self.should_fail { - return Err(anyhow!("mock send failure")); - } - - Ok(()) - } - } - - #[tokio::test] - async fn push_task_returns_ok_on_client_success() { - let activation = make_activations(1).remove(0); - let mut worker = MockWorkerClient::new(false); - let callback_url = "taskbroker:50051".to_string(); - - let result = push_task(&mut worker, activation.clone(), callback_url.clone()).await; - assert!(result.is_ok(), "push_task should succeed"); - assert_eq!(worker.captured_requests.len(), 1); - - let request = &worker.captured_requests[0]; - assert_eq!(request.callback_url, callback_url); - assert_eq!( - request.task.as_ref().map(|task| task.id.as_str()), - Some(activation.id.as_str()) - ); - } - - #[tokio::test] - async fn push_task_returns_err_on_invalid_payload() { - let mut activation = make_activations(1).remove(0); - activation.activation = vec![1, 2, 3, 4]; - - let mut worker = MockWorkerClient::new(false); - let result = push_task(&mut worker, activation, "taskbroker:50051".to_string()).await; - - assert!(result.is_err(), "invalid payload should fail decoding"); - assert!( - worker.captured_requests.is_empty(), - "worker should not be called if decode fails" - ); - } - - #[tokio::test] - async fn push_task_propagates_client_error() { - let activation = make_activations(1).remove(0); - let mut worker = MockWorkerClient::new(true); - - let result = push_task(&mut worker, activation, "taskbroker:50051".to_string()).await; - assert!(result.is_err(), "worker send errors should propagate"); - assert_eq!(worker.captured_requests.len(), 1); - } - - #[tokio::test] - async fn push_pool_submit_enqueues_item() { - let config = Arc::new(Config { - push_queue_size: 2, - ..Config::default() - }); - - let pool = PushPool::new(config); - let activation = make_activations(1).remove(0); - - let result = pool.submit(activation).await; - assert!(result.is_ok(), "submit should enqueue activation"); - } - - #[tokio::test] - async fn push_pool_submit_backpressures_when_queue_full() { - let config = Arc::new(Config { - push_queue_size: 1, - ..Config::default() - }); - - let pool = PushPool::new(config); - - let first = make_activations(1).remove(0); - let second = make_activations(1).remove(0); - - pool.submit(first) - .await - .expect("first submit should fill queue"); - - let second_submit = timeout(Duration::from_millis(50), pool.submit(second)).await; - assert!( - second_submit.is_err(), - "second submit should block when queue is full" - ); - } -} +mod tests; diff --git a/src/push/tests.rs b/src/push/tests.rs new file mode 100644 index 00000000..edc45cf7 --- /dev/null +++ b/src/push/tests.rs @@ -0,0 +1,123 @@ +use std::sync::Arc; + +use anyhow::anyhow; +use sentry_protos::taskbroker::v1::PushTaskRequest; +use tokio::time::{Duration, timeout}; +use tonic::async_trait; + +use super::*; +use crate::config::Config; +use crate::test_utils::make_activations; + +/// Fake worker client for unit testing. +struct MockWorkerClient { + /// Capture all received requests so we can assert things about them. + captured_requests: Vec, + + /// Should requests to the worker client fail? + should_fail: bool, +} + +impl MockWorkerClient { + fn new(should_fail: bool) -> Self { + let captured_requests = vec![]; + + Self { + captured_requests, + should_fail, + } + } +} + +#[async_trait] +impl WorkerClient for MockWorkerClient { + async fn send(&mut self, request: PushTaskRequest) -> Result<()> { + self.captured_requests.push(request); + + if self.should_fail { + return Err(anyhow!("mock send failure")); + } + + Ok(()) + } +} + +#[tokio::test] +async fn push_task_returns_ok_on_client_success() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(false); + let callback_url = "taskbroker:50051".to_string(); + + let result = push_task(&mut worker, activation.clone(), callback_url.clone()).await; + assert!(result.is_ok(), "push_task should succeed"); + assert_eq!(worker.captured_requests.len(), 1); + + let request = &worker.captured_requests[0]; + assert_eq!(request.callback_url, callback_url); + assert_eq!( + request.task.as_ref().map(|task| task.id.as_str()), + Some(activation.id.as_str()) + ); +} + +#[tokio::test] +async fn push_task_returns_err_on_invalid_payload() { + let mut activation = make_activations(1).remove(0); + activation.activation = vec![1, 2, 3, 4]; + + let mut worker = MockWorkerClient::new(false); + let result = push_task(&mut worker, activation, "taskbroker:50051".to_string()).await; + + assert!(result.is_err(), "invalid payload should fail decoding"); + assert!( + worker.captured_requests.is_empty(), + "worker should not be called if decode fails" + ); +} + +#[tokio::test] +async fn push_task_propagates_client_error() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(true); + + let result = push_task(&mut worker, activation, "taskbroker:50051".to_string()).await; + assert!(result.is_err(), "worker send errors should propagate"); + assert_eq!(worker.captured_requests.len(), 1); +} + +#[tokio::test] +async fn push_pool_submit_enqueues_item() { + let config = Arc::new(Config { + push_queue_size: 2, + ..Config::default() + }); + + let pool = PushPool::new(config); + let activation = make_activations(1).remove(0); + + let result = pool.submit(activation).await; + assert!(result.is_ok(), "submit should enqueue activation"); +} + +#[tokio::test] +async fn push_pool_submit_backpressures_when_queue_full() { + let config = Arc::new(Config { + push_queue_size: 1, + ..Config::default() + }); + + let pool = PushPool::new(config); + + let first = make_activations(1).remove(0); + let second = make_activations(1).remove(0); + + pool.submit(first) + .await + .expect("first submit should fill queue"); + + let second_submit = timeout(Duration::from_millis(50), pool.submit(second)).await; + assert!( + second_submit.is_err(), + "second submit should block when queue is full" + ); +} From aa055bc9fa0547c116137e3a9a1de032010722e4 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 18 Mar 2026 10:47:35 -0700 Subject: [PATCH 10/17] Move Tasks Back to Pending on Push Failure, Add Server Unit Test for Push Mode --- src/fetch/mod.rs | 17 ++++++++++++----- src/fetch/tests.rs | 2 +- src/grpc/server_tests.rs | 24 ++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 2f97f957..46e78b64 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -10,8 +10,9 @@ use tracing::{debug, error, info}; use crate::config::Config; use crate::helpers; use crate::push::PushPool; -use crate::store::inflight_activation::InflightActivation; -use crate::store::inflight_activation::InflightActivationStore; +use crate::store::inflight_activation::{ + InflightActivation, InflightActivationStatus, InflightActivationStore, +}; /// Thin interface for the push pool. It mostly serves to enable proper unit testing, but it also decouples fetch logic from push logic even further. #[async_trait] @@ -111,10 +112,16 @@ pub async fn fetch_activation( let id = activation.id.clone(); debug!("Atomically fetched and marked task {id} as processing"); - // Times out after `config.push_timeout_ms` milliseconds if let Err(e) = pusher.push_task(activation).await { - // Do not return `Err` because the fetch itself succeeded - error!("Failed to submit task {id} to push pool - {:?}", e); + error!("Failed to submit task {id} to push pool - {e:?}"); + + // Change status back to pending + if let Err(e) = store + .set_status(&id, InflightActivationStatus::Pending) + .await + { + error!("Failed to change task {id} back to pending - {e:?}"); + } } true diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs index 6c496554..377bb124 100644 --- a/src/fetch/tests.rs +++ b/src/fetch/tests.rs @@ -103,7 +103,7 @@ impl InflightActivationStore for MockStore { _id: &str, _status: InflightActivationStatus, ) -> Result, Error> { - unimplemented!() + Ok(None) } async fn set_processing_deadline( diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index 29e36f96..c96072a3 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -1,3 +1,6 @@ +use std::sync::Arc; + +use crate::config::Config; use crate::grpc::server::TaskbrokerServer; use prost::Message; use rstest::rstest; @@ -9,6 +12,27 @@ use tonic::{Code, Request}; use crate::test_utils::{create_config, create_test_store, make_activations}; +#[tokio::test] +async fn test_get_task_push_mode_returns_permission_denied() { + let store = create_test_store("sqlite").await; + let config = Arc::new(Config { + push_mode: true, + ..Config::default() + }); + + let service = TaskbrokerServer { store, config }; + let request = GetTaskRequest { + namespace: None, + application: None, + }; + let response = service.get_task(Request::new(request)).await; + + assert!(response.is_err()); + let e = response.unwrap_err(); + assert_eq!(e.code(), Code::PermissionDenied); + assert_eq!(e.message(), "Cannot call while broker is in PUSH mode"); +} + #[tokio::test] #[rstest] #[case::sqlite("sqlite")] From 2f540ad27a48a55cf16002c3af727f7e314e52b3 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 18 Mar 2026 11:14:35 -0700 Subject: [PATCH 11/17] Make Empty Store Backoff Configurable, Other Fixes and Tests --- src/config.rs | 4 ++ src/fetch/mod.rs | 14 ++++-- src/main.rs | 2 +- src/push/mod.rs | 33 +++++++++++-- src/push/tests.rs | 120 ++++++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 162 insertions(+), 11 deletions(-) diff --git a/src/config.rs b/src/config.rs index 659d7a9a..152b9982 100644 --- a/src/config.rs +++ b/src/config.rs @@ -246,6 +246,9 @@ pub struct Config { /// The number of concurrent dispatchers to run. pub fetch_threads: usize, + /// Time in milliseconds to wait between fetch attempts when no pending activation is found. + pub fetch_wait_ms: u64, + /// The number of concurrent pushers each dispatcher should run. pub push_threads: usize, @@ -334,6 +337,7 @@ impl Default for Config { enable_sqlite_status_metrics: true, push_mode: false, fetch_threads: 1, + fetch_wait_ms: 100, push_threads: 1, push_queue_size: 1, push_timeout_ms: 5000, diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 46e78b64..617ae96a 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -56,6 +56,8 @@ impl FetchPool { /// Spawn `config.fetch_threads` asynchronous tasks, each of which repeatedly moves pending activations from the store to the push pool until the shutdown signal is received. pub async fn start(&self) -> Result<()> { + let fetch_wait_ms = self.config.fetch_wait_ms; + let mut fetch_pool = helpers::spawn_pool(self.config.fetch_threads, |_| { let store = self.store.clone(); let pusher = self.pusher.clone(); @@ -74,9 +76,15 @@ impl FetchPool { debug!("About to fetch next activation..."); // Instead of returning when `fetch_activation` fails, we just try again - if let Ok(false) = fetch_activation(store.clone(), pusher.clone()).await { - // Found no pending activations, wait for some to appear - sleep(Duration::from_millis(100)).await; + match fetch_activation(store.clone(), pusher.clone()).await { + Ok(false) | Err(_) => { + // Found no pending activations OR there is an issue with the store, wait some time before trying again + sleep(Duration::from_millis(fetch_wait_ms)).await; + } + + Ok(true) => { + // Fetched pending activation successfully, so nothing else needs to be done + } } } => {} } diff --git a/src/main.rs b/src/main.rs index 4efdbc58..47b3123e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -241,7 +241,7 @@ async fn main() -> Result<(), Error> { }); // Initialize push and fetch pools - let push_pool = Arc::new(PushPool::new(config.clone())); + let push_pool = Arc::new(PushPool::new(store.clone(), config.clone())); let fetch_pool = FetchPool::new(store.clone(), config.clone(), push_pool.clone()); // Initialize push threads diff --git a/src/push/mod.rs b/src/push/mod.rs index 7d69ce54..d189c9a6 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -13,7 +13,9 @@ use tracing::{debug, error, info}; use crate::config::Config; use crate::helpers; -use crate::store::inflight_activation::InflightActivation; +use crate::store::inflight_activation::{ + InflightActivation, InflightActivationStatus, InflightActivationStore, +}; /// Thin interface for the worker client. It mostly serves to enable proper unit testing, but it also decouples the actual client implementation from our pushing logic. #[async_trait] @@ -38,27 +40,33 @@ pub struct PushPool { /// The receiving end of a channel that accepts task activations. receiver: Receiver, + /// Inflight activation store (used to set status back to pending on push failure). + store: Arc, + /// Taskbroker configuration. config: Arc, } impl PushPool { /// Initialize a new push pool. - pub fn new(config: Arc) -> Self { + pub fn new(store: Arc, config: Arc) -> Self { let (sender, receiver) = flume::bounded(config.push_queue_size); Self { sender, receiver, + store, config, } } /// Spawn `config.push_threads` asynchronous tasks, each of which repeatedly moves pending activations from the channel to the worker service until the shutdown signal is received. pub async fn start(&self) -> Result<()> { + let store = self.store.clone(); let mut push_pool = helpers::spawn_pool(self.config.push_threads, |_| { let endpoint = self.config.worker_endpoint.clone(); let receiver = self.receiver.clone(); + let store = store.clone(); let guard = get_shutdown_guard().shutdown_on_drop(); @@ -96,7 +104,16 @@ impl PushPool { match push_task(&mut worker, activation, callback_url.clone()).await { Ok(_) => debug!("Activation {id} was sent to worker!"), - Err(e) => error!("Pushing activation {id} resulted in error - {:?}", e) + Err(e) => { + error!("Pushing activation {id} resulted in error - {:?}", e); + + if let Err(e) = store + .set_status(&id, InflightActivationStatus::Pending) + .await + { + error!("Failed to change task {id} back to pending - {e:?}"); + } + } }; } } @@ -108,7 +125,15 @@ impl PushPool { match push_task(&mut worker, activation, callback_url.clone()).await { Ok(_) => debug!("Activation {id} was sent to worker!"), - Err(e) => error!("Pushing activation {id} resulted in error - {:?}", e), + Err(e) => { + error!("Pushing activation {id} resulted in error - {:?}", e); + if let Err(e) = store + .set_status(&id, InflightActivationStatus::Pending) + .await + { + error!("Failed to change task {id} back to pending - {e:?}"); + } + } }; } } diff --git a/src/push/tests.rs b/src/push/tests.rs index edc45cf7..76a35251 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -1,14 +1,126 @@ use std::sync::Arc; -use anyhow::anyhow; +use anyhow::{Error, anyhow}; +use chrono::{DateTime, Utc}; use sentry_protos::taskbroker::v1::PushTaskRequest; use tokio::time::{Duration, timeout}; use tonic::async_trait; use super::*; use crate::config::Config; +use crate::store::inflight_activation::{ + FailedTasksForwarder, InflightActivation, InflightActivationStatus, InflightActivationStore, + QueryResult, +}; use crate::test_utils::make_activations; +/// Minimal store for tests. +struct MockStore; + +#[async_trait] +impl InflightActivationStore for MockStore { + async fn vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn full_vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn db_size(&self) -> Result { + unimplemented!() + } + + async fn get_by_id(&self, _id: &str) -> Result, Error> { + unimplemented!() + } + + async fn store(&self, _batch: Vec) -> Result { + unimplemented!() + } + + async fn get_pending_activations_from_namespaces( + &self, + _application: Option<&str>, + _namespaces: Option<&[String]>, + _limit: Option, + ) -> Result, Error> { + unimplemented!() + } + + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { + unimplemented!() + } + + async fn count_by_status(&self, _status: InflightActivationStatus) -> Result { + unimplemented!() + } + + async fn count(&self) -> Result { + unimplemented!() + } + + async fn set_status( + &self, + _id: &str, + _status: InflightActivationStatus, + ) -> Result, Error> { + Ok(None) + } + + async fn set_processing_deadline( + &self, + _id: &str, + _deadline: Option>, + ) -> Result<(), Error> { + unimplemented!() + } + + async fn delete_activation(&self, _id: &str) -> Result<(), Error> { + unimplemented!() + } + + async fn get_retry_activations(&self) -> Result, Error> { + unimplemented!() + } + + async fn clear(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn handle_processing_deadline(&self) -> Result { + unimplemented!() + } + + async fn handle_processing_attempts(&self) -> Result { + unimplemented!() + } + + async fn handle_expires_at(&self) -> Result { + unimplemented!() + } + + async fn handle_delay_until(&self) -> Result { + unimplemented!() + } + + async fn handle_failed_tasks(&self) -> Result { + unimplemented!() + } + + async fn mark_completed(&self, _ids: Vec) -> Result { + unimplemented!() + } + + async fn remove_completed(&self) -> Result { + unimplemented!() + } + + async fn remove_killswitched(&self, _killswitched_tasks: Vec) -> Result { + unimplemented!() + } +} + /// Fake worker client for unit testing. struct MockWorkerClient { /// Capture all received requests so we can assert things about them. @@ -91,8 +203,9 @@ async fn push_pool_submit_enqueues_item() { push_queue_size: 2, ..Config::default() }); + let store = Arc::new(MockStore); - let pool = PushPool::new(config); + let pool = PushPool::new(store, config); let activation = make_activations(1).remove(0); let result = pool.submit(activation).await; @@ -105,8 +218,9 @@ async fn push_pool_submit_backpressures_when_queue_full() { push_queue_size: 1, ..Config::default() }); + let store = Arc::new(MockStore); - let pool = PushPool::new(config); + let pool = PushPool::new(store, config); let first = make_activations(1).remove(0); let second = make_activations(1).remove(0); From 97f00b0fdececdafcd495f6ca05bd9676f1406c1 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 18 Mar 2026 11:27:58 -0700 Subject: [PATCH 12/17] Return Error from `fetch_activation` on Submit Failure --- src/fetch/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 617ae96a..1868a4e5 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -78,7 +78,7 @@ impl FetchPool { // Instead of returning when `fetch_activation` fails, we just try again match fetch_activation(store.clone(), pusher.clone()).await { Ok(false) | Err(_) => { - // Found no pending activations OR there is an issue with the store, wait some time before trying again + // Found no pending activations OR there is an issue with the store OR submitting to push pool failed sleep(Duration::from_millis(fetch_wait_ms)).await; } @@ -130,6 +130,8 @@ pub async fn fetch_activation( { error!("Failed to change task {id} back to pending - {e:?}"); } + + return Err(e); } true From 6db46a0128b6b65a3026064e9039925726227ab1 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 18 Mar 2026 11:37:04 -0700 Subject: [PATCH 13/17] Fix Fetch Unit Tests --- src/fetch/tests.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs index 377bb124..8360bab9 100644 --- a/src/fetch/tests.rs +++ b/src/fetch/tests.rs @@ -213,22 +213,24 @@ async fn fetch_activation_submits_when_pending_exists() { } #[tokio::test] -async fn fetch_activation_logs_submit_error_but_does_not_fail() { +async fn fetch_activation_logs_submit_error_and_returns_err() { let activation = make_activations(1).remove(0); let store: Arc = Arc::new(MockStore::new(MockPendingResult::Some(activation))); let pusher = Arc::new(MockTaskPusher::new(true)); - let found = fetch_activation(store, pusher.clone()) - .await - .expect("fetch should succeed"); + let result = fetch_activation(store, pusher.clone()).await; assert!( - found, - "should return true when activation was found even if push fails" + result.is_err(), + "should return error when push fails (after reverting status to pending)" ); let pushed = pusher.pushed_ids.lock().await; - assert_eq!(pushed.len(), 1, "should attempt one push even if it fails"); + assert_eq!( + pushed.len(), + 1, + "should attempt one push before returning error" + ); } #[tokio::test] From cf06e6b5e26509eff255c2527e37265bbc907564 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 18 Mar 2026 15:47:13 -0700 Subject: [PATCH 14/17] Don't Use `FetchNextTask` When in Push Mode --- src/grpc/server.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/grpc/server.rs b/src/grpc/server.rs index 68d06699..68ee58dc 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -106,6 +106,10 @@ impl ConsumerService for TaskbrokerServer { } metrics::histogram!("grpc_server.set_status.duration").record(start_time.elapsed()); + if self.config.push_mode { + return Ok(Response::new(SetTaskStatusResponse { task: None })); + } + let Some(FetchNextTask { ref namespace, ref application, From d28397d583d85217547412261c0a1564a045a2d5 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 18 Mar 2026 17:51:09 -0700 Subject: [PATCH 15/17] Do Not Reset Task Status on Push Failure --- src/fetch/mod.rs | 14 +----- src/fetch/tests.rs | 2 +- src/main.rs | 2 +- src/push/mod.rs | 37 +++----------- src/push/tests.rs | 120 ++------------------------------------------- 5 files changed, 15 insertions(+), 160 deletions(-) diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 1868a4e5..1d500109 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -10,9 +10,7 @@ use tracing::{debug, error, info}; use crate::config::Config; use crate::helpers; use crate::push::PushPool; -use crate::store::inflight_activation::{ - InflightActivation, InflightActivationStatus, InflightActivationStore, -}; +use crate::store::inflight_activation::{InflightActivation, InflightActivationStore}; /// Thin interface for the push pool. It mostly serves to enable proper unit testing, but it also decouples fetch logic from push logic even further. #[async_trait] @@ -121,16 +119,8 @@ pub async fn fetch_activation( debug!("Atomically fetched and marked task {id} as processing"); if let Err(e) = pusher.push_task(activation).await { + // Once processing deadline expires, status will be set back to pending error!("Failed to submit task {id} to push pool - {e:?}"); - - // Change status back to pending - if let Err(e) = store - .set_status(&id, InflightActivationStatus::Pending) - .await - { - error!("Failed to change task {id} back to pending - {e:?}"); - } - return Err(e); } diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs index 8360bab9..1f6af405 100644 --- a/src/fetch/tests.rs +++ b/src/fetch/tests.rs @@ -103,7 +103,7 @@ impl InflightActivationStore for MockStore { _id: &str, _status: InflightActivationStatus, ) -> Result, Error> { - Ok(None) + unimplemented!() } async fn set_processing_deadline( diff --git a/src/main.rs b/src/main.rs index 47b3123e..4efdbc58 100644 --- a/src/main.rs +++ b/src/main.rs @@ -241,7 +241,7 @@ async fn main() -> Result<(), Error> { }); // Initialize push and fetch pools - let push_pool = Arc::new(PushPool::new(store.clone(), config.clone())); + let push_pool = Arc::new(PushPool::new(config.clone())); let fetch_pool = FetchPool::new(store.clone(), config.clone(), push_pool.clone()); // Initialize push threads diff --git a/src/push/mod.rs b/src/push/mod.rs index d189c9a6..54463bc1 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -13,9 +13,7 @@ use tracing::{debug, error, info}; use crate::config::Config; use crate::helpers; -use crate::store::inflight_activation::{ - InflightActivation, InflightActivationStatus, InflightActivationStore, -}; +use crate::store::inflight_activation::InflightActivation; /// Thin interface for the worker client. It mostly serves to enable proper unit testing, but it also decouples the actual client implementation from our pushing logic. #[async_trait] @@ -40,33 +38,27 @@ pub struct PushPool { /// The receiving end of a channel that accepts task activations. receiver: Receiver, - /// Inflight activation store (used to set status back to pending on push failure). - store: Arc, - /// Taskbroker configuration. config: Arc, } impl PushPool { /// Initialize a new push pool. - pub fn new(store: Arc, config: Arc) -> Self { + pub fn new(config: Arc) -> Self { let (sender, receiver) = flume::bounded(config.push_queue_size); Self { sender, receiver, - store, config, } } /// Spawn `config.push_threads` asynchronous tasks, each of which repeatedly moves pending activations from the channel to the worker service until the shutdown signal is received. pub async fn start(&self) -> Result<()> { - let store = self.store.clone(); let mut push_pool = helpers::spawn_pool(self.config.push_threads, |_| { let endpoint = self.config.worker_endpoint.clone(); let receiver = self.receiver.clone(); - let store = store.clone(); let guard = get_shutdown_guard().shutdown_on_drop(); @@ -104,16 +96,9 @@ impl PushPool { match push_task(&mut worker, activation, callback_url.clone()).await { Ok(_) => debug!("Activation {id} was sent to worker!"), - Err(e) => { - error!("Pushing activation {id} resulted in error - {:?}", e); - - if let Err(e) = store - .set_status(&id, InflightActivationStatus::Pending) - .await - { - error!("Failed to change task {id} back to pending - {e:?}"); - } - } + + // Once processing deadline expires, status will be set back to pending + Err(e) => error!("Pushing activation {id} resulted in error - {:?}", e) }; } } @@ -125,15 +110,9 @@ impl PushPool { match push_task(&mut worker, activation, callback_url.clone()).await { Ok(_) => debug!("Activation {id} was sent to worker!"), - Err(e) => { - error!("Pushing activation {id} resulted in error - {:?}", e); - if let Err(e) = store - .set_status(&id, InflightActivationStatus::Pending) - .await - { - error!("Failed to change task {id} back to pending - {e:?}"); - } - } + + // Once processing deadline expires, status will be set back to pending + Err(e) => error!("Pushing activation {id} resulted in error - {:?}", e), }; } } diff --git a/src/push/tests.rs b/src/push/tests.rs index 76a35251..edc45cf7 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -1,126 +1,14 @@ use std::sync::Arc; -use anyhow::{Error, anyhow}; -use chrono::{DateTime, Utc}; +use anyhow::anyhow; use sentry_protos::taskbroker::v1::PushTaskRequest; use tokio::time::{Duration, timeout}; use tonic::async_trait; use super::*; use crate::config::Config; -use crate::store::inflight_activation::{ - FailedTasksForwarder, InflightActivation, InflightActivationStatus, InflightActivationStore, - QueryResult, -}; use crate::test_utils::make_activations; -/// Minimal store for tests. -struct MockStore; - -#[async_trait] -impl InflightActivationStore for MockStore { - async fn vacuum_db(&self) -> Result<(), Error> { - unimplemented!() - } - - async fn full_vacuum_db(&self) -> Result<(), Error> { - unimplemented!() - } - - async fn db_size(&self) -> Result { - unimplemented!() - } - - async fn get_by_id(&self, _id: &str) -> Result, Error> { - unimplemented!() - } - - async fn store(&self, _batch: Vec) -> Result { - unimplemented!() - } - - async fn get_pending_activations_from_namespaces( - &self, - _application: Option<&str>, - _namespaces: Option<&[String]>, - _limit: Option, - ) -> Result, Error> { - unimplemented!() - } - - async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { - unimplemented!() - } - - async fn count_by_status(&self, _status: InflightActivationStatus) -> Result { - unimplemented!() - } - - async fn count(&self) -> Result { - unimplemented!() - } - - async fn set_status( - &self, - _id: &str, - _status: InflightActivationStatus, - ) -> Result, Error> { - Ok(None) - } - - async fn set_processing_deadline( - &self, - _id: &str, - _deadline: Option>, - ) -> Result<(), Error> { - unimplemented!() - } - - async fn delete_activation(&self, _id: &str) -> Result<(), Error> { - unimplemented!() - } - - async fn get_retry_activations(&self) -> Result, Error> { - unimplemented!() - } - - async fn clear(&self) -> Result<(), Error> { - unimplemented!() - } - - async fn handle_processing_deadline(&self) -> Result { - unimplemented!() - } - - async fn handle_processing_attempts(&self) -> Result { - unimplemented!() - } - - async fn handle_expires_at(&self) -> Result { - unimplemented!() - } - - async fn handle_delay_until(&self) -> Result { - unimplemented!() - } - - async fn handle_failed_tasks(&self) -> Result { - unimplemented!() - } - - async fn mark_completed(&self, _ids: Vec) -> Result { - unimplemented!() - } - - async fn remove_completed(&self) -> Result { - unimplemented!() - } - - async fn remove_killswitched(&self, _killswitched_tasks: Vec) -> Result { - unimplemented!() - } -} - /// Fake worker client for unit testing. struct MockWorkerClient { /// Capture all received requests so we can assert things about them. @@ -203,9 +91,8 @@ async fn push_pool_submit_enqueues_item() { push_queue_size: 2, ..Config::default() }); - let store = Arc::new(MockStore); - let pool = PushPool::new(store, config); + let pool = PushPool::new(config); let activation = make_activations(1).remove(0); let result = pool.submit(activation).await; @@ -218,9 +105,8 @@ async fn push_pool_submit_backpressures_when_queue_full() { push_queue_size: 1, ..Config::default() }); - let store = Arc::new(MockStore); - let pool = PushPool::new(store, config); + let pool = PushPool::new(config); let first = make_activations(1).remove(0); let second = make_activations(1).remove(0); From d42ffb5ab49da7d9b6e63992413dc8cfa5c32a75 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 18 Mar 2026 18:01:48 -0700 Subject: [PATCH 16/17] Change Push Mode Config Name, Rename Helpers Module --- src/config.rs | 53 +++++++++++++++++++++++++++++++++--- src/fetch/mod.rs | 3 +- src/grpc/server.rs | 6 ++-- src/grpc/server_tests.rs | 4 +-- src/lib.rs | 2 +- src/main.rs | 6 ++-- src/push/mod.rs | 3 +- src/{helpers.rs => tokio.rs} | 0 8 files changed, 60 insertions(+), 17 deletions(-) rename src/{helpers.rs => tokio.rs} (100%) diff --git a/src/config.rs b/src/config.rs index 152b9982..4eafa381 100644 --- a/src/config.rs +++ b/src/config.rs @@ -19,6 +19,17 @@ pub enum DatabaseAdapter { Postgres, } +/// How the taskbroker delivers tasks to workers. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum DeliveryMode { + /// Workers pull tasks from the broker. + Pull, + + /// Broker pushes tasks to workers. + Push, +} + #[derive(PartialEq, Debug, Deserialize, Serialize)] pub struct Config { /// The sentry DSN to use for error reporting. @@ -240,8 +251,8 @@ pub struct Config { /// Enable additional metrics for the sqlite. pub enable_sqlite_status_metrics: bool, - /// Run the taskbroker in push mode (as opposed to pull mode). - pub push_mode: bool, + /// How to deliver tasks to workers: "push" or "pull". + pub delivery_mode: DeliveryMode, /// The number of concurrent dispatchers to run. pub fetch_threads: usize, @@ -335,7 +346,7 @@ impl Default for Config { full_vacuum_on_upkeep: true, vacuum_interval_ms: 30000, enable_sqlite_status_metrics: true, - push_mode: false, + delivery_mode: DeliveryMode::Pull, fetch_threads: 1, fetch_wait_ms: 100, push_threads: 1, @@ -458,7 +469,7 @@ impl Provider for Config { mod tests { use std::{borrow::Cow, collections::BTreeMap}; - use super::{Config, DatabaseAdapter}; + use super::{Config, DatabaseAdapter, DeliveryMode}; use crate::{Args, logging::LogFormat}; use figment::Jail; @@ -749,6 +760,40 @@ mod tests { }); } + #[test] + fn test_default_delivery_mode() { + let config = Config::default(); + assert_eq!(config.delivery_mode, DeliveryMode::Pull); + } + + #[test] + fn test_from_args_delivery_mode_from_env() { + Jail::expect_with(|jail| { + jail.set_env("TASKBROKER_DELIVERY_MODE", "push"); + + let args = Args { config: None }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.delivery_mode, DeliveryMode::Push); + + Ok(()) + }); + } + + #[test] + fn test_from_args_delivery_mode_from_config_file() { + Jail::expect_with(|jail| { + jail.create_file("config.yaml", "delivery_mode: push")?; + + let args = Args { + config: Some("config.yaml".to_owned()), + }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.delivery_mode, DeliveryMode::Push); + + Ok(()) + }); + } + #[test] fn test_default_push_callback_fields() { let config = Config::default(); diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 1d500109..02ec9682 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -8,7 +8,6 @@ use tonic::async_trait; use tracing::{debug, error, info}; use crate::config::Config; -use crate::helpers; use crate::push::PushPool; use crate::store::inflight_activation::{InflightActivation, InflightActivationStore}; @@ -56,7 +55,7 @@ impl FetchPool { pub async fn start(&self) -> Result<()> { let fetch_wait_ms = self.config.fetch_wait_ms; - let mut fetch_pool = helpers::spawn_pool(self.config.fetch_threads, |_| { + let mut fetch_pool = crate::tokio::spawn_pool(self.config.fetch_threads, |_| { let store = self.store.clone(); let pusher = self.pusher.clone(); diff --git a/src/grpc/server.rs b/src/grpc/server.rs index 68ee58dc..7f1fcea2 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -9,7 +9,7 @@ use std::sync::Arc; use std::time::Instant; use tonic::{Request, Response, Status}; -use crate::config::Config; +use crate::config::{Config, DeliveryMode}; use crate::store::inflight_activation::{InflightActivationStatus, InflightActivationStore}; use tracing::{error, instrument}; @@ -25,7 +25,7 @@ impl ConsumerService for TaskbrokerServer { &self, request: Request, ) -> Result, Status> { - if self.config.push_mode { + if self.config.delivery_mode == DeliveryMode::Push { return Err(Status::permission_denied( "Cannot call while broker is in PUSH mode", )); @@ -106,7 +106,7 @@ impl ConsumerService for TaskbrokerServer { } metrics::histogram!("grpc_server.set_status.duration").record(start_time.elapsed()); - if self.config.push_mode { + if self.config.delivery_mode == DeliveryMode::Push { return Ok(Response::new(SetTaskStatusResponse { task: None })); } diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index c96072a3..f0834a95 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::config::Config; +use crate::config::{Config, DeliveryMode}; use crate::grpc::server::TaskbrokerServer; use prost::Message; use rstest::rstest; @@ -16,7 +16,7 @@ use crate::test_utils::{create_config, create_test_store, make_activations}; async fn test_get_task_push_mode_returns_permission_denied() { let store = create_test_store("sqlite").await; let config = Arc::new(Config { - push_mode: true, + delivery_mode: DeliveryMode::Push, ..Config::default() }); diff --git a/src/lib.rs b/src/lib.rs index df19be34..6ce53cd1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,6 @@ use std::fs; pub mod config; pub mod fetch; pub mod grpc; -pub mod helpers; pub mod kafka; pub mod logging; pub mod metrics; @@ -12,6 +11,7 @@ pub mod push; pub mod runtime_config; pub mod store; pub mod test_utils; +pub mod tokio; pub mod upkeep; /// Name of the grpc service. diff --git a/src/main.rs b/src/main.rs index 4efdbc58..74f6a0a6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,7 +17,7 @@ use tracing::{debug, error, info, warn}; use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerServiceServer; use taskbroker::SERVICE_NAME; -use taskbroker::config::{Config, DatabaseAdapter}; +use taskbroker::config::{Config, DatabaseAdapter, DeliveryMode}; use taskbroker::grpc::auth_middleware::AuthLayer; use taskbroker::grpc::metrics_middleware::MetricsLayer; use taskbroker::grpc::server::TaskbrokerServer; @@ -245,14 +245,14 @@ async fn main() -> Result<(), Error> { let fetch_pool = FetchPool::new(store.clone(), config.clone(), push_pool.clone()); // Initialize push threads - let push_task = if config.push_mode { + let push_task = if config.delivery_mode == DeliveryMode::Push { Some(tokio::spawn(async move { push_pool.start().await })) } else { None }; // Initialize fetch threads - let fetch_task = if config.push_mode { + let fetch_task = if config.delivery_mode == DeliveryMode::Push { Some(tokio::spawn(async move { fetch_pool.start().await })) } else { None diff --git a/src/push/mod.rs b/src/push/mod.rs index 54463bc1..daf60581 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -12,7 +12,6 @@ use tonic::transport::Channel; use tracing::{debug, error, info}; use crate::config::Config; -use crate::helpers; use crate::store::inflight_activation::InflightActivation; /// Thin interface for the worker client. It mostly serves to enable proper unit testing, but it also decouples the actual client implementation from our pushing logic. @@ -56,7 +55,7 @@ impl PushPool { /// Spawn `config.push_threads` asynchronous tasks, each of which repeatedly moves pending activations from the channel to the worker service until the shutdown signal is received. pub async fn start(&self) -> Result<()> { - let mut push_pool = helpers::spawn_pool(self.config.push_threads, |_| { + let mut push_pool = crate::tokio::spawn_pool(self.config.push_threads, |_| { let endpoint = self.config.worker_endpoint.clone(); let receiver = self.receiver.clone(); diff --git a/src/helpers.rs b/src/tokio.rs similarity index 100% rename from src/helpers.rs rename to src/tokio.rs From 90825d44496a8b032a920ad976d44810a21f09bb Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 18 Mar 2026 19:11:44 -0700 Subject: [PATCH 17/17] Add Application and Namespace Filters to Config --- src/config.rs | 53 ++++++++++++++++++++++++++++++++++++++++++++++ src/fetch/mod.rs | 18 +++++++++++++--- src/fetch/tests.rs | 31 +++++++++++++++------------ 3 files changed, 86 insertions(+), 16 deletions(-) diff --git a/src/config.rs b/src/config.rs index 4eafa381..db954538 100644 --- a/src/config.rs +++ b/src/config.rs @@ -277,6 +277,12 @@ pub struct Config { /// The port used to construct `callback_url` for task push requests. pub callback_port: u32, + + /// Application filter for push mode. When set, only pending activations for this application are considered. + pub application: Option, + + /// List of namespaces for push mode. When set, application must also be set (store requirement). + pub namespaces: Option>, } impl Default for Config { @@ -355,6 +361,8 @@ impl Default for Config { worker_endpoint: "http://127.0.0.1:50052".into(), callback_addr: "0.0.0.0".into(), callback_port: 50051, + application: None, + namespaces: None, } } } @@ -837,4 +845,49 @@ mod tests { Ok(()) }); } + + #[test] + fn test_default_application_and_namespaces() { + let config = Config::default(); + assert_eq!(config.application, None); + assert_eq!(config.namespaces, None); + } + + #[test] + fn test_from_args_application_from_env() { + Jail::expect_with(|jail| { + jail.set_env("TASKBROKER_APPLICATION", "getsentry"); + + let args = Args { config: None }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.application.as_deref(), Some("getsentry")); + assert_eq!(config.namespaces, None); + + Ok(()) + }); + } + + #[test] + fn test_from_args_application_and_namespaces_from_config_file() { + Jail::expect_with(|jail| { + jail.create_file( + "config.yaml", + r#" + application: getsentry + namespaces: + - ns1 + - ns2 + "#, + )?; + + let args = Args { + config: Some("config.yaml".to_owned()), + }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.application.as_deref(), Some("getsentry")); + assert_eq!(config.namespaces, Some(vec!["ns1".into(), "ns2".into()])); + + Ok(()) + }); + } } diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 02ec9682..6c701e4e 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -58,6 +58,7 @@ impl FetchPool { let mut fetch_pool = crate::tokio::spawn_pool(self.config.fetch_threads, |_| { let store = self.store.clone(); let pusher = self.pusher.clone(); + let config = self.config.clone(); let guard = get_shutdown_guard().shutdown_on_drop(); @@ -73,7 +74,7 @@ impl FetchPool { debug!("About to fetch next activation..."); // Instead of returning when `fetch_activation` fails, we just try again - match fetch_activation(store.clone(), pusher.clone()).await { + match fetch_activation(store.clone(), pusher.clone(), config.clone()).await { Ok(false) | Err(_) => { // Found no pending activations OR there is an issue with the store OR submitting to push pool failed sleep(Duration::from_millis(fetch_wait_ms)).await; @@ -106,15 +107,26 @@ impl FetchPool { pub async fn fetch_activation( store: Arc, pusher: Arc, + config: Arc, ) -> Result { let start = Instant::now(); metrics::counter!("fetch.fetch_activation.calls").increment(1); debug!("Fetching next pending activation..."); - let found = match store.get_pending_activation(None, None).await { - Ok(Some(activation)) => { + let result = store + .get_pending_activations_from_namespaces( + config.application.as_deref(), + config.namespaces.as_deref(), + Some(1), + ) + .await; + + let found = match result { + Ok(activations) if !activations.is_empty() => { + let activation = activations.into_iter().next().unwrap(); let id = activation.id.clone(); + debug!("Atomically fetched and marked task {id} as processing"); if let Err(e) = pusher.push_task(activation).await { diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs index 1f6af405..6f95147b 100644 --- a/src/fetch/tests.rs +++ b/src/fetch/tests.rs @@ -25,10 +25,10 @@ enum MockPendingResult { /// Fake store for testing. struct MockStore { - /// How should all calls to `get_pending_activation` respond? + /// How should all calls to `get_pending_activations_from_namespaces` respond? pending_result: MockPendingResult, - /// How many calls to `get_pending_activation` have been performed? + /// How many calls to `get_pending_activations_from_namespaces` have been performed? pending_calls: AtomicUsize, } @@ -69,12 +69,7 @@ impl InflightActivationStore for MockStore { _application: Option<&str>, _namespace: Option<&str>, ) -> Result, Error> { - self.pending_calls.fetch_add(1, Ordering::SeqCst); - match &self.pending_result { - MockPendingResult::Some(activation) => Ok(Some(activation.clone())), - MockPendingResult::None => Ok(None), - MockPendingResult::Err => Err(anyhow!("mock store error")), - } + unimplemented!() } async fn get_pending_activations_from_namespaces( @@ -83,7 +78,13 @@ impl InflightActivationStore for MockStore { _namespaces: Option<&[String]>, _limit: Option, ) -> Result, Error> { - unimplemented!() + self.pending_calls.fetch_add(1, Ordering::SeqCst); + + match &self.pending_result { + MockPendingResult::Some(activation) => Ok(vec![activation.clone()]), + MockPendingResult::None => Ok(vec![]), + MockPendingResult::Err => Err(anyhow!("mock store error")), + } } async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { @@ -199,7 +200,8 @@ async fn fetch_activation_submits_when_pending_exists() { Arc::new(MockStore::new(MockPendingResult::Some(activation.clone()))); let pusher = Arc::new(MockTaskPusher::new(false)); - let found = fetch_activation(store, pusher.clone()) + let config = Arc::new(Config::default()); + let found = fetch_activation(store, pusher.clone(), config) .await .expect("fetch should succeed"); assert!( @@ -219,7 +221,8 @@ async fn fetch_activation_logs_submit_error_and_returns_err() { Arc::new(MockStore::new(MockPendingResult::Some(activation))); let pusher = Arc::new(MockTaskPusher::new(true)); - let result = fetch_activation(store, pusher.clone()).await; + let config = Arc::new(Config::default()); + let result = fetch_activation(store, pusher.clone(), config).await; assert!( result.is_err(), "should return error when push fails (after reverting status to pending)" @@ -238,7 +241,8 @@ async fn fetch_activation_no_pending_returns_false() { let store: Arc = Arc::new(MockStore::new(MockPendingResult::None)); let pusher = Arc::new(MockTaskPusher::new(false)); - let found = fetch_activation(store, pusher.clone()) + let config = Arc::new(Config::default()); + let found = fetch_activation(store, pusher.clone(), config) .await .expect("fetch should succeed"); assert!(!found, "should return false when no activation is pending"); @@ -255,7 +259,8 @@ async fn fetch_activation_store_error_returns_false() { let store: Arc = Arc::new(MockStore::new(MockPendingResult::Err)); let pusher = Arc::new(MockTaskPusher::new(false)); - let result = fetch_activation(store, pusher.clone()).await; + let config = Arc::new(Config::default()); + let result = fetch_activation(store, pusher.clone(), config).await; assert!( result.is_err(), "should return error when pending activation lookup fails"