From edb56718e65418b20d479f5b5740a73f10d5dcfb Mon Sep 17 00:00:00 2001 From: Dave Page Date: Wed, 11 Mar 2026 14:25:23 +0000 Subject: [PATCH 1/8] Fixean issue where LLM responses are not streamed or rendered properly in the AI Assistant. Fixes #9734 --- docs/en_US/release_notes_9_14.rst | 1 + web/pgadmin/llm/chat.py | 116 +++- web/pgadmin/llm/client.py | 45 +- web/pgadmin/llm/prompts/nlq.py | 15 +- web/pgadmin/llm/providers/anthropic.py | 222 ++++++- web/pgadmin/llm/providers/docker.py | 209 ++++++- web/pgadmin/llm/providers/ollama.py | 172 +++++- web/pgadmin/llm/providers/openai.py | 213 ++++++- web/pgadmin/static/js/Theme/dark.js | 1 + web/pgadmin/static/js/Theme/high_contrast.js | 1 + web/pgadmin/static/js/Theme/light.js | 1 + web/pgadmin/tools/sqleditor/__init__.py | 111 ++-- .../js/components/sections/NLQChatPanel.jsx | 576 +++++++++++++++--- .../tools/sqleditor/tests/test_nlq_chat.py | 281 ++++++++- 14 files changed, 1767 insertions(+), 197 deletions(-) diff --git a/docs/en_US/release_notes_9_14.rst b/docs/en_US/release_notes_9_14.rst index 59619a5e4de..36ae2d679e1 100644 --- a/docs/en_US/release_notes_9_14.rst +++ b/docs/en_US/release_notes_9_14.rst @@ -41,5 +41,6 @@ Bug fixes | `Issue #9721 `_ - Fixed an issue where permissions page is not completely accessible on full scroll. | `Issue #9729 `_ - Fixed an issue where some LLM models would not use database tools in the AI assistant, instead returning text descriptions of tool calls. | `Issue #9732 `_ - Improve the AI Assistant user prompt to be more descriptive of the actual functionality. + | `Issue #9734 `_ - Fixed an issue where LLM responses are not streamed or rendered properly in the AI Assistant. | `Issue #9736 `_ - Fix an issue where the AI Assistant was not retaining conversation context between messages, with chat history compaction to manage token budgets. | `Issue #9740 `_ - Fixed an issue where the AI Assistant input textbox sometimes swallows the first character of input. diff --git a/web/pgadmin/llm/chat.py b/web/pgadmin/llm/chat.py index 40e99219111..8c9fc1593eb 100644 --- a/web/pgadmin/llm/chat.py +++ b/web/pgadmin/llm/chat.py @@ -14,10 +14,11 @@ """ import json -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union from pgadmin.llm.client import get_llm_client, is_llm_available -from pgadmin.llm.models import Message, StopReason +from pgadmin.llm.models import Message, LLMResponse, StopReason from pgadmin.llm.tools import DATABASE_TOOLS, execute_tool, DatabaseToolError from pgadmin.llm.utils import get_max_tool_iterations @@ -153,6 +154,117 @@ def chat_with_database( ) +def chat_with_database_stream( + user_message: str, + sid: int, + did: int, + conversation_history: Optional[list[Message]] = None, + system_prompt: Optional[str] = None, + max_tool_iterations: Optional[int] = None, + provider: Optional[str] = None, + model: Optional[str] = None +) -> Generator[Union[str, tuple[str, list[Message]]], None, None]: + """ + Stream an LLM chat conversation with database tool access. + + Like chat_with_database, but yields text chunks as the final + response streams in. During tool-use iterations, no text is + yielded (tools are executed silently). + + Yields: + str: Text content chunks from the final LLM response. + + The last item yielded is a tuple of + (final_response_text, updated_conversation_history). + + Raises: + LLMClientError: If the LLM request fails. + RuntimeError: If LLM is not available or max iterations exceeded. + """ + if not is_llm_available(): + raise RuntimeError("LLM is not configured. Please configure an LLM " + "provider in Preferences > AI.") + + client = get_llm_client(provider=provider, model=model) + if not client: + raise RuntimeError("Failed to create LLM client") + + messages = list(conversation_history) if conversation_history else [] + messages.append(Message.user(user_message)) + + if system_prompt is None: + system_prompt = DEFAULT_SYSTEM_PROMPT + + if max_tool_iterations is None: + max_tool_iterations = get_max_tool_iterations() + + iteration = 0 + while iteration < max_tool_iterations: + iteration += 1 + + # Stream the LLM response, yielding text chunks as they arrive + response = None + for item in client.chat_stream( + messages=messages, + tools=DATABASE_TOOLS, + system_prompt=system_prompt + ): + if isinstance(item, LLMResponse): + response = item + elif isinstance(item, str): + yield item + + if response is None: + raise RuntimeError("No response received from LLM") + + messages.append(response.to_message()) + + if response.stop_reason != StopReason.TOOL_USE: + # Final response - yield the completion tuple + yield (response.content, messages) + return + + # Signal that tools are being executed so the caller can + # reset streaming state and show a thinking indicator + yield ('tool_use', [tc.name for tc in response.tool_calls]) + + # Execute tool calls + tool_results = [] + for tool_call in response.tool_calls: + try: + result = execute_tool( + tool_name=tool_call.name, + arguments=tool_call.arguments, + sid=sid, + did=did + ) + tool_results.append(Message.tool_result( + tool_call_id=tool_call.id, + content=json.dumps(result, default=str), + is_error=False + )) + except (DatabaseToolError, ValueError) as e: + tool_results.append(Message.tool_result( + tool_call_id=tool_call.id, + content=json.dumps({"error": str(e)}), + is_error=True + )) + except Exception as e: + tool_results.append(Message.tool_result( + tool_call_id=tool_call.id, + content=json.dumps({ + "error": f"Unexpected error: {str(e)}" + }), + is_error=True + )) + + messages.extend(tool_results) + + raise RuntimeError( + f"Exceeded maximum tool iterations ({max_tool_iterations})" + ) + + def single_query( question: str, sid: int, diff --git a/web/pgadmin/llm/client.py b/web/pgadmin/llm/client.py index e11b8c15d5a..e860eec4497 100644 --- a/web/pgadmin/llm/client.py +++ b/web/pgadmin/llm/client.py @@ -10,7 +10,8 @@ """Base LLM client interface and factory.""" from abc import ABC, abstractmethod -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union from pgadmin.llm.models import ( Message, Tool, LLMResponse, LLMError @@ -74,6 +75,48 @@ def chat( """ pass + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """ + Stream a chat response from the LLM. + + Yields text chunks (str) as they arrive, then yields + a final LLMResponse with the complete response metadata. + + The default implementation falls back to non-streaming chat(). + + Args: + messages: List of conversation messages. + tools: Optional list of tools the LLM can use. + system_prompt: Optional system prompt to set context. + max_tokens: Maximum tokens in the response. + temperature: Sampling temperature (0.0 = deterministic). + **kwargs: Additional provider-specific parameters. + + Yields: + str: Text content chunks as they arrive. + LLMResponse: Final response with complete metadata (last item). + """ + # Default: fall back to non-streaming + response = self.chat( + messages=messages, + tools=tools, + system_prompt=system_prompt, + max_tokens=max_tokens, + temperature=temperature, + **kwargs + ) + if response.content: + yield response.content + yield response + def validate_connection(self) -> tuple[bool, Optional[str]]: """ Validate the connection to the LLM provider. diff --git a/web/pgadmin/llm/prompts/nlq.py b/web/pgadmin/llm/prompts/nlq.py index e40854292c4..cfbeb035058 100644 --- a/web/pgadmin/llm/prompts/nlq.py +++ b/web/pgadmin/llm/prompts/nlq.py @@ -35,13 +35,10 @@ - Use explicit column names instead of SELECT * - For UPDATE/DELETE, always include WHERE clauses -Once you have explored the database structure using the tools above, \ -provide your final answer as a JSON object in this exact format: -{"sql": "YOUR SQL QUERY HERE", "explanation": "Brief explanation"} - -Rules for the final response: -- Return ONLY the JSON object, no other text -- No markdown code blocks -- If you need clarification, set "sql" to null and put \ -your question in "explanation" +Response format: +- Always put SQL in fenced code blocks with the sql language tag +- You may include multiple SQL blocks if the request needs \ +multiple statements +- Briefly explain what each query does +- If you need clarification, just ask — no code blocks needed """ diff --git a/web/pgadmin/llm/providers/anthropic.py b/web/pgadmin/llm/providers/anthropic.py index 1ab9e2b0427..03284d7bf24 100644 --- a/web/pgadmin/llm/providers/anthropic.py +++ b/web/pgadmin/llm/providers/anthropic.py @@ -10,10 +10,12 @@ """Anthropic Claude LLM client implementation.""" import json +import socket import ssl import urllib.request import urllib.error -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union import uuid # Try to use certifi for proper SSL certificate handling @@ -284,3 +286,221 @@ def _parse_response(self, data: dict) -> LLMResponse: usage=usage, raw_response=data ) + + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """Stream a chat response from Anthropic.""" + payload = { + 'model': self._model, + 'max_tokens': max_tokens, + 'messages': self._convert_messages(messages), + 'stream': True + } + + if system_prompt: + payload['system'] = system_prompt + + if temperature > 0: + payload['temperature'] = temperature + + if tools: + payload['tools'] = self._convert_tools(tools) + + try: + yield from self._process_stream(payload) + except LLMClientError: + raise + except Exception as e: + raise LLMClientError(LLMError( + message=f"Streaming request failed: {str(e)}", + provider=self.provider_name + )) + + def _process_stream( + self, payload: dict + ) -> Generator[Union[str, LLMResponse], None, None]: + """Make a streaming request and yield chunks.""" + headers = { + 'Content-Type': 'application/json', + 'x-api-key': self._api_key, + 'anthropic-version': API_VERSION + } + + request = urllib.request.Request( + API_URL, + data=json.dumps(payload).encode('utf-8'), + headers=headers, + method='POST' + ) + + try: + response = urllib.request.urlopen( + request, timeout=120, context=SSL_CONTEXT + ) + except urllib.error.HTTPError as e: + error_body = e.read().decode('utf-8') + try: + error_data = json.loads(error_body) + error_msg = error_data.get( + 'error', {} + ).get('message', str(e)) + except json.JSONDecodeError: + error_msg = error_body or str(e) + raise LLMClientError(LLMError( + message=error_msg, + code=str(e.code), + provider=self.provider_name, + retryable=e.code in (429, 500, 502, 503, 504) + )) + except urllib.error.URLError as e: + raise LLMClientError(LLMError( + message=f"Connection error: {e.reason}", + provider=self.provider_name, + retryable=True + )) + except socket.timeout: + raise LLMClientError(LLMError( + message="Request timed out.", + code='timeout', + provider=self.provider_name, + retryable=True + )) + + try: + yield from self._read_anthropic_stream(response) + finally: + response.close() + + def _read_anthropic_stream( + self, response + ) -> Generator[Union[str, LLMResponse], None, None]: + """Read and parse an Anthropic SSE stream. + + Uses readline() for incremental reading. + """ + content_parts = [] + tool_calls = [] + current_tool_block = None + tool_input_json = '' + stop_reason_str = None + model_name = self._model + usage = Usage() + + while True: + line_bytes = response.readline() + if not line_bytes: + break + + line = line_bytes.decode('utf-8', errors='replace').strip() + + if not line or line.startswith(':'): + continue + + if line.startswith('event: '): + continue + + if not line.startswith('data: '): + continue + + try: + data = json.loads(line[6:]) + except json.JSONDecodeError: + continue + + event_type = data.get('type', '') + + if event_type == 'message_start': + msg = data.get('message', {}) + model_name = msg.get('model', self._model) + u = msg.get('usage', {}) + usage = Usage( + input_tokens=u.get('input_tokens', 0), + output_tokens=u.get('output_tokens', 0), + total_tokens=( + u.get('input_tokens', 0) + + u.get('output_tokens', 0) + ) + ) + + elif event_type == 'content_block_start': + block = data.get('content_block', {}) + if block.get('type') == 'tool_use': + current_tool_block = { + 'id': block.get('id', str(uuid.uuid4())), + 'name': block.get('name', '') + } + tool_input_json = '' + + elif event_type == 'content_block_delta': + delta = data.get('delta', {}) + if delta.get('type') == 'text_delta': + text = delta.get('text', '') + if text: + content_parts.append(text) + yield text + elif delta.get('type') == 'input_json_delta': + tool_input_json += delta.get( + 'partial_json', '' + ) + + elif event_type == 'content_block_stop': + if current_tool_block is not None: + try: + arguments = json.loads( + tool_input_json + ) if tool_input_json else {} + except json.JSONDecodeError: + arguments = {} + tool_calls.append(ToolCall( + id=current_tool_block['id'], + name=current_tool_block['name'], + arguments=arguments + )) + current_tool_block = None + tool_input_json = '' + + elif event_type == 'message_delta': + delta = data.get('delta', {}) + stop_reason_str = delta.get('stop_reason') + u = data.get('usage', {}) + if u: + usage = Usage( + input_tokens=usage.input_tokens, + output_tokens=u.get( + 'output_tokens', + usage.output_tokens + ), + total_tokens=( + usage.input_tokens + + u.get( + 'output_tokens', + usage.output_tokens + ) + ) + ) + + # Build final response + stop_reason_map = { + 'end_turn': StopReason.END_TURN, + 'tool_use': StopReason.TOOL_USE, + 'max_tokens': StopReason.MAX_TOKENS, + 'stop_sequence': StopReason.STOP_SEQUENCE + } + stop_reason = stop_reason_map.get( + stop_reason_str or '', StopReason.UNKNOWN + ) + + yield LLMResponse( + content=''.join(content_parts), + tool_calls=tool_calls, + stop_reason=stop_reason, + model=model_name, + usage=usage + ) diff --git a/web/pgadmin/llm/providers/docker.py b/web/pgadmin/llm/providers/docker.py index 4fa6ccda2cb..fc373b9dfdf 100644 --- a/web/pgadmin/llm/providers/docker.py +++ b/web/pgadmin/llm/providers/docker.py @@ -18,7 +18,8 @@ import ssl import urllib.request import urllib.error -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union import uuid # Try to use certifi for proper SSL certificate handling @@ -354,3 +355,209 @@ def _parse_response(self, data: dict) -> LLMResponse: usage=usage, raw_response=data ) + + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """Stream a chat response from Docker Model Runner.""" + converted_messages = self._convert_messages(messages) + + if system_prompt: + converted_messages.insert(0, { + 'role': 'system', + 'content': system_prompt + }) + + payload = { + 'model': self._model, + 'messages': converted_messages, + 'max_completion_tokens': max_tokens, + 'temperature': temperature, + 'stream': True, + 'stream_options': {'include_usage': True} + } + + if tools: + payload['tools'] = self._convert_tools(tools) + payload['tool_choice'] = 'auto' + + try: + yield from self._process_stream(payload) + except LLMClientError: + raise + except Exception as e: + raise LLMClientError(LLMError( + message=f"Streaming request failed: {str(e)}", + provider=self.provider_name + )) + + def _process_stream( + self, payload: dict + ) -> Generator[Union[str, LLMResponse], None, None]: + """Make a streaming request and yield chunks.""" + headers = { + 'Content-Type': 'application/json' + } + + url = f'{self._api_url}/engines/v1/chat/completions' + + request = urllib.request.Request( + url, + data=json.dumps(payload).encode('utf-8'), + headers=headers, + method='POST' + ) + + try: + response = urllib.request.urlopen( + request, timeout=300, context=SSL_CONTEXT + ) + except urllib.error.HTTPError as e: + error_body = e.read().decode('utf-8') + try: + error_data = json.loads(error_body) + error_msg = error_data.get( + 'error', {} + ).get('message', str(e)) + except json.JSONDecodeError: + error_msg = error_body or str(e) + raise LLMClientError(LLMError( + message=error_msg, + code=str(e.code), + provider=self.provider_name, + retryable=e.code in (429, 500, 502, 503, 504) + )) + except urllib.error.URLError as e: + raise LLMClientError(LLMError( + message=f"Connection error: {e.reason}. " + f"Is Docker Model Runner running at " + f"{self._api_url}?", + provider=self.provider_name, + retryable=True + )) + except socket.timeout: + raise LLMClientError(LLMError( + message="Request timed out.", + code='timeout', + provider=self.provider_name, + retryable=True + )) + + try: + yield from self._read_openai_stream(response) + finally: + response.close() + + def _read_openai_stream( + self, response + ) -> Generator[Union[str, LLMResponse], None, None]: + """Read and parse an OpenAI-format SSE stream. + + Uses readline() for incremental reading. + """ + content_parts = [] + tool_calls_data = {} + finish_reason = None + model_name = self._model + usage = Usage() + + while True: + line_bytes = response.readline() + if not line_bytes: + break + + line = line_bytes.decode('utf-8', errors='replace').strip() + + if not line or line.startswith(':'): + continue + + if line == 'data: [DONE]': + continue + + if not line.startswith('data: '): + continue + + try: + data = json.loads(line[6:]) + except json.JSONDecodeError: + continue + + if 'usage' in data and data['usage']: + u = data['usage'] + usage = Usage( + input_tokens=u.get('prompt_tokens', 0), + output_tokens=u.get('completion_tokens', 0), + total_tokens=u.get('total_tokens', 0) + ) + + if 'model' in data: + model_name = data['model'] + + choices = data.get('choices', []) + if not choices: + continue + + choice = choices[0] + delta = choice.get('delta', {}) + + if choice.get('finish_reason'): + finish_reason = choice['finish_reason'] + + text_chunk = delta.get('content') + if text_chunk: + content_parts.append(text_chunk) + yield text_chunk + + for tc_delta in delta.get('tool_calls', []): + idx = tc_delta.get('index', 0) + if idx not in tool_calls_data: + tool_calls_data[idx] = { + 'id': '', 'name': '', 'arguments': '' + } + tc = tool_calls_data[idx] + if 'id' in tc_delta: + tc['id'] = tc_delta['id'] + func = tc_delta.get('function', {}) + if 'name' in func: + tc['name'] = func['name'] + if 'arguments' in func: + tc['arguments'] += func['arguments'] + + content = ''.join(content_parts) + tool_calls = [] + for idx in sorted(tool_calls_data.keys()): + tc = tool_calls_data[idx] + try: + arguments = json.loads(tc['arguments']) \ + if tc['arguments'] else {} + except json.JSONDecodeError: + arguments = {} + tool_calls.append(ToolCall( + id=tc['id'] or str(uuid.uuid4()), + name=tc['name'], + arguments=arguments + )) + + stop_reason_map = { + 'stop': StopReason.END_TURN, + 'tool_calls': StopReason.TOOL_USE, + 'length': StopReason.MAX_TOKENS, + 'content_filter': StopReason.STOP_SEQUENCE + } + stop_reason = stop_reason_map.get( + finish_reason or '', StopReason.UNKNOWN + ) + + yield LLMResponse( + content=content, + tool_calls=tool_calls, + stop_reason=stop_reason, + model=model_name, + usage=usage + ) diff --git a/web/pgadmin/llm/providers/ollama.py b/web/pgadmin/llm/providers/ollama.py index 8d38b72facd..2432ee2c48c 100644 --- a/web/pgadmin/llm/providers/ollama.py +++ b/web/pgadmin/llm/providers/ollama.py @@ -10,10 +10,10 @@ """Ollama LLM client implementation.""" import json -import re import urllib.request import urllib.error -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union import uuid from pgadmin.llm.client import LLMClient, LLMClientError @@ -220,7 +220,7 @@ def _make_request(self, payload: dict) -> dict: message=error_msg, code=str(e.code), provider=self.provider_name, - retryable=e.code in (500, 502, 503, 504) + retryable=e.code in (429, 500, 502, 503, 504) )) except urllib.error.URLError as e: raise LLMClientError(LLMError( @@ -231,8 +231,6 @@ def _make_request(self, payload: dict) -> dict: def _parse_response(self, data: dict) -> LLMResponse: """Parse the Ollama API response into an LLMResponse.""" - import re - message = data.get('message', {}) content = message.get('content', '') tool_calls = [] @@ -285,3 +283,167 @@ def _parse_response(self, data: dict) -> LLMResponse: usage=usage, raw_response=data ) + + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """Stream a chat response from Ollama.""" + converted_messages = self._convert_messages(messages) + + if system_prompt: + converted_messages.insert(0, { + 'role': 'system', + 'content': system_prompt + }) + + payload = { + 'model': self._model, + 'messages': converted_messages, + 'stream': True, + 'options': { + 'num_predict': max_tokens, + 'temperature': temperature + } + } + + if tools: + payload['tools'] = self._convert_tools(tools) + + try: + yield from self._process_stream(payload) + except LLMClientError: + raise + except Exception as e: + raise LLMClientError(LLMError( + message=f"Streaming request failed: {str(e)}", + provider=self.provider_name + )) + + def _process_stream( + self, payload: dict + ) -> Generator[Union[str, LLMResponse], None, None]: + """Make a streaming request and yield chunks.""" + url = f'{self._api_url}/api/chat' + + request = urllib.request.Request( + url, + data=json.dumps(payload).encode('utf-8'), + headers={'Content-Type': 'application/json'}, + method='POST' + ) + + try: + response = urllib.request.urlopen(request, timeout=300) + except urllib.error.HTTPError as e: + error_body = e.read().decode('utf-8') + try: + error_data = json.loads(error_body) + error_msg = error_data.get('error', str(e)) + except json.JSONDecodeError: + error_msg = error_body or str(e) + raise LLMClientError(LLMError( + message=error_msg, + code=str(e.code), + provider=self.provider_name, + retryable=e.code in (429, 500, 502, 503, 504) + )) + except urllib.error.URLError as e: + raise LLMClientError(LLMError( + message=f"Cannot connect to Ollama: {e.reason}", + provider=self.provider_name, + retryable=True + )) + + try: + yield from self._read_ollama_stream(response) + finally: + response.close() + + def _read_ollama_stream( + self, response + ) -> Generator[Union[str, LLMResponse], None, None]: + """Read and parse an Ollama NDJSON stream. + + Uses readline() for incremental reading. + """ + content_parts = [] + tool_calls = [] + done_reason = None + model_name = self._model + input_tokens = 0 + output_tokens = 0 + final_data = None + + while True: + line_bytes = response.readline() + if not line_bytes: + break + + line = line_bytes.decode('utf-8', errors='replace').strip() + + if not line: + continue + + try: + data = json.loads(line) + except json.JSONDecodeError: + continue + + msg = data.get('message', {}) + + # Text content + text = msg.get('content', '') + if text: + content_parts.append(text) + yield text + + # Tool calls (in final message) + for tc in msg.get('tool_calls', []): + func = tc.get('function', {}) + arguments = func.get('arguments', {}) + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {} + tool_calls.append(ToolCall( + id=str(uuid.uuid4()), + name=func.get('name', ''), + arguments=arguments + )) + + if data.get('done'): + final_data = data + done_reason = data.get('done_reason', '') + model_name = data.get('model', self._model) + input_tokens = data.get('prompt_eval_count', 0) + output_tokens = data.get('eval_count', 0) + + # Build final response + if tool_calls: + stop_reason = StopReason.TOOL_USE + elif done_reason == 'stop': + stop_reason = StopReason.END_TURN + elif done_reason == 'length': + stop_reason = StopReason.MAX_TOKENS + else: + stop_reason = StopReason.UNKNOWN + + yield LLMResponse( + content=''.join(content_parts), + tool_calls=tool_calls, + stop_reason=stop_reason, + model=model_name, + usage=Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=input_tokens + output_tokens + ), + raw_response=final_data + ) diff --git a/web/pgadmin/llm/providers/openai.py b/web/pgadmin/llm/providers/openai.py index 73c020fb7ba..b0cafcfee2b 100644 --- a/web/pgadmin/llm/providers/openai.py +++ b/web/pgadmin/llm/providers/openai.py @@ -14,7 +14,8 @@ import ssl import urllib.request import urllib.error -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union import uuid # Try to use certifi for proper SSL certificate handling @@ -355,3 +356,213 @@ def _parse_response(self, data: dict) -> LLMResponse: usage=usage, raw_response=data ) + + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """Stream a chat response from OpenAI.""" + converted_messages = self._convert_messages(messages) + + if system_prompt: + converted_messages.insert(0, { + 'role': 'system', + 'content': system_prompt + }) + + payload = { + 'model': self._model, + 'messages': converted_messages, + 'max_completion_tokens': max_tokens, + 'temperature': temperature, + 'stream': True, + 'stream_options': {'include_usage': True} + } + + if tools: + payload['tools'] = self._convert_tools(tools) + payload['tool_choice'] = 'auto' + + try: + yield from self._process_stream(payload) + except LLMClientError: + raise + except Exception as e: + raise LLMClientError(LLMError( + message=f"Streaming request failed: {str(e)}", + provider=self.provider_name + )) + + def _process_stream( + self, payload: dict + ) -> Generator[Union[str, LLMResponse], None, None]: + """Make a streaming request and yield chunks.""" + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self._api_key}' + } + + request = urllib.request.Request( + API_URL, + data=json.dumps(payload).encode('utf-8'), + headers=headers, + method='POST' + ) + + try: + response = urllib.request.urlopen( + request, timeout=120, context=SSL_CONTEXT + ) + except urllib.error.HTTPError as e: + error_body = e.read().decode('utf-8') + try: + error_data = json.loads(error_body) + error_msg = error_data.get( + 'error', {} + ).get('message', str(e)) + except json.JSONDecodeError: + error_msg = error_body or str(e) + raise LLMClientError(LLMError( + message=error_msg, + code=str(e.code), + provider=self.provider_name, + retryable=e.code in (429, 500, 502, 503, 504) + )) + except urllib.error.URLError as e: + raise LLMClientError(LLMError( + message=f"Connection error: {e.reason}", + provider=self.provider_name, + retryable=True + )) + except socket.timeout: + raise LLMClientError(LLMError( + message="Request timed out.", + code='timeout', + provider=self.provider_name, + retryable=True + )) + + try: + yield from self._read_openai_stream(response) + finally: + response.close() + + def _read_openai_stream( + self, response + ) -> Generator[Union[str, LLMResponse], None, None]: + """Read and parse an OpenAI-format SSE stream. + + Uses readline() for incremental reading — it returns as soon + as a complete line arrives from the server, unlike read() + which blocks until a buffer fills up. + """ + content_parts = [] + # tool_calls_data: {index: {id, name, arguments_str}} + tool_calls_data = {} + finish_reason = None + model_name = self._model + usage = Usage() + + while True: + line_bytes = response.readline() + if not line_bytes: + break + + line = line_bytes.decode('utf-8', errors='replace').strip() + + if not line or line.startswith(':'): + continue + + if line == 'data: [DONE]': + continue + + if not line.startswith('data: '): + continue + + try: + data = json.loads(line[6:]) + except json.JSONDecodeError: + continue + + # Extract usage from the final chunk + if 'usage' in data and data['usage']: + u = data['usage'] + usage = Usage( + input_tokens=u.get('prompt_tokens', 0), + output_tokens=u.get('completion_tokens', 0), + total_tokens=u.get('total_tokens', 0) + ) + + if 'model' in data: + model_name = data['model'] + + choices = data.get('choices', []) + if not choices: + continue + + choice = choices[0] + delta = choice.get('delta', {}) + + if choice.get('finish_reason'): + finish_reason = choice['finish_reason'] + + # Text content + text_chunk = delta.get('content') + if text_chunk: + content_parts.append(text_chunk) + yield text_chunk + + # Tool calls (accumulate) + for tc_delta in delta.get('tool_calls', []): + idx = tc_delta.get('index', 0) + if idx not in tool_calls_data: + tool_calls_data[idx] = { + 'id': '', 'name': '', 'arguments': '' + } + tc = tool_calls_data[idx] + if 'id' in tc_delta: + tc['id'] = tc_delta['id'] + func = tc_delta.get('function', {}) + if 'name' in func: + tc['name'] = func['name'] + if 'arguments' in func: + tc['arguments'] += func['arguments'] + + # Build final response + content = ''.join(content_parts) + tool_calls = [] + for idx in sorted(tool_calls_data.keys()): + tc = tool_calls_data[idx] + try: + arguments = json.loads(tc['arguments']) \ + if tc['arguments'] else {} + except json.JSONDecodeError: + arguments = {} + tool_calls.append(ToolCall( + id=tc['id'] or str(uuid.uuid4()), + name=tc['name'], + arguments=arguments + )) + + stop_reason_map = { + 'stop': StopReason.END_TURN, + 'tool_calls': StopReason.TOOL_USE, + 'length': StopReason.MAX_TOKENS, + 'content_filter': StopReason.STOP_SEQUENCE + } + stop_reason = stop_reason_map.get( + finish_reason or '', StopReason.UNKNOWN + ) + + yield LLMResponse( + content=content, + tool_calls=tool_calls, + stop_reason=stop_reason, + model=model_name, + usage=usage + ) diff --git a/web/pgadmin/static/js/Theme/dark.js b/web/pgadmin/static/js/Theme/dark.js index 5deb73324d5..b087e49e0ad 100644 --- a/web/pgadmin/static/js/Theme/dark.js +++ b/web/pgadmin/static/js/Theme/dark.js @@ -89,6 +89,7 @@ export default function(basicSettings) { }, otherVars: { colorBrand: '#1b71b5', + hyperlinkColor: '#6CB4EE', borderColor: '#4a4a4a', inputBorderColor: '#6b6b6b', inputDisabledBg: 'inherit', diff --git a/web/pgadmin/static/js/Theme/high_contrast.js b/web/pgadmin/static/js/Theme/high_contrast.js index 184153cb273..0a4a5083cf5 100644 --- a/web/pgadmin/static/js/Theme/high_contrast.js +++ b/web/pgadmin/static/js/Theme/high_contrast.js @@ -87,6 +87,7 @@ export default function(basicSettings) { }, otherVars: { colorBrand: '#84D6FF', + hyperlinkColor: '#84D6FF', borderColor: '#A6B7C8', inputBorderColor: '#8B9CAD', inputDisabledBg: '#1F2932', diff --git a/web/pgadmin/static/js/Theme/light.js b/web/pgadmin/static/js/Theme/light.js index 093928cfef1..00847f91425 100644 --- a/web/pgadmin/static/js/Theme/light.js +++ b/web/pgadmin/static/js/Theme/light.js @@ -89,6 +89,7 @@ export default function(basicSettings) { }, otherVars: { colorBrand: '#326690', + hyperlinkColor: '#1a0dab', iconLoaderUrl: 'url("data:image/svg+xml,%3C%3Fxml version=\'1.0\' encoding=\'utf-8\'%3F%3E%3C!-- Generator: Adobe Illustrator 23.1.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) --%3E%3Csvg version=\'1.1\' id=\'Layer_1\' xmlns=\'http://www.w3.org/2000/svg\' xmlns:xlink=\'http://www.w3.org/1999/xlink\' x=\'0px\' y=\'0px\' viewBox=\'0 0 38 38\' style=\'enable-background:new 0 0 38 38;\' xml:space=\'preserve\'%3E%3Cstyle type=\'text/css\'%3E .st0%7Bfill:none;stroke:%23EBEEF3;stroke-width:5;%7D .st1%7Bfill:none;stroke:%23326690;stroke-width:5;%7D%0A%3C/style%3E%3Cg%3E%3Cg transform=\'translate(1 1)\'%3E%3Ccircle class=\'st0\' cx=\'18\' cy=\'18\' r=\'16\'/%3E%3Cpath class=\'st1\' d=\'M34,18c0-8.8-7.2-16-16-16 \'%3E%3CanimateTransform accumulate=\'none\' additive=\'replace\' attributeName=\'transform\' calcMode=\'linear\' dur=\'0.7s\' fill=\'remove\' from=\'0 18 18\' repeatCount=\'indefinite\' restart=\'always\' to=\'360 18 18\' type=\'rotate\'%3E%3C/animateTransform%3E%3C/path%3E%3C/g%3E%3C/g%3E%3C/svg%3E%0A");', iconLoaderSmall: 'url("data:image/svg+xml,%3C%3Fxml version=\'1.0\' encoding=\'utf-8\'%3F%3E%3C!-- Generator: Adobe Illustrator 23.1.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) --%3E%3Csvg version=\'1.1\' id=\'Layer_1\' xmlns=\'http://www.w3.org/2000/svg\' xmlns:xlink=\'http://www.w3.org/1999/xlink\' x=\'0px\' y=\'0px\' viewBox=\'0 0 38 38\' style=\'enable-background:new 0 0 38 38;\' xml:space=\'preserve\'%3E%3Cstyle type=\'text/css\'%3E .st0%7Bfill:none;stroke:%23EBEEF3;stroke-width:5;%7D .st1%7Bfill:none;stroke:%23326690;stroke-width:5;%7D%0A%3C/style%3E%3Cg%3E%3Cg transform=\'translate(1 1)\'%3E%3Ccircle class=\'st0\' cx=\'18\' cy=\'18\' r=\'16\'/%3E%3Cpath class=\'st1\' d=\'M34,18c0-8.8-7.2-16-16-16 \'%3E%3CanimateTransform accumulate=\'none\' additive=\'replace\' attributeName=\'transform\' calcMode=\'linear\' dur=\'0.7s\' fill=\'remove\' from=\'0 18 18\' repeatCount=\'indefinite\' restart=\'always\' to=\'360 18 18\' type=\'rotate\'%3E%3C/animateTransform%3E%3C/path%3E%3C/g%3E%3C/g%3E%3C/svg%3E%0A")', dashboardPgDoc: 'url("data:image/svg+xml,%3C%3Fxml version=\'1.0\' encoding=\'utf-8\'%3F%3E%3C!-- Generator: Adobe Illustrator 22.1.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) --%3E%3Csvg version=\'1.1\' id=\'Layer_1\' xmlns=\'http://www.w3.org/2000/svg\' xmlns:xlink=\'http://www.w3.org/1999/xlink\' x=\'0px\' y=\'0px\' viewBox=\'0 0 42 42\' style=\'enable-background:new 0 0 42 42;\' xml:space=\'preserve\'%3E%3Cstyle type=\'text/css\'%3E .st0%7Bstroke:%23000000;stroke-width:3.3022;%7D .st1%7Bfill:%23336791;%7D .st2%7Bfill:none;stroke:%23FFFFFF;stroke-width:1.1007;stroke-linecap:round;stroke-linejoin:round;%7D .st3%7Bfill:none;stroke:%23FFFFFF;stroke-width:1.1007;stroke-linecap:round;stroke-linejoin:bevel;%7D .st4%7Bfill:%23FFFFFF;stroke:%23FFFFFF;stroke-width:0.3669;%7D .st5%7Bfill:%23FFFFFF;stroke:%23FFFFFF;stroke-width:0.1835;%7D .st6%7Bfill:none;stroke:%23FFFFFF;stroke-width:0.2649;stroke-linecap:round;stroke-linejoin:round;%7D%0A%3C/style%3E%3Cg id=\'orginal\'%3E%3C/g%3E%3Cg id=\'Layer_x0020_3\'%3E%3Cpath class=\'st0\' d=\'M31.3,30c0.3-2.1,0.2-2.4,1.7-2.1l0.4,0c1.2,0.1,2.8-0.2,3.7-0.6c2-0.9,3.1-2.4,1.2-2 c-4.4,0.9-4.7-0.6-4.7-0.6c4.7-7,6.7-15.8,5-18c-4.6-5.9-12.6-3.1-12.7-3l0,0c-0.9-0.2-1.9-0.3-3-0.3c-2,0-3.5,0.5-4.7,1.4 c0,0-14.3-5.9-13.6,7.4c0.1,2.8,4,21.3,8.7,15.7c1.7-2,3.3-3.8,3.3-3.8c0.8,0.5,1.8,0.8,2.8,0.7l0.1-0.1c0,0.3,0,0.5,0,0.8 c-1.2,1.3-0.8,1.6-3.2,2.1c-2.4,0.5-1,1.4-0.1,1.6c1.1,0.3,3.7,0.7,5.5-1.8l-0.1,0.3c0.5,0.4,0.4,2.7,0.5,4.4 c0.1,1.7,0.2,3.2,0.5,4.1c0.3,0.9,0.7,3.3,3.9,2.6C29.1,38.3,31.1,37.5,31.3,30\'/%3E%3Cpath class=\'st1\' d=\'M38.3,25.3c-4.4,0.9-4.7-0.6-4.7-0.6c4.7-7,6.7-15.8,5-18c-4.6-5.9-12.6-3.1-12.7-3l0,0 c-0.9-0.2-1.9-0.3-3-0.3c-2,0-3.5,0.5-4.7,1.4c0,0-14.3-5.9-13.6,7.4c0.1,2.8,4,21.3,8.7,15.7c1.7-2,3.3-3.8,3.3-3.8 c0.8,0.5,1.8,0.8,2.8,0.7l0.1-0.1c0,0.3,0,0.5,0,0.8c-1.2,1.3-0.8,1.6-3.2,2.1c-2.4,0.5-1,1.4-0.1,1.6c1.1,0.3,3.7,0.7,5.5-1.8 l-0.1,0.3c0.5,0.4,0.8,2.4,0.7,4.3c-0.1,1.9-0.1,3.2,0.3,4.2c0.4,1,0.7,3.3,3.9,2.6c2.6-0.6,4-2,4.2-4.5c0.1-1.7,0.4-1.5,0.5-3 l0.2-0.7c0.3-2.3,0-3.1,1.7-2.8l0.4,0c1.2,0.1,2.8-0.2,3.7-0.6C39,26.4,40.2,24.9,38.3,25.3L38.3,25.3z\'/%3E%3Cpath class=\'st2\' d=\'M21.8,26.6c-0.1,4.4,0,8.8,0.5,9.8c0.4,1.1,1.3,3.2,4.5,2.5c2.6-0.6,3.6-1.7,4-4.1c0.3-1.8,0.9-6.7,1-7.7\'/%3E%3Cpath class=\'st2\' d=\'M18,4.7c0,0-14.3-5.8-13.6,7.4c0.1,2.8,4,21.3,8.7,15.7c1.7-2,3.2-3.7,3.2-3.7\'/%3E%3Cpath class=\'st2\' d=\'M25.7,3.6c-0.5,0.2,7.9-3.1,12.7,3c1.7,2.2-0.3,11-5,18\'/%3E%3Cpath class=\'st3\' d=\'M33.5,24.6c0,0,0.3,1.5,4.7,0.6c1.9-0.4,0.8,1.1-1.2,2c-1.6,0.8-5.3,0.9-5.3-0.1 C31.6,24.5,33.6,25.3,33.5,24.6c-0.1-0.6-1.1-1.2-1.7-2.7c-0.5-1.3-7.3-11.2,1.9-9.7c0.3-0.1-2.4-8.7-11-8.9 c-8.6-0.1-8.3,10.6-8.3,10.6\'/%3E%3Cpath class=\'st2\' d=\'M19.4,25.6c-1.2,1.3-0.8,1.6-3.2,2.1c-2.4,0.5-1,1.4-0.1,1.6c1.1,0.3,3.7,0.7,5.5-1.8c0.5-0.8,0-2-0.7-2.3 C20.5,25.1,20,24.9,19.4,25.6L19.4,25.6z\'/%3E%3Cpath class=\'st2\' d=\'M19.3,25.5c-0.1-0.8,0.3-1.7,0.7-2.8c0.6-1.6,2-3.3,0.9-8.5c-0.8-3.9-6.5-0.8-6.5-0.3c0,0.5,0.3,2.7-0.1,5.2 c-0.5,3.3,2.1,6,5,5.7\'/%3E%3Cpath class=\'st4\' d=\'M18,13.8c0,0.2,0.3,0.7,0.8,0.7c0.5,0.1,0.9-0.3,0.9-0.5c0-0.2-0.3-0.4-0.8-0.4C18.4,13.6,18,13.7,18,13.8 L18,13.8z\'/%3E%3Cpath class=\'st5\' d=\'M32,13.5c0,0.2-0.3,0.7-0.8,0.7c-0.5,0.1-0.9-0.3-0.9-0.5c0-0.2,0.3-0.4,0.8-0.4C31.6,13.2,32,13.3,32,13.5 L32,13.5z\'/%3E%3Cpath class=\'st2\' d=\'M33.7,12.2c0.1,1.4-0.3,2.4-0.4,3.9c-0.1,2.2,1,4.7-0.6,7.2\'/%3E%3Cpath class=\'st6\' d=\'M2.7,6.6\'/%3E%3C/g%3E%3C/svg%3E%0A")', diff --git a/web/pgadmin/tools/sqleditor/__init__.py b/web/pgadmin/tools/sqleditor/__init__.py index 73f11059438..1d98e1b6341 100644 --- a/web/pgadmin/tools/sqleditor/__init__.py +++ b/web/pgadmin/tools/sqleditor/__init__.py @@ -2842,7 +2842,7 @@ def nlq_chat_stream(trans_id): """ from flask import stream_with_context from pgadmin.llm.utils import is_llm_enabled - from pgadmin.llm.chat import chat_with_database + from pgadmin.llm.chat import chat_with_database_stream from pgadmin.llm.prompts.nlq import NLQ_SYSTEM_PROMPT # Check if LLM is configured @@ -2908,75 +2908,56 @@ def generate(): provider=provider ) - # Call the LLM with database tools and history - response_text, updated_history = chat_with_database( + # Stream the LLM response with database tools + response_text = '' + updated_history = [] + for item in chat_with_database_stream( user_message=user_message, sid=trans_obj.sid, did=trans_obj.did, system_prompt=NLQ_SYSTEM_PROMPT, conversation_history=conversation_history - ) - - # Try to parse the response as JSON - sql = None - explanation = '' - - # First, try to extract JSON from markdown code blocks - json_text = response_text.strip() - - # Look for ```json ... ``` blocks - json_match = re.search( - r'```json\s*\n?(.*?)\n?```', - json_text, + ): + if isinstance(item, str): + # Text chunk from streaming LLM response + yield _nlq_sse_event({ + 'type': 'text_delta', + 'content': item + }) + elif isinstance(item, tuple) and \ + item[0] == 'tool_use': + # Tool execution in progress - reset streaming + yield _nlq_sse_event({ + 'type': 'thinking', + 'message': gettext( + 'Querying the database...' + ) + }) + elif isinstance(item, tuple): + # Final result: (response_text, messages) + response_text = item[0] + if len(item) > 1: + updated_history = item[1] + + # Extract SQL from markdown code fences + sql_blocks = re.findall( + r'```(?:sql|pgsql|postgresql)\s*\n(.*?)```', + response_text, re.DOTALL ) - if json_match: - json_text = json_match.group(1).strip() - else: - # Also try to find a plain JSON object in the response - # Look for {"sql": ... } pattern anywhere in the text - sql_pattern = ( - r'\{["\']?sql["\']?\s*:\s*' - r'(?:null|"[^"]*"|\'[^\']*\').*?\}' - ) - plain_json_match = re.search(sql_pattern, json_text, re.DOTALL) - if plain_json_match: - json_text = plain_json_match.group(0) - - try: - result = json.loads(json_text) - sql = result.get('sql') - explanation = result.get('explanation', '') - except (json.JSONDecodeError, TypeError): - # If not valid JSON, try to extract SQL from the response - # Look for SQL code blocks first - sql_match = re.search( - r'```sql\s*\n?(.*?)\n?```', - response_text, - re.DOTALL - ) - if sql_match: - sql = sql_match.group(1).strip() - else: - # Check for malformed tool call text patterns - # Some models output tool calls as text instead of - # proper tool use blocks - tool_call_match = re.search( - r'\s*' - r'\s*(.*?)\s*', - response_text, - re.DOTALL - ) - if tool_call_match: - sql = tool_call_match.group(1).strip() - explanation = gettext( - 'Generated SQL query from your request.' - ) - else: - # No parseable JSON or SQL block found - # Treat the response as an explanation/error message - explanation = response_text.strip() - # Don't set sql - leave it as None + sql = ';\n\n'.join( + block.strip() for block in sql_blocks + ) if sql_blocks else None + + # Fallback: try JSON format in case LLM ignored + # the markdown instruction + if sql is None: + try: + result = json.loads(response_text.strip()) + if isinstance(result, dict): + sql = result.get('sql') + except (json.JSONDecodeError, TypeError): + pass # Generate a conversation ID if not provided if not conversation_id: @@ -2995,11 +2976,11 @@ def generate(): filter_conversational(updated_history) ] if updated_history else [] - # Send the final result + # Send the final result with full response content yield _nlq_sse_event({ 'type': 'complete', 'sql': sql, - 'explanation': explanation, + 'content': response_text, 'conversation_id': new_conversation_id, 'history': serialized_history }) diff --git a/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx b/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx index d484a04955f..1bed2427b04 100644 --- a/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx +++ b/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx @@ -23,6 +23,8 @@ import AddIcon from '@mui/icons-material/Add'; import ClearAllIcon from '@mui/icons-material/ClearAll'; import AutoFixHighIcon from '@mui/icons-material/AutoFixHigh'; import { format as formatSQL } from 'sql-formatter'; +import { marked } from 'marked'; +import DOMPurify from 'dompurify'; import gettext from 'sources/gettext'; import url_for from 'sources/url_for'; import getApiInstance from '../../../../../../static/js/api_instance'; @@ -106,7 +108,6 @@ const SQLPreviewBox = styled(Box)(({ theme }) => ({ borderRadius: theme.spacing(0.5), overflow: 'auto', '& .cm-editor': { - minHeight: '60px', maxHeight: '250px', }, '& .cm-scroller': { @@ -132,17 +133,173 @@ const ThinkingIndicator = styled(Box)(({ theme }) => ({ color: theme.palette.text.secondary, })); +const MarkdownContent = styled(Box)(({ theme }) => ({ + fontSize: theme.typography.body2.fontSize, + lineHeight: theme.typography.body2.lineHeight, + '& p': { margin: `${theme.spacing(0.5)} 0` }, + '& p:first-of-type': { marginTop: 0 }, + '& p:last-of-type': { marginBottom: 0 }, + '& code': { + backgroundColor: theme.palette.action.hover, + padding: '1px 4px', + borderRadius: 3, + fontSize: '0.85em', + fontFamily: 'monospace', + }, + '& pre': { + backgroundColor: theme.palette.action.hover, + padding: theme.spacing(1), + borderRadius: 4, + overflow: 'auto', + '& code': { + backgroundColor: 'transparent', + padding: 0, + }, + }, + '& h1, & h2, & h3, & h4, & h5, & h6': { + margin: `${theme.spacing(1)} 0 ${theme.spacing(0.5)} 0`, + lineHeight: 1.3, + }, + '& h1': { fontSize: '1.3em' }, + '& h2': { fontSize: '1.2em' }, + '& h3': { fontSize: '1.1em' }, + '& ul': { + margin: `${theme.spacing(0.5)} 0`, + paddingLeft: theme.spacing(2.5), + listStyleType: 'disc !important', + }, + '& ol': { + margin: `${theme.spacing(0.5)} 0`, + paddingLeft: theme.spacing(2.5), + listStyleType: 'decimal !important', + }, + '& li': { + margin: `${theme.spacing(0.25)} 0`, + display: 'list-item !important', + listStyle: 'inherit !important', + }, + '& ul ul': { listStyleType: 'circle !important' }, + '& ul ul ul': { listStyleType: 'square !important' }, + '& table': { + borderCollapse: 'collapse', + margin: `${theme.spacing(0.5)} 0`, + width: '100%', + }, + '& th, & td': { + border: `1px solid ${theme.otherVars.borderColor}`, + padding: `${theme.spacing(0.25)} ${theme.spacing(0.75)}`, + textAlign: 'left', + }, + '& th': { + backgroundColor: theme.palette.action.hover, + fontWeight: 600, + }, + '& blockquote': { + borderLeft: `3px solid ${theme.otherVars.borderColor}`, + margin: `${theme.spacing(0.5)} 0`, + paddingLeft: theme.spacing(1), + opacity: 0.85, + }, + '& strong': { fontWeight: 600 }, + '& a': { + color: theme.otherVars.hyperlinkColor, + textDecoration: 'underline', + }, +})); + // Message types const MESSAGE_TYPES = { USER: 'user', ASSISTANT: 'assistant', SQL: 'sql', THINKING: 'thinking', + STREAMING: 'streaming', ERROR: 'error', }; +/** + * Incrementally parse streaming markdown text into an ordered list of + * segments. Each segment is: + * { type: 'text', content: string } + * { type: 'code', language: string, content: string, complete: boolean } + * + * Handles ```language fenced code blocks. Segments appear in the order + * the LLM streams them so the renderer can map straight over the array. + */ +function parseMarkdownSegments(text) { + const segments = []; + let pos = 0; + + while (pos < text.length) { + const fenceIdx = text.indexOf('```', pos); + + if (fenceIdx === -1) { + // No more fences — rest is text + const content = text.substring(pos); + if (content) segments.push({ type: 'text', content }); + break; + } + + // Text before the fence + if (fenceIdx > pos) { + segments.push({ type: 'text', content: text.substring(pos, fenceIdx) }); + } + + // Parse opening fence line: ```language\n + const afterFence = text.substring(fenceIdx + 3); + const langMatch = /^([a-zA-Z]*)\n/.exec(afterFence); + if (!langMatch) { + // Language line not complete yet — wait for more tokens + break; + } + + const language = langMatch[1].toLowerCase(); + const codeStart = fenceIdx + 3 + langMatch[0].length; + + // Find closing fence + const closeIdx = text.indexOf('```', codeStart); + if (closeIdx === -1) { + // Still streaming code block content + segments.push({ + type: 'code', language, + content: text.substring(codeStart), + complete: false, + }); + break; + } + + // Complete code block — trim trailing newline before closing fence + let codeContent = text.substring(codeStart, closeIdx); + if (codeContent.endsWith('\n')) { + codeContent = codeContent.slice(0, -1); + } + segments.push({ + type: 'code', language, + content: codeContent, + complete: true, + }); + + // Move past closing ``` and optional trailing newline + pos = closeIdx + 3; + if (pos < text.length && text[pos] === '\n') pos++; + } + + return segments; +} + +/** + * Render a markdown text fragment to sanitized HTML. + * Uses marked for inline formatting (bold, italic, code, lists, tables, etc.) + * and DOMPurify to prevent XSS. + */ +function renderMarkdownText(text) { + if (!text) return ''; + const html = marked.parse(text, { gfm: true, breaks: true }); + return DOMPurify.sanitize(html); +} + // Single chat message component -function ChatMessage({ message, onInsertSQL, onReplaceSQL, textColors, cmKey }) { +function ChatMessage({ message, onInsertSQL, onReplaceSQL, textColors, cmKey, formatSqlWithPrefs }) { if (message.type === MESSAGE_TYPES.USER) { return ( @@ -152,58 +309,117 @@ function ChatMessage({ message, onInsertSQL, onReplaceSQL, textColors, cmKey }) } if (message.type === MESSAGE_TYPES.SQL) { + const segments = message.content + ? parseMarkdownSegments(message.content) : []; + + // Fallback for messages without markdown content (old format) + if (segments.length === 0 && message.sql) { + return ( + + + + + {gettext('Generated SQL')} + + + + onInsertSQL(message.sql)}> + + + + + onReplaceSQL(message.sql)}> + + + + + navigator.clipboard.writeText(message.sql)}> + + + + + + + + + + {message.explanation && ( + {message.explanation} + )} + + ); + } + + // Render markdown segments with action buttons on code blocks return ( - {message.explanation && ( - - {message.explanation} - - )} - - - - {gettext('Generated SQL')} - - - - onInsertSQL(message.sql)} - > - - - - - onReplaceSQL(message.sql)} - > - - - - - navigator.clipboard.writeText(message.sql)} - > - - - - - - - - - + {segments.map((seg, idx) => { + if (seg.type === 'text') { + const content = seg.content?.trim(); + if (!content) return null; + return ( + 0 ? 1 : 0 }} + dangerouslySetInnerHTML={{ __html: renderMarkdownText(content) }} + /> + ); + } + + if (seg.type === 'code') { + const isSql = ['sql', 'pgsql', 'postgresql'].includes(seg.language); + const formattedCode = isSql ? formatSqlWithPrefs(seg.content) : seg.content; + + return ( + + + + {seg.language || gettext('Code')} + + + {isSql && ( + <> + + onInsertSQL(formattedCode)}> + + + + + onReplaceSQL(formattedCode)}> + + + + + )} + + navigator.clipboard.writeText(formattedCode)}> + + + + + + + + + + ); + } + + return null; + })} ); } @@ -224,6 +440,106 @@ function ChatMessage({ message, onInsertSQL, onReplaceSQL, textColors, cmKey }) ); } + if (message.type === MESSAGE_TYPES.STREAMING) { + const segments = parseMarkdownSegments(message.content); + const BlinkingCursor = ( + + ); + + // No segments parsed yet — show raw text or spinner + if (segments.length === 0) { + return ( + + {message.content ? ( + + {message.content} + {BlinkingCursor} + + ) : ( + + + + {gettext('Generating response...')} + + + )} + + ); + } + + // Render markdown segments in order + const lastIdx = segments.length - 1; + return ( + + {segments.map((seg, idx) => { + const isLast = idx === lastIdx; + const cursor = isLast && !seg.complete ? BlinkingCursor : null; + + if (seg.type === 'code') { + return ( + + + + {seg.complete + ? (seg.language || gettext('Code')) + : gettext('Generating...')} + + + + + {seg.content} + {cursor} + + + + ); + } + + const content = seg.content?.trim(); + if (!content && !cursor) return null; + return ( + 0 ? 1 : 0, display: 'inline' }}> + + {cursor} + + ); + })} + + ); + } + if (message.type === MESSAGE_TYPES.ERROR) { return ( - {message.content} + ); } @@ -272,6 +590,8 @@ export function NLQChatPanel() { const readerRef = useRef(null); const stoppedRef = useRef(false); const clearedRef = useRef(false); + const streamingTextRef = useRef(''); + const streamingIdRef = useRef(null); const eventBus = useContext(QueryToolEventsContext); const queryToolCtx = useContext(QueryToolContext); const editorPrefs = usePreferences().getPreferencesForModule('editor'); @@ -448,9 +768,11 @@ export function NLQChatPanel() { const handleSubmit = async () => { if (!inputValue.trim() || isLoading) return; - // Reset stopped and cleared flags + // Reset stopped, cleared flags and streaming state stoppedRef.current = false; clearedRef.current = false; + streamingTextRef.current = ''; + streamingIdRef.current = null; // Fetch latest LLM provider/model info before submitting fetchLlmInfo(); @@ -553,36 +875,61 @@ export function NLQChatPanel() { // Check if user manually stopped (but not cleared) if (stoppedRef.current && !clearedRef.current) { - setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), - { - type: MESSAGE_TYPES.ASSISTANT, - content: gettext('Generation stopped.'), - }, - ]); + const streamId = streamingIdRef.current; + // If we have partial streaming content, show it as-is + if (streamingTextRef.current) { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), + { + type: MESSAGE_TYPES.ASSISTANT, + content: streamingTextRef.current + '\n\n' + gettext('(Generation stopped)'), + }, + ]); + } else { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId), + { + type: MESSAGE_TYPES.ASSISTANT, + content: gettext('Generation stopped.'), + }, + ]); + } + streamingTextRef.current = ''; + streamingIdRef.current = null; } } catch (error) { clearTimeout(timeoutId); abortControllerRef.current = null; readerRef.current = null; + const streamId = streamingIdRef.current; // If conversation was cleared, ignore all late errors if (clearedRef.current) { // Do nothing - conversation was wiped } else if (error.name === 'AbortError') { // Check if this was a user-initiated stop or a timeout if (stoppedRef.current) { - // User manually stopped - setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), - { - type: MESSAGE_TYPES.ASSISTANT, - content: gettext('Generation stopped.'), - }, - ]); + // User manually stopped - show partial content if any + if (streamingTextRef.current) { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), + { + type: MESSAGE_TYPES.ASSISTANT, + content: streamingTextRef.current + '\n\n' + gettext('(Generation stopped)'), + }, + ]); + } else { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId), + { + type: MESSAGE_TYPES.ASSISTANT, + content: gettext('Generation stopped.'), + }, + ]); + } } else { // Timeout occurred setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), { type: MESSAGE_TYPES.ERROR, content: gettext('Request timed out. The query may be too complex. Please try a simpler request.'), @@ -591,13 +938,15 @@ export function NLQChatPanel() { } } else { setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), { type: MESSAGE_TYPES.ERROR, content: gettext('Failed to generate SQL: ') + error.message, }, ]); } + streamingTextRef.current = ''; + streamingIdRef.current = null; } finally { setIsLoading(false); setThinkingMessageId(null); @@ -606,32 +955,78 @@ export function NLQChatPanel() { const handleSSEEvent = (event, thinkingId) => { switch (event.type) { - case 'thinking': - setMessages((prev) => - prev.map((m) => - m.id === thinkingId ? { ...m, content: event.message } : m - ) - ); + case 'thinking': { + const streamId = streamingIdRef.current; + if (streamId) { + // Transition from streaming back to thinking (tool use) + // Remove streaming message and re-add thinking indicator + streamingTextRef.current = ''; + streamingIdRef.current = null; + setMessages((prev) => [ + ...prev.filter((m) => m.id !== streamId), + { + type: MESSAGE_TYPES.THINKING, + content: event.message, + id: thinkingId, + }, + ]); + setThinkingMessageId(thinkingId); + } else { + setMessages((prev) => + prev.map((m) => + m.id === thinkingId ? { ...m, content: event.message } : m + ) + ); + } break; + } - case 'sql': - case 'complete': - // If sql is null/empty, show as regular assistant message (e.g., clarification questions) - if (!event.sql) { + case 'text_delta': + streamingTextRef.current += event.content; + if (!streamingIdRef.current) { + // First text chunk: replace thinking with streaming message + streamingIdRef.current = Date.now(); setMessages((prev) => [ ...prev.filter((m) => m.id !== thinkingId), { - type: MESSAGE_TYPES.ASSISTANT, - content: event.explanation || gettext('I need more information to generate the SQL.'), + type: MESSAGE_TYPES.STREAMING, + content: streamingTextRef.current, + id: streamingIdRef.current, }, ]); } else { + // Update existing streaming message + const sid = streamingIdRef.current; + setMessages((prev) => + prev.map((m) => + m.id === sid ? { ...m, content: streamingTextRef.current } : m + ) + ); + } + break; + + case 'sql': + case 'complete': { + const streamId = streamingIdRef.current; + const content = event.content || event.explanation + || gettext('I need more information to generate the SQL.'); + // Use SQL type if there's SQL or any code fences in the response + const hasCodeBlocks = event.sql || (content && content.includes('```')); + if (hasCodeBlocks) { setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), { type: MESSAGE_TYPES.SQL, - sql: formatSqlWithPrefs(event.sql), - explanation: event.explanation, + content, + sql: event.sql, + }, + ]); + } else { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), + { + type: MESSAGE_TYPES.ASSISTANT, + content, }, ]); } @@ -641,16 +1036,22 @@ export function NLQChatPanel() { if (event.history) { setConversationHistory(event.history); } + // Reset streaming state + streamingTextRef.current = ''; + streamingIdRef.current = null; break; + } case 'error': setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamingIdRef.current), { type: MESSAGE_TYPES.ERROR, content: event.message, }, ]); + streamingTextRef.current = ''; + streamingIdRef.current = null; break; } }; @@ -745,6 +1146,7 @@ export function NLQChatPanel() { onReplaceSQL={handleReplaceSQL} textColors={textColors} cmKey={cmKey} + formatSqlWithPrefs={formatSqlWithPrefs} /> )) )} diff --git a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py index 38feff067e1..52e3f28b93c 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py +++ b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py @@ -10,6 +10,7 @@ """Tests for the NLQ (Natural Language Query) chat endpoint.""" import json +import re from unittest.mock import patch, MagicMock from pgadmin.utils.route import BaseTestGenerator @@ -43,8 +44,9 @@ class NLQChatTestCase(BaseTestGenerator): message='Find all users', expected_error=False, mock_response=( - '{"sql": "SELECT * FROM users;", ' - '"explanation": "Gets all users"}' + 'Here are all users:\n\n' + '```sql\nSELECT * FROM users;\n```\n\n' + 'This retrieves all rows from the users table.' ) )), ('NLQ Chat - With History', dict( @@ -108,16 +110,15 @@ def runTest(self): ) patches.append(mock_check_trans) - # Mock chat_with_database — patch the source module because the - # endpoint uses a local import (from pgadmin.llm.chat import ...) - # inside the function body, so there is no module-level binding - # to patch at the use site. + # Mock chat_with_database_stream mock_chat_patcher = None - mock_chat_obj = None if hasattr(self, 'mock_response'): + def mock_stream_gen(*args, **kwargs): + yield self.mock_response + yield (self.mock_response, []) mock_chat_patcher = patch( - 'pgadmin.llm.chat.chat_with_database', - return_value=(self.mock_response, []) + 'pgadmin.llm.chat.chat_with_database_stream', + side_effect=mock_stream_gen ) patches.append(mock_chat_patcher) @@ -133,8 +134,6 @@ def runTest(self): for p in patches: m = p.start() started_mocks.append(m) - if p is mock_chat_patcher: - mock_chat_obj = m try: # Make request @@ -166,22 +165,9 @@ def runTest(self): self.assertIn('text/event-stream', response.content_type) # Consume the SSE stream so the generator executes - # fully (including the chat_with_database call) + # fully (including the chat_with_database_stream call) _ = response.data - # Verify history was passed to chat_with_database - if hasattr(self, 'history') and mock_chat_obj: - mock_chat_obj.assert_called_once() - call_kwargs = mock_chat_obj.call_args.kwargs - conv_hist = call_kwargs.get( - 'conversation_history', [] - ) - self.assertTrue( - len(conv_hist) > 0, - 'conversation_history should be non-empty ' - 'when history is provided' - ) - finally: # Stop all patches for p in patches: @@ -216,3 +202,248 @@ def runTest(self): def tearDown(self): pass + + +class NLQSqlExtractionTestCase(BaseTestGenerator): + """Test cases for SQL extraction from markdown responses""" + + scenarios = [ + ('SQL Extraction - Single SQL block', dict( + response_text=( + 'Here is the query:\n\n' + '```sql\nSELECT * FROM users;\n```\n\n' + 'This returns all users.' + ), + expected_sql='SELECT * FROM users;' + )), + ('SQL Extraction - Multiple SQL blocks', dict( + response_text=( + 'First get users:\n\n' + '```sql\nSELECT * FROM users;\n```\n\n' + 'Then get orders:\n\n' + '```sql\nSELECT * FROM orders;\n```' + ), + expected_sql='SELECT * FROM users;;\n\nSELECT * FROM orders;' + )), + ('SQL Extraction - pgsql language tag', dict( + response_text='```pgsql\nSELECT 1;\n```', + expected_sql='SELECT 1;' + )), + ('SQL Extraction - postgresql language tag', dict( + response_text='```postgresql\nSELECT 1;\n```', + expected_sql='SELECT 1;' + )), + ('SQL Extraction - No SQL blocks', dict( + response_text=( + 'I cannot generate a query without ' + 'knowing your table structure.' + ), + expected_sql=None + )), + ('SQL Extraction - Non-SQL code block only', dict( + response_text=( + 'Here is some Python:\n\n' + '```python\nprint("hello")\n```' + ), + expected_sql=None + )), + ('SQL Extraction - JSON fallback', dict( + response_text='{"sql": "SELECT 1;", "explanation": "test"}', + expected_sql='SELECT 1;' + )), + ('SQL Extraction - Multiline SQL', dict( + response_text=( + '```sql\n' + 'SELECT u.name, o.total\n' + 'FROM users u\n' + 'JOIN orders o ON u.id = o.user_id\n' + 'WHERE o.total > 100;\n' + '```' + ), + expected_sql=( + 'SELECT u.name, o.total\n' + 'FROM users u\n' + 'JOIN orders o ON u.id = o.user_id\n' + 'WHERE o.total > 100;' + ) + )), + ] + + def setUp(self): + pass + + def runTest(self): + """Test SQL extraction from markdown response text""" + response_text = self.response_text + + # Extract SQL using the same regex as the endpoint + sql_blocks = re.findall( + r'```(?:sql|pgsql|postgresql)\s*\n(.*?)```', + response_text, + re.DOTALL + ) + sql = ';\n\n'.join( + block.strip() for block in sql_blocks + ) if sql_blocks else None + + # JSON fallback + if sql is None: + try: + result = json.loads(response_text.strip()) + if isinstance(result, dict): + sql = result.get('sql') + except (json.JSONDecodeError, TypeError): + pass + + self.assertEqual(sql, self.expected_sql) + + def tearDown(self): + pass + + +class NLQStreamingSSETestCase(BaseTestGenerator): + """Test cases for SSE event format in streaming responses""" + + scenarios = [ + ('SSE - Text with SQL produces complete event', dict( + mock_response=( + '```sql\nSELECT 1;\n```' + ), + check_complete_has_sql=True + )), + ('SSE - Text without SQL has no sql field', dict( + mock_response='I need more information about your schema.', + check_complete_has_sql=False + )), + ] + + def setUp(self): + pass + + def runTest(self): + """Test SSE events from NLQ streaming endpoint""" + trans_id = 12345 + + patches = [] + + mock_llm_enabled = patch( + 'pgadmin.llm.utils.is_llm_enabled', + return_value=True + ) + patches.append(mock_llm_enabled) + + mock_trans_obj = MagicMock() + mock_trans_obj.sid = 1 + mock_trans_obj.did = 1 + + mock_conn = MagicMock() + mock_conn.connected.return_value = True + + mock_session = {'sid': 1, 'did': 1} + + mock_check_trans = patch( + 'pgadmin.tools.sqleditor.check_transaction_status', + return_value=( + True, None, mock_conn, mock_trans_obj, mock_session + ) + ) + patches.append(mock_check_trans) + + def mock_stream_gen(*args, **kwargs): + # Yield text chunks + for chunk in [self.mock_response[i:i + 10] + for i in range(0, len(self.mock_response), 10)]: + yield chunk + # Yield final tuple + yield (self.mock_response, []) + + mock_chat = patch( + 'pgadmin.llm.chat.chat_with_database_stream', + side_effect=mock_stream_gen + ) + patches.append(mock_chat) + + mock_csrf = patch( + 'pgadmin.authenticate.mfa.utils.mfa_required', + lambda f: f + ) + patches.append(mock_csrf) + + for p in patches: + p.start() + + try: + response = self.tester.post( + f'/sqleditor/nlq/chat/{trans_id}/stream', + data=json.dumps({'message': 'test query'}), + content_type='application/json', + follow_redirects=True + ) + + self.assertEqual(response.status_code, 200) + self.assertIn('text/event-stream', response.content_type) + + # Parse SSE events + events = [] + raw = response.data.decode('utf-8') + for line in raw.split('\n'): + if line.startswith('data: '): + try: + events.append(json.loads(line[6:])) + except json.JSONDecodeError: + pass + + # Should have at least one text_delta and one complete + event_types = [e.get('type') for e in events] + self.assertIn('text_delta', event_types) + self.assertIn('complete', event_types) + + # Check the complete event + complete_events = [ + e for e in events if e.get('type') == 'complete' + ] + self.assertEqual(len(complete_events), 1) + complete = complete_events[0] + + # Verify content is present + self.assertIn('content', complete) + self.assertEqual(complete['content'], self.mock_response) + + # Verify SQL extraction + if self.check_complete_has_sql: + self.assertIsNotNone(complete.get('sql')) + else: + self.assertIsNone(complete.get('sql')) + + finally: + for p in patches: + p.stop() + + def tearDown(self): + pass + + +class NLQPromptMarkdownFormatTestCase(BaseTestGenerator): + """Test that NLQ prompt instructs markdown code fences""" + + scenarios = [ + ('NLQ Prompt - Markdown format', dict()), + ] + + def setUp(self): + pass + + def runTest(self): + """Test NLQ prompt requires markdown SQL code fences""" + from pgadmin.llm.prompts.nlq import NLQ_SYSTEM_PROMPT + + # Prompt should instruct use of fenced code blocks + self.assertIn('fenced code block', NLQ_SYSTEM_PROMPT.lower()) + self.assertIn('sql', NLQ_SYSTEM_PROMPT.lower()) + + # Should NOT instruct JSON format + self.assertNotIn('"sql":', NLQ_SYSTEM_PROMPT) + self.assertNotIn('"explanation":', NLQ_SYSTEM_PROMPT) + + def tearDown(self): + pass From d3e66726cd772c634935887999c19ff13ef85a7d Mon Sep 17 00:00:00 2001 From: Dave Page Date: Thu, 12 Mar 2026 10:01:27 +0000 Subject: [PATCH 2/8] Address CodeRabbit review feedback for streaming and SQL extraction. - Anthropic: preserve separators between text blocks in streaming to match _parse_response() behavior. - Docker: validate that the API URL points to a loopback address to constrain the request surface. - Docker/OpenAI: raise LLMClientError on empty streams instead of yielding blank LLMResponse objects, matching non-streaming behavior. - SQL extraction: strip trailing semicolons before joining blocks to avoid double semicolons in output. Co-Authored-By: Claude Opus 4.6 --- web/pgadmin/llm/providers/anthropic.py | 8 ++++++ web/pgadmin/llm/providers/docker.py | 28 +++++++++++++++++++ web/pgadmin/llm/providers/openai.py | 7 +++++ web/pgadmin/tools/sqleditor/__init__.py | 2 +- .../tools/sqleditor/tests/test_nlq_chat.py | 2 +- 5 files changed, 45 insertions(+), 2 deletions(-) diff --git a/web/pgadmin/llm/providers/anthropic.py b/web/pgadmin/llm/providers/anthropic.py index 03284d7bf24..86081474b41 100644 --- a/web/pgadmin/llm/providers/anthropic.py +++ b/web/pgadmin/llm/providers/anthropic.py @@ -392,6 +392,7 @@ def _read_anthropic_stream( stop_reason_str = None model_name = self._model usage = Usage() + in_text_block = False while True: line_bytes = response.readline() @@ -437,6 +438,13 @@ def _read_anthropic_stream( 'name': block.get('name', '') } tool_input_json = '' + elif block.get('type') == 'text': + # Emit a separator between text blocks to + # match _parse_response() which joins with '\n' + if in_text_block: + content_parts.append('\n') + yield '\n' + in_text_block = True elif event_type == 'content_block_delta': delta = data.get('delta', {}) diff --git a/web/pgadmin/llm/providers/docker.py b/web/pgadmin/llm/providers/docker.py index fc373b9dfdf..52132827e67 100644 --- a/web/pgadmin/llm/providers/docker.py +++ b/web/pgadmin/llm/providers/docker.py @@ -16,6 +16,7 @@ import json import socket import ssl +import urllib.parse import urllib.request import urllib.error from collections.abc import Generator @@ -43,6 +44,25 @@ DEFAULT_API_URL = 'http://localhost:12434' DEFAULT_MODEL = 'ai/qwen3-coder' +# Allowed loopback hostnames for the Docker endpoint +_LOOPBACK_HOSTS = {'localhost', '127.0.0.1', '::1', '[::1]'} + + +def _validate_loopback_url(url: str) -> None: + """Ensure the URL uses HTTP(S) and points to a loopback address.""" + parsed = urllib.parse.urlparse(url) + if parsed.scheme not in ('http', 'https'): + raise ValueError( + f"Docker Model Runner URL must use http or https, " + f"got: {parsed.scheme}" + ) + hostname = (parsed.hostname or '').lower() + if hostname not in _LOOPBACK_HOSTS: + raise ValueError( + f"Docker Model Runner URL must point to a loopback address " + f"(localhost/127.0.0.1/::1), got: {hostname}" + ) + class DockerClient(LLMClient): """ @@ -64,6 +84,7 @@ def __init__( model: Optional model name. Defaults to ai/qwen3-coder. """ self._api_url = (api_url or DEFAULT_API_URL).rstrip('/') + _validate_loopback_url(self._api_url) self._model = model or DEFAULT_MODEL @property @@ -554,6 +575,13 @@ def _read_openai_stream( finish_reason or '', StopReason.UNKNOWN ) + if not content and not tool_calls: + raise LLMClientError(LLMError( + message='No response content returned from API', + provider=self.provider_name, + retryable=False + )) + yield LLMResponse( content=content, tool_calls=tool_calls, diff --git a/web/pgadmin/llm/providers/openai.py b/web/pgadmin/llm/providers/openai.py index b0cafcfee2b..e8653a3acb4 100644 --- a/web/pgadmin/llm/providers/openai.py +++ b/web/pgadmin/llm/providers/openai.py @@ -559,6 +559,13 @@ def _read_openai_stream( finish_reason or '', StopReason.UNKNOWN ) + if not content and not tool_calls: + raise LLMClientError(LLMError( + message='No response content returned from API', + provider=self.provider_name, + retryable=False + )) + yield LLMResponse( content=content, tool_calls=tool_calls, diff --git a/web/pgadmin/tools/sqleditor/__init__.py b/web/pgadmin/tools/sqleditor/__init__.py index 1d98e1b6341..44980526be2 100644 --- a/web/pgadmin/tools/sqleditor/__init__.py +++ b/web/pgadmin/tools/sqleditor/__init__.py @@ -2946,7 +2946,7 @@ def generate(): re.DOTALL ) sql = ';\n\n'.join( - block.strip() for block in sql_blocks + block.strip().rstrip(';') for block in sql_blocks ) if sql_blocks else None # Fallback: try JSON format in case LLM ignored diff --git a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py index 52e3f28b93c..0cfbbf2ea88 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py +++ b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py @@ -223,7 +223,7 @@ class NLQSqlExtractionTestCase(BaseTestGenerator): 'Then get orders:\n\n' '```sql\nSELECT * FROM orders;\n```' ), - expected_sql='SELECT * FROM users;;\n\nSELECT * FROM orders;' + expected_sql='SELECT * FROM users;\n\nSELECT * FROM orders' )), ('SQL Extraction - pgsql language tag', dict( response_text='```pgsql\nSELECT 1;\n```', From dde88025db860dffb6a1d2f0d36f184071117610 Mon Sep 17 00:00:00 2001 From: Dave Page Date: Mon, 16 Mar 2026 11:08:50 +0000 Subject: [PATCH 3/8] Address remaining CodeRabbit review feedback for streaming and rendering. - Use distinct 3-tuple ('complete', text, messages) for completion events to avoid ambiguity with ('tool_use', [...]) 2-tuples in chat streaming. - Pass conversation history from request into chat_with_database_stream() so follow-up NLQ turns retain context. - Add re.IGNORECASE to SQL fence regex for case-insensitive matching. - Render MarkdownContent as block element instead of span to avoid invalid DOM when response contains paragraphs, lists, or tables. - Keep stop notice as a separate message instead of appending to partial markdown, preventing it from being swallowed by open code fences. - Snapshot streamingIdRef before setMessages in error handler to avoid race condition where ref is cleared before React executes the updater. Co-Authored-By: Claude Opus 4.6 (1M context) --- web/pgadmin/llm/chat.py | 9 ++++--- web/pgadmin/tools/sqleditor/__init__.py | 13 +++++----- .../js/components/sections/NLQChatPanel.jsx | 25 +++++++++++++------ .../tools/sqleditor/tests/test_nlq_chat.py | 8 +++--- 4 files changed, 33 insertions(+), 22 deletions(-) diff --git a/web/pgadmin/llm/chat.py b/web/pgadmin/llm/chat.py index 8c9fc1593eb..2e32118d275 100644 --- a/web/pgadmin/llm/chat.py +++ b/web/pgadmin/llm/chat.py @@ -174,8 +174,8 @@ def chat_with_database_stream( Yields: str: Text content chunks from the final LLM response. - The last item yielded is a tuple of - (final_response_text, updated_conversation_history). + The last item yielded is a 3-tuple of + ('complete', final_response_text, updated_conversation_history). Raises: LLMClientError: If the LLM request fails. @@ -220,8 +220,9 @@ def chat_with_database_stream( messages.append(response.to_message()) if response.stop_reason != StopReason.TOOL_USE: - # Final response - yield the completion tuple - yield (response.content, messages) + # Final response - yield a 3-tuple to distinguish from + # the 2-tuple tool_use event + yield ('complete', response.content, messages) return # Signal that tools are being executed so the caller can diff --git a/web/pgadmin/tools/sqleditor/__init__.py b/web/pgadmin/tools/sqleditor/__init__.py index 44980526be2..7cd11e6ca3a 100644 --- a/web/pgadmin/tools/sqleditor/__init__.py +++ b/web/pgadmin/tools/sqleditor/__init__.py @@ -2933,17 +2933,18 @@ def generate(): 'Querying the database...' ) }) - elif isinstance(item, tuple): - # Final result: (response_text, messages) - response_text = item[0] - if len(item) > 1: - updated_history = item[1] + elif isinstance(item, tuple) and \ + item[0] == 'complete': + # Final result: ('complete', response_text, messages) + response_text = item[1] + if len(item) > 2: + updated_history = item[2] # Extract SQL from markdown code fences sql_blocks = re.findall( r'```(?:sql|pgsql|postgresql)\s*\n(.*?)```', response_text, - re.DOTALL + re.DOTALL | re.IGNORECASE ) sql = ';\n\n'.join( block.strip().rstrip(';') for block in sql_blocks diff --git a/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx b/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx index 1bed2427b04..f2f666f8516 100644 --- a/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx +++ b/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx @@ -527,9 +527,8 @@ function ChatMessage({ message, onInsertSQL, onReplaceSQL, textColors, cmKey, fo const content = seg.content?.trim(); if (!content && !cursor) return null; return ( - 0 ? 1 : 0, display: 'inline' }}> + 0 ? 1 : 0 }}> {cursor} @@ -876,13 +875,18 @@ export function NLQChatPanel() { // Check if user manually stopped (but not cleared) if (stoppedRef.current && !clearedRef.current) { const streamId = streamingIdRef.current; - // If we have partial streaming content, show it as-is + // If we have partial streaming content, show it separately + // from the stop notice to avoid breaking open markdown fences if (streamingTextRef.current) { setMessages((prev) => [ ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), { type: MESSAGE_TYPES.ASSISTANT, - content: streamingTextRef.current + '\n\n' + gettext('(Generation stopped)'), + content: streamingTextRef.current, + }, + { + type: MESSAGE_TYPES.ASSISTANT, + content: gettext('Generation stopped.'), }, ]); } else { @@ -908,13 +912,17 @@ export function NLQChatPanel() { } else if (error.name === 'AbortError') { // Check if this was a user-initiated stop or a timeout if (stoppedRef.current) { - // User manually stopped - show partial content if any + // User manually stopped - show partial content separately if (streamingTextRef.current) { setMessages((prev) => [ ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), { type: MESSAGE_TYPES.ASSISTANT, - content: streamingTextRef.current + '\n\n' + gettext('(Generation stopped)'), + content: streamingTextRef.current, + }, + { + type: MESSAGE_TYPES.ASSISTANT, + content: gettext('Generation stopped.'), }, ]); } else { @@ -1042,9 +1050,10 @@ export function NLQChatPanel() { break; } - case 'error': + case 'error': { + const streamId = streamingIdRef.current; setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId && m.id !== streamingIdRef.current), + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), { type: MESSAGE_TYPES.ERROR, content: event.message, diff --git a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py index 0cfbbf2ea88..57805045fce 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py +++ b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py @@ -115,7 +115,7 @@ def runTest(self): if hasattr(self, 'mock_response'): def mock_stream_gen(*args, **kwargs): yield self.mock_response - yield (self.mock_response, []) + yield ('complete', self.mock_response, []) mock_chat_patcher = patch( 'pgadmin.llm.chat.chat_with_database_stream', side_effect=mock_stream_gen @@ -280,7 +280,7 @@ def runTest(self): sql_blocks = re.findall( r'```(?:sql|pgsql|postgresql)\s*\n(.*?)```', response_text, - re.DOTALL + re.DOTALL | re.IGNORECASE ) sql = ';\n\n'.join( block.strip() for block in sql_blocks @@ -354,8 +354,8 @@ def mock_stream_gen(*args, **kwargs): for chunk in [self.mock_response[i:i + 10] for i in range(0, len(self.mock_response), 10)]: yield chunk - # Yield final tuple - yield (self.mock_response, []) + # Yield final 3-tuple + yield ('complete', self.mock_response, []) mock_chat = patch( 'pgadmin.llm.chat.chat_with_database_stream', From 8d89076bdf6c4b96498e7846c0e14709f3724a6c Mon Sep 17 00:00:00 2001 From: Dave Page Date: Mon, 16 Mar 2026 11:49:14 +0000 Subject: [PATCH 4/8] Address CodeRabbit review feedback for streaming providers and history. - Fix critical NameError: use self._api_url instead of undefined API_URL in anthropic and openai streaming _process_stream() methods. - Match sync path auth handling: conditionally set API key headers in streaming paths for both anthropic and openai providers. - Remove unconditional temperature from openai streaming payload to match sync path compatibility approach. - Add URL scheme validation to OllamaClient.__init__ to prevent unsafe local/resource access via non-http schemes. - Guard ollama streaming finalizer: raise error when stream drops without a done frame and no content was received. - Update chat.py type hint and docstring for 3-tuple completion event. - Serialize and return filtered conversation history in the complete SSE event so the client can round-trip it on follow-up turns. - Store and send conversation history from NLQChatPanel, clear on conversation reset. - Fix JSON-fallback SQL render path: clear content when SQL was extracted without fenced blocks so ChatMessage uses sql-only renderer. Co-Authored-By: Claude Opus 4.6 (1M context) --- web/pgadmin/llm/chat.py | 2 +- web/pgadmin/llm/providers/anthropic.py | 6 +++-- web/pgadmin/llm/providers/ollama.py | 21 +++++++++++++-- web/pgadmin/llm/providers/openai.py | 7 ++--- web/pgadmin/tools/sqleditor/__init__.py | 26 +++++++++---------- .../js/components/sections/NLQChatPanel.jsx | 7 +++-- 6 files changed, 46 insertions(+), 23 deletions(-) diff --git a/web/pgadmin/llm/chat.py b/web/pgadmin/llm/chat.py index 2e32118d275..68759a85c94 100644 --- a/web/pgadmin/llm/chat.py +++ b/web/pgadmin/llm/chat.py @@ -163,7 +163,7 @@ def chat_with_database_stream( max_tool_iterations: Optional[int] = None, provider: Optional[str] = None, model: Optional[str] = None -) -> Generator[Union[str, tuple[str, list[Message]]], None, None]: +) -> Generator[Union[str, tuple], None, None]: """ Stream an LLM chat conversation with database tool access. diff --git a/web/pgadmin/llm/providers/anthropic.py b/web/pgadmin/llm/providers/anthropic.py index 86081474b41..c7b91a56e8e 100644 --- a/web/pgadmin/llm/providers/anthropic.py +++ b/web/pgadmin/llm/providers/anthropic.py @@ -329,12 +329,14 @@ def _process_stream( """Make a streaming request and yield chunks.""" headers = { 'Content-Type': 'application/json', - 'x-api-key': self._api_key, 'anthropic-version': API_VERSION } + if self._api_key: + headers['x-api-key'] = self._api_key + request = urllib.request.Request( - API_URL, + self._api_url, data=json.dumps(payload).encode('utf-8'), headers=headers, method='POST' diff --git a/web/pgadmin/llm/providers/ollama.py b/web/pgadmin/llm/providers/ollama.py index 2432ee2c48c..5e3cbe09e20 100644 --- a/web/pgadmin/llm/providers/ollama.py +++ b/web/pgadmin/llm/providers/ollama.py @@ -10,6 +10,7 @@ """Ollama LLM client implementation.""" import json +import urllib.parse import urllib.request import urllib.error from collections.abc import Generator @@ -47,6 +48,14 @@ def __init__(self, api_url: str, model: Optional[str] = None): self._api_url = api_url.rstrip('/') self._model = model or DEFAULT_MODEL + # Validate URL scheme to prevent unsafe access + parsed = urllib.parse.urlparse(self._api_url) + if parsed.scheme not in ('http', 'https'): + raise ValueError( + f"Ollama URL must use http or https scheme, " + f"got: {parsed.scheme}" + ) + @property def provider_name(self) -> str: return 'ollama' @@ -425,7 +434,15 @@ def _read_ollama_stream( input_tokens = data.get('prompt_eval_count', 0) output_tokens = data.get('eval_count', 0) - # Build final response + # Build final response — only if the stream completed normally + content = ''.join(content_parts) + if final_data is None and not content and not tool_calls: + raise LLMClientError(LLMError( + message="Stream ended without a complete response", + provider=self.provider_name, + retryable=True + )) + if tool_calls: stop_reason = StopReason.TOOL_USE elif done_reason == 'stop': @@ -436,7 +453,7 @@ def _read_ollama_stream( stop_reason = StopReason.UNKNOWN yield LLMResponse( - content=''.join(content_parts), + content=content, tool_calls=tool_calls, stop_reason=stop_reason, model=model_name, diff --git a/web/pgadmin/llm/providers/openai.py b/web/pgadmin/llm/providers/openai.py index e8653a3acb4..2b0e2072917 100644 --- a/web/pgadmin/llm/providers/openai.py +++ b/web/pgadmin/llm/providers/openai.py @@ -379,7 +379,6 @@ def chat_stream( 'model': self._model, 'messages': converted_messages, 'max_completion_tokens': max_tokens, - 'temperature': temperature, 'stream': True, 'stream_options': {'include_usage': True} } @@ -404,11 +403,13 @@ def _process_stream( """Make a streaming request and yield chunks.""" headers = { 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self._api_key}' } + if self._api_key: + headers['Authorization'] = f'Bearer {self._api_key}' + request = urllib.request.Request( - API_URL, + self._api_url, data=json.dumps(payload).encode('utf-8'), headers=headers, method='POST' diff --git a/web/pgadmin/tools/sqleditor/__init__.py b/web/pgadmin/tools/sqleditor/__init__.py index 7cd11e6ca3a..dbfcb3d000d 100644 --- a/web/pgadmin/tools/sqleditor/__init__.py +++ b/web/pgadmin/tools/sqleditor/__init__.py @@ -2910,7 +2910,7 @@ def generate(): # Stream the LLM response with database tools response_text = '' - updated_history = [] + updated_messages = [] for item in chat_with_database_stream( user_message=user_message, sid=trans_obj.sid, @@ -2937,8 +2937,7 @@ def generate(): item[0] == 'complete': # Final result: ('complete', response_text, messages) response_text = item[1] - if len(item) > 2: - updated_history = item[2] + updated_messages = item[2] # Extract SQL from markdown code fences sql_blocks = re.findall( @@ -2966,16 +2965,17 @@ def generate(): else: new_conversation_id = conversation_id - # Serialize updated history for the frontend. - # Only include conversational messages (user + final - # assistant responses) to keep history size manageable. - # Internal tool call/result messages are ephemeral to - # each turn and don't need to round-trip. + # Filter and serialize the conversation history so the + # client can round-trip it on follow-up turns from pgadmin.llm.compaction import filter_conversational - serialized_history = [ - m.to_dict() for m in - filter_conversational(updated_history) - ] if updated_history else [] + filtered = filter_conversational(updated_messages) + history = [ + { + 'role': m.role.value, + 'content': m.content, + } + for m in filtered + ] # Send the final result with full response content yield _nlq_sse_event({ @@ -2983,7 +2983,7 @@ def generate(): 'sql': sql, 'content': response_text, 'conversation_id': new_conversation_id, - 'history': serialized_history + 'history': history }) except Exception as e: diff --git a/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx b/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx index f2f666f8516..ed04ad237c8 100644 --- a/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx +++ b/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx @@ -729,7 +729,6 @@ export function NLQChatPanel() { setMessages([]); setConversationId(null); setConversationHistory([]); - setIsLoading(false); }; // Stop the current request @@ -1021,11 +1020,15 @@ export function NLQChatPanel() { // Use SQL type if there's SQL or any code fences in the response const hasCodeBlocks = event.sql || (content && content.includes('```')); if (hasCodeBlocks) { + // When SQL was extracted via JSON fallback (no fenced blocks), + // clear content so ChatMessage uses the sql-only render path + const msgContent = (content && content.includes('```')) + ? content : null; setMessages((prev) => [ ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), { type: MESSAGE_TYPES.SQL, - content, + content: msgContent, sql: event.sql, }, ]); From 4c695959599c184eeb03595762233ab77b3ab25c Mon Sep 17 00:00:00 2001 From: Dave Page Date: Mon, 16 Mar 2026 13:06:40 +0000 Subject: [PATCH 5/8] Fix missing closing brace in NLQChatPanel switch statement. Adding block scoping to the error case introduced an unmatched brace that prevented the switch statement from closing properly, causing an eslint parse error. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../sqleditor/static/js/components/sections/NLQChatPanel.jsx | 1 + 1 file changed, 1 insertion(+) diff --git a/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx b/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx index ed04ad237c8..838f60c8831 100644 --- a/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx +++ b/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx @@ -1066,6 +1066,7 @@ export function NLQChatPanel() { streamingIdRef.current = null; break; } + } }; const handleKeyDown = (e) => { From 5470e944fd44b4d1292617ada92ddfc5e7362907 Mon Sep 17 00:00:00 2001 From: Dave Page Date: Mon, 16 Mar 2026 13:08:57 +0000 Subject: [PATCH 6/8] Fix missing compaction module and SQL extraction test. - Replace compaction module imports with inline history deserialization and filtering since compaction.py is on a different branch. - Add rstrip(';') to SQL extraction test to match production code, fixing double-semicolon assertion failure. Co-Authored-By: Claude Opus 4.6 (1M context) --- web/pgadmin/tools/sqleditor/__init__.py | 52 +++++++++++-------- .../tools/sqleditor/tests/test_nlq_chat.py | 2 +- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/web/pgadmin/tools/sqleditor/__init__.py b/web/pgadmin/tools/sqleditor/__init__.py index dbfcb3d000d..f7cb83ba0c5 100644 --- a/web/pgadmin/tools/sqleditor/__init__.py +++ b/web/pgadmin/tools/sqleditor/__init__.py @@ -2886,10 +2886,7 @@ def nlq_chat_stream(trans_id): def generate(): """Generator for SSE events.""" import secrets as py_secrets - from pgadmin.llm.compaction import ( - deserialize_history, compact_history - ) - from pgadmin.llm.utils import get_default_provider + from pgadmin.llm.models import Message, Role try: # Send thinking status @@ -2898,15 +2895,22 @@ def generate(): 'message': gettext('Analyzing your request...') }) - # Deserialize and compact conversation history + # Deserialize conversation history if provided conversation_history = None if history_data: - conversation_history = deserialize_history(history_data) - provider = get_default_provider() or 'openai' - conversation_history = compact_history( - conversation_history, - provider=provider - ) + conversation_history = [] + for item in (history_data or []): + if not isinstance(item, dict): + continue + role_str = item.get('role', '') + content = item.get('content', '') + try: + role = Role(role_str) + except ValueError: + continue + conversation_history.append( + Message(role=role, content=content) + ) # Stream the LLM response with database tools response_text = '' @@ -2965,17 +2969,21 @@ def generate(): else: new_conversation_id = conversation_id - # Filter and serialize the conversation history so the - # client can round-trip it on follow-up turns - from pgadmin.llm.compaction import filter_conversational - filtered = filter_conversational(updated_messages) - history = [ - { - 'role': m.role.value, - 'content': m.content, - } - for m in filtered - ] + # Serialize the conversation history so the client can + # round-trip it on follow-up turns. Only keep user + # messages and final assistant responses (no tool calls). + history = [] + for m in updated_messages: + if m.role == Role.USER: + history.append({ + 'role': m.role.value, + 'content': m.content, + }) + elif m.role == Role.ASSISTANT and not m.tool_calls: + history.append({ + 'role': m.role.value, + 'content': m.content, + }) # Send the final result with full response content yield _nlq_sse_event({ diff --git a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py index 57805045fce..07bed58ccc5 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py +++ b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py @@ -283,7 +283,7 @@ def runTest(self): re.DOTALL | re.IGNORECASE ) sql = ';\n\n'.join( - block.strip() for block in sql_blocks + block.strip().rstrip(';') for block in sql_blocks ) if sql_blocks else None # JSON fallback From 933a69e2093c7176cbaf286eef532380cbc0b37b Mon Sep 17 00:00:00 2001 From: Dave Page Date: Mon, 16 Mar 2026 13:16:49 +0000 Subject: [PATCH 7/8] Fix SQL extraction test expected values after rstrip(';') change. The rstrip(';') applied to each block before joining means single blocks and the last block in multi-block joins no longer have trailing semicolons. Update expected values to match. Co-Authored-By: Claude Opus 4.6 (1M context) --- web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py index 07bed58ccc5..360827abfbc 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py +++ b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py @@ -214,7 +214,7 @@ class NLQSqlExtractionTestCase(BaseTestGenerator): '```sql\nSELECT * FROM users;\n```\n\n' 'This returns all users.' ), - expected_sql='SELECT * FROM users;' + expected_sql='SELECT * FROM users' )), ('SQL Extraction - Multiple SQL blocks', dict( response_text=( @@ -227,11 +227,11 @@ class NLQSqlExtractionTestCase(BaseTestGenerator): )), ('SQL Extraction - pgsql language tag', dict( response_text='```pgsql\nSELECT 1;\n```', - expected_sql='SELECT 1;' + expected_sql='SELECT 1' )), ('SQL Extraction - postgresql language tag', dict( response_text='```postgresql\nSELECT 1;\n```', - expected_sql='SELECT 1;' + expected_sql='SELECT 1' )), ('SQL Extraction - No SQL blocks', dict( response_text=( @@ -264,7 +264,7 @@ class NLQSqlExtractionTestCase(BaseTestGenerator): 'SELECT u.name, o.total\n' 'FROM users u\n' 'JOIN orders o ON u.id = o.user_id\n' - 'WHERE o.total > 100;' + 'WHERE o.total > 100' ) )), ] From 73bc218db9e8c324a08aa9d8ab9da980bd8ebb5a Mon Sep 17 00:00:00 2001 From: Dave Page Date: Mon, 16 Mar 2026 15:34:32 +0000 Subject: [PATCH 8/8] Strictly guard Ollama stream: raise if no terminal done frame received. Truncated content from a dropped connection should not be treated as a complete response, even if partial text was streamed. Always raise when final_data is None, matching CodeRabbit's recommendation. Co-Authored-By: Claude Opus 4.6 (1M context) --- web/pgadmin/llm/providers/ollama.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/web/pgadmin/llm/providers/ollama.py b/web/pgadmin/llm/providers/ollama.py index 5e3cbe09e20..1706e3d8ead 100644 --- a/web/pgadmin/llm/providers/ollama.py +++ b/web/pgadmin/llm/providers/ollama.py @@ -434,15 +434,17 @@ def _read_ollama_stream( input_tokens = data.get('prompt_eval_count', 0) output_tokens = data.get('eval_count', 0) - # Build final response — only if the stream completed normally - content = ''.join(content_parts) - if final_data is None and not content and not tool_calls: + # Ensure the stream completed with a terminal done frame; + # truncated content from a dropped connection is unreliable. + if final_data is None: raise LLMClientError(LLMError( - message="Stream ended without a complete response", + message="Ollama stream ended before terminal done frame", provider=self.provider_name, retryable=True )) + content = ''.join(content_parts) + if tool_calls: stop_reason = StopReason.TOOL_USE elif done_reason == 'stop':