diff --git a/architecture/02-schema.md b/architecture/02-schema.md index b02ddfc050..10cceef123 100644 --- a/architecture/02-schema.md +++ b/architecture/02-schema.md @@ -214,12 +214,15 @@ Each `SchemaType` produces its JSON Schema fragment via `JSONSchema()`: ## Cog-Specific Extensions -| Extension | Purpose | -| --------------------- | ------------------------------------------------- | -| `x-order` | Preserves parameter order from function signature | -| `x-cog-array-type` | Marks iterators vs regular arrays | -| `x-cog-array-display` | Hints for how to display streaming output | -| `x-cog-secret` | Marks sensitive inputs | +| Extension | Purpose | +| --------------------- | --------------------------------------------------- | +| `x-order` | Preserves parameter order from function signature | +| `x-cog-array-type` | Marks iterators vs regular arrays | +| `x-cog-array-display` | Hints for how to display streaming output | +| `x-cog-secret` | Marks sensitive inputs | +| `x-cog-streaming` | Marks prediction operations that accept SSE clients | + +Iterator output types describe the shape of accumulated JSON output. SSE response support is a separate prediction operation capability and is only advertised when the prediction handler opts in with `@cog.streaming`. ## Where the Schema Lives diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index c0aec8b9ed..b49ba1fc5e 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -7,7 +7,7 @@ //! 4. Run event loop routing responses to predictions //! 5. On worker crash: fail all predictions, shut down -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::process::Stdio; use std::sync::Arc; use std::sync::Mutex as StdMutex; @@ -29,6 +29,8 @@ use crate::bridge::transport::create_transport; use crate::permit::{InactiveSlotIdleToken, PermitPool, SlotIdleToken}; use crate::prediction::Prediction; +const MAX_PENDING_CANCELLATIONS: usize = 1000; + /// Upload a file to a signed endpoint, returning the final URL. /// /// Matches the behavior of Python cog's `put_file_to_signed_endpoint`: @@ -353,15 +355,18 @@ pub struct OrchestratorReady { pub setup_logs: String, } +struct RegisterPredictionMessage { + slot_id: SlotId, + prediction: Arc>, + idle_sender: tokio::sync::oneshot::Sender, + registered_ack: tokio::sync::oneshot::Sender<()>, +} + pub struct OrchestratorHandle { child: Child, ctrl_writer: Arc>>>, - register_tx: mpsc::Sender<( - SlotId, - Arc>, - tokio::sync::oneshot::Sender, - )>, + register_tx: mpsc::Sender, healthcheck_tx: mpsc::Sender>, cancel_tx: mpsc::Sender, slot_ids: Vec, @@ -375,10 +380,17 @@ impl Orchestrator for OrchestratorHandle { prediction: Arc>, idle_sender: tokio::sync::oneshot::Sender, ) { + let (ack_tx, ack_rx) = tokio::sync::oneshot::channel(); let _ = self .register_tx - .send((slot_id, prediction, idle_sender)) + .send(RegisterPredictionMessage { + slot_id, + prediction, + idle_sender, + registered_ack: ack_tx, + }) .await; + let _ = ack_rx.await; } async fn cancel_by_prediction_id(&self, prediction_id: &str) -> Result<(), OrchestratorError> { @@ -688,6 +700,18 @@ pub async fn spawn_worker( }) } +fn record_pending_cancellation(pending_cancellations: &mut HashSet, prediction_id: String) { + if pending_cancellations.len() >= MAX_PENDING_CANCELLATIONS { + tracing::warn!( + prediction_id = %prediction_id, + cap = MAX_PENDING_CANCELLATIONS, + "Dropping pending cancellation because the pending cancellation buffer is full" + ); + return; + } + pending_cancellations.insert(prediction_id); +} + #[allow(clippy::too_many_arguments)] async fn run_event_loop( mut ctrl_reader: FramedRead>, @@ -698,11 +722,7 @@ async fn run_event_loop( SlotId, FramedRead>, )>, - mut register_rx: mpsc::Receiver<( - SlotId, - Arc>, - tokio::sync::oneshot::Sender, - )>, + mut register_rx: mpsc::Receiver, mut healthcheck_rx: mpsc::Receiver>, mut cancel_rx: mpsc::Receiver, pool: Arc, @@ -718,6 +738,7 @@ async fn run_event_loop( let mut pending_healthchecks: Vec> = Vec::new(); let mut healthcheck_counter: u64 = 0; let mut pending_uploads: HashMap>> = HashMap::new(); + let mut pending_cancellations: HashSet = HashSet::new(); let (slot_msg_tx, mut slot_msg_rx) = mpsc::channel::<(SlotId, Result)>(100); @@ -923,17 +944,19 @@ async fn run_event_loop( } } None => { - tracing::debug!(%prediction_id, "Cancel requested for unknown prediction (may have already completed)"); + tracing::debug!(%prediction_id, "Cancel requested for unknown prediction; storing pending cancellation"); + record_pending_cancellation(&mut pending_cancellations, prediction_id); } } } - Some((slot_id, prediction, idle_sender)) = register_rx.recv() => { + Some(RegisterPredictionMessage { slot_id, prediction, idle_sender, registered_ack }) = register_rx.recv() => { let prediction_id = match try_lock_prediction(&prediction) { Some(p) => p.id().to_string(), None => { // Mutex poisoned during registration - prediction already failed tracing::error!(%slot_id, "Prediction mutex poisoned during registration"); + let _ = registered_ack.send(()); continue; } }; @@ -949,6 +972,24 @@ async fn run_event_loop( ); tracing::debug!(%slot_id, %prediction_id, "Registered prediction"); predictions.insert(slot_id, prediction); + let pending_cancel = pending_cancellations.remove(&prediction_id); + let _ = registered_ack.send(()); + if pending_cancel { + tracing::info!( + target: "coglet::prediction", + %prediction_id, + %slot_id, + "Applying pending cancellation" + ); + let mut writer = ctrl_writer.lock().await; + if let Err(e) = writer.send(ControlRequest::Cancel { slot: slot_id }).await { + tracing::error!( + %slot_id, + error = %e, + "Failed to send pending cancel request to worker" + ); + } + } } Some((slot_id, result)) = slot_msg_rx.recv() => { @@ -966,7 +1007,7 @@ async fn run_event_loop( Ok(SlotResponse::LogLine { source, data }) => { let (prediction_id, poisoned) = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { - p.append_log(&data); + p.append_log_source(source, &data); (Some(p.id().to_string()), false) } else { (None, true) @@ -1015,10 +1056,10 @@ async fn run_event_loop( predictions.remove(&slot_id); } } - Ok(SlotResponse::OutputChunk { output, index: _ }) => { + Ok(SlotResponse::OutputChunk { output, index }) => { let poisoned = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { - p.append_output(output); + p.append_output_chunk(output, index); false } else { true @@ -1255,6 +1296,19 @@ mod tests { assert_eq!(result.into_values(), Vec::::new()); } + #[test] + fn record_pending_cancellation_caps_stored_ids() { + let mut pending = HashSet::new(); + for index in 0..MAX_PENDING_CANCELLATIONS { + record_pending_cancellation(&mut pending, format!("pred-{index}")); + } + + record_pending_cancellation(&mut pending, "overflow".to_string()); + + assert_eq!(pending.len(), MAX_PENDING_CANCELLATIONS); + assert!(!pending.contains("overflow")); + } + #[test] fn wrap_outputs_schema_array_single_item() { // List[Path] with num_outputs=1 → ["url"] not "url" diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index 81ab018b40..2b52e91129 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -1,15 +1,19 @@ //! Prediction state tracking. -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::sync::Arc; use std::time::Instant; use tokio::sync::Notify; pub use tokio_util::sync::CancellationToken; -use crate::bridge::protocol::MetricMode; +use crate::bridge::protocol::{LogSource, MetricMode}; use crate::webhook::{WebhookEventType, WebhookSender}; +pub const STREAM_CHANNEL_CAPACITY: usize = 1024; +pub const DEFAULT_STREAM_HISTORY_CAPACITY: usize = 1024; +const STREAM_HISTORY_CAPACITY_ENV: &str = "COG_STREAM_HISTORY_CAPACITY"; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PredictionStatus { Starting, @@ -64,6 +68,73 @@ impl PredictionOutput { } } +#[derive(Debug, Clone)] +pub enum PredictionStreamEvent { + Start { + id: String, + status: String, + }, + Output { + chunk: serde_json::Value, + index: u64, + }, + Log { + source: LogSource, + data: String, + }, + Metric { + name: String, + value: serde_json::Value, + mode: MetricMode, + }, + Completed { + payload: serde_json::Value, + }, +} + +pub type SharedPredictionStreamEvent = Arc; + +pub struct PredictionStreamReplay { + pub replay: VecDeque, + pub skipped: u64, + pub receiver: tokio::sync::broadcast::Receiver, +} + +impl PredictionStreamEvent { + pub fn event_name(&self) -> &'static str { + match self { + Self::Start { .. } => "start", + Self::Output { .. } => "output", + Self::Log { .. } => "log", + Self::Metric { .. } => "metric", + Self::Completed { .. } => "completed", + } + } + + pub fn json_data(&self) -> serde_json::Value { + match self { + Self::Start { id, status } => serde_json::json!({ + "id": id, + "status": status, + }), + Self::Output { chunk, index } => serde_json::json!({ + "chunk": chunk, + "index": index, + }), + Self::Log { source, data } => serde_json::json!({ + "source": source, + "data": data, + }), + Self::Metric { name, value, mode } => serde_json::json!({ + "name": name, + "value": value, + "mode": mode, + }), + Self::Completed { payload } => payload.clone(), + } + } +} + /// Prediction lifecycle state. pub struct Prediction { id: String, @@ -76,12 +147,19 @@ pub struct Prediction { error: Option, webhook: Option, completion: Arc, + stream_tx: tokio::sync::broadcast::Sender, + stream_history: VecDeque, + stream_history_capacity: usize, + stream_history_skipped: u64, /// User-emitted metrics. Merged with system metrics (predict_time) in terminal response. metrics: HashMap, } impl Prediction { pub fn new(id: String, webhook: Option) -> Self { + let (stream_tx, _) = tokio::sync::broadcast::channel(STREAM_CHANNEL_CAPACITY); + let stream_history_capacity = stream_history_capacity_from_env(); + Self { id, cancel_token: CancellationToken::new(), @@ -93,6 +171,10 @@ impl Prediction { error: None, webhook, completion: Arc::new(Notify::new()), + stream_tx, + stream_history: VecDeque::new(), + stream_history_capacity, + stream_history_skipped: 0, metrics: HashMap::new(), } } @@ -105,6 +187,66 @@ impl Prediction { self.cancel_token.clone() } + pub fn subscribe_stream( + &self, + ) -> tokio::sync::broadcast::Receiver { + self.stream_tx.subscribe() + } + + pub fn subscribe_stream_replay(&self) -> PredictionStreamReplay { + PredictionStreamReplay { + replay: self.stream_history.clone(), + skipped: self.stream_history_skipped, + receiver: self.stream_tx.subscribe(), + } + } + + pub fn stream_receiver_count(&self) -> usize { + self.stream_tx.receiver_count() + } + + fn emit_stream_event(&mut self, event: PredictionStreamEvent) { + let event = Arc::new(event); + if self.stream_history_capacity > 0 { + if self.stream_history.len() == self.stream_history_capacity { + self.stream_history.pop_front(); + self.stream_history_skipped += 1; + } + self.stream_history.push_back(Arc::clone(&event)); + } + let _ = self.stream_tx.send(event); + } +} + +fn stream_history_capacity_from_env() -> usize { + match std::env::var(STREAM_HISTORY_CAPACITY_ENV) { + Ok(value) => match value.parse::() { + Ok(capacity) => capacity, + Err(error) => { + tracing::warn!( + env_var = STREAM_HISTORY_CAPACITY_ENV, + value, + error = %error, + default = DEFAULT_STREAM_HISTORY_CAPACITY, + "Invalid stream history capacity; using default" + ); + DEFAULT_STREAM_HISTORY_CAPACITY + } + }, + Err(std::env::VarError::NotPresent) => DEFAULT_STREAM_HISTORY_CAPACITY, + Err(error) => { + tracing::warn!( + env_var = STREAM_HISTORY_CAPACITY_ENV, + error = %error, + default = DEFAULT_STREAM_HISTORY_CAPACITY, + "Invalid stream history capacity; using default" + ); + DEFAULT_STREAM_HISTORY_CAPACITY + } + } +} + +impl Prediction { pub fn is_canceled(&self) -> bool { self.cancel_token.is_cancelled() } @@ -119,6 +261,10 @@ impl Prediction { pub fn set_processing(&mut self) { self.status = PredictionStatus::Processing; + self.emit_stream_event(PredictionStreamEvent::Start { + id: self.id.clone(), + status: self.status.as_str().to_string(), + }); self.fire_webhook(WebhookEventType::Start); } @@ -128,6 +274,9 @@ impl Prediction { } self.status = PredictionStatus::Succeeded; self.output = Some(output); + self.emit_stream_event(PredictionStreamEvent::Completed { + payload: self.build_state_snapshot(), + }); self.fire_terminal_webhook(); // notify_one stores a permit so a future .notified().await will // consume it immediately. notify_waiters only wakes currently- @@ -144,6 +293,9 @@ impl Prediction { } self.status = PredictionStatus::Failed; self.error = Some(error); + self.emit_stream_event(PredictionStreamEvent::Completed { + payload: self.build_state_snapshot(), + }); self.fire_terminal_webhook(); self.completion.notify_one(); } @@ -153,6 +305,9 @@ impl Prediction { return; } self.status = PredictionStatus::Canceled; + self.emit_stream_event(PredictionStreamEvent::Completed { + payload: self.build_state_snapshot(), + }); self.fire_terminal_webhook(); self.completion.notify_one(); } @@ -162,7 +317,15 @@ impl Prediction { } pub fn append_log(&mut self, data: &str) { + self.append_log_source(LogSource::Stdout, data); + } + + pub fn append_log_source(&mut self, source: LogSource, data: &str) { self.logs.push_str(data); + self.emit_stream_event(PredictionStreamEvent::Log { + source, + data: data.to_string(), + }); self.fire_webhook(WebhookEventType::Logs); } @@ -184,6 +347,12 @@ impl Prediction { return; } + self.emit_stream_event(PredictionStreamEvent::Metric { + name: name.clone(), + value: value.clone(), + mode, + }); + // Dot-path resolution: "a.b.c" → nested objects let parts: Vec<&str> = name.split('.').collect(); if parts.len() > 1 { @@ -297,7 +466,16 @@ impl Prediction { } pub fn append_output(&mut self, output: serde_json::Value) { - self.outputs.push(output); + let index = self.outputs.len() as u64; + self.append_output_chunk(output, index); + } + + pub fn append_output_chunk(&mut self, output: serde_json::Value, index: u64) { + self.outputs.push(output.clone()); + self.emit_stream_event(PredictionStreamEvent::Output { + chunk: output, + index, + }); self.fire_webhook(WebhookEventType::Output); } @@ -411,6 +589,36 @@ impl Prediction { #[cfg(test)] mod tests { use super::*; + use std::ffi::OsString; + use std::sync::{Mutex, MutexGuard, OnceLock}; + + struct StreamHistoryCapacityEnvGuard { + previous: Option, + _lock: MutexGuard<'static, ()>, + } + + impl Drop for StreamHistoryCapacityEnvGuard { + fn drop(&mut self) { + match &self.previous { + Some(value) => unsafe { std::env::set_var(STREAM_HISTORY_CAPACITY_ENV, value) }, + None => unsafe { std::env::remove_var(STREAM_HISTORY_CAPACITY_ENV) }, + } + } + } + + fn set_stream_history_capacity(value: Option<&str>) -> StreamHistoryCapacityEnvGuard { + static LOCK: OnceLock> = OnceLock::new(); + let lock = LOCK.get_or_init(|| Mutex::new(())).lock().unwrap(); + let previous = std::env::var_os(STREAM_HISTORY_CAPACITY_ENV); + match value { + Some(value) => unsafe { std::env::set_var(STREAM_HISTORY_CAPACITY_ENV, value) }, + None => unsafe { std::env::remove_var(STREAM_HISTORY_CAPACITY_ENV) }, + } + StreamHistoryCapacityEnvGuard { + previous, + _lock: lock, + } + } #[test] fn status_is_terminal() { @@ -484,6 +692,184 @@ mod tests { assert_eq!(pred.outputs().len(), 2); } + #[tokio::test] + async fn prediction_stream_emits_start_output_log_and_completed() { + let mut prediction = Prediction::new("pred_stream".to_string(), None); + let mut rx = prediction.subscribe_stream(); + + prediction.set_processing(); + prediction.append_output_chunk(serde_json::json!("hello"), 0); + prediction.append_log("loading\n"); + prediction.set_succeeded(PredictionOutput::Stream(vec![serde_json::json!("hello")])); + + let start = rx.recv().await.unwrap(); + assert_eq!(start.event_name(), "start"); + assert_eq!( + start.json_data(), + serde_json::json!({"id":"pred_stream","status":"processing"}) + ); + + let output = rx.recv().await.unwrap(); + assert_eq!(output.event_name(), "output"); + assert_eq!( + output.json_data(), + serde_json::json!({"chunk":"hello","index":0}) + ); + + let log = rx.recv().await.unwrap(); + assert_eq!(log.event_name(), "log"); + assert_eq!( + log.json_data(), + serde_json::json!({"source":"stdout","data":"loading\n"}) + ); + + let completed = rx.recv().await.unwrap(); + assert_eq!(completed.event_name(), "completed"); + assert_eq!(completed.json_data()["id"], "pred_stream"); + assert_eq!(completed.json_data()["status"], "succeeded"); + assert_eq!( + completed.json_data()["output"], + serde_json::json!(["hello"]) + ); + } + + #[tokio::test] + async fn prediction_stream_emits_metric_event() { + let mut prediction = Prediction::new("pred_metric".to_string(), None); + let mut rx = prediction.subscribe_stream(); + + prediction.set_metric( + "tokens".to_string(), + serde_json::json!(1), + MetricMode::Increment, + ); + + let event = rx.recv().await.unwrap(); + assert_eq!(event.event_name(), "metric"); + assert_eq!( + event.json_data(), + serde_json::json!({ + "name":"tokens", + "value":1, + "mode":"increment" + }) + ); + } + + #[tokio::test] + async fn prediction_stream_preserves_log_source() { + let mut prediction = Prediction::new("pred_log_source".to_string(), None); + let mut rx = prediction.subscribe_stream(); + + prediction.append_log_source(crate::bridge::protocol::LogSource::Stderr, "warning\n"); + + let event = rx.recv().await.unwrap(); + assert_eq!(event.event_name(), "log"); + assert_eq!( + event.json_data(), + serde_json::json!({"source":"stderr","data":"warning\n"}) + ); + } + + #[tokio::test] + async fn prediction_stream_replay_includes_already_emitted_events() { + let _guard = set_stream_history_capacity(None); + let mut prediction = Prediction::new("pred_replay".to_string(), None); + + prediction.set_processing(); + prediction.append_output_chunk(serde_json::json!("hello"), 0); + prediction.set_succeeded(PredictionOutput::Stream(vec![serde_json::json!("hello")])); + + let replay = prediction.subscribe_stream_replay(); + let events: Vec<&str> = replay + .replay + .iter() + .map(|event| event.event_name()) + .collect(); + + assert_eq!(events, vec!["start", "output", "completed"]); + assert_eq!(replay.skipped, 0); + assert_eq!( + replay.replay[1].json_data(), + serde_json::json!({"chunk":"hello","index":0}) + ); + assert_eq!(replay.replay[2].json_data()["status"], "succeeded"); + } + + #[tokio::test] + async fn prediction_stream_replay_is_bounded_to_recent_events() { + let _guard = set_stream_history_capacity(None); + let mut prediction = Prediction::new("pred_replay_bounded".to_string(), None); + + prediction.set_processing(); + for index in 0..1100 { + prediction.append_output_chunk(serde_json::json!(index), index); + } + + let replay = prediction.subscribe_stream_replay(); + + assert_eq!(replay.replay.len(), DEFAULT_STREAM_HISTORY_CAPACITY); + assert_eq!(replay.skipped, 77); + assert_eq!( + replay.replay[0].json_data(), + serde_json::json!({"chunk":76,"index":76}) + ); + assert_eq!( + replay.replay[DEFAULT_STREAM_HISTORY_CAPACITY - 1].json_data(), + serde_json::json!({"chunk":1099,"index":1099}) + ); + } + + #[tokio::test] + async fn prediction_stream_replay_uses_configured_history_capacity() { + let _guard = set_stream_history_capacity(Some("2")); + let mut prediction = Prediction::new("pred_replay_configured".to_string(), None); + + prediction.append_output_chunk(serde_json::json!(0), 0); + prediction.append_output_chunk(serde_json::json!(1), 1); + prediction.append_output_chunk(serde_json::json!(2), 2); + + let replay = prediction.subscribe_stream_replay(); + + assert_eq!(replay.replay.len(), 2); + assert_eq!(replay.skipped, 1); + assert_eq!( + replay.replay[0].json_data(), + serde_json::json!({"chunk":1,"index":1}) + ); + assert_eq!( + replay.replay[1].json_data(), + serde_json::json!({"chunk":2,"index":2}) + ); + } + + #[tokio::test] + async fn prediction_stream_replay_can_be_disabled_with_zero_capacity() { + let _guard = set_stream_history_capacity(Some("0")); + let mut prediction = Prediction::new("pred_replay_disabled".to_string(), None); + + prediction.append_output_chunk(serde_json::json!(0), 0); + prediction.append_output_chunk(serde_json::json!(1), 1); + + let replay = prediction.subscribe_stream_replay(); + + assert!(replay.replay.is_empty()); + assert_eq!(replay.skipped, 0); + } + + #[tokio::test] + async fn prediction_stream_replay_uses_default_for_invalid_capacity() { + let _guard = set_stream_history_capacity(Some("nope")); + let mut prediction = Prediction::new("pred_replay_invalid".to_string(), None); + + prediction.append_output_chunk(serde_json::json!(0), 0); + + let replay = prediction.subscribe_stream_replay(); + + assert_eq!(replay.replay.len(), 1); + assert_eq!(replay.skipped, 0); + } + #[tokio::test] async fn wait_returns_immediately_if_terminal() { let mut pred = Prediction::new("test".to_string(), None); diff --git a/crates/coglet/src/service.rs b/crates/coglet/src/service.rs index 300ccab9e7..20b22300af 100644 --- a/crates/coglet/src/service.rs +++ b/crates/coglet/src/service.rs @@ -20,7 +20,10 @@ use crate::health::{Health, SetupResult}; use crate::input_validation::InputValidator; use crate::orchestrator::{HealthcheckResult, Orchestrator}; use crate::permit::{PermitPool, PredictionSlot, UnregisteredPredictionSlot}; -use crate::prediction::{CancellationToken, Prediction, PredictionStatus}; +use crate::prediction::{ + CancellationToken, Prediction, PredictionStatus, STREAM_CHANNEL_CAPACITY, + SharedPredictionStreamEvent, +}; use crate::predictor::{PredictionError, PredictionOutput, PredictionResult}; use crate::version::VersionInfo; use crate::webhook::WebhookSender; @@ -50,6 +53,18 @@ pub enum CreatePredictionError { AtCapacity, } +const MAX_STREAM_SUBSCRIBERS: usize = STREAM_CHANNEL_CAPACITY; + +#[derive(Debug, thiserror::Error)] +pub enum SubscribePredictionStreamError { + #[error("Prediction not found")] + NotFound, + #[error("Too many stream subscribers")] + TooManySubscribers, + #[error("Prediction stream unavailable")] + Unavailable, +} + /// Snapshot of service health for transports to query. #[derive(Debug, Clone)] pub struct HealthSnapshot { @@ -79,6 +94,7 @@ struct PredictionEntry { prediction: Arc>, cancel_token: CancellationToken, input: serde_json::Value, + cancel_on_stream_drop: bool, } /// Handle to a submitted prediction for cancellation on disconnect. @@ -106,6 +122,56 @@ impl PredictionHandle { } } +pub struct PredictionStreamSubscription { + id: String, + replay: std::collections::VecDeque, + skipped: u64, + // Drop order matters: receiver must drop before guard so stream_receiver_count() + // reaches zero before the guard decides whether to cancel on disconnect. + receiver: tokio::sync::broadcast::Receiver, + guard: PredictionStreamGuard, +} + +impl PredictionStreamSubscription { + pub fn prediction_id(&self) -> &str { + &self.id + } + + pub fn into_parts( + self, + ) -> ( + std::collections::VecDeque, + u64, + tokio::sync::broadcast::Receiver, + PredictionStreamGuard, + ) { + (self.replay, self.skipped, self.receiver, self.guard) + } +} + +pub struct PredictionStreamGuard { + id: String, + service: Arc, + cancel_on_stream_drop: bool, +} + +impl Drop for PredictionStreamGuard { + fn drop(&mut self) { + if !self.cancel_on_stream_drop { + return; + } + + // Prediction cleanup may remove the service entry before the SSE response + // finishes draining. Missing entries deliberately report zero receivers and + // terminal state so this guard cannot cancel an already-cleaned prediction. + if self.service.stream_receiver_count(&self.id) == 0 + && !self.service.prediction_is_terminal(&self.id) + { + self.service.cancel(&self.id); + } + } +} + /// Guard for sync predictions - cancels on drop unless disarmed. /// /// When the HTTP connection drops (client disconnect), axum drops the @@ -177,6 +243,7 @@ pub struct PredictionService { schema: RwLock>, input_validator: RwLock>, train_validator: RwLock>, + supports_prediction_streaming: RwLock, } impl PredictionService { @@ -196,6 +263,7 @@ impl PredictionService { schema: RwLock::new(None), input_validator: RwLock::new(None), train_validator: RwLock::new(None), + supports_prediction_streaming: RwLock::new(false), } } @@ -250,6 +318,10 @@ impl PredictionService { self.train_validator.read().await.is_some() } + pub async fn supports_prediction_streaming(&self) -> bool { + *self.supports_prediction_streaming.read().await + } + /// Get the permit pool from orchestrator. pub async fn pool(&self) -> Option> { if let Some(ref state) = *self.orchestrator.read().await { @@ -310,6 +382,9 @@ impl PredictionService { } pub async fn set_schema(&self, schema: serde_json::Value) { + let supports_prediction_streaming = Self::schema_supports_prediction_streaming(&schema); + *self.supports_prediction_streaming.write().await = supports_prediction_streaming; + // Compile input validators from the schema components let validator = InputValidator::from_openapi_schema(&schema); if let Some(v) = &validator { @@ -333,6 +408,16 @@ impl PredictionService { *self.schema.write().await = Some(schema); } + fn schema_supports_prediction_streaming(schema: &serde_json::Value) -> bool { + schema + .get("paths") + .and_then(|paths| paths.get("/predictions")) + .and_then(|path| path.get("post")) + .and_then(|operation| operation.get("x-cog-streaming")) + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + } + pub async fn schema(&self) -> Option { self.schema.read().await.clone() } @@ -415,6 +500,7 @@ impl PredictionService { id: String, input: serde_json::Value, webhook: Option, + cancel_on_stream_drop: bool, ) -> Result<(PredictionHandle, UnregisteredPredictionSlot), CreatePredictionError> { let health = *self.health.read().await; if health != Health::Ready { @@ -442,6 +528,7 @@ impl PredictionService { prediction: prediction_arc, cancel_token: cancel_token.clone(), input, + cancel_on_stream_drop, }, ); @@ -469,6 +556,58 @@ impl PredictionService { Some(response) } + pub fn subscribe_prediction_stream( + self: &Arc, + id: &str, + ) -> Result { + let entry = self + .predictions + .get(id) + .ok_or(SubscribePredictionStreamError::NotFound)?; + let stream = { + let Some(prediction) = try_lock_prediction(&entry.prediction) else { + return Err(SubscribePredictionStreamError::Unavailable); + }; + if prediction.stream_receiver_count() >= MAX_STREAM_SUBSCRIBERS { + return Err(SubscribePredictionStreamError::TooManySubscribers); + } + prediction.subscribe_stream_replay() + }; + let cancel_on_stream_drop = entry.cancel_on_stream_drop; + let id = id.to_string(); + Ok(PredictionStreamSubscription { + id: id.clone(), + replay: stream.replay, + skipped: stream.skipped, + receiver: stream.receiver, + guard: PredictionStreamGuard { + id, + service: Arc::clone(self), + cancel_on_stream_drop, + }, + }) + } + + fn stream_receiver_count(&self, id: &str) -> usize { + self.predictions + .get(id) + .and_then(|entry| { + entry + .prediction + .lock() + .ok() + .map(|p| p.stream_receiver_count()) + }) + .unwrap_or(0) + } + + fn prediction_is_terminal(&self, id: &str) -> bool { + self.predictions + .get(id) + .and_then(|entry| entry.prediction.lock().ok().map(|p| p.is_terminal())) + .unwrap_or(true) + } + /// Run a prediction to completion via orchestrator. pub async fn predict( &self, @@ -541,6 +680,22 @@ impl PredictionService { ))); } + let was_cancelled_before_send = try_lock_prediction(&prediction_arc) + .map(|p| p.is_canceled()) + .unwrap_or(false); + if was_cancelled_before_send + && let Err(e) = state + .orchestrator + .cancel_by_prediction_id(&prediction_id) + .await + { + tracing::error!( + prediction_id = %prediction_id, + error = %e, + "Failed to forward pending cancellation after registration" + ); + } + // Wait for prediction to complete // Check if already terminal first to avoid race with fast completions let (already_terminal, completion) = { @@ -610,21 +765,15 @@ impl PredictionService { // Delegate to orchestrator to actually cancel the worker-side prediction. // This must be non-blocking since cancel() is sync, so we spawn a task. let id_owned = id.to_string(); - let orchestrator = self - .orchestrator - .try_read() - .ok() - .and_then(|guard| guard.as_ref().map(|s| Arc::clone(&s.orchestrator))); + let orchestrator = match self.orchestrator.try_read() { + Ok(guard) => guard.as_ref().map(|s| Arc::clone(&s.orchestrator)), + Err(_) => { + tracing::warn!(prediction_id = %id, "Skipped worker cancel: orchestrator lock unavailable"); + None + } + }; if let Some(orch) = orchestrator { - tokio::spawn(async move { - if let Err(e) = orch.cancel_by_prediction_id(&id_owned).await { - tracing::error!( - prediction_id = %id_owned, - error = %e, - "Failed to send cancel to orchestrator" - ); - } - }); + spawn_orchestrator_cancel(orch, id_owned); } true } else { @@ -646,6 +795,22 @@ impl PredictionService { } } +fn spawn_orchestrator_cancel(orch: Arc, id: String) { + let Ok(handle) = tokio::runtime::Handle::try_current() else { + tracing::warn!(prediction_id = %id, "No tokio runtime available to cancel prediction"); + return; + }; + handle.spawn(async move { + if let Err(e) = orch.cancel_by_prediction_id(&id).await { + tracing::error!( + prediction_id = %id, + error = %e, + "Failed to send cancel to orchestrator" + ); + } + }); +} + /// Build a `SlotRequest::Predict`, spilling the input to disk if it exceeds /// `MAX_INLINE_IPC_SIZE`. This prevents IPC frame overflow on the slot socket. /// @@ -760,6 +925,103 @@ mod tests { } } + struct CountingCancelOrchestrator { + cancel_count: AtomicUsize, + } + + impl CountingCancelOrchestrator { + fn new() -> Self { + Self { + cancel_count: AtomicUsize::new(0), + } + } + + fn cancel_count(&self) -> usize { + self.cancel_count.load(Ordering::SeqCst) + } + } + + #[async_trait::async_trait] + impl Orchestrator for CountingCancelOrchestrator { + async fn register_prediction( + &self, + _slot_id: SlotId, + _prediction: Arc>, + _idle_sender: tokio::sync::oneshot::Sender, + ) { + } + + async fn cancel_by_prediction_id( + &self, + _prediction_id: &str, + ) -> Result<(), crate::orchestrator::OrchestratorError> { + self.cancel_count.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + + async fn healthcheck( + &self, + ) -> Result { + Ok(HealthcheckResult::healthy()) + } + + async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> { + Ok(()) + } + } + + struct CancelRecordingOrchestrator { + cancel_count: AtomicUsize, + prediction: std::sync::Mutex>>>, + } + + impl CancelRecordingOrchestrator { + fn new() -> Self { + Self { + cancel_count: AtomicUsize::new(0), + prediction: std::sync::Mutex::new(None), + } + } + + fn cancel_count(&self) -> usize { + self.cancel_count.load(Ordering::SeqCst) + } + } + + #[async_trait::async_trait] + impl Orchestrator for CancelRecordingOrchestrator { + async fn register_prediction( + &self, + slot_id: SlotId, + prediction: Arc>, + idle_sender: tokio::sync::oneshot::Sender, + ) { + *self.prediction.lock().unwrap() = Some(prediction); + let _ = idle_sender.send(InactiveSlotIdleToken::new(slot_id).activate()); + } + + async fn cancel_by_prediction_id( + &self, + _prediction_id: &str, + ) -> Result<(), crate::orchestrator::OrchestratorError> { + self.cancel_count.fetch_add(1, Ordering::SeqCst); + if let Some(prediction) = self.prediction.lock().unwrap().as_ref() { + prediction.lock().unwrap().set_canceled(); + } + Ok(()) + } + + async fn healthcheck( + &self, + ) -> Result { + Ok(HealthcheckResult::healthy()) + } + + async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> { + Ok(()) + } + } + async fn create_test_pool(num_slots: usize) -> Arc { use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::SlotRequest; @@ -863,7 +1125,7 @@ mod tests { let svc = PredictionService::new_no_pool(); let result = svc - .submit_prediction("test".to_string(), serde_json::json!({}), None) + .submit_prediction("test".to_string(), serde_json::json!({}), None, false) .await; assert!(matches!(result, Err(CreatePredictionError::NotReady))); } @@ -908,7 +1170,7 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-1".to_string(), serde_json::json!({}), None) + .submit_prediction("test-1".to_string(), serde_json::json!({}), None, false) .await .unwrap(); @@ -916,6 +1178,200 @@ mod tests { assert!(svc.prediction_exists("test-1")); } + #[tokio::test] + async fn subscribe_prediction_stream_returns_receiver_for_existing_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::new()); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction("stream-test".to_string(), serde_json::json!({}), None, true) + .await + .unwrap(); + + let subscription = svc.subscribe_prediction_stream("stream-test").unwrap(); + assert_eq!(subscription.prediction_id(), "stream-test"); + } + + #[tokio::test] + async fn dropping_only_sync_stream_subscription_cancels_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction("sync-stream".to_string(), serde_json::json!({}), None, true) + .await + .unwrap(); + + let subscription = svc.subscribe_prediction_stream("sync-stream").unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 1); + } + + #[tokio::test] + async fn dropping_async_json_stream_subscription_does_not_cancel_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "async-json-stream".to_string(), + serde_json::json!({}), + None, + false, + ) + .await + .unwrap(); + + let subscription = svc + .subscribe_prediction_stream("async-json-stream") + .unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 0); + } + + #[tokio::test] + async fn dropping_live_sse_stream_subscription_cancels_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "live-sse-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + + let subscription = svc.subscribe_prediction_stream("live-sse-stream").unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 1); + } + + #[tokio::test] + async fn dropping_one_of_two_sync_stream_subscriptions_does_not_cancel_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "multi-sse-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + + let first = svc.subscribe_prediction_stream("multi-sse-stream").unwrap(); + let second = svc.subscribe_prediction_stream("multi-sse-stream").unwrap(); + drop(first); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 0); + + drop(second); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 1); + } + + #[tokio::test] + async fn subscribe_prediction_stream_rejects_too_many_subscribers() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::new()); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "subscriber-cap".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + + let mut subscriptions = Vec::new(); + for _ in 0..MAX_STREAM_SUBSCRIBERS { + subscriptions.push(svc.subscribe_prediction_stream("subscriber-cap").unwrap()); + } + + assert!(matches!( + svc.subscribe_prediction_stream("subscriber-cap"), + Err(SubscribePredictionStreamError::TooManySubscribers) + )); + } + + #[tokio::test] + async fn dropping_completed_sync_stream_subscription_does_not_cancel_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "completed-sync-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + + { + let entry = svc.predictions.get("completed-sync-stream").unwrap(); + let mut prediction = entry.prediction.lock().unwrap(); + prediction.set_succeeded(crate::PredictionOutput::Single(serde_json::json!("done"))); + } + + let subscription = svc + .subscribe_prediction_stream("completed-sync-stream") + .unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 0); + } + #[tokio::test] async fn submit_returns_at_capacity_when_no_slots() { let svc = PredictionService::new_no_pool(); @@ -927,13 +1383,13 @@ mod tests { // First prediction takes the only slot let (_handle1, _slot1) = svc - .submit_prediction("test-1".to_string(), serde_json::json!({}), None) + .submit_prediction("test-1".to_string(), serde_json::json!({}), None, false) .await .unwrap(); // Second should fail with AtCapacity let result = svc - .submit_prediction("test-2".to_string(), serde_json::json!({}), None) + .submit_prediction("test-2".to_string(), serde_json::json!({}), None, false) .await; assert!(matches!(result, Err(CreatePredictionError::AtCapacity))); } @@ -953,6 +1409,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -970,6 +1427,42 @@ mod tests { assert_eq!(orch_ref.register_count(), 1); } + #[tokio::test] + async fn predict_forwards_cancel_token_set_before_registration() { + let svc = PredictionService::new_no_pool(); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CancelRecordingOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (handle, slot) = svc + .submit_prediction( + "pre-register-cancel".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + handle.cancel_token().cancel(); + + let result = tokio::time::timeout( + Duration::from_millis(100), + svc.predict( + slot, + serde_json::json!({}), + std::collections::HashMap::new(), + ), + ) + .await + .expect("prediction should observe cancellation after registration"); + + assert!(matches!(result, Err(PredictionError::Cancelled))); + assert_eq!(orchestrator_ref.cancel_count(), 1); + } + #[tokio::test] async fn health_shows_busy_when_all_slots_used() { let svc = PredictionService::new_no_pool(); @@ -986,7 +1479,7 @@ mod tests { // After acquiring slot let (_handle, _slot) = svc - .submit_prediction("test-1".to_string(), serde_json::json!({}), None) + .submit_prediction("test-1".to_string(), serde_json::json!({}), None, false) .await .unwrap(); let health = svc.health().await; @@ -1009,6 +1502,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -1048,6 +1542,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -1087,6 +1582,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -1113,7 +1609,12 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-cancel".to_string(), serde_json::json!({}), None) + .submit_prediction( + "test-cancel".to_string(), + serde_json::json!({}), + None, + false, + ) .await .unwrap(); @@ -1139,7 +1640,7 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-guard".to_string(), serde_json::json!({}), None) + .submit_prediction("test-guard".to_string(), serde_json::json!({}), None, false) .await .unwrap(); @@ -1163,7 +1664,12 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-disarm".to_string(), serde_json::json!({}), None) + .submit_prediction( + "test-disarm".to_string(), + serde_json::json!({}), + None, + false, + ) .await .unwrap(); @@ -1187,7 +1693,12 @@ mod tests { svc.set_health(Health::Ready).await; let (_handle, _slot) = svc - .submit_prediction("test-remove".to_string(), serde_json::json!({}), None) + .submit_prediction( + "test-remove".to_string(), + serde_json::json!({}), + None, + false, + ) .await .unwrap(); diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index aa9875340e..5904737391 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -1,12 +1,17 @@ //! HTTP route handlers. +use std::convert::Infallible; use std::sync::Arc; +use std::time::Duration; use axum::{ Router, extract::{DefaultBodyLimit, Path, State}, http::{HeaderMap, StatusCode}, - response::{IntoResponse, Json}, + response::{ + IntoResponse, Json, Response, + sse::{Event, KeepAlive, Sse}, + }, routing::{get, post, put}, }; use serde::{Deserialize, Serialize}; @@ -14,8 +19,12 @@ use serde::{Deserialize, Serialize}; #[cfg(test)] use crate::health::Health; use crate::health::{HealthResponse, SetupResult}; +use crate::prediction::SharedPredictionStreamEvent; use crate::predictor::PredictionError; -use crate::service::{CreatePredictionError, HealthSnapshot, PredictionService}; +use crate::service::{ + CreatePredictionError, HealthSnapshot, PredictionService, PredictionStreamSubscription, + SubscribePredictionStreamError, +}; use crate::version::VersionInfo; use crate::webhook::{TraceContext, WebhookConfig, WebhookEventType, WebhookSender}; @@ -209,6 +218,63 @@ fn should_respond_async(headers: &HeaderMap) -> bool { .unwrap_or(false) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PredictionResponseMode { + SyncJson, + AsyncJson, + AsyncSse, +} + +fn wants_sse(headers: &HeaderMap) -> bool { + headers + .get(axum::http::header::ACCEPT) + .and_then(|value| value.to_str().ok()) + .map(|accept| { + accept + .split(',') + .any(|part| part.trim().split(';').next() == Some("text/event-stream")) + }) + .unwrap_or(false) +} + +fn prediction_response_mode(headers: &HeaderMap) -> PredictionResponseMode { + if wants_sse(headers) { + PredictionResponseMode::AsyncSse + } else if should_respond_async(headers) { + PredictionResponseMode::AsyncJson + } else { + PredictionResponseMode::SyncJson + } +} + +fn streaming_not_supported_response() -> Response { + ( + StatusCode::NOT_ACCEPTABLE, + Json(serde_json::json!({ + "error": "This model does not support streaming responses. Add @cog.streaming to predict() to enable SSE." + })), + ) + .into_response() +} + +fn training_streaming_not_supported_response() -> Response { + ( + StatusCode::NOT_ACCEPTABLE, + Json(serde_json::json!({ + "error": "Training endpoints do not support streaming responses." + })), + ) + .into_response() +} + +fn json_response_mode(headers: &HeaderMap) -> PredictionResponseMode { + if should_respond_async(headers) { + PredictionResponseMode::AsyncJson + } else { + PredictionResponseMode::SyncJson + } +} + fn extract_trace_context(headers: &HeaderMap) -> TraceContext { TraceContext { traceparent: headers @@ -226,7 +292,7 @@ async fn create_prediction( State(service): State>, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -235,7 +301,7 @@ async fn create_prediction( webhook_events_filter: default_webhook_events_filter(), }); let prediction_id = request.id.unwrap_or_else(generate_prediction_id); - let respond_async = should_respond_async(&headers); + let response_mode = prediction_response_mode(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -244,7 +310,7 @@ async fn create_prediction( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, false, ) @@ -256,7 +322,7 @@ async fn create_prediction_idempotent( Path(prediction_id): Path, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -277,15 +343,23 @@ async fn create_prediction_idempotent( "type": "value_error" }] })), - ); + ) + .into_response(); } + let response_mode = prediction_response_mode(&headers); + // Check if prediction with this ID is already in-flight if let Some(response) = service.get_prediction_response(&prediction_id) { - return (StatusCode::ACCEPTED, Json(response)); + if response_mode == PredictionResponseMode::AsyncSse { + if !service.supports_prediction_streaming().await { + return streaming_not_supported_response(); + } + return stream_prediction_response(service, &prediction_id); + } + return (StatusCode::ACCEPTED, Json(response)).into_response(); } - let respond_async = should_respond_async(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -294,7 +368,7 @@ async fn create_prediction_idempotent( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, false, ) @@ -333,10 +407,17 @@ async fn create_prediction_with_id( context: std::collections::HashMap, webhook: Option, webhook_events_filter: Vec, - respond_async: bool, + response_mode: PredictionResponseMode, trace_context: TraceContext, is_training: bool, -) -> (StatusCode, Json) { +) -> Response { + if !is_training + && response_mode == PredictionResponseMode::AsyncSse + && !service.supports_prediction_streaming().await + { + return streaming_not_supported_response(); + } + // Strip unknown fields and validate in one pass. Unknown inputs are // silently dropped to match Replicate's historical API behavior. let (stripped, validation_result) = if is_training { @@ -365,7 +446,8 @@ async fn create_prediction_with_id( return ( StatusCode::UNPROCESSABLE_ENTITY, Json(serde_json::json!({ "detail": detail })), - ); + ) + .into_response(); } let webhook_sender = build_webhook_sender( @@ -376,7 +458,12 @@ async fn create_prediction_with_id( // Submit prediction: creates Prediction, acquires slot, registers in service let (handle, unregistered_slot) = match service - .submit_prediction(prediction_id.clone(), input.clone(), webhook_sender) + .submit_prediction( + prediction_id.clone(), + input.clone(), + webhook_sender, + response_mode == PredictionResponseMode::AsyncSse, + ) .await { Ok(r) => r, @@ -388,7 +475,8 @@ async fn create_prediction_with_id( "error": msg, "status": "failed" })), - ); + ) + .into_response(); } Err(CreatePredictionError::AtCapacity) => { return ( @@ -397,14 +485,27 @@ async fn create_prediction_with_id( "error": "At capacity - all prediction slots busy", "status": "failed" })), - ); + ) + .into_response(); } }; let prediction = unregistered_slot.prediction(); // Async mode: spawn background task, return immediately - if respond_async { + if response_mode != PredictionResponseMode::SyncJson { + let sse_subscription = if response_mode == PredictionResponseMode::AsyncSse { + match service.subscribe_prediction_stream(&prediction_id) { + Ok(subscription) => Some(subscription), + Err(error) => { + service.remove_prediction(&prediction_id); + return stream_subscription_error_response(error); + } + } + } else { + None + }; + let service_clone = Arc::clone(&service); let id_for_cleanup = prediction_id.clone(); let context_async = context.clone(); @@ -417,13 +518,19 @@ async fn create_prediction_with_id( service_clone.remove_prediction(&id_for_cleanup); }); + if response_mode == PredictionResponseMode::AsyncSse { + let subscription = sse_subscription.expect("SSE subscription requested"); + return stream_prediction_subscription_response(subscription); + } + return ( StatusCode::ACCEPTED, Json(serde_json::json!({ "id": prediction_id, "status": "starting" })), - ); + ) + .into_response(); } // Sync mode: spawn prediction into a background task so the slot lifetime @@ -489,6 +596,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } Err(PredictionError::InvalidInput(msg)) => { let metrics = build_metrics(&user_metrics); @@ -502,6 +610,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } Err(PredictionError::NotReady) => { let msg = PredictionError::NotReady.to_string(); @@ -514,6 +623,7 @@ async fn create_prediction_with_id( "status": "failed" })), ) + .into_response() } Err(PredictionError::Failed(msg)) => { let metrics = build_metrics(&user_metrics); @@ -528,6 +638,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } Err(PredictionError::Cancelled) => { let metrics = build_metrics(&user_metrics); @@ -540,6 +651,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } } } @@ -557,6 +669,123 @@ async fn cancel_prediction( } } +fn stream_event_to_sse(event: SharedPredictionStreamEvent) -> Event { + Event::default() + .event(event.event_name()) + .json_data(event.json_data()) + .expect("prediction stream events serialize to JSON") +} + +fn prediction_sse_stream( + subscription: PredictionStreamSubscription, +) -> impl futures::Stream> { + let (replay, replay_skipped, receiver, guard) = subscription.into_parts(); + + struct StreamState { + replay: std::collections::VecDeque, + replay_skipped: u64, + // Drop order matters: receiver must drop before guard so stream_receiver_count() + // reaches zero before the guard decides whether to cancel on disconnect. + receiver: tokio::sync::broadcast::Receiver, + _guard: crate::service::PredictionStreamGuard, + done: bool, + } + + futures::stream::unfold( + StreamState { + replay, + replay_skipped, + receiver, + _guard: guard, + done: false, + }, + |mut state| async move { + if state.done { + return None; + } + + if state.replay_skipped > 0 { + let skipped = state.replay_skipped; + state.replay_skipped = 0; + state.done = true; + let event = Event::default() + .event("error") + .json_data(serde_json::json!({ + "error": "SSE stream replay truncated; events were dropped", + "skipped": skipped, + })) + .expect("SSE replay truncation error serializes to JSON"); + return Some((Ok(event), state)); + } + + if let Some(event) = state.replay.pop_front() { + state.done = event.event_name() == "completed"; + return Some((Ok(stream_event_to_sse(event)), state)); + } + + match state.receiver.recv().await { + Ok(event) => { + state.done = event.event_name() == "completed"; + Some((Ok(stream_event_to_sse(event)), state)) + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => { + tracing::warn!(skipped, "SSE prediction stream receiver lagged"); + state.done = true; + // In the future, this could become backpressure or cursor-based replay. + let event = Event::default() + .event("error") + .json_data(serde_json::json!({ + "error": "SSE stream lagged; events were dropped", + "skipped": skipped, + })) + .expect("SSE lag error serializes to JSON"); + Some((Ok(event), state)) + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => None, + } + }, + ) +} + +fn stream_prediction_response(service: Arc, prediction_id: &str) -> Response { + let subscription = match service.subscribe_prediction_stream(prediction_id) { + Ok(subscription) => subscription, + Err(error) => return stream_subscription_error_response(error), + }; + + stream_prediction_subscription_response(subscription) +} + +fn stream_subscription_error_response(error: SubscribePredictionStreamError) -> Response { + match error { + SubscribePredictionStreamError::NotFound => ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({"error": "Prediction not found"})), + ) + .into_response(), + SubscribePredictionStreamError::TooManySubscribers => ( + StatusCode::TOO_MANY_REQUESTS, + Json(serde_json::json!({"error": "Too many stream subscribers"})), + ) + .into_response(), + SubscribePredictionStreamError::Unavailable => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({"error": "Prediction stream unavailable"})), + ) + .into_response(), + } +} + +fn stream_prediction_subscription_response(subscription: PredictionStreamSubscription) -> Response { + Sse::new(prediction_sse_stream(subscription)) + .keep_alive( + KeepAlive::new() + .interval(Duration::from_secs(15)) + .text("keep-alive"), + ) + .into_response() +} + async fn shutdown(State(service): State>) -> impl IntoResponse { tracing::info!("Shutdown requested via HTTP"); service.trigger_shutdown(); @@ -582,7 +811,11 @@ async fn create_training( State(service): State>, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { + if wants_sse(&headers) { + return training_streaming_not_supported_response(); + } + let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -591,7 +824,7 @@ async fn create_training( webhook_events_filter: default_webhook_events_filter(), }); let prediction_id = request.id.unwrap_or_else(generate_prediction_id); - let respond_async = should_respond_async(&headers); + let response_mode = json_response_mode(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -600,7 +833,7 @@ async fn create_training( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, true, ) @@ -612,7 +845,11 @@ async fn create_training_idempotent( Path(training_id): Path, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { + if wants_sse(&headers) { + return training_streaming_not_supported_response(); + } + let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -633,15 +870,16 @@ async fn create_training_idempotent( "type": "value_error" }] })), - ); + ) + .into_response(); } // Idempotent: return existing state if already submitted if let Some(response) = service.get_prediction_response(&training_id) { - return (StatusCode::ACCEPTED, Json(response)); + return (StatusCode::ACCEPTED, Json(response)).into_response(); } - let respond_async = should_respond_async(&headers); + let response_mode = json_response_mode(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -650,7 +888,7 @@ async fn create_training_idempotent( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, true, ) @@ -897,6 +1135,15 @@ mod tests { service } + async fn enable_prediction_streaming(service: &PredictionService) { + service + .set_schema(serde_json::json!({ + "paths": {"/predictions": {"post": {"x-cog-streaming": true}}}, + "components": {"schemas": {"Input": {"type": "object", "properties": {}}}} + })) + .await; + } + #[tokio::test] async fn health_check_ready_with_orchestrator() { let service = create_ready_service().await; @@ -955,6 +1202,315 @@ mod tests { assert_eq!(json["status"], "starting"); } + #[tokio::test] + async fn prediction_post_with_sse_accept_returns_sse() { + let service = create_ready_service().await; + enable_prediction_streaming(&service).await; + let app = routes(service); + + let response = app + .oneshot( + Request::post("/predictions") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let content_type = response.headers().get("content-type").unwrap(); + assert!( + content_type + .to_str() + .unwrap() + .starts_with("text/event-stream"), + "unexpected content-type: {:?}", + content_type + ); + + let body = response.into_body(); + let bytes = body.collect().await.unwrap().to_bytes(); + let sse = String::from_utf8(bytes.to_vec()).unwrap(); + assert!(sse.contains("event: completed"), "SSE body: {sse}"); + assert!(sse.contains(r#""status":"succeeded""#), "SSE body: {sse}"); + } + + #[tokio::test] + async fn prediction_post_with_sse_accept_rejects_when_not_opted_in() { + let service = create_ready_service().await; + let app = routes(service); + + let response = app + .oneshot( + Request::post("/predictions") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_ACCEPTABLE); + let json = response_json(response).await; + assert_eq!( + json["error"], + "This model does not support streaming responses. Add @cog.streaming to predict() to enable SSE." + ); + } + + #[tokio::test] + async fn lagged_prediction_sse_stream_emits_error_and_closes() { + let service = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::never_complete()); + service.set_orchestrator(pool, orchestrator).await; + service.set_health(Health::Ready).await; + + let (_handle, slot) = service + .submit_prediction( + "lagged-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let subscription = service + .subscribe_prediction_stream("lagged-stream") + .unwrap(); + + { + let prediction = slot.prediction(); + let mut prediction = prediction.lock().unwrap(); + for index in 0..1030 { + prediction.append_output_chunk(serde_json::json!(index), index); + } + } + + let response = Sse::new(prediction_sse_stream(subscription)).into_response(); + let collected = + tokio::time::timeout(Duration::from_millis(100), response.into_body().collect()) + .await + .expect("lagged SSE stream should close after emitting an error") + .unwrap(); + let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + assert!(sse.contains("event: error"), "SSE body: {sse}"); + assert!(sse.contains("SSE stream lagged"), "SSE body: {sse}"); + assert!(sse.contains("skipped"), "SSE body: {sse}"); + } + + #[tokio::test] + async fn truncated_replay_prediction_sse_stream_emits_error_and_closes() { + let service = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::never_complete()); + service.set_orchestrator(pool, orchestrator).await; + service.set_health(Health::Ready).await; + + let (_handle, slot) = service + .submit_prediction( + "truncated-replay".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + + { + let prediction = slot.prediction(); + let mut prediction = prediction.lock().unwrap(); + for index in 0..1030 { + prediction.append_output_chunk(serde_json::json!(index), index); + } + } + + let subscription = service + .subscribe_prediction_stream("truncated-replay") + .unwrap(); + let response = Sse::new(prediction_sse_stream(subscription)).into_response(); + let collected = + tokio::time::timeout(Duration::from_millis(100), response.into_body().collect()) + .await + .expect("truncated replay SSE stream should close after emitting an error") + .unwrap(); + let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + assert!(sse.contains("event: error"), "SSE body: {sse}"); + assert!( + sse.contains("SSE stream replay truncated"), + "SSE body: {sse}" + ); + assert!(sse.contains("skipped"), "SSE body: {sse}"); + } + + #[tokio::test] + async fn failed_prediction_sse_stream_emits_completed_event() { + let service = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::never_complete()); + service.set_orchestrator(pool, orchestrator).await; + service.set_health(Health::Ready).await; + + let (_handle, slot) = service + .submit_prediction( + "failed-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let subscription = service + .subscribe_prediction_stream("failed-stream") + .unwrap(); + + { + let prediction = slot.prediction(); + let mut prediction = prediction.lock().unwrap(); + prediction.set_processing(); + prediction.set_failed("boom".to_string()); + } + + let response = Sse::new(prediction_sse_stream(subscription)).into_response(); + let collected = response.into_body().collect().await.unwrap(); + let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + assert!(sse.contains("event: completed"), "SSE body: {sse}"); + assert!(sse.contains(r#""status":"failed""#), "SSE body: {sse}"); + assert!(sse.contains(r#""error":"boom""#), "SSE body: {sse}"); + } + + #[tokio::test] + async fn canceled_prediction_sse_stream_emits_completed_event() { + let service = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::never_complete()); + service.set_orchestrator(pool, orchestrator).await; + service.set_health(Health::Ready).await; + + let (_handle, slot) = service + .submit_prediction( + "canceled-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let subscription = service + .subscribe_prediction_stream("canceled-stream") + .unwrap(); + + { + let prediction = slot.prediction(); + let mut prediction = prediction.lock().unwrap(); + prediction.set_processing(); + prediction.set_canceled(); + } + + let response = Sse::new(prediction_sse_stream(subscription)).into_response(); + let collected = response.into_body().collect().await.unwrap(); + let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + assert!(sse.contains("event: completed"), "SSE body: {sse}"); + assert!(sse.contains(r#""status":"canceled""#), "SSE body: {sse}"); + } + + #[tokio::test] + async fn prediction_put_with_sse_accept_returns_sse() { + let service = create_ready_service().await; + enable_prediction_streaming(&service).await; + let app = routes(service); + + let response = app + .oneshot( + Request::put("/predictions/sse-put") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let content_type = response.headers().get("content-type").unwrap(); + assert!( + content_type + .to_str() + .unwrap() + .starts_with("text/event-stream"), + "unexpected content-type: {:?}", + content_type + ); + } + + #[tokio::test] + async fn prediction_put_existing_with_sse_accept_returns_sse() { + let service = create_ready_service().await; + enable_prediction_streaming(&service).await; + let (_handle, _slot) = service + .submit_prediction( + "existing-sse-put".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let app = routes(service); + + let response = app + .oneshot( + Request::put("/predictions/existing-sse-put") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let content_type = response.headers().get("content-type").unwrap(); + assert!( + content_type + .to_str() + .unwrap() + .starts_with("text/event-stream"), + "unexpected content-type: {:?}", + content_type + ); + } + + #[tokio::test] + async fn stream_prediction_route_is_removed() { + let service = create_ready_service().await; + let (_handle, _slot) = service + .submit_prediction( + "removed-stream-route".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let app = routes(service); + + let response = app + .oneshot( + Request::get("/predictions/removed-stream-route/stream") + .header("accept", "text/event-stream") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + } + #[tokio::test] async fn prediction_with_custom_id() { let service = create_ready_service().await; @@ -1119,6 +1675,54 @@ mod tests { assert_eq!(json["status"], "succeeded"); } + #[tokio::test] + async fn training_post_with_sse_accept_rejects() { + let service = create_ready_service().await; + let app = routes(service); + + let response = app + .oneshot( + Request::post("/trainings") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_ACCEPTABLE); + let json = response_json(response).await; + assert_eq!( + json["error"], + "Training endpoints do not support streaming responses." + ); + } + + #[tokio::test] + async fn training_put_with_sse_accept_rejects() { + let service = create_ready_service().await; + let app = routes(service); + + let response = app + .oneshot( + Request::put("/trainings/train-sse") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_ACCEPTABLE); + let json = response_json(response).await; + assert_eq!( + json["error"], + "Training endpoints do not support streaming responses." + ); + } + #[tokio::test] async fn training_idempotent_put() { let service = create_ready_service().await; diff --git a/docs/environment.md b/docs/environment.md index 011a03e240..bbf02e2a90 100644 --- a/docs/environment.md +++ b/docs/environment.md @@ -1,12 +1,12 @@ # Environment variables -This guide lists the environment variables that change how Cog functions. +This reference lists the public Cog-specific environment variables that change how Cog behaves. ## Build-time variables ### `COG_SDK_WHEEL` -Controls which cog Python SDK wheel is installed in the Docker image during `cog build`. Takes precedence over `build.sdk_version` in `cog.yaml`. +Controls which Cog Python SDK wheel is installed in the Docker image during `cog build`. Takes precedence over `build.sdk_version` in `cog.yaml`. **Supported values:** @@ -18,98 +18,262 @@ Controls which cog Python SDK wheel is installed in the Docker image during `cog | `https://...` | Install from URL | | `/path/to/wheel.whl` | Install from local file path | -**Default behavior:** +**Default behaviour:** -- **Release builds**: Installs latest cog from PyPI -- **Development builds**: Auto-detects wheel in `dist/` directory, falls back to latest PyPI - -**Examples:** +- Release builds install the latest Cog SDK from PyPI. +- Development builds auto-detect a wheel in `dist/`, then fall back to the latest Cog SDK from PyPI. ```console -# Use specific PyPI version $ COG_SDK_WHEEL=pypi:0.11.0 cog build - -# Use local development wheel $ COG_SDK_WHEEL=dist cog build - -# Use wheel from URL $ COG_SDK_WHEEL=https://example.com/cog-0.12.0-py3-none-any.whl cog build ``` The `dist` option searches for wheels in: 1. `./dist/` (current directory) -2. `$REPO_ROOT/dist/` (if REPO_ROOT is set) +2. `$REPO_ROOT/dist/` (if `REPO_ROOT` is set) 3. `/dist/` (via `git rev-parse`, useful when running from subdirectories) ### `COGLET_WHEEL` Controls which coglet wheel is installed in the Docker image. Coglet is the Rust-based inference server. -**Supported values:** Same as `COG_SDK_WHEEL` +**Supported values:** Same as `COG_SDK_WHEEL`. -**Default behavior:** For development builds, auto-detects a wheel in `dist/`. For release builds, installs the latest version from PyPI. Can be overridden with an explicit value. - -**Examples:** +**Default behaviour:** For development builds, auto-detects a wheel in `dist/`. For release builds, installs the latest version from PyPI. ```console -# Use local development wheel $ COGLET_WHEEL=dist cog build - -# Use specific version from PyPI $ COGLET_WHEEL=pypi:0.1.0 cog build ``` -## Runtime variables +### `COG_CA_CERT` + +Injects a custom CA certificate into the Docker image during `cog build`. This is useful when building behind a corporate proxy or VPN that uses custom certificate authorities (for example, Cloudflare WARP). + +**Supported values:** + +| Value | Description | +| -------------------------------- | ----------------------------------------------------------- | +| `/path/to/cert.crt` | Path to a single PEM certificate file | +| `/path/to/certs/` | Directory of `.crt` and `.pem` files (all are concatenated) | +| `-----BEGIN CERTIFICATE-----...` | Inline PEM certificate | +| `LS0tLS1CRUdJTi...` | Base64-encoded PEM certificate | + +The certificate is installed into the system CA store and the `SSL_CERT_FILE` and `REQUESTS_CA_BUNDLE` environment variables are set automatically in the built image. + +```console +$ COG_CA_CERT=/usr/local/share/ca-certificates/corporate-ca.crt cog build +$ COG_CA_CERT=/etc/custom-certs/ cog build +$ COG_CA_CERT="$(cat /path/to/cert.pem)" cog build +``` + +### `COG_OPENAPI_SCHEMA` + +Uses a pre-built OpenAPI schema instead of generating one from the configured predict or train reference. + +The value must be a path to a JSON schema file. Cog reads that file during schema generation and embeds it in the built image. + +```console +$ COG_OPENAPI_SCHEMA=./openapi.json cog build +``` + +## CLI and local cache variables ### `COG_NO_UPDATE_CHECK` -By default, Cog automatically checks for updates -and notifies you if there is a new version available. +Disables Cog's automatic update check. Set it to any non-empty value. + +```console +$ COG_NO_UPDATE_CHECK=1 cog build +``` + +### `COG_NO_COLOR` -To disable this behavior, -set the `COG_NO_UPDATE_CHECK` environment variable to any value. +Disables coloured CLI output. Set it to any non-empty value. + +Cog also honours the standard `NO_COLOR` environment variable. + +```console +$ COG_NO_COLOR=1 cog predict -i prompt="hello" +``` + +### `COG_SKIP_DOCKER_CHECK` + +Skips the `cog doctor` Docker environment check. Set it to any non-empty value. + +```console +$ COG_SKIP_DOCKER_CHECK=1 cog doctor +``` + +### `COG_CACHE_DIR` + +Overrides Cog's local cache root. + +Cog currently uses this cache for the content-addressed weights store. If unset, Cog uses `$XDG_CACHE_HOME/cog` when `XDG_CACHE_HOME` is set, otherwise `$HOME/.cache/cog`. + +```console +$ COG_CACHE_DIR=/mnt/fast-cache cog weights pull +``` + +## Model reference and registry variables + +### `COG_MODEL` + +Overrides the full model reference used by commands that need a model destination, such as `cog push` and weights commands. + +The value is parsed as a complete model reference (`registry/repo`, `registry/repo:tag`, or `registry/repo@digest`). If no tag is supplied, Cog generates a timestamp tag. + +When `COG_MODEL` is set, it takes precedence over `COG_MODEL_REGISTRY`, `COG_MODEL_REPO`, and `COG_MODEL_TAG`. ```console -$ COG_NO_UPDATE_CHECK=1 cog build # runs without automatic update check +$ COG_MODEL=r8.im/acme/my-model:v1 cog push +``` + +### `COG_MODEL_REGISTRY` + +Overrides only the registry host of the model reference. + +```console +$ COG_MODEL_REGISTRY=registry.example.com cog push +``` + +### `COG_MODEL_REPO` + +Overrides only the repository path of the model reference. The value must not include a registry host, tag, or digest. + +```console +$ COG_MODEL_REPO=acme/my-model cog push +``` + +### `COG_MODEL_TAG` + +Overrides only the tag of the model reference. + +Tags starting with `cog-` are reserved for tags that Cog generates internally and are rejected. + +```console +$ COG_MODEL_TAG=staging cog push +``` + +### `COG_REGISTRY_HOST` + +Changes the default Replicate-compatible registry host used by commands such as `cog login`, base image resolution, and model reference resolution. + +The default is `r8.im`. + +```console +$ COG_REGISTRY_HOST=registry.example.com cog login +``` + +## Runtime server variables + +These variables affect a running model server. Set them in `cog.yaml` under `environment`, pass them with `cog predict -e` or `cog serve -e`, or set them when running the built Docker image. + +### `COG_MAX_CONCURRENCY` + +Controls how many predictions the model server can run concurrently. + +By default, Cog runs one prediction at a time. Invalid values are ignored and the default of `1` is used. + +```console +$ COG_MAX_CONCURRENCY=4 docker run -p 5000:5000 my-model ``` ### `COG_SETUP_TIMEOUT` -Controls the maximum time (in seconds) allowed for the model's `setup()` method to complete. If setup exceeds this timeout, the server will report a setup failure. +Controls the maximum time, in seconds, allowed for the model's `setup()` method to complete. If setup exceeds this timeout, the server reports setup failure. + +By default, there is no timeout. Set to `0` to disable the timeout. Invalid values are ignored with a warning. -By default, there is no timeout — setup runs indefinitely. +```console +$ COG_SETUP_TIMEOUT=300 docker run -p 5000:5000 my-model +``` + +### `COG_LOG_LEVEL` -Set to `0` to disable the timeout (same as default). Invalid values are ignored with a warning. +Controls Coglet runtime log verbosity when `RUST_LOG` is not set. + +Supported values are `debug`, `info`, `warn`, `warning`, and `error`. The default is `info`. ```console -$ COG_SETUP_TIMEOUT=300 docker run -p 5000:5000 my-model # 5-minute setup timeout +$ COG_LOG_LEVEL=debug docker run -p 5000:5000 my-model ``` -### `COG_CA_CERT` +### `COG_THROTTLE_RESPONSE_INTERVAL` -Injects a custom CA certificate into the Docker image during `cog build`. This is useful when building behind a corporate proxy or VPN that uses custom certificate authorities (e.g. Cloudflare WARP). +Controls how often asynchronous webhook `output` and `logs` events are sent, in seconds. -**Supported values:** +The default is `0.5` seconds. Invalid values are ignored and the default is used. `start` and `completed` webhook events are always sent immediately. -| Value | Description | -| -------------------------------- | ----------------------------------------------------------- | -| `/path/to/cert.crt` | Path to a single PEM certificate file | -| `/path/to/certs/` | Directory of `.crt` and `.pem` files (all are concatenated) | -| `-----BEGIN CERTIFICATE-----...` | Inline PEM certificate | -| `LS0tLS1CRUdJTi...` | Base64-encoded PEM certificate | +```console +$ COG_THROTTLE_RESPONSE_INTERVAL=1 docker run -p 5000:5000 my-model +``` -The certificate is installed into the system CA store and the `SSL_CERT_FILE` and `REQUESTS_CA_BUNDLE` environment variables are set automatically in the built image. +### `COG_STREAM_HISTORY_CAPACITY` -**Examples:** +Controls how many server-sent event stream events are retained per prediction for replay when a client reconnects with `Accept: text/event-stream`. + +By default, Cog retains the most recent 1024 events per prediction. Set to `0` to disable replay history while keeping live streaming enabled. Invalid values are ignored with a warning and the default is used. ```console -# From a file -$ COG_CA_CERT=/usr/local/share/ca-certificates/corporate-ca.crt cog build +$ COG_STREAM_HISTORY_CAPACITY=0 docker run -p 5000:5000 my-model +$ COG_STREAM_HISTORY_CAPACITY=4096 docker run -p 5000:5000 my-model +``` -# From a directory of certs -$ COG_CA_CERT=/etc/custom-certs/ cog build +### `COG_WEIGHTS` -# Inline (e.g. from a CI secret) -$ COG_CA_CERT="$(cat /path/to/cert.pem)" cog build +Provides a weights path or URL to a model whose `setup()` method accepts a `weights` parameter. + +```console +$ cog run -e COG_WEIGHTS=https://example.com/weights.tar -i prompt="hello" +``` + +### `COG_USER_AGENT` + +Sets the `User-Agent` header used by Cog when downloading URL-backed `File` inputs. + +```console +$ COG_USER_AGENT="my-service/1.0" docker run -p 5000:5000 my-model +``` + +## Push tuning variables + +### `COG_PUSH_OCI` + +Enables Cog's OCI chunked push path for container image layers when set to `1`. If the OCI push fails with a non-fatal error, Cog falls back to Docker's native push path. + +```console +$ COG_PUSH_OCI=1 cog push +``` + +### `COG_PUSH_CONCURRENCY` + +Controls how many image layers or weight blobs Cog uploads concurrently during push operations. + +The default is `5`. Invalid values and values less than `1` are ignored. + +```console +$ COG_PUSH_CONCURRENCY=2 cog push +``` + +### `COG_PUSH_DEFAULT_CHUNK_SIZE` + +Sets the default multipart upload chunk size, in bytes, when the registry does not advertise a maximum chunk size. + +The default is 96 MiB. Invalid values and values less than `1` are ignored. + +```console +$ COG_PUSH_DEFAULT_CHUNK_SIZE=67108864 cog push +``` + +### `COG_PUSH_MULTIPART_THRESHOLD` + +Sets the minimum blob size, in bytes, before Cog uses multipart upload. + +The default is 128 MiB. Invalid values and values less than `1` are ignored. + +```console +$ COG_PUSH_MULTIPART_THRESHOLD=268435456 cog push ``` diff --git a/docs/http.md b/docs/http.md index 480ad23d38..81fa737d13 100644 --- a/docs/http.md +++ b/docs/http.md @@ -17,13 +17,18 @@ The server supports both synchronous and asynchronous prediction creation: and processes the prediction in the background. The client can create a prediction asynchronously -by setting the `Prefer: respond-async` header in their request. -When provided, the server responds immediately after starting the prediction -with `202 Accepted` status and a prediction object in status `processing`. +by setting the `Prefer: respond-async` header in their request +or by requesting a streamed response with `Accept: text/event-stream`. +With `Prefer: respond-async`, +the server responds immediately after starting the prediction +with `202 Accepted` status and a prediction object in status `starting`. +With `Accept: text/event-stream`, +the server responds with `200 OK` and keeps the response open +as a server-sent event stream. > [!NOTE] -> The only supported way to receive updates on the status of predictions -> started asynchronously is using [webhooks](#webhooks). +> For JSON responses, the only supported way to receive updates on the status +> of predictions started asynchronously is using [webhooks](#webhooks). > Polling for prediction status is not currently supported. You can also use certain server endpoints to create predictions idempotently, @@ -31,28 +36,163 @@ such that if a client calls this endpoint more than once with the same ID (for example, due to a network interruption) while the prediction is still running, no new prediction is created. -Instead, the client receives a `202 Accepted` response -with the initial state of the prediction. +Instead, the client receives the response type requested by the retry: +JSON for regular requests or a server-sent event stream for streaming requests. --- Here's a summary of the prediction creation endpoints: -| Endpoint | Header | Behavior | -| ---------------------------------- | ----------------------- | ---------------------------- | -| `POST /predictions` | - | Synchronous, non-idempotent | -| `POST /predictions` | `Prefer: respond-async` | Asynchronous, non-idempotent | -| `PUT /predictions/` | - | Synchronous, idempotent | -| `PUT /predictions/` | `Prefer: respond-async` | Asynchronous, idempotent | +| Endpoint | Header | Behavior | +| ---------------------------------- | --------------------------- | ---------------------------- | +| `POST /predictions` | - | Synchronous, non-idempotent | +| `POST /predictions` | `Prefer: respond-async` | Asynchronous, non-idempotent | +| `POST /predictions` | `Accept: text/event-stream` | Streaming, non-idempotent | +| `PUT /predictions/` | - | Synchronous, idempotent | +| `PUT /predictions/` | `Prefer: respond-async` | Asynchronous, idempotent | +| `PUT /predictions/` | `Accept: text/event-stream` | Streaming, idempotent | Choose the endpoint that best fits your needs: - Use synchronous endpoints when you want to wait for the prediction result. - Use asynchronous endpoints when you want to start a prediction and receive updates via webhooks. +- Use streaming endpoints when you want to receive prediction lifecycle events + over the HTTP response as they happen. - Use idempotent endpoints when you need to safely retry requests without creating duplicate predictions. +## Streaming predictions with server-sent events + +To produce streamed prediction events, +the model must return an iterator and opt in to SSE streaming +with the `streaming` decorator. + +```python +from typing import Iterator + +from cog import BaseRunner, Input, streaming + + +class Runner(BaseRunner): + @streaming + def run(self, prompt: str = Input(description="Prompt")) -> Iterator[str]: + for token in generate_tokens(prompt): + yield token +``` + +The decorator can also be written as `@cog.streaming` +or, if imported directly from `cog`, `@streaming`. +The parenthesized forms `@cog.streaming()` and `@streaming()` are also accepted. +Without the decorator, +iterator outputs still work in normal JSON responses, +but requests with `Accept: text/event-stream` return `406 Not Acceptable`. + +To consume a streamed prediction, +send the prediction request with `Accept: text/event-stream`: + +```http +POST /predictions HTTP/1.1 +Content-Type: application/json; charset=utf-8 +Accept: text/event-stream + +{ + "input": {"prompt": "Write a haiku about onions"} +} +``` + +The server starts the prediction asynchronously +and keeps the HTTP response open as a server-sent event stream. +Each event has an `event` name and JSON `data` payload: + +```text +event: start +data: {"id":"abc123","status":"processing"} + +event: output +data: {"chunk":"Onions","index":0} + +event: output +data: {"chunk":" bloom","index":1} + +event: completed +data: {"id":"abc123","status":"succeeded","output":["Onions"," bloom"],"metrics":{"predict_time":0.42}} +``` + +Prediction streams can emit these event types: + +- `start`: The prediction started processing. +- `output`: The model yielded an output chunk. + The payload includes `chunk` and `index`. +- `log`: The model wrote to `stdout` or `stderr`. + The payload includes `source` and `data`. +- `metric`: The model recorded a custom metric. + The payload includes `name`, `value`, and `mode`. +- `completed`: The prediction reached a terminal state. + The payload is the final prediction object, + with `status` set to `succeeded`, `failed`, or `canceled`. + +For command-line clients, +use a client that prints the response as data arrives: + +```bash +curl -N \ + -H 'Accept: text/event-stream' \ + -H 'Content-Type: application/json' \ + -d '{"input":{"prompt":"Write a haiku about onions"}}' \ + http://localhost:5000/predictions +``` + +For browser clients, +use `fetch()` or another client that supports request bodies. +The browser `EventSource` API only supports `GET` requests, +so it cannot create a prediction with `POST /predictions` or +`PUT /predictions/`. + +```js +const response = await fetch("/predictions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + }, + body: JSON.stringify({ input: { prompt: "Write a haiku about onions" } }), +}); + +const reader = response.body.pipeThrough(new TextDecoderStream()).getReader(); + +while (true) { + const { value, done } = await reader.read(); + if (done) break; + console.log(value); +} +``` + +Use `PUT /predictions/` when the client needs safe retries +or wants to reconnect to an in-flight prediction by ID: + +```http +PUT /predictions/wjx3whax6rf4vphkegkhcvpv6a HTTP/1.1 +Content-Type: application/json; charset=utf-8 +Accept: text/event-stream + +{ + "input": {"prompt": "Write a haiku about onions"} +} +``` + +If the prediction is still running, +the server returns a stream for the existing prediction +instead of creating a duplicate prediction. +If earlier events have been dropped from the replay buffer, +the stream emits an `error` event and closes. +The replay buffer keeps the most recent 1024 events by default. +Set `COG_STREAM_HISTORY_CAPACITY` to change this limit, +or set it to `0` to disable replay history while keeping live streaming enabled. +Training endpoints do not support SSE streaming; +requests to `/trainings` with `Accept: text/event-stream` +return `406 Not Acceptable`. + ## Webhooks You can provide a `webhook` parameter in the client request body @@ -367,6 +507,11 @@ Content-Type: application/json } ``` +If the client sets the `Accept: text/event-stream` header, +the server starts the prediction asynchronously and responds with a +server-sent event stream. +See [Streaming predictions with server-sent events](#streaming-predictions-with-server-sent-events). + ### `PUT /predictions/` Make a single prediction. @@ -415,6 +560,13 @@ Content-Type: application/json } ``` +If the client sets the `Accept: text/event-stream` header, +the server starts the prediction asynchronously and responds with a +server-sent event stream. +If a prediction with the same ID is already running, +the server returns a stream for the existing prediction. +See [Streaming predictions with server-sent events](#streaming-predictions-with-server-sent-events). + ### `POST /predictions//cancel` A client can cancel an asynchronous prediction by making a diff --git a/docs/llms.txt b/docs/llms.txt index 9c23e24c65..e8e9e1f7c7 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -608,13 +608,13 @@ See the [environment variables reference](environment.md) for the full list. # Environment variables -This guide lists the environment variables that change how Cog functions. +This reference lists the public Cog-specific environment variables that change how Cog behaves. ## Build-time variables ### `COG_SDK_WHEEL` -Controls which cog Python SDK wheel is installed in the Docker image during `cog build`. Takes precedence over `build.sdk_version` in `cog.yaml`. +Controls which Cog Python SDK wheel is installed in the Docker image during `cog build`. Takes precedence over `build.sdk_version` in `cog.yaml`. **Supported values:** @@ -626,100 +626,264 @@ Controls which cog Python SDK wheel is installed in the Docker image during `cog | `https://...` | Install from URL | | `/path/to/wheel.whl` | Install from local file path | -**Default behavior:** +**Default behaviour:** -- **Release builds**: Installs latest cog from PyPI -- **Development builds**: Auto-detects wheel in `dist/` directory, falls back to latest PyPI - -**Examples:** +- Release builds install the latest Cog SDK from PyPI. +- Development builds auto-detect a wheel in `dist/`, then fall back to the latest Cog SDK from PyPI. ```console -# Use specific PyPI version $ COG_SDK_WHEEL=pypi:0.11.0 cog build - -# Use local development wheel $ COG_SDK_WHEEL=dist cog build - -# Use wheel from URL $ COG_SDK_WHEEL=https://example.com/cog-0.12.0-py3-none-any.whl cog build ``` The `dist` option searches for wheels in: 1. `./dist/` (current directory) -2. `$REPO_ROOT/dist/` (if REPO_ROOT is set) +2. `$REPO_ROOT/dist/` (if `REPO_ROOT` is set) 3. `/dist/` (via `git rev-parse`, useful when running from subdirectories) ### `COGLET_WHEEL` Controls which coglet wheel is installed in the Docker image. Coglet is the Rust-based inference server. -**Supported values:** Same as `COG_SDK_WHEEL` +**Supported values:** Same as `COG_SDK_WHEEL`. -**Default behavior:** For development builds, auto-detects a wheel in `dist/`. For release builds, installs the latest version from PyPI. Can be overridden with an explicit value. - -**Examples:** +**Default behaviour:** For development builds, auto-detects a wheel in `dist/`. For release builds, installs the latest version from PyPI. ```console -# Use local development wheel $ COGLET_WHEEL=dist cog build - -# Use specific version from PyPI $ COGLET_WHEEL=pypi:0.1.0 cog build ``` -## Runtime variables +### `COG_CA_CERT` + +Injects a custom CA certificate into the Docker image during `cog build`. This is useful when building behind a corporate proxy or VPN that uses custom certificate authorities (for example, Cloudflare WARP). + +**Supported values:** + +| Value | Description | +| -------------------------------- | ----------------------------------------------------------- | +| `/path/to/cert.crt` | Path to a single PEM certificate file | +| `/path/to/certs/` | Directory of `.crt` and `.pem` files (all are concatenated) | +| `-----BEGIN CERTIFICATE-----...` | Inline PEM certificate | +| `LS0tLS1CRUdJTi...` | Base64-encoded PEM certificate | + +The certificate is installed into the system CA store and the `SSL_CERT_FILE` and `REQUESTS_CA_BUNDLE` environment variables are set automatically in the built image. + +```console +$ COG_CA_CERT=/usr/local/share/ca-certificates/corporate-ca.crt cog build +$ COG_CA_CERT=/etc/custom-certs/ cog build +$ COG_CA_CERT="$(cat /path/to/cert.pem)" cog build +``` + +### `COG_OPENAPI_SCHEMA` + +Uses a pre-built OpenAPI schema instead of generating one from the configured predict or train reference. + +The value must be a path to a JSON schema file. Cog reads that file during schema generation and embeds it in the built image. + +```console +$ COG_OPENAPI_SCHEMA=./openapi.json cog build +``` + +## CLI and local cache variables ### `COG_NO_UPDATE_CHECK` -By default, Cog automatically checks for updates -and notifies you if there is a new version available. +Disables Cog's automatic update check. Set it to any non-empty value. + +```console +$ COG_NO_UPDATE_CHECK=1 cog build +``` + +### `COG_NO_COLOR` + +Disables coloured CLI output. Set it to any non-empty value. + +Cog also honours the standard `NO_COLOR` environment variable. + +```console +$ COG_NO_COLOR=1 cog predict -i prompt="hello" +``` + +### `COG_SKIP_DOCKER_CHECK` + +Skips the `cog doctor` Docker environment check. Set it to any non-empty value. + +```console +$ COG_SKIP_DOCKER_CHECK=1 cog doctor +``` + +### `COG_CACHE_DIR` + +Overrides Cog's local cache root. -To disable this behavior, -set the `COG_NO_UPDATE_CHECK` environment variable to any value. +Cog currently uses this cache for the content-addressed weights store. If unset, Cog uses `$XDG_CACHE_HOME/cog` when `XDG_CACHE_HOME` is set, otherwise `$HOME/.cache/cog`. ```console -$ COG_NO_UPDATE_CHECK=1 cog build # runs without automatic update check +$ COG_CACHE_DIR=/mnt/fast-cache cog weights pull +``` + +## Model reference and registry variables + +### `COG_MODEL` + +Overrides the full model reference used by commands that need a model destination, such as `cog push` and weights commands. + +The value is parsed as a complete model reference (`registry/repo`, `registry/repo:tag`, or `registry/repo@digest`). If no tag is supplied, Cog generates a timestamp tag. + +When `COG_MODEL` is set, it takes precedence over `COG_MODEL_REGISTRY`, `COG_MODEL_REPO`, and `COG_MODEL_TAG`. + +```console +$ COG_MODEL=r8.im/acme/my-model:v1 cog push +``` + +### `COG_MODEL_REGISTRY` + +Overrides only the registry host of the model reference. + +```console +$ COG_MODEL_REGISTRY=registry.example.com cog push +``` + +### `COG_MODEL_REPO` + +Overrides only the repository path of the model reference. The value must not include a registry host, tag, or digest. + +```console +$ COG_MODEL_REPO=acme/my-model cog push +``` + +### `COG_MODEL_TAG` + +Overrides only the tag of the model reference. + +Tags starting with `cog-` are reserved for tags that Cog generates internally and are rejected. + +```console +$ COG_MODEL_TAG=staging cog push +``` + +### `COG_REGISTRY_HOST` + +Changes the default Replicate-compatible registry host used by commands such as `cog login`, base image resolution, and model reference resolution. + +The default is `r8.im`. + +```console +$ COG_REGISTRY_HOST=registry.example.com cog login +``` + +## Runtime server variables + +These variables affect a running model server. Set them in `cog.yaml` under `environment`, pass them with `cog predict -e` or `cog serve -e`, or set them when running the built Docker image. + +### `COG_MAX_CONCURRENCY` + +Controls how many predictions the model server can run concurrently. + +By default, Cog runs one prediction at a time. Invalid values are ignored and the default of `1` is used. + +```console +$ COG_MAX_CONCURRENCY=4 docker run -p 5000:5000 my-model ``` ### `COG_SETUP_TIMEOUT` -Controls the maximum time (in seconds) allowed for the model's `setup()` method to complete. If setup exceeds this timeout, the server will report a setup failure. +Controls the maximum time, in seconds, allowed for the model's `setup()` method to complete. If setup exceeds this timeout, the server reports setup failure. + +By default, there is no timeout. Set to `0` to disable the timeout. Invalid values are ignored with a warning. + +```console +$ COG_SETUP_TIMEOUT=300 docker run -p 5000:5000 my-model +``` -By default, there is no timeout — setup runs indefinitely. +### `COG_LOG_LEVEL` -Set to `0` to disable the timeout (same as default). Invalid values are ignored with a warning. +Controls Coglet runtime log verbosity when `RUST_LOG` is not set. + +Supported values are `debug`, `info`, `warn`, `warning`, and `error`. The default is `info`. ```console -$ COG_SETUP_TIMEOUT=300 docker run -p 5000:5000 my-model # 5-minute setup timeout +$ COG_LOG_LEVEL=debug docker run -p 5000:5000 my-model ``` -### `COG_CA_CERT` +### `COG_THROTTLE_RESPONSE_INTERVAL` -Injects a custom CA certificate into the Docker image during `cog build`. This is useful when building behind a corporate proxy or VPN that uses custom certificate authorities (e.g. Cloudflare WARP). +Controls how often asynchronous webhook `output` and `logs` events are sent, in seconds. -**Supported values:** +The default is `0.5` seconds. Invalid values are ignored and the default is used. `start` and `completed` webhook events are always sent immediately. -| Value | Description | -| -------------------------------- | ----------------------------------------------------------- | -| `/path/to/cert.crt` | Path to a single PEM certificate file | -| `/path/to/certs/` | Directory of `.crt` and `.pem` files (all are concatenated) | -| `-----BEGIN CERTIFICATE-----...` | Inline PEM certificate | -| `LS0tLS1CRUdJTi...` | Base64-encoded PEM certificate | +```console +$ COG_THROTTLE_RESPONSE_INTERVAL=1 docker run -p 5000:5000 my-model +``` -The certificate is installed into the system CA store and the `SSL_CERT_FILE` and `REQUESTS_CA_BUNDLE` environment variables are set automatically in the built image. +### `COG_STREAM_HISTORY_CAPACITY` + +Controls how many server-sent event stream events are retained per prediction for replay when a client reconnects with `Accept: text/event-stream`. -**Examples:** +By default, Cog retains the most recent 1024 events per prediction. Set to `0` to disable replay history while keeping live streaming enabled. Invalid values are ignored with a warning and the default is used. ```console -# From a file -$ COG_CA_CERT=/usr/local/share/ca-certificates/corporate-ca.crt cog build +$ COG_STREAM_HISTORY_CAPACITY=0 docker run -p 5000:5000 my-model +$ COG_STREAM_HISTORY_CAPACITY=4096 docker run -p 5000:5000 my-model +``` -# From a directory of certs -$ COG_CA_CERT=/etc/custom-certs/ cog build +### `COG_WEIGHTS` -# Inline (e.g. from a CI secret) -$ COG_CA_CERT="$(cat /path/to/cert.pem)" cog build +Provides a weights path or URL to a model whose `setup()` method accepts a `weights` parameter. + +```console +$ cog run -e COG_WEIGHTS=https://example.com/weights.tar -i prompt="hello" +``` + +### `COG_USER_AGENT` + +Sets the `User-Agent` header used by Cog when downloading URL-backed `File` inputs. + +```console +$ COG_USER_AGENT="my-service/1.0" docker run -p 5000:5000 my-model +``` + +## Push tuning variables + +### `COG_PUSH_OCI` + +Enables Cog's OCI chunked push path for container image layers when set to `1`. If the OCI push fails with a non-fatal error, Cog falls back to Docker's native push path. + +```console +$ COG_PUSH_OCI=1 cog push +``` + +### `COG_PUSH_CONCURRENCY` + +Controls how many image layers or weight blobs Cog uploads concurrently during push operations. + +The default is `5`. Invalid values and values less than `1` are ignored. + +```console +$ COG_PUSH_CONCURRENCY=2 cog push +``` + +### `COG_PUSH_DEFAULT_CHUNK_SIZE` + +Sets the default multipart upload chunk size, in bytes, when the registry does not advertise a maximum chunk size. + +The default is 96 MiB. Invalid values and values less than `1` are ignored. + +```console +$ COG_PUSH_DEFAULT_CHUNK_SIZE=67108864 cog push +``` + +### `COG_PUSH_MULTIPART_THRESHOLD` + +Sets the minimum blob size, in bytes, before Cog uses multipart upload. + +The default is 128 MiB. Invalid values and values less than `1` are ignored. + +```console +$ COG_PUSH_MULTIPART_THRESHOLD=268435456 cog push ``` @@ -1158,13 +1322,18 @@ The server supports both synchronous and asynchronous prediction creation: and processes the prediction in the background. The client can create a prediction asynchronously -by setting the `Prefer: respond-async` header in their request. -When provided, the server responds immediately after starting the prediction -with `202 Accepted` status and a prediction object in status `processing`. +by setting the `Prefer: respond-async` header in their request +or by requesting a streamed response with `Accept: text/event-stream`. +With `Prefer: respond-async`, +the server responds immediately after starting the prediction +with `202 Accepted` status and a prediction object in status `starting`. +With `Accept: text/event-stream`, +the server responds with `200 OK` and keeps the response open +as a server-sent event stream. > [!NOTE] -> The only supported way to receive updates on the status of predictions -> started asynchronously is using [webhooks](#webhooks). +> For JSON responses, the only supported way to receive updates on the status +> of predictions started asynchronously is using [webhooks](#webhooks). > Polling for prediction status is not currently supported. You can also use certain server endpoints to create predictions idempotently, @@ -1172,28 +1341,163 @@ such that if a client calls this endpoint more than once with the same ID (for example, due to a network interruption) while the prediction is still running, no new prediction is created. -Instead, the client receives a `202 Accepted` response -with the initial state of the prediction. +Instead, the client receives the response type requested by the retry: +JSON for regular requests or a server-sent event stream for streaming requests. --- Here's a summary of the prediction creation endpoints: -| Endpoint | Header | Behavior | -| ---------------------------------- | ----------------------- | ---------------------------- | -| `POST /predictions` | - | Synchronous, non-idempotent | -| `POST /predictions` | `Prefer: respond-async` | Asynchronous, non-idempotent | -| `PUT /predictions/` | - | Synchronous, idempotent | -| `PUT /predictions/` | `Prefer: respond-async` | Asynchronous, idempotent | +| Endpoint | Header | Behavior | +| ---------------------------------- | --------------------------- | ---------------------------- | +| `POST /predictions` | - | Synchronous, non-idempotent | +| `POST /predictions` | `Prefer: respond-async` | Asynchronous, non-idempotent | +| `POST /predictions` | `Accept: text/event-stream` | Streaming, non-idempotent | +| `PUT /predictions/` | - | Synchronous, idempotent | +| `PUT /predictions/` | `Prefer: respond-async` | Asynchronous, idempotent | +| `PUT /predictions/` | `Accept: text/event-stream` | Streaming, idempotent | Choose the endpoint that best fits your needs: - Use synchronous endpoints when you want to wait for the prediction result. - Use asynchronous endpoints when you want to start a prediction and receive updates via webhooks. +- Use streaming endpoints when you want to receive prediction lifecycle events + over the HTTP response as they happen. - Use idempotent endpoints when you need to safely retry requests without creating duplicate predictions. +## Streaming predictions with server-sent events + +To produce streamed prediction events, +the model must return an iterator and opt in to SSE streaming +with the `streaming` decorator. + +```python +from typing import Iterator + +from cog import BaseRunner, Input, streaming + + +class Runner(BaseRunner): + @streaming + def run(self, prompt: str = Input(description="Prompt")) -> Iterator[str]: + for token in generate_tokens(prompt): + yield token +``` + +The decorator can also be written as `@cog.streaming` +or, if imported directly from `cog`, `@streaming`. +The parenthesized forms `@cog.streaming()` and `@streaming()` are also accepted. +Without the decorator, +iterator outputs still work in normal JSON responses, +but requests with `Accept: text/event-stream` return `406 Not Acceptable`. + +To consume a streamed prediction, +send the prediction request with `Accept: text/event-stream`: + +```http +POST /predictions HTTP/1.1 +Content-Type: application/json; charset=utf-8 +Accept: text/event-stream + +{ + "input": {"prompt": "Write a haiku about onions"} +} +``` + +The server starts the prediction asynchronously +and keeps the HTTP response open as a server-sent event stream. +Each event has an `event` name and JSON `data` payload: + +```text +event: start +data: {"id":"abc123","status":"processing"} + +event: output +data: {"chunk":"Onions","index":0} + +event: output +data: {"chunk":" bloom","index":1} + +event: completed +data: {"id":"abc123","status":"succeeded","output":["Onions"," bloom"],"metrics":{"predict_time":0.42}} +``` + +Prediction streams can emit these event types: + +- `start`: The prediction started processing. +- `output`: The model yielded an output chunk. + The payload includes `chunk` and `index`. +- `log`: The model wrote to `stdout` or `stderr`. + The payload includes `source` and `data`. +- `metric`: The model recorded a custom metric. + The payload includes `name`, `value`, and `mode`. +- `completed`: The prediction reached a terminal state. + The payload is the final prediction object, + with `status` set to `succeeded`, `failed`, or `canceled`. + +For command-line clients, +use a client that prints the response as data arrives: + +```bash +curl -N \ + -H 'Accept: text/event-stream' \ + -H 'Content-Type: application/json' \ + -d '{"input":{"prompt":"Write a haiku about onions"}}' \ + http://localhost:5000/predictions +``` + +For browser clients, +use `fetch()` or another client that supports request bodies. +The browser `EventSource` API only supports `GET` requests, +so it cannot create a prediction with `POST /predictions` or +`PUT /predictions/`. + +```js +const response = await fetch("/predictions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + }, + body: JSON.stringify({ input: { prompt: "Write a haiku about onions" } }), +}); + +const reader = response.body.pipeThrough(new TextDecoderStream()).getReader(); + +while (true) { + const { value, done } = await reader.read(); + if (done) break; + console.log(value); +} +``` + +Use `PUT /predictions/` when the client needs safe retries +or wants to reconnect to an in-flight prediction by ID: + +```http +PUT /predictions/wjx3whax6rf4vphkegkhcvpv6a HTTP/1.1 +Content-Type: application/json; charset=utf-8 +Accept: text/event-stream + +{ + "input": {"prompt": "Write a haiku about onions"} +} +``` + +If the prediction is still running, +the server returns a stream for the existing prediction +instead of creating a duplicate prediction. +If earlier events have been dropped from the replay buffer, +the stream emits an `error` event and closes. +The replay buffer keeps the most recent 1024 events by default. +Set `COG_STREAM_HISTORY_CAPACITY` to change this limit, +or set it to `0` to disable replay history while keeping live streaming enabled. +Training endpoints do not support SSE streaming; +requests to `/trainings` with `Accept: text/event-stream` +return `406 Not Acceptable`. + ## Webhooks You can provide a `webhook` parameter in the client request body @@ -1508,6 +1812,11 @@ Content-Type: application/json } ``` +If the client sets the `Accept: text/event-stream` header, +the server starts the prediction asynchronously and responds with a +server-sent event stream. +See [Streaming predictions with server-sent events](#streaming-predictions-with-server-sent-events). + ### `PUT /predictions/` Make a single prediction. @@ -1556,6 +1865,13 @@ Content-Type: application/json } ``` +If the client sets the `Accept: text/event-stream` header, +the server starts the prediction asynchronously and responds with a +server-sent event stream. +If a prediction with the same ID is already running, +the server returns a stream for the existing prediction. +See [Streaming predictions with server-sent events](#streaming-predictions-with-server-sent-events). + ### `POST /predictions//cancel` A client can cancel an asynchronous prediction by making a @@ -1979,11 +2295,14 @@ Cog models can stream output as the `run()` method is running. For example, a la To support streaming output in your Cog model, add `from typing import Iterator` to your `run.py` file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `run()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. +To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the prediction method (`run()` or `predict()`) with `@cog.streaming` (or `@streaming` if imported directly from `cog`). The parenthesized forms `@cog.streaming()` and `@streaming()` are also accepted. The decorated method must return `Iterator[...]`, `AsyncIterator[...]`, `ConcatenateIterator[...]`, or `AsyncConcatenateIterator[...]`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. + ```py -from cog import BaseRunner, Path from typing import Iterator +from cog import BaseRunner, Path, streaming class Runner(BaseRunner): + @streaming def run(self) -> Iterator[Path]: done = False while not done: @@ -1995,9 +2314,10 @@ If you have an [async `run()` method](#async-runners-and-concurrency), use `Asyn ```py from typing import AsyncIterator -from cog import BaseRunner, Path +from cog import BaseRunner, Path, streaming class Runner(BaseRunner): + @streaming async def run(self) -> AsyncIterator[Path]: done = False while not done: @@ -2008,9 +2328,10 @@ class Runner(BaseRunner): If you're streaming text output, you can use `ConcatenateIterator` to hint that the output should be concatenated together into a single string. This is useful on Replicate to display the output as a string instead of a list of strings. ```py -from cog import BaseRunner, Path, ConcatenateIterator +from cog import BaseRunner, ConcatenateIterator, streaming class Runner(BaseRunner): + @streaming def run(self) -> ConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: @@ -2020,9 +2341,10 @@ class Runner(BaseRunner): Or for async `run()` methods, use `AsyncConcatenateIterator`: ```py -from cog import BaseRunner, Path, AsyncConcatenateIterator +from cog import AsyncConcatenateIterator, BaseRunner, streaming class Runner(BaseRunner): + @streaming async def run(self) -> AsyncConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: diff --git a/docs/python.md b/docs/python.md index c86ac4b9b4..5f5619a9af 100644 --- a/docs/python.md +++ b/docs/python.md @@ -263,11 +263,14 @@ Cog models can stream output as the `run()` method is running. For example, a la To support streaming output in your Cog model, add `from typing import Iterator` to your `run.py` file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `run()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. +To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the prediction method (`run()` or `predict()`) with `@cog.streaming` (or `@streaming` if imported directly from `cog`). The parenthesized forms `@cog.streaming()` and `@streaming()` are also accepted. The decorated method must return `Iterator[...]`, `AsyncIterator[...]`, `ConcatenateIterator[...]`, or `AsyncConcatenateIterator[...]`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. + ```py -from cog import BaseRunner, Path from typing import Iterator +from cog import BaseRunner, Path, streaming class Runner(BaseRunner): + @streaming def run(self) -> Iterator[Path]: done = False while not done: @@ -279,9 +282,10 @@ If you have an [async `run()` method](#async-runners-and-concurrency), use `Asyn ```py from typing import AsyncIterator -from cog import BaseRunner, Path +from cog import BaseRunner, Path, streaming class Runner(BaseRunner): + @streaming async def run(self) -> AsyncIterator[Path]: done = False while not done: @@ -292,9 +296,10 @@ class Runner(BaseRunner): If you're streaming text output, you can use `ConcatenateIterator` to hint that the output should be concatenated together into a single string. This is useful on Replicate to display the output as a string instead of a list of strings. ```py -from cog import BaseRunner, Path, ConcatenateIterator +from cog import BaseRunner, ConcatenateIterator, streaming class Runner(BaseRunner): + @streaming def run(self) -> ConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: @@ -304,9 +309,10 @@ class Runner(BaseRunner): Or for async `run()` methods, use `AsyncConcatenateIterator`: ```py -from cog import BaseRunner, Path, AsyncConcatenateIterator +from cog import AsyncConcatenateIterator, BaseRunner, streaming class Runner(BaseRunner): + @streaming async def run(self) -> AsyncConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: diff --git a/examples/streaming-text/.dockerignore b/examples/streaming-text/.dockerignore new file mode 100644 index 0000000000..9118d0e055 --- /dev/null +++ b/examples/streaming-text/.dockerignore @@ -0,0 +1,3 @@ +.cog/ +__pycache__/ +*.pyc diff --git a/examples/streaming-text/.gitignore b/examples/streaming-text/.gitignore new file mode 100644 index 0000000000..9118d0e055 --- /dev/null +++ b/examples/streaming-text/.gitignore @@ -0,0 +1,3 @@ +.cog/ +__pycache__/ +*.pyc diff --git a/examples/streaming-text/README.md b/examples/streaming-text/README.md new file mode 100644 index 0000000000..77265e6f1e --- /dev/null +++ b/examples/streaming-text/README.md @@ -0,0 +1,51 @@ +# examples/streaming-text + +Streaming text generation with `HuggingFaceTB/SmolLM2-135M-Instruct`. + +This example shows how a Cog runner can yield text chunks as a model generates them, and how to consume those chunks with Server-Sent Events. + +## Run a normal prediction + +From this directory: + +```sh +cog predict -i prompt="Write a short haiku about databases" +``` + +This returns the final accumulated output after the prediction completes. + +## Stream output over HTTP + +Start the server: + +```sh +cog serve +``` + +Create a prediction and request an SSE response: + +```sh +curl -N -X PUT http://localhost:5000/predictions/streaming-demo \ + -H 'Content-Type: application/json' \ + -H 'Accept: text/event-stream' \ + -d '{"input":{"prompt":"Write a short haiku about databases","max_new_tokens":96}}' +``` + +The response includes `output` events as chunks are generated, followed by a `completed` event: + +```text +event: output +data: {"chunk":"Silent","index":0} + +event: output +data: {"chunk":" rows","index":1} + +event: completed +data: {"id":"streaming-demo","status":"succeeded",...} +``` + +## How it works + +`predict.py` defines `run() -> Iterator[str]`. Each `yield` becomes one streamed output chunk. The example uses Hugging Face `TextIteratorStreamer` to receive generated text from `model.generate()` while generation is still running. + +The normal prediction response still contains the accumulated output for compatibility. Requesting `Accept: text/event-stream` is useful when clients want to display tokens as they arrive. diff --git a/examples/streaming-text/cog.yaml b/examples/streaming-text/cog.yaml new file mode 100644 index 0000000000..68866b7aeb --- /dev/null +++ b/examples/streaming-text/cog.yaml @@ -0,0 +1,7 @@ +# Streaming text generation example using a small open-weight language model. + +build: + python_version: "3.12" + python_requirements: requirements.txt + +predict: "predict.py:Predictor" diff --git a/examples/streaming-text/predict.py b/examples/streaming-text/predict.py new file mode 100644 index 0000000000..c94ddb04f3 --- /dev/null +++ b/examples/streaming-text/predict.py @@ -0,0 +1,66 @@ +from threading import Thread +from typing import Iterator + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer + +from cog import BaseRunner, Input, streaming + +MODEL_NAME = "HuggingFaceTB/SmolLM2-135M-Instruct" + + +class Predictor(BaseRunner): + def setup(self) -> None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if self.device == "cuda" else torch.float32 + + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + self.model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=dtype, + ).to(self.device) + self.model.eval() + + @streaming + def run( + self, + prompt: str = Input(description="Prompt to complete"), + max_new_tokens: int = Input( + description="Maximum number of tokens to generate", + default=128, + ge=1, + le=512, + ), + ) -> Iterator[str]: + messages = [{"role": "user", "content": prompt}] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + inputs = self.tokenizer([text], return_tensors="pt").to(self.device) + streamer = TextIteratorStreamer( + self.tokenizer, + skip_prompt=True, + skip_special_tokens=True, + ) + + generation_kwargs = { + **inputs, + "streamer": streamer, + "max_new_tokens": max_new_tokens, + "do_sample": True, + "temperature": 0.7, + "top_p": 0.9, + "pad_token_id": self.tokenizer.eos_token_id, + } + + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + thread.start() + + try: + for chunk in streamer: + if chunk: + yield chunk + finally: + thread.join() diff --git a/examples/streaming-text/requirements.txt b/examples/streaming-text/requirements.txt new file mode 100644 index 0000000000..916933a3de --- /dev/null +++ b/examples/streaming-text/requirements.txt @@ -0,0 +1,3 @@ +torch==2.7.1 +transformers==4.51.3 +accelerate==1.6.0 diff --git a/integration-tests/tests/sse_requires_streaming_opt_in.txtar b/integration-tests/tests/sse_requires_streaming_opt_in.txtar new file mode 100644 index 0000000000..7ac62c51a0 --- /dev/null +++ b/integration-tests/tests/sse_requires_streaming_opt_in.txtar @@ -0,0 +1,28 @@ +# Test that SSE requires @cog.streaming while undecorated iterators still work as JSON. + +[short] skip 'requires Docker build' + +cog serve --upload-url http://unused/ + +cog predict -i count=2 +stdout '"chunk-0"' +stdout '"chunk-1"' + +! curl -H Accept:text/event-stream PUT /predictions/no-streaming '{"id":"no-streaming","input":{"count":2}}' +stderr 'This model does not support streaming responses' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from typing import Iterator + +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, count: int) -> Iterator[str]: + for index in range(count): + yield f"chunk-{index}" diff --git a/integration-tests/tests/sse_stream_history_capacity.txtar b/integration-tests/tests/sse_stream_history_capacity.txtar new file mode 100644 index 0000000000..42b640b572 --- /dev/null +++ b/integration-tests/tests/sse_stream_history_capacity.txtar @@ -0,0 +1,41 @@ +# Test configurable SSE replay history capacity. + +[short] skip 'requires Docker build' + +cog serve --upload-url http://unused/ + +# Capacity 2 should drop older replay events and close with an error for late subscribers. +curl -H Prefer:respond-async PUT /predictions/replay-truncated '{"id":"replay-truncated","input":{"count":4,"sleep_after":2}}' +stdout '"status":"starting"' + +exec sleep 1 + +curl -H Accept:text/event-stream PUT /predictions/replay-truncated '{"id":"replay-truncated","input":{"count":4,"sleep_after":2}}' +stdout 'event: error' +stdout 'SSE stream replay truncated' +stdout '"skipped":3' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" +environment: + - COG_STREAM_HISTORY_CAPACITY=2 + +-- predict.py -- +import time +from typing import Iterator + +from cog import BasePredictor, Input, streaming + + +class Predictor(BasePredictor): + @streaming + def predict( + self, + count: int = Input(default=4), + sleep_after: float = Input(default=2.0), + ) -> Iterator[str]: + for index in range(count): + yield f"chunk-{index}" + time.sleep(sleep_after) diff --git a/integration-tests/tests/sse_stream_history_disabled.txtar b/integration-tests/tests/sse_stream_history_disabled.txtar new file mode 100644 index 0000000000..6a36d9e88d --- /dev/null +++ b/integration-tests/tests/sse_stream_history_disabled.txtar @@ -0,0 +1,41 @@ +# Test that COG_STREAM_HISTORY_CAPACITY=0 disables SSE replay while live streaming still works. + +[short] skip 'requires Docker build' + +cog serve --upload-url http://unused/ + +curl -H Prefer:respond-async PUT /predictions/replay-disabled '{"id":"replay-disabled","input":{"count":3,"sleep_after":1}}' +stdout '"status":"starting"' + +exec sleep 0.5 + +curl -H Accept:text/event-stream PUT /predictions/replay-disabled '{"id":"replay-disabled","input":{"count":3,"sleep_after":1}}' +stdout 'event: completed' +stdout '"status":"succeeded"' +! stdout 'event: output' +! stdout 'SSE stream replay truncated' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" +environment: + - COG_STREAM_HISTORY_CAPACITY=0 + +-- predict.py -- +import time +from typing import Iterator + +from cog import BasePredictor, Input, streaming + + +class Predictor(BasePredictor): + @streaming + def predict( + self, + count: int = Input(default=3), + sleep_after: float = Input(default=1.0), + ) -> Iterator[str]: + for index in range(count): + yield f"chunk-{index}" + time.sleep(sleep_after) diff --git a/integration-tests/tests/sse_streaming_metrics.txtar b/integration-tests/tests/sse_streaming_metrics.txtar new file mode 100644 index 0000000000..e2eebaa475 --- /dev/null +++ b/integration-tests/tests/sse_streaming_metrics.txtar @@ -0,0 +1,42 @@ +# Test that custom metrics are emitted as SSE metric events and included in the completed payload. + +[short] skip 'requires Docker build' + +cog serve --upload-url http://unused/ + +curl -H Accept:text/event-stream PUT /predictions/sse-metrics '{"id":"sse-metrics","input":{}}' +stdout 'event: start' +stdout 'event: metric' +stdout '"name":"temperature"' +stdout '"value":0.7' +stdout '"mode":"replace"' +stdout 'event: output' +stdout 'data: {"chunk":"chunk-1","index":0}' +stdout 'event: metric' +stdout '"name":"token_count"' +stdout '"value":2' +stdout '"mode":"increment"' +stdout 'event: completed' +stdout '"status":"succeeded"' +stdout '"temperature":0.7' +stdout '"token_count":2' +stdout '"predict_time"' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from typing import Iterator + +from cog import BasePredictor, current_scope, streaming + + +class Predictor(BasePredictor): + @streaming + def predict(self) -> Iterator[str]: + scope = current_scope() + scope.record_metric("temperature", 0.7) + yield "chunk-1" + scope.record_metric("token_count", 2, mode="incr") diff --git a/integration-tests/tests/sse_streaming_output.txtar b/integration-tests/tests/sse_streaming_output.txtar new file mode 100644 index 0000000000..23646106f5 --- /dev/null +++ b/integration-tests/tests/sse_streaming_output.txtar @@ -0,0 +1,34 @@ +# Test that async generator output is available when predictions are created with SSE accept. + +[short] skip 'requires Docker build' + +cog serve --upload-url http://unused/ + +curl -H Accept:text/event-stream PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' +stdout 'event: start' +stdout 'event: output' +stdout 'data: {"chunk":"chunk-1","index":0}' +stdout 'event: output' +stdout 'data: {"chunk":"chunk-2","index":1}' +stdout 'event: completed' +stdout '"status":"succeeded"' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +import time +from typing import Iterator + +from cog import BasePredictor, streaming + + +class Predictor(BasePredictor): + @streaming + def predict(self) -> Iterator[str]: + time.sleep(0.25) + yield "chunk-1" + time.sleep(0.25) + yield "chunk-2" diff --git a/pkg/schema/openapi.go b/pkg/schema/openapi.go index 9435ea8248..f15f4471ca 100644 --- a/pkg/schema/openapi.go +++ b/pkg/schema/openapi.go @@ -187,38 +187,40 @@ func buildOpenAPISpec(info *PredictorInfo) map[string]any { }) // Main endpoint (predict or train) - paths.Set(endpoint, map[string]any{ - "post": map[string]any{ - "summary": summary, - "description": description, - "operationId": opID, - "requestBody": map[string]any{ + mainOperation := map[string]any{ + "summary": summary, + "description": description, + "operationId": opID, + "requestBody": map[string]any{ + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{"$ref": requestRef}, + }, + }, + }, + "responses": map[string]any{ + "200": map[string]any{ + "description": "Successful Response", "content": map[string]any{ "application/json": map[string]any{ - "schema": map[string]any{"$ref": requestRef}, + "schema": map[string]any{"$ref": responseRef}, }, }, }, - "responses": map[string]any{ - "200": map[string]any{ - "description": "Successful Response", - "content": map[string]any{ - "application/json": map[string]any{ - "schema": map[string]any{"$ref": responseRef}, - }, - }, - }, - "422": map[string]any{ - "description": "Validation Error", - "content": map[string]any{ - "application/json": map[string]any{ - "schema": map[string]any{"$ref": "#/components/schemas/HTTPValidationError"}, - }, + "422": map[string]any{ + "description": "Validation Error", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{"$ref": "#/components/schemas/HTTPValidationError"}, }, }, }, }, - }) + } + if !isTrain && info.SupportsStreaming { + mainOperation["x-cog-streaming"] = true + } + paths.Set(endpoint, map[string]any{"post": mainOperation}) // Cancel endpoint paths.Set(cancelEP, map[string]any{ diff --git a/pkg/schema/openapi_test.go b/pkg/schema/openapi_test.go index a3a9f7a8a6..d9b528bb9f 100644 --- a/pkg/schema/openapi_test.go +++ b/pkg/schema/openapi_test.go @@ -657,6 +657,55 @@ func TestOutputConcatenateIterator(t *testing.T) { assert.Equal(t, "concatenate", output["x-cog-array-display"]) } +func TestPredictionOperationIncludesStreamingExtensionWhenEnabled(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + info := &PredictorInfo{ + Inputs: inputs, + Output: SchemaIteratorOf(SchemaPrim(TypeString)), + Mode: ModePredict, + SupportsStreaming: true, + } + + spec := parseSpec(t, info) + postPath := getPath(spec, "paths", "/predictions", "post") + require.NotNil(t, postPath) + post := postPath.(map[string]any) + assert.Equal(t, true, post["x-cog-streaming"]) +} + +func TestPredictionOperationOmitsStreamingExtensionByDefault(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + info := &PredictorInfo{ + Inputs: inputs, + Output: SchemaIteratorOf(SchemaPrim(TypeString)), + Mode: ModePredict, + } + + spec := parseSpec(t, info) + postPath := getPath(spec, "paths", "/predictions", "post") + require.NotNil(t, postPath) + post := postPath.(map[string]any) + _, ok := post["x-cog-streaming"] + assert.False(t, ok) +} + +func TestTrainingOperationOmitsStreamingExtensionWhenEnabled(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + info := &PredictorInfo{ + Inputs: inputs, + Output: SchemaIteratorOf(SchemaPrim(TypeString)), + Mode: ModeTrain, + SupportsStreaming: true, + } + + spec := parseSpec(t, info) + postPath := getPath(spec, "paths", "/trainings", "post") + require.NotNil(t, postPath) + post := postPath.(map[string]any) + _, ok := post["x-cog-streaming"] + assert.False(t, ok) +} + func TestOutputObject(t *testing.T) { inputs := NewOrderedMap[string, InputField]() fields := NewOrderedMap[string, SchemaField]() diff --git a/pkg/schema/python/parser.go b/pkg/schema/python/parser.go index 427445fecf..e17e9d8a58 100644 --- a/pkg/schema/python/parser.go +++ b/pkg/schema/python/parser.go @@ -59,7 +59,6 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi if mode == schema.ModeTrain { methodName = "train" } - fileCtx := &pythonFileContext{ root: root, source: source, @@ -117,11 +116,16 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi if err != nil { return nil, err } + supportsStreaming := functionSupportsStreaming(target.node, target.file.source, target.file.imports) + if supportsStreaming && !supportsStreamingOutput(output) { + return nil, schema.WrapError(schema.ErrUnsupportedType, "@streaming requires Iterator[...] or ConcatenateIterator[...] return type", nil) + } return &schema.PredictorInfo{ - Inputs: inputs, - Output: output, - Mode: mode, + Inputs: inputs, + Output: output, + Mode: mode, + SupportsStreaming: supportsStreaming, }, nil } @@ -670,6 +674,70 @@ func UnwrapFunction(node *sitter.Node) *sitter.Node { return nil } +func functionSupportsStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + if node.Type() == "function_definition" { + parent := node.Parent() + if parent == nil || parent.Type() != "decorated_definition" { + return false + } + node = parent + } + + if node.Type() != "decorated_definition" { + return false + } + for _, child := range NamedChildren(node) { + if child.Type() != "decorator" { + continue + } + if decoratorIsCogStreaming(child, source, imports) { + return true + } + } + return false +} + +func decoratorIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + for _, child := range NamedChildren(node) { + switch child.Type() { + case "attribute": + return attributeIsCogStreaming(child, source, imports) + case "identifier": + return identifierIsCogStreaming(child, source, imports) + case "call": + callee := child.ChildByFieldName("function") + if callee == nil { + return false + } + switch callee.Type() { + case "attribute": + return attributeIsCogStreaming(callee, source, imports) + case "identifier": + return identifierIsCogStreaming(callee, source, imports) + } + } + } + return false +} + +func attributeIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + parts := strings.SplitN(Content(node, source), ".", 2) + if len(parts) != 2 || parts[1] != "streaming" { + return false + } + entry, ok := imports.Names.Get(parts[0]) + return ok && entry.Module == "cog" && entry.Original == "cog" +} + +func identifierIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + entry, ok := imports.Names.Get(Content(node, source)) + return ok && entry.Module == "cog" && entry.Original == "streaming" +} + +func supportsStreamingOutput(output schema.SchemaType) bool { + return output.Kind == schema.SchemaIterator || output.Kind == schema.SchemaConcatIterator +} + func InheritsFromBaseModel(classNode *sitter.Node, source []byte, imports *schema.ImportContext) bool { supers := classNode.ChildByFieldName("superclasses") if supers == nil { diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index 07849303fe..30fc1c15e3 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -575,6 +575,133 @@ class Predictor(BasePredictor): require.Equal(t, schema.ErrConcatIteratorNotStr, se.Kind) } +func TestStreamingDecoratorQualifiedOptIn(t *testing.T) { + source := ` +import cog +from typing import Iterator + +class Predictor(cog.BasePredictor): + @cog.streaming + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.True(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorQualifiedAliasOptIn(t *testing.T) { + source := ` +import cog as c +from typing import Iterator + +class Predictor(c.BasePredictor): + @c.streaming + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.True(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorImportedOptIn(t *testing.T) { + source := ` +from cog import BasePredictor, streaming +from typing import Iterator + +class Predictor(BasePredictor): + @streaming + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.True(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorImportedAliasOptIn(t *testing.T) { + source := ` +from cog import BasePredictor, streaming as stream +from typing import Iterator + +class Predictor(BasePredictor): + @stream + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.True(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorRequiresIteratorOutput(t *testing.T) { + source := ` +from cog import BasePredictor, streaming + +class Predictor(BasePredictor): + @streaming + def predict(self) -> str: + return "hello" +` + se := parseErr(t, source, "Predictor", schema.ModePredict) + require.Equal(t, schema.ErrUnsupportedType, se.Kind) + require.Contains(t, se.Message, "@streaming requires") +} + +func TestStreamingDecoratorIgnoredWhenNotFromCog(t *testing.T) { + source := ` +from other import streaming +from typing import Iterator +from cog import BasePredictor + +class Predictor(BasePredictor): + @streaming + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.False(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorParameterizedFormOptIn(t *testing.T) { + source := ` +import cog +from typing import Iterator + +class Predictor(cog.BasePredictor): + @cog.streaming() + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.True(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorImportedParameterizedFormOptIn(t *testing.T) { + source := ` +from cog import BasePredictor, streaming +from typing import Iterator + +class Predictor(BasePredictor): + @streaming() + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.True(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorClassLevelIgnored(t *testing.T) { + source := ` +import cog +from typing import Iterator + +@cog.streaming +class Predictor(cog.BasePredictor): + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.False(t, info.SupportsStreaming) +} + func TestListOutput(t *testing.T) { source := ` from cog import BasePredictor, Path @@ -660,6 +787,28 @@ def train(n: int) -> Path: require.Equal(t, 1, info.Inputs.Len()) } +func TestTrainModeClassDecoratedMethod(t *testing.T) { + source := ` +from functools import wraps +from cog import BasePredictor, Path + +def noop(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + return wrapper + +class Trainer(BasePredictor): + @noop + def train(self, n: int) -> Path: + pass +` + info, err := ParsePredictor([]byte(source), "Trainer", schema.ModeTrain, "") + require.NoError(t, err) + require.Equal(t, schema.ModeTrain, info.Mode) + require.Equal(t, 1, info.Inputs.Len()) +} + // --------------------------------------------------------------------------- // Non-BasePredictor class (just has predict method) // --------------------------------------------------------------------------- diff --git a/pkg/schema/types.go b/pkg/schema/types.go index 5a7c7ce611..1b0250e8f7 100644 --- a/pkg/schema/types.go +++ b/pkg/schema/types.go @@ -197,9 +197,10 @@ func (f *InputField) IsRequired() bool { // PredictorInfo is the top-level extraction result. type PredictorInfo struct { - Inputs *OrderedMap[string, InputField] - Output SchemaType - Mode Mode + Inputs *OrderedMap[string, InputField] + Output SchemaType + Mode Mode + SupportsStreaming bool } // TypeAnnotation is a parsed Python type annotation (intermediate, before resolution). diff --git a/python/cog/__init__.py b/python/cog/__init__.py index e3ae649fa6..775e2ba0e9 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -20,6 +20,8 @@ def run( """ import sys as _sys +from collections.abc import Callable +from typing import TypeVar, overload from coglet import CancelationException as CancelationException @@ -38,6 +40,29 @@ def run( URLPath, ) +_F = TypeVar("_F", bound=Callable[..., object]) + + +@overload +def streaming(fn: _F) -> _F: + pass + + +@overload +def streaming(fn: None = None) -> Callable[[_F], _F]: + pass + + +def streaming(fn: _F | None = None) -> _F | Callable[[_F], _F]: + """Mark a predict handler as supporting streaming responses.""" + + def decorate(inner: _F) -> _F: + return inner + + if fn is None: + return decorate + return decorate(fn) + # --------------------------------------------------------------------------- # Backwards-compatibility shim: ExperimentalFeatureWarning @@ -134,6 +159,8 @@ def current_scope() -> object: "CancelationException", # Metrics "current_scope", + # Decorators + "streaming", # Deprecated compat shims "ExperimentalFeatureWarning", "emit_metric", diff --git a/python/tests/test_types.py b/python/tests/test_types.py index 653658ce11..fdbbd6622b 100644 --- a/python/tests/test_types.py +++ b/python/tests/test_types.py @@ -10,6 +10,7 @@ Path, Secret, URLFile, + streaming, ) @@ -132,3 +133,25 @@ def test_async_concatenate_iterator_is_abstract(self) -> None: from typing import AsyncIterator assert issubclass(AsyncConcatenateIterator, AsyncIterator) + + +class TestStreamingDecorator: + """Tests for the streaming opt-in decorator.""" + + def test_streaming_returns_same_object(self) -> None: + def predict() -> str: + return "ok" + + decorated = streaming(predict) + + assert decorated is predict + assert not hasattr(predict, "__cog_streaming__") + + def test_streaming_call_form_returns_same_object(self) -> None: + def predict() -> str: + return "ok" + + decorated = streaming()(predict) + + assert decorated is predict + assert not hasattr(predict, "__cog_streaming__")