diff --git a/.env.example b/.env.example index 18b34cb7e..e305c873f 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,3 @@ -# Copy this file to .env and add your actual API key -ANTHROPIC_API_KEY=your-anthropic-api-key-here \ No newline at end of file +# Ollama settings (free local LLM - no API key needed) +OLLAMA_BASE_URL=http://localhost:11434 +OLLAMA_MODEL=llama3.1:8b diff --git a/.gitignore b/.gitignore index 41b4384b8..dbdd45c19 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,4 @@ uploads/ # OS .DS_Store -Thumbs.db \ No newline at end of file +Thumbs.dbOllamaSetup.exe diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..f5fbc4653 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,53 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Running the Application + +```bash +# Install dependencies +uv sync + +# Start the server (from project root) +cd backend && uv run uvicorn app:app --reload --port 8000 + +# Or use the shell script +./run.sh +``` + +Requires Ollama running locally with a model installed (configured in `.env`). +The app is served at http://localhost:8000 with API docs at http://localhost:8000/docs. + +## Environment + +- Python 3.13+ with `uv` package manager +- **Always use `uv` to run commands and manage dependencies. Never use `pip` directly.** +- `.env` file in project root with `OLLAMA_BASE_URL` and `OLLAMA_MODEL` +- No test suite exists currently + +## Architecture + +This is a RAG (Retrieval-Augmented Generation) chatbot for course materials. FastAPI backend serves both the API and the frontend static files. + +**Query flow:** Frontend → `app.py` (FastAPI) → `rag_system.py` (orchestrator) → searches `vector_store.py` (ChromaDB) → sends query + retrieved chunks to `ai_generator.py` (Ollama LLM) → response back to frontend. + +**Key design decisions:** +- Search-first approach: every query searches ChromaDB before calling the LLM (no tool calling — small local models can't handle it reliably) +- Two ChromaDB collections: `course_catalog` (metadata for fuzzy course name matching) and `course_content` (chunked text for semantic search) +- Embeddings via `all-MiniLM-L6-v2` sentence-transformer +- Session history is in-memory only (lost on restart), capped at 2 exchanges per session +- Documents are chunked at 800 chars with 100 char overlap +- Course documents in `docs/` are auto-loaded on server startup + +**Backend modules:** +- `app.py` — API endpoints and static file serving +- `rag_system.py` — orchestrates search → LLM → session flow +- `vector_store.py` — ChromaDB wrapper with semantic search and course/lesson filtering +- `ai_generator.py` — Ollama API client +- `document_processor.py` — parses course .txt/.pdf/.docx files into structured chunks +- `search_tools.py` — search tool abstraction with source tracking +- `session_manager.py` — per-session conversation memory +- `models.py` — `Course`, `Lesson`, `CourseChunk` dataclasses +- `config.py` — loads settings from `.env` + +**Frontend:** plain HTML/JS/CSS in `frontend/`, uses `marked.js` for markdown rendering. No build step. diff --git a/backend/ai_generator.py b/backend/ai_generator.py index 0363ca90c..a3922d221 100644 --- a/backend/ai_generator.py +++ b/backend/ai_generator.py @@ -1,25 +1,31 @@ -import anthropic -from typing import List, Optional, Dict, Any +import httpx +import json +from typing import List, Optional, Dict, Any, Callable, Tuple + class AIGenerator: - """Handles interactions with Anthropic's Claude API for generating responses""" - - # Static system prompt to avoid rebuilding on each call - SYSTEM_PROMPT = """ You are an AI assistant specialized in course materials and educational content with access to a comprehensive search tool for course information. + """Handles interactions with Ollama's local LLM API for generating responses""" + + MAX_TOOL_ROUNDS = 3 -Search Tool Usage: -- Use the search tool **only** for questions about specific course content or detailed educational materials -- **One search per query maximum** -- Synthesize search results into accurate, fact-based responses -- If search yields no results, state this clearly without offering alternatives + # Static system prompt to avoid rebuilding on each call + SYSTEM_PROMPT = """ You are an AI assistant specialized in course materials and educational content. Response Protocol: -- **General knowledge questions**: Answer using existing knowledge without searching -- **Course-specific questions**: Search first, then answer -- **No meta-commentary**: - - Provide direct answers only — no reasoning process, search explanations, or question-type analysis - - Do not mention "based on the search results" +- Answer questions using the provided course context +- If the context doesn't contain relevant information, say so clearly +- Do not mention "based on the context" or "based on the search results" +- For outline, structure, or "what lessons" queries: You MUST list EVERY lesson exactly as provided in the context using markdown bullet points. Output the course title, course link, then each lesson as a bullet point "- **Lesson N:** Title". Do NOT summarize, group, or skip any lessons. +Tool Usage: +- You have access to search tools for finding course content. Use them when the user asks a question. +- When the user asks about a SPECIFIC lesson (e.g. "lesson 5 of MCP"), use search_course_content with course_name AND lesson_number to get that lesson's actual content. Do NOT just return the course outline. +- Use get_course_outline only when the user asks to LIST or OVERVIEW all lessons in a course. +- For complex queries, break them into steps: first find the relevant course/lesson, then search for specific content. +- After receiving tool results, synthesize them into a well-structured answer. Do NOT dump raw tool output. +- Structure your final answer using markdown: use headings, bullet points, and bold for key terms. +- Only include information that directly answers the user's question — ignore irrelevant tool results. +- If tool results are empty or unhelpful, say so honestly instead of fabricating an answer. All responses must be: 1. **Brief, Concise and focused** - Get to the point quickly @@ -28,108 +34,158 @@ class AIGenerator: 4. **Example-supported** - Include relevant examples when they aid understanding Provide only the direct answer to what was asked. """ - - def __init__(self, api_key: str, model: str): - self.client = anthropic.Anthropic(api_key=api_key) + + def __init__(self, base_url: str, model: str): + self.base_url = base_url.rstrip("/") self.model = model - - # Pre-build base API parameters - self.base_params = { + # 120s timeout to handle first-request model loading + self.http_client = httpx.Client(timeout=120.0) + + def _call_ollama(self, messages: List[Dict], tools: Optional[List] = None) -> Dict: + """Make a POST request to the Ollama chat API.""" + payload = { "model": self.model, - "temperature": 0, - "max_tokens": 800 + "messages": messages, + "stream": False, + "options": { + "temperature": 0, + "num_predict": 800 + } } - + if tools is not None: + payload["tools"] = tools + + resp = self.http_client.post(f"{self.base_url}/api/chat", json=payload) + resp.raise_for_status() + return resp.json() + + def _execute_tool_call(self, tool_call: dict, tool_executor: Callable) -> Tuple[str, str]: + """Extract tool name/args from a tool call and execute via tool_executor.""" + func = tool_call.get("function", {}) + name = func.get("name", "") + arguments = func.get("arguments", {}) + try: + result = tool_executor(name, **arguments) + return name, str(result) + except Exception as e: + return name, f"Tool error: {e}" + + def _parse_tool_call_from_text(self, content: str) -> Optional[dict]: + """Try to recover a tool call that the LLM wrote as plain text. + + Small models sometimes emit JSON-like text instead of using the + structured tool_calls field. We attempt a best-effort parse. + """ + import re + if not content: + return None + + # Try to find JSON-ish blob with "name" and "parameters" + # Handle unquoted identifiers like: {"name": get_course_outline, ...} + text = content.strip() + # Quick gate: must look like it's trying to be a tool call + if '"name"' not in text and "'name'" not in text: + return None + + # Fix common malformed JSON: unquoted values after "name": + text = re.sub(r'"name"\s*:\s*([a-zA-Z_]\w*)', r'"name": "\1"', text) + # Fix single quotes to double quotes + text = text.replace("'", '"') + + try: + obj = json.loads(text) + except json.JSONDecodeError: + # Try to extract just the JSON object + match = re.search(r'\{.*\}', text, re.DOTALL) + if not match: + return None + try: + obj = json.loads(match.group()) + except json.JSONDecodeError: + return None + + name = obj.get("name") + params = obj.get("parameters") or obj.get("arguments") or {} + if not name: + return None + + return {"function": {"name": name, "arguments": params}} + + def _run_tool_round(self, messages: List[Dict], tools: List, + tool_executor: Callable, remaining_rounds: int) -> str: + """Recursively call Ollama, executing tool calls until the LLM produces text.""" + # If rounds remain, offer tools; otherwise force a text-only response + data = self._call_ollama( + messages, tools=tools if remaining_rounds > 0 else None + ) + + assistant_msg = data.get("message", {}) + tool_calls = assistant_msg.get("tool_calls") + content = assistant_msg.get("content", "") + + # Fallback: if the LLM wrote the tool call as plain text, parse it + if not tool_calls and content and remaining_rounds > 0: + recovered = self._parse_tool_call_from_text(content) + if recovered: + tool_calls = [recovered] + # Rewrite assistant_msg so context stays consistent + assistant_msg = {"role": "assistant", "content": "", "tool_calls": tool_calls} + + # Base cases: no tool calls or no remaining rounds + if not tool_calls or remaining_rounds <= 0: + return content + + # Append the assistant message (with its tool_calls) to context + messages.append(assistant_msg) + + # Execute each tool call and append results + for tc in tool_calls: + name, result = self._execute_tool_call(tc, tool_executor) + print(f"[Tool Call] {name}({tc.get('function', {}).get('arguments', {})}) -> {len(result)} chars") + messages.append({"role": "tool", "content": result}) + + return self._run_tool_round(messages, tools, tool_executor, remaining_rounds - 1) + def generate_response(self, query: str, + context: str = "", conversation_history: Optional[str] = None, tools: Optional[List] = None, - tool_manager=None) -> str: + tool_executor: Optional[Callable] = None) -> str: """ - Generate AI response with optional tool usage and conversation context. - + Generate AI response using provided context. + Args: query: The user's question or request + context: Retrieved course content to answer from conversation_history: Previous messages for context - tools: Available tools the AI can use - tool_manager: Manager to execute tools - + tools: Optional tool definitions for Ollama function calling + tool_executor: Callable(name, **kwargs) to execute tool calls + Returns: Generated response as string """ - - # Build system content efficiently - avoid string ops when possible - system_content = ( - f"{self.SYSTEM_PROMPT}\n\nPrevious conversation:\n{conversation_history}" - if conversation_history - else self.SYSTEM_PROMPT - ) - - # Prepare API call parameters efficiently - api_params = { - **self.base_params, - "messages": [{"role": "user", "content": query}], - "system": system_content - } - - # Add tools if available - if tools: - api_params["tools"] = tools - api_params["tool_choice"] = {"type": "auto"} - - # Get response from Claude - response = self.client.messages.create(**api_params) - - # Handle tool execution if needed - if response.stop_reason == "tool_use" and tool_manager: - return self._handle_tool_execution(response, api_params, tool_manager) - - # Return direct response - return response.content[0].text - - def _handle_tool_execution(self, initial_response, base_params: Dict[str, Any], tool_manager): - """ - Handle execution of tool calls and get follow-up response. - - Args: - initial_response: The response containing tool use requests - base_params: Base API parameters - tool_manager: Manager to execute tools - - Returns: - Final response text after tool execution - """ - # Start with existing messages - messages = base_params["messages"].copy() - - # Add AI's tool use response - messages.append({"role": "assistant", "content": initial_response.content}) - - # Execute all tool calls and collect results - tool_results = [] - for content_block in initial_response.content: - if content_block.type == "tool_use": - tool_result = tool_manager.execute_tool( - content_block.name, - **content_block.input - ) - - tool_results.append({ - "type": "tool_result", - "tool_use_id": content_block.id, - "content": tool_result - }) - - # Add tool results as single message - if tool_results: - messages.append({"role": "user", "content": tool_results}) - - # Prepare final API call without tools - final_params = { - **self.base_params, - "messages": messages, - "system": base_params["system"] - } - - # Get final response - final_response = self.client.messages.create(**final_params) - return final_response.content[0].text \ No newline at end of file + + # Build system content + system_content = self.SYSTEM_PROMPT + if conversation_history: + system_content += f"\n\nPrevious conversation:\n{conversation_history}" + + # Build user message with context + if context: + user_content = f"Course context:\n{context}\n\nQuestion: {query}" + else: + user_content = query + + # Build messages list + messages = [ + {"role": "system", "content": system_content}, + {"role": "user", "content": user_content} + ] + + # Tool-calling path: recursive loop + if tools is not None and tool_executor is not None: + return self._run_tool_round(messages, tools, tool_executor, self.MAX_TOOL_ROUNDS) + + # Default path: single call, no tools + data = self._call_ollama(messages) + return data.get("message", {}).get("content", "") diff --git a/backend/app.py b/backend/app.py index 5a69d741d..491e2b4f2 100644 --- a/backend/app.py +++ b/backend/app.py @@ -40,10 +40,19 @@ class QueryRequest(BaseModel): query: str session_id: Optional[str] = None +class SourceItem(BaseModel): + """A single source with title and optional link""" + title: str + link: Optional[str] = None + class QueryResponse(BaseModel): """Response model for course queries""" answer: str - sources: List[str] + sources: List[SourceItem] + session_id: str + +class ClearSessionRequest(BaseModel): + """Request model for clearing a session""" session_id: str class CourseStats(BaseModel): @@ -73,6 +82,15 @@ async def query_documents(request: QueryRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) +@app.post("/api/sessions/clear") +async def clear_session(request: ClearSessionRequest): + """Clear a chat session to free server memory""" + try: + rag_system.session_manager.clear_session(request.session_id) + return {"status": "ok"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + @app.get("/api/courses", response_model=CourseStats) async def get_course_stats(): """Get course analytics and statistics""" diff --git a/backend/config.py b/backend/config.py index d9f6392ef..30dc10079 100644 --- a/backend/config.py +++ b/backend/config.py @@ -8,9 +8,9 @@ @dataclass class Config: """Configuration settings for the RAG system""" - # Anthropic API settings - ANTHROPIC_API_KEY: str = os.getenv("ANTHROPIC_API_KEY", "") - ANTHROPIC_MODEL: str = "claude-sonnet-4-20250514" + # Ollama settings (free local LLM) + OLLAMA_BASE_URL: str = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") + OLLAMA_MODEL: str = os.getenv("OLLAMA_MODEL", "llama3.2") # Embedding model settings EMBEDDING_MODEL: str = "all-MiniLM-L6-v2" diff --git a/backend/rag_system.py b/backend/rag_system.py index 50d848c8e..0af1b8dde 100644 --- a/backend/rag_system.py +++ b/backend/rag_system.py @@ -4,7 +4,7 @@ from vector_store import VectorStore from ai_generator import AIGenerator from session_manager import SessionManager -from search_tools import ToolManager, CourseSearchTool +from search_tools import ToolManager, CourseSearchTool, CourseOutlineTool from models import Course, Lesson, CourseChunk class RAGSystem: @@ -16,13 +16,15 @@ def __init__(self, config): # Initialize core components self.document_processor = DocumentProcessor(config.CHUNK_SIZE, config.CHUNK_OVERLAP) self.vector_store = VectorStore(config.CHROMA_PATH, config.EMBEDDING_MODEL, config.MAX_RESULTS) - self.ai_generator = AIGenerator(config.ANTHROPIC_API_KEY, config.ANTHROPIC_MODEL) + self.ai_generator = AIGenerator(config.OLLAMA_BASE_URL, config.OLLAMA_MODEL) self.session_manager = SessionManager(config.MAX_HISTORY) # Initialize search tools self.tool_manager = ToolManager() self.search_tool = CourseSearchTool(self.vector_store) + self.outline_tool = CourseOutlineTool(self.vector_store) self.tool_manager.register_tool(self.search_tool) + self.tool_manager.register_tool(self.outline_tool) def add_course_document(self, file_path: str) -> Tuple[Course, int]: """ @@ -99,44 +101,52 @@ def add_course_folder(self, folder_path: str, clear_existing: bool = False) -> T return total_courses, total_chunks - def query(self, query: str, session_id: Optional[str] = None) -> Tuple[str, List[str]]: + def _is_search_error(self, result: str) -> bool: + """Check if a search tool result is an error/empty message, not real content.""" + error_prefixes = ("Search error:", "No relevant content found", "No course found") + return result.startswith(error_prefixes) + + def query(self, query: str, session_id: Optional[str] = None) -> Tuple[str, List[Dict]]: """ - Process a user query using the RAG system with tool-based search. - + Process a user query using the RAG system. + + The LLM decides which tools to call via the tool-calling loop in AIGenerator. + Args: query: User's question session_id: Optional session ID for conversation context - + Returns: - Tuple of (response, sources list - empty for tool-based approach) + Tuple of (response, sources list) """ - # Create prompt for the AI with clear instructions - prompt = f"""Answer this question about course materials: {query}""" - - # Get conversation history if session exists + # Reset sources from previous query + self.tool_manager.reset_sources() + + # Get conversation history history = None if session_id: history = self.session_manager.get_conversation_history(session_id) - - # Generate response using AI with tools - response = self.ai_generator.generate_response( - query=prompt, - conversation_history=history, - tools=self.tool_manager.get_tool_definitions(), - tool_manager=self.tool_manager - ) - - # Get sources from the search tool + + # Let the LLM decide which tools to call + try: + response = self.ai_generator.generate_response( + query=query, + conversation_history=history, + tools=self.tool_manager.get_tool_definitions(), + tool_executor=self.tool_manager.execute_tool, + ) + except Exception as e: + print(f"AI generation error: {e}") + sources = self.tool_manager.get_last_sources() + return "Sorry, I'm having trouble generating a response right now. Please try again.", sources + + # Collect sources from all tools used during the conversation sources = self.tool_manager.get_last_sources() - # Reset sources after retrieving them - self.tool_manager.reset_sources() - - # Update conversation history + # Update session history if session_id: self.session_manager.add_exchange(session_id, query, response) - - # Return response with sources from tool searches + return response, sources def get_course_analytics(self) -> Dict: diff --git a/backend/search_tools.py b/backend/search_tools.py index adfe82352..42dc7a39a 100644 --- a/backend/search_tools.py +++ b/backend/search_tools.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Optional, Protocol +from typing import Dict, Any, Optional, List, Protocol from abc import ABC, abstractmethod from vector_store import VectorStore, SearchResults @@ -8,7 +8,7 @@ class Tool(ABC): @abstractmethod def get_tool_definition(self) -> Dict[str, Any]: - """Return Anthropic tool definition for this tool""" + """Return OpenAI-compatible tool definition for Ollama""" pass @abstractmethod @@ -25,43 +25,56 @@ def __init__(self, vector_store: VectorStore): self.last_sources = [] # Track sources from last search def get_tool_definition(self) -> Dict[str, Any]: - """Return Anthropic tool definition for this tool""" + """Return OpenAI-compatible tool definition for Ollama""" return { - "name": "search_course_content", - "description": "Search course materials with smart course name matching and lesson filtering", - "input_schema": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "What to search for in the course content" + "type": "function", + "function": { + "name": "search_course_content", + "description": "Search course materials with smart course name matching and lesson filtering", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "What to search for in the course content" + }, + "course_name": { + "type": "string", + "description": "Course title (partial matches work, e.g. 'MCP', 'Introduction')" + }, + "lesson_number": { + "type": "integer", + "description": "Specific lesson number to search within (e.g. 1, 2, 3)" + } }, - "course_name": { - "type": "string", - "description": "Course title (partial matches work, e.g. 'MCP', 'Introduction')" - }, - "lesson_number": { - "type": "integer", - "description": "Specific lesson number to search within (e.g. 1, 2, 3)" - } - }, - "required": ["query"] + "required": ["query"] + } } } - def execute(self, query: str, course_name: Optional[str] = None, lesson_number: Optional[int] = None) -> str: + def execute(self, query: str = "", course_name: Optional[str] = None, lesson_number: Optional[int] = None) -> str: """ Execute the search tool with given parameters. - + Args: query: What to search for course_name: Optional course filter lesson_number: Optional lesson filter - + Returns: Formatted search results or error message """ - + # Coerce lesson_number to int (LLMs sometimes send it as a string) + if lesson_number is not None: + try: + lesson_number = int(lesson_number) + except (ValueError, TypeError): + lesson_number = None + + # If query is empty but we have filters, use a generic query + if not query and (course_name or lesson_number is not None): + query = f"{course_name or ''} lesson {lesson_number or ''}".strip() + # Use the vector store's unified search interface results = self.store.search( query=query, @@ -88,31 +101,93 @@ def execute(self, query: str, course_name: Optional[str] = None, lesson_number: def _format_results(self, results: SearchResults) -> str: """Format search results with course and lesson context""" formatted = [] - sources = [] # Track sources for the UI - + sources: List[Dict[str, Any]] = [] + seen = set() # Deduplicate by (course_title, lesson_num) + for doc, meta in zip(results.documents, results.metadata): course_title = meta.get('course_title', 'unknown') lesson_num = meta.get('lesson_number') - + # Build context header header = f"[{course_title}" if lesson_num is not None: header += f" - Lesson {lesson_num}" header += "]" - - # Track source for the UI - source = course_title - if lesson_num is not None: - source += f" - Lesson {lesson_num}" - sources.append(source) - + + # Track source for the UI (deduplicated) + key = (course_title, lesson_num) + if key not in seen: + seen.add(key) + title = course_title + if lesson_num is not None: + title += f" - Lesson {lesson_num}" + link = None + if lesson_num is not None: + link = self.store.get_lesson_link(course_title, lesson_num) + if not link: + link = self.store.get_course_link(course_title) + sources.append({"title": title, "link": link}) + formatted.append(f"{header}\n{doc}") - - # Store sources for retrieval - self.last_sources = sources - + + # Accumulate sources (may be called multiple times per query) + self.last_sources.extend(sources) + return "\n\n".join(formatted) +class CourseOutlineTool(Tool): + """Tool for retrieving course outlines from the course catalog""" + + def __init__(self, vector_store: VectorStore): + self.store = vector_store + self.last_sources = [] + + def get_tool_definition(self) -> Dict[str, Any]: + return { + "type": "function", + "function": { + "name": "get_course_outline", + "description": "Get the full outline of a course including title, link, and all lessons", + "parameters": { + "type": "object", + "properties": { + "course_name": { + "type": "string", + "description": "Course title or partial name (e.g. 'MCP', 'RAG chatbot')" + } + }, + "required": ["course_name"] + } + } + } + + def execute(self, course_name: str) -> str: + outline = self.store.get_course_outline(course_name) + if not outline: + return f"No course found matching '{course_name}'." + + title = outline["title"] + link = outline.get("course_link") or "N/A" + lessons = outline.get("lessons", []) + + # Accumulate sources (may be called multiple times per query) + self.last_sources.extend([{"title": title, "link": outline.get("course_link")}]) + + lines = [ + f"COURSE OUTLINE (list every lesson as a markdown bullet point):", + f"**Course:** {title}", + f"**Course Link:** {link}", + f"**Total Lessons:** {len(lessons)}", + "", + ] + for lesson in lessons: + num = lesson.get("lesson_number", "?") + ltitle = lesson.get("lesson_title", "Untitled") + lines.append(f"- **Lesson {num}:** {ltitle}") + + return "\n".join(lines) + + class ToolManager: """Manages available tools for the AI""" @@ -122,14 +197,16 @@ def __init__(self): def register_tool(self, tool: Tool): """Register any tool that implements the Tool interface""" tool_def = tool.get_tool_definition() - tool_name = tool_def.get("name") + # OpenAI-compatible format nests name under "function" + func_def = tool_def.get("function", {}) + tool_name = func_def.get("name") or tool_def.get("name") if not tool_name: raise ValueError("Tool must have a 'name' in its definition") self.tools[tool_name] = tool def get_tool_definitions(self) -> list: - """Get all tool definitions for Anthropic tool calling""" + """Get all tool definitions for Ollama tool calling""" return [tool.get_tool_definition() for tool in self.tools.values()] def execute_tool(self, tool_name: str, **kwargs) -> str: @@ -140,12 +217,17 @@ def execute_tool(self, tool_name: str, **kwargs) -> str: return self.tools[tool_name].execute(**kwargs) def get_last_sources(self) -> list: - """Get sources from the last search operation""" - # Check all tools for last_sources attribute + """Get sources from all tools, deduplicated by (title, link).""" + seen = set() + sources = [] for tool in self.tools.values(): - if hasattr(tool, 'last_sources') and tool.last_sources: - return tool.last_sources - return [] + if hasattr(tool, 'last_sources'): + for s in tool.last_sources: + key = (s.get("title"), s.get("link")) + if key not in seen: + seen.add(key) + sources.append(s) + return sources def reset_sources(self): """Reset sources from all tools that track sources""" diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 000000000..2df6d3693 --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,185 @@ +"""Shared fixtures for RAG chatbot tests.""" + +import sys +import os +import pytest +from unittest.mock import MagicMock, patch +from dataclasses import dataclass + +# Add backend to path so imports work +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from vector_store import SearchResults +from models import Course, Lesson, CourseChunk + + +# --------------------------------------------------------------------------- +# API test fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_rag_system(): + """A fully mocked RAGSystem for API endpoint tests.""" + rag = MagicMock() + rag.query.return_value = ( + "RAG stands for Retrieval-Augmented Generation.", + [{"title": "Building RAG Chatbots", "link": "https://example.com/rag"}], + ) + rag.session_manager.create_session.return_value = "session_1" + rag.get_course_analytics.return_value = { + "total_courses": 2, + "course_titles": ["Building RAG Chatbots", "Intro to MCP"], + } + return rag + + +@pytest.fixture +def test_app(mock_rag_system): + """A lightweight FastAPI app with the same endpoints as app.py, + but without static-file mounting or startup document loading.""" + from fastapi import FastAPI, HTTPException + from pydantic import BaseModel + from typing import List, Optional + + api = FastAPI() + + class QueryRequest(BaseModel): + query: str + session_id: Optional[str] = None + + class SourceItem(BaseModel): + title: str + link: Optional[str] = None + + class QueryResponse(BaseModel): + answer: str + sources: List[SourceItem] + session_id: str + + class ClearSessionRequest(BaseModel): + session_id: str + + class CourseStats(BaseModel): + total_courses: int + course_titles: List[str] + + rag = mock_rag_system + + @api.post("/api/query", response_model=QueryResponse) + async def query_documents(request: QueryRequest): + try: + session_id = request.session_id + if not session_id: + session_id = rag.session_manager.create_session() + answer, sources = rag.query(request.query, session_id) + return QueryResponse(answer=answer, sources=sources, session_id=session_id) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @api.post("/api/sessions/clear") + async def clear_session(request: ClearSessionRequest): + try: + rag.session_manager.clear_session(request.session_id) + return {"status": "ok"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @api.get("/api/courses", response_model=CourseStats) + async def get_course_stats(): + try: + analytics = rag.get_course_analytics() + return CourseStats( + total_courses=analytics["total_courses"], + course_titles=analytics["course_titles"], + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + return api + + +@pytest.fixture +def client(test_app): + """HTTPX TestClient (sync) for the test FastAPI app.""" + from starlette.testclient import TestClient + return TestClient(test_app) + + +@dataclass +class MockConfig: + """Minimal config for testing without .env or real services.""" + OLLAMA_BASE_URL: str = "http://localhost:11434" + OLLAMA_MODEL: str = "test-model" + EMBEDDING_MODEL: str = "all-MiniLM-L6-v2" + CHUNK_SIZE: int = 800 + CHUNK_OVERLAP: int = 100 + MAX_RESULTS: int = 5 + MAX_HISTORY: int = 2 + CHROMA_PATH: str = "./test_chroma_db" + + +@pytest.fixture +def mock_config(): + return MockConfig() + + +@pytest.fixture +def sample_search_results(): + """Realistic search results as returned by VectorStore.search().""" + return SearchResults( + documents=[ + "Lesson 1 content: RAG stands for Retrieval-Augmented Generation. It combines retrieval and generation.", + "Lesson 2 content: Vector databases store embeddings for semantic search.", + ], + metadata=[ + {"course_title": "Building RAG Chatbots", "lesson_number": 1, "chunk_index": 0}, + {"course_title": "Building RAG Chatbots", "lesson_number": 2, "chunk_index": 3}, + ], + distances=[0.25, 0.42], + ) + + +@pytest.fixture +def empty_search_results(): + """Empty search results (no matching documents).""" + return SearchResults(documents=[], metadata=[], distances=[]) + + +@pytest.fixture +def error_search_results(): + """Search results with an error.""" + return SearchResults(documents=[], metadata=[], distances=[], error="Search error: collection is empty") + + +@pytest.fixture +def sample_course(): + """A sample Course object for testing.""" + return Course( + title="Building RAG Chatbots", + course_link="https://example.com/rag-course", + instructor="Test Instructor", + lessons=[ + Lesson(lesson_number=1, title="Introduction to RAG", lesson_link="https://example.com/lesson1"), + Lesson(lesson_number=2, title="Vector Databases", lesson_link="https://example.com/lesson2"), + Lesson(lesson_number=3, title="Building the Pipeline", lesson_link="https://example.com/lesson3"), + ], + ) + + +@pytest.fixture +def mock_vector_store(): + """A fully mocked VectorStore (avoids ChromaDB + embedding model).""" + store = MagicMock() + store.search.return_value = SearchResults( + documents=[ + "RAG combines retrieval with generation for better answers.", + ], + metadata=[ + {"course_title": "Building RAG Chatbots", "lesson_number": 1, "chunk_index": 0}, + ], + distances=[0.3], + ) + store.get_lesson_link.return_value = "https://example.com/lesson1" + store.get_course_link.return_value = "https://example.com/rag-course" + store.get_course_outline.return_value = None # default: no outline match + return store diff --git a/backend/tests/test_ai_generator.py b/backend/tests/test_ai_generator.py new file mode 100644 index 000000000..6718fde0a --- /dev/null +++ b/backend/tests/test_ai_generator.py @@ -0,0 +1,424 @@ +"""Tests for AIGenerator in ai_generator.py. + +These tests verify that the AI generator: +- Correctly formats messages with search context for the LLM +- Includes the system prompt in every call +- Handles conversation history +- Propagates connection/HTTP errors (which cause 'query failed' in the frontend) +- Handles unexpected response shapes gracefully +""" + +import pytest +from unittest.mock import MagicMock, patch, PropertyMock +import httpx + +from ai_generator import AIGenerator + + +@pytest.fixture +def generator(): + """AIGenerator pointed at a fake Ollama URL.""" + return AIGenerator(base_url="http://localhost:11434", model="test-model") + + +@pytest.fixture +def mock_ollama_success(monkeypatch): + """Patch httpx.Client.post to return a successful Ollama response.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"role": "assistant", "content": "RAG combines retrieval with generation."} + } + mock_response.raise_for_status = MagicMock() + + mock_post = MagicMock(return_value=mock_response) + monkeypatch.setattr(httpx.Client, "post", mock_post) + return mock_post + + +# --------------------------------------------------------------------------- +# Message formatting tests +# --------------------------------------------------------------------------- + +class TestGenerateResponseFormatting: + """Verify the messages sent to Ollama are correctly structured.""" + + def test_context_is_included_in_user_message(self, generator, mock_ollama_success): + """When context is provided, it should appear in the user message.""" + generator.generate_response( + query="What is RAG?", + context="[Building RAG Chatbots - Lesson 1]\nRAG stands for Retrieval-Augmented Generation.", + ) + + # Inspect the payload sent to Ollama + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + messages = payload["messages"] + + user_msg = next(m for m in messages if m["role"] == "user") + assert "Course context:" in user_msg["content"] + assert "RAG stands for Retrieval-Augmented Generation" in user_msg["content"] + assert "What is RAG?" in user_msg["content"] + + def test_no_context_sends_query_only(self, generator, mock_ollama_success): + """Without context, the user message should be just the query.""" + generator.generate_response(query="Hello") + + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + messages = payload["messages"] + + user_msg = next(m for m in messages if m["role"] == "user") + assert user_msg["content"] == "Hello" + assert "Course context:" not in user_msg["content"] + + def test_system_prompt_always_present(self, generator, mock_ollama_success): + """The system prompt should always be the first message.""" + generator.generate_response(query="test") + + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + messages = payload["messages"] + + system_msg = messages[0] + assert system_msg["role"] == "system" + assert "AI assistant" in system_msg["content"] + + def test_conversation_history_appended_to_system(self, generator, mock_ollama_success): + """Conversation history should be appended to the system prompt.""" + history = "User: What is RAG?\nAssistant: RAG is retrieval-augmented generation." + + generator.generate_response(query="Tell me more", conversation_history=history) + + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + messages = payload["messages"] + + system_msg = messages[0] + assert "Previous conversation:" in system_msg["content"] + assert "What is RAG?" in system_msg["content"] + + def test_empty_context_treated_as_no_context(self, generator, mock_ollama_success): + """An empty string context should NOT add 'Course context:' prefix.""" + generator.generate_response(query="Hello", context="") + + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + messages = payload["messages"] + + user_msg = next(m for m in messages if m["role"] == "user") + assert user_msg["content"] == "Hello" + + +# --------------------------------------------------------------------------- +# Response extraction tests +# --------------------------------------------------------------------------- + +class TestGenerateResponseOutput: + """Verify the response is correctly extracted from the Ollama reply.""" + + def test_returns_content_string(self, generator, mock_ollama_success): + """generate_response should return the content string from Ollama.""" + result = generator.generate_response(query="What is RAG?", context="some context") + + assert result == "RAG combines retrieval with generation." + + def test_handles_missing_message_key(self, generator, monkeypatch): + """If Ollama returns unexpected JSON, should return empty string (not crash).""" + mock_response = MagicMock() + mock_response.json.return_value = {"unexpected": "format"} + mock_response.raise_for_status = MagicMock() + monkeypatch.setattr(httpx.Client, "post", MagicMock(return_value=mock_response)) + + result = generator.generate_response(query="test") + + assert result == "" + + def test_handles_missing_content_key(self, generator, monkeypatch): + """If 'message' exists but has no 'content', should return empty string.""" + mock_response = MagicMock() + mock_response.json.return_value = {"message": {"role": "assistant"}} + mock_response.raise_for_status = MagicMock() + monkeypatch.setattr(httpx.Client, "post", MagicMock(return_value=mock_response)) + + result = generator.generate_response(query="test") + + assert result == "" + + +# --------------------------------------------------------------------------- +# Error handling tests — these expose the root cause of "query failed" +# --------------------------------------------------------------------------- + +class TestGenerateResponseErrors: + """Test error scenarios that cause the 'query failed' frontend error.""" + + def test_connection_error_propagates(self, generator, monkeypatch): + """If Ollama is not running, a ConnectError should propagate. + + This is the most likely cause of 'query failed' — the ai_generator + does NOT catch connection errors, so they bubble up through + rag_system.query() → app.py → HTTP 500 → frontend 'Query failed'. + """ + def raise_connect_error(*args, **kwargs): + raise httpx.ConnectError("Connection refused") + + monkeypatch.setattr(httpx.Client, "post", raise_connect_error) + + with pytest.raises(httpx.ConnectError): + generator.generate_response(query="What is RAG?", context="some context") + + def test_http_error_propagates(self, generator, monkeypatch): + """If Ollama returns a 4xx/5xx, the error should propagate. + + For example, if the model name is wrong, Ollama returns 404. + """ + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Not Found", request=MagicMock(), response=mock_response + ) + monkeypatch.setattr(httpx.Client, "post", MagicMock(return_value=mock_response)) + + with pytest.raises(httpx.HTTPStatusError): + generator.generate_response(query="test", context="ctx") + + def test_timeout_error_propagates(self, generator, monkeypatch): + """If Ollama takes too long, a timeout error should propagate.""" + def raise_timeout(*args, **kwargs): + raise httpx.ReadTimeout("Read timed out") + + monkeypatch.setattr(httpx.Client, "post", raise_timeout) + + with pytest.raises(httpx.ReadTimeout): + generator.generate_response(query="test", context="ctx") + + def test_generator_does_not_filter_context(self, generator, mock_ollama_success): + """AIGenerator passes context as-is — filtering is rag_system's job. + + The generator is a thin wrapper around Ollama. If it receives an error + string as context, it will include it. The fix for this lives in + rag_system.query(), which now filters error strings before they + reach the generator. + """ + error_as_context = "Search error: collection is empty" + + generator.generate_response(query="What is RAG?", context=error_as_context) + + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + user_msg = next(m for m in payload["messages"] if m["role"] == "user") + + # Generator doesn't filter — it passes whatever it receives + assert "Search error:" in user_msg["content"] + + +# --------------------------------------------------------------------------- +# Request structure tests +# --------------------------------------------------------------------------- + +class TestOllamaRequestFormat: + """Verify the HTTP request to Ollama is well-formed.""" + + def test_request_uses_correct_endpoint(self, generator, mock_ollama_success): + """Should POST to /api/chat.""" + generator.generate_response(query="test") + + call_args = mock_ollama_success.call_args + url = call_args.args[0] if call_args.args else call_args[0][0] + assert url == "http://localhost:11434/api/chat" + + def test_request_includes_model(self, generator, mock_ollama_success): + """Payload should include the configured model name.""" + generator.generate_response(query="test") + + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + assert payload["model"] == "test-model" + + def test_request_disables_streaming(self, generator, mock_ollama_success): + """stream should be False (we expect a single JSON response).""" + generator.generate_response(query="test") + + call_args = mock_ollama_success.call_args + payload = call_args.kwargs.get("json") or call_args[1].get("json") + assert payload["stream"] is False + + +# --------------------------------------------------------------------------- +# Helper factories for tool-calling tests +# --------------------------------------------------------------------------- + +def _make_response(message_body: dict): + """Create a mock httpx response returning the given message dict.""" + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = {"message": message_body} + resp.raise_for_status = MagicMock() + return resp + + +def make_tool_call_response(tool_name: str, arguments: dict): + """Mock Ollama response that requests a tool call.""" + return _make_response({ + "role": "assistant", + "content": "", + "tool_calls": [ + {"function": {"name": tool_name, "arguments": arguments}} + ], + }) + + +def make_text_response(content: str): + """Mock Ollama response that returns plain text (no tool calls).""" + return _make_response({ + "role": "assistant", + "content": content, + }) + + +SAMPLE_TOOLS = [{"type": "function", "function": {"name": "search", "parameters": {}}}] + + +# --------------------------------------------------------------------------- +# Sequential tool-calling tests +# --------------------------------------------------------------------------- + +class TestSequentialToolCalling: + """Verify the recursive tool-calling loop in generate_response.""" + + def test_single_tool_call_then_text(self, generator, monkeypatch): + """One tool call followed by a text response → 2 API calls, 1 execution.""" + mock_post = MagicMock(side_effect=[ + make_tool_call_response("search", {"query": "RAG"}), + make_text_response("Here is the answer."), + ]) + monkeypatch.setattr(httpx.Client, "post", mock_post) + executor = MagicMock(return_value="search results") + + result = generator.generate_response( + query="What is RAG?", tools=SAMPLE_TOOLS, tool_executor=executor, + ) + + assert result == "Here is the answer." + assert mock_post.call_count == 2 + executor.assert_called_once_with("search", query="RAG") + + def test_two_sequential_tool_calls(self, generator, monkeypatch): + """Two sequential tool calls then text → 3 API calls, 2 executions.""" + mock_post = MagicMock(side_effect=[ + make_tool_call_response("search", {"query": "course X"}), + make_tool_call_response("search", {"query": "lesson 4"}), + make_text_response("Final answer."), + ]) + monkeypatch.setattr(httpx.Client, "post", mock_post) + executor = MagicMock(return_value="result") + + result = generator.generate_response( + query="complex query", tools=SAMPLE_TOOLS, tool_executor=executor, + ) + + assert result == "Final answer." + assert mock_post.call_count == 3 + assert executor.call_count == 2 + + def test_no_tool_calls_returns_immediately(self, generator, monkeypatch): + """If the LLM returns text right away, no tool execution happens.""" + mock_post = MagicMock(side_effect=[ + make_text_response("Immediate answer."), + ]) + monkeypatch.setattr(httpx.Client, "post", mock_post) + executor = MagicMock() + + result = generator.generate_response( + query="hi", tools=SAMPLE_TOOLS, tool_executor=executor, + ) + + assert result == "Immediate answer." + assert mock_post.call_count == 1 + executor.assert_not_called() + + def test_max_rounds_forces_final_call_without_tools(self, generator, monkeypatch): + """After MAX_TOOL_ROUNDS tool calls, the next call omits tools to force text.""" + # MAX_TOOL_ROUNDS is 3, so we need 3 tool-call rounds + 1 forced text + mock_post = MagicMock(side_effect=[ + make_tool_call_response("search", {"query": "a"}), + make_tool_call_response("search", {"query": "b"}), + make_tool_call_response("search", {"query": "c"}), + make_text_response("Forced text."), + ]) + monkeypatch.setattr(httpx.Client, "post", mock_post) + executor = MagicMock(return_value="r") + + result = generator.generate_response( + query="q", tools=SAMPLE_TOOLS, tool_executor=executor, + ) + + assert result == "Forced text." + # Fourth call should NOT have tools in its payload + fourth_call_payload = mock_post.call_args_list[3].kwargs.get("json") or mock_post.call_args_list[3][1].get("json") + assert "tools" not in fourth_call_payload + + def test_tool_error_passed_as_result(self, generator, monkeypatch): + """If tool_executor raises, the error string is sent as the tool result.""" + mock_post = MagicMock(side_effect=[ + make_tool_call_response("bad_tool", {}), + make_text_response("Recovered."), + ]) + monkeypatch.setattr(httpx.Client, "post", mock_post) + executor = MagicMock(side_effect=ValueError("something broke")) + + result = generator.generate_response( + query="q", tools=SAMPLE_TOOLS, tool_executor=executor, + ) + + assert result == "Recovered." + # Verify the error was passed as a tool message + second_call_payload = mock_post.call_args_list[1].kwargs.get("json") or mock_post.call_args_list[1][1].get("json") + tool_msgs = [m for m in second_call_payload["messages"] if m["role"] == "tool"] + assert len(tool_msgs) == 1 + assert "Tool error:" in tool_msgs[0]["content"] + assert "something broke" in tool_msgs[0]["content"] + + def test_messages_accumulate_across_rounds(self, generator, monkeypatch): + """After 2 rounds, the final API call should contain the full conversation.""" + mock_post = MagicMock(side_effect=[ + make_tool_call_response("search", {"query": "first"}), + make_tool_call_response("search", {"query": "second"}), + make_text_response("Done."), + ]) + monkeypatch.setattr(httpx.Client, "post", mock_post) + executor = MagicMock(return_value="data") + + generator.generate_response( + query="q", tools=SAMPLE_TOOLS, tool_executor=executor, + ) + + final_payload = mock_post.call_args_list[2].kwargs.get("json") or mock_post.call_args_list[2][1].get("json") + msgs = final_payload["messages"] + roles = [m["role"] for m in msgs] + # system, user, assistant(tool_call), tool, assistant(tool_call), tool + assert roles == ["system", "user", "assistant", "tool", "assistant", "tool"] + + def test_tools_in_payload_when_provided(self, generator, monkeypatch): + """When tools are provided, the first API call payload should include them.""" + mock_post = MagicMock(side_effect=[ + make_text_response("answer"), + ]) + monkeypatch.setattr(httpx.Client, "post", mock_post) + + generator.generate_response( + query="q", tools=SAMPLE_TOOLS, tool_executor=MagicMock(), + ) + + first_payload = mock_post.call_args_list[0].kwargs.get("json") or mock_post.call_args_list[0][1].get("json") + assert "tools" in first_payload + assert first_payload["tools"] == SAMPLE_TOOLS + + def test_no_tools_backward_compatible(self, generator, mock_ollama_success): + """Without tools/tool_executor, behavior is unchanged — no tools in payload.""" + generator.generate_response(query="test") + + payload = mock_ollama_success.call_args.kwargs.get("json") or mock_ollama_success.call_args[1].get("json") + assert "tools" not in payload + assert mock_ollama_success.call_count == 1 diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py new file mode 100644 index 000000000..c6248f5ab --- /dev/null +++ b/backend/tests/test_api.py @@ -0,0 +1,131 @@ +"""Tests for the FastAPI API endpoints.""" + +import pytest +from unittest.mock import MagicMock + + +# ── POST /api/query ───────────────────────────────────────────────────── + + +class TestQueryEndpoint: + """Tests for the /api/query endpoint.""" + + def test_query_returns_answer_and_sources(self, client, mock_rag_system): + resp = client.post("/api/query", json={"query": "What is RAG?"}) + assert resp.status_code == 200 + data = resp.json() + assert data["answer"] == "RAG stands for Retrieval-Augmented Generation." + assert len(data["sources"]) == 1 + assert data["sources"][0]["title"] == "Building RAG Chatbots" + + def test_query_creates_session_when_not_provided(self, client, mock_rag_system): + resp = client.post("/api/query", json={"query": "Hello"}) + data = resp.json() + assert data["session_id"] == "session_1" + mock_rag_system.session_manager.create_session.assert_called_once() + + def test_query_uses_provided_session_id(self, client, mock_rag_system): + resp = client.post( + "/api/query", + json={"query": "Hello", "session_id": "existing_session"}, + ) + data = resp.json() + assert data["session_id"] == "existing_session" + mock_rag_system.session_manager.create_session.assert_not_called() + mock_rag_system.query.assert_called_once_with("Hello", "existing_session") + + def test_query_passes_query_to_rag_system(self, client, mock_rag_system): + client.post("/api/query", json={"query": "Tell me about MCP"}) + mock_rag_system.query.assert_called_once() + assert mock_rag_system.query.call_args[0][0] == "Tell me about MCP" + + def test_query_missing_query_field_returns_422(self, client): + resp = client.post("/api/query", json={}) + assert resp.status_code == 422 + + def test_query_empty_string_passes_through(self, client, mock_rag_system): + resp = client.post("/api/query", json={"query": ""}) + # FastAPI doesn't block empty strings by default; the endpoint processes it + assert resp.status_code == 200 + + def test_query_rag_exception_returns_500(self, client, mock_rag_system): + mock_rag_system.query.side_effect = RuntimeError("Ollama unreachable") + resp = client.post("/api/query", json={"query": "anything"}) + assert resp.status_code == 500 + assert "Ollama unreachable" in resp.json()["detail"] + + def test_query_sources_with_no_link(self, client, mock_rag_system): + mock_rag_system.query.return_value = ( + "Answer", + [{"title": "Some Course"}], + ) + resp = client.post("/api/query", json={"query": "test"}) + data = resp.json() + assert data["sources"][0]["link"] is None + + def test_query_empty_sources(self, client, mock_rag_system): + mock_rag_system.query.return_value = ("No results", []) + resp = client.post("/api/query", json={"query": "obscure question"}) + data = resp.json() + assert data["sources"] == [] + + +# ── POST /api/sessions/clear ──────────────────────────────────────────── + + +class TestClearSessionEndpoint: + """Tests for the /api/sessions/clear endpoint.""" + + def test_clear_session_returns_ok(self, client, mock_rag_system): + resp = client.post( + "/api/sessions/clear", json={"session_id": "session_1"} + ) + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} + + def test_clear_session_calls_session_manager(self, client, mock_rag_system): + client.post("/api/sessions/clear", json={"session_id": "session_42"}) + mock_rag_system.session_manager.clear_session.assert_called_once_with( + "session_42" + ) + + def test_clear_session_missing_body_returns_422(self, client): + resp = client.post("/api/sessions/clear", json={}) + assert resp.status_code == 422 + + def test_clear_session_exception_returns_500(self, client, mock_rag_system): + mock_rag_system.session_manager.clear_session.side_effect = RuntimeError("boom") + resp = client.post( + "/api/sessions/clear", json={"session_id": "bad"} + ) + assert resp.status_code == 500 + + +# ── GET /api/courses ──────────────────────────────────────────────────── + + +class TestCoursesEndpoint: + """Tests for the /api/courses endpoint.""" + + def test_courses_returns_stats(self, client, mock_rag_system): + resp = client.get("/api/courses") + assert resp.status_code == 200 + data = resp.json() + assert data["total_courses"] == 2 + assert "Building RAG Chatbots" in data["course_titles"] + assert "Intro to MCP" in data["course_titles"] + + def test_courses_empty_catalog(self, client, mock_rag_system): + mock_rag_system.get_course_analytics.return_value = { + "total_courses": 0, + "course_titles": [], + } + resp = client.get("/api/courses") + data = resp.json() + assert data["total_courses"] == 0 + assert data["course_titles"] == [] + + def test_courses_exception_returns_500(self, client, mock_rag_system): + mock_rag_system.get_course_analytics.side_effect = RuntimeError("db down") + resp = client.get("/api/courses") + assert resp.status_code == 500 diff --git a/backend/tests/test_rag_system.py b/backend/tests/test_rag_system.py new file mode 100644 index 000000000..4ead1e265 --- /dev/null +++ b/backend/tests/test_rag_system.py @@ -0,0 +1,175 @@ +"""Tests for the RAG system query pipeline in rag_system.py. + +These tests verify the tool-calling orchestration: + query() resets sources → passes tools + executor to AIGenerator → collects sources → returns response. + +All external dependencies (VectorStore, Ollama) are mocked to isolate the logic. +""" + +import pytest +from unittest.mock import MagicMock, patch, call +import httpx + +from vector_store import SearchResults + + +# --------------------------------------------------------------------------- +# Helpers to build a RAGSystem with mocked internals +# --------------------------------------------------------------------------- + +@pytest.fixture +def rag_system(mock_config, mock_vector_store): + """Build a RAGSystem with mocked VectorStore and AIGenerator.""" + with patch("rag_system.VectorStore", return_value=mock_vector_store), \ + patch("rag_system.AIGenerator") as MockAI, \ + patch("rag_system.DocumentProcessor"): + + mock_ai = MockAI.return_value + mock_ai.generate_response.return_value = "RAG is a technique that combines retrieval and generation." + + from rag_system import RAGSystem + system = RAGSystem(mock_config) + + # Replace the vector store in the search tools with our mock + system.search_tool.store = mock_vector_store + system.outline_tool.store = mock_vector_store + system.ai_generator = mock_ai + + yield system + + +@pytest.fixture +def rag_system_with_ollama_error(mock_config, mock_vector_store): + """RAGSystem where the AI generator raises a connection error.""" + with patch("rag_system.VectorStore", return_value=mock_vector_store), \ + patch("rag_system.AIGenerator") as MockAI, \ + patch("rag_system.DocumentProcessor"): + + mock_ai = MockAI.return_value + mock_ai.generate_response.side_effect = httpx.ConnectError("Connection refused") + + from rag_system import RAGSystem + system = RAGSystem(mock_config) + + system.search_tool.store = mock_vector_store + system.outline_tool.store = mock_vector_store + system.ai_generator = mock_ai + + yield system + + +# --------------------------------------------------------------------------- +# Core query pipeline tests — tool-calling flow +# --------------------------------------------------------------------------- + +class TestRAGQueryPipeline: + """Test the main query() method orchestration.""" + + def test_query_passes_tool_definitions(self, rag_system): + """generate_response should receive tool definitions from the tool manager.""" + rag_system.query("What is RAG?") + + call_args = rag_system.ai_generator.generate_response.call_args + tools = call_args.kwargs.get("tools") + assert tools is not None + assert isinstance(tools, list) + assert len(tools) == 2 # search + outline tools + + # Verify tool names + names = {t["function"]["name"] for t in tools} + assert "search_course_content" in names + assert "get_course_outline" in names + + def test_query_passes_tool_executor(self, rag_system): + """generate_response should receive tool_executor bound to the tool manager.""" + rag_system.query("What is RAG?") + + call_args = rag_system.ai_generator.generate_response.call_args + executor = call_args.kwargs.get("tool_executor") + # Bound methods create new objects each access, so compare underlying function + instance + assert executor.__func__ is rag_system.tool_manager.execute_tool.__func__ + assert executor.__self__ is rag_system.tool_manager + + def test_query_returns_response_and_sources(self, rag_system): + """query() should return a (response_text, sources_list) tuple.""" + response, sources = rag_system.query("What is RAG?") + + assert response == "RAG is a technique that combines retrieval and generation." + assert isinstance(sources, list) + + def test_query_collects_sources_after_response(self, rag_system): + """Sources should be gathered from tools after the LLM finishes.""" + # Simulate the tool executor populating sources during generate_response + def fake_generate(**kwargs): + # Mimic what happens when the LLM calls the search tool + rag_system.tool_manager.execute_tool("search_course_content", query="RAG") + return "Here's what I found about RAG." + + rag_system.ai_generator.generate_response.side_effect = fake_generate + + response, sources = rag_system.query("What is RAG?") + + assert response == "Here's what I found about RAG." + titles = [s["title"] for s in sources] + assert any("Building RAG Chatbots" in t for t in titles) + + def test_query_resets_sources_before_each_query(self, rag_system): + """reset_sources() should be called at the start of each query.""" + # Seed leftover sources from a "previous" query + rag_system.search_tool.last_sources = [{"title": "stale", "link": None}] + + # The new query should reset before doing anything + rag_system.query("What is RAG?") + + # After the query, the stale source should be gone. + # Since generate_response is a plain mock (no tool calls), sources should be empty. + _, sources = rag_system.query("Second question") + assert all(s["title"] != "stale" for s in sources) + + +# --------------------------------------------------------------------------- +# Error handling tests +# --------------------------------------------------------------------------- + +class TestRAGQueryErrorHandling: + """Tests that expose error handling issues in the RAG pipeline.""" + + def test_ollama_connection_error_returns_friendly_message(self, rag_system_with_ollama_error): + """Ollama being down returns a friendly error instead of crashing.""" + response, sources = rag_system_with_ollama_error.query("What is RAG?") + + assert "trouble generating a response" in response + + +# --------------------------------------------------------------------------- +# Session handling tests +# --------------------------------------------------------------------------- + +class TestRAGSessionHandling: + """Test conversation session integration.""" + + def test_query_without_session(self, rag_system): + """Query without session_id should still work (no history).""" + response, sources = rag_system.query("What is RAG?") + + call_args = rag_system.ai_generator.generate_response.call_args + history = call_args.kwargs.get("conversation_history") + assert history is None + + def test_query_with_new_session(self, rag_system): + """First query with a session_id should have no history.""" + response, sources = rag_system.query("What is RAG?", session_id="session_1") + + assert response is not None + + def test_query_with_existing_session(self, rag_system): + """Second query in same session should include history from first query.""" + rag_system.query("What is RAG?", session_id="session_1") + rag_system.query("Tell me more", session_id="session_1") + + # Second call should have conversation history + calls = rag_system.ai_generator.generate_response.call_args_list + second_call = calls[1] + history = second_call.kwargs.get("conversation_history") + assert history is not None + assert "What is RAG?" in history diff --git a/backend/tests/test_search_tools.py b/backend/tests/test_search_tools.py new file mode 100644 index 000000000..a190c5346 --- /dev/null +++ b/backend/tests/test_search_tools.py @@ -0,0 +1,205 @@ +"""Tests for CourseSearchTool.execute() in search_tools.py. + +These tests mock the VectorStore to isolate the search tool logic and verify: +- Correct result formatting +- Error handling +- Filter pass-through (course_name, lesson_number) +- Source tracking and deduplication +""" + +import pytest +from unittest.mock import MagicMock, call +from search_tools import CourseSearchTool, CourseOutlineTool, ToolManager +from vector_store import SearchResults + + +# --------------------------------------------------------------------------- +# CourseSearchTool.execute() tests +# --------------------------------------------------------------------------- + +class TestCourseSearchToolExecute: + """Tests for the execute method of CourseSearchTool.""" + + def test_execute_returns_formatted_results(self, mock_vector_store, sample_search_results): + """Successful search should return formatted text with course/lesson headers.""" + mock_vector_store.search.return_value = sample_search_results + tool = CourseSearchTool(mock_vector_store) + + result = tool.execute(query="What is RAG?") + + assert "[Building RAG Chatbots - Lesson 1]" in result + assert "[Building RAG Chatbots - Lesson 2]" in result + assert "RAG stands for Retrieval-Augmented Generation" in result + assert "Vector databases store embeddings" in result + + def test_execute_with_empty_results(self, mock_vector_store, empty_search_results): + """Empty results should return a 'no content found' message.""" + mock_vector_store.search.return_value = empty_search_results + tool = CourseSearchTool(mock_vector_store) + + result = tool.execute(query="nonexistent topic") + + assert "No relevant content found" in result + + def test_execute_with_error(self, mock_vector_store, error_search_results): + """When VectorStore returns an error, execute should return the error string.""" + mock_vector_store.search.return_value = error_search_results + tool = CourseSearchTool(mock_vector_store) + + result = tool.execute(query="anything") + + assert "Search error" in result + + def test_execute_passes_course_name_filter(self, mock_vector_store, sample_search_results): + """course_name parameter should be forwarded to VectorStore.search().""" + mock_vector_store.search.return_value = sample_search_results + tool = CourseSearchTool(mock_vector_store) + + tool.execute(query="RAG", course_name="Building RAG Chatbots") + + mock_vector_store.search.assert_called_once_with( + query="RAG", + course_name="Building RAG Chatbots", + lesson_number=None, + ) + + def test_execute_passes_lesson_number_filter(self, mock_vector_store, sample_search_results): + """lesson_number parameter should be forwarded to VectorStore.search().""" + mock_vector_store.search.return_value = sample_search_results + tool = CourseSearchTool(mock_vector_store) + + tool.execute(query="RAG", lesson_number=2) + + mock_vector_store.search.assert_called_once_with( + query="RAG", + course_name=None, + lesson_number=2, + ) + + def test_execute_passes_both_filters(self, mock_vector_store, sample_search_results): + """Both course_name and lesson_number should be forwarded together.""" + mock_vector_store.search.return_value = sample_search_results + tool = CourseSearchTool(mock_vector_store) + + tool.execute(query="RAG", course_name="RAG Course", lesson_number=3) + + mock_vector_store.search.assert_called_once_with( + query="RAG", + course_name="RAG Course", + lesson_number=3, + ) + + def test_execute_tracks_sources(self, mock_vector_store, sample_search_results): + """After a successful search, last_sources should contain source info.""" + mock_vector_store.search.return_value = sample_search_results + tool = CourseSearchTool(mock_vector_store) + + tool.execute(query="What is RAG?") + + assert len(tool.last_sources) > 0 + assert tool.last_sources[0]["title"] == "Building RAG Chatbots - Lesson 1" + assert tool.last_sources[0]["link"] is not None + + def test_execute_deduplicates_sources(self, mock_vector_store): + """Multiple chunks from same course+lesson should produce one source entry.""" + mock_vector_store.search.return_value = SearchResults( + documents=["chunk 1 from lesson 1", "chunk 2 from lesson 1"], + metadata=[ + {"course_title": "Test Course", "lesson_number": 1, "chunk_index": 0}, + {"course_title": "Test Course", "lesson_number": 1, "chunk_index": 1}, + ], + distances=[0.1, 0.2], + ) + mock_vector_store.get_lesson_link.return_value = "https://example.com/l1" + tool = CourseSearchTool(mock_vector_store) + + tool.execute(query="test") + + assert len(tool.last_sources) == 1 + assert tool.last_sources[0]["title"] == "Test Course - Lesson 1" + + def test_execute_empty_results_includes_filter_info(self, mock_vector_store, empty_search_results): + """Empty results message should mention the active filters.""" + mock_vector_store.search.return_value = empty_search_results + tool = CourseSearchTool(mock_vector_store) + + result = tool.execute(query="test", course_name="MCP", lesson_number=5) + + assert "in course 'MCP'" in result + assert "in lesson 5" in result + + def test_execute_error_does_not_update_sources(self, mock_vector_store, error_search_results): + """When search errors, last_sources should NOT be updated (stays from init).""" + mock_vector_store.search.return_value = error_search_results + tool = CourseSearchTool(mock_vector_store) + + tool.execute(query="anything") + + # last_sources should still be the initial empty list + assert tool.last_sources == [] + + def test_execute_handles_metadata_without_lesson_number(self, mock_vector_store): + """Results without lesson_number should format without ' - Lesson X'.""" + mock_vector_store.search.return_value = SearchResults( + documents=["some general content"], + metadata=[{"course_title": "Intro Course", "chunk_index": 0}], + distances=[0.3], + ) + tool = CourseSearchTool(mock_vector_store) + + result = tool.execute(query="general") + + assert "[Intro Course]" in result + assert "Lesson" not in result.split("]")[0] # no lesson in header + + +# --------------------------------------------------------------------------- +# ToolManager tests +# --------------------------------------------------------------------------- + +class TestToolManager: + """Tests for ToolManager registration and execution.""" + + def test_register_and_execute_tool(self, mock_vector_store, sample_search_results): + """Registered tools should be executable by name.""" + mock_vector_store.search.return_value = sample_search_results + manager = ToolManager() + tool = CourseSearchTool(mock_vector_store) + manager.register_tool(tool) + + result = manager.execute_tool("search_course_content", query="RAG") + + assert result # non-empty string + mock_vector_store.search.assert_called_once() + + def test_execute_unknown_tool(self): + """Executing an unregistered tool should return an error message.""" + manager = ToolManager() + + result = manager.execute_tool("nonexistent_tool", query="test") + + assert "not found" in result + + def test_get_last_sources(self, mock_vector_store, sample_search_results): + """get_last_sources should return sources from the most recent search.""" + mock_vector_store.search.return_value = sample_search_results + manager = ToolManager() + tool = CourseSearchTool(mock_vector_store) + manager.register_tool(tool) + + manager.execute_tool("search_course_content", query="RAG") + sources = manager.get_last_sources() + + assert len(sources) > 0 + + def test_reset_sources(self, mock_vector_store, sample_search_results): + """reset_sources should clear sources from all tools.""" + mock_vector_store.search.return_value = sample_search_results + manager = ToolManager() + tool = CourseSearchTool(mock_vector_store) + manager.register_tool(tool) + + manager.execute_tool("search_course_content", query="RAG") + manager.reset_sources() + + assert manager.get_last_sources() == [] diff --git a/backend/vector_store.py b/backend/vector_store.py index 390abe71c..cd95f75d1 100644 --- a/backend/vector_store.py +++ b/backend/vector_store.py @@ -246,6 +246,30 @@ def get_course_link(self, course_title: str) -> Optional[str]: print(f"Error getting course link: {e}") return None + def get_course_outline(self, course_name: str) -> Optional[Dict[str, Any]]: + """Get full course outline (title, link, lessons) by fuzzy course name match""" + import json + # Resolve to exact course title + course_title = self._resolve_course_name(course_name) + if not course_title: + return None + + try: + results = self.course_catalog.get(ids=[course_title]) + if results and results['metadatas']: + meta = results['metadatas'][0] + lessons = [] + if meta.get('lessons_json'): + lessons = json.loads(meta['lessons_json']) + return { + "title": meta.get('title', course_title), + "course_link": meta.get('course_link'), + "lessons": lessons + } + except Exception as e: + print(f"Error getting course outline: {e}") + return None + def get_lesson_link(self, course_title: str, lesson_number: int) -> Optional[str]: """Get lesson link for a given course title and lesson number""" import json diff --git a/frontend/index.html b/frontend/index.html index f8e25a62f..d80dcedeb 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -7,9 +7,25 @@