diff --git a/README.md b/README.md index 5008ca0..c939db5 100644 --- a/README.md +++ b/README.md @@ -832,6 +832,37 @@ The `IAudioInference` class supports the following parameters: - `duration`: Duration of the generated audio in seconds - `includeCost`: Whether to include cost information in the response +### Text inference streaming + +To stream text inference (e.g. LLM chat) over HTTP SSE, set `deliveryMethod="stream"`. The SDK yields content chunks (strings) and a final `IText` with usage and cost: + +```python +import asyncio +from runware import Runware, ITextInference, ITextInferenceMessage + +async def main() -> None: + runware = Runware(api_key=RUNWARE_API_KEY) + await runware.connect() + + request = ITextInference( + model="runware:qwen3-thinking@1", + messages=[ITextInferenceMessage(role="user", content="Explain photosynthesis in one sentence.")], + deliveryMethod="stream", + includeCost=True, + ) + + stream = await runware.textInference(request) + async for chunk in stream: + if isinstance(chunk, str): + print(chunk, end="", flush=True) + else: + print(chunk) + +asyncio.run(main()) +``` + +Streaming uses the same concurrency limit as other requests (`RUNWARE_MAX_CONCURRENT_REQUESTS`). To allow longer streams, set `RUNWARE_TEXT_STREAM_TIMEOUT` (milliseconds; default 600000). + ### Model Upload To upload model using the Runware API, you can use the `uploadModel` method of the `Runware` class. Here are examples: @@ -1068,6 +1099,9 @@ RUNWARE_AUDIO_INFERENCE_TIMEOUT=300000 # Audio generation (default: 5 min) RUNWARE_AUDIO_POLLING_DELAY=1000 # Delay between status checks (default: 1 sec) RUNWARE_MAX_POLLS_AUDIO_GENERATION=240 # Max polling attempts for audio inference (default: 240, ~4 min total) +# Text Operations (milliseconds) +RUNWARE_TEXT_STREAM_TIMEOUT=600000 # Text inference streaming (SSE) read timeout (default: 10 min) + # Other Operations (milliseconds) RUNWARE_PROMPT_ENHANCE_TIMEOUT=60000 # Prompt enhancement (default: 1 min) RUNWARE_WEBHOOK_TIMEOUT=30000 # Webhook acknowledgment (default: 30 sec) diff --git a/requirements.txt b/requirements.txt index e4cce40..4d38ca9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ aiofiles>=23.2.1 +httpx>=0.27.0 python-dotenv>=1.0.1 -websockets>=12.0 \ No newline at end of file +websockets>=12.0 diff --git a/runware/base.py b/runware/base.py index b060f80..66a94a6 100644 --- a/runware/base.py +++ b/runware/base.py @@ -1,5 +1,6 @@ import asyncio import inspect +import json import logging import os import re @@ -7,8 +8,9 @@ from dataclasses import asdict, is_dataclass, fields from enum import Enum from random import uniform -from typing import List, Optional, Union, Callable, Any, Dict, Tuple +from typing import List, Optional, Union, Callable, Any, Dict, Tuple, AsyncIterator +import httpx from websockets.protocol import State from .logging_config import configure_logging @@ -62,11 +64,13 @@ IUploadMediaRequest, ITextInference, IText, + ITextInputs, ) from .types import IImage, IError, SdkType, ListenerType from .utils import ( BASE_RUNWARE_URLS, getUUID, + get_http_url_from_ws_url, fileToBase64, createImageFromResponse, createEnhancedPromptsFromResponse, @@ -84,6 +88,7 @@ createAsyncTaskResponse, VIDEO_INITIAL_TIMEOUT, TEXT_INITIAL_TIMEOUT, + TEXT_STREAM_READ_TIMEOUT, VIDEO_POLLING_DELAY, WEBHOOK_TIMEOUT, IMAGE_INFERENCE_TIMEOUT, @@ -2092,7 +2097,20 @@ async def _inference3d(self, request3d: I3dInference) -> Union[List[I3d], IAsync await self.ensureConnection() return await self._request3d(request3d) - async def textInference(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]: + async def textInference( + self, requestText: ITextInference + ) -> Union[List[IText], IAsyncTaskResponse, AsyncIterator[Union[str, IText]]]: + delivery_method_enum = ( + requestText.deliveryMethod + if isinstance(requestText.deliveryMethod, EDeliveryMethod) + else EDeliveryMethod(requestText.deliveryMethod) + ) + if delivery_method_enum == EDeliveryMethod.STREAM: + async def stream_with_semaphore() -> AsyncIterator[Union[str, IText]]: + async with self._request_semaphore: + async for chunk in self._requestTextStream(requestText): + yield chunk + return stream_with_semaphore() async with self._request_semaphore: return await self._retry_async_with_reconnect( self._requestText, @@ -2281,26 +2299,129 @@ def _buildTextRequest(self, requestText: ITextInference) -> Dict[str, Any]: "deliveryMethod": requestText.deliveryMethod, "messages": [asdict(m) for m in requestText.messages], } - if requestText.maxTokens is not None: - request_object["maxTokens"] = requestText.maxTokens - if requestText.temperature is not None: - request_object["temperature"] = requestText.temperature - if requestText.topP is not None: - request_object["topP"] = requestText.topP - if requestText.topK is not None: - request_object["topK"] = requestText.topK if requestText.seed is not None: request_object["seed"] = requestText.seed - if requestText.stopSequences is not None: - request_object["stopSequences"] = requestText.stopSequences if requestText.includeCost is not None: request_object["includeCost"] = requestText.includeCost + if requestText.includeUsage is not None: + request_object["includeUsage"] = requestText.includeUsage + if requestText.numberResults is not None: + request_object["numberResults"] = requestText.numberResults + self._addOptionalField(request_object, requestText.settings) + self._addOptionalField(request_object, requestText.inputs) self._addProviderSettings(request_object, requestText) return request_object + async def _message_from_http_status_error(self, exc: httpx.HTTPStatusError) -> str: + """ + Build a short, user-facing message from an HTTP error response. + Matches WebSocket auth errors where possible (e.g. invalid API key). + """ + resp = exc.response + try: + await resp.aread() + except Exception: + pass + status = resp.status_code + try: + data = resp.json() + if isinstance(data, dict): + msg = data.get("message") + if isinstance(msg, str) and msg.strip(): + return msg.strip() + err = data.get("error") + if isinstance(err, dict): + inner = err.get("message") + if isinstance(inner, str) and inner.strip(): + return inner.strip() + if isinstance(err, str) and err.strip(): + return err.strip() + except Exception: + pass + if status == 401: + return "Invalid API key. Get one at https://my.runware.ai/signup" + return f"HTTP {status} error for {resp.request.url}" + + async def _requestTextStream( + self, requestText: ITextInference + ) -> AsyncIterator[Union[str, IText]]: + requestText.taskUUID = requestText.taskUUID or getUUID() + request_object = self._buildTextRequest(requestText) + body = [request_object] + http_url = get_http_url_from_ws_url(self._url or "") + headers = { + "Accept": "text/event-stream", + "Authorization": f"Bearer {self._apiKey}", + "Content-Type": "application/json", + } + accumulated_text = "" + try: + async with httpx.AsyncClient(timeout=TEXT_STREAM_READ_TIMEOUT / 1000) as client: + async with client.stream( + "POST", + http_url, + json=body, + headers=headers, + ) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if not line: + continue + payload = line.replace("data:", "", 1).strip() + if payload == "[DONE]": + return + try: + line_obj = json.loads(payload) + except json.JSONDecodeError: + continue + data = line_obj.get("data") or line_obj + if data.get("error") is not None: + raise RunwareAPIError(data["error"]) + + delta = data.get("delta") or {} + finishReason = data.get("finishReason") + + content_chunk = delta.get("text") + if content_chunk: + accumulated_text += content_chunk + yield content_chunk + if finishReason is not None: + yield instantiateDataclass( + IText, + { + **data, + "taskType": data.get("taskType"), + "text": data.get("text") or accumulated_text, + "finishReason": finishReason, + }, + ) + return + except httpx.HTTPStatusError as e: + msg = await self._message_from_http_status_error(e) + if e.response.status_code == 401: + self._invalidAPIkey = msg + self._reconnection_manager.on_auth_failure() + raise ConnectionError(msg) from e + raise RunwareAPIError({"message": msg, "statusCode": e.response.status_code}) from e + except RunwareAPIError: + raise + except Exception as e: + raise RunwareAPIError({"message": str(e)}) from e + async def _requestText(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]: await self.ensureConnection() requestText.taskUUID = requestText.taskUUID or getUUID() + + + if requestText.inputs: + inputs = requestText.inputs + if isinstance(inputs, dict): + inputs = ITextInputs(**inputs) + requestText.inputs = inputs + + if inputs.images: + inputs.images = await process_image(inputs.images) + request_object = self._buildTextRequest(requestText) if requestText.webhookURL: diff --git a/runware/types.py b/runware/types.py index d1fdf3a..246e8e8 100644 --- a/runware/types.py +++ b/runware/types.py @@ -106,6 +106,7 @@ class EOpenPosePreProcessor(Enum): class EDeliveryMethod(Enum): SYNC = "sync" ASYNC = "async" + STREAM = "stream" class OperationState(Enum): """State machine for pending operations.""" @@ -810,9 +811,26 @@ def request_key(self) -> str: return "texSlat" +@dataclass +class ITextInferenceTool(SerializableMixin): + """Tool definition for text inference (e.g. function-calling / JSON-schema tools).""" + + name: str + description: str + input_schema: Dict[str, Any] + + +@dataclass +class ITextInferenceToolChoice(SerializableMixin): + """Selects how tools are used (provider-specific shape, e.g. type + name).""" + + type: str + name: Optional[str] = None + + @dataclass class ISettings(SerializableMixin): - # Image + # Image / Text temperature: Optional[float] = None systemPrompt: Optional[str] = None topP: Optional[float] = None @@ -851,6 +869,13 @@ class ISettings(SerializableMixin): expressiveness: Optional[str] = None removeBackground: Optional[bool] = None backgroundColor: Optional[str] = None + # Text + maxTokens: Optional[int] = None + topK: Optional[int] = None + stopSequences: Optional[List[str]] = None + thinkingLevel: Optional[str] = None + tools: Optional[List[Union[ITextInferenceTool, Dict[str, Any]]]] = None + toolChoice: Optional[Union[ITextInferenceToolChoice, Dict[str, Any]]] = None def __post_init__(self): if self.sparseStructure is not None and isinstance(self.sparseStructure, dict): @@ -859,6 +884,13 @@ def __post_init__(self): self.shapeSlat = IShapeSlat(**self.shapeSlat) if self.texSlat is not None and isinstance(self.texSlat, dict): self.texSlat = ITexSlat(**self.texSlat) + if self.tools is not None: + self.tools = [ + ITextInferenceTool(**t) if isinstance(t, dict) else t + for t in self.tools + ] + if self.toolChoice is not None and isinstance(self.toolChoice, dict): + self.toolChoice = ITextInferenceToolChoice(**self.toolChoice) @property def request_key(self) -> str: @@ -909,6 +941,15 @@ def __post_init__(self): self.referenceImages = self.references +@dataclass +class ITextInputs(SerializableMixin): + images: Optional[List[Union[str, File]]] = None + + @property + def request_key(self) -> str: + return "inputs" + + @dataclass class IAudioInput(SerializableMixin): id: Optional[str] = None @@ -1400,6 +1441,8 @@ class IGoogleProviderSettings(BaseProviderSettings): generateAudio: Optional[bool] = None enhancePrompt: Optional[bool] = None search: Optional[bool] = None + searchLatitude: Optional[float] = None + searchLongitude: Optional[float] = None resizeMode: Optional[str] = None safetyTolerance: Optional[str] = None @@ -1788,24 +1831,61 @@ class ITextInferenceMessage: content: str +@dataclass +class ITextInferenceCompletionTokensDetails: + reasoningTokens: Optional[int] = None + + +@dataclass +class ITextInferenceUsageModality: + modality: Optional[str] = None + tokens: Optional[int] = None + cost: Optional[float] = None + costDisplay: Optional[str] = None + + +@dataclass +class ITextInferenceUsageTokenPromptCache: + modalities: Optional[List[ITextInferenceUsageModality]] = None + billableTokens: Optional[int] = None + cost: Optional[float] = None + costDisplay: Optional[str] = None + + +@dataclass +class ITextInferenceUsageTokenCompletion: + billableTokens: Optional[int] = None + textTokens: Optional[int] = None + reasoningTokens: Optional[int] = None + cost: Optional[float] = None + costDisplay: Optional[str] = None + + +@dataclass +class ITextInferenceUsageTokensBreakdown: + prompt: Optional[ITextInferenceUsageTokenPromptCache] = None + cache: Optional[ITextInferenceUsageTokenPromptCache] = None + completion: Optional[ITextInferenceUsageTokenCompletion] = None + + +@dataclass +class ITextInferenceUsageCostBreakdown: + tokens: Optional[ITextInferenceUsageTokensBreakdown] = None + total: Optional[float] = None + totalDisplay: Optional[str] = None + + @dataclass class ITextInferenceUsage: promptTokens: Optional[int] = None completionTokens: Optional[int] = None totalTokens: Optional[int] = None thinkingTokens: Optional[int] = None + completionTokensDetails: Optional[ITextInferenceCompletionTokensDetails] = None + costBreakdown: Optional[ITextInferenceUsageCostBreakdown] = None -@dataclass -class IGoogleTextProviderSettings(BaseProviderSettings): - thinkingLevel: Optional[str] = None - - @property - def provider_key(self) -> str: - return "google" - - -TextProviderSettings = IGoogleTextProviderSettings +TextProviderSettings = IGoogleProviderSettings @dataclass @@ -1815,16 +1895,20 @@ class ITextInference: taskUUID: Optional[str] = None deliveryMethod: str = "sync" numberResults: Optional[int] = 1 - maxTokens: Optional[int] = None - temperature: Optional[float] = None - topP: Optional[float] = None - topK: Optional[int] = None - seed: Optional[int] = None - stopSequences: Optional[List[str]] = None + seed: Optional[int] = None includeCost: Optional[bool] = None + includeUsage: Optional[bool] = None + settings: Optional[Union[ISettings, Dict[str, Any]]] = None + inputs: Optional[Union[ITextInputs, Dict[str, Any]]] = None providerSettings: Optional[TextProviderSettings] = None webhookURL: Optional[str] = None + def __post_init__(self) -> None: + if self.settings is not None and isinstance(self.settings, dict): + self.settings = ISettings(**self.settings) + if self.inputs is not None and isinstance(self.inputs, dict): + self.inputs = ITextInputs(**self.inputs) + @dataclass class IText: @@ -1835,6 +1919,9 @@ class IText: usage: Optional[ITextInferenceUsage] = None cost: Optional[float] = None status: Optional[str] = None + reasoningContent: Optional[List[str]] = None + seed: Optional[int] = None + thoughtSignature: Optional[str] = None @dataclass diff --git a/runware/utils.py b/runware/utils.py index 55661d2..52160e6 100644 --- a/runware/utils.py +++ b/runware/utils.py @@ -43,6 +43,25 @@ Environment.TEST: "ws://localhost:8080", } +# HTTP REST base URL for streaming (e.g. textInference with deliveryMethod=stream) +BASE_RUNWARE_HTTP_URLS = { + Environment.PRODUCTION: "https://api.runware.ai/v1", + Environment.TEST: "http://localhost:8080", +} + +# Map each WebSocket base URL to its HTTP counterpart (for streaming requests). +_WS_TO_HTTP = { + BASE_RUNWARE_URLS[Environment.PRODUCTION]: BASE_RUNWARE_HTTP_URLS[Environment.PRODUCTION], + BASE_RUNWARE_URLS[Environment.TEST]: BASE_RUNWARE_HTTP_URLS[Environment.TEST], +} + + +def get_http_url_from_ws_url(ws_url: str) -> str: + """Return the HTTP URL for this ws_url from _WS_TO_HTTP.""" + if not ws_url: + return BASE_RUNWARE_HTTP_URLS[Environment.PRODUCTION] + return _WS_TO_HTTP.get(ws_url, BASE_RUNWARE_HTTP_URLS[Environment.PRODUCTION]) + RETRY_SDK_COUNTS = { "GLOBAL": 2, @@ -126,6 +145,14 @@ 30000 )) +# Text streaming read timeout (milliseconds) +# Maximum time to wait for data on the SSE stream; long to avoid ReadTimeout mid-stream +# Used in: _requestTextStream() for deliveryMethod=stream +TEXT_STREAM_READ_TIMEOUT = int(os.environ.get( + "RUNWARE_TEXT_STREAM_TIMEOUT", + 600000 +)) + # Audio generation timeout (milliseconds) # Maximum time to wait for audio generation completion # Used in: _waitForAudioCompletion() for single audio generation