From 6b121f54100e9616c759bab20c6d6f7fe101455c Mon Sep 17 00:00:00 2001 From: Muhammad Sohaib Date: Thu, 12 Feb 2026 17:40:21 +0500 Subject: [PATCH 1/4] Switch from Anthropic API to local Ollama LLM and add CLAUDE.md Replaced Anthropic Claude with Ollama (llama3.2) for free local inference. Changed RAG query flow to search-first approach instead of tool calling, since small local models can't handle tool calls reliably. Co-Authored-By: Claude Opus 4.6 --- .env.example | 5 +- CLAUDE.md | 53 +++++++++++++ backend/ai_generator.py | 166 ++++++++++++++-------------------------- backend/config.py | 6 +- backend/rag_system.py | 42 +++++----- backend/search_tools.py | 49 ++++++------ pyproject.toml | 2 +- uv.lock | 60 +-------------- 8 files changed, 166 insertions(+), 217 deletions(-) create mode 100644 CLAUDE.md diff --git a/.env.example b/.env.example index 18b34cb7e..e305c873f 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,3 @@ -# Copy this file to .env and add your actual API key -ANTHROPIC_API_KEY=your-anthropic-api-key-here \ No newline at end of file +# Ollama settings (free local LLM - no API key needed) +OLLAMA_BASE_URL=http://localhost:11434 +OLLAMA_MODEL=llama3.1:8b diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..f5fbc4653 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,53 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Running the Application + +```bash +# Install dependencies +uv sync + +# Start the server (from project root) +cd backend && uv run uvicorn app:app --reload --port 8000 + +# Or use the shell script +./run.sh +``` + +Requires Ollama running locally with a model installed (configured in `.env`). +The app is served at http://localhost:8000 with API docs at http://localhost:8000/docs. + +## Environment + +- Python 3.13+ with `uv` package manager +- **Always use `uv` to run commands and manage dependencies. Never use `pip` directly.** +- `.env` file in project root with `OLLAMA_BASE_URL` and `OLLAMA_MODEL` +- No test suite exists currently + +## Architecture + +This is a RAG (Retrieval-Augmented Generation) chatbot for course materials. FastAPI backend serves both the API and the frontend static files. + +**Query flow:** Frontend → `app.py` (FastAPI) → `rag_system.py` (orchestrator) → searches `vector_store.py` (ChromaDB) → sends query + retrieved chunks to `ai_generator.py` (Ollama LLM) → response back to frontend. + +**Key design decisions:** +- Search-first approach: every query searches ChromaDB before calling the LLM (no tool calling — small local models can't handle it reliably) +- Two ChromaDB collections: `course_catalog` (metadata for fuzzy course name matching) and `course_content` (chunked text for semantic search) +- Embeddings via `all-MiniLM-L6-v2` sentence-transformer +- Session history is in-memory only (lost on restart), capped at 2 exchanges per session +- Documents are chunked at 800 chars with 100 char overlap +- Course documents in `docs/` are auto-loaded on server startup + +**Backend modules:** +- `app.py` — API endpoints and static file serving +- `rag_system.py` — orchestrates search → LLM → session flow +- `vector_store.py` — ChromaDB wrapper with semantic search and course/lesson filtering +- `ai_generator.py` — Ollama API client +- `document_processor.py` — parses course .txt/.pdf/.docx files into structured chunks +- `search_tools.py` — search tool abstraction with source tracking +- `session_manager.py` — per-session conversation memory +- `models.py` — `Course`, `Lesson`, `CourseChunk` dataclasses +- `config.py` — loads settings from `.env` + +**Frontend:** plain HTML/JS/CSS in `frontend/`, uses `marked.js` for markdown rendering. No build step. diff --git a/backend/ai_generator.py b/backend/ai_generator.py index 0363ca90c..198a1a62c 100644 --- a/backend/ai_generator.py +++ b/backend/ai_generator.py @@ -1,25 +1,18 @@ -import anthropic +import httpx +import json from typing import List, Optional, Dict, Any + class AIGenerator: - """Handles interactions with Anthropic's Claude API for generating responses""" - - # Static system prompt to avoid rebuilding on each call - SYSTEM_PROMPT = """ You are an AI assistant specialized in course materials and educational content with access to a comprehensive search tool for course information. + """Handles interactions with Ollama's local LLM API for generating responses""" -Search Tool Usage: -- Use the search tool **only** for questions about specific course content or detailed educational materials -- **One search per query maximum** -- Synthesize search results into accurate, fact-based responses -- If search yields no results, state this clearly without offering alternatives + # Static system prompt to avoid rebuilding on each call + SYSTEM_PROMPT = """ You are an AI assistant specialized in course materials and educational content. Response Protocol: -- **General knowledge questions**: Answer using existing knowledge without searching -- **Course-specific questions**: Search first, then answer -- **No meta-commentary**: - - Provide direct answers only — no reasoning process, search explanations, or question-type analysis - - Do not mention "based on the search results" - +- Answer questions using the provided course context +- If the context doesn't contain relevant information, say so clearly +- Do not mention "based on the context" or "based on the search results" All responses must be: 1. **Brief, Concise and focused** - Get to the point quickly @@ -28,108 +21,61 @@ class AIGenerator: 4. **Example-supported** - Include relevant examples when they aid understanding Provide only the direct answer to what was asked. """ - - def __init__(self, api_key: str, model: str): - self.client = anthropic.Anthropic(api_key=api_key) + + def __init__(self, base_url: str, model: str): + self.base_url = base_url.rstrip("/") self.model = model - - # Pre-build base API parameters - self.base_params = { + # 120s timeout to handle first-request model loading + self.http_client = httpx.Client(timeout=120.0) + + def _call_ollama(self, messages: List[Dict]) -> Dict: + """Make a POST request to the Ollama chat API.""" + payload = { "model": self.model, - "temperature": 0, - "max_tokens": 800 + "messages": messages, + "stream": False, + "options": { + "temperature": 0, + "num_predict": 800 + } } - + + resp = self.http_client.post(f"{self.base_url}/api/chat", json=payload) + resp.raise_for_status() + return resp.json() + def generate_response(self, query: str, - conversation_history: Optional[str] = None, - tools: Optional[List] = None, - tool_manager=None) -> str: + context: str = "", + conversation_history: Optional[str] = None) -> str: """ - Generate AI response with optional tool usage and conversation context. - + Generate AI response using provided context. + Args: query: The user's question or request + context: Retrieved course content to answer from conversation_history: Previous messages for context - tools: Available tools the AI can use - tool_manager: Manager to execute tools - + Returns: Generated response as string """ - - # Build system content efficiently - avoid string ops when possible - system_content = ( - f"{self.SYSTEM_PROMPT}\n\nPrevious conversation:\n{conversation_history}" - if conversation_history - else self.SYSTEM_PROMPT - ) - - # Prepare API call parameters efficiently - api_params = { - **self.base_params, - "messages": [{"role": "user", "content": query}], - "system": system_content - } - - # Add tools if available - if tools: - api_params["tools"] = tools - api_params["tool_choice"] = {"type": "auto"} - - # Get response from Claude - response = self.client.messages.create(**api_params) - - # Handle tool execution if needed - if response.stop_reason == "tool_use" and tool_manager: - return self._handle_tool_execution(response, api_params, tool_manager) - - # Return direct response - return response.content[0].text - - def _handle_tool_execution(self, initial_response, base_params: Dict[str, Any], tool_manager): - """ - Handle execution of tool calls and get follow-up response. - - Args: - initial_response: The response containing tool use requests - base_params: Base API parameters - tool_manager: Manager to execute tools - - Returns: - Final response text after tool execution - """ - # Start with existing messages - messages = base_params["messages"].copy() - - # Add AI's tool use response - messages.append({"role": "assistant", "content": initial_response.content}) - - # Execute all tool calls and collect results - tool_results = [] - for content_block in initial_response.content: - if content_block.type == "tool_use": - tool_result = tool_manager.execute_tool( - content_block.name, - **content_block.input - ) - - tool_results.append({ - "type": "tool_result", - "tool_use_id": content_block.id, - "content": tool_result - }) - - # Add tool results as single message - if tool_results: - messages.append({"role": "user", "content": tool_results}) - - # Prepare final API call without tools - final_params = { - **self.base_params, - "messages": messages, - "system": base_params["system"] - } - - # Get final response - final_response = self.client.messages.create(**final_params) - return final_response.content[0].text \ No newline at end of file + + # Build system content + system_content = self.SYSTEM_PROMPT + if conversation_history: + system_content += f"\n\nPrevious conversation:\n{conversation_history}" + + # Build user message with context + if context: + user_content = f"Course context:\n{context}\n\nQuestion: {query}" + else: + user_content = query + + # Build messages list + messages = [ + {"role": "system", "content": system_content}, + {"role": "user", "content": user_content} + ] + + # Get response from Ollama + data = self._call_ollama(messages) + return data.get("message", {}).get("content", "") diff --git a/backend/config.py b/backend/config.py index d9f6392ef..30dc10079 100644 --- a/backend/config.py +++ b/backend/config.py @@ -8,9 +8,9 @@ @dataclass class Config: """Configuration settings for the RAG system""" - # Anthropic API settings - ANTHROPIC_API_KEY: str = os.getenv("ANTHROPIC_API_KEY", "") - ANTHROPIC_MODEL: str = "claude-sonnet-4-20250514" + # Ollama settings (free local LLM) + OLLAMA_BASE_URL: str = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") + OLLAMA_MODEL: str = os.getenv("OLLAMA_MODEL", "llama3.2") # Embedding model settings EMBEDDING_MODEL: str = "all-MiniLM-L6-v2" diff --git a/backend/rag_system.py b/backend/rag_system.py index 50d848c8e..b58df4b9d 100644 --- a/backend/rag_system.py +++ b/backend/rag_system.py @@ -16,7 +16,7 @@ def __init__(self, config): # Initialize core components self.document_processor = DocumentProcessor(config.CHUNK_SIZE, config.CHUNK_OVERLAP) self.vector_store = VectorStore(config.CHROMA_PATH, config.EMBEDDING_MODEL, config.MAX_RESULTS) - self.ai_generator = AIGenerator(config.ANTHROPIC_API_KEY, config.ANTHROPIC_MODEL) + self.ai_generator = AIGenerator(config.OLLAMA_BASE_URL, config.OLLAMA_MODEL) self.session_manager = SessionManager(config.MAX_HISTORY) # Initialize search tools @@ -101,42 +101,40 @@ def add_course_folder(self, folder_path: str, clear_existing: bool = False) -> T def query(self, query: str, session_id: Optional[str] = None) -> Tuple[str, List[str]]: """ - Process a user query using the RAG system with tool-based search. - + Process a user query using the RAG system. + + Searches course content first, then passes results as context to the LLM. + Args: query: User's question session_id: Optional session ID for conversation context - + Returns: - Tuple of (response, sources list - empty for tool-based approach) + Tuple of (response, sources list) """ - # Create prompt for the AI with clear instructions - prompt = f"""Answer this question about course materials: {query}""" - + # Search for relevant content first + context = self.search_tool.execute(query=query) + + # Get sources from the search + sources = self.search_tool.last_sources.copy() + self.search_tool.last_sources = [] + # Get conversation history if session exists history = None if session_id: history = self.session_manager.get_conversation_history(session_id) - - # Generate response using AI with tools + + # Generate response with retrieved context response = self.ai_generator.generate_response( - query=prompt, - conversation_history=history, - tools=self.tool_manager.get_tool_definitions(), - tool_manager=self.tool_manager + query=query, + context=context, + conversation_history=history ) - - # Get sources from the search tool - sources = self.tool_manager.get_last_sources() - # Reset sources after retrieving them - self.tool_manager.reset_sources() - # Update conversation history if session_id: self.session_manager.add_exchange(session_id, query, response) - - # Return response with sources from tool searches + return response, sources def get_course_analytics(self) -> Dict: diff --git a/backend/search_tools.py b/backend/search_tools.py index adfe82352..e044c81c7 100644 --- a/backend/search_tools.py +++ b/backend/search_tools.py @@ -8,7 +8,7 @@ class Tool(ABC): @abstractmethod def get_tool_definition(self) -> Dict[str, Any]: - """Return Anthropic tool definition for this tool""" + """Return OpenAI-compatible tool definition for Ollama""" pass @abstractmethod @@ -25,27 +25,30 @@ def __init__(self, vector_store: VectorStore): self.last_sources = [] # Track sources from last search def get_tool_definition(self) -> Dict[str, Any]: - """Return Anthropic tool definition for this tool""" + """Return OpenAI-compatible tool definition for Ollama""" return { - "name": "search_course_content", - "description": "Search course materials with smart course name matching and lesson filtering", - "input_schema": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "What to search for in the course content" + "type": "function", + "function": { + "name": "search_course_content", + "description": "Search course materials with smart course name matching and lesson filtering", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "What to search for in the course content" + }, + "course_name": { + "type": "string", + "description": "Course title (partial matches work, e.g. 'MCP', 'Introduction')" + }, + "lesson_number": { + "type": "integer", + "description": "Specific lesson number to search within (e.g. 1, 2, 3)" + } }, - "course_name": { - "type": "string", - "description": "Course title (partial matches work, e.g. 'MCP', 'Introduction')" - }, - "lesson_number": { - "type": "integer", - "description": "Specific lesson number to search within (e.g. 1, 2, 3)" - } - }, - "required": ["query"] + "required": ["query"] + } } } @@ -122,14 +125,16 @@ def __init__(self): def register_tool(self, tool: Tool): """Register any tool that implements the Tool interface""" tool_def = tool.get_tool_definition() - tool_name = tool_def.get("name") + # OpenAI-compatible format nests name under "function" + func_def = tool_def.get("function", {}) + tool_name = func_def.get("name") or tool_def.get("name") if not tool_name: raise ValueError("Tool must have a 'name' in its definition") self.tools[tool_name] = tool def get_tool_definitions(self) -> list: - """Get all tool definitions for Anthropic tool calling""" + """Get all tool definitions for Ollama tool calling""" return [tool.get_tool_definition() for tool in self.tools.values()] def execute_tool(self, tool_name: str, **kwargs) -> str: diff --git a/pyproject.toml b/pyproject.toml index 3f05e2de0..ceb8140a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ readme = "README.md" requires-python = ">=3.13" dependencies = [ "chromadb==1.0.15", - "anthropic==0.58.2", + "httpx>=0.27.0", "sentence-transformers==5.0.0", "fastapi==0.116.1", "uvicorn==0.35.0", diff --git a/uv.lock b/uv.lock index 9ae65c557..1c5694ffe 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.13" [[package]] @@ -11,24 +11,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] -[[package]] -name = "anthropic" -version = "0.58.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "distro" }, - { name = "httpx" }, - { name = "jiter" }, - { name = "pydantic" }, - { name = "sniffio" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/95/b9/ab06c586aa5a5e7499017cee5ebab94ee260e75975c45395f32b8592abdd/anthropic-0.58.2.tar.gz", hash = "sha256:86396cc45530a83acea25ae6bca9f86656af81e3d598b4d22a1300e0e4cf8df8", size = 425125, upload-time = "2025-07-18T13:38:55.94Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d0/f2/68d908ff308c9a65af5749ec31952e01d32f19ea073b0268affc616e6ebc/anthropic-0.58.2-py3-none-any.whl", hash = "sha256:3742181c634c725f337b71096839b6404145e33a8e190c75387c4028b825864d", size = 292896, upload-time = "2025-07-18T13:38:54.782Z" }, -] - [[package]] name = "anyio" version = "4.9.0" @@ -482,42 +464,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] -[[package]] -name = "jiter" -version = "0.10.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ee/9d/ae7ddb4b8ab3fb1b51faf4deb36cb48a4fbbd7cb36bad6a5fca4741306f7/jiter-0.10.0.tar.gz", hash = "sha256:07a7142c38aacc85194391108dc91b5b57093c978a9932bd86a36862759d9500", size = 162759, upload-time = "2025-05-18T19:04:59.73Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2e/b0/279597e7a270e8d22623fea6c5d4eeac328e7d95c236ed51a2b884c54f70/jiter-0.10.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:e0588107ec8e11b6f5ef0e0d656fb2803ac6cf94a96b2b9fc675c0e3ab5e8644", size = 311617, upload-time = "2025-05-18T19:04:02.078Z" }, - { url = "https://files.pythonhosted.org/packages/91/e3/0916334936f356d605f54cc164af4060e3e7094364add445a3bc79335d46/jiter-0.10.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cafc4628b616dc32530c20ee53d71589816cf385dd9449633e910d596b1f5c8a", size = 318947, upload-time = "2025-05-18T19:04:03.347Z" }, - { url = "https://files.pythonhosted.org/packages/6a/8e/fd94e8c02d0e94539b7d669a7ebbd2776e51f329bb2c84d4385e8063a2ad/jiter-0.10.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:520ef6d981172693786a49ff5b09eda72a42e539f14788124a07530f785c3ad6", size = 344618, upload-time = "2025-05-18T19:04:04.709Z" }, - { url = "https://files.pythonhosted.org/packages/6f/b0/f9f0a2ec42c6e9c2e61c327824687f1e2415b767e1089c1d9135f43816bd/jiter-0.10.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:554dedfd05937f8fc45d17ebdf298fe7e0c77458232bcb73d9fbbf4c6455f5b3", size = 368829, upload-time = "2025-05-18T19:04:06.912Z" }, - { url = "https://files.pythonhosted.org/packages/e8/57/5bbcd5331910595ad53b9fd0c610392ac68692176f05ae48d6ce5c852967/jiter-0.10.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5bc299da7789deacf95f64052d97f75c16d4fc8c4c214a22bf8d859a4288a1c2", size = 491034, upload-time = "2025-05-18T19:04:08.222Z" }, - { url = "https://files.pythonhosted.org/packages/9b/be/c393df00e6e6e9e623a73551774449f2f23b6ec6a502a3297aeeece2c65a/jiter-0.10.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5161e201172de298a8a1baad95eb85db4fb90e902353b1f6a41d64ea64644e25", size = 388529, upload-time = "2025-05-18T19:04:09.566Z" }, - { url = "https://files.pythonhosted.org/packages/42/3e/df2235c54d365434c7f150b986a6e35f41ebdc2f95acea3036d99613025d/jiter-0.10.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e2227db6ba93cb3e2bf67c87e594adde0609f146344e8207e8730364db27041", size = 350671, upload-time = "2025-05-18T19:04:10.98Z" }, - { url = "https://files.pythonhosted.org/packages/c6/77/71b0b24cbcc28f55ab4dbfe029f9a5b73aeadaba677843fc6dc9ed2b1d0a/jiter-0.10.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:15acb267ea5e2c64515574b06a8bf393fbfee6a50eb1673614aa45f4613c0cca", size = 390864, upload-time = "2025-05-18T19:04:12.722Z" }, - { url = "https://files.pythonhosted.org/packages/6a/d3/ef774b6969b9b6178e1d1e7a89a3bd37d241f3d3ec5f8deb37bbd203714a/jiter-0.10.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:901b92f2e2947dc6dfcb52fd624453862e16665ea909a08398dde19c0731b7f4", size = 522989, upload-time = "2025-05-18T19:04:14.261Z" }, - { url = "https://files.pythonhosted.org/packages/0c/41/9becdb1d8dd5d854142f45a9d71949ed7e87a8e312b0bede2de849388cb9/jiter-0.10.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d0cb9a125d5a3ec971a094a845eadde2db0de85b33c9f13eb94a0c63d463879e", size = 513495, upload-time = "2025-05-18T19:04:15.603Z" }, - { url = "https://files.pythonhosted.org/packages/9c/36/3468e5a18238bdedae7c4d19461265b5e9b8e288d3f86cd89d00cbb48686/jiter-0.10.0-cp313-cp313-win32.whl", hash = "sha256:48a403277ad1ee208fb930bdf91745e4d2d6e47253eedc96e2559d1e6527006d", size = 211289, upload-time = "2025-05-18T19:04:17.541Z" }, - { url = "https://files.pythonhosted.org/packages/7e/07/1c96b623128bcb913706e294adb5f768fb7baf8db5e1338ce7b4ee8c78ef/jiter-0.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:75f9eb72ecb640619c29bf714e78c9c46c9c4eaafd644bf78577ede459f330d4", size = 205074, upload-time = "2025-05-18T19:04:19.21Z" }, - { url = "https://files.pythonhosted.org/packages/54/46/caa2c1342655f57d8f0f2519774c6d67132205909c65e9aa8255e1d7b4f4/jiter-0.10.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:28ed2a4c05a1f32ef0e1d24c2611330219fed727dae01789f4a335617634b1ca", size = 318225, upload-time = "2025-05-18T19:04:20.583Z" }, - { url = "https://files.pythonhosted.org/packages/43/84/c7d44c75767e18946219ba2d703a5a32ab37b0bc21886a97bc6062e4da42/jiter-0.10.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14a4c418b1ec86a195f1ca69da8b23e8926c752b685af665ce30777233dfe070", size = 350235, upload-time = "2025-05-18T19:04:22.363Z" }, - { url = "https://files.pythonhosted.org/packages/01/16/f5a0135ccd968b480daad0e6ab34b0c7c5ba3bc447e5088152696140dcb3/jiter-0.10.0-cp313-cp313t-win_amd64.whl", hash = "sha256:d7bfed2fe1fe0e4dda6ef682cee888ba444b21e7a6553e03252e4feb6cf0adca", size = 207278, upload-time = "2025-05-18T19:04:23.627Z" }, - { url = "https://files.pythonhosted.org/packages/1c/9b/1d646da42c3de6c2188fdaa15bce8ecb22b635904fc68be025e21249ba44/jiter-0.10.0-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:5e9251a5e83fab8d87799d3e1a46cb4b7f2919b895c6f4483629ed2446f66522", size = 310866, upload-time = "2025-05-18T19:04:24.891Z" }, - { url = "https://files.pythonhosted.org/packages/ad/0e/26538b158e8a7c7987e94e7aeb2999e2e82b1f9d2e1f6e9874ddf71ebda0/jiter-0.10.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:023aa0204126fe5b87ccbcd75c8a0d0261b9abdbbf46d55e7ae9f8e22424eeb8", size = 318772, upload-time = "2025-05-18T19:04:26.161Z" }, - { url = "https://files.pythonhosted.org/packages/7b/fb/d302893151caa1c2636d6574d213e4b34e31fd077af6050a9c5cbb42f6fb/jiter-0.10.0-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c189c4f1779c05f75fc17c0c1267594ed918996a231593a21a5ca5438445216", size = 344534, upload-time = "2025-05-18T19:04:27.495Z" }, - { url = "https://files.pythonhosted.org/packages/01/d8/5780b64a149d74e347c5128d82176eb1e3241b1391ac07935693466d6219/jiter-0.10.0-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:15720084d90d1098ca0229352607cd68256c76991f6b374af96f36920eae13c4", size = 369087, upload-time = "2025-05-18T19:04:28.896Z" }, - { url = "https://files.pythonhosted.org/packages/e8/5b/f235a1437445160e777544f3ade57544daf96ba7e96c1a5b24a6f7ac7004/jiter-0.10.0-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e4f2fb68e5f1cfee30e2b2a09549a00683e0fde4c6a2ab88c94072fc33cb7426", size = 490694, upload-time = "2025-05-18T19:04:30.183Z" }, - { url = "https://files.pythonhosted.org/packages/85/a9/9c3d4617caa2ff89cf61b41e83820c27ebb3f7b5fae8a72901e8cd6ff9be/jiter-0.10.0-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ce541693355fc6da424c08b7edf39a2895f58d6ea17d92cc2b168d20907dee12", size = 388992, upload-time = "2025-05-18T19:04:32.028Z" }, - { url = "https://files.pythonhosted.org/packages/68/b1/344fd14049ba5c94526540af7eb661871f9c54d5f5601ff41a959b9a0bbd/jiter-0.10.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31c50c40272e189d50006ad5c73883caabb73d4e9748a688b216e85a9a9ca3b9", size = 351723, upload-time = "2025-05-18T19:04:33.467Z" }, - { url = "https://files.pythonhosted.org/packages/41/89/4c0e345041186f82a31aee7b9d4219a910df672b9fef26f129f0cda07a29/jiter-0.10.0-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fa3402a2ff9815960e0372a47b75c76979d74402448509ccd49a275fa983ef8a", size = 392215, upload-time = "2025-05-18T19:04:34.827Z" }, - { url = "https://files.pythonhosted.org/packages/55/58/ee607863e18d3f895feb802154a2177d7e823a7103f000df182e0f718b38/jiter-0.10.0-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:1956f934dca32d7bb647ea21d06d93ca40868b505c228556d3373cbd255ce853", size = 522762, upload-time = "2025-05-18T19:04:36.19Z" }, - { url = "https://files.pythonhosted.org/packages/15/d0/9123fb41825490d16929e73c212de9a42913d68324a8ce3c8476cae7ac9d/jiter-0.10.0-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:fcedb049bdfc555e261d6f65a6abe1d5ad68825b7202ccb9692636c70fcced86", size = 513427, upload-time = "2025-05-18T19:04:37.544Z" }, - { url = "https://files.pythonhosted.org/packages/d8/b3/2bd02071c5a2430d0b70403a34411fc519c2f227da7b03da9ba6a956f931/jiter-0.10.0-cp314-cp314-win32.whl", hash = "sha256:ac509f7eccca54b2a29daeb516fb95b6f0bd0d0d8084efaf8ed5dfc7b9f0b357", size = 210127, upload-time = "2025-05-18T19:04:38.837Z" }, - { url = "https://files.pythonhosted.org/packages/03/0c/5fe86614ea050c3ecd728ab4035534387cd41e7c1855ef6c031f1ca93e3f/jiter-0.10.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5ed975b83a2b8639356151cef5c0d597c68376fc4922b45d0eb384ac058cfa00", size = 318527, upload-time = "2025-05-18T19:04:40.612Z" }, - { url = "https://files.pythonhosted.org/packages/b3/4a/4175a563579e884192ba6e81725fc0448b042024419be8d83aa8a80a3f44/jiter-0.10.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aa96f2abba33dc77f79b4cf791840230375f9534e5fac927ccceb58c5e604a5", size = 354213, upload-time = "2025-05-18T19:04:41.894Z" }, -] - [[package]] name = "joblib" version = "1.5.1" @@ -1552,9 +1498,9 @@ name = "starting-codebase" version = "0.1.0" source = { virtual = "." } dependencies = [ - { name = "anthropic" }, { name = "chromadb" }, { name = "fastapi" }, + { name = "httpx" }, { name = "python-dotenv" }, { name = "python-multipart" }, { name = "sentence-transformers" }, @@ -1563,9 +1509,9 @@ dependencies = [ [package.metadata] requires-dist = [ - { name = "anthropic", specifier = "==0.58.2" }, { name = "chromadb", specifier = "==1.0.15" }, { name = "fastapi", specifier = "==0.116.1" }, + { name = "httpx", specifier = ">=0.27.0" }, { name = "python-dotenv", specifier = "==1.1.1" }, { name = "python-multipart", specifier = "==0.0.20" }, { name = "sentence-transformers", specifier = "==5.0.0" }, From 6f279815485700d9899450d0c435093ab7cee9de Mon Sep 17 00:00:00 2001 From: Muhammad Sohaib Date: Fri, 13 Feb 2026 16:17:09 +0500 Subject: [PATCH 2/4] Add course outline tool with markdown formatting and UI improvements - Add CourseOutlineTool that retrieves full course outline (title, link, all lessons) from course catalog - Update system prompt to enforce listing every lesson for outline queries - Augment outline queries with explicit LLM instructions to prevent summarization - Add source links with clickable chips, new chat button, and breaks:true for markdown rendering Co-Authored-By: Claude Opus 4.6 --- backend/ai_generator.py | 1 + backend/app.py | 20 ++++++- backend/rag_system.py | 37 ++++++++++-- backend/search_tools.py | 88 ++++++++++++++++++++++++----- backend/vector_store.py | 24 ++++++++ frontend/index.html | 7 ++- frontend/script.js | 37 +++++++++--- frontend/style.css | 122 ++++++++++++++++++++++++++++++++++++++-- 8 files changed, 300 insertions(+), 36 deletions(-) diff --git a/backend/ai_generator.py b/backend/ai_generator.py index 198a1a62c..c8b4e9dd1 100644 --- a/backend/ai_generator.py +++ b/backend/ai_generator.py @@ -13,6 +13,7 @@ class AIGenerator: - Answer questions using the provided course context - If the context doesn't contain relevant information, say so clearly - Do not mention "based on the context" or "based on the search results" +- For outline, structure, or "what lessons" queries: You MUST list EVERY lesson exactly as provided in the context using markdown bullet points. Output the course title, course link, then each lesson as a bullet point "- **Lesson N:** Title". Do NOT summarize, group, or skip any lessons. All responses must be: 1. **Brief, Concise and focused** - Get to the point quickly diff --git a/backend/app.py b/backend/app.py index 5a69d741d..491e2b4f2 100644 --- a/backend/app.py +++ b/backend/app.py @@ -40,10 +40,19 @@ class QueryRequest(BaseModel): query: str session_id: Optional[str] = None +class SourceItem(BaseModel): + """A single source with title and optional link""" + title: str + link: Optional[str] = None + class QueryResponse(BaseModel): """Response model for course queries""" answer: str - sources: List[str] + sources: List[SourceItem] + session_id: str + +class ClearSessionRequest(BaseModel): + """Request model for clearing a session""" session_id: str class CourseStats(BaseModel): @@ -73,6 +82,15 @@ async def query_documents(request: QueryRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) +@app.post("/api/sessions/clear") +async def clear_session(request: ClearSessionRequest): + """Clear a chat session to free server memory""" + try: + rag_system.session_manager.clear_session(request.session_id) + return {"status": "ok"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + @app.get("/api/courses", response_model=CourseStats) async def get_course_stats(): """Get course analytics and statistics""" diff --git a/backend/rag_system.py b/backend/rag_system.py index b58df4b9d..6fd9da826 100644 --- a/backend/rag_system.py +++ b/backend/rag_system.py @@ -4,7 +4,7 @@ from vector_store import VectorStore from ai_generator import AIGenerator from session_manager import SessionManager -from search_tools import ToolManager, CourseSearchTool +from search_tools import ToolManager, CourseSearchTool, CourseOutlineTool from models import Course, Lesson, CourseChunk class RAGSystem: @@ -22,7 +22,9 @@ def __init__(self, config): # Initialize search tools self.tool_manager = ToolManager() self.search_tool = CourseSearchTool(self.vector_store) + self.outline_tool = CourseOutlineTool(self.vector_store) self.tool_manager.register_tool(self.search_tool) + self.tool_manager.register_tool(self.outline_tool) def add_course_document(self, file_path: str) -> Tuple[Course, int]: """ @@ -99,7 +101,7 @@ def add_course_folder(self, folder_path: str, clear_existing: bool = False) -> T return total_courses, total_chunks - def query(self, query: str, session_id: Optional[str] = None) -> Tuple[str, List[str]]: + def query(self, query: str, session_id: Optional[str] = None) -> Tuple[str, List[Dict]]: """ Process a user query using the RAG system. @@ -113,11 +115,34 @@ def query(self, query: str, session_id: Optional[str] = None) -> Tuple[str, List Tuple of (response, sources list) """ # Search for relevant content first - context = self.search_tool.execute(query=query) + content_context = self.search_tool.execute(query=query) - # Get sources from the search - sources = self.search_tool.last_sources.copy() + # Also try the outline tool (adds course structure when relevant) + outline_context = self.outline_tool.execute(course_name=query) + has_outline = outline_context and not outline_context.startswith("No course found") + + # Combine contexts — outline first so the LLM sees it prominently + context_parts = [] + if has_outline: + context_parts.append(outline_context) + if content_context: + context_parts.append(content_context) + context = "\n\n".join(context_parts) + + # Gather sources from both tools + sources = self.search_tool.last_sources.copy() + self.outline_tool.last_sources.copy() self.search_tool.last_sources = [] + self.outline_tool.last_sources = [] + + # For outline queries, augment the query with an explicit instruction + effective_query = query + if has_outline: + effective_query = ( + f"{query}\n\n" + "IMPORTANT: The context contains a complete course outline. " + "You MUST list every single lesson exactly as shown — do not skip, " + "group, or summarize any lessons." + ) # Get conversation history if session exists history = None @@ -126,7 +151,7 @@ def query(self, query: str, session_id: Optional[str] = None) -> Tuple[str, List # Generate response with retrieved context response = self.ai_generator.generate_response( - query=query, + query=effective_query, context=context, conversation_history=history ) diff --git a/backend/search_tools.py b/backend/search_tools.py index e044c81c7..4cf36092d 100644 --- a/backend/search_tools.py +++ b/backend/search_tools.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Optional, Protocol +from typing import Dict, Any, Optional, List, Protocol from abc import ABC, abstractmethod from vector_store import VectorStore, SearchResults @@ -91,31 +91,93 @@ def execute(self, query: str, course_name: Optional[str] = None, lesson_number: def _format_results(self, results: SearchResults) -> str: """Format search results with course and lesson context""" formatted = [] - sources = [] # Track sources for the UI - + sources: List[Dict[str, Any]] = [] + seen = set() # Deduplicate by (course_title, lesson_num) + for doc, meta in zip(results.documents, results.metadata): course_title = meta.get('course_title', 'unknown') lesson_num = meta.get('lesson_number') - + # Build context header header = f"[{course_title}" if lesson_num is not None: header += f" - Lesson {lesson_num}" header += "]" - - # Track source for the UI - source = course_title - if lesson_num is not None: - source += f" - Lesson {lesson_num}" - sources.append(source) - + + # Track source for the UI (deduplicated) + key = (course_title, lesson_num) + if key not in seen: + seen.add(key) + title = course_title + if lesson_num is not None: + title += f" - Lesson {lesson_num}" + link = None + if lesson_num is not None: + link = self.store.get_lesson_link(course_title, lesson_num) + if not link: + link = self.store.get_course_link(course_title) + sources.append({"title": title, "link": link}) + formatted.append(f"{header}\n{doc}") - + # Store sources for retrieval self.last_sources = sources - + return "\n\n".join(formatted) +class CourseOutlineTool(Tool): + """Tool for retrieving course outlines from the course catalog""" + + def __init__(self, vector_store: VectorStore): + self.store = vector_store + self.last_sources = [] + + def get_tool_definition(self) -> Dict[str, Any]: + return { + "type": "function", + "function": { + "name": "get_course_outline", + "description": "Get the full outline of a course including title, link, and all lessons", + "parameters": { + "type": "object", + "properties": { + "course_name": { + "type": "string", + "description": "Course title or partial name (e.g. 'MCP', 'RAG chatbot')" + } + }, + "required": ["course_name"] + } + } + } + + def execute(self, course_name: str) -> str: + outline = self.store.get_course_outline(course_name) + if not outline: + return f"No course found matching '{course_name}'." + + title = outline["title"] + link = outline.get("course_link") or "N/A" + lessons = outline.get("lessons", []) + + # Track source + self.last_sources = [{"title": title, "link": outline.get("course_link")}] + + lines = [ + f"COURSE OUTLINE (list every lesson as a markdown bullet point):", + f"**Course:** {title}", + f"**Course Link:** {link}", + f"**Total Lessons:** {len(lessons)}", + "", + ] + for lesson in lessons: + num = lesson.get("lesson_number", "?") + ltitle = lesson.get("lesson_title", "Untitled") + lines.append(f"- **Lesson {num}:** {ltitle}") + + return "\n".join(lines) + + class ToolManager: """Manages available tools for the AI""" diff --git a/backend/vector_store.py b/backend/vector_store.py index 390abe71c..cd95f75d1 100644 --- a/backend/vector_store.py +++ b/backend/vector_store.py @@ -246,6 +246,30 @@ def get_course_link(self, course_title: str) -> Optional[str]: print(f"Error getting course link: {e}") return None + def get_course_outline(self, course_name: str) -> Optional[Dict[str, Any]]: + """Get full course outline (title, link, lessons) by fuzzy course name match""" + import json + # Resolve to exact course title + course_title = self._resolve_course_name(course_name) + if not course_title: + return None + + try: + results = self.course_catalog.get(ids=[course_title]) + if results and results['metadatas']: + meta = results['metadatas'][0] + lessons = [] + if meta.get('lessons_json'): + lessons = json.loads(meta['lessons_json']) + return { + "title": meta.get('title', course_title), + "course_link": meta.get('course_link'), + "lessons": lessons + } + except Exception as e: + print(f"Error getting course outline: {e}") + return None + def get_lesson_link(self, course_title: str, lesson_number: int) -> Optional[str]: """Get lesson link for a given course title and lesson number""" import json diff --git a/frontend/index.html b/frontend/index.html index f8e25a62f..26f2decc3 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -7,7 +7,7 @@ Course Materials Assistant - +
@@ -19,6 +19,9 @@

