From 2b0b483c606a7dabe3997a23d19ab3cd705bc112 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 14 May 2026 12:32:10 -0400 Subject: [PATCH 01/11] feat: expose prediction SSE streams --- crates/coglet/src/orchestrator.rs | 6 +- crates/coglet/src/prediction.rs | 232 ++++++++++++++- crates/coglet/src/service.rs | 274 +++++++++++++++++- crates/coglet/src/transport/http/routes.rs | 149 +++++++++- .../tests/sse_streaming_output.txtar | 36 +++ 5 files changed, 680 insertions(+), 17 deletions(-) create mode 100644 integration-tests/tests/sse_streaming_output.txtar diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index c0aec8b9ed..5694338e35 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -966,7 +966,7 @@ async fn run_event_loop( Ok(SlotResponse::LogLine { source, data }) => { let (prediction_id, poisoned) = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { - p.append_log(&data); + p.append_log_source(source, &data); (Some(p.id().to_string()), false) } else { (None, true) @@ -1015,10 +1015,10 @@ async fn run_event_loop( predictions.remove(&slot_id); } } - Ok(SlotResponse::OutputChunk { output, index: _ }) => { + Ok(SlotResponse::OutputChunk { output, index }) => { let poisoned = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { - p.append_output(output); + p.append_output_chunk(output, index); false } else { true diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index 81ab018b40..1e86a20754 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -7,7 +7,7 @@ use std::time::Instant; use tokio::sync::Notify; pub use tokio_util::sync::CancellationToken; -use crate::bridge::protocol::MetricMode; +use crate::bridge::protocol::{LogSource, MetricMode}; use crate::webhook::{WebhookEventType, WebhookSender}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -64,6 +64,70 @@ impl PredictionOutput { } } +#[derive(Debug, Clone)] +pub enum PredictionStreamEvent { + Start { + id: String, + status: String, + }, + Output { + chunk: serde_json::Value, + index: u64, + }, + Log { + source: LogSource, + data: String, + }, + Metric { + name: String, + value: serde_json::Value, + mode: MetricMode, + }, + Completed { + payload: serde_json::Value, + }, +} + +pub struct PredictionStreamReplay { + pub replay: Vec, + pub receiver: tokio::sync::broadcast::Receiver, +} + +impl PredictionStreamEvent { + pub fn event_name(&self) -> &'static str { + match self { + Self::Start { .. } => "start", + Self::Output { .. } => "output", + Self::Log { .. } => "log", + Self::Metric { .. } => "metric", + Self::Completed { .. } => "completed", + } + } + + pub fn json_data(&self) -> serde_json::Value { + match self { + Self::Start { id, status } => serde_json::json!({ + "id": id, + "status": status, + }), + Self::Output { chunk, index } => serde_json::json!({ + "chunk": chunk, + "index": index, + }), + Self::Log { source, data } => serde_json::json!({ + "source": source, + "data": data, + }), + Self::Metric { name, value, mode } => serde_json::json!({ + "name": name, + "value": value, + "mode": mode, + }), + Self::Completed { payload } => payload.clone(), + } + } +} + /// Prediction lifecycle state. pub struct Prediction { id: String, @@ -76,12 +140,16 @@ pub struct Prediction { error: Option, webhook: Option, completion: Arc, + stream_tx: tokio::sync::broadcast::Sender, + stream_history: Vec, /// User-emitted metrics. Merged with system metrics (predict_time) in terminal response. metrics: HashMap, } impl Prediction { pub fn new(id: String, webhook: Option) -> Self { + let (stream_tx, _) = tokio::sync::broadcast::channel(1024); + Self { id, cancel_token: CancellationToken::new(), @@ -93,6 +161,8 @@ impl Prediction { error: None, webhook, completion: Arc::new(Notify::new()), + stream_tx, + stream_history: Vec::new(), metrics: HashMap::new(), } } @@ -105,6 +175,26 @@ impl Prediction { self.cancel_token.clone() } + pub fn subscribe_stream(&self) -> tokio::sync::broadcast::Receiver { + self.stream_tx.subscribe() + } + + pub fn subscribe_stream_replay(&self) -> PredictionStreamReplay { + PredictionStreamReplay { + replay: self.stream_history.clone(), + receiver: self.stream_tx.subscribe(), + } + } + + pub fn stream_receiver_count(&self) -> usize { + self.stream_tx.receiver_count() + } + + fn emit_stream_event(&mut self, event: PredictionStreamEvent) { + self.stream_history.push(event.clone()); + let _ = self.stream_tx.send(event); + } + pub fn is_canceled(&self) -> bool { self.cancel_token.is_cancelled() } @@ -119,6 +209,10 @@ impl Prediction { pub fn set_processing(&mut self) { self.status = PredictionStatus::Processing; + self.emit_stream_event(PredictionStreamEvent::Start { + id: self.id.clone(), + status: self.status.as_str().to_string(), + }); self.fire_webhook(WebhookEventType::Start); } @@ -128,6 +222,9 @@ impl Prediction { } self.status = PredictionStatus::Succeeded; self.output = Some(output); + self.emit_stream_event(PredictionStreamEvent::Completed { + payload: self.build_state_snapshot(), + }); self.fire_terminal_webhook(); // notify_one stores a permit so a future .notified().await will // consume it immediately. notify_waiters only wakes currently- @@ -144,6 +241,9 @@ impl Prediction { } self.status = PredictionStatus::Failed; self.error = Some(error); + self.emit_stream_event(PredictionStreamEvent::Completed { + payload: self.build_state_snapshot(), + }); self.fire_terminal_webhook(); self.completion.notify_one(); } @@ -153,6 +253,9 @@ impl Prediction { return; } self.status = PredictionStatus::Canceled; + self.emit_stream_event(PredictionStreamEvent::Completed { + payload: self.build_state_snapshot(), + }); self.fire_terminal_webhook(); self.completion.notify_one(); } @@ -162,7 +265,15 @@ impl Prediction { } pub fn append_log(&mut self, data: &str) { + self.append_log_source(LogSource::Stdout, data); + } + + pub fn append_log_source(&mut self, source: LogSource, data: &str) { self.logs.push_str(data); + self.emit_stream_event(PredictionStreamEvent::Log { + source, + data: data.to_string(), + }); self.fire_webhook(WebhookEventType::Logs); } @@ -184,6 +295,12 @@ impl Prediction { return; } + self.emit_stream_event(PredictionStreamEvent::Metric { + name: name.clone(), + value: value.clone(), + mode, + }); + // Dot-path resolution: "a.b.c" → nested objects let parts: Vec<&str> = name.split('.').collect(); if parts.len() > 1 { @@ -297,7 +414,16 @@ impl Prediction { } pub fn append_output(&mut self, output: serde_json::Value) { - self.outputs.push(output); + let index = self.outputs.len() as u64; + self.append_output_chunk(output, index); + } + + pub fn append_output_chunk(&mut self, output: serde_json::Value, index: u64) { + self.outputs.push(output.clone()); + self.emit_stream_event(PredictionStreamEvent::Output { + chunk: output, + index, + }); self.fire_webhook(WebhookEventType::Output); } @@ -484,6 +610,108 @@ mod tests { assert_eq!(pred.outputs().len(), 2); } + #[tokio::test] + async fn prediction_stream_emits_start_output_log_and_completed() { + let mut prediction = Prediction::new("pred_stream".to_string(), None); + let mut rx = prediction.subscribe_stream(); + + prediction.set_processing(); + prediction.append_output_chunk(serde_json::json!("hello"), 0); + prediction.append_log("loading\n"); + prediction.set_succeeded(PredictionOutput::Stream(vec![serde_json::json!("hello")])); + + let start = rx.recv().await.unwrap(); + assert_eq!(start.event_name(), "start"); + assert_eq!( + start.json_data(), + serde_json::json!({"id":"pred_stream","status":"processing"}) + ); + + let output = rx.recv().await.unwrap(); + assert_eq!(output.event_name(), "output"); + assert_eq!( + output.json_data(), + serde_json::json!({"chunk":"hello","index":0}) + ); + + let log = rx.recv().await.unwrap(); + assert_eq!(log.event_name(), "log"); + assert_eq!( + log.json_data(), + serde_json::json!({"source":"stdout","data":"loading\n"}) + ); + + let completed = rx.recv().await.unwrap(); + assert_eq!(completed.event_name(), "completed"); + assert_eq!(completed.json_data()["id"], "pred_stream"); + assert_eq!(completed.json_data()["status"], "succeeded"); + assert_eq!( + completed.json_data()["output"], + serde_json::json!(["hello"]) + ); + } + + #[tokio::test] + async fn prediction_stream_emits_metric_event() { + let mut prediction = Prediction::new("pred_metric".to_string(), None); + let mut rx = prediction.subscribe_stream(); + + prediction.set_metric( + "tokens".to_string(), + serde_json::json!(1), + MetricMode::Increment, + ); + + let event = rx.recv().await.unwrap(); + assert_eq!(event.event_name(), "metric"); + assert_eq!( + event.json_data(), + serde_json::json!({ + "name":"tokens", + "value":1, + "mode":"increment" + }) + ); + } + + #[tokio::test] + async fn prediction_stream_preserves_log_source() { + let mut prediction = Prediction::new("pred_log_source".to_string(), None); + let mut rx = prediction.subscribe_stream(); + + prediction.append_log_source(crate::bridge::protocol::LogSource::Stderr, "warning\n"); + + let event = rx.recv().await.unwrap(); + assert_eq!(event.event_name(), "log"); + assert_eq!( + event.json_data(), + serde_json::json!({"source":"stderr","data":"warning\n"}) + ); + } + + #[tokio::test] + async fn prediction_stream_replay_includes_already_emitted_events() { + let mut prediction = Prediction::new("pred_replay".to_string(), None); + + prediction.set_processing(); + prediction.append_output_chunk(serde_json::json!("hello"), 0); + prediction.set_succeeded(PredictionOutput::Stream(vec![serde_json::json!("hello")])); + + let replay = prediction.subscribe_stream_replay(); + let events: Vec<&str> = replay + .replay + .iter() + .map(|event| event.event_name()) + .collect(); + + assert_eq!(events, vec!["start", "output", "completed"]); + assert_eq!( + replay.replay[1].json_data(), + serde_json::json!({"chunk":"hello","index":0}) + ); + assert_eq!(replay.replay[2].json_data()["status"], "succeeded"); + } + #[tokio::test] async fn wait_returns_immediately_if_terminal() { let mut pred = Prediction::new("test".to_string(), None); diff --git a/crates/coglet/src/service.rs b/crates/coglet/src/service.rs index 300ccab9e7..80a42676bf 100644 --- a/crates/coglet/src/service.rs +++ b/crates/coglet/src/service.rs @@ -79,6 +79,7 @@ struct PredictionEntry { prediction: Arc>, cancel_token: CancellationToken, input: serde_json::Value, + respond_async: bool, } /// Handle to a submitted prediction for cancellation on disconnect. @@ -106,6 +107,49 @@ impl PredictionHandle { } } +pub struct PredictionStreamSubscription { + id: String, + replay: Vec, + receiver: tokio::sync::broadcast::Receiver, + guard: PredictionStreamGuard, +} + +impl PredictionStreamSubscription { + pub fn prediction_id(&self) -> &str { + &self.id + } + + pub fn into_parts( + self, + ) -> ( + Vec, + tokio::sync::broadcast::Receiver, + PredictionStreamGuard, + ) { + (self.replay, self.receiver, self.guard) + } +} + +pub struct PredictionStreamGuard { + id: String, + service: Arc, + respond_async: bool, +} + +impl Drop for PredictionStreamGuard { + fn drop(&mut self) { + if self.respond_async { + return; + } + + if self.service.stream_receiver_count(&self.id) == 0 + && !self.service.prediction_is_terminal(&self.id) + { + self.service.cancel(&self.id); + } + } +} + /// Guard for sync predictions - cancels on drop unless disarmed. /// /// When the HTTP connection drops (client disconnect), axum drops the @@ -415,6 +459,7 @@ impl PredictionService { id: String, input: serde_json::Value, webhook: Option, + respond_async: bool, ) -> Result<(PredictionHandle, UnregisteredPredictionSlot), CreatePredictionError> { let health = *self.health.read().await; if health != Health::Ready { @@ -442,6 +487,7 @@ impl PredictionService { prediction: prediction_arc, cancel_token: cancel_token.clone(), input, + respond_async, }, ); @@ -469,6 +515,45 @@ impl PredictionService { Some(response) } + pub fn subscribe_prediction_stream( + self: &Arc, + id: &str, + ) -> Option { + let entry = self.predictions.get(id)?; + let stream = entry.prediction.lock().ok()?.subscribe_stream_replay(); + let respond_async = entry.respond_async; + Some(PredictionStreamSubscription { + id: id.to_string(), + replay: stream.replay, + receiver: stream.receiver, + guard: PredictionStreamGuard { + id: id.to_string(), + service: Arc::clone(self), + respond_async, + }, + }) + } + + fn stream_receiver_count(&self, id: &str) -> usize { + self.predictions + .get(id) + .and_then(|entry| { + entry + .prediction + .lock() + .ok() + .map(|p| p.stream_receiver_count()) + }) + .unwrap_or(0) + } + + fn prediction_is_terminal(&self, id: &str) -> bool { + self.predictions + .get(id) + .and_then(|entry| entry.prediction.lock().ok().map(|p| p.is_terminal())) + .unwrap_or(true) + } + /// Run a prediction to completion via orchestrator. pub async fn predict( &self, @@ -760,6 +845,51 @@ mod tests { } } + struct CountingCancelOrchestrator { + cancel_count: AtomicUsize, + } + + impl CountingCancelOrchestrator { + fn new() -> Self { + Self { + cancel_count: AtomicUsize::new(0), + } + } + + fn cancel_count(&self) -> usize { + self.cancel_count.load(Ordering::SeqCst) + } + } + + #[async_trait::async_trait] + impl Orchestrator for CountingCancelOrchestrator { + async fn register_prediction( + &self, + _slot_id: SlotId, + _prediction: Arc>, + _idle_sender: tokio::sync::oneshot::Sender, + ) { + } + + async fn cancel_by_prediction_id( + &self, + _prediction_id: &str, + ) -> Result<(), crate::orchestrator::OrchestratorError> { + self.cancel_count.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + + async fn healthcheck( + &self, + ) -> Result { + Ok(HealthcheckResult::healthy()) + } + + async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> { + Ok(()) + } + } + async fn create_test_pool(num_slots: usize) -> Arc { use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::SlotRequest; @@ -863,7 +993,7 @@ mod tests { let svc = PredictionService::new_no_pool(); let result = svc - .submit_prediction("test".to_string(), serde_json::json!({}), None) + .submit_prediction("test".to_string(), serde_json::json!({}), None, false) .await; assert!(matches!(result, Err(CreatePredictionError::NotReady))); } @@ -908,7 +1038,7 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-1".to_string(), serde_json::json!({}), None) + .submit_prediction("test-1".to_string(), serde_json::json!({}), None, false) .await .unwrap(); @@ -916,6 +1046,113 @@ mod tests { assert!(svc.prediction_exists("test-1")); } + #[tokio::test] + async fn subscribe_prediction_stream_returns_receiver_for_existing_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::new()); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction("stream-test".to_string(), serde_json::json!({}), None, true) + .await + .unwrap(); + + let subscription = svc.subscribe_prediction_stream("stream-test").unwrap(); + assert_eq!(subscription.prediction_id(), "stream-test"); + } + + #[tokio::test] + async fn dropping_only_sync_stream_subscription_cancels_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "sync-stream".to_string(), + serde_json::json!({}), + None, + false, + ) + .await + .unwrap(); + + let subscription = svc.subscribe_prediction_stream("sync-stream").unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 1); + } + + #[tokio::test] + async fn dropping_async_stream_subscription_does_not_cancel_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "async-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + + let subscription = svc.subscribe_prediction_stream("async-stream").unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 0); + } + + #[tokio::test] + async fn dropping_completed_sync_stream_subscription_does_not_cancel_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "completed-sync-stream".to_string(), + serde_json::json!({}), + None, + false, + ) + .await + .unwrap(); + + { + let entry = svc.predictions.get("completed-sync-stream").unwrap(); + let mut prediction = entry.prediction.lock().unwrap(); + prediction.set_succeeded(crate::PredictionOutput::Single(serde_json::json!("done"))); + } + + let subscription = svc + .subscribe_prediction_stream("completed-sync-stream") + .unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 0); + } + #[tokio::test] async fn submit_returns_at_capacity_when_no_slots() { let svc = PredictionService::new_no_pool(); @@ -927,13 +1164,13 @@ mod tests { // First prediction takes the only slot let (_handle1, _slot1) = svc - .submit_prediction("test-1".to_string(), serde_json::json!({}), None) + .submit_prediction("test-1".to_string(), serde_json::json!({}), None, false) .await .unwrap(); // Second should fail with AtCapacity let result = svc - .submit_prediction("test-2".to_string(), serde_json::json!({}), None) + .submit_prediction("test-2".to_string(), serde_json::json!({}), None, false) .await; assert!(matches!(result, Err(CreatePredictionError::AtCapacity))); } @@ -953,6 +1190,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -986,7 +1224,7 @@ mod tests { // After acquiring slot let (_handle, _slot) = svc - .submit_prediction("test-1".to_string(), serde_json::json!({}), None) + .submit_prediction("test-1".to_string(), serde_json::json!({}), None, false) .await .unwrap(); let health = svc.health().await; @@ -1009,6 +1247,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -1048,6 +1287,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -1087,6 +1327,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -1113,7 +1354,12 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-cancel".to_string(), serde_json::json!({}), None) + .submit_prediction( + "test-cancel".to_string(), + serde_json::json!({}), + None, + false, + ) .await .unwrap(); @@ -1139,7 +1385,7 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-guard".to_string(), serde_json::json!({}), None) + .submit_prediction("test-guard".to_string(), serde_json::json!({}), None, false) .await .unwrap(); @@ -1163,7 +1409,12 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-disarm".to_string(), serde_json::json!({}), None) + .submit_prediction( + "test-disarm".to_string(), + serde_json::json!({}), + None, + false, + ) .await .unwrap(); @@ -1187,7 +1438,12 @@ mod tests { svc.set_health(Health::Ready).await; let (_handle, _slot) = svc - .submit_prediction("test-remove".to_string(), serde_json::json!({}), None) + .submit_prediction( + "test-remove".to_string(), + serde_json::json!({}), + None, + false, + ) .await .unwrap(); diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index aa9875340e..fd1b7cdc8f 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -1,12 +1,17 @@ //! HTTP route handlers. +use std::convert::Infallible; use std::sync::Arc; +use std::time::Duration; use axum::{ Router, extract::{DefaultBodyLimit, Path, State}, http::{HeaderMap, StatusCode}, - response::{IntoResponse, Json}, + response::{ + IntoResponse, Json, Response, + sse::{Event, KeepAlive, Sse}, + }, routing::{get, post, put}, }; use serde::{Deserialize, Serialize}; @@ -15,7 +20,9 @@ use serde::{Deserialize, Serialize}; use crate::health::Health; use crate::health::{HealthResponse, SetupResult}; use crate::predictor::PredictionError; -use crate::service::{CreatePredictionError, HealthSnapshot, PredictionService}; +use crate::service::{ + CreatePredictionError, HealthSnapshot, PredictionService, PredictionStreamSubscription, +}; use crate::version::VersionInfo; use crate::webhook::{TraceContext, WebhookConfig, WebhookEventType, WebhookSender}; @@ -376,7 +383,12 @@ async fn create_prediction_with_id( // Submit prediction: creates Prediction, acquires slot, registers in service let (handle, unregistered_slot) = match service - .submit_prediction(prediction_id.clone(), input.clone(), webhook_sender) + .submit_prediction( + prediction_id.clone(), + input.clone(), + webhook_sender, + respond_async, + ) .await { Ok(r) => r, @@ -557,6 +569,80 @@ async fn cancel_prediction( } } +fn stream_event_to_sse(event: crate::prediction::PredictionStreamEvent) -> Event { + Event::default() + .event(event.event_name()) + .json_data(event.json_data()) + .expect("prediction stream events serialize to JSON") +} + +fn prediction_sse_stream( + subscription: PredictionStreamSubscription, +) -> impl futures::Stream> { + let (replay, receiver, guard) = subscription.into_parts(); + + struct StreamState { + replay: std::collections::VecDeque, + receiver: tokio::sync::broadcast::Receiver, + _guard: crate::service::PredictionStreamGuard, + done: bool, + } + + futures::stream::unfold( + StreamState { + replay: replay.into(), + receiver, + _guard: guard, + done: false, + }, + |mut state| async move { + if state.done { + return None; + } + + if let Some(event) = state.replay.pop_front() { + state.done = event.event_name() == "completed"; + return Some((Ok(stream_event_to_sse(event)), state)); + } + + loop { + match state.receiver.recv().await { + Ok(event) => { + state.done = event.event_name() == "completed"; + return Some((Ok(stream_event_to_sse(event)), state)); + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => { + tracing::warn!(skipped, "SSE prediction stream receiver lagged"); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => return None, + } + } + }, + ) +} + +async fn stream_prediction( + State(service): State>, + Path(prediction_id): Path, +) -> Response { + let Some(subscription) = service.subscribe_prediction_stream(&prediction_id) else { + return ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({"error": "Prediction not found"})), + ) + .into_response(); + }; + + Sse::new(prediction_sse_stream(subscription)) + .keep_alive( + KeepAlive::new() + .interval(Duration::from_secs(15)) + .text("keep-alive"), + ) + .into_response() +} + async fn shutdown(State(service): State>) -> impl IntoResponse { tracing::info!("Shutdown requested via HTTP"); service.trigger_shutdown(); @@ -679,6 +765,7 @@ pub fn routes(service: Arc) -> Router { .route("/shutdown", post(shutdown)) .route("/predictions", post(create_prediction)) .route("/predictions/{id}", put(create_prediction_idempotent)) + .route("/predictions/{id}/stream", get(stream_prediction)) .route("/predictions/{id}/cancel", post(cancel_prediction)) .route("/trainings", post(create_training)) .route("/trainings/{id}", put(create_training_idempotent)) @@ -955,6 +1042,62 @@ mod tests { assert_eq!(json["status"], "starting"); } + #[tokio::test] + async fn stream_prediction_unknown_id_returns_404() { + let service = create_ready_service().await; + let app = routes(service); + + let response = app + .oneshot( + Request::get("/predictions/missing/stream") + .header("accept", "text/event-stream") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + let json = response_json(response).await; + assert_eq!(json["error"], "Prediction not found"); + } + + #[tokio::test] + async fn stream_prediction_existing_id_returns_sse() { + let service = create_ready_service().await; + let (_handle, _slot) = service + .submit_prediction( + "stream-route".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let app = routes(service); + + let response = app + .oneshot( + Request::get("/predictions/stream-route/stream") + .header("accept", "text/event-stream") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let content_type = response.headers().get("content-type").unwrap(); + assert!( + content_type + .to_str() + .unwrap() + .starts_with("text/event-stream"), + "unexpected content-type: {:?}", + content_type + ); + } + #[tokio::test] async fn prediction_with_custom_id() { let service = create_ready_service().await; diff --git a/integration-tests/tests/sse_streaming_output.txtar b/integration-tests/tests/sse_streaming_output.txtar new file mode 100644 index 0000000000..76e7d4d47d --- /dev/null +++ b/integration-tests/tests/sse_streaming_output.txtar @@ -0,0 +1,36 @@ +# Test that async generator output is available over the SSE stream endpoint. + +[short] skip 'requires Docker build' + +cog build -t $TEST_IMAGE +cog serve --upload-url http://unused/ + +curl -H Prefer:respond-async PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' +stdout '"status":"starting"' + +curl -N -H Accept:text/event-stream GET /predictions/sse-stream-test/stream +stdout 'event: output' +stdout 'data: {"chunk":"chunk-1","index":0}' +stdout 'event: output' +stdout 'data: {"chunk":"chunk-2","index":1}' +stdout 'event: completed' +stdout '"status":"succeeded"' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +import time +from typing import Iterator + +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self) -> Iterator[str]: + time.sleep(0.25) + yield "chunk-1" + time.sleep(0.25) + yield "chunk-2" From c61e323d0b6a24e89962a4e1c267c4b5a79f0a8e Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 14 May 2026 13:00:14 -0400 Subject: [PATCH 02/11] docs: add streaming text example --- examples/streaming-text/.dockerignore | 3 ++ examples/streaming-text/.gitignore | 3 ++ examples/streaming-text/README.md | 58 +++++++++++++++++++++ examples/streaming-text/cog.yaml | 7 +++ examples/streaming-text/predict.py | 64 ++++++++++++++++++++++++ examples/streaming-text/requirements.txt | 3 ++ 6 files changed, 138 insertions(+) create mode 100644 examples/streaming-text/.dockerignore create mode 100644 examples/streaming-text/.gitignore create mode 100644 examples/streaming-text/README.md create mode 100644 examples/streaming-text/cog.yaml create mode 100644 examples/streaming-text/predict.py create mode 100644 examples/streaming-text/requirements.txt diff --git a/examples/streaming-text/.dockerignore b/examples/streaming-text/.dockerignore new file mode 100644 index 0000000000..9118d0e055 --- /dev/null +++ b/examples/streaming-text/.dockerignore @@ -0,0 +1,3 @@ +.cog/ +__pycache__/ +*.pyc diff --git a/examples/streaming-text/.gitignore b/examples/streaming-text/.gitignore new file mode 100644 index 0000000000..9118d0e055 --- /dev/null +++ b/examples/streaming-text/.gitignore @@ -0,0 +1,3 @@ +.cog/ +__pycache__/ +*.pyc diff --git a/examples/streaming-text/README.md b/examples/streaming-text/README.md new file mode 100644 index 0000000000..947d5b6625 --- /dev/null +++ b/examples/streaming-text/README.md @@ -0,0 +1,58 @@ +# examples/streaming-text + +Streaming text generation with `HuggingFaceTB/SmolLM2-135M-Instruct`. + +This example shows how a Cog predictor can yield text chunks as a model generates them, and how to consume those chunks from the Server-Sent Events stream endpoint. + +## Run a normal prediction + +From this directory: + +```sh +cog predict -i prompt="Write a short haiku about databases" +``` + +This returns the final accumulated output after the prediction completes. + +## Stream output over HTTP + +Start the server: + +```sh +cog serve +``` + +Create an async prediction with a fixed ID: + +```sh +curl -s -X PUT http://localhost:5000/predictions/streaming-demo \ + -H 'Content-Type: application/json' \ + -H 'Prefer: respond-async' \ + -d '{"input":{"prompt":"Write a short haiku about databases","max_new_tokens":96}}' +``` + +Then subscribe to its stream: + +```sh +curl -N -H 'Accept: text/event-stream' \ + http://localhost:5000/predictions/streaming-demo/stream +``` + +The response includes `output` events as chunks are generated, followed by a `completed` event: + +```text +event: output +data: {"chunk":"Silent","index":0} + +event: output +data: {"chunk":" rows","index":1} + +event: completed +data: {"id":"streaming-demo","status":"succeeded",...} +``` + +## How it works + +`predict.py` returns `Iterator[str]`. Each `yield` becomes one streamed output chunk. The example uses Hugging Face `TextIteratorStreamer` to receive generated text from `model.generate()` while generation is still running. + +The normal prediction response still contains the accumulated output for compatibility. The stream endpoint is useful when clients want to display tokens as they arrive. diff --git a/examples/streaming-text/cog.yaml b/examples/streaming-text/cog.yaml new file mode 100644 index 0000000000..68866b7aeb --- /dev/null +++ b/examples/streaming-text/cog.yaml @@ -0,0 +1,7 @@ +# Streaming text generation example using a small open-weight language model. + +build: + python_version: "3.12" + python_requirements: requirements.txt + +predict: "predict.py:Predictor" diff --git a/examples/streaming-text/predict.py b/examples/streaming-text/predict.py new file mode 100644 index 0000000000..3ed8772e54 --- /dev/null +++ b/examples/streaming-text/predict.py @@ -0,0 +1,64 @@ +from threading import Thread +from typing import Iterator + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer + +from cog import BasePredictor, Input + +MODEL_NAME = "HuggingFaceTB/SmolLM2-135M-Instruct" + + +class Predictor(BasePredictor): + def setup(self) -> None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if self.device == "cuda" else torch.float32 + + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + self.model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=dtype, + ).to(self.device) + self.model.eval() + + def predict( + self, + prompt: str = Input(description="Prompt to complete"), + max_new_tokens: int = Input( + description="Maximum number of tokens to generate", + default=128, + ge=1, + le=512, + ), + ) -> Iterator[str]: + messages = [{"role": "user", "content": prompt}] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + inputs = self.tokenizer([text], return_tensors="pt").to(self.device) + streamer = TextIteratorStreamer( + self.tokenizer, + skip_prompt=True, + skip_special_tokens=True, + ) + + generation_kwargs = { + **inputs, + "streamer": streamer, + "max_new_tokens": max_new_tokens, + "do_sample": True, + "temperature": 0.7, + "top_p": 0.9, + "pad_token_id": self.tokenizer.eos_token_id, + } + + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + thread.start() + + for chunk in streamer: + if chunk: + yield chunk + + thread.join() diff --git a/examples/streaming-text/requirements.txt b/examples/streaming-text/requirements.txt new file mode 100644 index 0000000000..916933a3de --- /dev/null +++ b/examples/streaming-text/requirements.txt @@ -0,0 +1,3 @@ +torch==2.7.1 +transformers==4.51.3 +accelerate==1.6.0 From bf064d22475bd6b0f7a3cbf0008f8b372f4f48eb Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 14 May 2026 15:58:26 -0400 Subject: [PATCH 03/11] feat: stream predictions via accept header --- crates/coglet/src/transport/http/routes.rs | 196 ++++++++++++++---- examples/streaming-text/README.md | 17 +- .../tests/sse_streaming_output.txtar | 7 +- 3 files changed, 163 insertions(+), 57 deletions(-) diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index fd1b7cdc8f..1cdc6f6c18 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -216,6 +216,43 @@ fn should_respond_async(headers: &HeaderMap) -> bool { .unwrap_or(false) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PredictionResponseMode { + SyncJson, + AsyncJson, + AsyncSse, +} + +fn wants_sse(headers: &HeaderMap) -> bool { + headers + .get(axum::http::header::ACCEPT) + .and_then(|value| value.to_str().ok()) + .map(|accept| { + accept + .split(',') + .any(|part| part.trim().split(';').next() == Some("text/event-stream")) + }) + .unwrap_or(false) +} + +fn prediction_response_mode(headers: &HeaderMap) -> PredictionResponseMode { + if wants_sse(headers) { + PredictionResponseMode::AsyncSse + } else if should_respond_async(headers) { + PredictionResponseMode::AsyncJson + } else { + PredictionResponseMode::SyncJson + } +} + +fn json_response_mode(headers: &HeaderMap) -> PredictionResponseMode { + if should_respond_async(headers) { + PredictionResponseMode::AsyncJson + } else { + PredictionResponseMode::SyncJson + } +} + fn extract_trace_context(headers: &HeaderMap) -> TraceContext { TraceContext { traceparent: headers @@ -233,7 +270,7 @@ async fn create_prediction( State(service): State>, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -242,7 +279,7 @@ async fn create_prediction( webhook_events_filter: default_webhook_events_filter(), }); let prediction_id = request.id.unwrap_or_else(generate_prediction_id); - let respond_async = should_respond_async(&headers); + let response_mode = prediction_response_mode(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -251,7 +288,7 @@ async fn create_prediction( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, false, ) @@ -263,7 +300,7 @@ async fn create_prediction_idempotent( Path(prediction_id): Path, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -284,15 +321,20 @@ async fn create_prediction_idempotent( "type": "value_error" }] })), - ); + ) + .into_response(); } + let response_mode = prediction_response_mode(&headers); + // Check if prediction with this ID is already in-flight if let Some(response) = service.get_prediction_response(&prediction_id) { - return (StatusCode::ACCEPTED, Json(response)); + if response_mode == PredictionResponseMode::AsyncSse { + return stream_prediction_response(service, &prediction_id); + } + return (StatusCode::ACCEPTED, Json(response)).into_response(); } - let respond_async = should_respond_async(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -301,7 +343,7 @@ async fn create_prediction_idempotent( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, false, ) @@ -340,10 +382,10 @@ async fn create_prediction_with_id( context: std::collections::HashMap, webhook: Option, webhook_events_filter: Vec, - respond_async: bool, + response_mode: PredictionResponseMode, trace_context: TraceContext, is_training: bool, -) -> (StatusCode, Json) { +) -> Response { // Strip unknown fields and validate in one pass. Unknown inputs are // silently dropped to match Replicate's historical API behavior. let (stripped, validation_result) = if is_training { @@ -372,7 +414,8 @@ async fn create_prediction_with_id( return ( StatusCode::UNPROCESSABLE_ENTITY, Json(serde_json::json!({ "detail": detail })), - ); + ) + .into_response(); } let webhook_sender = build_webhook_sender( @@ -387,7 +430,7 @@ async fn create_prediction_with_id( prediction_id.clone(), input.clone(), webhook_sender, - respond_async, + response_mode != PredictionResponseMode::SyncJson, ) .await { @@ -400,7 +443,8 @@ async fn create_prediction_with_id( "error": msg, "status": "failed" })), - ); + ) + .into_response(); } Err(CreatePredictionError::AtCapacity) => { return ( @@ -409,14 +453,15 @@ async fn create_prediction_with_id( "error": "At capacity - all prediction slots busy", "status": "failed" })), - ); + ) + .into_response(); } }; let prediction = unregistered_slot.prediction(); // Async mode: spawn background task, return immediately - if respond_async { + if response_mode != PredictionResponseMode::SyncJson { let service_clone = Arc::clone(&service); let id_for_cleanup = prediction_id.clone(); let context_async = context.clone(); @@ -429,13 +474,18 @@ async fn create_prediction_with_id( service_clone.remove_prediction(&id_for_cleanup); }); + if response_mode == PredictionResponseMode::AsyncSse { + return stream_prediction_response(service, &prediction_id); + } + return ( StatusCode::ACCEPTED, Json(serde_json::json!({ "id": prediction_id, "status": "starting" })), - ); + ) + .into_response(); } // Sync mode: spawn prediction into a background task so the slot lifetime @@ -501,6 +551,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } Err(PredictionError::InvalidInput(msg)) => { let metrics = build_metrics(&user_metrics); @@ -514,6 +565,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } Err(PredictionError::NotReady) => { let msg = PredictionError::NotReady.to_string(); @@ -526,6 +578,7 @@ async fn create_prediction_with_id( "status": "failed" })), ) + .into_response() } Err(PredictionError::Failed(msg)) => { let metrics = build_metrics(&user_metrics); @@ -540,6 +593,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } Err(PredictionError::Cancelled) => { let metrics = build_metrics(&user_metrics); @@ -552,6 +606,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } } } @@ -622,11 +677,8 @@ fn prediction_sse_stream( ) } -async fn stream_prediction( - State(service): State>, - Path(prediction_id): Path, -) -> Response { - let Some(subscription) = service.subscribe_prediction_stream(&prediction_id) else { +fn stream_prediction_response(service: Arc, prediction_id: &str) -> Response { + let Some(subscription) = service.subscribe_prediction_stream(prediction_id) else { return ( StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "Prediction not found"})), @@ -668,7 +720,7 @@ async fn create_training( State(service): State>, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -677,7 +729,7 @@ async fn create_training( webhook_events_filter: default_webhook_events_filter(), }); let prediction_id = request.id.unwrap_or_else(generate_prediction_id); - let respond_async = should_respond_async(&headers); + let response_mode = json_response_mode(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -686,7 +738,7 @@ async fn create_training( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, true, ) @@ -698,7 +750,7 @@ async fn create_training_idempotent( Path(training_id): Path, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -719,15 +771,16 @@ async fn create_training_idempotent( "type": "value_error" }] })), - ); + ) + .into_response(); } // Idempotent: return existing state if already submitted if let Some(response) = service.get_prediction_response(&training_id) { - return (StatusCode::ACCEPTED, Json(response)); + return (StatusCode::ACCEPTED, Json(response)).into_response(); } - let respond_async = should_respond_async(&headers); + let response_mode = json_response_mode(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -736,7 +789,7 @@ async fn create_training_idempotent( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, true, ) @@ -765,7 +818,6 @@ pub fn routes(service: Arc) -> Router { .route("/shutdown", post(shutdown)) .route("/predictions", post(create_prediction)) .route("/predictions/{id}", put(create_prediction_idempotent)) - .route("/predictions/{id}/stream", get(stream_prediction)) .route("/predictions/{id}/cancel", post(cancel_prediction)) .route("/trainings", post(create_training)) .route("/trainings/{id}", put(create_training_idempotent)) @@ -1043,31 +1095,67 @@ mod tests { } #[tokio::test] - async fn stream_prediction_unknown_id_returns_404() { + async fn prediction_post_with_sse_accept_returns_sse() { let service = create_ready_service().await; let app = routes(service); let response = app .oneshot( - Request::get("/predictions/missing/stream") + Request::post("/predictions") + .header("content-type", "application/json") .header("accept", "text/event-stream") - .body(Body::empty()) + .body(Body::from(r#"{"input":{}}"#)) .unwrap(), ) .await .unwrap(); - assert_eq!(response.status(), StatusCode::NOT_FOUND); - let json = response_json(response).await; - assert_eq!(json["error"], "Prediction not found"); + assert_eq!(response.status(), StatusCode::OK); + let content_type = response.headers().get("content-type").unwrap(); + assert!( + content_type + .to_str() + .unwrap() + .starts_with("text/event-stream"), + "unexpected content-type: {:?}", + content_type + ); + } + + #[tokio::test] + async fn prediction_put_with_sse_accept_returns_sse() { + let service = create_ready_service().await; + let app = routes(service); + + let response = app + .oneshot( + Request::put("/predictions/sse-put") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let content_type = response.headers().get("content-type").unwrap(); + assert!( + content_type + .to_str() + .unwrap() + .starts_with("text/event-stream"), + "unexpected content-type: {:?}", + content_type + ); } #[tokio::test] - async fn stream_prediction_existing_id_returns_sse() { + async fn prediction_put_existing_with_sse_accept_returns_sse() { let service = create_ready_service().await; let (_handle, _slot) = service .submit_prediction( - "stream-route".to_string(), + "existing-sse-put".to_string(), serde_json::json!({}), None, true, @@ -1078,9 +1166,10 @@ mod tests { let response = app .oneshot( - Request::get("/predictions/stream-route/stream") + Request::put("/predictions/existing-sse-put") + .header("content-type", "application/json") .header("accept", "text/event-stream") - .body(Body::empty()) + .body(Body::from(r#"{"input":{}}"#)) .unwrap(), ) .await @@ -1098,6 +1187,33 @@ mod tests { ); } + #[tokio::test] + async fn stream_prediction_route_is_removed() { + let service = create_ready_service().await; + let (_handle, _slot) = service + .submit_prediction( + "removed-stream-route".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let app = routes(service); + + let response = app + .oneshot( + Request::get("/predictions/removed-stream-route/stream") + .header("accept", "text/event-stream") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + } + #[tokio::test] async fn prediction_with_custom_id() { let service = create_ready_service().await; diff --git a/examples/streaming-text/README.md b/examples/streaming-text/README.md index 947d5b6625..645664289b 100644 --- a/examples/streaming-text/README.md +++ b/examples/streaming-text/README.md @@ -2,7 +2,7 @@ Streaming text generation with `HuggingFaceTB/SmolLM2-135M-Instruct`. -This example shows how a Cog predictor can yield text chunks as a model generates them, and how to consume those chunks from the Server-Sent Events stream endpoint. +This example shows how a Cog predictor can yield text chunks as a model generates them, and how to consume those chunks with Server-Sent Events. ## Run a normal prediction @@ -22,22 +22,15 @@ Start the server: cog serve ``` -Create an async prediction with a fixed ID: +Create a prediction and request an SSE response: ```sh -curl -s -X PUT http://localhost:5000/predictions/streaming-demo \ +curl -N -X PUT http://localhost:5000/predictions/streaming-demo \ -H 'Content-Type: application/json' \ - -H 'Prefer: respond-async' \ + -H 'Accept: text/event-stream' \ -d '{"input":{"prompt":"Write a short haiku about databases","max_new_tokens":96}}' ``` -Then subscribe to its stream: - -```sh -curl -N -H 'Accept: text/event-stream' \ - http://localhost:5000/predictions/streaming-demo/stream -``` - The response includes `output` events as chunks are generated, followed by a `completed` event: ```text @@ -55,4 +48,4 @@ data: {"id":"streaming-demo","status":"succeeded",...} `predict.py` returns `Iterator[str]`. Each `yield` becomes one streamed output chunk. The example uses Hugging Face `TextIteratorStreamer` to receive generated text from `model.generate()` while generation is still running. -The normal prediction response still contains the accumulated output for compatibility. The stream endpoint is useful when clients want to display tokens as they arrive. +The normal prediction response still contains the accumulated output for compatibility. Requesting `Accept: text/event-stream` is useful when clients want to display tokens as they arrive. diff --git a/integration-tests/tests/sse_streaming_output.txtar b/integration-tests/tests/sse_streaming_output.txtar index 76e7d4d47d..5f567cfd6e 100644 --- a/integration-tests/tests/sse_streaming_output.txtar +++ b/integration-tests/tests/sse_streaming_output.txtar @@ -1,14 +1,11 @@ -# Test that async generator output is available over the SSE stream endpoint. +# Test that async generator output is available when predictions are created with SSE accept. [short] skip 'requires Docker build' cog build -t $TEST_IMAGE cog serve --upload-url http://unused/ -curl -H Prefer:respond-async PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' -stdout '"status":"starting"' - -curl -N -H Accept:text/event-stream GET /predictions/sse-stream-test/stream +curl -N -H Accept:text/event-stream PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' stdout 'event: output' stdout 'data: {"chunk":"chunk-1","index":0}' stdout 'event: output' From 308ecffb4fb0dc9b69627d0e8c00907a26ccc20d Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 14 May 2026 16:04:18 -0400 Subject: [PATCH 04/11] fix: bound prediction stream replay history --- crates/coglet/src/prediction.rs | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index 1e86a20754..d40f1b38d7 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -10,6 +10,8 @@ pub use tokio_util::sync::CancellationToken; use crate::bridge::protocol::{LogSource, MetricMode}; use crate::webhook::{WebhookEventType, WebhookSender}; +const MAX_STREAM_HISTORY_EVENTS: usize = 1024; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PredictionStatus { Starting, @@ -191,6 +193,9 @@ impl Prediction { } fn emit_stream_event(&mut self, event: PredictionStreamEvent) { + if self.stream_history.len() == MAX_STREAM_HISTORY_EVENTS { + self.stream_history.remove(0); + } self.stream_history.push(event.clone()); let _ = self.stream_tx.send(event); } @@ -712,6 +717,28 @@ mod tests { assert_eq!(replay.replay[2].json_data()["status"], "succeeded"); } + #[tokio::test] + async fn prediction_stream_replay_is_bounded_to_recent_events() { + let mut prediction = Prediction::new("pred_replay_bounded".to_string(), None); + + prediction.set_processing(); + for index in 0..1100 { + prediction.append_output_chunk(serde_json::json!(index), index); + } + + let replay = prediction.subscribe_stream_replay(); + + assert_eq!(replay.replay.len(), MAX_STREAM_HISTORY_EVENTS); + assert_eq!( + replay.replay[0].json_data(), + serde_json::json!({"chunk":76,"index":76}) + ); + assert_eq!( + replay.replay[MAX_STREAM_HISTORY_EVENTS - 1].json_data(), + serde_json::json!({"chunk":1099,"index":1099}) + ); + } + #[tokio::test] async fn wait_returns_immediately_if_terminal() { let mut pred = Prediction::new("test".to_string(), None); From 8acf5f95755ae039146ef112095ce694e3bb5847 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 14 May 2026 16:59:47 -0400 Subject: [PATCH 05/11] fix: harden SSE prediction streaming --- crates/coglet/src/orchestrator.rs | 50 ++++-- crates/coglet/src/prediction.rs | 7 + crates/coglet/src/service.rs | 163 ++++++++++++++++-- crates/coglet/src/transport/http/routes.rs | 159 +++++++++++++++-- .../tests/sse_streaming_output.txtar | 3 +- 5 files changed, 336 insertions(+), 46 deletions(-) diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index 5694338e35..8fb6735eb9 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -7,7 +7,7 @@ //! 4. Run event loop routing responses to predictions //! 5. On worker crash: fail all predictions, shut down -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::process::Stdio; use std::sync::Arc; use std::sync::Mutex as StdMutex; @@ -353,15 +353,18 @@ pub struct OrchestratorReady { pub setup_logs: String, } +type RegisterPredictionMessage = ( + SlotId, + Arc>, + tokio::sync::oneshot::Sender, + tokio::sync::oneshot::Sender<()>, +); + pub struct OrchestratorHandle { child: Child, ctrl_writer: Arc>>>, - register_tx: mpsc::Sender<( - SlotId, - Arc>, - tokio::sync::oneshot::Sender, - )>, + register_tx: mpsc::Sender, healthcheck_tx: mpsc::Sender>, cancel_tx: mpsc::Sender, slot_ids: Vec, @@ -375,10 +378,12 @@ impl Orchestrator for OrchestratorHandle { prediction: Arc>, idle_sender: tokio::sync::oneshot::Sender, ) { + let (ack_tx, ack_rx) = tokio::sync::oneshot::channel(); let _ = self .register_tx - .send((slot_id, prediction, idle_sender)) + .send((slot_id, prediction, idle_sender, ack_tx)) .await; + let _ = ack_rx.await; } async fn cancel_by_prediction_id(&self, prediction_id: &str) -> Result<(), OrchestratorError> { @@ -698,11 +703,7 @@ async fn run_event_loop( SlotId, FramedRead>, )>, - mut register_rx: mpsc::Receiver<( - SlotId, - Arc>, - tokio::sync::oneshot::Sender, - )>, + mut register_rx: mpsc::Receiver, mut healthcheck_rx: mpsc::Receiver>, mut cancel_rx: mpsc::Receiver, pool: Arc, @@ -718,6 +719,7 @@ async fn run_event_loop( let mut pending_healthchecks: Vec> = Vec::new(); let mut healthcheck_counter: u64 = 0; let mut pending_uploads: HashMap>> = HashMap::new(); + let mut pending_cancellations: HashSet = HashSet::new(); let (slot_msg_tx, mut slot_msg_rx) = mpsc::channel::<(SlotId, Result)>(100); @@ -923,17 +925,19 @@ async fn run_event_loop( } } None => { - tracing::debug!(%prediction_id, "Cancel requested for unknown prediction (may have already completed)"); + tracing::debug!(%prediction_id, "Cancel requested for unknown prediction; storing pending cancellation"); + pending_cancellations.insert(prediction_id); } } } - Some((slot_id, prediction, idle_sender)) = register_rx.recv() => { + Some((slot_id, prediction, idle_sender, registered_tx)) = register_rx.recv() => { let prediction_id = match try_lock_prediction(&prediction) { Some(p) => p.id().to_string(), None => { // Mutex poisoned during registration - prediction already failed tracing::error!(%slot_id, "Prediction mutex poisoned during registration"); + let _ = registered_tx.send(()); continue; } }; @@ -949,6 +953,24 @@ async fn run_event_loop( ); tracing::debug!(%slot_id, %prediction_id, "Registered prediction"); predictions.insert(slot_id, prediction); + let pending_cancel = pending_cancellations.remove(&prediction_id); + let _ = registered_tx.send(()); + if pending_cancel { + tracing::info!( + target: "coglet::prediction", + %prediction_id, + %slot_id, + "Applying pending cancellation" + ); + let mut writer = ctrl_writer.lock().await; + if let Err(e) = writer.send(ControlRequest::Cancel { slot: slot_id }).await { + tracing::error!( + %slot_id, + error = %e, + "Failed to send pending cancel request to worker" + ); + } + } } Some((slot_id, result)) = slot_msg_rx.recv() => { diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index d40f1b38d7..10c87fe83e 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -92,6 +92,7 @@ pub enum PredictionStreamEvent { pub struct PredictionStreamReplay { pub replay: Vec, + pub skipped: u64, pub receiver: tokio::sync::broadcast::Receiver, } @@ -144,6 +145,7 @@ pub struct Prediction { completion: Arc, stream_tx: tokio::sync::broadcast::Sender, stream_history: Vec, + stream_history_skipped: u64, /// User-emitted metrics. Merged with system metrics (predict_time) in terminal response. metrics: HashMap, } @@ -165,6 +167,7 @@ impl Prediction { completion: Arc::new(Notify::new()), stream_tx, stream_history: Vec::new(), + stream_history_skipped: 0, metrics: HashMap::new(), } } @@ -184,6 +187,7 @@ impl Prediction { pub fn subscribe_stream_replay(&self) -> PredictionStreamReplay { PredictionStreamReplay { replay: self.stream_history.clone(), + skipped: self.stream_history_skipped, receiver: self.stream_tx.subscribe(), } } @@ -195,6 +199,7 @@ impl Prediction { fn emit_stream_event(&mut self, event: PredictionStreamEvent) { if self.stream_history.len() == MAX_STREAM_HISTORY_EVENTS { self.stream_history.remove(0); + self.stream_history_skipped += 1; } self.stream_history.push(event.clone()); let _ = self.stream_tx.send(event); @@ -710,6 +715,7 @@ mod tests { .collect(); assert_eq!(events, vec!["start", "output", "completed"]); + assert_eq!(replay.skipped, 0); assert_eq!( replay.replay[1].json_data(), serde_json::json!({"chunk":"hello","index":0}) @@ -729,6 +735,7 @@ mod tests { let replay = prediction.subscribe_stream_replay(); assert_eq!(replay.replay.len(), MAX_STREAM_HISTORY_EVENTS); + assert_eq!(replay.skipped, 77); assert_eq!( replay.replay[0].json_data(), serde_json::json!({"chunk":76,"index":76}) diff --git a/crates/coglet/src/service.rs b/crates/coglet/src/service.rs index 80a42676bf..30a1c85163 100644 --- a/crates/coglet/src/service.rs +++ b/crates/coglet/src/service.rs @@ -79,7 +79,7 @@ struct PredictionEntry { prediction: Arc>, cancel_token: CancellationToken, input: serde_json::Value, - respond_async: bool, + cancel_on_stream_drop: bool, } /// Handle to a submitted prediction for cancellation on disconnect. @@ -110,6 +110,7 @@ impl PredictionHandle { pub struct PredictionStreamSubscription { id: String, replay: Vec, + skipped: u64, receiver: tokio::sync::broadcast::Receiver, guard: PredictionStreamGuard, } @@ -123,22 +124,23 @@ impl PredictionStreamSubscription { self, ) -> ( Vec, + u64, tokio::sync::broadcast::Receiver, PredictionStreamGuard, ) { - (self.replay, self.receiver, self.guard) + (self.replay, self.skipped, self.receiver, self.guard) } } pub struct PredictionStreamGuard { id: String, service: Arc, - respond_async: bool, + cancel_on_stream_drop: bool, } impl Drop for PredictionStreamGuard { fn drop(&mut self) { - if self.respond_async { + if !self.cancel_on_stream_drop { return; } @@ -459,7 +461,7 @@ impl PredictionService { id: String, input: serde_json::Value, webhook: Option, - respond_async: bool, + cancel_on_stream_drop: bool, ) -> Result<(PredictionHandle, UnregisteredPredictionSlot), CreatePredictionError> { let health = *self.health.read().await; if health != Health::Ready { @@ -487,7 +489,7 @@ impl PredictionService { prediction: prediction_arc, cancel_token: cancel_token.clone(), input, - respond_async, + cancel_on_stream_drop, }, ); @@ -521,15 +523,16 @@ impl PredictionService { ) -> Option { let entry = self.predictions.get(id)?; let stream = entry.prediction.lock().ok()?.subscribe_stream_replay(); - let respond_async = entry.respond_async; + let cancel_on_stream_drop = entry.cancel_on_stream_drop; Some(PredictionStreamSubscription { id: id.to_string(), replay: stream.replay, + skipped: stream.skipped, receiver: stream.receiver, guard: PredictionStreamGuard { id: id.to_string(), service: Arc::clone(self), - respond_async, + cancel_on_stream_drop, }, }) } @@ -626,6 +629,22 @@ impl PredictionService { ))); } + let was_cancelled_before_send = try_lock_prediction(&prediction_arc) + .map(|p| p.is_canceled()) + .unwrap_or(false); + if was_cancelled_before_send + && let Err(e) = state + .orchestrator + .cancel_by_prediction_id(&prediction_id) + .await + { + tracing::error!( + prediction_id = %prediction_id, + error = %e, + "Failed to forward pending cancellation after registration" + ); + } + // Wait for prediction to complete // Check if already terminal first to avoid race with fast completions let (already_terminal, completion) = { @@ -890,6 +909,58 @@ mod tests { } } + struct CancelRecordingOrchestrator { + cancel_count: AtomicUsize, + prediction: std::sync::Mutex>>>, + } + + impl CancelRecordingOrchestrator { + fn new() -> Self { + Self { + cancel_count: AtomicUsize::new(0), + prediction: std::sync::Mutex::new(None), + } + } + + fn cancel_count(&self) -> usize { + self.cancel_count.load(Ordering::SeqCst) + } + } + + #[async_trait::async_trait] + impl Orchestrator for CancelRecordingOrchestrator { + async fn register_prediction( + &self, + slot_id: SlotId, + prediction: Arc>, + idle_sender: tokio::sync::oneshot::Sender, + ) { + *self.prediction.lock().unwrap() = Some(prediction); + let _ = idle_sender.send(InactiveSlotIdleToken::new(slot_id).activate()); + } + + async fn cancel_by_prediction_id( + &self, + _prediction_id: &str, + ) -> Result<(), crate::orchestrator::OrchestratorError> { + self.cancel_count.fetch_add(1, Ordering::SeqCst); + if let Some(prediction) = self.prediction.lock().unwrap().as_ref() { + prediction.lock().unwrap().set_canceled(); + } + Ok(()) + } + + async fn healthcheck( + &self, + ) -> Result { + Ok(HealthcheckResult::healthy()) + } + + async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> { + Ok(()) + } + } + async fn create_test_pool(num_slots: usize) -> Arc { use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::SlotRequest; @@ -1074,9 +1145,31 @@ mod tests { svc.set_orchestrator(pool, orchestrator).await; svc.set_health(Health::Ready).await; + let (_handle, _slot) = svc + .submit_prediction("sync-stream".to_string(), serde_json::json!({}), None, true) + .await + .unwrap(); + + let subscription = svc.subscribe_prediction_stream("sync-stream").unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 1); + } + + #[tokio::test] + async fn dropping_async_json_stream_subscription_does_not_cancel_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + let (_handle, _slot) = svc .submit_prediction( - "sync-stream".to_string(), + "async-json-stream".to_string(), serde_json::json!({}), None, false, @@ -1084,15 +1177,17 @@ mod tests { .await .unwrap(); - let subscription = svc.subscribe_prediction_stream("sync-stream").unwrap(); + let subscription = svc + .subscribe_prediction_stream("async-json-stream") + .unwrap(); drop(subscription); tokio::time::sleep(Duration::from_millis(25)).await; - assert_eq!(orchestrator_ref.cancel_count(), 1); + assert_eq!(orchestrator_ref.cancel_count(), 0); } #[tokio::test] - async fn dropping_async_stream_subscription_does_not_cancel_prediction() { + async fn dropping_live_sse_stream_subscription_cancels_prediction() { let svc = Arc::new(PredictionService::new_no_pool()); let pool = create_test_pool(1).await; let orchestrator = Arc::new(CountingCancelOrchestrator::new()); @@ -1103,7 +1198,7 @@ mod tests { let (_handle, _slot) = svc .submit_prediction( - "async-stream".to_string(), + "live-sse-stream".to_string(), serde_json::json!({}), None, true, @@ -1111,11 +1206,11 @@ mod tests { .await .unwrap(); - let subscription = svc.subscribe_prediction_stream("async-stream").unwrap(); + let subscription = svc.subscribe_prediction_stream("live-sse-stream").unwrap(); drop(subscription); tokio::time::sleep(Duration::from_millis(25)).await; - assert_eq!(orchestrator_ref.cancel_count(), 0); + assert_eq!(orchestrator_ref.cancel_count(), 1); } #[tokio::test] @@ -1133,7 +1228,7 @@ mod tests { "completed-sync-stream".to_string(), serde_json::json!({}), None, - false, + true, ) .await .unwrap(); @@ -1208,6 +1303,42 @@ mod tests { assert_eq!(orch_ref.register_count(), 1); } + #[tokio::test] + async fn predict_forwards_cancel_token_set_before_registration() { + let svc = PredictionService::new_no_pool(); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CancelRecordingOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (handle, slot) = svc + .submit_prediction( + "pre-register-cancel".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + handle.cancel_token().cancel(); + + let result = tokio::time::timeout( + Duration::from_millis(100), + svc.predict( + slot, + serde_json::json!({}), + std::collections::HashMap::new(), + ), + ) + .await + .expect("prediction should observe cancellation after registration"); + + assert!(matches!(result, Err(PredictionError::Cancelled))); + assert_eq!(orchestrator_ref.cancel_count(), 1); + } + #[tokio::test] async fn health_shows_busy_when_all_slots_used() { let svc = PredictionService::new_no_pool(); diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index 1cdc6f6c18..869f0307a6 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -430,7 +430,7 @@ async fn create_prediction_with_id( prediction_id.clone(), input.clone(), webhook_sender, - response_mode != PredictionResponseMode::SyncJson, + response_mode != PredictionResponseMode::AsyncJson, ) .await { @@ -462,6 +462,12 @@ async fn create_prediction_with_id( // Async mode: spawn background task, return immediately if response_mode != PredictionResponseMode::SyncJson { + let sse_subscription = if response_mode == PredictionResponseMode::AsyncSse { + service.subscribe_prediction_stream(&prediction_id) + } else { + None + }; + let service_clone = Arc::clone(&service); let id_for_cleanup = prediction_id.clone(); let context_async = context.clone(); @@ -475,7 +481,14 @@ async fn create_prediction_with_id( }); if response_mode == PredictionResponseMode::AsyncSse { - return stream_prediction_response(service, &prediction_id); + let Some(subscription) = sse_subscription else { + return ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({"error": "Prediction not found"})), + ) + .into_response(); + }; + return stream_prediction_subscription_response(subscription); } return ( @@ -634,10 +647,11 @@ fn stream_event_to_sse(event: crate::prediction::PredictionStreamEvent) -> Event fn prediction_sse_stream( subscription: PredictionStreamSubscription, ) -> impl futures::Stream> { - let (replay, receiver, guard) = subscription.into_parts(); + let (replay, replay_skipped, receiver, guard) = subscription.into_parts(); struct StreamState { replay: std::collections::VecDeque, + replay_skipped: u64, receiver: tokio::sync::broadcast::Receiver, _guard: crate::service::PredictionStreamGuard, done: bool, @@ -646,6 +660,7 @@ fn prediction_sse_stream( futures::stream::unfold( StreamState { replay: replay.into(), + replay_skipped, receiver, _guard: guard, done: false, @@ -655,23 +670,44 @@ fn prediction_sse_stream( return None; } + if state.replay_skipped > 0 { + let skipped = state.replay_skipped; + state.replay_skipped = 0; + state.done = true; + let event = Event::default() + .event("error") + .json_data(serde_json::json!({ + "error": "SSE stream replay truncated; events were dropped", + "skipped": skipped, + })) + .expect("SSE replay truncation error serializes to JSON"); + return Some((Ok(event), state)); + } + if let Some(event) = state.replay.pop_front() { state.done = event.event_name() == "completed"; return Some((Ok(stream_event_to_sse(event)), state)); } - loop { - match state.receiver.recv().await { - Ok(event) => { - state.done = event.event_name() == "completed"; - return Some((Ok(stream_event_to_sse(event)), state)); - } - Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => { - tracing::warn!(skipped, "SSE prediction stream receiver lagged"); - continue; - } - Err(tokio::sync::broadcast::error::RecvError::Closed) => return None, + match state.receiver.recv().await { + Ok(event) => { + state.done = event.event_name() == "completed"; + Some((Ok(stream_event_to_sse(event)), state)) + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => { + tracing::warn!(skipped, "SSE prediction stream receiver lagged"); + state.done = true; + // In the future, this could become backpressure or cursor-based replay. + let event = Event::default() + .event("error") + .json_data(serde_json::json!({ + "error": "SSE stream lagged; events were dropped", + "skipped": skipped, + })) + .expect("SSE lag error serializes to JSON"); + Some((Ok(event), state)) } + Err(tokio::sync::broadcast::error::RecvError::Closed) => None, } }, ) @@ -686,6 +722,10 @@ fn stream_prediction_response(service: Arc, prediction_id: &s .into_response(); }; + stream_prediction_subscription_response(subscription) +} + +fn stream_prediction_subscription_response(subscription: PredictionStreamSubscription) -> Response { Sse::new(prediction_sse_stream(subscription)) .keep_alive( KeepAlive::new() @@ -1120,6 +1160,97 @@ mod tests { "unexpected content-type: {:?}", content_type ); + + let body = response.into_body(); + let bytes = body.collect().await.unwrap().to_bytes(); + let sse = String::from_utf8(bytes.to_vec()).unwrap(); + assert!(sse.contains("event: completed"), "SSE body: {sse}"); + assert!(sse.contains(r#""status":"succeeded""#), "SSE body: {sse}"); + } + + #[tokio::test] + async fn lagged_prediction_sse_stream_emits_error_and_closes() { + let service = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::never_complete()); + service.set_orchestrator(pool, orchestrator).await; + service.set_health(Health::Ready).await; + + let (_handle, slot) = service + .submit_prediction( + "lagged-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let subscription = service + .subscribe_prediction_stream("lagged-stream") + .unwrap(); + + { + let prediction = slot.prediction(); + let mut prediction = prediction.lock().unwrap(); + for index in 0..1030 { + prediction.append_output_chunk(serde_json::json!(index), index); + } + } + + let response = Sse::new(prediction_sse_stream(subscription)).into_response(); + let collected = + tokio::time::timeout(Duration::from_millis(100), response.into_body().collect()) + .await + .expect("lagged SSE stream should close after emitting an error") + .unwrap(); + let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + assert!(sse.contains("event: error"), "SSE body: {sse}"); + assert!(sse.contains("SSE stream lagged"), "SSE body: {sse}"); + assert!(sse.contains("skipped"), "SSE body: {sse}"); + } + + #[tokio::test] + async fn truncated_replay_prediction_sse_stream_emits_error_and_closes() { + let service = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::never_complete()); + service.set_orchestrator(pool, orchestrator).await; + service.set_health(Health::Ready).await; + + let (_handle, slot) = service + .submit_prediction( + "truncated-replay".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + + { + let prediction = slot.prediction(); + let mut prediction = prediction.lock().unwrap(); + for index in 0..1030 { + prediction.append_output_chunk(serde_json::json!(index), index); + } + } + + let subscription = service + .subscribe_prediction_stream("truncated-replay") + .unwrap(); + let response = Sse::new(prediction_sse_stream(subscription)).into_response(); + let collected = + tokio::time::timeout(Duration::from_millis(100), response.into_body().collect()) + .await + .expect("truncated replay SSE stream should close after emitting an error") + .unwrap(); + let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + assert!(sse.contains("event: error"), "SSE body: {sse}"); + assert!( + sse.contains("SSE stream replay truncated"), + "SSE body: {sse}" + ); + assert!(sse.contains("skipped"), "SSE body: {sse}"); } #[tokio::test] diff --git a/integration-tests/tests/sse_streaming_output.txtar b/integration-tests/tests/sse_streaming_output.txtar index 5f567cfd6e..32c757008f 100644 --- a/integration-tests/tests/sse_streaming_output.txtar +++ b/integration-tests/tests/sse_streaming_output.txtar @@ -2,10 +2,9 @@ [short] skip 'requires Docker build' -cog build -t $TEST_IMAGE cog serve --upload-url http://unused/ -curl -N -H Accept:text/event-stream PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' +curl -H Accept:text/event-stream PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' stdout 'event: output' stdout 'data: {"chunk":"chunk-1","index":0}' stdout 'event: output' From ff3140eb437f87d0a152d8752f9a4fc777c95ceb Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Mon, 18 May 2026 12:05:00 -0400 Subject: [PATCH 06/11] fix: address SSE review feedback --- crates/coglet/src/prediction.rs | 22 +++++++++---------- crates/coglet/src/service.rs | 3 +++ crates/coglet/src/transport/http/routes.rs | 2 +- .../tests/sse_streaming_output.txtar | 1 + 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index 10c87fe83e..5c3419dc9a 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -1,6 +1,6 @@ //! Prediction state tracking. -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::sync::Arc; use std::time::Instant; @@ -10,7 +10,7 @@ pub use tokio_util::sync::CancellationToken; use crate::bridge::protocol::{LogSource, MetricMode}; use crate::webhook::{WebhookEventType, WebhookSender}; -const MAX_STREAM_HISTORY_EVENTS: usize = 1024; +const STREAM_EVENT_BUFFER_CAPACITY: usize = 1024; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PredictionStatus { @@ -144,7 +144,7 @@ pub struct Prediction { webhook: Option, completion: Arc, stream_tx: tokio::sync::broadcast::Sender, - stream_history: Vec, + stream_history: VecDeque, stream_history_skipped: u64, /// User-emitted metrics. Merged with system metrics (predict_time) in terminal response. metrics: HashMap, @@ -152,7 +152,7 @@ pub struct Prediction { impl Prediction { pub fn new(id: String, webhook: Option) -> Self { - let (stream_tx, _) = tokio::sync::broadcast::channel(1024); + let (stream_tx, _) = tokio::sync::broadcast::channel(STREAM_EVENT_BUFFER_CAPACITY); Self { id, @@ -166,7 +166,7 @@ impl Prediction { webhook, completion: Arc::new(Notify::new()), stream_tx, - stream_history: Vec::new(), + stream_history: VecDeque::new(), stream_history_skipped: 0, metrics: HashMap::new(), } @@ -186,7 +186,7 @@ impl Prediction { pub fn subscribe_stream_replay(&self) -> PredictionStreamReplay { PredictionStreamReplay { - replay: self.stream_history.clone(), + replay: self.stream_history.iter().cloned().collect(), skipped: self.stream_history_skipped, receiver: self.stream_tx.subscribe(), } @@ -197,11 +197,11 @@ impl Prediction { } fn emit_stream_event(&mut self, event: PredictionStreamEvent) { - if self.stream_history.len() == MAX_STREAM_HISTORY_EVENTS { - self.stream_history.remove(0); + if self.stream_history.len() == STREAM_EVENT_BUFFER_CAPACITY { + self.stream_history.pop_front(); self.stream_history_skipped += 1; } - self.stream_history.push(event.clone()); + self.stream_history.push_back(event.clone()); let _ = self.stream_tx.send(event); } @@ -734,14 +734,14 @@ mod tests { let replay = prediction.subscribe_stream_replay(); - assert_eq!(replay.replay.len(), MAX_STREAM_HISTORY_EVENTS); + assert_eq!(replay.replay.len(), STREAM_EVENT_BUFFER_CAPACITY); assert_eq!(replay.skipped, 77); assert_eq!( replay.replay[0].json_data(), serde_json::json!({"chunk":76,"index":76}) ); assert_eq!( - replay.replay[MAX_STREAM_HISTORY_EVENTS - 1].json_data(), + replay.replay[STREAM_EVENT_BUFFER_CAPACITY - 1].json_data(), serde_json::json!({"chunk":1099,"index":1099}) ); } diff --git a/crates/coglet/src/service.rs b/crates/coglet/src/service.rs index 30a1c85163..d7a85726cb 100644 --- a/crates/coglet/src/service.rs +++ b/crates/coglet/src/service.rs @@ -144,6 +144,9 @@ impl Drop for PredictionStreamGuard { return; } + // Prediction cleanup may remove the service entry before the SSE response + // finishes draining. Missing entries deliberately report zero receivers and + // terminal state so this guard cannot cancel an already-cleaned prediction. if self.service.stream_receiver_count(&self.id) == 0 && !self.service.prediction_is_terminal(&self.id) { diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index 869f0307a6..d5bf7e0064 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -430,7 +430,7 @@ async fn create_prediction_with_id( prediction_id.clone(), input.clone(), webhook_sender, - response_mode != PredictionResponseMode::AsyncJson, + response_mode == PredictionResponseMode::AsyncSse, ) .await { diff --git a/integration-tests/tests/sse_streaming_output.txtar b/integration-tests/tests/sse_streaming_output.txtar index 32c757008f..4d3f2102c9 100644 --- a/integration-tests/tests/sse_streaming_output.txtar +++ b/integration-tests/tests/sse_streaming_output.txtar @@ -5,6 +5,7 @@ cog serve --upload-url http://unused/ curl -H Accept:text/event-stream PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' +stdout 'event: start' stdout 'event: output' stdout 'data: {"chunk":"chunk-1","index":0}' stdout 'event: output' From 8c9c9826e8a21628fe4200da23e96e63dfb0ac46 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Mon, 18 May 2026 16:05:13 -0400 Subject: [PATCH 07/11] feat: make prediction streaming opt-in --- architecture/02-schema.md | 73 ++++++++++--------- crates/coglet/src/service.rs | 19 +++++ crates/coglet/src/transport/http/routes.rs | 56 ++++++++++++++ docs/llms.txt | 18 +++-- docs/python.md | 18 +++-- .../tests/sse_requires_streaming_opt_in.txtar | 27 +++++++ .../tests/sse_streaming_output.txtar | 3 +- pkg/schema/openapi.go | 48 ++++++------ pkg/schema/openapi_test.go | 28 +++++++ pkg/schema/python/parser.go | 51 +++++++++++-- pkg/schema/python/parser_test.go | 71 ++++++++++++++++++ pkg/schema/types.go | 7 +- python/cog/__init__.py | 11 +++ python/tests/test_types.py | 14 ++++ 14 files changed, 366 insertions(+), 78 deletions(-) create mode 100644 integration-tests/tests/sse_requires_streaming_opt_in.txtar diff --git a/architecture/02-schema.md b/architecture/02-schema.md index bd217199f3..26bc99cbdf 100644 --- a/architecture/02-schema.md +++ b/architecture/02-schema.md @@ -186,40 +186,43 @@ Each `SchemaType` produces its JSON Schema fragment via `JSONSchema()`: ### Output Types -| Python | SchemaType | JSON Schema | -| -------------------------- | ------------------------ | --------------------------------------------------------------- | -| `str` | `SchemaPrimitive` | `{"type": "string"}` | -| `int` | `SchemaPrimitive` | `{"type": "integer"}` | -| `float` | `SchemaPrimitive` | `{"type": "number"}` | -| `bool` | `SchemaPrimitive` | `{"type": "boolean"}` | -| `Path` | `SchemaPrimitive` | `{"type": "string", "format": "uri"}` | -| `dict` (bare) | `SchemaAny` | `{"type": "object"}` | -| `dict[str, V]` | `SchemaDict` | `{"type": "object", "additionalProperties": V}` | -| `list` (bare) | `SchemaArray(SchemaAny)` | `{"type": "array", "items": {"type": "object"}}` | -| `list[T]` | `SchemaArray` | `{"type": "array", "items": T}` | -| `Annotated[T, cog.Opaque]` | `SchemaPrimitive(TypeAny)` | `{"type": "object"}` | +| Python | SchemaType | JSON Schema | +| -------------------------------- | --------------------------------------- | --------------------------------------------------------------- | +| `str` | `SchemaPrimitive` | `{"type": "string"}` | +| `int` | `SchemaPrimitive` | `{"type": "integer"}` | +| `float` | `SchemaPrimitive` | `{"type": "number"}` | +| `bool` | `SchemaPrimitive` | `{"type": "boolean"}` | +| `Path` | `SchemaPrimitive` | `{"type": "string", "format": "uri"}` | +| `dict` (bare) | `SchemaAny` | `{"type": "object"}` | +| `dict[str, V]` | `SchemaDict` | `{"type": "object", "additionalProperties": V}` | +| `list` (bare) | `SchemaArray(SchemaAny)` | `{"type": "array", "items": {"type": "object"}}` | +| `list[T]` | `SchemaArray` | `{"type": "array", "items": T}` | +| `Annotated[T, cog.Opaque]` | `SchemaPrimitive(TypeAny)` | `{"type": "object"}` | | `Annotated[list[T], cog.Opaque]` | `SchemaArray(SchemaPrimitive(TypeAny))` | `{"type": "array", "items": {"type": "object"}}` | -| `BaseModel` subclass | `SchemaObject` | `{"type": "object", "properties": {...}}` | -| `Iterator[T]` | `SchemaIterator` | `{"type": "array", "items": T, "x-cog-array-type": "iterator"}` | -| `ConcatenateIterator[str]` | `SchemaConcatIterator` | Streaming token output | -| Nested types | Recursive | `dict[str, list[dict[str, int]]]` fully supported | +| `BaseModel` subclass | `SchemaObject` | `{"type": "object", "properties": {...}}` | +| `Iterator[T]` | `SchemaIterator` | `{"type": "array", "items": T, "x-cog-array-type": "iterator"}` | +| `ConcatenateIterator[str]` | `SchemaConcatIterator` | Streaming token output | +| Nested types | Recursive | `dict[str, list[dict[str, int]]]` fully supported | ### Unsupported Output Types -| Python | Error | -| --------------------------- | -------------------------------------------------------------------- | -| `Optional[T]` / `T \| None` | Predictions must succeed with a value or fail with an error | -| `Union[A, B]` | Ambiguous for downstream consumers | +| Python | Error | +| --------------------------- | -------------------------------------------------------------------------------------------------------------------------------- | +| `Optional[T]` / `T \| None` | Predictions must succeed with a value or fail with an error | +| `Union[A, B]` | Ambiguous for downstream consumers | | External package types | Cannot be statically analyzed — define as BaseModel, use .pyi stub, or mark JSON-shaped values with `Annotated[..., cog.Opaque]` | ## Cog-Specific Extensions -| Extension | Purpose | -| --------------------- | ------------------------------------------------- | -| `x-order` | Preserves parameter order from function signature | -| `x-cog-array-type` | Marks iterators vs regular arrays | -| `x-cog-array-display` | Hints for how to display streaming output | -| `x-cog-secret` | Marks sensitive inputs | +| Extension | Purpose | +| --------------------- | --------------------------------------------------- | +| `x-order` | Preserves parameter order from function signature | +| `x-cog-array-type` | Marks iterators vs regular arrays | +| `x-cog-array-display` | Hints for how to display streaming output | +| `x-cog-secret` | Marks sensitive inputs | +| `x-cog-streaming` | Marks prediction operations that accept SSE clients | + +Iterator output types describe the shape of accumulated JSON output. SSE response support is a separate prediction operation capability and is only advertised when the prediction handler opts in with `@cog.streaming`. ## Where the Schema Lives @@ -311,12 +314,12 @@ A simplified example showing a multi-file predictor with structured output: ## Code References -| File | Purpose | -| ----------------------------- | -------------------------------------------------------------------- | -| `pkg/schema/schema_type.go` | `SchemaType` ADT, `ResolveSchemaType()`, `JSONSchema()` generation | -| `pkg/schema/types.go` | `PredictorInfo`, `PrimitiveType`, `FieldType`, `InputField`, imports | -| `pkg/schema/python/` | Tree-sitter Python parser and cross-file resolution | -| `pkg/schema/openapi.go` | OpenAPI document assembly from `PredictorInfo` | -| `pkg/schema/generator.go` | Top-level `Generate()`, `GenerateCombined()`, `Parser` type | -| `pkg/schema/errors.go` | Typed schema error kinds | -| `pkg/image/build.go` | Build-time schema generation entry point and schema file validation | +| File | Purpose | +| --------------------------- | -------------------------------------------------------------------- | +| `pkg/schema/schema_type.go` | `SchemaType` ADT, `ResolveSchemaType()`, `JSONSchema()` generation | +| `pkg/schema/types.go` | `PredictorInfo`, `PrimitiveType`, `FieldType`, `InputField`, imports | +| `pkg/schema/python/` | Tree-sitter Python parser and cross-file resolution | +| `pkg/schema/openapi.go` | OpenAPI document assembly from `PredictorInfo` | +| `pkg/schema/generator.go` | Top-level `Generate()`, `GenerateCombined()`, `Parser` type | +| `pkg/schema/errors.go` | Typed schema error kinds | +| `pkg/image/build.go` | Build-time schema generation entry point and schema file validation | diff --git a/crates/coglet/src/service.rs b/crates/coglet/src/service.rs index d7a85726cb..bd33d9e5b6 100644 --- a/crates/coglet/src/service.rs +++ b/crates/coglet/src/service.rs @@ -226,6 +226,7 @@ pub struct PredictionService { schema: RwLock>, input_validator: RwLock>, train_validator: RwLock>, + supports_prediction_streaming: RwLock, } impl PredictionService { @@ -245,6 +246,7 @@ impl PredictionService { schema: RwLock::new(None), input_validator: RwLock::new(None), train_validator: RwLock::new(None), + supports_prediction_streaming: RwLock::new(false), } } @@ -299,6 +301,10 @@ impl PredictionService { self.train_validator.read().await.is_some() } + pub async fn supports_prediction_streaming(&self) -> bool { + *self.supports_prediction_streaming.read().await + } + /// Get the permit pool from orchestrator. pub async fn pool(&self) -> Option> { if let Some(ref state) = *self.orchestrator.read().await { @@ -359,6 +365,9 @@ impl PredictionService { } pub async fn set_schema(&self, schema: serde_json::Value) { + let supports_prediction_streaming = Self::schema_supports_prediction_streaming(&schema); + *self.supports_prediction_streaming.write().await = supports_prediction_streaming; + // Compile input validators from the schema components let validator = InputValidator::from_openapi_schema(&schema); if let Some(v) = &validator { @@ -382,6 +391,16 @@ impl PredictionService { *self.schema.write().await = Some(schema); } + fn schema_supports_prediction_streaming(schema: &serde_json::Value) -> bool { + schema + .get("paths") + .and_then(|paths| paths.get("/predictions")) + .and_then(|path| path.get("post")) + .and_then(|operation| operation.get("x-cog-streaming")) + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + } + pub async fn schema(&self) -> Option { self.schema.read().await.clone() } diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index d5bf7e0064..6822fdead3 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -245,6 +245,16 @@ fn prediction_response_mode(headers: &HeaderMap) -> PredictionResponseMode { } } +fn streaming_not_supported_response() -> Response { + ( + StatusCode::NOT_ACCEPTABLE, + Json(serde_json::json!({ + "error": "This model does not support streaming responses. Add @cog.streaming to predict() to enable SSE." + })), + ) + .into_response() +} + fn json_response_mode(headers: &HeaderMap) -> PredictionResponseMode { if should_respond_async(headers) { PredictionResponseMode::AsyncJson @@ -330,6 +340,9 @@ async fn create_prediction_idempotent( // Check if prediction with this ID is already in-flight if let Some(response) = service.get_prediction_response(&prediction_id) { if response_mode == PredictionResponseMode::AsyncSse { + if !service.supports_prediction_streaming().await { + return streaming_not_supported_response(); + } return stream_prediction_response(service, &prediction_id); } return (StatusCode::ACCEPTED, Json(response)).into_response(); @@ -386,6 +399,13 @@ async fn create_prediction_with_id( trace_context: TraceContext, is_training: bool, ) -> Response { + if !is_training + && response_mode == PredictionResponseMode::AsyncSse + && !service.supports_prediction_streaming().await + { + return streaming_not_supported_response(); + } + // Strip unknown fields and validate in one pass. Unknown inputs are // silently dropped to match Replicate's historical API behavior. let (stripped, validation_result) = if is_training { @@ -1076,6 +1096,15 @@ mod tests { service } + async fn enable_prediction_streaming(service: &PredictionService) { + service + .set_schema(serde_json::json!({ + "paths": {"/predictions": {"post": {"x-cog-streaming": true}}}, + "components": {"schemas": {"Input": {"type": "object", "properties": {}}}} + })) + .await; + } + #[tokio::test] async fn health_check_ready_with_orchestrator() { let service = create_ready_service().await; @@ -1137,6 +1166,7 @@ mod tests { #[tokio::test] async fn prediction_post_with_sse_accept_returns_sse() { let service = create_ready_service().await; + enable_prediction_streaming(&service).await; let app = routes(service); let response = app @@ -1168,6 +1198,30 @@ mod tests { assert!(sse.contains(r#""status":"succeeded""#), "SSE body: {sse}"); } + #[tokio::test] + async fn prediction_post_with_sse_accept_rejects_when_not_opted_in() { + let service = create_ready_service().await; + let app = routes(service); + + let response = app + .oneshot( + Request::post("/predictions") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_ACCEPTABLE); + let json = response_json(response).await; + assert_eq!( + json["error"], + "This model does not support streaming responses. Add @cog.streaming to predict() to enable SSE." + ); + } + #[tokio::test] async fn lagged_prediction_sse_stream_emits_error_and_closes() { let service = Arc::new(PredictionService::new_no_pool()); @@ -1256,6 +1310,7 @@ mod tests { #[tokio::test] async fn prediction_put_with_sse_accept_returns_sse() { let service = create_ready_service().await; + enable_prediction_streaming(&service).await; let app = routes(service); let response = app @@ -1284,6 +1339,7 @@ mod tests { #[tokio::test] async fn prediction_put_existing_with_sse_accept_returns_sse() { let service = create_ready_service().await; + enable_prediction_streaming(&service).await; let (_handle, _slot) = service .submit_prediction( "existing-sse-put".to_string(), diff --git a/docs/llms.txt b/docs/llms.txt index b38aa99586..ead330f0c2 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -1975,13 +1975,17 @@ class Predictor(BasePredictor): Cog models can stream output as the `predict()` method is running. For example, a language model can output tokens as they're being generated and an image generation model can output images as they are being generated. -To support streaming output in your Cog model, add `from typing import Iterator` to your predict.py file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `predict()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. +To define streaming-shaped output in your Cog model, add `from typing import Iterator` to your predict.py file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `predict()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. + +To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the `predict()` method with `@cog.streaming` or `@streaming` imported from `cog`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. ```py -from cog import BasePredictor, Path from typing import Iterator +from cog import BasePredictor, Path, streaming + class Predictor(BasePredictor): + @streaming def predict(self) -> Iterator[Path]: done = False while not done: @@ -1993,9 +1997,11 @@ If you have an [async `predict()` method](#async-predictors-and-concurrency), us ```py from typing import AsyncIterator -from cog import BasePredictor, Path + +from cog import BasePredictor, Path, streaming class Predictor(BasePredictor): + @streaming async def predict(self) -> AsyncIterator[Path]: done = False while not done: @@ -2006,9 +2012,10 @@ class Predictor(BasePredictor): If you're streaming text output, you can use `ConcatenateIterator` to hint that the output should be concatenated together into a single string. This is useful on Replicate to display the output as a string instead of a list of strings. ```py -from cog import BasePredictor, Path, ConcatenateIterator +from cog import BasePredictor, ConcatenateIterator, streaming class Predictor(BasePredictor): + @streaming def predict(self) -> ConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: @@ -2018,9 +2025,10 @@ class Predictor(BasePredictor): Or for async `predict()` methods, use `AsyncConcatenateIterator`: ```py -from cog import BasePredictor, Path, AsyncConcatenateIterator +from cog import AsyncConcatenateIterator, BasePredictor, streaming class Predictor(BasePredictor): + @streaming async def predict(self) -> AsyncConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: diff --git a/docs/python.md b/docs/python.md index 06d2a8adc3..76c294728d 100644 --- a/docs/python.md +++ b/docs/python.md @@ -259,13 +259,17 @@ class Predictor(BasePredictor): Cog models can stream output as the `predict()` method is running. For example, a language model can output tokens as they're being generated and an image generation model can output images as they are being generated. -To support streaming output in your Cog model, add `from typing import Iterator` to your predict.py file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `predict()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. +To define streaming-shaped output in your Cog model, add `from typing import Iterator` to your predict.py file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `predict()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. + +To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the `predict()` method with `@cog.streaming` or `@streaming` imported from `cog`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. ```py -from cog import BasePredictor, Path from typing import Iterator +from cog import BasePredictor, Path, streaming + class Predictor(BasePredictor): + @streaming def predict(self) -> Iterator[Path]: done = False while not done: @@ -277,9 +281,11 @@ If you have an [async `predict()` method](#async-predictors-and-concurrency), us ```py from typing import AsyncIterator -from cog import BasePredictor, Path + +from cog import BasePredictor, Path, streaming class Predictor(BasePredictor): + @streaming async def predict(self) -> AsyncIterator[Path]: done = False while not done: @@ -290,9 +296,10 @@ class Predictor(BasePredictor): If you're streaming text output, you can use `ConcatenateIterator` to hint that the output should be concatenated together into a single string. This is useful on Replicate to display the output as a string instead of a list of strings. ```py -from cog import BasePredictor, Path, ConcatenateIterator +from cog import BasePredictor, ConcatenateIterator, streaming class Predictor(BasePredictor): + @streaming def predict(self) -> ConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: @@ -302,9 +309,10 @@ class Predictor(BasePredictor): Or for async `predict()` methods, use `AsyncConcatenateIterator`: ```py -from cog import BasePredictor, Path, AsyncConcatenateIterator +from cog import AsyncConcatenateIterator, BasePredictor, streaming class Predictor(BasePredictor): + @streaming async def predict(self) -> AsyncConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: diff --git a/integration-tests/tests/sse_requires_streaming_opt_in.txtar b/integration-tests/tests/sse_requires_streaming_opt_in.txtar new file mode 100644 index 0000000000..7edf04fab1 --- /dev/null +++ b/integration-tests/tests/sse_requires_streaming_opt_in.txtar @@ -0,0 +1,27 @@ +# Test that SSE requires @cog.streaming while undecorated iterators still work as JSON. + +[short] skip 'requires Docker build' + +cog serve --upload-url http://unused/ + +cog predict -i count=2 +stdout '"output":\["chunk-0","chunk-1"\]' + +! curl -H Accept:text/event-stream PUT /predictions/no-streaming '{"id":"no-streaming","input":{"count":2}}' +stderr 'This model does not support streaming responses' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from typing import Iterator + +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, count: int) -> Iterator[str]: + for index in range(count): + yield f"chunk-{index}" diff --git a/integration-tests/tests/sse_streaming_output.txtar b/integration-tests/tests/sse_streaming_output.txtar index 4d3f2102c9..23646106f5 100644 --- a/integration-tests/tests/sse_streaming_output.txtar +++ b/integration-tests/tests/sse_streaming_output.txtar @@ -22,10 +22,11 @@ predict: "predict.py:Predictor" import time from typing import Iterator -from cog import BasePredictor +from cog import BasePredictor, streaming class Predictor(BasePredictor): + @streaming def predict(self) -> Iterator[str]: time.sleep(0.25) yield "chunk-1" diff --git a/pkg/schema/openapi.go b/pkg/schema/openapi.go index 9435ea8248..f15f4471ca 100644 --- a/pkg/schema/openapi.go +++ b/pkg/schema/openapi.go @@ -187,38 +187,40 @@ func buildOpenAPISpec(info *PredictorInfo) map[string]any { }) // Main endpoint (predict or train) - paths.Set(endpoint, map[string]any{ - "post": map[string]any{ - "summary": summary, - "description": description, - "operationId": opID, - "requestBody": map[string]any{ + mainOperation := map[string]any{ + "summary": summary, + "description": description, + "operationId": opID, + "requestBody": map[string]any{ + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{"$ref": requestRef}, + }, + }, + }, + "responses": map[string]any{ + "200": map[string]any{ + "description": "Successful Response", "content": map[string]any{ "application/json": map[string]any{ - "schema": map[string]any{"$ref": requestRef}, + "schema": map[string]any{"$ref": responseRef}, }, }, }, - "responses": map[string]any{ - "200": map[string]any{ - "description": "Successful Response", - "content": map[string]any{ - "application/json": map[string]any{ - "schema": map[string]any{"$ref": responseRef}, - }, - }, - }, - "422": map[string]any{ - "description": "Validation Error", - "content": map[string]any{ - "application/json": map[string]any{ - "schema": map[string]any{"$ref": "#/components/schemas/HTTPValidationError"}, - }, + "422": map[string]any{ + "description": "Validation Error", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{"$ref": "#/components/schemas/HTTPValidationError"}, }, }, }, }, - }) + } + if !isTrain && info.SupportsStreaming { + mainOperation["x-cog-streaming"] = true + } + paths.Set(endpoint, map[string]any{"post": mainOperation}) // Cancel endpoint paths.Set(cancelEP, map[string]any{ diff --git a/pkg/schema/openapi_test.go b/pkg/schema/openapi_test.go index a3a9f7a8a6..9cc037c684 100644 --- a/pkg/schema/openapi_test.go +++ b/pkg/schema/openapi_test.go @@ -657,6 +657,34 @@ func TestOutputConcatenateIterator(t *testing.T) { assert.Equal(t, "concatenate", output["x-cog-array-display"]) } +func TestPredictionOperationIncludesStreamingExtensionWhenEnabled(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + info := &PredictorInfo{ + Inputs: inputs, + Output: SchemaIteratorOf(SchemaPrim(TypeString)), + Mode: ModePredict, + SupportsStreaming: true, + } + + spec := parseSpec(t, info) + post := getPath(spec, "paths", "/predictions", "post").(map[string]any) + assert.Equal(t, true, post["x-cog-streaming"]) +} + +func TestPredictionOperationOmitsStreamingExtensionByDefault(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + info := &PredictorInfo{ + Inputs: inputs, + Output: SchemaIteratorOf(SchemaPrim(TypeString)), + Mode: ModePredict, + } + + spec := parseSpec(t, info) + post := getPath(spec, "paths", "/predictions", "post").(map[string]any) + _, ok := post["x-cog-streaming"] + assert.False(t, ok) +} + func TestOutputObject(t *testing.T) { inputs := NewOrderedMap[string, InputField]() fields := NewOrderedMap[string, SchemaField]() diff --git a/pkg/schema/python/parser.go b/pkg/schema/python/parser.go index c711ea0960..3fbac6bf65 100644 --- a/pkg/schema/python/parser.go +++ b/pkg/schema/python/parser.go @@ -60,10 +60,15 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi methodName = "train" } - funcNode, err := findTargetFunction(root, source, predictRef, methodName) + targetNode, err := findTargetFunction(root, source, predictRef, methodName) if err != nil { return nil, err } + supportsStreaming := functionSupportsStreaming(targetNode, source, imports) + funcNode := UnwrapFunction(targetNode) + if funcNode == nil { + return nil, schema.WrapError(schema.ErrParse, "target is not a function", nil) + } // 6. Check if method (has self first param) paramsNode := funcNode.ChildByFieldName("parameters") @@ -100,9 +105,10 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi } return &schema.PredictorInfo{ - Inputs: inputs, - Output: output, - Mode: mode, + Inputs: inputs, + Output: output, + Mode: mode, + SupportsStreaming: supportsStreaming, }, nil } @@ -651,6 +657,39 @@ func UnwrapFunction(node *sitter.Node) *sitter.Node { return nil } +func functionSupportsStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + if node.Type() != "decorated_definition" { + return false + } + for _, child := range NamedChildren(node) { + if child.Type() != "decorator" { + continue + } + if decoratorIsCogStreaming(child, source, imports) { + return true + } + } + return false +} + +func decoratorIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + for _, child := range NamedChildren(node) { + switch child.Type() { + case "attribute": + return Content(child, source) == "cog.streaming" + case "identifier": + if Content(child, source) != "streaming" { + return false + } + entry, ok := imports.Names.Get("streaming") + return ok && entry.Module == "cog" && entry.Original == "streaming" + case "call": + return false + } + } + return false +} + func InheritsFromBaseModel(classNode *sitter.Node, source []byte, imports *schema.ImportContext) bool { supers := classNode.ChildByFieldName("superclasses") if supers == nil { @@ -1205,7 +1244,7 @@ func findTargetFunction(root *sitter.Node, source []byte, predictRef, methodName if nameNode != nil { name := Content(nameNode, source) if name == predictRef || name == methodName { - return funcNode, nil + return child, nil } } } @@ -1226,7 +1265,7 @@ func findMethodInClass(classNode *sitter.Node, source []byte, className, methodN } nameNode := funcNode.ChildByFieldName("name") if nameNode != nil && Content(nameNode, source) == methodName { - return funcNode, nil + return child, nil } } diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index 2b3d59445a..f82058e36a 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -455,6 +455,77 @@ class Predictor(BasePredictor): require.Equal(t, schema.ErrConcatIteratorNotStr, se.Kind) } +func TestStreamingDecoratorQualifiedOptIn(t *testing.T) { + source := ` +import cog +from typing import Iterator + +class Predictor(cog.BasePredictor): + @cog.streaming + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.True(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorImportedOptIn(t *testing.T) { + source := ` +from cog import BasePredictor, streaming +from typing import Iterator + +class Predictor(BasePredictor): + @streaming + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.True(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorIgnoredWhenNotFromCog(t *testing.T) { + source := ` +from other import streaming +from typing import Iterator +from cog import BasePredictor + +class Predictor(BasePredictor): + @streaming + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.False(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorParameterizedFormIgnored(t *testing.T) { + source := ` +import cog +from typing import Iterator + +class Predictor(cog.BasePredictor): + @cog.streaming() + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.False(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorClassLevelIgnored(t *testing.T) { + source := ` +import cog +from typing import Iterator + +@cog.streaming +class Predictor(cog.BasePredictor): + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.False(t, info.SupportsStreaming) +} + func TestListOutput(t *testing.T) { source := ` from cog import BasePredictor, Path diff --git a/pkg/schema/types.go b/pkg/schema/types.go index 5a7c7ce611..1b0250e8f7 100644 --- a/pkg/schema/types.go +++ b/pkg/schema/types.go @@ -197,9 +197,10 @@ func (f *InputField) IsRequired() bool { // PredictorInfo is the top-level extraction result. type PredictorInfo struct { - Inputs *OrderedMap[string, InputField] - Output SchemaType - Mode Mode + Inputs *OrderedMap[string, InputField] + Output SchemaType + Mode Mode + SupportsStreaming bool } // TypeAnnotation is a parsed Python type annotation (intermediate, before resolution). diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 6c1fb8ce44..f7f4e1d461 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -20,6 +20,8 @@ def predict( """ import sys as _sys +from collections.abc import Callable +from typing import TypeVar from coglet import CancelationException as CancelationException @@ -38,6 +40,14 @@ def predict( URLPath, ) +F = TypeVar("F", bound=Callable[..., object]) + + +def streaming(fn: F) -> F: + """Mark a predict handler as supporting streaming responses.""" + fn.__cog_streaming__ = True # type: ignore[attr-defined] + return fn + # --------------------------------------------------------------------------- # Backwards-compatibility shim: ExperimentalFeatureWarning @@ -133,6 +143,7 @@ def current_scope() -> object: "CancelationException", # Metrics "current_scope", + "streaming", # Deprecated compat shims "ExperimentalFeatureWarning", "emit_metric", diff --git a/python/tests/test_types.py b/python/tests/test_types.py index 653658ce11..6b4aead4d7 100644 --- a/python/tests/test_types.py +++ b/python/tests/test_types.py @@ -10,6 +10,7 @@ Path, Secret, URLFile, + streaming, ) @@ -132,3 +133,16 @@ def test_async_concatenate_iterator_is_abstract(self) -> None: from typing import AsyncIterator assert issubclass(AsyncConcatenateIterator, AsyncIterator) + + +class TestStreamingDecorator: + """Tests for the streaming opt-in decorator.""" + + def test_streaming_marks_function_and_returns_same_object(self) -> None: + def predict() -> str: + return "ok" + + decorated = streaming(predict) + + assert decorated is predict + assert predict.__cog_streaming__ is True # type: ignore[attr-defined] From 38665ac969e8bfd05fe3c66cc69d4d1a3795abf5 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Mon, 18 May 2026 16:19:26 -0400 Subject: [PATCH 08/11] fix: match iterator CLI output in SSE opt-in test --- integration-tests/tests/sse_requires_streaming_opt_in.txtar | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/integration-tests/tests/sse_requires_streaming_opt_in.txtar b/integration-tests/tests/sse_requires_streaming_opt_in.txtar index 7edf04fab1..7ac62c51a0 100644 --- a/integration-tests/tests/sse_requires_streaming_opt_in.txtar +++ b/integration-tests/tests/sse_requires_streaming_opt_in.txtar @@ -5,7 +5,8 @@ cog serve --upload-url http://unused/ cog predict -i count=2 -stdout '"output":\["chunk-0","chunk-1"\]' +stdout '"chunk-0"' +stdout '"chunk-1"' ! curl -H Accept:text/event-stream PUT /predictions/no-streaming '{"id":"no-streaming","input":{"count":2}}' stderr 'This model does not support streaming responses' From aee98854a2c20fca627d1f050138f54d3b3b5891 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 20 May 2026 11:52:46 -0400 Subject: [PATCH 09/11] fix: address streaming SSE review feedback --- crates/coglet/src/orchestrator.rs | 54 ++++-- crates/coglet/src/prediction.rs | 30 ++-- crates/coglet/src/service.rs | 137 +++++++++++++--- crates/coglet/src/transport/http/routes.rs | 182 +++++++++++++++++++-- docs/llms.txt | 2 +- docs/python.md | 2 +- examples/streaming-text/predict.py | 3 +- pkg/schema/openapi_test.go | 8 +- pkg/schema/python/parser.go | 34 +++- pkg/schema/python/parser_test.go | 32 +++- python/cog/__init__.py | 25 ++- python/tests/test_types.py | 9 + 12 files changed, 440 insertions(+), 78 deletions(-) diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index 8fb6735eb9..b49ba1fc5e 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -29,6 +29,8 @@ use crate::bridge::transport::create_transport; use crate::permit::{InactiveSlotIdleToken, PermitPool, SlotIdleToken}; use crate::prediction::Prediction; +const MAX_PENDING_CANCELLATIONS: usize = 1000; + /// Upload a file to a signed endpoint, returning the final URL. /// /// Matches the behavior of Python cog's `put_file_to_signed_endpoint`: @@ -353,12 +355,12 @@ pub struct OrchestratorReady { pub setup_logs: String, } -type RegisterPredictionMessage = ( - SlotId, - Arc>, - tokio::sync::oneshot::Sender, - tokio::sync::oneshot::Sender<()>, -); +struct RegisterPredictionMessage { + slot_id: SlotId, + prediction: Arc>, + idle_sender: tokio::sync::oneshot::Sender, + registered_ack: tokio::sync::oneshot::Sender<()>, +} pub struct OrchestratorHandle { child: Child, @@ -381,7 +383,12 @@ impl Orchestrator for OrchestratorHandle { let (ack_tx, ack_rx) = tokio::sync::oneshot::channel(); let _ = self .register_tx - .send((slot_id, prediction, idle_sender, ack_tx)) + .send(RegisterPredictionMessage { + slot_id, + prediction, + idle_sender, + registered_ack: ack_tx, + }) .await; let _ = ack_rx.await; } @@ -693,6 +700,18 @@ pub async fn spawn_worker( }) } +fn record_pending_cancellation(pending_cancellations: &mut HashSet, prediction_id: String) { + if pending_cancellations.len() >= MAX_PENDING_CANCELLATIONS { + tracing::warn!( + prediction_id = %prediction_id, + cap = MAX_PENDING_CANCELLATIONS, + "Dropping pending cancellation because the pending cancellation buffer is full" + ); + return; + } + pending_cancellations.insert(prediction_id); +} + #[allow(clippy::too_many_arguments)] async fn run_event_loop( mut ctrl_reader: FramedRead>, @@ -926,18 +945,18 @@ async fn run_event_loop( } None => { tracing::debug!(%prediction_id, "Cancel requested for unknown prediction; storing pending cancellation"); - pending_cancellations.insert(prediction_id); + record_pending_cancellation(&mut pending_cancellations, prediction_id); } } } - Some((slot_id, prediction, idle_sender, registered_tx)) = register_rx.recv() => { + Some(RegisterPredictionMessage { slot_id, prediction, idle_sender, registered_ack }) = register_rx.recv() => { let prediction_id = match try_lock_prediction(&prediction) { Some(p) => p.id().to_string(), None => { // Mutex poisoned during registration - prediction already failed tracing::error!(%slot_id, "Prediction mutex poisoned during registration"); - let _ = registered_tx.send(()); + let _ = registered_ack.send(()); continue; } }; @@ -954,7 +973,7 @@ async fn run_event_loop( tracing::debug!(%slot_id, %prediction_id, "Registered prediction"); predictions.insert(slot_id, prediction); let pending_cancel = pending_cancellations.remove(&prediction_id); - let _ = registered_tx.send(()); + let _ = registered_ack.send(()); if pending_cancel { tracing::info!( target: "coglet::prediction", @@ -1277,6 +1296,19 @@ mod tests { assert_eq!(result.into_values(), Vec::::new()); } + #[test] + fn record_pending_cancellation_caps_stored_ids() { + let mut pending = HashSet::new(); + for index in 0..MAX_PENDING_CANCELLATIONS { + record_pending_cancellation(&mut pending, format!("pred-{index}")); + } + + record_pending_cancellation(&mut pending, "overflow".to_string()); + + assert_eq!(pending.len(), MAX_PENDING_CANCELLATIONS); + assert!(!pending.contains("overflow")); + } + #[test] fn wrap_outputs_schema_array_single_item() { // List[Path] with num_outputs=1 → ["url"] not "url" diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index 5c3419dc9a..c7a4ddf0d7 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -10,7 +10,8 @@ pub use tokio_util::sync::CancellationToken; use crate::bridge::protocol::{LogSource, MetricMode}; use crate::webhook::{WebhookEventType, WebhookSender}; -const STREAM_EVENT_BUFFER_CAPACITY: usize = 1024; +pub const STREAM_CHANNEL_CAPACITY: usize = 1024; +pub const STREAM_HISTORY_CAPACITY: usize = 1024; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PredictionStatus { @@ -90,10 +91,12 @@ pub enum PredictionStreamEvent { }, } +pub type SharedPredictionStreamEvent = Arc; + pub struct PredictionStreamReplay { - pub replay: Vec, + pub replay: VecDeque, pub skipped: u64, - pub receiver: tokio::sync::broadcast::Receiver, + pub receiver: tokio::sync::broadcast::Receiver, } impl PredictionStreamEvent { @@ -143,8 +146,8 @@ pub struct Prediction { error: Option, webhook: Option, completion: Arc, - stream_tx: tokio::sync::broadcast::Sender, - stream_history: VecDeque, + stream_tx: tokio::sync::broadcast::Sender, + stream_history: VecDeque, stream_history_skipped: u64, /// User-emitted metrics. Merged with system metrics (predict_time) in terminal response. metrics: HashMap, @@ -152,7 +155,7 @@ pub struct Prediction { impl Prediction { pub fn new(id: String, webhook: Option) -> Self { - let (stream_tx, _) = tokio::sync::broadcast::channel(STREAM_EVENT_BUFFER_CAPACITY); + let (stream_tx, _) = tokio::sync::broadcast::channel(STREAM_CHANNEL_CAPACITY); Self { id, @@ -180,13 +183,15 @@ impl Prediction { self.cancel_token.clone() } - pub fn subscribe_stream(&self) -> tokio::sync::broadcast::Receiver { + pub fn subscribe_stream( + &self, + ) -> tokio::sync::broadcast::Receiver { self.stream_tx.subscribe() } pub fn subscribe_stream_replay(&self) -> PredictionStreamReplay { PredictionStreamReplay { - replay: self.stream_history.iter().cloned().collect(), + replay: self.stream_history.clone(), skipped: self.stream_history_skipped, receiver: self.stream_tx.subscribe(), } @@ -197,11 +202,12 @@ impl Prediction { } fn emit_stream_event(&mut self, event: PredictionStreamEvent) { - if self.stream_history.len() == STREAM_EVENT_BUFFER_CAPACITY { + if self.stream_history.len() == STREAM_HISTORY_CAPACITY { self.stream_history.pop_front(); self.stream_history_skipped += 1; } - self.stream_history.push_back(event.clone()); + let event = Arc::new(event); + self.stream_history.push_back(Arc::clone(&event)); let _ = self.stream_tx.send(event); } @@ -734,14 +740,14 @@ mod tests { let replay = prediction.subscribe_stream_replay(); - assert_eq!(replay.replay.len(), STREAM_EVENT_BUFFER_CAPACITY); + assert_eq!(replay.replay.len(), STREAM_HISTORY_CAPACITY); assert_eq!(replay.skipped, 77); assert_eq!( replay.replay[0].json_data(), serde_json::json!({"chunk":76,"index":76}) ); assert_eq!( - replay.replay[STREAM_EVENT_BUFFER_CAPACITY - 1].json_data(), + replay.replay[STREAM_HISTORY_CAPACITY - 1].json_data(), serde_json::json!({"chunk":1099,"index":1099}) ); } diff --git a/crates/coglet/src/service.rs b/crates/coglet/src/service.rs index bd33d9e5b6..5904b20049 100644 --- a/crates/coglet/src/service.rs +++ b/crates/coglet/src/service.rs @@ -20,7 +20,10 @@ use crate::health::{Health, SetupResult}; use crate::input_validation::InputValidator; use crate::orchestrator::{HealthcheckResult, Orchestrator}; use crate::permit::{PermitPool, PredictionSlot, UnregisteredPredictionSlot}; -use crate::prediction::{CancellationToken, Prediction, PredictionStatus}; +use crate::prediction::{ + CancellationToken, Prediction, PredictionStatus, STREAM_CHANNEL_CAPACITY, + SharedPredictionStreamEvent, +}; use crate::predictor::{PredictionError, PredictionOutput, PredictionResult}; use crate::version::VersionInfo; use crate::webhook::WebhookSender; @@ -50,6 +53,16 @@ pub enum CreatePredictionError { AtCapacity, } +const MAX_STREAM_SUBSCRIBERS: usize = STREAM_CHANNEL_CAPACITY; + +#[derive(Debug, thiserror::Error)] +pub enum SubscribePredictionStreamError { + #[error("Prediction not found")] + NotFound, + #[error("Too many stream subscribers")] + TooManySubscribers, +} + /// Snapshot of service health for transports to query. #[derive(Debug, Clone)] pub struct HealthSnapshot { @@ -109,9 +122,9 @@ impl PredictionHandle { pub struct PredictionStreamSubscription { id: String, - replay: Vec, + replay: std::collections::VecDeque, skipped: u64, - receiver: tokio::sync::broadcast::Receiver, + receiver: tokio::sync::broadcast::Receiver, guard: PredictionStreamGuard, } @@ -123,9 +136,9 @@ impl PredictionStreamSubscription { pub fn into_parts( self, ) -> ( - Vec, + std::collections::VecDeque, u64, - tokio::sync::broadcast::Receiver, + tokio::sync::broadcast::Receiver, PredictionStreamGuard, ) { (self.replay, self.skipped, self.receiver, self.guard) @@ -542,17 +555,30 @@ impl PredictionService { pub fn subscribe_prediction_stream( self: &Arc, id: &str, - ) -> Option { - let entry = self.predictions.get(id)?; - let stream = entry.prediction.lock().ok()?.subscribe_stream_replay(); + ) -> Result { + let entry = self + .predictions + .get(id) + .ok_or(SubscribePredictionStreamError::NotFound)?; + let stream = { + let prediction = entry + .prediction + .lock() + .map_err(|_| SubscribePredictionStreamError::NotFound)?; + if prediction.stream_receiver_count() >= MAX_STREAM_SUBSCRIBERS { + return Err(SubscribePredictionStreamError::TooManySubscribers); + } + prediction.subscribe_stream_replay() + }; let cancel_on_stream_drop = entry.cancel_on_stream_drop; - Some(PredictionStreamSubscription { - id: id.to_string(), + let id = id.to_string(); + Ok(PredictionStreamSubscription { + id: id.clone(), replay: stream.replay, skipped: stream.skipped, receiver: stream.receiver, guard: PredictionStreamGuard { - id: id.to_string(), + id, service: Arc::clone(self), cancel_on_stream_drop, }, @@ -742,15 +768,7 @@ impl PredictionService { .ok() .and_then(|guard| guard.as_ref().map(|s| Arc::clone(&s.orchestrator))); if let Some(orch) = orchestrator { - tokio::spawn(async move { - if let Err(e) = orch.cancel_by_prediction_id(&id_owned).await { - tracing::error!( - prediction_id = %id_owned, - error = %e, - "Failed to send cancel to orchestrator" - ); - } - }); + spawn_orchestrator_cancel(orch, id_owned); } true } else { @@ -772,6 +790,22 @@ impl PredictionService { } } +fn spawn_orchestrator_cancel(orch: Arc, id: String) { + let Ok(handle) = tokio::runtime::Handle::try_current() else { + tracing::warn!(prediction_id = %id, "No tokio runtime available to cancel prediction"); + return; + }; + handle.spawn(async move { + if let Err(e) = orch.cancel_by_prediction_id(&id).await { + tracing::error!( + prediction_id = %id, + error = %e, + "Failed to send cancel to orchestrator" + ); + } + }); +} + /// Build a `SlotRequest::Predict`, spilling the input to disk if it exceeds /// `MAX_INLINE_IPC_SIZE`. This prevents IPC frame overflow on the slot socket. /// @@ -1235,6 +1269,69 @@ mod tests { assert_eq!(orchestrator_ref.cancel_count(), 1); } + #[tokio::test] + async fn dropping_one_of_two_sync_stream_subscriptions_does_not_cancel_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "multi-sse-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + + let first = svc.subscribe_prediction_stream("multi-sse-stream").unwrap(); + let second = svc.subscribe_prediction_stream("multi-sse-stream").unwrap(); + drop(first); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 0); + + drop(second); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 1); + } + + #[tokio::test] + async fn subscribe_prediction_stream_rejects_too_many_subscribers() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::new()); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "subscriber-cap".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + + let mut subscriptions = Vec::new(); + for _ in 0..MAX_STREAM_SUBSCRIBERS { + subscriptions.push(svc.subscribe_prediction_stream("subscriber-cap").unwrap()); + } + + assert!(matches!( + svc.subscribe_prediction_stream("subscriber-cap"), + Err(SubscribePredictionStreamError::TooManySubscribers) + )); + } + #[tokio::test] async fn dropping_completed_sync_stream_subscription_does_not_cancel_prediction() { let svc = Arc::new(PredictionService::new_no_pool()); diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index 6822fdead3..499ae65ecc 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -19,9 +19,11 @@ use serde::{Deserialize, Serialize}; #[cfg(test)] use crate::health::Health; use crate::health::{HealthResponse, SetupResult}; +use crate::prediction::SharedPredictionStreamEvent; use crate::predictor::PredictionError; use crate::service::{ CreatePredictionError, HealthSnapshot, PredictionService, PredictionStreamSubscription, + SubscribePredictionStreamError, }; use crate::version::VersionInfo; use crate::webhook::{TraceContext, WebhookConfig, WebhookEventType, WebhookSender}; @@ -255,6 +257,16 @@ fn streaming_not_supported_response() -> Response { .into_response() } +fn training_streaming_not_supported_response() -> Response { + ( + StatusCode::NOT_ACCEPTABLE, + Json(serde_json::json!({ + "error": "Training endpoints do not support streaming responses." + })), + ) + .into_response() +} + fn json_response_mode(headers: &HeaderMap) -> PredictionResponseMode { if should_respond_async(headers) { PredictionResponseMode::AsyncJson @@ -483,7 +495,7 @@ async fn create_prediction_with_id( // Async mode: spawn background task, return immediately if response_mode != PredictionResponseMode::SyncJson { let sse_subscription = if response_mode == PredictionResponseMode::AsyncSse { - service.subscribe_prediction_stream(&prediction_id) + Some(service.subscribe_prediction_stream(&prediction_id)) } else { None }; @@ -501,12 +513,9 @@ async fn create_prediction_with_id( }); if response_mode == PredictionResponseMode::AsyncSse { - let Some(subscription) = sse_subscription else { - return ( - StatusCode::NOT_FOUND, - Json(serde_json::json!({"error": "Prediction not found"})), - ) - .into_response(); + let subscription = match sse_subscription.expect("SSE subscription requested") { + Ok(subscription) => subscription, + Err(error) => return stream_subscription_error_response(error), }; return stream_prediction_subscription_response(subscription); } @@ -657,7 +666,7 @@ async fn cancel_prediction( } } -fn stream_event_to_sse(event: crate::prediction::PredictionStreamEvent) -> Event { +fn stream_event_to_sse(event: SharedPredictionStreamEvent) -> Event { Event::default() .event(event.event_name()) .json_data(event.json_data()) @@ -670,16 +679,16 @@ fn prediction_sse_stream( let (replay, replay_skipped, receiver, guard) = subscription.into_parts(); struct StreamState { - replay: std::collections::VecDeque, + replay: std::collections::VecDeque, replay_skipped: u64, - receiver: tokio::sync::broadcast::Receiver, + receiver: tokio::sync::broadcast::Receiver, _guard: crate::service::PredictionStreamGuard, done: bool, } futures::stream::unfold( StreamState { - replay: replay.into(), + replay, replay_skipped, receiver, _guard: guard, @@ -734,17 +743,29 @@ fn prediction_sse_stream( } fn stream_prediction_response(service: Arc, prediction_id: &str) -> Response { - let Some(subscription) = service.subscribe_prediction_stream(prediction_id) else { - return ( - StatusCode::NOT_FOUND, - Json(serde_json::json!({"error": "Prediction not found"})), - ) - .into_response(); + let subscription = match service.subscribe_prediction_stream(prediction_id) { + Ok(subscription) => subscription, + Err(error) => return stream_subscription_error_response(error), }; stream_prediction_subscription_response(subscription) } +fn stream_subscription_error_response(error: SubscribePredictionStreamError) -> Response { + match error { + SubscribePredictionStreamError::NotFound => ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({"error": "Prediction not found"})), + ) + .into_response(), + SubscribePredictionStreamError::TooManySubscribers => ( + StatusCode::TOO_MANY_REQUESTS, + Json(serde_json::json!({"error": "Too many stream subscribers"})), + ) + .into_response(), + } +} + fn stream_prediction_subscription_response(subscription: PredictionStreamSubscription) -> Response { Sse::new(prediction_sse_stream(subscription)) .keep_alive( @@ -781,6 +802,10 @@ async fn create_training( headers: HeaderMap, body: Option>, ) -> Response { + if wants_sse(&headers) { + return training_streaming_not_supported_response(); + } + let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -811,6 +836,10 @@ async fn create_training_idempotent( headers: HeaderMap, body: Option>, ) -> Response { + if wants_sse(&headers) { + return training_streaming_not_supported_response(); + } + let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -1307,6 +1336,77 @@ mod tests { assert!(sse.contains("skipped"), "SSE body: {sse}"); } + #[tokio::test] + async fn failed_prediction_sse_stream_emits_completed_event() { + let service = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::never_complete()); + service.set_orchestrator(pool, orchestrator).await; + service.set_health(Health::Ready).await; + + let (_handle, slot) = service + .submit_prediction( + "failed-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let subscription = service + .subscribe_prediction_stream("failed-stream") + .unwrap(); + + { + let prediction = slot.prediction(); + let mut prediction = prediction.lock().unwrap(); + prediction.set_processing(); + prediction.set_failed("boom".to_string()); + } + + let response = Sse::new(prediction_sse_stream(subscription)).into_response(); + let collected = response.into_body().collect().await.unwrap(); + let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + assert!(sse.contains("event: completed"), "SSE body: {sse}"); + assert!(sse.contains(r#""status":"failed""#), "SSE body: {sse}"); + assert!(sse.contains(r#""error":"boom""#), "SSE body: {sse}"); + } + + #[tokio::test] + async fn canceled_prediction_sse_stream_emits_completed_event() { + let service = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::never_complete()); + service.set_orchestrator(pool, orchestrator).await; + service.set_health(Health::Ready).await; + + let (_handle, slot) = service + .submit_prediction( + "canceled-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let subscription = service + .subscribe_prediction_stream("canceled-stream") + .unwrap(); + + { + let prediction = slot.prediction(); + let mut prediction = prediction.lock().unwrap(); + prediction.set_processing(); + prediction.set_canceled(); + } + + let response = Sse::new(prediction_sse_stream(subscription)).into_response(); + let collected = response.into_body().collect().await.unwrap(); + let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + assert!(sse.contains("event: completed"), "SSE body: {sse}"); + assert!(sse.contains(r#""status":"canceled""#), "SSE body: {sse}"); + } + #[tokio::test] async fn prediction_put_with_sse_accept_returns_sse() { let service = create_ready_service().await; @@ -1565,6 +1665,54 @@ mod tests { assert_eq!(json["status"], "succeeded"); } + #[tokio::test] + async fn training_post_with_sse_accept_rejects() { + let service = create_ready_service().await; + let app = routes(service); + + let response = app + .oneshot( + Request::post("/trainings") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_ACCEPTABLE); + let json = response_json(response).await; + assert_eq!( + json["error"], + "Training endpoints do not support streaming responses." + ); + } + + #[tokio::test] + async fn training_put_with_sse_accept_rejects() { + let service = create_ready_service().await; + let app = routes(service); + + let response = app + .oneshot( + Request::put("/trainings/train-sse") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_ACCEPTABLE); + let json = response_json(response).await; + assert_eq!( + json["error"], + "Training endpoints do not support streaming responses." + ); + } + #[tokio::test] async fn training_idempotent_put() { let service = create_ready_service().await; diff --git a/docs/llms.txt b/docs/llms.txt index 9e73ed8d18..5809b003e4 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -1979,7 +1979,7 @@ Cog models can stream output as the `run()` method is running. For example, a la To support streaming output in your Cog model, add `from typing import Iterator` to your `run.py` file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `run()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. -To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the prediction method (`run()` or `predict()`) with `@cog.streaming` or `@streaming` imported from `cog`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. +To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the prediction method (`run()` or `predict()`) with `@cog.streaming`, `@cog.streaming()`, `@streaming`, or `@streaming()` imported from `cog`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. ```py from typing import Iterator diff --git a/docs/python.md b/docs/python.md index 192e2c0db3..d100d4cc0a 100644 --- a/docs/python.md +++ b/docs/python.md @@ -263,7 +263,7 @@ Cog models can stream output as the `run()` method is running. For example, a la To support streaming output in your Cog model, add `from typing import Iterator` to your `run.py` file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `run()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. -To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the prediction method (`run()` or `predict()`) with `@cog.streaming` or `@streaming` imported from `cog`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. +To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the prediction method (`run()` or `predict()`) with `@cog.streaming`, `@cog.streaming()`, `@streaming`, or `@streaming()` imported from `cog`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. ```py from typing import Iterator diff --git a/examples/streaming-text/predict.py b/examples/streaming-text/predict.py index 3ed8772e54..4f0d11c462 100644 --- a/examples/streaming-text/predict.py +++ b/examples/streaming-text/predict.py @@ -4,7 +4,7 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer -from cog import BasePredictor, Input +from cog import BasePredictor, Input, streaming MODEL_NAME = "HuggingFaceTB/SmolLM2-135M-Instruct" @@ -21,6 +21,7 @@ def setup(self) -> None: ).to(self.device) self.model.eval() + @streaming def predict( self, prompt: str = Input(description="Prompt to complete"), diff --git a/pkg/schema/openapi_test.go b/pkg/schema/openapi_test.go index 9cc037c684..e3d1b9e80c 100644 --- a/pkg/schema/openapi_test.go +++ b/pkg/schema/openapi_test.go @@ -667,7 +667,9 @@ func TestPredictionOperationIncludesStreamingExtensionWhenEnabled(t *testing.T) } spec := parseSpec(t, info) - post := getPath(spec, "paths", "/predictions", "post").(map[string]any) + postPath := getPath(spec, "paths", "/predictions", "post") + require.NotNil(t, postPath) + post := postPath.(map[string]any) assert.Equal(t, true, post["x-cog-streaming"]) } @@ -680,7 +682,9 @@ func TestPredictionOperationOmitsStreamingExtensionByDefault(t *testing.T) { } spec := parseSpec(t, info) - post := getPath(spec, "paths", "/predictions", "post").(map[string]any) + postPath := getPath(spec, "paths", "/predictions", "post") + require.NotNil(t, postPath) + post := postPath.(map[string]any) _, ok := post["x-cog-streaming"] assert.False(t, ok) } diff --git a/pkg/schema/python/parser.go b/pkg/schema/python/parser.go index eb1b0dd080..a32c73a602 100644 --- a/pkg/schema/python/parser.go +++ b/pkg/schema/python/parser.go @@ -698,20 +698,42 @@ func decoratorIsCogStreaming(node *sitter.Node, source []byte, imports *schema.I for _, child := range NamedChildren(node) { switch child.Type() { case "attribute": - return Content(child, source) == "cog.streaming" + return attributeIsCogStreaming(child, source, imports) case "identifier": - if Content(child, source) != "streaming" { + return identifierIsCogStreaming(child, source, imports) + case "call": + callee := child.ChildByFieldName("function") + if callee == nil { return false } - entry, ok := imports.Names.Get("streaming") - return ok && entry.Module == "cog" && entry.Original == "streaming" - case "call": - return false + switch callee.Type() { + case "attribute": + return attributeIsCogStreaming(callee, source, imports) + case "identifier": + return identifierIsCogStreaming(callee, source, imports) + } } } return false } +func attributeIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + parts := strings.SplitN(Content(node, source), ".", 2) + if len(parts) != 2 || parts[1] != "streaming" { + return false + } + entry, ok := imports.Names.Get(parts[0]) + return ok && entry.Module == "cog" && entry.Original == "cog" +} + +func identifierIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + if Content(node, source) != "streaming" { + return false + } + entry, ok := imports.Names.Get("streaming") + return ok && entry.Module == "cog" && entry.Original == "streaming" +} + func InheritsFromBaseModel(classNode *sitter.Node, source []byte, imports *schema.ImportContext) bool { supers := classNode.ChildByFieldName("superclasses") if supers == nil { diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index a4b1ae91de..a94ab04a37 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -589,6 +589,20 @@ class Predictor(cog.BasePredictor): require.True(t, info.SupportsStreaming) } +func TestStreamingDecoratorQualifiedAliasOptIn(t *testing.T) { + source := ` +import cog as c +from typing import Iterator + +class Predictor(c.BasePredictor): + @c.streaming + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.True(t, info.SupportsStreaming) +} + func TestStreamingDecoratorImportedOptIn(t *testing.T) { source := ` from cog import BasePredictor, streaming @@ -618,7 +632,7 @@ class Predictor(BasePredictor): require.False(t, info.SupportsStreaming) } -func TestStreamingDecoratorParameterizedFormIgnored(t *testing.T) { +func TestStreamingDecoratorParameterizedFormOptIn(t *testing.T) { source := ` import cog from typing import Iterator @@ -629,7 +643,21 @@ class Predictor(cog.BasePredictor): yield "hello" ` info := parse(t, source, "Predictor") - require.False(t, info.SupportsStreaming) + require.True(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorImportedParameterizedFormOptIn(t *testing.T) { + source := ` +from cog import BasePredictor, streaming +from typing import Iterator + +class Predictor(BasePredictor): + @streaming() + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.True(t, info.SupportsStreaming) } func TestStreamingDecoratorClassLevelIgnored(t *testing.T) { diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 3e7c6fe457..7745a8b627 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -21,7 +21,7 @@ def run( import sys as _sys from collections.abc import Callable -from typing import TypeVar +from typing import TypeVar, overload from coglet import CancelationException as CancelationException @@ -40,13 +40,27 @@ def run( URLPath, ) -F = TypeVar("F", bound=Callable[..., object]) +_F = TypeVar("_F", bound=Callable[..., object]) -def streaming(fn: F) -> F: +@overload +def streaming(fn: _F) -> _F: ... + + +@overload +def streaming(fn: None = None) -> Callable[[_F], _F]: ... + + +def streaming(fn: _F | None = None) -> _F | Callable[[_F], _F]: """Mark a predict handler as supporting streaming responses.""" - fn.__cog_streaming__ = True # type: ignore[attr-defined] - return fn + + def decorate(inner: _F) -> _F: + inner.__cog_streaming__ = True # type: ignore[attr-defined] + return inner + + if fn is None: + return decorate + return decorate(fn) # --------------------------------------------------------------------------- @@ -144,6 +158,7 @@ def current_scope() -> object: "CancelationException", # Metrics "current_scope", + # Decorators "streaming", # Deprecated compat shims "ExperimentalFeatureWarning", diff --git a/python/tests/test_types.py b/python/tests/test_types.py index 6b4aead4d7..3981ccdcef 100644 --- a/python/tests/test_types.py +++ b/python/tests/test_types.py @@ -146,3 +146,12 @@ def predict() -> str: assert decorated is predict assert predict.__cog_streaming__ is True # type: ignore[attr-defined] + + def test_streaming_call_form_marks_function_and_returns_same_object(self) -> None: + def predict() -> str: + return "ok" + + decorated = streaming()(predict) + + assert decorated is predict + assert predict.__cog_streaming__ is True # type: ignore[attr-defined] From fdd938fa662c5e97f8ce5a2364d16d6cf7955fb9 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 21 May 2026 15:29:13 -0400 Subject: [PATCH 10/11] feat: configure streaming replay history --- crates/coglet/src/prediction.rs | 134 +++++- docs/environment.md | 260 +++++++++-- docs/http.md | 179 ++++++- docs/llms.txt | 439 +++++++++++++++--- .../tests/sse_stream_history_capacity.txtar | 41 ++ .../tests/sse_stream_history_disabled.txtar | 41 ++ .../tests/sse_streaming_metrics.txtar | 42 ++ 7 files changed, 1006 insertions(+), 130 deletions(-) create mode 100644 integration-tests/tests/sse_stream_history_capacity.txtar create mode 100644 integration-tests/tests/sse_stream_history_disabled.txtar create mode 100644 integration-tests/tests/sse_streaming_metrics.txtar diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index c7a4ddf0d7..2b52e91129 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -11,7 +11,8 @@ use crate::bridge::protocol::{LogSource, MetricMode}; use crate::webhook::{WebhookEventType, WebhookSender}; pub const STREAM_CHANNEL_CAPACITY: usize = 1024; -pub const STREAM_HISTORY_CAPACITY: usize = 1024; +pub const DEFAULT_STREAM_HISTORY_CAPACITY: usize = 1024; +const STREAM_HISTORY_CAPACITY_ENV: &str = "COG_STREAM_HISTORY_CAPACITY"; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PredictionStatus { @@ -148,6 +149,7 @@ pub struct Prediction { completion: Arc, stream_tx: tokio::sync::broadcast::Sender, stream_history: VecDeque, + stream_history_capacity: usize, stream_history_skipped: u64, /// User-emitted metrics. Merged with system metrics (predict_time) in terminal response. metrics: HashMap, @@ -156,6 +158,7 @@ pub struct Prediction { impl Prediction { pub fn new(id: String, webhook: Option) -> Self { let (stream_tx, _) = tokio::sync::broadcast::channel(STREAM_CHANNEL_CAPACITY); + let stream_history_capacity = stream_history_capacity_from_env(); Self { id, @@ -170,6 +173,7 @@ impl Prediction { completion: Arc::new(Notify::new()), stream_tx, stream_history: VecDeque::new(), + stream_history_capacity, stream_history_skipped: 0, metrics: HashMap::new(), } @@ -202,15 +206,47 @@ impl Prediction { } fn emit_stream_event(&mut self, event: PredictionStreamEvent) { - if self.stream_history.len() == STREAM_HISTORY_CAPACITY { - self.stream_history.pop_front(); - self.stream_history_skipped += 1; - } let event = Arc::new(event); - self.stream_history.push_back(Arc::clone(&event)); + if self.stream_history_capacity > 0 { + if self.stream_history.len() == self.stream_history_capacity { + self.stream_history.pop_front(); + self.stream_history_skipped += 1; + } + self.stream_history.push_back(Arc::clone(&event)); + } let _ = self.stream_tx.send(event); } +} +fn stream_history_capacity_from_env() -> usize { + match std::env::var(STREAM_HISTORY_CAPACITY_ENV) { + Ok(value) => match value.parse::() { + Ok(capacity) => capacity, + Err(error) => { + tracing::warn!( + env_var = STREAM_HISTORY_CAPACITY_ENV, + value, + error = %error, + default = DEFAULT_STREAM_HISTORY_CAPACITY, + "Invalid stream history capacity; using default" + ); + DEFAULT_STREAM_HISTORY_CAPACITY + } + }, + Err(std::env::VarError::NotPresent) => DEFAULT_STREAM_HISTORY_CAPACITY, + Err(error) => { + tracing::warn!( + env_var = STREAM_HISTORY_CAPACITY_ENV, + error = %error, + default = DEFAULT_STREAM_HISTORY_CAPACITY, + "Invalid stream history capacity; using default" + ); + DEFAULT_STREAM_HISTORY_CAPACITY + } + } +} + +impl Prediction { pub fn is_canceled(&self) -> bool { self.cancel_token.is_cancelled() } @@ -553,6 +589,36 @@ impl Prediction { #[cfg(test)] mod tests { use super::*; + use std::ffi::OsString; + use std::sync::{Mutex, MutexGuard, OnceLock}; + + struct StreamHistoryCapacityEnvGuard { + previous: Option, + _lock: MutexGuard<'static, ()>, + } + + impl Drop for StreamHistoryCapacityEnvGuard { + fn drop(&mut self) { + match &self.previous { + Some(value) => unsafe { std::env::set_var(STREAM_HISTORY_CAPACITY_ENV, value) }, + None => unsafe { std::env::remove_var(STREAM_HISTORY_CAPACITY_ENV) }, + } + } + } + + fn set_stream_history_capacity(value: Option<&str>) -> StreamHistoryCapacityEnvGuard { + static LOCK: OnceLock> = OnceLock::new(); + let lock = LOCK.get_or_init(|| Mutex::new(())).lock().unwrap(); + let previous = std::env::var_os(STREAM_HISTORY_CAPACITY_ENV); + match value { + Some(value) => unsafe { std::env::set_var(STREAM_HISTORY_CAPACITY_ENV, value) }, + None => unsafe { std::env::remove_var(STREAM_HISTORY_CAPACITY_ENV) }, + } + StreamHistoryCapacityEnvGuard { + previous, + _lock: lock, + } + } #[test] fn status_is_terminal() { @@ -707,6 +773,7 @@ mod tests { #[tokio::test] async fn prediction_stream_replay_includes_already_emitted_events() { + let _guard = set_stream_history_capacity(None); let mut prediction = Prediction::new("pred_replay".to_string(), None); prediction.set_processing(); @@ -731,6 +798,7 @@ mod tests { #[tokio::test] async fn prediction_stream_replay_is_bounded_to_recent_events() { + let _guard = set_stream_history_capacity(None); let mut prediction = Prediction::new("pred_replay_bounded".to_string(), None); prediction.set_processing(); @@ -740,18 +808,68 @@ mod tests { let replay = prediction.subscribe_stream_replay(); - assert_eq!(replay.replay.len(), STREAM_HISTORY_CAPACITY); + assert_eq!(replay.replay.len(), DEFAULT_STREAM_HISTORY_CAPACITY); assert_eq!(replay.skipped, 77); assert_eq!( replay.replay[0].json_data(), serde_json::json!({"chunk":76,"index":76}) ); assert_eq!( - replay.replay[STREAM_HISTORY_CAPACITY - 1].json_data(), + replay.replay[DEFAULT_STREAM_HISTORY_CAPACITY - 1].json_data(), serde_json::json!({"chunk":1099,"index":1099}) ); } + #[tokio::test] + async fn prediction_stream_replay_uses_configured_history_capacity() { + let _guard = set_stream_history_capacity(Some("2")); + let mut prediction = Prediction::new("pred_replay_configured".to_string(), None); + + prediction.append_output_chunk(serde_json::json!(0), 0); + prediction.append_output_chunk(serde_json::json!(1), 1); + prediction.append_output_chunk(serde_json::json!(2), 2); + + let replay = prediction.subscribe_stream_replay(); + + assert_eq!(replay.replay.len(), 2); + assert_eq!(replay.skipped, 1); + assert_eq!( + replay.replay[0].json_data(), + serde_json::json!({"chunk":1,"index":1}) + ); + assert_eq!( + replay.replay[1].json_data(), + serde_json::json!({"chunk":2,"index":2}) + ); + } + + #[tokio::test] + async fn prediction_stream_replay_can_be_disabled_with_zero_capacity() { + let _guard = set_stream_history_capacity(Some("0")); + let mut prediction = Prediction::new("pred_replay_disabled".to_string(), None); + + prediction.append_output_chunk(serde_json::json!(0), 0); + prediction.append_output_chunk(serde_json::json!(1), 1); + + let replay = prediction.subscribe_stream_replay(); + + assert!(replay.replay.is_empty()); + assert_eq!(replay.skipped, 0); + } + + #[tokio::test] + async fn prediction_stream_replay_uses_default_for_invalid_capacity() { + let _guard = set_stream_history_capacity(Some("nope")); + let mut prediction = Prediction::new("pred_replay_invalid".to_string(), None); + + prediction.append_output_chunk(serde_json::json!(0), 0); + + let replay = prediction.subscribe_stream_replay(); + + assert_eq!(replay.replay.len(), 1); + assert_eq!(replay.skipped, 0); + } + #[tokio::test] async fn wait_returns_immediately_if_terminal() { let mut pred = Prediction::new("test".to_string(), None); diff --git a/docs/environment.md b/docs/environment.md index 072e013a36..818c08daba 100644 --- a/docs/environment.md +++ b/docs/environment.md @@ -1,12 +1,12 @@ # Environment variables -This guide lists the environment variables that change how Cog functions. +This reference lists the public Cog-specific environment variables that change how Cog behaves. ## Build-time variables ### `COG_SDK_WHEEL` -Controls which cog Python SDK wheel is installed in the Docker image during `cog build`. Takes precedence over `build.sdk_version` in `cog.yaml`. +Controls which Cog Python SDK wheel is installed in the Docker image during `cog build`. Takes precedence over `build.sdk_version` in `cog.yaml`. **Supported values:** @@ -18,98 +18,262 @@ Controls which cog Python SDK wheel is installed in the Docker image during `cog | `https://...` | Install from URL | | `/path/to/wheel.whl` | Install from local file path | -**Default behavior:** +**Default behaviour:** -- **Release builds**: Installs latest cog from PyPI -- **Development builds**: Auto-detects wheel in `dist/` directory, falls back to latest PyPI - -**Examples:** +- Release builds install the latest Cog SDK from PyPI. +- Development builds auto-detect a wheel in `dist/`, then fall back to the latest Cog SDK from PyPI. ```console -# Use specific PyPI version $ COG_SDK_WHEEL=pypi:0.11.0 cog build - -# Use local development wheel $ COG_SDK_WHEEL=dist cog build - -# Use wheel from URL $ COG_SDK_WHEEL=https://example.com/cog-0.12.0-py3-none-any.whl cog build ``` The `dist` option searches for wheels in: 1. `./dist/` (current directory) -2. `$REPO_ROOT/dist/` (if REPO_ROOT is set) +2. `$REPO_ROOT/dist/` (if `REPO_ROOT` is set) 3. `/dist/` (via `git rev-parse`, useful when running from subdirectories) ### `COGLET_WHEEL` -Controls which coglet wheel is installed in the Docker image. Coglet is the Rust-based prediction server. +Controls which Coglet wheel is installed in the Docker image. Coglet is the Rust-based prediction server. -**Supported values:** Same as `COG_SDK_WHEEL` +**Supported values:** Same as `COG_SDK_WHEEL`. -**Default behavior:** For development builds, auto-detects a wheel in `dist/`. For release builds, installs the latest version from PyPI. Can be overridden with an explicit value. - -**Examples:** +**Default behaviour:** For development builds, auto-detects a wheel in `dist/`. For release builds, installs the latest version from PyPI. ```console -# Use local development wheel $ COGLET_WHEEL=dist cog build - -# Use specific version from PyPI $ COGLET_WHEEL=pypi:0.1.0 cog build ``` -## Runtime variables +### `COG_CA_CERT` + +Injects a custom CA certificate into the Docker image during `cog build`. This is useful when building behind a corporate proxy or VPN that uses custom certificate authorities (for example, Cloudflare WARP). + +**Supported values:** + +| Value | Description | +| -------------------------------- | ----------------------------------------------------------- | +| `/path/to/cert.crt` | Path to a single PEM certificate file | +| `/path/to/certs/` | Directory of `.crt` and `.pem` files (all are concatenated) | +| `-----BEGIN CERTIFICATE-----...` | Inline PEM certificate | +| `LS0tLS1CRUdJTi...` | Base64-encoded PEM certificate | + +The certificate is installed into the system CA store and the `SSL_CERT_FILE` and `REQUESTS_CA_BUNDLE` environment variables are set automatically in the built image. + +```console +$ COG_CA_CERT=/usr/local/share/ca-certificates/corporate-ca.crt cog build +$ COG_CA_CERT=/etc/custom-certs/ cog build +$ COG_CA_CERT="$(cat /path/to/cert.pem)" cog build +``` + +### `COG_OPENAPI_SCHEMA` + +Uses a pre-built OpenAPI schema instead of generating one from the configured predict or train reference. + +The value must be a path to a JSON schema file. Cog reads that file during schema generation and embeds it in the built image. + +```console +$ COG_OPENAPI_SCHEMA=./openapi.json cog build +``` + +## CLI and local cache variables ### `COG_NO_UPDATE_CHECK` -By default, Cog automatically checks for updates -and notifies you if there is a new version available. +Disables Cog's automatic update check. Set it to any non-empty value. + +```console +$ COG_NO_UPDATE_CHECK=1 cog build +``` + +### `COG_NO_COLOR` -To disable this behavior, -set the `COG_NO_UPDATE_CHECK` environment variable to any value. +Disables coloured CLI output. Set it to any non-empty value. + +Cog also honours the standard `NO_COLOR` environment variable. + +```console +$ COG_NO_COLOR=1 cog predict -i prompt="hello" +``` + +### `COG_SKIP_DOCKER_CHECK` + +Skips the `cog doctor` Docker environment check. Set it to any non-empty value. + +```console +$ COG_SKIP_DOCKER_CHECK=1 cog doctor +``` + +### `COG_CACHE_DIR` + +Overrides Cog's local cache root. + +Cog currently uses this cache for the content-addressed weights store. If unset, Cog uses `$XDG_CACHE_HOME/cog` when `XDG_CACHE_HOME` is set, otherwise `$HOME/.cache/cog`. + +```console +$ COG_CACHE_DIR=/mnt/fast-cache cog weights pull +``` + +## Model reference and registry variables + +### `COG_MODEL` + +Overrides the full model reference used by commands that need a model destination, such as `cog push` and weights commands. + +The value is parsed as a complete model reference (`registry/repo`, `registry/repo:tag`, or `registry/repo@digest`). If no tag is supplied, Cog generates a timestamp tag. + +When `COG_MODEL` is set, it takes precedence over `COG_MODEL_REGISTRY`, `COG_MODEL_REPO`, and `COG_MODEL_TAG`. ```console -$ COG_NO_UPDATE_CHECK=1 cog build # runs without automatic update check +$ COG_MODEL=r8.im/acme/my-model:v1 cog push +``` + +### `COG_MODEL_REGISTRY` + +Overrides only the registry host of the model reference. + +```console +$ COG_MODEL_REGISTRY=registry.example.com cog push +``` + +### `COG_MODEL_REPO` + +Overrides only the repository path of the model reference. The value must not include a registry host, tag, or digest. + +```console +$ COG_MODEL_REPO=acme/my-model cog push +``` + +### `COG_MODEL_TAG` + +Overrides only the tag of the model reference. + +Tags starting with `cog-` are reserved for tags that Cog generates internally and are rejected. + +```console +$ COG_MODEL_TAG=staging cog push +``` + +### `COG_REGISTRY_HOST` + +Changes the default Replicate-compatible registry host used by commands such as `cog login`, base image resolution, and model reference resolution. + +The default is `r8.im`. + +```console +$ COG_REGISTRY_HOST=registry.example.com cog login +``` + +## Runtime server variables + +These variables affect a running model server. Set them in `cog.yaml` under `environment`, pass them with `cog predict -e` or `cog serve -e`, or set them when running the built Docker image. + +### `COG_MAX_CONCURRENCY` + +Controls how many predictions the model server can run concurrently. + +By default, Cog runs one prediction at a time. Invalid values are ignored and the default of `1` is used. + +```console +$ COG_MAX_CONCURRENCY=4 docker run -p 5000:5000 my-model ``` ### `COG_SETUP_TIMEOUT` -Controls the maximum time (in seconds) allowed for the model's `setup()` method to complete. If setup exceeds this timeout, the server will report a setup failure. +Controls the maximum time, in seconds, allowed for the model's `setup()` method to complete. If setup exceeds this timeout, the server reports setup failure. + +By default, there is no timeout. Set to `0` to disable the timeout. Invalid values are ignored with a warning. -By default, there is no timeout — setup runs indefinitely. +```console +$ COG_SETUP_TIMEOUT=300 docker run -p 5000:5000 my-model +``` + +### `COG_LOG_LEVEL` -Set to `0` to disable the timeout (same as default). Invalid values are ignored with a warning. +Controls Coglet runtime log verbosity when `RUST_LOG` is not set. + +Supported values are `debug`, `info`, `warn`, `warning`, and `error`. The default is `info`. ```console -$ COG_SETUP_TIMEOUT=300 docker run -p 5000:5000 my-model # 5-minute setup timeout +$ COG_LOG_LEVEL=debug docker run -p 5000:5000 my-model ``` -### `COG_CA_CERT` +### `COG_THROTTLE_RESPONSE_INTERVAL` -Injects a custom CA certificate into the Docker image during `cog build`. This is useful when building behind a corporate proxy or VPN that uses custom certificate authorities (e.g. Cloudflare WARP). +Controls how often asynchronous webhook `output` and `logs` events are sent, in seconds. -**Supported values:** +The default is `0.5` seconds. Invalid values are ignored and the default is used. `start` and `completed` webhook events are always sent immediately. -| Value | Description | -| -------------------------------- | ----------------------------------------------------------- | -| `/path/to/cert.crt` | Path to a single PEM certificate file | -| `/path/to/certs/` | Directory of `.crt` and `.pem` files (all are concatenated) | -| `-----BEGIN CERTIFICATE-----...` | Inline PEM certificate | -| `LS0tLS1CRUdJTi...` | Base64-encoded PEM certificate | +```console +$ COG_THROTTLE_RESPONSE_INTERVAL=1 docker run -p 5000:5000 my-model +``` -The certificate is installed into the system CA store and the `SSL_CERT_FILE` and `REQUESTS_CA_BUNDLE` environment variables are set automatically in the built image. +### `COG_STREAM_HISTORY_CAPACITY` -**Examples:** +Controls how many server-sent event stream events are retained per prediction for replay when a client reconnects with `Accept: text/event-stream`. + +By default, Cog retains the most recent 1024 events per prediction. Set to `0` to disable replay history while keeping live streaming enabled. Invalid values are ignored with a warning and the default is used. ```console -# From a file -$ COG_CA_CERT=/usr/local/share/ca-certificates/corporate-ca.crt cog build +$ COG_STREAM_HISTORY_CAPACITY=0 docker run -p 5000:5000 my-model +$ COG_STREAM_HISTORY_CAPACITY=4096 docker run -p 5000:5000 my-model +``` -# From a directory of certs -$ COG_CA_CERT=/etc/custom-certs/ cog build +### `COG_WEIGHTS` -# Inline (e.g. from a CI secret) -$ COG_CA_CERT="$(cat /path/to/cert.pem)" cog build +Provides a weights path or URL to a model whose `setup()` method accepts a `weights` parameter. + +```console +$ cog run -e COG_WEIGHTS=https://example.com/weights.tar -i prompt="hello" +``` + +### `COG_USER_AGENT` + +Sets the `User-Agent` header used by Cog when downloading URL-backed `File` inputs. + +```console +$ COG_USER_AGENT="my-service/1.0" docker run -p 5000:5000 my-model +``` + +## Push tuning variables + +### `COG_PUSH_OCI` + +Enables Cog's OCI chunked push path for container image layers when set to `1`. If the OCI push fails with a non-fatal error, Cog falls back to Docker's native push path. + +```console +$ COG_PUSH_OCI=1 cog push +``` + +### `COG_PUSH_CONCURRENCY` + +Controls how many image layers or weight blobs Cog uploads concurrently during push operations. + +The default is `5`. Invalid values and values less than `1` are ignored. + +```console +$ COG_PUSH_CONCURRENCY=2 cog push +``` + +### `COG_PUSH_DEFAULT_CHUNK_SIZE` + +Sets the default multipart upload chunk size, in bytes, when the registry does not advertise a maximum chunk size. + +The default is 96 MiB. Invalid values and values less than `1` are ignored. + +```console +$ COG_PUSH_DEFAULT_CHUNK_SIZE=67108864 cog push +``` + +### `COG_PUSH_MULTIPART_THRESHOLD` + +Sets the minimum blob size, in bytes, before Cog uses multipart upload. + +The default is 128 MiB. Invalid values and values less than `1` are ignored. + +```console +$ COG_PUSH_MULTIPART_THRESHOLD=268435456 cog push ``` diff --git a/docs/http.md b/docs/http.md index 480ad23d38..dedd20dc40 100644 --- a/docs/http.md +++ b/docs/http.md @@ -17,13 +17,18 @@ The server supports both synchronous and asynchronous prediction creation: and processes the prediction in the background. The client can create a prediction asynchronously -by setting the `Prefer: respond-async` header in their request. -When provided, the server responds immediately after starting the prediction -with `202 Accepted` status and a prediction object in status `processing`. +by setting the `Prefer: respond-async` header in their request +or by requesting a streamed response with `Accept: text/event-stream`. +With `Prefer: respond-async`, +the server responds immediately after starting the prediction +with `202 Accepted` status and a prediction object in status `starting`. +With `Accept: text/event-stream`, +the server responds with `200 OK` and keeps the response open +as a server-sent event stream. > [!NOTE] -> The only supported way to receive updates on the status of predictions -> started asynchronously is using [webhooks](#webhooks). +> For JSON responses, the only supported way to receive updates on the status +> of predictions started asynchronously is using [webhooks](#webhooks). > Polling for prediction status is not currently supported. You can also use certain server endpoints to create predictions idempotently, @@ -31,28 +36,164 @@ such that if a client calls this endpoint more than once with the same ID (for example, due to a network interruption) while the prediction is still running, no new prediction is created. -Instead, the client receives a `202 Accepted` response -with the initial state of the prediction. +Instead, the client receives the response type requested by the retry: +JSON for regular requests or a server-sent event stream for streaming requests. --- Here's a summary of the prediction creation endpoints: -| Endpoint | Header | Behavior | -| ---------------------------------- | ----------------------- | ---------------------------- | -| `POST /predictions` | - | Synchronous, non-idempotent | -| `POST /predictions` | `Prefer: respond-async` | Asynchronous, non-idempotent | -| `PUT /predictions/` | - | Synchronous, idempotent | -| `PUT /predictions/` | `Prefer: respond-async` | Asynchronous, idempotent | +| Endpoint | Header | Behavior | +| ---------------------------------- | --------------------------- | ---------------------------- | +| `POST /predictions` | - | Synchronous, non-idempotent | +| `POST /predictions` | `Prefer: respond-async` | Asynchronous, non-idempotent | +| `POST /predictions` | `Accept: text/event-stream` | Streaming, non-idempotent | +| `PUT /predictions/` | - | Synchronous, idempotent | +| `PUT /predictions/` | `Prefer: respond-async` | Asynchronous, idempotent | +| `PUT /predictions/` | `Accept: text/event-stream` | Streaming, idempotent | Choose the endpoint that best fits your needs: - Use synchronous endpoints when you want to wait for the prediction result. - Use asynchronous endpoints when you want to start a prediction and receive updates via webhooks. +- Use streaming endpoints when you want to receive prediction lifecycle events + over the HTTP response as they happen. - Use idempotent endpoints when you need to safely retry requests without creating duplicate predictions. +## Streaming predictions with server-sent events + +To produce streamed prediction events, +the model must return an iterator and opt in to SSE streaming +with the `streaming` decorator. + +```python +from typing import Iterator + +from cog import BaseRunner, Input, streaming + + +class Runner(BaseRunner): + @streaming + def run(self, prompt: str = Input(description="Prompt")) -> Iterator[str]: + for token in generate_tokens(prompt): + yield token +``` + +The decorator can also be written as `@streaming()`, +`@cog.streaming`, or `@cog.streaming()`. +Without the decorator, +iterator outputs still work in normal JSON responses, +but requests with `Accept: text/event-stream` return `406 Not Acceptable`. + +To consume a streamed prediction, +send the prediction request with `Accept: text/event-stream`: + +```http +POST /predictions HTTP/1.1 +Content-Type: application/json; charset=utf-8 +Accept: text/event-stream + +{ + "input": {"prompt": "Write a haiku about onions"} +} +``` + +The server starts the prediction asynchronously +and keeps the HTTP response open as a server-sent event stream. +Each event has an `event` name and JSON `data` payload: + +```text +event: start +data: {"id":"abc123","status":"processing"} + +event: output +data: {"chunk":"Onions","index":0} + +event: output +data: {"chunk":" bloom","index":1} + +event: completed +data: {"id":"abc123","status":"succeeded","output":["Onions"," bloom"],"metrics":{"predict_time":0.42}} +``` + +Prediction streams can emit these event types: + +- `start`: The prediction started processing. +- `output`: The model yielded an output chunk. + The payload includes `chunk` and `index`. +- `log`: The model wrote to `stdout` or `stderr`. + The payload includes `source` and `data`. +- `metric`: The model recorded a custom metric. + The payload includes `name`, `value`, and `mode`. +- `completed`: The prediction reached a terminal state. + The payload is the final prediction object, + with `status` set to `succeeded`, `failed`, or `canceled`. + +For command-line clients, +use a client that prints the response as data arrives: + +```bash +curl -N \ + -H 'Accept: text/event-stream' \ + -H 'Content-Type: application/json' \ + -d '{"input":{"prompt":"Write a haiku about onions"}}' \ + http://localhost:5000/predictions +``` + +For browser clients, +use `fetch()` or another client that supports request bodies. +The browser `EventSource` API only supports `GET` requests, +so it cannot create a prediction with `POST /predictions` or +`PUT /predictions/`. + +```js +const response = await fetch("/predictions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + }, + body: JSON.stringify({ input: { prompt: "Write a haiku about onions" } }), +}); + +const reader = response.body + .pipeThrough(new TextDecoderStream()) + .getReader(); + +while (true) { + const { value, done } = await reader.read(); + if (done) break; + console.log(value); +} +``` + +Use `PUT /predictions/` when the client needs safe retries +or wants to reconnect to an in-flight prediction by ID: + +```http +PUT /predictions/wjx3whax6rf4vphkegkhcvpv6a HTTP/1.1 +Content-Type: application/json; charset=utf-8 +Accept: text/event-stream + +{ + "input": {"prompt": "Write a haiku about onions"} +} +``` + +If the prediction is still running, +the server returns a stream for the existing prediction +instead of creating a duplicate prediction. +If earlier events have been dropped from the replay buffer, +the stream emits an `error` event and closes. +The replay buffer keeps the most recent 1024 events by default. +Set `COG_STREAM_HISTORY_CAPACITY` to change this limit, +or set it to `0` to disable replay history while keeping live streaming enabled. +Training endpoints do not support SSE streaming; +requests to `/trainings` with `Accept: text/event-stream` +return `406 Not Acceptable`. + ## Webhooks You can provide a `webhook` parameter in the client request body @@ -367,6 +508,11 @@ Content-Type: application/json } ``` +If the client sets the `Accept: text/event-stream` header, +the server starts the prediction asynchronously and responds with a +server-sent event stream. +See [Streaming predictions with server-sent events](#streaming-predictions-with-server-sent-events). + ### `PUT /predictions/` Make a single prediction. @@ -415,6 +561,13 @@ Content-Type: application/json } ``` +If the client sets the `Accept: text/event-stream` header, +the server starts the prediction asynchronously and responds with a +server-sent event stream. +If a prediction with the same ID is already running, +the server returns a stream for the existing prediction. +See [Streaming predictions with server-sent events](#streaming-predictions-with-server-sent-events). + ### `POST /predictions//cancel` A client can cancel an asynchronous prediction by making a diff --git a/docs/llms.txt b/docs/llms.txt index 5809b003e4..f17c1fb9b4 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -608,13 +608,13 @@ See the [environment variables reference](environment.md) for the full list. # Environment variables -This guide lists the environment variables that change how Cog functions. +This reference lists the public Cog-specific environment variables that change how Cog behaves. ## Build-time variables ### `COG_SDK_WHEEL` -Controls which cog Python SDK wheel is installed in the Docker image during `cog build`. Takes precedence over `build.sdk_version` in `cog.yaml`. +Controls which Cog Python SDK wheel is installed in the Docker image during `cog build`. Takes precedence over `build.sdk_version` in `cog.yaml`. **Supported values:** @@ -626,100 +626,264 @@ Controls which cog Python SDK wheel is installed in the Docker image during `cog | `https://...` | Install from URL | | `/path/to/wheel.whl` | Install from local file path | -**Default behavior:** +**Default behaviour:** -- **Release builds**: Installs latest cog from PyPI -- **Development builds**: Auto-detects wheel in `dist/` directory, falls back to latest PyPI - -**Examples:** +- Release builds install the latest Cog SDK from PyPI. +- Development builds auto-detect a wheel in `dist/`, then fall back to the latest Cog SDK from PyPI. ```console -# Use specific PyPI version $ COG_SDK_WHEEL=pypi:0.11.0 cog build - -# Use local development wheel $ COG_SDK_WHEEL=dist cog build - -# Use wheel from URL $ COG_SDK_WHEEL=https://example.com/cog-0.12.0-py3-none-any.whl cog build ``` The `dist` option searches for wheels in: 1. `./dist/` (current directory) -2. `$REPO_ROOT/dist/` (if REPO_ROOT is set) +2. `$REPO_ROOT/dist/` (if `REPO_ROOT` is set) 3. `/dist/` (via `git rev-parse`, useful when running from subdirectories) ### `COGLET_WHEEL` -Controls which coglet wheel is installed in the Docker image. Coglet is the Rust-based prediction server. +Controls which Coglet wheel is installed in the Docker image. Coglet is the Rust-based prediction server. -**Supported values:** Same as `COG_SDK_WHEEL` +**Supported values:** Same as `COG_SDK_WHEEL`. -**Default behavior:** For development builds, auto-detects a wheel in `dist/`. For release builds, installs the latest version from PyPI. Can be overridden with an explicit value. - -**Examples:** +**Default behaviour:** For development builds, auto-detects a wheel in `dist/`. For release builds, installs the latest version from PyPI. ```console -# Use local development wheel $ COGLET_WHEEL=dist cog build - -# Use specific version from PyPI $ COGLET_WHEEL=pypi:0.1.0 cog build ``` -## Runtime variables +### `COG_CA_CERT` + +Injects a custom CA certificate into the Docker image during `cog build`. This is useful when building behind a corporate proxy or VPN that uses custom certificate authorities (for example, Cloudflare WARP). + +**Supported values:** + +| Value | Description | +| -------------------------------- | ----------------------------------------------------------- | +| `/path/to/cert.crt` | Path to a single PEM certificate file | +| `/path/to/certs/` | Directory of `.crt` and `.pem` files (all are concatenated) | +| `-----BEGIN CERTIFICATE-----...` | Inline PEM certificate | +| `LS0tLS1CRUdJTi...` | Base64-encoded PEM certificate | + +The certificate is installed into the system CA store and the `SSL_CERT_FILE` and `REQUESTS_CA_BUNDLE` environment variables are set automatically in the built image. + +```console +$ COG_CA_CERT=/usr/local/share/ca-certificates/corporate-ca.crt cog build +$ COG_CA_CERT=/etc/custom-certs/ cog build +$ COG_CA_CERT="$(cat /path/to/cert.pem)" cog build +``` + +### `COG_OPENAPI_SCHEMA` + +Uses a pre-built OpenAPI schema instead of generating one from the configured predict or train reference. + +The value must be a path to a JSON schema file. Cog reads that file during schema generation and embeds it in the built image. + +```console +$ COG_OPENAPI_SCHEMA=./openapi.json cog build +``` + +## CLI and local cache variables ### `COG_NO_UPDATE_CHECK` -By default, Cog automatically checks for updates -and notifies you if there is a new version available. +Disables Cog's automatic update check. Set it to any non-empty value. + +```console +$ COG_NO_UPDATE_CHECK=1 cog build +``` + +### `COG_NO_COLOR` + +Disables coloured CLI output. Set it to any non-empty value. + +Cog also honours the standard `NO_COLOR` environment variable. + +```console +$ COG_NO_COLOR=1 cog predict -i prompt="hello" +``` + +### `COG_SKIP_DOCKER_CHECK` + +Skips the `cog doctor` Docker environment check. Set it to any non-empty value. + +```console +$ COG_SKIP_DOCKER_CHECK=1 cog doctor +``` + +### `COG_CACHE_DIR` + +Overrides Cog's local cache root. -To disable this behavior, -set the `COG_NO_UPDATE_CHECK` environment variable to any value. +Cog currently uses this cache for the content-addressed weights store. If unset, Cog uses `$XDG_CACHE_HOME/cog` when `XDG_CACHE_HOME` is set, otherwise `$HOME/.cache/cog`. ```console -$ COG_NO_UPDATE_CHECK=1 cog build # runs without automatic update check +$ COG_CACHE_DIR=/mnt/fast-cache cog weights pull +``` + +## Model reference and registry variables + +### `COG_MODEL` + +Overrides the full model reference used by commands that need a model destination, such as `cog push` and weights commands. + +The value is parsed as a complete model reference (`registry/repo`, `registry/repo:tag`, or `registry/repo@digest`). If no tag is supplied, Cog generates a timestamp tag. + +When `COG_MODEL` is set, it takes precedence over `COG_MODEL_REGISTRY`, `COG_MODEL_REPO`, and `COG_MODEL_TAG`. + +```console +$ COG_MODEL=r8.im/acme/my-model:v1 cog push +``` + +### `COG_MODEL_REGISTRY` + +Overrides only the registry host of the model reference. + +```console +$ COG_MODEL_REGISTRY=registry.example.com cog push +``` + +### `COG_MODEL_REPO` + +Overrides only the repository path of the model reference. The value must not include a registry host, tag, or digest. + +```console +$ COG_MODEL_REPO=acme/my-model cog push +``` + +### `COG_MODEL_TAG` + +Overrides only the tag of the model reference. + +Tags starting with `cog-` are reserved for tags that Cog generates internally and are rejected. + +```console +$ COG_MODEL_TAG=staging cog push +``` + +### `COG_REGISTRY_HOST` + +Changes the default Replicate-compatible registry host used by commands such as `cog login`, base image resolution, and model reference resolution. + +The default is `r8.im`. + +```console +$ COG_REGISTRY_HOST=registry.example.com cog login +``` + +## Runtime server variables + +These variables affect a running model server. Set them in `cog.yaml` under `environment`, pass them with `cog predict -e` or `cog serve -e`, or set them when running the built Docker image. + +### `COG_MAX_CONCURRENCY` + +Controls how many predictions the model server can run concurrently. + +By default, Cog runs one prediction at a time. Invalid values are ignored and the default of `1` is used. + +```console +$ COG_MAX_CONCURRENCY=4 docker run -p 5000:5000 my-model ``` ### `COG_SETUP_TIMEOUT` -Controls the maximum time (in seconds) allowed for the model's `setup()` method to complete. If setup exceeds this timeout, the server will report a setup failure. +Controls the maximum time, in seconds, allowed for the model's `setup()` method to complete. If setup exceeds this timeout, the server reports setup failure. -By default, there is no timeout — setup runs indefinitely. +By default, there is no timeout. Set to `0` to disable the timeout. Invalid values are ignored with a warning. + +```console +$ COG_SETUP_TIMEOUT=300 docker run -p 5000:5000 my-model +``` -Set to `0` to disable the timeout (same as default). Invalid values are ignored with a warning. +### `COG_LOG_LEVEL` + +Controls Coglet runtime log verbosity when `RUST_LOG` is not set. + +Supported values are `debug`, `info`, `warn`, `warning`, and `error`. The default is `info`. ```console -$ COG_SETUP_TIMEOUT=300 docker run -p 5000:5000 my-model # 5-minute setup timeout +$ COG_LOG_LEVEL=debug docker run -p 5000:5000 my-model ``` -### `COG_CA_CERT` +### `COG_THROTTLE_RESPONSE_INTERVAL` -Injects a custom CA certificate into the Docker image during `cog build`. This is useful when building behind a corporate proxy or VPN that uses custom certificate authorities (e.g. Cloudflare WARP). +Controls how often asynchronous webhook `output` and `logs` events are sent, in seconds. -**Supported values:** +The default is `0.5` seconds. Invalid values are ignored and the default is used. `start` and `completed` webhook events are always sent immediately. -| Value | Description | -| -------------------------------- | ----------------------------------------------------------- | -| `/path/to/cert.crt` | Path to a single PEM certificate file | -| `/path/to/certs/` | Directory of `.crt` and `.pem` files (all are concatenated) | -| `-----BEGIN CERTIFICATE-----...` | Inline PEM certificate | -| `LS0tLS1CRUdJTi...` | Base64-encoded PEM certificate | +```console +$ COG_THROTTLE_RESPONSE_INTERVAL=1 docker run -p 5000:5000 my-model +``` -The certificate is installed into the system CA store and the `SSL_CERT_FILE` and `REQUESTS_CA_BUNDLE` environment variables are set automatically in the built image. +### `COG_STREAM_HISTORY_CAPACITY` + +Controls how many server-sent event stream events are retained per prediction for replay when a client reconnects with `Accept: text/event-stream`. -**Examples:** +By default, Cog retains the most recent 1024 events per prediction. Set to `0` to disable replay history while keeping live streaming enabled. Invalid values are ignored with a warning and the default is used. ```console -# From a file -$ COG_CA_CERT=/usr/local/share/ca-certificates/corporate-ca.crt cog build +$ COG_STREAM_HISTORY_CAPACITY=0 docker run -p 5000:5000 my-model +$ COG_STREAM_HISTORY_CAPACITY=4096 docker run -p 5000:5000 my-model +``` -# From a directory of certs -$ COG_CA_CERT=/etc/custom-certs/ cog build +### `COG_WEIGHTS` -# Inline (e.g. from a CI secret) -$ COG_CA_CERT="$(cat /path/to/cert.pem)" cog build +Provides a weights path or URL to a model whose `setup()` method accepts a `weights` parameter. + +```console +$ cog run -e COG_WEIGHTS=https://example.com/weights.tar -i prompt="hello" +``` + +### `COG_USER_AGENT` + +Sets the `User-Agent` header used by Cog when downloading URL-backed `File` inputs. + +```console +$ COG_USER_AGENT="my-service/1.0" docker run -p 5000:5000 my-model +``` + +## Push tuning variables + +### `COG_PUSH_OCI` + +Enables Cog's OCI chunked push path for container image layers when set to `1`. If the OCI push fails with a non-fatal error, Cog falls back to Docker's native push path. + +```console +$ COG_PUSH_OCI=1 cog push +``` + +### `COG_PUSH_CONCURRENCY` + +Controls how many image layers or weight blobs Cog uploads concurrently during push operations. + +The default is `5`. Invalid values and values less than `1` are ignored. + +```console +$ COG_PUSH_CONCURRENCY=2 cog push +``` + +### `COG_PUSH_DEFAULT_CHUNK_SIZE` + +Sets the default multipart upload chunk size, in bytes, when the registry does not advertise a maximum chunk size. + +The default is 96 MiB. Invalid values and values less than `1` are ignored. + +```console +$ COG_PUSH_DEFAULT_CHUNK_SIZE=67108864 cog push +``` + +### `COG_PUSH_MULTIPART_THRESHOLD` + +Sets the minimum blob size, in bytes, before Cog uses multipart upload. + +The default is 128 MiB. Invalid values and values less than `1` are ignored. + +```console +$ COG_PUSH_MULTIPART_THRESHOLD=268435456 cog push ``` @@ -1158,13 +1322,18 @@ The server supports both synchronous and asynchronous prediction creation: and processes the prediction in the background. The client can create a prediction asynchronously -by setting the `Prefer: respond-async` header in their request. -When provided, the server responds immediately after starting the prediction -with `202 Accepted` status and a prediction object in status `processing`. +by setting the `Prefer: respond-async` header in their request +or by requesting a streamed response with `Accept: text/event-stream`. +With `Prefer: respond-async`, +the server responds immediately after starting the prediction +with `202 Accepted` status and a prediction object in status `starting`. +With `Accept: text/event-stream`, +the server responds with `200 OK` and keeps the response open +as a server-sent event stream. > [!NOTE] -> The only supported way to receive updates on the status of predictions -> started asynchronously is using [webhooks](#webhooks). +> For JSON responses, the only supported way to receive updates on the status +> of predictions started asynchronously is using [webhooks](#webhooks). > Polling for prediction status is not currently supported. You can also use certain server endpoints to create predictions idempotently, @@ -1172,28 +1341,164 @@ such that if a client calls this endpoint more than once with the same ID (for example, due to a network interruption) while the prediction is still running, no new prediction is created. -Instead, the client receives a `202 Accepted` response -with the initial state of the prediction. +Instead, the client receives the response type requested by the retry: +JSON for regular requests or a server-sent event stream for streaming requests. --- Here's a summary of the prediction creation endpoints: -| Endpoint | Header | Behavior | -| ---------------------------------- | ----------------------- | ---------------------------- | -| `POST /predictions` | - | Synchronous, non-idempotent | -| `POST /predictions` | `Prefer: respond-async` | Asynchronous, non-idempotent | -| `PUT /predictions/` | - | Synchronous, idempotent | -| `PUT /predictions/` | `Prefer: respond-async` | Asynchronous, idempotent | +| Endpoint | Header | Behavior | +| ---------------------------------- | --------------------------- | ---------------------------- | +| `POST /predictions` | - | Synchronous, non-idempotent | +| `POST /predictions` | `Prefer: respond-async` | Asynchronous, non-idempotent | +| `POST /predictions` | `Accept: text/event-stream` | Streaming, non-idempotent | +| `PUT /predictions/` | - | Synchronous, idempotent | +| `PUT /predictions/` | `Prefer: respond-async` | Asynchronous, idempotent | +| `PUT /predictions/` | `Accept: text/event-stream` | Streaming, idempotent | Choose the endpoint that best fits your needs: - Use synchronous endpoints when you want to wait for the prediction result. - Use asynchronous endpoints when you want to start a prediction and receive updates via webhooks. +- Use streaming endpoints when you want to receive prediction lifecycle events + over the HTTP response as they happen. - Use idempotent endpoints when you need to safely retry requests without creating duplicate predictions. +## Streaming predictions with server-sent events + +To produce streamed prediction events, +the model must return an iterator and opt in to SSE streaming +with the `streaming` decorator. + +```python +from typing import Iterator + +from cog import BaseRunner, Input, streaming + + +class Runner(BaseRunner): + @streaming + def run(self, prompt: str = Input(description="Prompt")) -> Iterator[str]: + for token in generate_tokens(prompt): + yield token +``` + +The decorator can also be written as `@streaming()`, +`@cog.streaming`, or `@cog.streaming()`. +Without the decorator, +iterator outputs still work in normal JSON responses, +but requests with `Accept: text/event-stream` return `406 Not Acceptable`. + +To consume a streamed prediction, +send the prediction request with `Accept: text/event-stream`: + +```http +POST /predictions HTTP/1.1 +Content-Type: application/json; charset=utf-8 +Accept: text/event-stream + +{ + "input": {"prompt": "Write a haiku about onions"} +} +``` + +The server starts the prediction asynchronously +and keeps the HTTP response open as a server-sent event stream. +Each event has an `event` name and JSON `data` payload: + +```text +event: start +data: {"id":"abc123","status":"processing"} + +event: output +data: {"chunk":"Onions","index":0} + +event: output +data: {"chunk":" bloom","index":1} + +event: completed +data: {"id":"abc123","status":"succeeded","output":["Onions"," bloom"],"metrics":{"predict_time":0.42}} +``` + +Prediction streams can emit these event types: + +- `start`: The prediction started processing. +- `output`: The model yielded an output chunk. + The payload includes `chunk` and `index`. +- `log`: The model wrote to `stdout` or `stderr`. + The payload includes `source` and `data`. +- `metric`: The model recorded a custom metric. + The payload includes `name`, `value`, and `mode`. +- `completed`: The prediction reached a terminal state. + The payload is the final prediction object, + with `status` set to `succeeded`, `failed`, or `canceled`. + +For command-line clients, +use a client that prints the response as data arrives: + +```bash +curl -N \ + -H 'Accept: text/event-stream' \ + -H 'Content-Type: application/json' \ + -d '{"input":{"prompt":"Write a haiku about onions"}}' \ + http://localhost:5000/predictions +``` + +For browser clients, +use `fetch()` or another client that supports request bodies. +The browser `EventSource` API only supports `GET` requests, +so it cannot create a prediction with `POST /predictions` or +`PUT /predictions/`. + +```js +const response = await fetch("/predictions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + }, + body: JSON.stringify({ input: { prompt: "Write a haiku about onions" } }), +}); + +const reader = response.body + .pipeThrough(new TextDecoderStream()) + .getReader(); + +while (true) { + const { value, done } = await reader.read(); + if (done) break; + console.log(value); +} +``` + +Use `PUT /predictions/` when the client needs safe retries +or wants to reconnect to an in-flight prediction by ID: + +```http +PUT /predictions/wjx3whax6rf4vphkegkhcvpv6a HTTP/1.1 +Content-Type: application/json; charset=utf-8 +Accept: text/event-stream + +{ + "input": {"prompt": "Write a haiku about onions"} +} +``` + +If the prediction is still running, +the server returns a stream for the existing prediction +instead of creating a duplicate prediction. +If earlier events have been dropped from the replay buffer, +the stream emits an `error` event and closes. +The replay buffer keeps the most recent 1024 events by default. +Set `COG_STREAM_HISTORY_CAPACITY` to change this limit, +or set it to `0` to disable replay history while keeping live streaming enabled. +Training endpoints do not support SSE streaming; +requests to `/trainings` with `Accept: text/event-stream` +return `406 Not Acceptable`. + ## Webhooks You can provide a `webhook` parameter in the client request body @@ -1508,6 +1813,11 @@ Content-Type: application/json } ``` +If the client sets the `Accept: text/event-stream` header, +the server starts the prediction asynchronously and responds with a +server-sent event stream. +See [Streaming predictions with server-sent events](#streaming-predictions-with-server-sent-events). + ### `PUT /predictions/` Make a single prediction. @@ -1556,6 +1866,13 @@ Content-Type: application/json } ``` +If the client sets the `Accept: text/event-stream` header, +the server starts the prediction asynchronously and responds with a +server-sent event stream. +If a prediction with the same ID is already running, +the server returns a stream for the existing prediction. +See [Streaming predictions with server-sent events](#streaming-predictions-with-server-sent-events). + ### `POST /predictions//cancel` A client can cancel an asynchronous prediction by making a diff --git a/integration-tests/tests/sse_stream_history_capacity.txtar b/integration-tests/tests/sse_stream_history_capacity.txtar new file mode 100644 index 0000000000..42b640b572 --- /dev/null +++ b/integration-tests/tests/sse_stream_history_capacity.txtar @@ -0,0 +1,41 @@ +# Test configurable SSE replay history capacity. + +[short] skip 'requires Docker build' + +cog serve --upload-url http://unused/ + +# Capacity 2 should drop older replay events and close with an error for late subscribers. +curl -H Prefer:respond-async PUT /predictions/replay-truncated '{"id":"replay-truncated","input":{"count":4,"sleep_after":2}}' +stdout '"status":"starting"' + +exec sleep 1 + +curl -H Accept:text/event-stream PUT /predictions/replay-truncated '{"id":"replay-truncated","input":{"count":4,"sleep_after":2}}' +stdout 'event: error' +stdout 'SSE stream replay truncated' +stdout '"skipped":3' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" +environment: + - COG_STREAM_HISTORY_CAPACITY=2 + +-- predict.py -- +import time +from typing import Iterator + +from cog import BasePredictor, Input, streaming + + +class Predictor(BasePredictor): + @streaming + def predict( + self, + count: int = Input(default=4), + sleep_after: float = Input(default=2.0), + ) -> Iterator[str]: + for index in range(count): + yield f"chunk-{index}" + time.sleep(sleep_after) diff --git a/integration-tests/tests/sse_stream_history_disabled.txtar b/integration-tests/tests/sse_stream_history_disabled.txtar new file mode 100644 index 0000000000..6a36d9e88d --- /dev/null +++ b/integration-tests/tests/sse_stream_history_disabled.txtar @@ -0,0 +1,41 @@ +# Test that COG_STREAM_HISTORY_CAPACITY=0 disables SSE replay while live streaming still works. + +[short] skip 'requires Docker build' + +cog serve --upload-url http://unused/ + +curl -H Prefer:respond-async PUT /predictions/replay-disabled '{"id":"replay-disabled","input":{"count":3,"sleep_after":1}}' +stdout '"status":"starting"' + +exec sleep 0.5 + +curl -H Accept:text/event-stream PUT /predictions/replay-disabled '{"id":"replay-disabled","input":{"count":3,"sleep_after":1}}' +stdout 'event: completed' +stdout '"status":"succeeded"' +! stdout 'event: output' +! stdout 'SSE stream replay truncated' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" +environment: + - COG_STREAM_HISTORY_CAPACITY=0 + +-- predict.py -- +import time +from typing import Iterator + +from cog import BasePredictor, Input, streaming + + +class Predictor(BasePredictor): + @streaming + def predict( + self, + count: int = Input(default=3), + sleep_after: float = Input(default=1.0), + ) -> Iterator[str]: + for index in range(count): + yield f"chunk-{index}" + time.sleep(sleep_after) diff --git a/integration-tests/tests/sse_streaming_metrics.txtar b/integration-tests/tests/sse_streaming_metrics.txtar new file mode 100644 index 0000000000..e2eebaa475 --- /dev/null +++ b/integration-tests/tests/sse_streaming_metrics.txtar @@ -0,0 +1,42 @@ +# Test that custom metrics are emitted as SSE metric events and included in the completed payload. + +[short] skip 'requires Docker build' + +cog serve --upload-url http://unused/ + +curl -H Accept:text/event-stream PUT /predictions/sse-metrics '{"id":"sse-metrics","input":{}}' +stdout 'event: start' +stdout 'event: metric' +stdout '"name":"temperature"' +stdout '"value":0.7' +stdout '"mode":"replace"' +stdout 'event: output' +stdout 'data: {"chunk":"chunk-1","index":0}' +stdout 'event: metric' +stdout '"name":"token_count"' +stdout '"value":2' +stdout '"mode":"increment"' +stdout 'event: completed' +stdout '"status":"succeeded"' +stdout '"temperature":0.7' +stdout '"token_count":2' +stdout '"predict_time"' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from typing import Iterator + +from cog import BasePredictor, current_scope, streaming + + +class Predictor(BasePredictor): + @streaming + def predict(self) -> Iterator[str]: + scope = current_scope() + scope.record_metric("temperature", 0.7) + yield "chunk-1" + scope.record_metric("token_count", 2, mode="incr") From 686adc272890e95c9f078f1b81bd275075bfcddb Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 21 May 2026 15:54:46 -0400 Subject: [PATCH 11/11] fix: address streaming review regressions --- crates/coglet/src/service.rs | 23 ++++++---- crates/coglet/src/transport/http/routes.rs | 20 ++++++--- docs/http.md | 9 ++-- docs/llms.txt | 11 +++-- docs/python.md | 2 +- examples/streaming-text/README.md | 4 +- examples/streaming-text/predict.py | 17 ++++---- pkg/schema/openapi_test.go | 17 ++++++++ pkg/schema/python/parser.go | 16 ++++--- pkg/schema/python/parser_test.go | 50 ++++++++++++++++++++++ python/cog/__init__.py | 7 +-- python/tests/test_types.py | 8 ++-- 12 files changed, 135 insertions(+), 49 deletions(-) diff --git a/crates/coglet/src/service.rs b/crates/coglet/src/service.rs index 5904b20049..20b22300af 100644 --- a/crates/coglet/src/service.rs +++ b/crates/coglet/src/service.rs @@ -61,6 +61,8 @@ pub enum SubscribePredictionStreamError { NotFound, #[error("Too many stream subscribers")] TooManySubscribers, + #[error("Prediction stream unavailable")] + Unavailable, } /// Snapshot of service health for transports to query. @@ -124,6 +126,8 @@ pub struct PredictionStreamSubscription { id: String, replay: std::collections::VecDeque, skipped: u64, + // Drop order matters: receiver must drop before guard so stream_receiver_count() + // reaches zero before the guard decides whether to cancel on disconnect. receiver: tokio::sync::broadcast::Receiver, guard: PredictionStreamGuard, } @@ -561,10 +565,9 @@ impl PredictionService { .get(id) .ok_or(SubscribePredictionStreamError::NotFound)?; let stream = { - let prediction = entry - .prediction - .lock() - .map_err(|_| SubscribePredictionStreamError::NotFound)?; + let Some(prediction) = try_lock_prediction(&entry.prediction) else { + return Err(SubscribePredictionStreamError::Unavailable); + }; if prediction.stream_receiver_count() >= MAX_STREAM_SUBSCRIBERS { return Err(SubscribePredictionStreamError::TooManySubscribers); } @@ -762,11 +765,13 @@ impl PredictionService { // Delegate to orchestrator to actually cancel the worker-side prediction. // This must be non-blocking since cancel() is sync, so we spawn a task. let id_owned = id.to_string(); - let orchestrator = self - .orchestrator - .try_read() - .ok() - .and_then(|guard| guard.as_ref().map(|s| Arc::clone(&s.orchestrator))); + let orchestrator = match self.orchestrator.try_read() { + Ok(guard) => guard.as_ref().map(|s| Arc::clone(&s.orchestrator)), + Err(_) => { + tracing::warn!(prediction_id = %id, "Skipped worker cancel: orchestrator lock unavailable"); + None + } + }; if let Some(orch) = orchestrator { spawn_orchestrator_cancel(orch, id_owned); } diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index 499ae65ecc..5904737391 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -495,7 +495,13 @@ async fn create_prediction_with_id( // Async mode: spawn background task, return immediately if response_mode != PredictionResponseMode::SyncJson { let sse_subscription = if response_mode == PredictionResponseMode::AsyncSse { - Some(service.subscribe_prediction_stream(&prediction_id)) + match service.subscribe_prediction_stream(&prediction_id) { + Ok(subscription) => Some(subscription), + Err(error) => { + service.remove_prediction(&prediction_id); + return stream_subscription_error_response(error); + } + } } else { None }; @@ -513,10 +519,7 @@ async fn create_prediction_with_id( }); if response_mode == PredictionResponseMode::AsyncSse { - let subscription = match sse_subscription.expect("SSE subscription requested") { - Ok(subscription) => subscription, - Err(error) => return stream_subscription_error_response(error), - }; + let subscription = sse_subscription.expect("SSE subscription requested"); return stream_prediction_subscription_response(subscription); } @@ -681,6 +684,8 @@ fn prediction_sse_stream( struct StreamState { replay: std::collections::VecDeque, replay_skipped: u64, + // Drop order matters: receiver must drop before guard so stream_receiver_count() + // reaches zero before the guard decides whether to cancel on disconnect. receiver: tokio::sync::broadcast::Receiver, _guard: crate::service::PredictionStreamGuard, done: bool, @@ -763,6 +768,11 @@ fn stream_subscription_error_response(error: SubscribePredictionStreamError) -> Json(serde_json::json!({"error": "Too many stream subscribers"})), ) .into_response(), + SubscribePredictionStreamError::Unavailable => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({"error": "Prediction stream unavailable"})), + ) + .into_response(), } } diff --git a/docs/http.md b/docs/http.md index dedd20dc40..81fa737d13 100644 --- a/docs/http.md +++ b/docs/http.md @@ -81,8 +81,9 @@ class Runner(BaseRunner): yield token ``` -The decorator can also be written as `@streaming()`, -`@cog.streaming`, or `@cog.streaming()`. +The decorator can also be written as `@cog.streaming` +or, if imported directly from `cog`, `@streaming`. +The parenthesized forms `@cog.streaming()` and `@streaming()` are also accepted. Without the decorator, iterator outputs still work in normal JSON responses, but requests with `Accept: text/event-stream` return `406 Not Acceptable`. @@ -158,9 +159,7 @@ const response = await fetch("/predictions", { body: JSON.stringify({ input: { prompt: "Write a haiku about onions" } }), }); -const reader = response.body - .pipeThrough(new TextDecoderStream()) - .getReader(); +const reader = response.body.pipeThrough(new TextDecoderStream()).getReader(); while (true) { const { value, done } = await reader.read(); diff --git a/docs/llms.txt b/docs/llms.txt index f17c1fb9b4..6b59681765 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -1386,8 +1386,9 @@ class Runner(BaseRunner): yield token ``` -The decorator can also be written as `@streaming()`, -`@cog.streaming`, or `@cog.streaming()`. +The decorator can also be written as `@cog.streaming` +or, if imported directly from `cog`, `@streaming`. +The parenthesized forms `@cog.streaming()` and `@streaming()` are also accepted. Without the decorator, iterator outputs still work in normal JSON responses, but requests with `Accept: text/event-stream` return `406 Not Acceptable`. @@ -1463,9 +1464,7 @@ const response = await fetch("/predictions", { body: JSON.stringify({ input: { prompt: "Write a haiku about onions" } }), }); -const reader = response.body - .pipeThrough(new TextDecoderStream()) - .getReader(); +const reader = response.body.pipeThrough(new TextDecoderStream()).getReader(); while (true) { const { value, done } = await reader.read(); @@ -2296,7 +2295,7 @@ Cog models can stream output as the `run()` method is running. For example, a la To support streaming output in your Cog model, add `from typing import Iterator` to your `run.py` file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `run()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. -To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the prediction method (`run()` or `predict()`) with `@cog.streaming`, `@cog.streaming()`, `@streaming`, or `@streaming()` imported from `cog`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. +To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the prediction method (`run()` or `predict()`) with `@cog.streaming` (or `@streaming` if imported directly from `cog`). The parenthesized forms `@cog.streaming()` and `@streaming()` are also accepted. The decorated method must return `Iterator[...]`, `AsyncIterator[...]`, `ConcatenateIterator[...]`, or `AsyncConcatenateIterator[...]`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. ```py from typing import Iterator diff --git a/docs/python.md b/docs/python.md index d100d4cc0a..8ccab4208a 100644 --- a/docs/python.md +++ b/docs/python.md @@ -263,7 +263,7 @@ Cog models can stream output as the `run()` method is running. For example, a la To support streaming output in your Cog model, add `from typing import Iterator` to your `run.py` file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `run()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. -To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the prediction method (`run()` or `predict()`) with `@cog.streaming`, `@cog.streaming()`, `@streaming`, or `@streaming()` imported from `cog`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. +To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the prediction method (`run()` or `predict()`) with `@cog.streaming` (or `@streaming` if imported directly from `cog`). The parenthesized forms `@cog.streaming()` and `@streaming()` are also accepted. The decorated method must return `Iterator[...]`, `AsyncIterator[...]`, `ConcatenateIterator[...]`, or `AsyncConcatenateIterator[...]`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. ```py from typing import Iterator diff --git a/examples/streaming-text/README.md b/examples/streaming-text/README.md index 645664289b..77265e6f1e 100644 --- a/examples/streaming-text/README.md +++ b/examples/streaming-text/README.md @@ -2,7 +2,7 @@ Streaming text generation with `HuggingFaceTB/SmolLM2-135M-Instruct`. -This example shows how a Cog predictor can yield text chunks as a model generates them, and how to consume those chunks with Server-Sent Events. +This example shows how a Cog runner can yield text chunks as a model generates them, and how to consume those chunks with Server-Sent Events. ## Run a normal prediction @@ -46,6 +46,6 @@ data: {"id":"streaming-demo","status":"succeeded",...} ## How it works -`predict.py` returns `Iterator[str]`. Each `yield` becomes one streamed output chunk. The example uses Hugging Face `TextIteratorStreamer` to receive generated text from `model.generate()` while generation is still running. +`predict.py` defines `run() -> Iterator[str]`. Each `yield` becomes one streamed output chunk. The example uses Hugging Face `TextIteratorStreamer` to receive generated text from `model.generate()` while generation is still running. The normal prediction response still contains the accumulated output for compatibility. Requesting `Accept: text/event-stream` is useful when clients want to display tokens as they arrive. diff --git a/examples/streaming-text/predict.py b/examples/streaming-text/predict.py index 4f0d11c462..c94ddb04f3 100644 --- a/examples/streaming-text/predict.py +++ b/examples/streaming-text/predict.py @@ -4,12 +4,12 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer -from cog import BasePredictor, Input, streaming +from cog import BaseRunner, Input, streaming MODEL_NAME = "HuggingFaceTB/SmolLM2-135M-Instruct" -class Predictor(BasePredictor): +class Predictor(BaseRunner): def setup(self) -> None: self.device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if self.device == "cuda" else torch.float32 @@ -22,7 +22,7 @@ def setup(self) -> None: self.model.eval() @streaming - def predict( + def run( self, prompt: str = Input(description="Prompt to complete"), max_new_tokens: int = Input( @@ -58,8 +58,9 @@ def predict( thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() - for chunk in streamer: - if chunk: - yield chunk - - thread.join() + try: + for chunk in streamer: + if chunk: + yield chunk + finally: + thread.join() diff --git a/pkg/schema/openapi_test.go b/pkg/schema/openapi_test.go index e3d1b9e80c..d9b528bb9f 100644 --- a/pkg/schema/openapi_test.go +++ b/pkg/schema/openapi_test.go @@ -689,6 +689,23 @@ func TestPredictionOperationOmitsStreamingExtensionByDefault(t *testing.T) { assert.False(t, ok) } +func TestTrainingOperationOmitsStreamingExtensionWhenEnabled(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + info := &PredictorInfo{ + Inputs: inputs, + Output: SchemaIteratorOf(SchemaPrim(TypeString)), + Mode: ModeTrain, + SupportsStreaming: true, + } + + spec := parseSpec(t, info) + postPath := getPath(spec, "paths", "/trainings", "post") + require.NotNil(t, postPath) + post := postPath.(map[string]any) + _, ok := post["x-cog-streaming"] + assert.False(t, ok) +} + func TestOutputObject(t *testing.T) { inputs := NewOrderedMap[string, InputField]() fields := NewOrderedMap[string, SchemaField]() diff --git a/pkg/schema/python/parser.go b/pkg/schema/python/parser.go index a32c73a602..e17e9d8a58 100644 --- a/pkg/schema/python/parser.go +++ b/pkg/schema/python/parser.go @@ -76,7 +76,6 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi if err != nil { return nil, err } - supportsStreaming := functionSupportsStreaming(target.node, target.file.source, target.file.imports) funcNode := target.node targetSource := target.file.source actualMethodName := methodName @@ -117,6 +116,10 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi if err != nil { return nil, err } + supportsStreaming := functionSupportsStreaming(target.node, target.file.source, target.file.imports) + if supportsStreaming && !supportsStreamingOutput(output) { + return nil, schema.WrapError(schema.ErrUnsupportedType, "@streaming requires Iterator[...] or ConcatenateIterator[...] return type", nil) + } return &schema.PredictorInfo{ Inputs: inputs, @@ -727,13 +730,14 @@ func attributeIsCogStreaming(node *sitter.Node, source []byte, imports *schema.I } func identifierIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { - if Content(node, source) != "streaming" { - return false - } - entry, ok := imports.Names.Get("streaming") + entry, ok := imports.Names.Get(Content(node, source)) return ok && entry.Module == "cog" && entry.Original == "streaming" } +func supportsStreamingOutput(output schema.SchemaType) bool { + return output.Kind == schema.SchemaIterator || output.Kind == schema.SchemaConcatIterator +} + func InheritsFromBaseModel(classNode *sitter.Node, source []byte, imports *schema.ImportContext) bool { supers := classNode.ChildByFieldName("superclasses") if supers == nil { @@ -1514,7 +1518,7 @@ func findMethodInClass(classNode *sitter.Node, source []byte, className, methodN } nameNode := funcNode.ChildByFieldName("name") if nameNode != nil && Content(nameNode, source) == methodName { - return child, nil + return funcNode, nil } } diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index a94ab04a37..30fc1c15e3 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -617,6 +617,34 @@ class Predictor(BasePredictor): require.True(t, info.SupportsStreaming) } +func TestStreamingDecoratorImportedAliasOptIn(t *testing.T) { + source := ` +from cog import BasePredictor, streaming as stream +from typing import Iterator + +class Predictor(BasePredictor): + @stream + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.True(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorRequiresIteratorOutput(t *testing.T) { + source := ` +from cog import BasePredictor, streaming + +class Predictor(BasePredictor): + @streaming + def predict(self) -> str: + return "hello" +` + se := parseErr(t, source, "Predictor", schema.ModePredict) + require.Equal(t, schema.ErrUnsupportedType, se.Kind) + require.Contains(t, se.Message, "@streaming requires") +} + func TestStreamingDecoratorIgnoredWhenNotFromCog(t *testing.T) { source := ` from other import streaming @@ -759,6 +787,28 @@ def train(n: int) -> Path: require.Equal(t, 1, info.Inputs.Len()) } +func TestTrainModeClassDecoratedMethod(t *testing.T) { + source := ` +from functools import wraps +from cog import BasePredictor, Path + +def noop(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + return wrapper + +class Trainer(BasePredictor): + @noop + def train(self, n: int) -> Path: + pass +` + info, err := ParsePredictor([]byte(source), "Trainer", schema.ModeTrain, "") + require.NoError(t, err) + require.Equal(t, schema.ModeTrain, info.Mode) + require.Equal(t, 1, info.Inputs.Len()) +} + // --------------------------------------------------------------------------- // Non-BasePredictor class (just has predict method) // --------------------------------------------------------------------------- diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 7745a8b627..775e2ba0e9 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -44,18 +44,19 @@ def run( @overload -def streaming(fn: _F) -> _F: ... +def streaming(fn: _F) -> _F: + pass @overload -def streaming(fn: None = None) -> Callable[[_F], _F]: ... +def streaming(fn: None = None) -> Callable[[_F], _F]: + pass def streaming(fn: _F | None = None) -> _F | Callable[[_F], _F]: """Mark a predict handler as supporting streaming responses.""" def decorate(inner: _F) -> _F: - inner.__cog_streaming__ = True # type: ignore[attr-defined] return inner if fn is None: diff --git a/python/tests/test_types.py b/python/tests/test_types.py index 3981ccdcef..fdbbd6622b 100644 --- a/python/tests/test_types.py +++ b/python/tests/test_types.py @@ -138,20 +138,20 @@ def test_async_concatenate_iterator_is_abstract(self) -> None: class TestStreamingDecorator: """Tests for the streaming opt-in decorator.""" - def test_streaming_marks_function_and_returns_same_object(self) -> None: + def test_streaming_returns_same_object(self) -> None: def predict() -> str: return "ok" decorated = streaming(predict) assert decorated is predict - assert predict.__cog_streaming__ is True # type: ignore[attr-defined] + assert not hasattr(predict, "__cog_streaming__") - def test_streaming_call_form_marks_function_and_returns_same_object(self) -> None: + def test_streaming_call_form_returns_same_object(self) -> None: def predict() -> str: return "ok" decorated = streaming()(predict) assert decorated is predict - assert predict.__cog_streaming__ is True # type: ignore[attr-defined] + assert not hasattr(predict, "__cog_streaming__")