Course Materials Assistant

+
+ +
@@ -76,6 +79,6 @@

Course Materials Assistant

- + \ No newline at end of file diff --git a/frontend/script.js b/frontend/script.js index 562a8a363..98b080d5f 100644 --- a/frontend/script.js +++ b/frontend/script.js @@ -5,7 +5,7 @@ const API_URL = '/api'; let currentSessionId = null; // DOM elements -let chatMessages, chatInput, sendButton, totalCourses, courseTitles; +let chatMessages, chatInput, sendButton, totalCourses, courseTitles, newChatBtn; // Initialize document.addEventListener('DOMContentLoaded', () => { @@ -15,7 +15,8 @@ document.addEventListener('DOMContentLoaded', () => { sendButton = document.getElementById('sendButton'); totalCourses = document.getElementById('totalCourses'); courseTitles = document.getElementById('courseTitles'); - + newChatBtn = document.getElementById('newChatBtn'); + setupEventListeners(); createNewSession(); loadCourseStats(); @@ -28,8 +29,20 @@ function setupEventListeners() { chatInput.addEventListener('keypress', (e) => { if (e.key === 'Enter') sendMessage(); }); - - + + // New chat button + newChatBtn.addEventListener('click', () => { + if (currentSessionId) { + fetch(`${API_URL}/sessions/clear`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ session_id: currentSessionId }) + }).catch(() => {}); // fire-and-forget + } + createNewSession(); + chatInput.focus(); + }); + // Suggested questions document.querySelectorAll('.suggested-item').forEach(button => { button.addEventListener('click', (e) => { @@ -117,15 +130,23 @@ function addMessage(content, type, sources = null, isWelcome = false) { messageDiv.id = `message-${messageId}`; // Convert markdown to HTML for assistant messages - const displayContent = type === 'assistant' ? marked.parse(content) : escapeHtml(content); + const displayContent = type === 'assistant' ? marked.parse(content, { breaks: true }) : escapeHtml(content); let html = `
${displayContent}
`; if (sources && sources.length > 0) { + const sourceChips = sources.map(s => { + const title = escapeHtml(typeof s === 'string' ? s : s.title); + const link = typeof s === 'object' && s.link; + const icon = ''; + return link + ? `${icon}${title}` + : `${title}`; + }).join(''); html += ` -
- Sources -
${sources.join(', ')}
+
+ ${sources.length} Source${sources.length > 1 ? 's' : ''} +
${sourceChips}
`; } diff --git a/frontend/style.css b/frontend/style.css index 825d03675..47b2a7fac 100644 --- a/frontend/style.css +++ b/frontend/style.css @@ -103,6 +103,37 @@ header h1 { background: var(--text-secondary); } +/* New Chat Button */ +.new-chat-btn { + display: block; + padding: 0.5rem 0; + background: none; + border: none; + cursor: pointer; + font-size: 0.875rem; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.5px; + color: var(--text-secondary); + text-align: left; + margin-bottom: 0; + transition: color 0.2s ease; + appearance: none; + -webkit-appearance: none; + list-style: none; + outline: none; + font-family: inherit; +} + +.new-chat-btn:hover { + color: var(--primary-color); +} + +.new-chat-btn:focus { + outline: none; + color: var(--primary-color); +} + .sidebar-section { margin-bottom: 1.5rem; } @@ -220,29 +251,108 @@ header h1 { /* Collapsible Sources */ .sources-collapsible { - margin-top: 0.5rem; - font-size: 0.75rem; + margin-top: 0.75rem; + font-size: 0.8rem; color: var(--text-secondary); + background: rgba(255, 255, 255, 0.03); + border: 1px solid var(--border-color); + border-radius: 12px; + overflow: hidden; } .sources-collapsible summary { cursor: pointer; - padding: 0.25rem 0.5rem; + padding: 0.5rem 0.75rem; user-select: none; - font-weight: 500; + font-weight: 600; + font-size: 0.75rem; + text-transform: uppercase; + letter-spacing: 0.5px; + display: flex; + align-items: center; + gap: 0.35rem; + transition: color 0.2s ease, background 0.2s ease; } .sources-collapsible summary:hover { color: var(--text-primary); + background: rgba(255, 255, 255, 0.03); +} + +.sources-collapsible summary::-webkit-details-marker { + display: none; +} + +.sources-collapsible summary::after { + content: ''; + display: inline-block; + width: 6px; + height: 6px; + border-right: 2px solid currentColor; + border-bottom: 2px solid currentColor; + transform: rotate(-45deg); + transition: transform 0.2s ease; + margin-left: auto; +} + +.sources-collapsible[open] summary::after { + transform: rotate(45deg); +} + +.source-header-icon { + flex-shrink: 0; + opacity: 0.7; } .sources-collapsible[open] summary { - margin-bottom: 0.25rem; + margin-bottom: 0; + border-bottom: 1px solid var(--border-color); } .sources-content { - padding: 0 0.5rem 0.25rem 1.5rem; + padding: 0.5rem; + display: flex; + flex-wrap: wrap; + gap: 0.4rem; +} + +.source-chip { + display: inline-flex; + align-items: center; + gap: 0.35rem; + padding: 0.35rem 0.7rem; + background: rgba(37, 99, 235, 0.1); + border: 1px solid rgba(37, 99, 235, 0.2); + border-radius: 20px; + color: var(--primary-color); + font-size: 0.75rem; + font-weight: 500; + text-decoration: none; + transition: all 0.2s ease; + line-height: 1.3; +} + +a.source-chip:hover { + background: rgba(37, 99, 235, 0.2); + border-color: rgba(37, 99, 235, 0.4); + color: #60a5fa; + transform: translateY(-1px); + box-shadow: 0 2px 8px rgba(37, 99, 235, 0.15); +} + +a.source-chip:active { + transform: translateY(0); +} + +.source-icon { + flex-shrink: 0; + opacity: 0.7; +} + +span.source-chip { color: var(--text-secondary); + background: rgba(255, 255, 255, 0.05); + border-color: var(--border-color); } /* Markdown formatting styles */ From 2aa223e97a1ed7d48fb524995f3a372a7ee2682b Mon Sep 17 00:00:00 2001 From: Muhammad Sohaib Date: Sun, 15 Feb 2026 03:03:28 +0500 Subject: [PATCH 3/4] Update RAG chatbot with improved AI generation, search tools, frontend UI, and tests Co-Authored-By: Claude Opus 4.6 --- backend/ai_generator.py | 117 +++++++- backend/rag_system.py | 63 ++--- backend/search_tools.py | 41 ++- backend/tests/__init__.py | 0 backend/tests/conftest.py | 185 +++++++++++++ backend/tests/test_ai_generator.py | 424 +++++++++++++++++++++++++++++ backend/tests/test_api.py | 131 +++++++++ backend/tests/test_rag_system.py | 175 ++++++++++++ backend/tests/test_search_tools.py | 205 ++++++++++++++ frontend/index.html | 20 +- frontend/script.js | 19 ++ frontend/style.css | 107 ++++++++ pyproject.toml | 13 + uv.lock | 42 +++ 14 files changed, 1485 insertions(+), 57 deletions(-) create mode 100644 backend/tests/__init__.py create mode 100644 backend/tests/conftest.py create mode 100644 backend/tests/test_ai_generator.py create mode 100644 backend/tests/test_api.py create mode 100644 backend/tests/test_rag_system.py create mode 100644 backend/tests/test_search_tools.py diff --git a/backend/ai_generator.py b/backend/ai_generator.py index c8b4e9dd1..a3922d221 100644 --- a/backend/ai_generator.py +++ b/backend/ai_generator.py @@ -1,11 +1,13 @@ import httpx import json -from typing import List, Optional, Dict, Any +from typing import List, Optional, Dict, Any, Callable, Tuple class AIGenerator: """Handles interactions with Ollama's local LLM API for generating responses""" + MAX_TOOL_ROUNDS = 3 + # Static system prompt to avoid rebuilding on each call SYSTEM_PROMPT = """ You are an AI assistant specialized in course materials and educational content. @@ -15,6 +17,16 @@ class AIGenerator: - Do not mention "based on the context" or "based on the search results" - For outline, structure, or "what lessons" queries: You MUST list EVERY lesson exactly as provided in the context using markdown bullet points. Output the course title, course link, then each lesson as a bullet point "- **Lesson N:** Title". Do NOT summarize, group, or skip any lessons. +Tool Usage: +- You have access to search tools for finding course content. Use them when the user asks a question. +- When the user asks about a SPECIFIC lesson (e.g. "lesson 5 of MCP"), use search_course_content with course_name AND lesson_number to get that lesson's actual content. Do NOT just return the course outline. +- Use get_course_outline only when the user asks to LIST or OVERVIEW all lessons in a course. +- For complex queries, break them into steps: first find the relevant course/lesson, then search for specific content. +- After receiving tool results, synthesize them into a well-structured answer. Do NOT dump raw tool output. +- Structure your final answer using markdown: use headings, bullet points, and bold for key terms. +- Only include information that directly answers the user's question — ignore irrelevant tool results. +- If tool results are empty or unhelpful, say so honestly instead of fabricating an answer. + All responses must be: 1. **Brief, Concise and focused** - Get to the point quickly 2. **Educational** - Maintain instructional value @@ -29,7 +41,7 @@ def __init__(self, base_url: str, model: str): # 120s timeout to handle first-request model loading self.http_client = httpx.Client(timeout=120.0) - def _call_ollama(self, messages: List[Dict]) -> Dict: + def _call_ollama(self, messages: List[Dict], tools: Optional[List] = None) -> Dict: """Make a POST request to the Ollama chat API.""" payload = { "model": self.model, @@ -40,14 +52,105 @@ def _call_ollama(self, messages: List[Dict]) -> Dict: "num_predict": 800 } } + if tools is not None: + payload["tools"] = tools resp = self.http_client.post(f"{self.base_url}/api/chat", json=payload) resp.raise_for_status() return resp.json() + def _execute_tool_call(self, tool_call: dict, tool_executor: Callable) -> Tuple[str, str]: + """Extract tool name/args from a tool call and execute via tool_executor.""" + func = tool_call.get("function", {}) + name = func.get("name", "") + arguments = func.get("arguments", {}) + try: + result = tool_executor(name, **arguments) + return name, str(result) + except Exception as e: + return name, f"Tool error: {e}" + + def _parse_tool_call_from_text(self, content: str) -> Optional[dict]: + """Try to recover a tool call that the LLM wrote as plain text. + + Small models sometimes emit JSON-like text instead of using the + structured tool_calls field. We attempt a best-effort parse. + """ + import re + if not content: + return None + + # Try to find JSON-ish blob with "name" and "parameters" + # Handle unquoted identifiers like: {"name": get_course_outline, ...} + text = content.strip() + # Quick gate: must look like it's trying to be a tool call + if '"name"' not in text and "'name'" not in text: + return None + + # Fix common malformed JSON: unquoted values after "name": + text = re.sub(r'"name"\s*:\s*([a-zA-Z_]\w*)', r'"name": "\1"', text) + # Fix single quotes to double quotes + text = text.replace("'", '"') + + try: + obj = json.loads(text) + except json.JSONDecodeError: + # Try to extract just the JSON object + match = re.search(r'\{.*\}', text, re.DOTALL) + if not match: + return None + try: + obj = json.loads(match.group()) + except json.JSONDecodeError: + return None + + name = obj.get("name") + params = obj.get("parameters") or obj.get("arguments") or {} + if not name: + return None + + return {"function": {"name": name, "arguments": params}} + + def _run_tool_round(self, messages: List[Dict], tools: List, + tool_executor: Callable, remaining_rounds: int) -> str: + """Recursively call Ollama, executing tool calls until the LLM produces text.""" + # If rounds remain, offer tools; otherwise force a text-only response + data = self._call_ollama( + messages, tools=tools if remaining_rounds > 0 else None + ) + + assistant_msg = data.get("message", {}) + tool_calls = assistant_msg.get("tool_calls") + content = assistant_msg.get("content", "") + + # Fallback: if the LLM wrote the tool call as plain text, parse it + if not tool_calls and content and remaining_rounds > 0: + recovered = self._parse_tool_call_from_text(content) + if recovered: + tool_calls = [recovered] + # Rewrite assistant_msg so context stays consistent + assistant_msg = {"role": "assistant", "content": "", "tool_calls": tool_calls} + + # Base cases: no tool calls or no remaining rounds + if not tool_calls or remaining_rounds <= 0: + return content + + # Append the assistant message (with its tool_calls) to context + messages.append(assistant_msg) + + # Execute each tool call and append results + for tc in tool_calls: + name, result = self._execute_tool_call(tc, tool_executor) + print(f"[Tool Call] {name}({tc.get('function', {}).get('arguments', {})}) -> {len(result)} chars") + messages.append({"role": "tool", "content": result}) + + return self._run_tool_round(messages, tools, tool_executor, remaining_rounds - 1) + def generate_response(self, query: str, context: str = "", - conversation_history: Optional[str] = None) -> str: + conversation_history: Optional[str] = None, + tools: Optional[List] = None, + tool_executor: Optional[Callable] = None) -> str: """ Generate AI response using provided context. @@ -55,6 +158,8 @@ def generate_response(self, query: str, query: The user's question or request context: Retrieved course content to answer from conversation_history: Previous messages for context + tools: Optional tool definitions for Ollama function calling + tool_executor: Callable(name, **kwargs) to execute tool calls Returns: Generated response as string @@ -77,6 +182,10 @@ def generate_response(self, query: str, {"role": "user", "content": user_content} ] - # Get response from Ollama + # Tool-calling path: recursive loop + if tools is not None and tool_executor is not None: + return self._run_tool_round(messages, tools, tool_executor, self.MAX_TOOL_ROUNDS) + + # Default path: single call, no tools data = self._call_ollama(messages) return data.get("message", {}).get("content", "") diff --git a/backend/rag_system.py b/backend/rag_system.py index 6fd9da826..0af1b8dde 100644 --- a/backend/rag_system.py +++ b/backend/rag_system.py @@ -101,11 +101,16 @@ def add_course_folder(self, folder_path: str, clear_existing: bool = False) -> T return total_courses, total_chunks + def _is_search_error(self, result: str) -> bool: + """Check if a search tool result is an error/empty message, not real content.""" + error_prefixes = ("Search error:", "No relevant content found", "No course found") + return result.startswith(error_prefixes) + def query(self, query: str, session_id: Optional[str] = None) -> Tuple[str, List[Dict]]: """ Process a user query using the RAG system. - Searches course content first, then passes results as context to the LLM. + The LLM decides which tools to call via the tool-calling loop in AIGenerator. Args: query: User's question @@ -114,49 +119,31 @@ def query(self, query: str, session_id: Optional[str] = None) -> Tuple[str, List Returns: Tuple of (response, sources list) """ - # Search for relevant content first - content_context = self.search_tool.execute(query=query) - - # Also try the outline tool (adds course structure when relevant) - outline_context = self.outline_tool.execute(course_name=query) - has_outline = outline_context and not outline_context.startswith("No course found") - - # Combine contexts — outline first so the LLM sees it prominently - context_parts = [] - if has_outline: - context_parts.append(outline_context) - if content_context: - context_parts.append(content_context) - context = "\n\n".join(context_parts) - - # Gather sources from both tools - sources = self.search_tool.last_sources.copy() + self.outline_tool.last_sources.copy() - self.search_tool.last_sources = [] - self.outline_tool.last_sources = [] - - # For outline queries, augment the query with an explicit instruction - effective_query = query - if has_outline: - effective_query = ( - f"{query}\n\n" - "IMPORTANT: The context contains a complete course outline. " - "You MUST list every single lesson exactly as shown — do not skip, " - "group, or summarize any lessons." - ) + # Reset sources from previous query + self.tool_manager.reset_sources() - # Get conversation history if session exists + # Get conversation history history = None if session_id: history = self.session_manager.get_conversation_history(session_id) - # Generate response with retrieved context - response = self.ai_generator.generate_response( - query=effective_query, - context=context, - conversation_history=history - ) + # Let the LLM decide which tools to call + try: + response = self.ai_generator.generate_response( + query=query, + conversation_history=history, + tools=self.tool_manager.get_tool_definitions(), + tool_executor=self.tool_manager.execute_tool, + ) + except Exception as e: + print(f"AI generation error: {e}") + sources = self.tool_manager.get_last_sources() + return "Sorry, I'm having trouble generating a response right now. Please try again.", sources + + # Collect sources from all tools used during the conversation + sources = self.tool_manager.get_last_sources() - # Update conversation history + # Update session history if session_id: self.session_manager.add_exchange(session_id, query, response) diff --git a/backend/search_tools.py b/backend/search_tools.py index 4cf36092d..42dc7a39a 100644 --- a/backend/search_tools.py +++ b/backend/search_tools.py @@ -52,19 +52,29 @@ def get_tool_definition(self) -> Dict[str, Any]: } } - def execute(self, query: str, course_name: Optional[str] = None, lesson_number: Optional[int] = None) -> str: + def execute(self, query: str = "", course_name: Optional[str] = None, lesson_number: Optional[int] = None) -> str: """ Execute the search tool with given parameters. - + Args: query: What to search for course_name: Optional course filter lesson_number: Optional lesson filter - + Returns: Formatted search results or error message """ - + # Coerce lesson_number to int (LLMs sometimes send it as a string) + if lesson_number is not None: + try: + lesson_number = int(lesson_number) + except (ValueError, TypeError): + lesson_number = None + + # If query is empty but we have filters, use a generic query + if not query and (course_name or lesson_number is not None): + query = f"{course_name or ''} lesson {lesson_number or ''}".strip() + # Use the vector store's unified search interface results = self.store.search( query=query, @@ -120,8 +130,8 @@ def _format_results(self, results: SearchResults) -> str: formatted.append(f"{header}\n{doc}") - # Store sources for retrieval - self.last_sources = sources + # Accumulate sources (may be called multiple times per query) + self.last_sources.extend(sources) return "\n\n".join(formatted) @@ -160,8 +170,8 @@ def execute(self, course_name: str) -> str: link = outline.get("course_link") or "N/A" lessons = outline.get("lessons", []) - # Track source - self.last_sources = [{"title": title, "link": outline.get("course_link")}] + # Accumulate sources (may be called multiple times per query) + self.last_sources.extend([{"title": title, "link": outline.get("course_link")}]) lines = [ f"COURSE OUTLINE (list every lesson as a markdown bullet point):", @@ -207,12 +217,17 @@ def execute_tool(self, tool_name: str, **kwargs) -> str: return self.tools[tool_name].execute(**kwargs) def get_last_sources(self) -> list: - """Get sources from the last search operation""" - # Check all tools for last_sources attribute + """Get sources from all tools, deduplicated by (title, link).""" + seen = set() + sources = [] for tool in self.tools.values(): - if hasattr(tool, 'last_sources') and tool.last_sources: - return tool.last_sources - return [] + if hasattr(tool, 'last_sources'): + for s in tool.last_sources: + key = (s.get("title"), s.get("link")) + if key not in seen: + seen.add(key) + sources.append(s) + return sources def reset_sources(self): """Reset sources from all tools that track sources""" diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 000000000..2df6d3693 --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,185 @@ +"""Shared fixtures for RAG chatbot tests.""" + +import sys +import os +import pytest +from unittest.mock import MagicMock, patch +from dataclasses import dataclass + +# Add backend to path so imports work +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from vector_store import SearchResults +from models import Course, Lesson, CourseChunk + + +# --------------------------------------------------------------------------- +# API test fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_rag_system(): + """A fully mocked RAGSystem for API endpoint tests.""" + rag = MagicMock() + rag.query.return_value = ( + "RAG stands for Retrieval-Augmented Generation.", + [{"title": "Building RAG Chatbots", "link": "https://example.com/rag"}], + ) + rag.session_manager.create_session.return_value = "session_1" + rag.get_course_analytics.return_value = { + "total_courses": 2, + "course_titles": ["Building RAG Chatbots", "Intro to MCP"], + } + return rag + + +@pytest.fixture +def test_app(mock_rag_system): + """A lightweight FastAPI app with the same endpoints as app.py, + but without static-file mounting or startup document loading.""" + from fastapi import FastAPI, HTTPException + from pydantic import BaseModel + from typing import List, Optional + + api = FastAPI() + + class QueryRequest(BaseModel): + query: str + session_id: Optional[str] = None + + class SourceItem(BaseModel): + title: str + link: Optional[str] = None + + class QueryResponse(BaseModel): + answer: str + sources: List[SourceItem] + session_id: str + + class ClearSessionRequest(BaseModel): + session_id: str + + class CourseStats(BaseModel): + total_courses: int + course_titles: List[str] + + rag = mock_rag_system + + @api.post("/api/query", response_model=QueryResponse) + async def query_documents(request: QueryRequest): + try: + session_id = request.session_id + if not session_id: + session_id = rag.session_manager.create_session() + answer, sources = rag.query(request.query, session_id) + return QueryResponse(answer=answer, sources=sources, session_id=session_id) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @api.post("/api/sessions/clear") + async def clear_session(request: ClearSessionRequest): + try: + rag.session_manager.clear_session(request.session_id) + return {"status": "ok"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @api.get("/api/courses", response_model=CourseStats) + async def get_course_stats(): + try: + analytics = rag.get_course_analytics() + return CourseStats( + total_courses=analytics["total_courses"], + course_titles=analytics["course_titles"], + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + return api + + +@pytest.fixture +def client(test_app): + """HTTPX TestClient (sync) for the test FastAPI app.""" + from starlette.testclient import TestClient + return TestClient(test_app) + + +@dataclass +class MockConfig: + """Minimal config for testing without .env or real services.""" + OLLAMA_BASE_URL: str = "http://localhost:11434" + OLLAMA_MODEL: str = "test-model" + EMBEDDING_MODEL: str = "all-MiniLM-L6-v2" + CHUNK_SIZE: int = 800 + CHUNK_OVERLAP: int = 100 + MAX_RESULTS: int = 5 + MAX_HISTORY: int = 2 + CHROMA_PATH: str = "./test_chroma_db" + + +@pytest.fixture +def mock_config(): + return MockConfig() + + +@pytest.fixture +def sample_search_results(): + """Realistic search results as returned by VectorStore.search().""" + return SearchResults( + documents=[ + "Lesson 1 content: RAG stands for Retrieval-Augmented Generation. It combines retrieval and generation.", + "Lesson 2 content: Vector databases store embeddings for semantic search.", + ], + metadata=[ + {"course_title": "Building RAG Chatbots", "lesson_number": 1, "chunk_index": 0}, + {"course_title": "Building RAG Chatbots", "lesson_number": 2, "chunk_index": 3}, + ], + distances=[0.25, 0.42], + ) + + +@pytest.fixture +def empty_search_results(): + """Empty search results (no matching documents).""" + return SearchResults(documents=[], metadata=[], distances=[]) + + +@pytest.fixture +def error_search_results(): + """Search results with an error.""" + return SearchResults(documents=[], metadata=[], distances=[], error="Search error: collection is empty") + + +@pytest.fixture +def sample_course(): + """A sample Course object for testing.""" + return Course( + title="Building RAG Chatbots", + course_link="https://example.com/rag-course", + instructor="Test Instructor", + lessons=[ + Lesson(lesson_number=1, title="Introduction to RAG", lesson_link="https://example.com/lesson1"), + Lesson(lesson_number=2, title="Vector Databases", lesson_link="https://example.com/lesson2"), + Lesson(lesson_number=3, title="Building the Pipeline", lesson_link="https://example.com/lesson3"), + ], + ) + + +@pytest.fixture +def mock_vector_store(): + """A fully mocked VectorStore (avoids ChromaDB + embedding model).""" + store = MagicMock() + store.search.return_value = SearchResults( + documents=[ + "RAG combines retrieval with generation for better answers.", + ], + metadata=[ + {"course_title": "Building RAG Chatbots", "lesson_number": 1, "chunk_index": 0}, + ], + distances=[0.3], + ) + store.get_lesson_link.return_value = "https://example.com/lesson1" + store.get_course_link.return_value = "https://example.com/rag-course" + store.get_course_outline.return_value = None # default: no outline match + return store diff --git a/backend/tests/test_ai_generator.py b/backend/tests/test_ai_generator.py new file mode 100644 index 000000000..6718fde0a --- /dev/null +++ b/backend/tests/test_ai_generator.py @@ -0,0 +1,424 @@ +"""Tests for AIGenerator in ai_generator.py. + +These tests verify that the AI generator: +- Correctly formats messages with search context for the LLM +- Includes the system prompt in every call +- Handles conversation history +- Propagates connection/HTTP errors (which cause 'query failed' in the frontend) +- Handles unexpected response shapes gracefully +""" + +import pytest +from unittest.mock import MagicMock, patch, PropertyMock +import httpx + +from ai_generator import AIGenerator + + +@pytest.fixture +def generator(): + """AIGenerator pointed at a fake Ollama URL.""" + return AIGenerator(base_url="http://localhost:11434", model="test-model") + + +@pytest.fixture +def mock_ollama_success(monkeypatch): + """Patch httpx.Client.post to return a successful Ollama response.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"role": "assistant", "content": "RAG combines retrieval with generation."} + } + mock_response.raise_for_status = MagicMock() + + mock_post = MagicMock(return_value=mock_response) + monkeypatch.setattr(httpx.Client, "post", mock_post) + return mock_post + + +# --------------------------------------------------------------------------- +# Message formatting tests +# --------------------------------------------------------------------------- + +class TestGenerateResponseFormatting: + """Verify the messages sent to Ollama are correctly structured.""" + + def test_context_is_included_in_user_message(self, generator, mock_ollama_success): + """When context is provided, it should appear in the user message.""" + generator.generate_response( + query="What is RAG?", + context="[Building RAG Chatbots - Lesson 1]\nRAG stands for Retrieval-Augmented Generation.", + ) + + # Inspect the payload sent to Ollama + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + messages = payload["messages"] + + user_msg = next(m for m in messages if m["role"] == "user") + assert "Course context:" in user_msg["content"] + assert "RAG stands for Retrieval-Augmented Generation" in user_msg["content"] + assert "What is RAG?" in user_msg["content"] + + def test_no_context_sends_query_only(self, generator, mock_ollama_success): + """Without context, the user message should be just the query.""" + generator.generate_response(query="Hello") + + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + messages = payload["messages"] + + user_msg = next(m for m in messages if m["role"] == "user") + assert user_msg["content"] == "Hello" + assert "Course context:" not in user_msg["content"] + + def test_system_prompt_always_present(self, generator, mock_ollama_success): + """The system prompt should always be the first message.""" + generator.generate_response(query="test") + + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + messages = payload["messages"] + + system_msg = messages[0] + assert system_msg["role"] == "system" + assert "AI assistant" in system_msg["content"] + + def test_conversation_history_appended_to_system(self, generator, mock_ollama_success): + """Conversation history should be appended to the system prompt.""" + history = "User: What is RAG?\nAssistant: RAG is retrieval-augmented generation." + + generator.generate_response(query="Tell me more", conversation_history=history) + + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + messages = payload["messages"] + + system_msg = messages[0] + assert "Previous conversation:" in system_msg["content"] + assert "What is RAG?" in system_msg["content"] + + def test_empty_context_treated_as_no_context(self, generator, mock_ollama_success): + """An empty string context should NOT add 'Course context:' prefix.""" + generator.generate_response(query="Hello", context="") + + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + messages = payload["messages"] + + user_msg = next(m for m in messages if m["role"] == "user") + assert user_msg["content"] == "Hello" + + +# --------------------------------------------------------------------------- +# Response extraction tests +# --------------------------------------------------------------------------- + +class TestGenerateResponseOutput: + """Verify the response is correctly extracted from the Ollama reply.""" + + def test_returns_content_string(self, generator, mock_ollama_success): + """generate_response should return the content string from Ollama.""" + result = generator.generate_response(query="What is RAG?", context="some context") + + assert result == "RAG combines retrieval with generation." + + def test_handles_missing_message_key(self, generator, monkeypatch): + """If Ollama returns unexpected JSON, should return empty string (not crash).""" + mock_response = MagicMock() + mock_response.json.return_value = {"unexpected": "format"} + mock_response.raise_for_status = MagicMock() + monkeypatch.setattr(httpx.Client, "post", MagicMock(return_value=mock_response)) + + result = generator.generate_response(query="test") + + assert result == "" + + def test_handles_missing_content_key(self, generator, monkeypatch): + """If 'message' exists but has no 'content', should return empty string.""" + mock_response = MagicMock() + mock_response.json.return_value = {"message": {"role": "assistant"}} + mock_response.raise_for_status = MagicMock() + monkeypatch.setattr(httpx.Client, "post", MagicMock(return_value=mock_response)) + + result = generator.generate_response(query="test") + + assert result == "" + + +# --------------------------------------------------------------------------- +# Error handling tests — these expose the root cause of "query failed" +# --------------------------------------------------------------------------- + +class TestGenerateResponseErrors: + """Test error scenarios that cause the 'query failed' frontend error.""" + + def test_connection_error_propagates(self, generator, monkeypatch): + """If Ollama is not running, a ConnectError should propagate. + + This is the most likely cause of 'query failed' — the ai_generator + does NOT catch connection errors, so they bubble up through + rag_system.query() → app.py → HTTP 500 → frontend 'Query failed'. + """ + def raise_connect_error(*args, **kwargs): + raise httpx.ConnectError("Connection refused") + + monkeypatch.setattr(httpx.Client, "post", raise_connect_error) + + with pytest.raises(httpx.ConnectError): + generator.generate_response(query="What is RAG?", context="some context") + + def test_http_error_propagates(self, generator, monkeypatch): + """If Ollama returns a 4xx/5xx, the error should propagate. + + For example, if the model name is wrong, Ollama returns 404. + """ + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Not Found", request=MagicMock(), response=mock_response + ) + monkeypatch.setattr(httpx.Client, "post", MagicMock(return_value=mock_response)) + + with pytest.raises(httpx.HTTPStatusError): + generator.generate_response(query="test", context="ctx") + + def test_timeout_error_propagates(self, generator, monkeypatch): + """If Ollama takes too long, a timeout error should propagate.""" + def raise_timeout(*args, **kwargs): + raise httpx.ReadTimeout("Read timed out") + + monkeypatch.setattr(httpx.Client, "post", raise_timeout) + + with pytest.raises(httpx.ReadTimeout): + generator.generate_response(query="test", context="ctx") + + def test_generator_does_not_filter_context(self, generator, mock_ollama_success): + """AIGenerator passes context as-is — filtering is rag_system's job. + + The generator is a thin wrapper around Ollama. If it receives an error + string as context, it will include it. The fix for this lives in + rag_system.query(), which now filters error strings before they + reach the generator. + """ + error_as_context = "Search error: collection is empty" + + generator.generate_response(query="What is RAG?", context=error_as_context) + + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + user_msg = next(m for m in payload["messages"] if m["role"] == "user") + + # Generator doesn't filter — it passes whatever it receives + assert "Search error:" in user_msg["content"] + + +# --------------------------------------------------------------------------- +# Request structure tests +# --------------------------------------------------------------------------- + +class TestOllamaRequestFormat: + """Verify the HTTP request to Ollama is well-formed.""" + + def test_request_uses_correct_endpoint(self, generator, mock_ollama_success): + """Should POST to /api/chat.""" + generator.generate_response(query="test") + + call_args = mock_ollama_success.call_args + url = call_args.args[0] if call_args.args else call_args[0][0] + assert url == "http://localhost:11434/api/chat" + + def test_request_includes_model(self, generator, mock_ollama_success): + """Payload should include the configured model name.""" + generator.generate_response(query="test") + + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + assert payload["model"] == "test-model" + + def test_request_disables_streaming(self, generator, mock_ollama_success): + """stream should be False (we expect a single JSON response).""" + generator.generate_response(query="test") + + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + assert payload["stream"] is False + + +# --------------------------------------------------------------------------- +# Helper factories for tool-calling tests +# --------------------------------------------------------------------------- + +def _make_response(message_body: dict): + """Create a mock httpx response returning the given message dict.""" + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = {"message": message_body} + resp.raise_for_status = MagicMock() + return resp + + +def make_tool_call_response(tool_name: str, arguments: dict): + """Mock Ollama response that requests a tool call.""" + return _make_response({ + "role": "assistant", + "content": "", + "tool_calls": [ + {"function": {"name": tool_name, "arguments": arguments}} + ], + }) + + +def make_text_response(content: str): + """Mock Ollama response that returns plain text (no tool calls).""" + return _make_response({ + "role": "assistant", + "content": content, + }) + + +SAMPLE_TOOLS = [{"type": "function", "function": {"name": "search", "parameters": {}}}] + + +# --------------------------------------------------------------------------- +# Sequential tool-calling tests +# --------------------------------------------------------------------------- + +class TestSequentialToolCalling: + """Verify the recursive tool-calling loop in generate_response.""" + + def test_single_tool_call_then_text(self, generator, monkeypatch): + """One tool call followed by a text response → 2 API calls, 1 execution.""" + mock_post = MagicMock(side_effect=[ + make_tool_call_response("search", {"query": "RAG"}), + make_text_response("Here is the answer."), + ]) + monkeypatch.setattr(httpx.Client, "post", mock_post) + executor = MagicMock(return_value="search results") + + result = generator.generate_response( + query="What is RAG?", tools=SAMPLE_TOOLS, tool_executor=executor, + ) + + assert result == "Here is the answer." + assert mock_post.call_count == 2 + executor.assert_called_once_with("search", query="RAG") + + def test_two_sequential_tool_calls(self, generator, monkeypatch): + """Two sequential tool calls then text → 3 API calls, 2 executions.""" + mock_post = MagicMock(side_effect=[ + make_tool_call_response("search", {"query": "course X"}), + make_tool_call_response("search", {"query": "lesson 4"}), + make_text_response("Final answer."), + ]) + monkeypatch.setattr(httpx.Client, "post", mock_post) + executor = MagicMock(return_value="result") + + result = generator.generate_response( + query="complex query", tools=SAMPLE_TOOLS, tool_executor=executor, + ) + + assert result == "Final answer." + assert mock_post.call_count == 3 + assert executor.call_count == 2 + + def test_no_tool_calls_returns_immediately(self, generator, monkeypatch): + """If the LLM returns text right away, no tool execution happens.""" + mock_post = MagicMock(side_effect=[ + make_text_response("Immediate answer."), + ]) + monkeypatch.setattr(httpx.Client, "post", mock_post) + executor = MagicMock() + + result = generator.generate_response( + query="hi", tools=SAMPLE_TOOLS, tool_executor=executor, + ) + + assert result == "Immediate answer." + assert mock_post.call_count == 1 + executor.assert_not_called() + + def test_max_rounds_forces_final_call_without_tools(self, generator, monkeypatch): + """After MAX_TOOL_ROUNDS tool calls, the next call omits tools to force text.""" + # MAX_TOOL_ROUNDS is 3, so we need 3 tool-call rounds + 1 forced text + mock_post = MagicMock(side_effect=[ + make_tool_call_response("search", {"query": "a"}), + make_tool_call_response("search", {"query": "b"}), + make_tool_call_response("search", {"query": "c"}), + make_text_response("Forced text."), + ]) + monkeypatch.setattr(httpx.Client, "post", mock_post) + executor = MagicMock(return_value="r") + + result = generator.generate_response( + query="q", tools=SAMPLE_TOOLS, tool_executor=executor, + ) + + assert result == "Forced text." + # Fourth call should NOT have tools in its payload + fourth_call_payload = mock_post.call_args_list[3].kwargs.get("json") or mock_post.call_args_list[3][1].get("json") + assert "tools" not in fourth_call_payload + + def test_tool_error_passed_as_result(self, generator, monkeypatch): + """If tool_executor raises, the error string is sent as the tool result.""" + mock_post = MagicMock(side_effect=[ + make_tool_call_response("bad_tool", {}), + make_text_response("Recovered."), + ]) + monkeypatch.setattr(httpx.Client, "post", mock_post) + executor = MagicMock(side_effect=ValueError("something broke")) + + result = generator.generate_response( + query="q", tools=SAMPLE_TOOLS, tool_executor=executor, + ) + + assert result == "Recovered." + # Verify the error was passed as a tool message + second_call_payload = mock_post.call_args_list[1].kwargs.get("json") or mock_post.call_args_list[1][1].get("json") + tool_msgs = [m for m in second_call_payload["messages"] if m["role"] == "tool"] + assert len(tool_msgs) == 1 + assert "Tool error:" in tool_msgs[0]["content"] + assert "something broke" in tool_msgs[0]["content"] + + def test_messages_accumulate_across_rounds(self, generator, monkeypatch): + """After 2 rounds, the final API call should contain the full conversation.""" + mock_post = MagicMock(side_effect=[ + make_tool_call_response("search", {"query": "first"}), + make_tool_call_response("search", {"query": "second"}), + make_text_response("Done."), + ]) + monkeypatch.setattr(httpx.Client, "post", mock_post) + executor = MagicMock(return_value="data") + + generator.generate_response( + query="q", tools=SAMPLE_TOOLS, tool_executor=executor, + ) + + final_payload = mock_post.call_args_list[2].kwargs.get("json") or mock_post.call_args_list[2][1].get("json") + msgs = final_payload["messages"] + roles = [m["role"] for m in msgs] + # system, user, assistant(tool_call), tool, assistant(tool_call), tool + assert roles == ["system", "user", "assistant", "tool", "assistant", "tool"] + + def test_tools_in_payload_when_provided(self, generator, monkeypatch): + """When tools are provided, the first API call payload should include them.""" + mock_post = MagicMock(side_effect=[ + make_text_response("answer"), + ]) + monkeypatch.setattr(httpx.Client, "post", mock_post) + + generator.generate_response( + query="q", tools=SAMPLE_TOOLS, tool_executor=MagicMock(), + ) + + first_payload = mock_post.call_args_list[0].kwargs.get("json") or mock_post.call_args_list[0][1].get("json") + assert "tools" in first_payload + assert first_payload["tools"] == SAMPLE_TOOLS + + def test_no_tools_backward_compatible(self, generator, mock_ollama_success): + """Without tools/tool_executor, behavior is unchanged — no tools in payload.""" + generator.generate_response(query="test") + + payload = mock_ollama_success.call_args.kwargs.get("json") or mock_ollama_success.call_args[1].get("json") + assert "tools" not in payload + assert mock_ollama_success.call_count == 1 diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py new file mode 100644 index 000000000..c6248f5ab --- /dev/null +++ b/backend/tests/test_api.py @@ -0,0 +1,131 @@ +"""Tests for the FastAPI API endpoints.""" + +import pytest +from unittest.mock import MagicMock + + +# ── POST /api/query ───────────────────────────────────────────────────── + + +class TestQueryEndpoint: + """Tests for the /api/query endpoint.""" + + def test_query_returns_answer_and_sources(self, client, mock_rag_system): + resp = client.post("/api/query", json={"query": "What is RAG?"}) + assert resp.status_code == 200 + data = resp.json() + assert data["answer"] == "RAG stands for Retrieval-Augmented Generation." + assert len(data["sources"]) == 1 + assert data["sources"][0]["title"] == "Building RAG Chatbots" + + def test_query_creates_session_when_not_provided(self, client, mock_rag_system): + resp = client.post("/api/query", json={"query": "Hello"}) + data = resp.json() + assert data["session_id"] == "session_1" + mock_rag_system.session_manager.create_session.assert_called_once() + + def test_query_uses_provided_session_id(self, client, mock_rag_system): + resp = client.post( + "/api/query", + json={"query": "Hello", "session_id": "existing_session"}, + ) + data = resp.json() + assert data["session_id"] == "existing_session" + mock_rag_system.session_manager.create_session.assert_not_called() + mock_rag_system.query.assert_called_once_with("Hello", "existing_session") + + def test_query_passes_query_to_rag_system(self, client, mock_rag_system): + client.post("/api/query", json={"query": "Tell me about MCP"}) + mock_rag_system.query.assert_called_once() + assert mock_rag_system.query.call_args[0][0] == "Tell me about MCP" + + def test_query_missing_query_field_returns_422(self, client): + resp = client.post("/api/query", json={}) + assert resp.status_code == 422 + + def test_query_empty_string_passes_through(self, client, mock_rag_system): + resp = client.post("/api/query", json={"query": ""}) + # FastAPI doesn't block empty strings by default; the endpoint processes it + assert resp.status_code == 200 + + def test_query_rag_exception_returns_500(self, client, mock_rag_system): + mock_rag_system.query.side_effect = RuntimeError("Ollama unreachable") + resp = client.post("/api/query", json={"query": "anything"}) + assert resp.status_code == 500 + assert "Ollama unreachable" in resp.json()["detail"] + + def test_query_sources_with_no_link(self, client, mock_rag_system): + mock_rag_system.query.return_value = ( + "Answer", + [{"title": "Some Course"}], + ) + resp = client.post("/api/query", json={"query": "test"}) + data = resp.json() + assert data["sources"][0]["link"] is None + + def test_query_empty_sources(self, client, mock_rag_system): + mock_rag_system.query.return_value = ("No results", []) + resp = client.post("/api/query", json={"query": "obscure question"}) + data = resp.json() + assert data["sources"] == [] + + +# ── POST /api/sessions/clear ──────────────────────────────────────────── + + +class TestClearSessionEndpoint: + """Tests for the /api/sessions/clear endpoint.""" + + def test_clear_session_returns_ok(self, client, mock_rag_system): + resp = client.post( + "/api/sessions/clear", json={"session_id": "session_1"} + ) + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} + + def test_clear_session_calls_session_manager(self, client, mock_rag_system): + client.post("/api/sessions/clear", json={"session_id": "session_42"}) + mock_rag_system.session_manager.clear_session.assert_called_once_with( + "session_42" + ) + + def test_clear_session_missing_body_returns_422(self, client): + resp = client.post("/api/sessions/clear", json={}) + assert resp.status_code == 422 + + def test_clear_session_exception_returns_500(self, client, mock_rag_system): + mock_rag_system.session_manager.clear_session.side_effect = RuntimeError("boom") + resp = client.post( + "/api/sessions/clear", json={"session_id": "bad"} + ) + assert resp.status_code == 500 + + +# ── GET /api/courses ──────────────────────────────────────────────────── + + +class TestCoursesEndpoint: + """Tests for the /api/courses endpoint.""" + + def test_courses_returns_stats(self, client, mock_rag_system): + resp = client.get("/api/courses") + assert resp.status_code == 200 + data = resp.json() + assert data["total_courses"] == 2 + assert "Building RAG Chatbots" in data["course_titles"] + assert "Intro to MCP" in data["course_titles"] + + def test_courses_empty_catalog(self, client, mock_rag_system): + mock_rag_system.get_course_analytics.return_value = { + "total_courses": 0, + "course_titles": [], + } + resp = client.get("/api/courses") + data = resp.json() + assert data["total_courses"] == 0 + assert data["course_titles"] == [] + + def test_courses_exception_returns_500(self, client, mock_rag_system): + mock_rag_system.get_course_analytics.side_effect = RuntimeError("db down") + resp = client.get("/api/courses") + assert resp.status_code == 500 diff --git a/backend/tests/test_rag_system.py b/backend/tests/test_rag_system.py new file mode 100644 index 000000000..4ead1e265 --- /dev/null +++ b/backend/tests/test_rag_system.py @@ -0,0 +1,175 @@ +"""Tests for the RAG system query pipeline in rag_system.py. + +These tests verify the tool-calling orchestration: + query() resets sources → passes tools + executor to AIGenerator → collects sources → returns response. + +All external dependencies (VectorStore, Ollama) are mocked to isolate the logic. +""" + +import pytest +from unittest.mock import MagicMock, patch, call +import httpx + +from vector_store import SearchResults + + +# --------------------------------------------------------------------------- +# Helpers to build a RAGSystem with mocked internals +# --------------------------------------------------------------------------- + +@pytest.fixture +def rag_system(mock_config, mock_vector_store): + """Build a RAGSystem with mocked VectorStore and AIGenerator.""" + with patch("rag_system.VectorStore", return_value=mock_vector_store), \ + patch("rag_system.AIGenerator") as MockAI, \ + patch("rag_system.DocumentProcessor"): + + mock_ai = MockAI.return_value + mock_ai.generate_response.return_value = "RAG is a technique that combines retrieval and generation." + + from rag_system import RAGSystem + system = RAGSystem(mock_config) + + # Replace the vector store in the search tools with our mock + system.search_tool.store = mock_vector_store + system.outline_tool.store = mock_vector_store + system.ai_generator = mock_ai + + yield system + + +@pytest.fixture +def rag_system_with_ollama_error(mock_config, mock_vector_store): + """RAGSystem where the AI generator raises a connection error.""" + with patch("rag_system.VectorStore", return_value=mock_vector_store), \ + patch("rag_system.AIGenerator") as MockAI, \ + patch("rag_system.DocumentProcessor"): + + mock_ai = MockAI.return_value + mock_ai.generate_response.side_effect = httpx.ConnectError("Connection refused") + + from rag_system import RAGSystem + system = RAGSystem(mock_config) + + system.search_tool.store = mock_vector_store + system.outline_tool.store = mock_vector_store + system.ai_generator = mock_ai + + yield system + + +# --------------------------------------------------------------------------- +# Core query pipeline tests — tool-calling flow +# --------------------------------------------------------------------------- + +class TestRAGQueryPipeline: + """Test the main query() method orchestration.""" + + def test_query_passes_tool_definitions(self, rag_system): + """generate_response should receive tool definitions from the tool manager.""" + rag_system.query("What is RAG?") + + call_args = rag_system.ai_generator.generate_response.call_args + tools = call_args.kwargs.get("tools") + assert tools is not None + assert isinstance(tools, list) + assert len(tools) == 2 # search + outline tools + + # Verify tool names + names = {t["function"]["name"] for t in tools} + assert "search_course_content" in names + assert "get_course_outline" in names + + def test_query_passes_tool_executor(self, rag_system): + """generate_response should receive tool_executor bound to the tool manager.""" + rag_system.query("What is RAG?") + + call_args = rag_system.ai_generator.generate_response.call_args + executor = call_args.kwargs.get("tool_executor") + # Bound methods create new objects each access, so compare underlying function + instance + assert executor.__func__ is rag_system.tool_manager.execute_tool.__func__ + assert executor.__self__ is rag_system.tool_manager + + def test_query_returns_response_and_sources(self, rag_system): + """query() should return a (response_text, sources_list) tuple.""" + response, sources = rag_system.query("What is RAG?") + + assert response == "RAG is a technique that combines retrieval and generation." + assert isinstance(sources, list) + + def test_query_collects_sources_after_response(self, rag_system): + """Sources should be gathered from tools after the LLM finishes.""" + # Simulate the tool executor populating sources during generate_response + def fake_generate(**kwargs): + # Mimic what happens when the LLM calls the search tool + rag_system.tool_manager.execute_tool("search_course_content", query="RAG") + return "Here's what I found about RAG." + + rag_system.ai_generator.generate_response.side_effect = fake_generate + + response, sources = rag_system.query("What is RAG?") + + assert response == "Here's what I found about RAG." + titles = [s["title"] for s in sources] + assert any("Building RAG Chatbots" in t for t in titles) + + def test_query_resets_sources_before_each_query(self, rag_system): + """reset_sources() should be called at the start of each query.""" + # Seed leftover sources from a "previous" query + rag_system.search_tool.last_sources = [{"title": "stale", "link": None}] + + # The new query should reset before doing anything + rag_system.query("What is RAG?") + + # After the query, the stale source should be gone. + # Since generate_response is a plain mock (no tool calls), sources should be empty. + _, sources = rag_system.query("Second question") + assert all(s["title"] != "stale" for s in sources) + + +# --------------------------------------------------------------------------- +# Error handling tests +# --------------------------------------------------------------------------- + +class TestRAGQueryErrorHandling: + """Tests that expose error handling issues in the RAG pipeline.""" + + def test_ollama_connection_error_returns_friendly_message(self, rag_system_with_ollama_error): + """Ollama being down returns a friendly error instead of crashing.""" + response, sources = rag_system_with_ollama_error.query("What is RAG?") + + assert "trouble generating a response" in response + + +# --------------------------------------------------------------------------- +# Session handling tests +# --------------------------------------------------------------------------- + +class TestRAGSessionHandling: + """Test conversation session integration.""" + + def test_query_without_session(self, rag_system): + """Query without session_id should still work (no history).""" + response, sources = rag_system.query("What is RAG?") + + call_args = rag_system.ai_generator.generate_response.call_args + history = call_args.kwargs.get("conversation_history") + assert history is None + + def test_query_with_new_session(self, rag_system): + """First query with a session_id should have no history.""" + response, sources = rag_system.query("What is RAG?", session_id="session_1") + + assert response is not None + + def test_query_with_existing_session(self, rag_system): + """Second query in same session should include history from first query.""" + rag_system.query("What is RAG?", session_id="session_1") + rag_system.query("Tell me more", session_id="session_1") + + # Second call should have conversation history + calls = rag_system.ai_generator.generate_response.call_args_list + second_call = calls[1] + history = second_call.kwargs.get("conversation_history") + assert history is not None + assert "What is RAG?" in history diff --git a/backend/tests/test_search_tools.py b/backend/tests/test_search_tools.py new file mode 100644 index 000000000..a190c5346 --- /dev/null +++ b/backend/tests/test_search_tools.py @@ -0,0 +1,205 @@ +"""Tests for CourseSearchTool.execute() in search_tools.py. + +These tests mock the VectorStore to isolate the search tool logic and verify: +- Correct result formatting +- Error handling +- Filter pass-through (course_name, lesson_number) +- Source tracking and deduplication +""" + +import pytest +from unittest.mock import MagicMock, call +from search_tools import CourseSearchTool, CourseOutlineTool, ToolManager +from vector_store import SearchResults + + +# --------------------------------------------------------------------------- +# CourseSearchTool.execute() tests +# --------------------------------------------------------------------------- + +class TestCourseSearchToolExecute: + """Tests for the execute method of CourseSearchTool.""" + + def test_execute_returns_formatted_results(self, mock_vector_store, sample_search_results): + """Successful search should return formatted text with course/lesson headers.""" + mock_vector_store.search.return_value = sample_search_results + tool = CourseSearchTool(mock_vector_store) + + result = tool.execute(query="What is RAG?") + + assert "[Building RAG Chatbots - Lesson 1]" in result + assert "[Building RAG Chatbots - Lesson 2]" in result + assert "RAG stands for Retrieval-Augmented Generation" in result + assert "Vector databases store embeddings" in result + + def test_execute_with_empty_results(self, mock_vector_store, empty_search_results): + """Empty results should return a 'no content found' message.""" + mock_vector_store.search.return_value = empty_search_results + tool = CourseSearchTool(mock_vector_store) + + result = tool.execute(query="nonexistent topic") + + assert "No relevant content found" in result + + def test_execute_with_error(self, mock_vector_store, error_search_results): + """When VectorStore returns an error, execute should return the error string.""" + mock_vector_store.search.return_value = error_search_results + tool = CourseSearchTool(mock_vector_store) + + result = tool.execute(query="anything") + + assert "Search error" in result + + def test_execute_passes_course_name_filter(self, mock_vector_store, sample_search_results): + """course_name parameter should be forwarded to VectorStore.search().""" + mock_vector_store.search.return_value = sample_search_results + tool = CourseSearchTool(mock_vector_store) + + tool.execute(query="RAG", course_name="Building RAG Chatbots") + + mock_vector_store.search.assert_called_once_with( + query="RAG", + course_name="Building RAG Chatbots", + lesson_number=None, + ) + + def test_execute_passes_lesson_number_filter(self, mock_vector_store, sample_search_results): + """lesson_number parameter should be forwarded to VectorStore.search().""" + mock_vector_store.search.return_value = sample_search_results + tool = CourseSearchTool(mock_vector_store) + + tool.execute(query="RAG", lesson_number=2) + + mock_vector_store.search.assert_called_once_with( + query="RAG", + course_name=None, + lesson_number=2, + ) + + def test_execute_passes_both_filters(self, mock_vector_store, sample_search_results): + """Both course_name and lesson_number should be forwarded together.""" + mock_vector_store.search.return_value = sample_search_results + tool = CourseSearchTool(mock_vector_store) + + tool.execute(query="RAG", course_name="RAG Course", lesson_number=3) + + mock_vector_store.search.assert_called_once_with( + query="RAG", + course_name="RAG Course", + lesson_number=3, + ) + + def test_execute_tracks_sources(self, mock_vector_store, sample_search_results): + """After a successful search, last_sources should contain source info.""" + mock_vector_store.search.return_value = sample_search_results + tool = CourseSearchTool(mock_vector_store) + + tool.execute(query="What is RAG?") + + assert len(tool.last_sources) > 0 + assert tool.last_sources[0]["title"] == "Building RAG Chatbots - Lesson 1" + assert tool.last_sources[0]["link"] is not None + + def test_execute_deduplicates_sources(self, mock_vector_store): + """Multiple chunks from same course+lesson should produce one source entry.""" + mock_vector_store.search.return_value = SearchResults( + documents=["chunk 1 from lesson 1", "chunk 2 from lesson 1"], + metadata=[ + {"course_title": "Test Course", "lesson_number": 1, "chunk_index": 0}, + {"course_title": "Test Course", "lesson_number": 1, "chunk_index": 1}, + ], + distances=[0.1, 0.2], + ) + mock_vector_store.get_lesson_link.return_value = "https://example.com/l1" + tool = CourseSearchTool(mock_vector_store) + + tool.execute(query="test") + + assert len(tool.last_sources) == 1 + assert tool.last_sources[0]["title"] == "Test Course - Lesson 1" + + def test_execute_empty_results_includes_filter_info(self, mock_vector_store, empty_search_results): + """Empty results message should mention the active filters.""" + mock_vector_store.search.return_value = empty_search_results + tool = CourseSearchTool(mock_vector_store) + + result = tool.execute(query="test", course_name="MCP", lesson_number=5) + + assert "in course 'MCP'" in result + assert "in lesson 5" in result + + def test_execute_error_does_not_update_sources(self, mock_vector_store, error_search_results): + """When search errors, last_sources should NOT be updated (stays from init).""" + mock_vector_store.search.return_value = error_search_results + tool = CourseSearchTool(mock_vector_store) + + tool.execute(query="anything") + + # last_sources should still be the initial empty list + assert tool.last_sources == [] + + def test_execute_handles_metadata_without_lesson_number(self, mock_vector_store): + """Results without lesson_number should format without ' - Lesson X'.""" + mock_vector_store.search.return_value = SearchResults( + documents=["some general content"], + metadata=[{"course_title": "Intro Course", "chunk_index": 0}], + distances=[0.3], + ) + tool = CourseSearchTool(mock_vector_store) + + result = tool.execute(query="general") + + assert "[Intro Course]" in result + assert "Lesson" not in result.split("]")[0] # no lesson in header + + +# --------------------------------------------------------------------------- +# ToolManager tests +# --------------------------------------------------------------------------- + +class TestToolManager: + """Tests for ToolManager registration and execution.""" + + def test_register_and_execute_tool(self, mock_vector_store, sample_search_results): + """Registered tools should be executable by name.""" + mock_vector_store.search.return_value = sample_search_results + manager = ToolManager() + tool = CourseSearchTool(mock_vector_store) + manager.register_tool(tool) + + result = manager.execute_tool("search_course_content", query="RAG") + + assert result # non-empty string + mock_vector_store.search.assert_called_once() + + def test_execute_unknown_tool(self): + """Executing an unregistered tool should return an error message.""" + manager = ToolManager() + + result = manager.execute_tool("nonexistent_tool", query="test") + + assert "not found" in result + + def test_get_last_sources(self, mock_vector_store, sample_search_results): + """get_last_sources should return sources from the most recent search.""" + mock_vector_store.search.return_value = sample_search_results + manager = ToolManager() + tool = CourseSearchTool(mock_vector_store) + manager.register_tool(tool) + + manager.execute_tool("search_course_content", query="RAG") + sources = manager.get_last_sources() + + assert len(sources) > 0 + + def test_reset_sources(self, mock_vector_store, sample_search_results): + """reset_sources should clear sources from all tools.""" + mock_vector_store.search.return_value = sample_search_results + manager = ToolManager() + tool = CourseSearchTool(mock_vector_store) + manager.register_tool(tool) + + manager.execute_tool("search_course_content", query="RAG") + manager.reset_sources() + + assert manager.get_last_sources() == [] diff --git a/frontend/index.html b/frontend/index.html index 26f2decc3..d80dcedeb 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -7,9 +7,25 @@ Course Materials Assistant - + +

Course Materials Assistant

@@ -79,6 +95,6 @@

Course Materials Assistant

- + \ No newline at end of file diff --git a/frontend/script.js b/frontend/script.js index 98b080d5f..88d218f3d 100644 --- a/frontend/script.js +++ b/frontend/script.js @@ -7,6 +7,23 @@ let currentSessionId = null; // DOM elements let chatMessages, chatInput, sendButton, totalCourses, courseTitles, newChatBtn; +// Theme toggle +function initTheme() { + const saved = localStorage.getItem('theme'); + const theme = saved || 'dark'; + document.documentElement.setAttribute('data-theme', theme); +} + +function toggleTheme() { + const current = document.documentElement.getAttribute('data-theme'); + const next = current === 'light' ? 'dark' : 'light'; + document.documentElement.setAttribute('data-theme', next); + localStorage.setItem('theme', next); +} + +// Apply theme immediately to prevent flash +initTheme(); + // Initialize document.addEventListener('DOMContentLoaded', () => { // Get DOM elements after page loads @@ -17,6 +34,8 @@ document.addEventListener('DOMContentLoaded', () => { courseTitles = document.getElementById('courseTitles'); newChatBtn = document.getElementById('newChatBtn'); + document.getElementById('themeToggle').addEventListener('click', toggleTheme); + setupEventListeners(); createNewSession(); loadCourseStats(); diff --git a/frontend/style.css b/frontend/style.css index 47b2a7fac..a7038550f 100644 --- a/frontend/style.css +++ b/frontend/style.css @@ -24,6 +24,113 @@ --welcome-border: #2563eb; } +/* Light Theme */ +[data-theme="light"] { + --primary-color: #2563eb; + --primary-hover: #1d4ed8; + --background: #f8fafc; + --surface: #ffffff; + --surface-hover: #f1f5f9; + --text-primary: #0f172a; + --text-secondary: #64748b; + --border-color: #e2e8f0; + --user-message: #2563eb; + --assistant-message: #f1f5f9; + --shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.08); + --focus-ring: rgba(37, 99, 235, 0.2); + --welcome-bg: #eff6ff; + --welcome-border: #2563eb; +} + +[data-theme="light"] .message-content code { + background-color: rgba(0, 0, 0, 0.06); +} + +[data-theme="light"] .message-content pre { + background-color: rgba(0, 0, 0, 0.06); +} + +[data-theme="light"] .sources-collapsible { + background: rgba(0, 0, 0, 0.02); +} + +[data-theme="light"] .sources-collapsible summary:hover { + background: rgba(0, 0, 0, 0.03); +} + +[data-theme="light"] span.source-chip { + background: rgba(0, 0, 0, 0.04); + color: var(--text-secondary); +} + +[data-theme="light"] .error-message { + background: rgba(239, 68, 68, 0.08); + color: #dc2626; + border-color: rgba(239, 68, 68, 0.15); +} + +[data-theme="light"] .success-message { + background: rgba(34, 197, 94, 0.08); + color: #16a34a; + border-color: rgba(34, 197, 94, 0.15); +} + +/* Theme Toggle Button */ +.theme-toggle { + position: fixed; + top: 1rem; + right: 1.5rem; + z-index: 100; + width: 40px; + height: 40px; + border-radius: 50%; + border: 1px solid var(--border-color); + background: var(--surface); + color: var(--text-secondary); + cursor: pointer; + display: flex; + align-items: center; + justify-content: center; + transition: all 0.3s ease; + box-shadow: var(--shadow); +} + +.theme-toggle:hover { + color: var(--primary-color); + border-color: var(--primary-color); + transform: rotate(15deg); +} + +.theme-toggle:focus { + outline: none; + box-shadow: 0 0 0 3px var(--focus-ring); +} + +/* Show sun in dark mode, moon in light mode */ +.icon-moon { display: none; } +.icon-sun { display: block; } + +[data-theme="light"] .icon-moon { display: block; } +[data-theme="light"] .icon-sun { display: none; } + +/* Smooth theme transition on all themed elements */ +body, +.sidebar, +.chat-main, +.chat-container, +.chat-messages, +.chat-input-container, +#chatInput, +.stat-item, +.suggested-item, +.message-content, +.theme-toggle, +.new-chat-btn, +.sources-collapsible, +.source-chip { + transition: background-color 0.3s ease, color 0.3s ease, border-color 0.3s ease, box-shadow 0.3s ease; +} + /* Base Styles */ body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif; diff --git a/pyproject.toml b/pyproject.toml index ceb8140a3..a9f013c2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,3 +13,16 @@ dependencies = [ "python-multipart==0.0.20", "python-dotenv==1.1.1", ] + +[dependency-groups] +dev = [ + "httpx>=0.27.0", + "pytest>=8.0", +] + +[tool.pytest.ini_options] +testpaths = ["backend/tests"] +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore:resource_tracker:UserWarning", +] diff --git a/uv.lock b/uv.lock index 1c5694ffe..571b3ac1f 100644 --- a/uv.lock +++ b/uv.lock @@ -452,6 +452,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461, upload-time = "2025-01-03T18:51:54.306Z" }, ] +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -984,6 +993,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/c7/5572fa4a3f45740eaab6ae86fcdf7195b55beac1371ac8c619d880cfe948/pillow-11.3.0-cp314-cp314t-win_arm64.whl", hash = "sha256:79ea0d14d3ebad43ec77ad5272e6ff9bba5b679ef73375ea760261207fa8e0aa", size = 2512835, upload-time = "2025-07-01T09:15:50.399Z" }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + [[package]] name = "posthog" version = "5.4.0" @@ -1153,6 +1171,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6", size = 83178, upload-time = "2024-09-19T02:40:08.598Z" }, ] +[[package]] +name = "pytest" +version = "9.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -1507,6 +1541,11 @@ dependencies = [ { name = "uvicorn" }, ] +[package.dev-dependencies] +dev = [ + { name = "pytest" }, +] + [package.metadata] requires-dist = [ { name = "chromadb", specifier = "==1.0.15" }, @@ -1518,6 +1557,9 @@ requires-dist = [ { name = "uvicorn", specifier = "==0.35.0" }, ] +[package.metadata.requires-dev] +dev = [{ name = "pytest", specifier = ">=8.0" }] + [[package]] name = "sympy" version = "1.14.0" From 60fbb69ef190f54979cf150a96b92405c7b25fcb Mon Sep 17 00:00:00 2001 From: Muhammad Sohaib Date: Sun, 15 Feb 2026 03:05:40 +0500 Subject: [PATCH 4/4] Add OllamaSetup.exe to .gitignore Co-Authored-By: Claude Opus 4.6 --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 41b4384b8..dbdd45c19 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,4 @@ uploads/ # OS .DS_Store -Thumbs.db \ No newline at end of file +Thumbs.dbOllamaSetup.exe