diff --git a/codeflow_engine/actions/__init__.py b/codeflow_engine/actions/__init__.py index 5178031..4d2faff 100644 --- a/codeflow_engine/actions/__init__.py +++ b/codeflow_engine/actions/__init__.py @@ -186,6 +186,13 @@ except ImportError: pass +# Create llm alias for backward compatibility (codeflow_engine.actions.llm) +llm = None +try: + from codeflow_engine.actions.ai_actions import llm +except ImportError: + pass + # Platform PlatformDetector: type[Any] | None = None try: @@ -223,6 +230,7 @@ "generation", "git", "issues", + "llm", # Backward compatibility alias for ai_actions.llm "maintenance", "platform", "quality", diff --git a/codeflow_engine/actions/ai_actions/llm/__init__.py b/codeflow_engine/actions/ai_actions/llm/__init__.py index b188df0..cc2614d 100644 --- a/codeflow_engine/actions/ai_actions/llm/__init__.py +++ b/codeflow_engine/actions/ai_actions/llm/__init__.py @@ -1,4 +1,4 @@ -""" +""" CODEFLOW LLM Package - Modular LLM provider system. This package provides a unified interface for multiple LLM providers including: @@ -12,7 +12,7 @@ Usage:: - from codeflow_engine.actions.llm import get_llm_provider_manager, complete_chat + from codeflow_engine.actions.ai_actions.llm import get_llm_provider_manager, complete_chat # Get a manager instance manager = get_llm_provider_manager() @@ -28,31 +28,41 @@ import os from typing import Any -# Export base classes -from codeflow_engine.actions.llm.base import BaseLLMProvider +# Export base classes from core +from codeflow_engine.core.llm import ( + BaseLLMProvider, + LLMProviderRegistry, + LLMResponse, + OpenAICompatibleProvider, +) # Export manager -from codeflow_engine.actions.llm.manager import ActionLLMProviderManager +from codeflow_engine.actions.ai_actions.llm.manager import ActionLLMProviderManager # Export providers -from codeflow_engine.actions.llm.providers import ( +from codeflow_engine.actions.ai_actions.llm.providers import ( AnthropicProvider, + AzureOpenAIProvider, GroqProvider, - MistralProvider, OpenAIProvider, PerplexityProvider, TogetherAIProvider, + MISTRAL_AVAILABLE, ) # Export types -from codeflow_engine.actions.llm.types import ( +from codeflow_engine.actions.ai_actions.llm.types import ( LLMConfig, LLMProviderType, - LLMResponse, Message, MessageRole, ) +# Conditionally import MistralProvider +MistralProvider = None +if MISTRAL_AVAILABLE: + from codeflow_engine.actions.ai_actions.llm.providers import MistralProvider + # Global provider manager instance _provider_manager: ActionLLMProviderManager | None = None @@ -63,7 +73,7 @@ def get_llm_provider_manager() -> ActionLLMProviderManager: Get or create the global LLM provider manager with configuration from environment variables. Returns: - LLMProviderManager: A configured instance of LLMProviderManager + ActionLLMProviderManager: A configured instance of LLMProviderManager """ global _provider_manager @@ -166,26 +176,35 @@ def complete_chat( return manager.complete(request) +# Backward compatibility alias +LLMProviderManager = ActionLLMProviderManager + + # Export all public components __all__ = [ - "AnthropicProvider", # Base classes "BaseLLMProvider", - "GroqProvider", - "LLMConfig", + "OpenAICompatibleProvider", # Manager "ActionLLMProviderManager", + "LLMProviderManager", # Backward compatibility + # Registry + "LLMProviderRegistry", + # Types + "LLMConfig", "LLMProviderType", "LLMResponse", "Message", - # Types "MessageRole", - "MistralProvider", # Providers + "AnthropicProvider", + "AzureOpenAIProvider", + "GroqProvider", + "MistralProvider", "OpenAIProvider", "PerplexityProvider", "TogetherAIProvider", - "complete_chat", # Factory functions + "complete_chat", "get_llm_provider_manager", ] diff --git a/codeflow_engine/actions/ai_actions/llm/base.py b/codeflow_engine/actions/ai_actions/llm/base.py index c0db1be..8384708 100644 --- a/codeflow_engine/actions/ai_actions/llm/base.py +++ b/codeflow_engine/actions/ai_actions/llm/base.py @@ -1,28 +1,11 @@ """ Abstract base class for LLM providers. -""" - -from abc import ABC, abstractmethod -import os -from typing import Any - -from codeflow_engine.actions.llm.types import LLMResponse +This module re-exports from codeflow_engine.core.llm for backwards compatibility. +New code should import directly from codeflow_engine.core.llm. +""" -class BaseLLMProvider(ABC): - """Abstract base class for LLM providers.""" - - def __init__(self, config: dict[str, Any]) -> None: - self.config = config - self.api_key = config.get("api_key") or os.getenv(config.get("api_key_env", "")) - self.base_url = config.get("base_url") - self.default_model = config.get("default_model") - self.name = config.get("name", self.__class__.__name__.lower().replace("provider", "")) - - @abstractmethod - def complete(self, request: dict[str, Any]) -> LLMResponse: - """Complete a chat conversation.""" +# Re-export from core for backwards compatibility +from codeflow_engine.core.llm.base import BaseLLMProvider - @abstractmethod - def is_available(self) -> bool: - """Check if the provider is properly configured and available.""" +__all__ = ["BaseLLMProvider"] diff --git a/codeflow_engine/actions/ai_actions/llm/manager.py b/codeflow_engine/actions/ai_actions/llm/manager.py index 18dfd2f..09fd9ba 100644 --- a/codeflow_engine/actions/ai_actions/llm/manager.py +++ b/codeflow_engine/actions/ai_actions/llm/manager.py @@ -1,26 +1,75 @@ """ LLM Provider Manager - Manages multiple LLM providers with fallback support. + +Uses LLMProviderRegistry for dynamic provider creation (Open/Closed principle). """ import logging import os from typing import Any -from codeflow_engine.actions.llm.base import BaseLLMProvider -from codeflow_engine.actions.llm.providers import (AnthropicProvider, GroqProvider, - MistralProvider, OpenAIProvider, - PerplexityProvider, - TogetherAIProvider) -from codeflow_engine.actions.llm.providers.azure_openai import AzureOpenAIProvider -from codeflow_engine.actions.llm.types import LLMResponse +from codeflow_engine.core.llm import BaseLLMProvider, LLMProviderRegistry, LLMResponse + +# Import providers to trigger their registration +from codeflow_engine.actions.ai_actions.llm import providers as _ # noqa: F401 logger = logging.getLogger(__name__) class ActionLLMProviderManager: - """Manages multiple LLM providers with fallback support for action-driven use cases.""" + """ + Manages multiple LLM providers with fallback support for action-driven use cases. + + Uses LLMProviderRegistry for dynamic provider instantiation, following the + Open/Closed principle - new providers can be added without modifying this class. + """ + + # Default provider configurations (can be extended via registry) + DEFAULT_PROVIDER_CONFIGS: dict[str, dict[str, Any]] = { + "openai": { + "api_key_env": "OPENAI_API_KEY", + "default_model": "gpt-4", + }, + "azure_openai": { + "api_key_env": "AZURE_OPENAI_API_KEY", + "default_model": "gpt-5-chat", + "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT", ""), + "api_version": "2024-02-01", + "deployment_name": os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "gpt-5-chat"), + }, + "anthropic": { + "api_key_env": "ANTHROPIC_API_KEY", + "default_model": "claude-3-sonnet-20240229", + }, + "mistral": { + "api_key_env": "MISTRAL_API_KEY", + "default_model": "mistral-large-latest", + }, + "groq": { + "api_key_env": "GROQ_API_KEY", + "default_model": "mixtral-8x7b-32768", + }, + "perplexity": { + "api_key_env": "PERPLEXITY_API_KEY", + "default_model": "llama-3.1-sonar-large-128k-online", + }, + "together": { + "api_key_env": "TOGETHER_API_KEY", + "default_model": "meta-llama/Llama-2-70b-chat-hf", + }, + } def __init__(self, config: dict[str, Any], display=None) -> None: + """ + Initialize the manager with configuration. + + Args: + config: Configuration dictionary containing: + - fallback_order: List of provider names to try as fallbacks + - default_provider: Name of the default provider + - providers: Dict of provider-specific configurations + display: Optional display object for user feedback + """ self.providers: dict[str, BaseLLMProvider] = {} self.fallback_order: list[str] = config.get( "fallback_order", ["azure_openai", "openai", "anthropic", "mistral"] @@ -28,96 +77,56 @@ def __init__(self, config: dict[str, Any], display=None) -> None: self.default_provider: str = config.get("default_provider", "azure_openai") self.display = display - # Initialize providers based on configuration + # Get user-provided provider configs provider_configs: dict[str, dict[str, Any]] = config.get("providers", {}) - # Default provider configurations - default_configs: dict[str, dict[str, Any]] = { - "openai": { - "api_key_env": "OPENAI_API_KEY", - "default_model": "gpt-4", - "base_url": None, - }, - "azure_openai": { - "api_key_env": "AZURE_OPENAI_API_KEY", - "default_model": "gpt-5-chat", - "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT", "https:///"), - "api_version": "2024-02-01", - "deployment_name": os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "gpt-5-chat"), - }, - "anthropic": { - "api_key_env": "ANTHROPIC_API_KEY", - "default_model": "claude-3-sonnet-20240229", - "base_url": None, - }, - "mistral": { - "api_key_env": "MISTRAL_API_KEY", - "default_model": "mistral-large-latest", - "base_url": None, - }, - "groq": { - "api_key_env": "GROQ_API_KEY", - "default_model": "mixtral-8x7b-32768", - "base_url": None, - }, - "perplexity": { - "api_key_env": "PERPLEXITY_API_KEY", - "default_model": "llama-3.1-sonar-large-128k-online", - "base_url": None, - }, - "together": { - "api_key_env": "TOGETHER_API_KEY", - "default_model": "meta-llama/Llama-2-70b-chat-hf", - "base_url": None, - }, - } + # Initialize providers using registry (Open/Closed principle) + self._initialize_providers(provider_configs) - # Merge user configs with defaults - for provider_name, default_config in default_configs.items(): - user_config: dict[str, Any] = provider_configs.get(provider_name, {}) - merged_config: dict[str, Any] = {**default_config, **user_config} + def _initialize_providers(self, user_configs: dict[str, dict[str, Any]]) -> None: + """ + Initialize all registered providers. + + Args: + user_configs: User-provided configuration overrides + """ + # Get all providers from registry plus defaults + provider_names = set(LLMProviderRegistry.list_providers()) | set(self.DEFAULT_PROVIDER_CONFIGS.keys()) + + for provider_name in provider_names: + # Merge configs: defaults -> registry defaults -> user config + default_config = self.DEFAULT_PROVIDER_CONFIGS.get(provider_name, {}) + registry_config = LLMProviderRegistry.get_default_config(provider_name) + user_config = user_configs.get(provider_name, {}) + merged_config = {**default_config, **registry_config, **user_config} - # Initialize provider try: - if provider_name == "openai": - self.providers[provider_name] = OpenAIProvider(merged_config) - elif provider_name == "azure_openai": - self.providers[provider_name] = AzureOpenAIProvider(merged_config) - elif provider_name == "anthropic": - self.providers[provider_name] = AnthropicProvider(merged_config) - elif provider_name == "mistral": - self.providers[provider_name] = MistralProvider(merged_config) - elif provider_name == "groq": - self.providers[provider_name] = GroqProvider(merged_config) - elif provider_name == "perplexity": - self.providers[provider_name] = PerplexityProvider(merged_config) - elif provider_name == "together": - self.providers[provider_name] = TogetherAIProvider(merged_config) + # Create provider using registry + provider = LLMProviderRegistry.create(provider_name, merged_config) + if provider is not None: + self.providers[provider_name] = provider except Exception as e: - # Log the error for debugging - logger.debug(f"Failed to initialize {provider_name} provider: {e}") - - # Only show warning if this provider is in the fallback order or is the default - if ( - provider_name in self.fallback_order - or provider_name == self.default_provider - ): - if self.display: - self.display.error.show_warning( - f"Provider {provider_name} not available: {e!s}" - ) - else: - # Fallback to logger if no display is available - logger.warning( - f"Failed to initialize {provider_name} provider: {e}" - ) + self._handle_provider_init_error(provider_name, e) + + def _handle_provider_init_error(self, provider_name: str, error: Exception) -> None: + """Handle provider initialization errors with appropriate logging.""" + logger.debug(f"Failed to initialize {provider_name} provider: {error}") + + # Only show warning if this provider is important + if provider_name in self.fallback_order or provider_name == self.default_provider: + if self.display: + self.display.error.show_warning( + f"Provider {provider_name} not available: {error!s}" + ) + else: + logger.warning(f"Failed to initialize {provider_name} provider: {error}") def get_provider(self, provider_name: str) -> BaseLLMProvider | None: """ Get a provider by name. Args: - provider_name: Name of the provider to retrieve (e.g., 'openai', 'anthropic') + provider_name: Name of the provider to retrieve Returns: The provider instance if found and available, None otherwise @@ -127,7 +136,6 @@ def get_provider(self, provider_name: str) -> BaseLLMProvider | None: return provider return None - # Backward-compatibility helpers used in tests def get_llm(self, provider_name: str) -> BaseLLMProvider | None: """Alias for get_provider to satisfy older code/tests.""" return self.get_provider(provider_name) @@ -137,16 +145,15 @@ def complete(self, request: dict[str, Any]) -> LLMResponse: Complete a chat conversation using the specified or default provider with fallback. Args: - request: Dictionary containing the request parameters including: + request: Dictionary containing: - provider: Optional provider name to use - - model: Optional model name to use - - messages: List of message dictionaries with 'role' and 'content' + - model: Optional model name + - messages: List of message dictionaries - Other provider-specific parameters Returns: LLMResponse containing the completion response or error """ - # Make a copy of the request to avoid modifying the original request = request.copy() # Get the requested provider or use default @@ -160,23 +167,15 @@ def complete(self, request: dict[str, Any]) -> LLMResponse: # Try to get the requested provider provider = self.get_provider(provider_name) if provider is None: - # Try fallback providers - for fallback_name in self.fallback_order: - if fallback_name != provider_name: - fallback_provider = self.get_provider(fallback_name) - if fallback_provider is not None: - logger.info( - f"Using fallback provider '{fallback_name}' instead of '{provider_name}'" - ) - provider = fallback_provider - - if provider is None: - return LLMResponse.from_error( - f"Provider '{provider_name}' not found or not available, and no fallback providers available", - request.get("model") or "unknown", - ) + provider = self._get_fallback_provider(provider_name) - # Ensure required fields are present + if provider is None: + return LLMResponse.from_error( + f"Provider '{provider_name}' not found or not available, and no fallback providers available", + request.get("model") or "unknown", + ) + + # Validate request if "messages" not in request: return LLMResponse.from_error( "Missing required field 'messages' in request", @@ -188,13 +187,32 @@ def complete(self, request: dict[str, Any]) -> LLMResponse: request["model"] = provider.default_model try: - # Call the provider's complete method return provider.complete(request) except Exception as e: error_msg = f"Error calling provider '{provider_name}': {e!s}" logger.exception(error_msg) return LLMResponse.from_error(error_msg, request.get("model") or "unknown") + def _get_fallback_provider(self, failed_provider: str) -> BaseLLMProvider | None: + """ + Get a fallback provider when the requested one fails. + + Args: + failed_provider: Name of the provider that failed + + Returns: + A working fallback provider or None + """ + for fallback_name in self.fallback_order: + if fallback_name != failed_provider: + fallback_provider = self.get_provider(fallback_name) + if fallback_provider is not None: + logger.info( + f"Using fallback provider '{fallback_name}' instead of '{failed_provider}'" + ) + return fallback_provider + return None + def get_available_providers(self) -> list[str]: """Get list of available providers.""" return [ @@ -220,4 +238,4 @@ def get_provider_info(self) -> dict[str, Any]: # Backward compatibility alias -LLMProviderManager = ActionLLMProviderManager \ No newline at end of file +LLMProviderManager = ActionLLMProviderManager diff --git a/codeflow_engine/actions/ai_actions/llm/providers/__init__.py b/codeflow_engine/actions/ai_actions/llm/providers/__init__.py index 916e2a8..ba8e78d 100644 --- a/codeflow_engine/actions/ai_actions/llm/providers/__init__.py +++ b/codeflow_engine/actions/ai_actions/llm/providers/__init__.py @@ -1,198 +1,118 @@ """ LLM Providers package - Individual provider implementations. + +All providers auto-register themselves with LLMProviderRegistry on import. """ from typing import Any -# Import base class for inline implementations -from codeflow_engine.actions.llm.base import BaseLLMProvider -from codeflow_engine.actions.llm.providers.anthropic import AnthropicProvider -from codeflow_engine.actions.llm.providers.groq import GroqProvider -from codeflow_engine.actions.llm.types import LLMResponse - - -# Optional AI providers +# Import core components +from codeflow_engine.core.llm import ( + BaseLLMProvider, + LLMProviderRegistry, + LLMResponse, + OpenAICompatibleProvider, +) + +# Import providers (they auto-register on import) +from codeflow_engine.actions.ai_actions.llm.providers.openai import OpenAIProvider +from codeflow_engine.actions.ai_actions.llm.providers.anthropic import AnthropicProvider +from codeflow_engine.actions.ai_actions.llm.providers.groq import GroqProvider +from codeflow_engine.actions.ai_actions.llm.providers.azure_openai import AzureOpenAIProvider + +# Optional providers with graceful fallback +MistralProvider = None +MISTRAL_AVAILABLE = False try: - from codeflow_engine.actions.llm.providers.mistral import MistralProvider - + from codeflow_engine.actions.ai_actions.llm.providers.mistral import MistralProvider MISTRAL_AVAILABLE = True except ImportError: - MistralProvider = None - MISTRAL_AVAILABLE = False -from codeflow_engine.actions.llm.providers.openai import OpenAIProvider + pass + +# OpenAI-compatible providers that only need a custom base URL +class PerplexityProvider(OpenAICompatibleProvider): + """ + Perplexity AI provider. -class PerplexityProvider(BaseLLMProvider): - """Perplexity AI provider.""" + Uses OpenAI-compatible API with custom base URL. + """ - def __init__(self, config: dict[str, Any]) -> None: - super().__init__(config) + DEFAULT_MODEL = "llama-3.1-sonar-large-128k-online" + LIBRARY_NAME = "perplexity" + + def _initialize_client(self) -> None: + """Initialize the Perplexity client using OpenAI library.""" try: import openai self.client = openai.OpenAI( - api_key=self.api_key, base_url="https://api.perplexity.ai" + api_key=self.api_key, + base_url="https://api.perplexity.ai", ) self.available = True except ImportError: self.available = False - def complete(self, request: dict[str, Any]) -> LLMResponse: - try: - messages = request.get("messages", []) - model = ( - request.get("model", self.default_model) - or "llama-3.1-sonar-large-128k-online" - ) - max_tokens = request.get("max_tokens", 1024) - temperature = request.get("temperature", 0.7) - - # Filter out empty messages - filtered_messages = [ - {"role": m.get("role", "user"), "content": m.get("content", "")} - for m in messages - if m.get("content") - ] - - # Call the API - response = self.client.chat.completions.create( - model=str(model), - messages=filtered_messages, # type: ignore[arg-type] - max_tokens=max_tokens, - temperature=temperature, - ) - # Extract content and finish reason - content = "" - finish_reason = "stop" - if ( - hasattr(response, "choices") - and response.choices - and len(response.choices) > 0 - ): - choice = response.choices[0] - if hasattr(choice, "message") and hasattr(choice.message, "content"): - content = choice.message.content or "" - finish_reason = getattr(choice, "finish_reason", "stop") or "stop" - - # Extract usage information - usage = None - if hasattr(response, "usage") and response.usage: - usage = { - "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), - "completion_tokens": getattr( - response.usage, "completion_tokens", 0 - ), - "total_tokens": getattr(response.usage, "total_tokens", 0), - } - - return LLMResponse( - content=content, - model=str(model), - finish_reason=finish_reason, - usage=usage, - ) - - except Exception as e: - return LLMResponse.from_error( - f"Error calling Perplexity API: {e!s}", - str(request.get("model") or "llama-3.1-sonar-large-128k-online"), - ) - - def is_available(self) -> bool: - return self.available and bool(self.api_key) +class TogetherAIProvider(OpenAICompatibleProvider): + """ + Together AI provider for open source models. + Uses OpenAI-compatible API with custom base URL. + """ -class TogetherAIProvider(BaseLLMProvider): - """Together AI provider for open source models.""" + DEFAULT_MODEL = "meta-llama/Llama-2-70b-chat-hf" + LIBRARY_NAME = "together" - def __init__(self, config: dict[str, Any]) -> None: - super().__init__(config) + def _initialize_client(self) -> None: + """Initialize the Together AI client using OpenAI library.""" try: import openai self.client = openai.OpenAI( - api_key=self.api_key, base_url="https://api.together.xyz/v1" + api_key=self.api_key, + base_url="https://api.together.xyz/v1", ) self.available = True except ImportError: self.available = False - def complete(self, request: dict[str, Any]) -> LLMResponse: - try: - messages = request.get("messages", []) - model = ( - request.get("model", self.default_model) - or "meta-llama/Llama-2-70b-chat-hf" - ) - max_tokens = request.get("max_tokens", 1024) - temperature = request.get("temperature", 0.7) - - # Filter out empty messages - filtered_messages = [ - {"role": m.get("role", "user"), "content": m.get("content", "")} - for m in messages - if m.get("content") - ] - - # Call the API - response = self.client.chat.completions.create( - model=str(model), - messages=filtered_messages, # type: ignore[arg-type] - max_tokens=max_tokens, - temperature=temperature, - ) - - # Extract content and finish reason - content = "" - finish_reason = "stop" - if ( - hasattr(response, "choices") - and response.choices - and len(response.choices) > 0 - ): - choice = response.choices[0] - if hasattr(choice, "message") and hasattr(choice.message, "content"): - content = choice.message.content or "" - finish_reason = getattr(choice, "finish_reason", "stop") or "stop" - - # Extract usage information - usage = None - if hasattr(response, "usage") and response.usage: - usage = { - "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), - "completion_tokens": getattr( - response.usage, "completion_tokens", 0 - ), - "total_tokens": getattr(response.usage, "total_tokens", 0), - } - - return LLMResponse( - content=content, - model=str(model), - finish_reason=finish_reason, - usage=usage, - ) - except Exception as e: - return LLMResponse.from_error( - f"Error calling Together AI API: {e!s}", - str(request.get("model") or "meta-llama/Llama-2-70b-chat-hf"), - ) +# Register the inline providers +LLMProviderRegistry.register( + "perplexity", + PerplexityProvider, + default_config={ + "api_key_env": "PERPLEXITY_API_KEY", + "default_model": "llama-3.1-sonar-large-128k-online", + }, +) - def is_available(self) -> bool: - return self.available and bool(self.api_key) +LLMProviderRegistry.register( + "together", + TogetherAIProvider, + default_config={ + "api_key_env": "TOGETHER_API_KEY", + "default_model": "meta-llama/Llama-2-70b-chat-hf", + }, +) # Export all providers __all__ = [ "AnthropicProvider", + "AzureOpenAIProvider", + "BaseLLMProvider", "GroqProvider", + "LLMProviderRegistry", + "LLMResponse", + "OpenAICompatibleProvider", "OpenAIProvider", "PerplexityProvider", "TogetherAIProvider", ] # Add MistralProvider if available -if MISTRAL_AVAILABLE: +if MISTRAL_AVAILABLE and MistralProvider is not None: __all__.append("MistralProvider") diff --git a/codeflow_engine/actions/ai_actions/llm/providers/anthropic.py b/codeflow_engine/actions/ai_actions/llm/providers/anthropic.py index 6307a7a..10efcd7 100644 --- a/codeflow_engine/actions/ai_actions/llm/providers/anthropic.py +++ b/codeflow_engine/actions/ai_actions/llm/providers/anthropic.py @@ -1,18 +1,32 @@ """ Anthropic Claude provider implementation. + +Uses BaseLLMProvider with Anthropic-specific message handling. """ from typing import Any -from codeflow_engine.actions.llm.base import BaseLLMProvider -from codeflow_engine.actions.llm.types import LLMResponse +from codeflow_engine.core.llm import BaseLLMProvider, LLMResponse, LLMProviderRegistry +from codeflow_engine.core.llm.response import ResponseExtractor class AnthropicProvider(BaseLLMProvider): - """Anthropic Claude provider.""" + """ + Anthropic Claude provider. + + Anthropic has a different API format than OpenAI, requiring custom + message conversion and response extraction. + """ + + DEFAULT_MODEL = "claude-3-sonnet-20240229" def __init__(self, config: dict[str, Any]) -> None: super().__init__(config) + self.client: Any = None + self._initialize_client() + + def _initialize_client(self) -> None: + """Initialize the Anthropic client.""" try: import anthropic @@ -23,66 +37,69 @@ def __init__(self, config: dict[str, Any]) -> None: except ImportError: self.available = False + def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, str]]]: + """ + Convert messages to Anthropic format, extracting system prompt. + + Args: + messages: List of message dicts with role and content + + Returns: + Tuple of (system_prompt, converted_messages) + """ + system_prompt = "" + converted_messages: list[dict[str, str]] = [] + + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + if not content: + continue + + if role == "system": + system_prompt += content + "\n" + else: + converted_messages.append({"role": role, "content": content}) + + return system_prompt.strip(), converted_messages + def complete(self, request: dict[str, Any]) -> LLMResponse: + """Complete a chat conversation using Anthropic's API.""" + if not self.client: + return LLMResponse.from_error( + "Anthropic client not initialized", + self.get_model(request, self.DEFAULT_MODEL), + ) + try: messages = request.get("messages", []) - model = ( - request.get("model", self.default_model) or "claude-3-sonnet-20240229" - ) + model = self.get_model(request, self.DEFAULT_MODEL) max_tokens = request.get("max_tokens", 1024) temperature = request.get("temperature", 0.7) # Convert messages to Anthropic format - system_prompt = "" - converted_messages: list[dict[str, str]] = [] + system_prompt, converted_messages = self._convert_messages(messages) - for msg in messages: - role = msg.get("role", "user") - content = msg.get("content", "") - if not content: - continue + if not converted_messages: + return LLMResponse.from_error("No valid messages provided", model) - if role == "system": - system_prompt += content + "\n" - else: - converted_messages.append({"role": role, "content": content}) + # Build API call parameters + api_params: dict[str, Any] = { + "model": str(model), + "max_tokens": max_tokens, + "temperature": temperature, + "messages": converted_messages, + } - # Prepare system parameter - use NotGiven if empty - system_param = system_prompt.strip() or None + # Only include system if not empty + if system_prompt: + api_params["system"] = system_prompt # Call the API - if system_param: - response = self.client.messages.create( - model=str(model), - max_tokens=max_tokens, - temperature=temperature, - system=system_param, - messages=converted_messages, # type: ignore[arg-type] - ) - else: - response = self.client.messages.create( - model=str(model), - max_tokens=max_tokens, - temperature=temperature, - messages=converted_messages, # type: ignore[arg-type] - ) - - # Extract content and finish reason - content = "" - if hasattr(response, "content") and response.content: - content = "\n".join( - block.text for block in response.content if hasattr(block, "text") - ) - - finish_reason = getattr(response, "stop_reason", "stop") - usage = { - "prompt_tokens": getattr(response, "usage", {}).get("input_tokens", 0), - "completion_tokens": getattr(response, "usage", {}).get( - "output_tokens", 0 - ), - "total_tokens": getattr(response, "usage", {}).get("input_tokens", 0) - + getattr(response, "usage", {}).get("output_tokens", 0), - } + response = self.client.messages.create(**api_params) + + # Extract response using centralized utility + content, finish_reason, usage = ResponseExtractor.extract_anthropic_response(response) return LLMResponse( content=content, @@ -92,10 +109,15 @@ def complete(self, request: dict[str, Any]) -> LLMResponse: ) except Exception as e: - return LLMResponse.from_error( - f"Error calling Anthropic API: {e!s}", - str(request.get("model") or "claude-3-sonnet-20240229"), - ) - - def is_available(self) -> bool: - return self.available and bool(self.api_key) + return self._create_error_response(e, request, self.DEFAULT_MODEL) + + +# Register with the provider registry +LLMProviderRegistry.register( + "anthropic", + AnthropicProvider, + default_config={ + "api_key_env": "ANTHROPIC_API_KEY", + "default_model": "claude-3-sonnet-20240229", + }, +) diff --git a/codeflow_engine/actions/ai_actions/llm/providers/azure_openai.py b/codeflow_engine/actions/ai_actions/llm/providers/azure_openai.py index b649ddb..ae9b7d4 100644 --- a/codeflow_engine/actions/ai_actions/llm/providers/azure_openai.py +++ b/codeflow_engine/actions/ai_actions/llm/providers/azure_openai.py @@ -1,29 +1,34 @@ -""" +""" Azure OpenAI Provider for CodeFlow LLM system. Supports Azure OpenAI endpoints with custom configurations. +Uses OpenAICompatibleProvider as base with Azure-specific initialization. """ import logging import os from typing import Any -from codeflow_engine.actions.llm.base import BaseLLMProvider -from codeflow_engine.actions.llm.types import LLMResponse - +from codeflow_engine.core.llm import OpenAICompatibleProvider, LLMResponse, LLMProviderRegistry logger = logging.getLogger(__name__) -class AzureOpenAIProvider(BaseLLMProvider): - """Azure OpenAI provider implementation.""" +class AzureOpenAIProvider(OpenAICompatibleProvider): + """ + Azure OpenAI provider implementation. - def __init__(self, config: dict[str, Any]) -> None: - super().__init__(config) - self.name = "azure_openai" - self.description = "Azure OpenAI API provider with custom endpoints" + Extends OpenAICompatibleProvider with Azure-specific configuration: + - Azure endpoint + - API version + - Deployment name + """ - # Azure-specific configuration + DEFAULT_MODEL = "gpt-35-turbo" + LIBRARY_NAME = "openai (Azure)" + + def __init__(self, config: dict[str, Any]) -> None: + # Azure-specific configuration - set before super().__init__ self.azure_endpoint = config.get("azure_endpoint") or os.getenv( "AZURE_OPENAI_ENDPOINT" ) @@ -33,95 +38,50 @@ def __init__(self, config: dict[str, Any]) -> None: # Use Azure-specific API key environment variable azure_api_key = config.get("api_key") or os.getenv("AZURE_OPENAI_API_KEY") if azure_api_key: - self.api_key = azure_api_key + config["api_key"] = azure_api_key - self.default_model = self.deployment_name - self._client = None + # Use deployment name as default model + config["default_model"] = self.deployment_name - def is_available(self) -> bool: - """Check if Azure OpenAI is properly configured.""" - return bool(self.api_key and self.azure_endpoint) - - def _get_client(self): - """Get or create Azure OpenAI client.""" - if self._client is None: - try: - from openai import AzureOpenAI - - self._client = AzureOpenAI( - api_key=self.api_key, - api_version=self.api_version, - azure_endpoint=self.azure_endpoint, - ) - logger.info( - f"Initialized Azure OpenAI client with endpoint: {self.azure_endpoint}" - ) - except ImportError: - logger.exception( - "openai package not installed. Install with: pip install openai" - ) - return None - except Exception as e: - logger.exception(f"Failed to initialize Azure OpenAI client: {e}") - return None - - return self._client - - def complete(self, request: dict[str, Any]) -> LLMResponse: - """Complete a chat conversation using Azure OpenAI.""" - client = self._get_client() - if not client: - return LLMResponse.from_error( - "Azure OpenAI client not available", - request.get("model", self.default_model), - ) + super().__init__(config) + self.name = "azure_openai" + + def _initialize_client(self) -> None: + """Initialize the Azure OpenAI client.""" + if not self.azure_endpoint: + logger.warning("Azure OpenAI endpoint not configured") + self.available = False + return try: - # Extract parameters - messages = request.get("messages", []) - model = request.get("model", self.deployment_name) - temperature = request.get("temperature", 0.7) - max_tokens = request.get("max_tokens", 1000) - - # Validate messages format - if not messages: - return LLMResponse.from_error("No messages provided", model) - - # Make the API call - response = client.chat.completions.create( - model=model, # This is the deployment name in Azure - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - ) + from openai import AzureOpenAI - # Extract response content - content = "" - finish_reason = "unknown" - usage = None - - if response.choices: - choice = response.choices[0] - content = choice.message.content or "" - finish_reason = choice.finish_reason or "unknown" - - if response.usage: - usage = { - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, - } - - return LLMResponse( - content=content, - model=model, - finish_reason=finish_reason, - usage=usage, + self.client = AzureOpenAI( + api_key=self.api_key, + api_version=self.api_version, + azure_endpoint=self.azure_endpoint, ) - + self.available = True + logger.info(f"Initialized Azure OpenAI client with endpoint: {self.azure_endpoint}") + except ImportError: + logger.debug("openai package not installed") + self.available = False except Exception as e: - error_msg = f"Azure OpenAI API error: {e!s}" - logger.exception(error_msg) - return LLMResponse.from_error( - error_msg, request.get("model", self.default_model) - ) + logger.exception(f"Failed to initialize Azure OpenAI client: {e}") + self.available = False + + def is_available(self) -> bool: + """Check if Azure OpenAI is properly configured.""" + return bool(self.api_key and self.azure_endpoint and self.client) + + +# Register with the provider registry +LLMProviderRegistry.register( + "azure_openai", + AzureOpenAIProvider, + default_config={ + "api_key_env": "AZURE_OPENAI_API_KEY", + "default_model": "gpt-35-turbo", + "api_version": "2024-02-01", + }, +) diff --git a/codeflow_engine/actions/ai_actions/llm/providers/groq.py b/codeflow_engine/actions/ai_actions/llm/providers/groq.py index a356fca..664aefc 100644 --- a/codeflow_engine/actions/ai_actions/llm/providers/groq.py +++ b/codeflow_engine/actions/ai_actions/llm/providers/groq.py @@ -1,12 +1,26 @@ +""" +Groq provider implementation. + +Uses the OpenAICompatibleProvider base class for common functionality. +""" + from typing import Any -from codeflow_engine.actions.llm.base import BaseLLMProvider -from codeflow_engine.actions.llm.types import LLMResponse +from codeflow_engine.core.llm import OpenAICompatibleProvider, LLMProviderRegistry + + +class GroqProvider(OpenAICompatibleProvider): + """ + Groq provider for fast inference. + Groq uses an OpenAI-compatible API format. + """ -class GroqProvider(BaseLLMProvider): - def __init__(self, config: dict[str, Any]) -> None: - super().__init__(config) + DEFAULT_MODEL = "mixtral-8x7b-32768" + LIBRARY_NAME = "groq" + + def _initialize_client(self) -> None: + """Initialize the Groq client.""" try: from groq import Groq @@ -15,79 +29,13 @@ def __init__(self, config: dict[str, Any]) -> None: except ImportError: self.available = False - def _convert_to_provider_messages( - self, messages: list[dict[str, Any]], provider: str - ) -> list[dict[str, Any]]: - return [ - { - "role": str(msg.get("role", "user")), - "content": str(msg.get("content", "")), - **({k: v for k, v in msg.items() if k not in {"role", "content"}}), - } - for msg in messages - if msg.get("content", "").strip() - ] - - async def _call_groq_api( - self, messages: list[dict[str, Any]], **kwargs: Any - ) -> Any: - groq_messages = self._convert_to_provider_messages(messages, "groq") - return self.client.chat.completions.create(messages=groq_messages, **kwargs) # type: ignore[arg-type] - - def complete(self, request: dict[str, Any]) -> LLMResponse: - try: - messages = request.get("messages", []) - model = request.get("model", self.default_model) or "mixtral-8x7b-32768" - max_tokens = request.get("max_tokens", 1024) - temperature = request.get("temperature", 0.7) - - filtered_messages = [ - {"role": m.get("role", "user"), "content": m.get("content", "")} - for m in messages - if m.get("content") - ] - - response = self.client.chat.completions.create( # type: ignore[arg-type] - model=str(model), - messages=filtered_messages, - max_tokens=max_tokens, - temperature=temperature, - ) - - content = "" - finish_reason = "stop" - if ( - hasattr(response, "choices") - and response.choices - and len(response.choices) > 0 - ): - choice = response.choices[0] - if hasattr(choice, "message") and hasattr(choice.message, "content"): - content = choice.message.content or "" - finish_reason = getattr(choice, "finish_reason", "stop") or "stop" - - usage = None - if hasattr(response, "usage") and response.usage: - usage = { - "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), - "completion_tokens": getattr( - response.usage, "completion_tokens", 0 - ), - "total_tokens": getattr(response.usage, "total_tokens", 0), - } - - return LLMResponse( - content=content, - model=str(model), - finish_reason=finish_reason, - usage=usage, - ) - - except Exception as e: - return LLMResponse.from_error( - f"Error calling Groq API: {e!s}", - str(request.get("model") or "mixtral-8x7b-32768"), - ) - def is_available(self) -> bool: - return self.available and bool(self.api_key) +# Register with the provider registry +LLMProviderRegistry.register( + "groq", + GroqProvider, + default_config={ + "api_key_env": "GROQ_API_KEY", + "default_model": "mixtral-8x7b-32768", + }, +) diff --git a/codeflow_engine/actions/ai_actions/llm/providers/mistral.py b/codeflow_engine/actions/ai_actions/llm/providers/mistral.py index 7f4010d..0c33c0b 100644 --- a/codeflow_engine/actions/ai_actions/llm/providers/mistral.py +++ b/codeflow_engine/actions/ai_actions/llm/providers/mistral.py @@ -1,56 +1,90 @@ """ Mistral AI provider implementation. -""" -from typing import Any, TYPE_CHECKING +Uses BaseLLMProvider with Mistral-specific client handling. +""" -# Handle optional mistralai dependency -ChatMessage: Any = None -try: - from mistralai.models.chat_completion import ChatMessage -except ImportError: - pass +from typing import Any -from codeflow_engine.actions.llm.base import BaseLLMProvider -from codeflow_engine.actions.llm.types import LLMResponse +from codeflow_engine.core.llm import BaseLLMProvider, LLMResponse, LLMProviderRegistry +from codeflow_engine.core.llm.response import ResponseExtractor class MistralProvider(BaseLLMProvider): - """Mistral AI provider.""" + """ + Mistral AI provider. + + Mistral has its own client library with a different API format. + """ + + DEFAULT_MODEL = "mistral-large-latest" def __init__(self, config: dict[str, Any]) -> None: super().__init__(config) + self.client: Any = None + self._chat_message_class: Any = None + self._initialize_client() + + def _initialize_client(self) -> None: + """Initialize the Mistral client.""" try: from mistralai.client import MistralClient + from mistralai.models.chat_completion import ChatMessage self.client = MistralClient(api_key=self.api_key) + self._chat_message_class = ChatMessage self.available = True except ImportError: self.available = False + def _convert_messages(self, messages: list[dict[str, Any]]) -> list[Any]: + """ + Convert messages to Mistral ChatMessage format. + + Args: + messages: List of message dicts with role and content + + Returns: + List of ChatMessage objects + """ + mistral_messages = [] + for msg in messages: + role = str(msg.get("role", "user")) + content = str(msg.get("content", "")).strip() + if content and self._chat_message_class: + mistral_messages.append( + self._chat_message_class(role=role, content=content) + ) + return mistral_messages + def complete(self, request: dict[str, Any]) -> LLMResponse: + """Complete a chat conversation using Mistral's API.""" + if not self.client: + return LLMResponse.from_error( + "Mistral client not initialized", + self.get_model(request, self.DEFAULT_MODEL), + ) + try: messages = request.get("messages", []) - model = request.get("model", self.default_model) or "mistral-large-latest" + model = self.get_model(request, self.DEFAULT_MODEL) max_tokens = request.get("max_tokens", 1024) temperature = request.get("temperature", 0.7) - # Convert input messages to correct type - mistral_messages: list[ChatMessage] = [] - for msg in messages: - role = str(msg.get("role", "user")) - content = str(msg.get("content", "")).strip() - if not content: - continue + # Convert messages to Mistral format + mistral_messages = self._convert_messages(messages) - mistral_messages.append(ChatMessage(role=role, content=content)) + if not mistral_messages: + return LLMResponse.from_error("No valid messages provided", model) - # Defensive: check if chat method exists + # Check if chat method exists chat_method = getattr(self.client, "chat", None) if not callable(chat_method): - msg = "MistralClient has no 'chat' method" - raise AttributeError(msg) + return LLMResponse.from_error( + "MistralClient has no 'chat' method", model + ) + # Call the API response = chat_method( model=str(model), messages=mistral_messages, @@ -58,44 +92,26 @@ def complete(self, request: dict[str, Any]) -> LLMResponse: max_tokens=max_tokens, ) - # Defensive: ensure response has expected attributes - content = "" - finish_reason = "stop" - if ( - hasattr(response, "choices") - and response.choices - and len(response.choices) > 0 - ): - choice = response.choices[0] - if hasattr(choice, "message") and hasattr(choice.message, "content"): - content = str(choice.message.content or "") - finish_reason = str(getattr(choice, "finish_reason", "stop") or "stop") - - # Extract usage information - usage = None - if hasattr(response, "usage") and response.usage is not None: - if hasattr(response.usage, "dict"): - usage = response.usage.dict() - else: - usage = { - "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), - "completion_tokens": getattr( - response.usage, "completion_tokens", 0 - ), - "total_tokens": getattr(response.usage, "total_tokens", 0), - } + # Extract response using centralized utility + content, finish_reason, usage = ResponseExtractor.extract_openai_response(response) return LLMResponse( - content=content, + content=str(content), model=str(model), - finish_reason=finish_reason, + finish_reason=str(finish_reason), usage=usage, ) - except Exception as e: - return LLMResponse.from_error( - f"Error calling Mistral API: {e!s}", - str(request.get("model") or "mistral-large-latest"), - ) - def is_available(self) -> bool: - return self.available and bool(self.api_key) + except Exception as e: + return self._create_error_response(e, request, self.DEFAULT_MODEL) + + +# Register with the provider registry +LLMProviderRegistry.register( + "mistral", + MistralProvider, + default_config={ + "api_key_env": "MISTRAL_API_KEY", + "default_model": "mistral-large-latest", + }, +) diff --git a/codeflow_engine/actions/ai_actions/llm/providers/openai.py b/codeflow_engine/actions/ai_actions/llm/providers/openai.py index d8cb952..efb50b6 100644 --- a/codeflow_engine/actions/ai_actions/llm/providers/openai.py +++ b/codeflow_engine/actions/ai_actions/llm/providers/openai.py @@ -1,18 +1,26 @@ """ OpenAI GPT provider implementation. + +Uses the OpenAICompatibleProvider base class for common functionality. """ from typing import Any -from codeflow_engine.actions.llm.base import BaseLLMProvider -from codeflow_engine.actions.llm.types import LLMResponse +from codeflow_engine.core.llm import OpenAICompatibleProvider, LLMProviderRegistry + +class OpenAIProvider(OpenAICompatibleProvider): + """ + OpenAI GPT provider. -class OpenAIProvider(BaseLLMProvider): - """OpenAI GPT provider.""" + Inherits all common functionality from OpenAICompatibleProvider. + """ - def __init__(self, config: dict[str, Any]) -> None: - super().__init__(config) + DEFAULT_MODEL = "gpt-4" + LIBRARY_NAME = "openai" + + def _initialize_client(self) -> None: + """Initialize the OpenAI client.""" try: import openai @@ -21,52 +29,13 @@ def __init__(self, config: dict[str, Any]) -> None: except ImportError: self.available = False - def complete(self, request: dict[str, Any]) -> LLMResponse: - try: - # Ensure all messages have non-empty content (additional robustness) - filtered_messages = [ - msg for msg in request["messages"] if msg["content"].strip() - ] - response = self.client.chat.completions.create( - model=request["model"] or self.default_model or "gpt-4", - messages=filtered_messages, - temperature=request.get("temperature", 0.7), - max_tokens=request.get("max_tokens"), - top_p=request.get("top_p", 1.0), - frequency_penalty=request.get("frequency_penalty", 0.0), - presence_penalty=request.get("presence_penalty", 0.0), - stop=request.get("stop"), - ) - - # Defensive: ensure response has expected attributes - if hasattr(response, "choices") and hasattr(response.choices[0], "message"): - content = response.choices[0].message.content or "" - finish_reason = ( - getattr(response.choices[0], "finish_reason", "stop") or "stop" - ) - else: - content = "" - finish_reason = "stop" - - model = getattr( - response, "model", request.get("model") or self.default_model or "gpt-4" - ) - usage = ( - response.usage.dict() - if hasattr(response, "usage") and response.usage is not None - else None - ) - - return LLMResponse( - content=str(content), - model=str(model), - finish_reason=str(finish_reason), - usage=usage, - ) - except Exception as e: - return LLMResponse.from_error( - str(e), request.get("model") or self.default_model or "gpt-4" - ) - def is_available(self) -> bool: - return self.available and bool(self.api_key) +# Register with the provider registry +LLMProviderRegistry.register( + "openai", + OpenAIProvider, + default_config={ + "api_key_env": "OPENAI_API_KEY", + "default_model": "gpt-4", + }, +) diff --git a/codeflow_engine/actions/ai_actions/llm/types.py b/codeflow_engine/actions/ai_actions/llm/types.py index d7ab539..c75d3e5 100644 --- a/codeflow_engine/actions/ai_actions/llm/types.py +++ b/codeflow_engine/actions/ai_actions/llm/types.py @@ -1,13 +1,17 @@ """ Base types, enums, and data classes for LLM providers. + +This module maintains backwards compatibility while re-exporting LLMResponse from core. """ -from dataclasses import dataclass from enum import StrEnum from typing import Any from pydantic import BaseModel +# Re-export LLMResponse from core for backwards compatibility +from codeflow_engine.core.llm.response import LLMResponse + class MessageRole(StrEnum): """Role of a message in a chat conversation.""" @@ -54,17 +58,10 @@ class LLMConfig(BaseModel): stop: list[str] | None = None -@dataclass -class LLMResponse: - """Response from an LLM provider.""" - - content: str - model: str - finish_reason: str - usage: dict[str, int] | None = None - error: str | None = None - - @classmethod - def from_error(cls, error: str, model: str = "unknown") -> "LLMResponse": - """Create an error response.""" - return cls(content="", model=model, finish_reason="error", error=error) +__all__ = [ + "LLMConfig", + "LLMProviderType", + "LLMResponse", + "Message", + "MessageRole", +] diff --git a/codeflow_engine/actions/ai_linting_fixer/file_manager.py b/codeflow_engine/actions/ai_linting_fixer/file_manager.py index e69b8a6..6c39539 100644 --- a/codeflow_engine/actions/ai_linting_fixer/file_manager.py +++ b/codeflow_engine/actions/ai_linting_fixer/file_manager.py @@ -1,402 +1,237 @@ """ -File Manager Module +File Manager Module - Facade for File Operations. -This module handles file operations, backups, and safe file modifications. +This module provides a high-level interface for file operations, +composing the core file components (FileIO, BackupService, ContentValidator). + +For new code, prefer using the core components directly: +- codeflow_engine.core.files.FileIO +- codeflow_engine.core.files.BackupService +- codeflow_engine.core.files.ContentValidator """ -from datetime import UTC, datetime import logging -import operator -from pathlib import Path -import shutil +from typing import Any + +from codeflow_engine.core.files import ( + BackupService, + ContentValidator, + FileIO, +) logger = logging.getLogger(__name__) class FileManager: - """Handles file operations and backups.""" + """ + Facade for file operations, backups, and validation. + + This class composes the core file components to provide + backward-compatible high-level file management. - def __init__(self, backup_directory: str | None = None): - """Initialize the file manager.""" + For new code, consider using the core components directly. + """ + + def __init__(self, backup_directory: str | None = None) -> None: + """ + Initialize the file manager. + + Args: + backup_directory: Directory to store backups + """ self.backup_directory = backup_directory or "./backups" - self._ensure_backup_directory() + self._backup_service = BackupService(self.backup_directory) + self._content_validator = ContentValidator() - def _ensure_backup_directory(self) -> None: - """Ensure the backup directory exists.""" - try: - Path(self.backup_directory).mkdir(parents=True, exist_ok=True) - except Exception as e: - logger.warning("Failed to create backup directory: %s", e) + # =================== + # Backup Operations + # =================== def create_backup(self, file_path: str) -> str | None: - """Create a backup of a file before modification.""" - try: - if not Path(self.backup_directory).exists(): - Path(self.backup_directory).mkdir(parents=True, exist_ok=True) - except Exception as e: - logger.warning("Failed to create backup directory: %s", e) - return None - - if not Path(file_path).exists(): - logger.warning("File does not exist: %s", file_path) - return None - - try: - timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S") - backup_filename = f"{Path(file_path).stem}.backup_{timestamp}" - backup_path = Path(self.backup_directory) / backup_filename - - shutil.copy2(file_path, backup_path) - logger.info("Created backup: %s", backup_path) - return str(backup_path) - - except Exception as e: - logger.exception("Failed to create backup for %s: %s", file_path, e) - return None + """ + Create a backup of a file before modification. + + Args: + file_path: Path to the file to backup + + Returns: + Backup path or None if failed + """ + backup = self._backup_service.create_backup(file_path) + return backup.backup_path if backup else None def create_backups(self, file_paths: list) -> int: - """Create backups for multiple files.""" - successful_backups = 0 - for file_path in file_paths: - backup_path = self.create_backup(file_path) - if backup_path: - successful_backups += 1 - return successful_backups + """ + Create backups for multiple files. - def write_file_safely( - self, file_path: str, content: str, backup: bool = True - ) -> bool: - """Write content to a file with optional backup.""" - backup_path = None - if backup: - backup_path = self.create_backup(file_path) - - try: - with Path(file_path).open("w", encoding="utf-8") as f: - f.write(content) - logger.info("Successfully wrote to file: %s", file_path) - return True - - except Exception as e: - logger.exception("Failed to write to file %s: %s", file_path, e) - - # Try to restore from backup if available - if backup_path and Path(backup_path).exists(): - try: - shutil.copy2(backup_path, file_path) - logger.info( - "Restored %s from backup after write failure", file_path - ) - return False - except Exception as restore_error: - logger.exception("Failed to restore from backup: %s", restore_error) - - return False + Args: + file_paths: List of file paths to backup + + Returns: + Number of successful backups + """ + return self._backup_service.create_backups(file_paths) def restore_from_backup(self, file_path: str, backup_path: str) -> bool: - """Restore a file from its backup.""" - try: - if not Path(backup_path).exists(): - logger.error("Backup file does not exist: %s", backup_path) - return False - - shutil.copy2(backup_path, file_path) - logger.info("Restored %s from backup: %s", file_path, backup_path) - return True - - except Exception as e: - logger.exception( - "Failed to restore %s from backup %s: %s", file_path, backup_path, e - ) - return False + """ + Restore a file from its backup. + + Args: + file_path: Target file path + backup_path: Path to the backup + + Returns: + True if successful + """ + return self._backup_service.restore(file_path, backup_path) + + def list_backups(self, file_path: str | None = None) -> list: + """ + List available backups. + + Args: + file_path: Optional filter by original file + + Returns: + List of backup info dictionaries + """ + return self._backup_service.list_backups(file_path) + + def cleanup_old_backups( + self, + max_backups: int = 10, + older_than_days: int | None = None, + ) -> int: + """ + Clean up old backup files. + + Args: + max_backups: Maximum backups to keep + older_than_days: Optional age limit + + Returns: + Number of backups removed + """ + return self._backup_service.cleanup_old_backups(max_backups, older_than_days) + + # =================== + # File I/O Operations + # =================== def read_file_safely(self, file_path: str) -> tuple[bool, str]: - """Read a file safely and return success status and content.""" - try: - with Path(file_path).open(encoding="utf-8") as f: - content = f.read() - except Exception as e: - logger.exception(f"Failed to read file {file_path}: {e}") - return False, "" - else: - return True, content + """ + Read a file safely. + + Args: + file_path: Path to the file + + Returns: + Tuple of (success, content) + """ + return FileIO.read(file_path) def read_file(self, file_path: str) -> str | None: - """Read a file and return its content.""" - success, content = self.read_file_safely(file_path) - return content if success else None + """ + Read a file and return its content. + + Args: + file_path: Path to the file + + Returns: + Content or None if failed + """ + return FileIO.read_or_none(file_path) def write_file(self, file_path: str, content: str) -> bool: - """Write content to a file.""" - return self.write_file_safely(file_path, content, backup=False) + """ + Write content to a file. + + Args: + file_path: Path to the file + content: Content to write + + Returns: + True if successful + """ + return FileIO.write(file_path, content) + + def write_file_safely( + self, + file_path: str, + content: str, + backup: bool = True, + ) -> bool: + """ + Write content with optional backup. + + Args: + file_path: Path to the file + content: Content to write + backup: Whether to create backup first + + Returns: + True if successful + """ + backup_path = None + if backup and FileIO.exists(file_path): + backup_result = self._backup_service.create_backup(file_path) + backup_path = backup_result.backup_path if backup_result else None + + success = FileIO.write(file_path, content) + + if not success and backup_path: + # Try to restore from backup + self._backup_service.restore(file_path, backup_path) + logger.info("Restored %s from backup after write failure", file_path) + + return success def file_exists(self, file_path: str) -> bool: """Check if a file exists.""" - return Path(file_path).exists() + return FileIO.exists(file_path) def get_file_size(self, file_path: str) -> int: - """Get the size of a file in bytes.""" - try: - return Path(file_path).stat().st_size - except Exception as e: - logger.debug("Failed to get file size for %s: %s", file_path, e) - return 0 + """Get file size in bytes.""" + return FileIO.get_size(file_path) def get_file_info(self, file_path: str) -> dict: - """Get comprehensive information about a file.""" - try: - file_path_obj = Path(file_path) - if not file_path_obj.exists(): - return {"exists": False} - - stat = file_path_obj.stat() - except Exception as e: - logger.debug("Failed to get file info for %s: %s", file_path, e) - return {"exists": False, "error": str(e)} - else: - return { - "exists": True, - "size_bytes": stat.st_size, - "size_mb": stat.st_size / (1024 * 1024), - "modified_time": datetime.fromtimestamp( - stat.st_mtime, tz=UTC - ).isoformat(), - "created_time": datetime.fromtimestamp( - stat.st_ctime, tz=UTC - ).isoformat(), - "is_file": file_path_obj.is_file(), - "is_directory": file_path_obj.is_dir(), - "extension": file_path_obj.suffix, - "name": file_path_obj.name, - "stem": file_path_obj.stem, - "parent": str(file_path_obj.parent), - } - - def list_backups(self, file_path: str | None = None) -> list: - """List available backups, optionally filtered by original file.""" - try: - backup_dir = Path(self.backup_directory) - if not backup_dir.exists(): - return [] - - backups = [] - for backup_file in backup_dir.glob("*.backup_*"): - backup_info = { - "backup_path": str(backup_file), - "backup_name": backup_file.name, - "size_bytes": backup_file.stat().st_size, - "modified_time": datetime.fromtimestamp( - backup_file.stat().st_mtime - ).isoformat(), - } - - # Try to extract original filename - if ".backup_" in backup_file.name: - original_name = backup_file.name.split(".backup_")[0] - backup_info["original_name"] = original_name - - # Filter by original file if specified - if file_path and not backup_file.name.startswith( - Path(file_path).stem - ): - continue - - backups.append(backup_info) - - # Sort by modification time (newest first) - backups.sort(key=operator.itemgetter("modified_time"), reverse=True) - except Exception as e: - logger.exception("Failed to list backups: %s", e) - return [] - else: - return backups - - def cleanup_old_backups( - self, max_backups: int = 10, older_than_days: int | None = None - ) -> int: - """Clean up old backup files.""" - try: - backups = self.list_backups() - if len(backups) <= max_backups: - return 0 - - # Remove oldest backups beyond max_backups - backups_to_remove = backups[max_backups:] - - # Additional filtering by age if specified - if older_than_days: - cutoff_time = datetime.now(UTC).timestamp() - ( - older_than_days * 24 * 60 * 60 - ) - backups_to_remove = [ - backup - for backup in backups_to_remove - if datetime.fromisoformat(backup["modified_time"]).timestamp() - < cutoff_time - ] - - removed_count = 0 - for backup in backups_to_remove: - try: - Path(backup["backup_path"]).unlink() - logger.debug("Removed old backup: %s", backup["backup_path"]) - removed_count += 1 - except Exception as e: - logger.warning( - "Failed to remove backup %s: %s", backup["backup_path"], e - ) - - logger.info("Cleaned up %d old backup files", removed_count) - return removed_count - - except Exception as e: - logger.exception("Failed to cleanup old backups: %s", e) - return 0 - - def validate_file_content(self, content: str) -> dict[str, object]: - """Validate file content for common issues.""" - validation_result: dict[str, object] = { - "valid": True, - "issues": list[str](), - "warnings": list[str](), - } - - try: - # Check for empty content - if not content.strip(): - warnings_list = validation_result["warnings"] - self._validate_warnings_list(warnings_list) - warnings_list.append("File content is empty") - - # Check for encoding issues - try: - content.encode("utf-8") - except UnicodeEncodeError: - issues_list = validation_result["issues"] - self._validate_issues_list(issues_list) - issues_list.append("Content contains invalid UTF-8 characters") - validation_result["valid"] = False - - # Check for extremely long lines - lines = content.split("\n") - for i, line in enumerate(lines, 1): - if len(line) > 1000: # Very long lines might indicate issues - warnings_list = validation_result["warnings"] - self._validate_warnings_list(warnings_list) - warnings_list.append( - f"Line {i} is very long ({len(line)} characters)" - ) - - # Check for mixed line endings - if "\r\n" in content and "\n" in content: - warnings_list = validation_result["warnings"] - self._validate_warnings_list(warnings_list) - warnings_list.append("Mixed line endings detected") - - # Check for trailing whitespace - for i, line in enumerate(lines, 1): - if line.rstrip() != line: - warnings_list = validation_result["warnings"] - self._validate_warnings_list(warnings_list) - warnings_list.append(f"Line {i} has trailing whitespace") - - except Exception as e: - issues_list = validation_result["issues"] - self._validate_issues_list(issues_list) - issues_list.append(f"Validation error: {e}") - validation_result["valid"] = False - - return validation_result + """Get comprehensive file information.""" + return FileIO.get_info(file_path) def create_directory_safely(self, directory_path: str) -> bool: """Create a directory safely.""" - try: - Path(directory_path).mkdir(parents=True, exist_ok=True) - return True - except Exception as e: - logger.exception(f"Failed to create directory {directory_path}: {e}") - return False + return FileIO.mkdir(directory_path) def copy_file_safely(self, source_path: str, destination_path: str) -> bool: """Copy a file safely.""" - try: - shutil.copy2(source_path, destination_path) - return True - except Exception as e: - logger.exception(f"Failed to copy {source_path} to {destination_path}: {e}") - return False + return FileIO.copy(source_path, destination_path) def move_file_safely(self, source_path: str, destination_path: str) -> bool: """Move a file safely.""" - try: - shutil.move(source_path, destination_path) - return True - except Exception as e: - logger.exception(f"Failed to move {source_path} to {destination_path}: {e}") - return False + return FileIO.move(source_path, destination_path) def delete_file(self, file_path: str) -> bool: """Delete a file safely.""" - try: - if not Path(file_path).exists(): - logger.debug("File not found for deletion: %s", file_path) - return False - - Path(file_path).unlink() - return True - - except Exception as e: - logger.exception("Failed to delete file %s: %s", file_path, e) - return False - - def _validate_warnings_list(self, warnings_list: list) -> None: - """Validate that warnings_list is a list.""" - if not isinstance(warnings_list, list): - msg = f"Expected list for warnings_list, got {type(warnings_list).__name__}" - raise TypeError(msg) - - def _validate_issues_list(self, issues_list: list) -> None: - """Validate that issues_list is a list.""" - if not isinstance(issues_list, list): - msg = f"Expected list for issues_list, got {type(issues_list).__name__}" - raise TypeError(msg) - - def _validate_file_info(self, file_info: dict) -> None: - """Validate file_info structure.""" - if not isinstance(file_info, dict): - msg = f"Expected dict for file_info, got {type(file_info).__name__}" - raise TypeError(msg) - - if "warnings" not in file_info: - msg = "file_info must contain 'warnings' key" - raise ValueError(msg) - - if "issues" not in file_info: - msg = "file_info must contain 'issues' key" - raise ValueError(msg) - - self._validate_warnings_list(file_info["warnings"]) - self._validate_issues_list(file_info["issues"]) - - def _validate_file_info_list(self, file_info_list: list) -> None: - """Validate that file_info_list is a list of valid file_info dicts.""" - if not isinstance(file_info_list, list): - msg = ( - f"Expected list for file_info_list, got {type(file_info_list).__name__}" - ) - raise TypeError(msg) - - for file_info in file_info_list: - self._validate_file_info(file_info) - - def _validate_file_info_dict(self, file_info_dict: dict) -> None: - """Validate that file_info_dict is a dict of valid file_info dicts.""" - if not isinstance(file_info_dict, dict): - msg = ( - f"Expected dict for file_info_dict, got {type(file_info_dict).__name__}" - ) - raise TypeError(msg) - - for file_info in file_info_dict.values(): - self._validate_file_info(file_info) + return FileIO.delete(file_path) + + # =================== + # Content Validation + # =================== + + def validate_file_content(self, content: str) -> dict[str, Any]: + """ + Validate file content for common issues. + + Args: + content: Content to validate + + Returns: + Validation result dictionary + """ + result = self._content_validator.validate(content) + return { + "valid": result.valid, + "issues": result.issues, + "warnings": result.warnings, + } diff --git a/codeflow_engine/core/__init__.py b/codeflow_engine/core/__init__.py new file mode 100644 index 0000000..ba95c02 --- /dev/null +++ b/codeflow_engine/core/__init__.py @@ -0,0 +1,90 @@ +""" +CodeFlow Core Module - Shared base classes, utilities, and patterns. + +This module provides common infrastructure used across the codeflow_engine package: +- Base classes for managers, validators, and handlers +- Common patterns (Registry, Factory, etc.) +- Configuration utilities +- Shared utilities +""" + +from codeflow_engine.core.llm import ( + BaseLLMProvider, + LLMProviderRegistry, + LLMResponse, + OpenAICompatibleProvider, +) + +from codeflow_engine.core.managers import ( + BaseManager, + ManagerConfig, + SessionMixin, + StatsMixin, +) + +from codeflow_engine.core.validation import ( + BaseTypeValidator, + CompositeValidator, + SecurityPatterns, + ValidationResult, + ValidationSeverity, +) + +from codeflow_engine.core.config import ( + AppSettings, + BaseConfig, + ConfigLoader, + DatabaseSettings, + LLMSettings, + LoggingSettings, + env_var, + env_bool, + env_int, + env_float, + env_list, +) + +from codeflow_engine.core.files import ( + BackupService, + ContentValidator, + ContentValidationResult, + FileBackup, + FileIO, +) + +__all__ = [ + # LLM + "BaseLLMProvider", + "LLMProviderRegistry", + "LLMResponse", + "OpenAICompatibleProvider", + # Managers + "BaseManager", + "ManagerConfig", + "SessionMixin", + "StatsMixin", + # Validation + "BaseTypeValidator", + "CompositeValidator", + "SecurityPatterns", + "ValidationResult", + "ValidationSeverity", + # Configuration + "AppSettings", + "BaseConfig", + "ConfigLoader", + "DatabaseSettings", + "LLMSettings", + "LoggingSettings", + "env_var", + "env_bool", + "env_int", + "env_float", + "env_list", + # Files + "BackupService", + "ContentValidator", + "ContentValidationResult", + "FileBackup", + "FileIO", +] diff --git a/codeflow_engine/core/config/__init__.py b/codeflow_engine/core/config/__init__.py new file mode 100644 index 0000000..1b1c89d --- /dev/null +++ b/codeflow_engine/core/config/__init__.py @@ -0,0 +1,40 @@ +""" +Core Configuration Module. + +Provides centralized configuration management with: +- Environment-based configuration loading +- Type-safe configuration models +- Environment variable helpers +""" + +from codeflow_engine.core.config.base import ( + BaseConfig, + ConfigLoader, + env_var, + env_bool, + env_int, + env_float, + env_list, +) +from codeflow_engine.core.config.models import ( + AppSettings, + DatabaseSettings, + LLMSettings, + LoggingSettings, +) + +__all__ = [ + # Base utilities + "BaseConfig", + "ConfigLoader", + "env_var", + "env_bool", + "env_int", + "env_float", + "env_list", + # Settings models + "AppSettings", + "DatabaseSettings", + "LLMSettings", + "LoggingSettings", +] diff --git a/codeflow_engine/core/config/base.py b/codeflow_engine/core/config/base.py new file mode 100644 index 0000000..cc0c334 --- /dev/null +++ b/codeflow_engine/core/config/base.py @@ -0,0 +1,338 @@ +""" +Base Configuration Utilities. + +Provides environment variable helpers and base configuration patterns. +""" + +import os +from abc import ABC +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, TypeVar + +import structlog + + +logger = structlog.get_logger(__name__) + +T = TypeVar("T") + + +def env_var(name: str, default: str = "") -> str: + """ + Get an environment variable. + + Args: + name: Variable name + default: Default value if not set + + Returns: + Environment variable value + """ + return os.getenv(name, default) + + +def env_bool(name: str, default: bool = False) -> bool: + """ + Get an environment variable as boolean. + + Truthy values: "1", "true", "yes", "on" (case-insensitive) + + Args: + name: Variable name + default: Default value if not set + + Returns: + Boolean value + """ + value = os.getenv(name, "").lower() + if not value: + return default + return value in ("1", "true", "yes", "on") + + +def env_int(name: str, default: int = 0) -> int: + """ + Get an environment variable as integer. + + Args: + name: Variable name + default: Default value if not set or invalid + + Returns: + Integer value + """ + value = os.getenv(name, "") + if not value: + return default + try: + return int(value) + except ValueError: + logger.warning("invalid_env_int", name=name, value=value, default=default) + return default + + +def env_float(name: str, default: float = 0.0) -> float: + """ + Get an environment variable as float. + + Args: + name: Variable name + default: Default value if not set or invalid + + Returns: + Float value + """ + value = os.getenv(name, "") + if not value: + return default + try: + return float(value) + except ValueError: + logger.warning("invalid_env_float", name=name, value=value, default=default) + return default + + +def env_list(name: str, default: list[str] | None = None, separator: str = ",") -> list[str]: + """ + Get an environment variable as a list. + + Args: + name: Variable name + default: Default value if not set + separator: List item separator + + Returns: + List of strings + """ + value = os.getenv(name, "") + if not value: + return default or [] + return [item.strip() for item in value.split(separator) if item.strip()] + + +@dataclass +class BaseConfig(ABC): + """ + Base class for configuration objects. + + Provides common patterns for configuration handling. + """ + + @classmethod + def from_env(cls, prefix: str = "") -> "BaseConfig": + """ + Create configuration from environment variables. + + Override in subclasses to implement specific loading logic. + + Args: + prefix: Optional prefix for environment variable names + + Returns: + Configuration instance + """ + raise NotImplementedError("Subclasses must implement from_env") + + def to_dict(self) -> dict[str, Any]: + """ + Convert configuration to dictionary. + + Returns: + Dictionary representation + """ + return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} + + def merge(self, overrides: dict[str, Any]) -> "BaseConfig": + """ + Create a new config with overrides applied. + + Args: + overrides: Dictionary of values to override + + Returns: + New config instance with overrides applied + """ + current = self.to_dict() + current.update(overrides) + return type(self)(**current) + + +@dataclass +class ConfigLoader: + """ + Utility for loading configuration from multiple sources. + + Supports loading from: + - Environment variables + - TOML files (pyproject.toml) + - YAML files + - JSON files + """ + + config_paths: list[str] = field(default_factory=lambda: ["pyproject.toml", "config.yaml"]) + + def load_toml(self, path: str, section: str | None = None) -> dict[str, Any]: + """ + Load configuration from TOML file. + + Args: + path: Path to TOML file + section: Optional section to extract (e.g., "tool.codeflow") + + Returns: + Configuration dictionary + """ + if not Path(path).exists(): + return {} + + try: + import tomllib + except ImportError: + try: + import tomli as tomllib # type: ignore[import-not-found] + except ImportError: + logger.debug("toml_not_available", path=path) + return {} + + try: + with open(path, "rb") as f: + data = tomllib.load(f) + + if section: + # Navigate to nested section (e.g., "tool.codeflow") + for key in section.split("."): + data = data.get(key, {}) + if not isinstance(data, dict): + return {} + + return data + except Exception as e: + logger.warning("toml_load_failed", path=path, error=str(e)) + return {} + + def load_yaml(self, path: str) -> dict[str, Any]: + """ + Load configuration from YAML file. + + Args: + path: Path to YAML file + + Returns: + Configuration dictionary + """ + if not Path(path).exists(): + return {} + + try: + import yaml + except ImportError: + logger.debug("yaml_not_available", path=path) + return {} + + try: + with open(path, encoding="utf-8") as f: + return yaml.safe_load(f) or {} + except Exception as e: + logger.warning("yaml_load_failed", path=path, error=str(e)) + return {} + + def load_json(self, path: str) -> dict[str, Any]: + """ + Load configuration from JSON file. + + Args: + path: Path to JSON file + + Returns: + Configuration dictionary + """ + if not Path(path).exists(): + return {} + + import json + + try: + with open(path, encoding="utf-8") as f: + return json.load(f) + except Exception as e: + logger.warning("json_load_failed", path=path, error=str(e)) + return {} + + def load(self, path: str, section: str | None = None) -> dict[str, Any]: + """ + Load configuration from file, auto-detecting format. + + Args: + path: Path to configuration file + section: Optional section to extract (for TOML) + + Returns: + Configuration dictionary + """ + path_lower = path.lower() + + if path_lower.endswith(".toml"): + return self.load_toml(path, section) + elif path_lower.endswith((".yaml", ".yml")): + return self.load_yaml(path) + elif path_lower.endswith(".json"): + return self.load_json(path) + else: + logger.warning("unknown_config_format", path=path) + return {} + + def load_merged( + self, + paths: list[str] | None = None, + section: str | None = None, + ) -> dict[str, Any]: + """ + Load and merge configuration from multiple files. + + Later files override earlier ones. + + Args: + paths: List of config file paths (uses self.config_paths if None) + section: Optional section to extract + + Returns: + Merged configuration dictionary + """ + paths = paths or self.config_paths + merged: dict[str, Any] = {} + + for path in paths: + config = self.load(path, section) + merged = self._deep_merge(merged, config) + + return merged + + def _deep_merge( + self, + base: dict[str, Any], + override: dict[str, Any], + ) -> dict[str, Any]: + """ + Deep merge two dictionaries. + + Args: + base: Base dictionary + override: Override dictionary + + Returns: + Merged dictionary + """ + result = dict(base) + + for key, value in override.items(): + if ( + key in result + and isinstance(result[key], dict) + and isinstance(value, dict) + ): + result[key] = self._deep_merge(result[key], value) + else: + result[key] = value + + return result diff --git a/codeflow_engine/core/config/models.py b/codeflow_engine/core/config/models.py new file mode 100644 index 0000000..28baebc --- /dev/null +++ b/codeflow_engine/core/config/models.py @@ -0,0 +1,176 @@ +""" +Configuration Models. + +Centralized configuration models for the codeflow_engine. +Uses dataclasses for simplicity; can be migrated to Pydantic if needed. +""" + +from dataclasses import dataclass, field +from enum import StrEnum +from typing import Any + +from codeflow_engine.core.config.base import ( + BaseConfig, + env_var, + env_bool, + env_int, + env_float, +) + + +class Environment(StrEnum): + """Application environment types.""" + + DEVELOPMENT = "development" + STAGING = "staging" + PRODUCTION = "production" + TESTING = "testing" + + +class LogLevel(StrEnum): + """Logging levels.""" + + DEBUG = "DEBUG" + INFO = "INFO" + WARNING = "WARNING" + ERROR = "ERROR" + CRITICAL = "CRITICAL" + + +@dataclass +class LoggingSettings(BaseConfig): + """Logging configuration.""" + + level: str = "INFO" + format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + json_format: bool = False + + @classmethod + def from_env(cls, prefix: str = "") -> "LoggingSettings": + """Load from environment variables.""" + p = f"{prefix}_" if prefix else "" + return cls( + level=env_var(f"{p}LOG_LEVEL", "INFO").upper(), + format=env_var(f"{p}LOG_FORMAT", cls.format), + json_format=env_bool(f"{p}LOG_JSON"), + ) + + +@dataclass +class DatabaseSettings(BaseConfig): + """Database configuration.""" + + url: str = "sqlite:///:memory:" + pool_size: int = 5 + max_overflow: int = 10 + pool_timeout: int = 30 + pool_recycle: int = 3600 + pool_pre_ping: bool = True + echo: bool = False + ssl_required: bool = False + + @classmethod + def from_env(cls, prefix: str = "") -> "DatabaseSettings": + """Load from environment variables.""" + p = f"{prefix}_" if prefix else "" + return cls( + url=env_var(f"{p}DATABASE_URL", "sqlite:///:memory:"), + pool_size=env_int(f"{p}DB_POOL_SIZE", 5), + max_overflow=env_int(f"{p}DB_MAX_OVERFLOW", 10), + pool_timeout=env_int(f"{p}DB_POOL_TIMEOUT", 30), + pool_recycle=env_int(f"{p}DB_POOL_RECYCLE", 3600), + pool_pre_ping=env_bool(f"{p}DB_POOL_PRE_PING", True), + echo=env_bool(f"{p}DB_ECHO"), + ssl_required=env_bool(f"{p}DB_SSL_REQUIRED"), + ) + + +@dataclass +class LLMSettings(BaseConfig): + """LLM provider configuration.""" + + provider: str = "openai" + api_key: str = "" + api_key_env: str = "" + model: str = "gpt-4" + temperature: float = 0.7 + max_tokens: int = 4096 + base_url: str | None = None + + @classmethod + def from_env(cls, prefix: str = "") -> "LLMSettings": + """Load from environment variables.""" + p = f"{prefix}_" if prefix else "" + provider = env_var(f"{p}LLM_PROVIDER", "openai") + + # Determine API key env var based on provider + api_key_env_map = { + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "groq": "GROQ_API_KEY", + "mistral": "MISTRAL_API_KEY", + "azure": "AZURE_OPENAI_API_KEY", + } + api_key_env = api_key_env_map.get(provider, f"{provider.upper()}_API_KEY") + + return cls( + provider=provider, + api_key=env_var(api_key_env, ""), + api_key_env=api_key_env, + model=env_var(f"{p}LLM_MODEL", "gpt-4"), + temperature=env_float(f"{p}LLM_TEMPERATURE", 0.7), + max_tokens=env_int(f"{p}LLM_MAX_TOKENS", 4096), + base_url=env_var(f"{p}LLM_BASE_URL") or None, + ) + + +@dataclass +class AppSettings(BaseConfig): + """Main application settings.""" + + environment: Environment = Environment.DEVELOPMENT + debug: bool = False + app_name: str = "codeflow_engine" + version: str = "0.1.0" + + logging: LoggingSettings = field(default_factory=LoggingSettings) + database: DatabaseSettings = field(default_factory=DatabaseSettings) + llm: LLMSettings = field(default_factory=LLMSettings) + + custom: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_env(cls, prefix: str = "") -> "AppSettings": + """Load all settings from environment variables.""" + p = f"{prefix}_" if prefix else "" + + env_str = env_var(f"{p}ENVIRONMENT", "development").lower() + try: + environment = Environment(env_str) + except ValueError: + environment = Environment.DEVELOPMENT + + return cls( + environment=environment, + debug=env_bool(f"{p}DEBUG"), + app_name=env_var(f"{p}APP_NAME", "codeflow_engine"), + version=env_var(f"{p}VERSION", "0.1.0"), + logging=LoggingSettings.from_env(prefix), + database=DatabaseSettings.from_env(prefix), + llm=LLMSettings.from_env(prefix), + ) + + @property + def is_production(self) -> bool: + """Check if running in production environment.""" + return self.environment == Environment.PRODUCTION + + @property + def is_development(self) -> bool: + """Check if running in development environment.""" + return self.environment == Environment.DEVELOPMENT + + @property + def is_testing(self) -> bool: + """Check if running in testing environment.""" + return self.environment == Environment.TESTING diff --git a/codeflow_engine/core/files/__init__.py b/codeflow_engine/core/files/__init__.py new file mode 100644 index 0000000..4e47913 --- /dev/null +++ b/codeflow_engine/core/files/__init__.py @@ -0,0 +1,20 @@ +""" +Core File Operations Module. + +Provides focused file handling components following Single Responsibility Principle: +- FileIO: Basic file read/write operations +- BackupService: File backup and restore capabilities +- ContentValidator: File content validation +""" + +from codeflow_engine.core.files.io import FileIO +from codeflow_engine.core.files.backup import BackupService, FileBackup +from codeflow_engine.core.files.validator import ContentValidator, ContentValidationResult + +__all__ = [ + "FileIO", + "BackupService", + "FileBackup", + "ContentValidator", + "ContentValidationResult", +] diff --git a/codeflow_engine/core/files/backup.py b/codeflow_engine/core/files/backup.py new file mode 100644 index 0000000..8deac87 --- /dev/null +++ b/codeflow_engine/core/files/backup.py @@ -0,0 +1,259 @@ +""" +Backup Service. + +Provides file backup, restore, and cleanup operations. +""" + +from dataclasses import dataclass, field +from datetime import datetime, UTC +import operator +from pathlib import Path +import shutil +from typing import Any + +import structlog + +from codeflow_engine.core.files.io import FileIO + + +logger = structlog.get_logger(__name__) + + +@dataclass +class FileBackup: + """Information about a file backup.""" + + file_path: str + backup_path: str + backup_time: datetime + original_size: int + metadata: dict[str, Any] = field(default_factory=dict) + + +class BackupService: + """ + Manages file backups and restore operations. + + Provides: + - Creating timestamped backups + - Listing available backups + - Restoring from backups + - Cleanup of old backups + """ + + def __init__(self, backup_directory: str = "./backups") -> None: + """ + Initialize the backup service. + + Args: + backup_directory: Directory to store backups + """ + self.backup_directory = Path(backup_directory) + self._ensure_backup_directory() + + def _ensure_backup_directory(self) -> None: + """Ensure the backup directory exists.""" + try: + self.backup_directory.mkdir(parents=True, exist_ok=True) + except Exception as e: + logger.warning("backup_dir_create_failed", error=str(e)) + + def create_backup(self, file_path: str, prefix: str = "") -> FileBackup | None: + """ + Create a backup of a file. + + Args: + file_path: Path to the file to backup + prefix: Optional prefix for backup filename + + Returns: + FileBackup info or None if failed + """ + path = Path(file_path) + if not path.exists(): + logger.warning("backup_source_not_found", file_path=file_path) + return None + + try: + timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S") + prefix_part = f"{prefix}_" if prefix else "" + backup_filename = f"{path.stem}.{prefix_part}backup_{timestamp}{path.suffix}" + backup_path = self.backup_directory / backup_filename + + shutil.copy2(file_path, backup_path) + + backup = FileBackup( + file_path=str(path.resolve()), + backup_path=str(backup_path), + backup_time=datetime.now(UTC), + original_size=FileIO.get_size(file_path), + ) + + logger.info("backup_created", file_path=file_path, backup_path=str(backup_path)) + return backup + + except Exception as e: + logger.error("backup_failed", file_path=file_path, error=str(e)) + return None + + def create_backups(self, file_paths: list[str], prefix: str = "") -> int: + """ + Create backups for multiple files. + + Args: + file_paths: List of file paths to backup + prefix: Optional prefix for backup filenames + + Returns: + Number of successful backups + """ + successful = 0 + for file_path in file_paths: + if self.create_backup(file_path, prefix): + successful += 1 + return successful + + def restore(self, file_path: str, backup_path: str) -> bool: + """ + Restore a file from backup. + + Args: + file_path: Original file path to restore to + backup_path: Path to the backup file + + Returns: + True if successful + """ + if not Path(backup_path).exists(): + logger.error("backup_not_found", backup_path=backup_path) + return False + + try: + shutil.copy2(backup_path, file_path) + logger.info("file_restored", file_path=file_path, backup_path=backup_path) + return True + except Exception as e: + logger.error( + "restore_failed", + file_path=file_path, + backup_path=backup_path, + error=str(e), + ) + return False + + def list_backups(self, file_path: str | None = None) -> list[dict[str, Any]]: + """ + List available backups. + + Args: + file_path: Optional filter by original file + + Returns: + List of backup info dictionaries, sorted newest first + """ + try: + if not self.backup_directory.exists(): + return [] + + backups = [] + for backup_file in self.backup_directory.glob("*.backup_*"): + try: + stat = backup_file.stat() + backup_info = { + "backup_path": str(backup_file), + "backup_name": backup_file.name, + "size_bytes": stat.st_size, + "modified_time": datetime.fromtimestamp( + stat.st_mtime, tz=UTC + ).isoformat(), + } + + # Extract original filename + name = backup_file.name + if ".backup_" in name: + original_stem = name.split(".backup_")[0] + # Remove any prefix (e.g., "session_") + parts = original_stem.rsplit(".", 1) + backup_info["original_stem"] = parts[0] if parts else original_stem + + # Filter by file_path if specified + if file_path: + file_stem = Path(file_path).stem + if not backup_info["original_stem"].endswith(file_stem): + continue + + backups.append(backup_info) + except Exception: + continue + + # Sort by modification time (newest first) + backups.sort(key=operator.itemgetter("modified_time"), reverse=True) + return backups + + except Exception as e: + logger.error("list_backups_failed", error=str(e)) + return [] + + def get_latest_backup(self, file_path: str) -> str | None: + """ + Get the path to the latest backup for a file. + + Args: + file_path: Original file path + + Returns: + Path to latest backup or None + """ + backups = self.list_backups(file_path) + return backups[0]["backup_path"] if backups else None + + def cleanup_old_backups( + self, + max_backups: int = 10, + older_than_days: int | None = None, + ) -> int: + """ + Clean up old backup files. + + Args: + max_backups: Maximum number of backups to keep + older_than_days: Optional age limit in days + + Returns: + Number of backups removed + """ + try: + backups = self.list_backups() + if len(backups) <= max_backups: + return 0 + + # Identify backups to remove + backups_to_remove = backups[max_backups:] + + # Filter by age if specified + if older_than_days: + cutoff_time = datetime.now(UTC).timestamp() - (older_than_days * 24 * 60 * 60) + backups_to_remove = [ + backup for backup in backups_to_remove + if datetime.fromisoformat(backup["modified_time"]).timestamp() < cutoff_time + ] + + removed = 0 + for backup in backups_to_remove: + try: + Path(backup["backup_path"]).unlink() + logger.debug("backup_removed", backup_path=backup["backup_path"]) + removed += 1 + except Exception as e: + logger.warning( + "backup_remove_failed", + backup_path=backup["backup_path"], + error=str(e), + ) + + logger.info("backups_cleaned", removed=removed) + return removed + + except Exception as e: + logger.error("cleanup_failed", error=str(e)) + return 0 diff --git a/codeflow_engine/core/files/io.py b/codeflow_engine/core/files/io.py new file mode 100644 index 0000000..c85a0a7 --- /dev/null +++ b/codeflow_engine/core/files/io.py @@ -0,0 +1,273 @@ +""" +File I/O Operations. + +Provides safe file reading and writing operations with consistent error handling. +""" + +from datetime import datetime, UTC +from pathlib import Path +import shutil +from typing import Any + +import structlog + + +logger = structlog.get_logger(__name__) + + +class FileIO: + """ + Handles basic file I/O operations with consistent error handling. + + Provides safe methods for reading, writing, copying, and moving files. + Does NOT handle backups - use BackupService for that. + """ + + @staticmethod + def read(file_path: str, encoding: str = "utf-8") -> tuple[bool, str]: + """ + Read a file safely. + + Args: + file_path: Path to the file + encoding: File encoding (default: utf-8) + + Returns: + Tuple of (success, content) + """ + try: + with Path(file_path).open(encoding=encoding) as f: + content = f.read() + return True, content + except Exception as e: + logger.warning("file_read_failed", file_path=file_path, error=str(e)) + return False, "" + + @staticmethod + def read_or_none(file_path: str, encoding: str = "utf-8") -> str | None: + """ + Read a file, returning None on failure. + + Args: + file_path: Path to the file + encoding: File encoding (default: utf-8) + + Returns: + File content or None + """ + success, content = FileIO.read(file_path, encoding) + return content if success else None + + @staticmethod + def write( + file_path: str, + content: str, + encoding: str = "utf-8", + create_dirs: bool = False, + ) -> bool: + """ + Write content to a file safely. + + Args: + file_path: Path to the file + content: Content to write + encoding: File encoding (default: utf-8) + create_dirs: Whether to create parent directories + + Returns: + True if successful, False otherwise + """ + try: + path = Path(file_path) + if create_dirs: + path.parent.mkdir(parents=True, exist_ok=True) + + with path.open("w", encoding=encoding) as f: + f.write(content) + + logger.debug("file_written", file_path=file_path, size=len(content)) + return True + except Exception as e: + logger.error("file_write_failed", file_path=file_path, error=str(e)) + return False + + @staticmethod + def exists(file_path: str) -> bool: + """ + Check if a file exists. + + Args: + file_path: Path to check + + Returns: + True if file exists + """ + return Path(file_path).exists() + + @staticmethod + def is_file(file_path: str) -> bool: + """ + Check if path is a file. + + Args: + file_path: Path to check + + Returns: + True if path is a file + """ + return Path(file_path).is_file() + + @staticmethod + def is_dir(file_path: str) -> bool: + """ + Check if path is a directory. + + Args: + file_path: Path to check + + Returns: + True if path is a directory + """ + return Path(file_path).is_dir() + + @staticmethod + def get_size(file_path: str) -> int: + """ + Get file size in bytes. + + Args: + file_path: Path to the file + + Returns: + File size in bytes, 0 if file doesn't exist + """ + try: + return Path(file_path).stat().st_size + except Exception: + return 0 + + @staticmethod + def get_info(file_path: str) -> dict[str, Any]: + """ + Get comprehensive file information. + + Args: + file_path: Path to the file + + Returns: + Dictionary with file information + """ + try: + path = Path(file_path) + if not path.exists(): + return {"exists": False} + + stat = path.stat() + return { + "exists": True, + "size_bytes": stat.st_size, + "size_mb": stat.st_size / (1024 * 1024), + "modified_time": datetime.fromtimestamp(stat.st_mtime, tz=UTC).isoformat(), + "created_time": datetime.fromtimestamp(stat.st_ctime, tz=UTC).isoformat(), + "is_file": path.is_file(), + "is_directory": path.is_dir(), + "extension": path.suffix, + "name": path.name, + "stem": path.stem, + "parent": str(path.parent), + } + except Exception as e: + logger.debug("file_info_failed", file_path=file_path, error=str(e)) + return {"exists": False, "error": str(e)} + + @staticmethod + def copy(source_path: str, destination_path: str) -> bool: + """ + Copy a file safely. + + Args: + source_path: Source file path + destination_path: Destination file path + + Returns: + True if successful + """ + try: + shutil.copy2(source_path, destination_path) + logger.debug("file_copied", source=source_path, destination=destination_path) + return True + except Exception as e: + logger.error( + "file_copy_failed", + source=source_path, + destination=destination_path, + error=str(e), + ) + return False + + @staticmethod + def move(source_path: str, destination_path: str) -> bool: + """ + Move a file safely. + + Args: + source_path: Source file path + destination_path: Destination file path + + Returns: + True if successful + """ + try: + shutil.move(source_path, destination_path) + logger.debug("file_moved", source=source_path, destination=destination_path) + return True + except Exception as e: + logger.error( + "file_move_failed", + source=source_path, + destination=destination_path, + error=str(e), + ) + return False + + @staticmethod + def delete(file_path: str) -> bool: + """ + Delete a file safely. + + Args: + file_path: Path to delete + + Returns: + True if successful or file doesn't exist + """ + try: + path = Path(file_path) + if not path.exists(): + return True + + path.unlink() + logger.debug("file_deleted", file_path=file_path) + return True + except Exception as e: + logger.error("file_delete_failed", file_path=file_path, error=str(e)) + return False + + @staticmethod + def mkdir(directory_path: str, parents: bool = True) -> bool: + """ + Create a directory safely. + + Args: + directory_path: Directory path to create + parents: Whether to create parent directories + + Returns: + True if successful + """ + try: + Path(directory_path).mkdir(parents=parents, exist_ok=True) + return True + except Exception as e: + logger.error("mkdir_failed", directory_path=directory_path, error=str(e)) + return False diff --git a/codeflow_engine/core/files/validator.py b/codeflow_engine/core/files/validator.py new file mode 100644 index 0000000..e18d683 --- /dev/null +++ b/codeflow_engine/core/files/validator.py @@ -0,0 +1,195 @@ +""" +Content Validator. + +Validates file content for common issues and encoding problems. +""" + +from dataclasses import dataclass, field +from typing import Any + +import structlog + + +logger = structlog.get_logger(__name__) + + +@dataclass +class ContentValidationResult: + """Result of content validation.""" + + valid: bool = True + issues: list[str] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + +class ContentValidator: + """ + Validates file content for common issues. + + Checks for: + - Empty content + - Encoding issues + - Long lines + - Mixed line endings + - Trailing whitespace + """ + + # Default thresholds + MAX_LINE_LENGTH = 1000 + WARN_LINE_LENGTH = 500 + + def __init__( + self, + max_line_length: int = MAX_LINE_LENGTH, + warn_line_length: int = WARN_LINE_LENGTH, + check_trailing_whitespace: bool = True, + check_mixed_line_endings: bool = True, + ) -> None: + """ + Initialize the content validator. + + Args: + max_line_length: Maximum allowed line length (issues if exceeded) + warn_line_length: Line length that triggers a warning + check_trailing_whitespace: Whether to check for trailing whitespace + check_mixed_line_endings: Whether to check for mixed line endings + """ + self.max_line_length = max_line_length + self.warn_line_length = warn_line_length + self.check_trailing_whitespace = check_trailing_whitespace + self.check_mixed_line_endings = check_mixed_line_endings + + def validate(self, content: str) -> ContentValidationResult: + """ + Validate file content. + + Args: + content: The content to validate + + Returns: + ContentValidationResult with validation outcome + """ + result = ContentValidationResult() + + # Check for empty content + if not content.strip(): + result.warnings.append("Content is empty") + result.metadata["is_empty"] = True + return result + + # Check encoding + self._check_encoding(content, result) + if not result.valid: + return result + + lines = content.split("\n") + result.metadata["line_count"] = len(lines) + + # Check line lengths + self._check_line_lengths(lines, result) + + # Check line endings + if self.check_mixed_line_endings: + self._check_line_endings(content, result) + + # Check trailing whitespace + if self.check_trailing_whitespace: + self._check_trailing_whitespace(lines, result) + + return result + + def _check_encoding(self, content: str, result: ContentValidationResult) -> None: + """Check for encoding issues.""" + try: + content.encode("utf-8") + except UnicodeEncodeError: + result.issues.append("Content contains invalid UTF-8 characters") + result.valid = False + + def _check_line_lengths( + self, + lines: list[str], + result: ContentValidationResult, + ) -> None: + """Check for long lines.""" + long_lines = [] + very_long_lines = [] + + for i, line in enumerate(lines, 1): + line_len = len(line) + if line_len > self.max_line_length: + very_long_lines.append(i) + elif line_len > self.warn_line_length: + long_lines.append(i) + + if very_long_lines: + result.warnings.append( + f"Lines exceeding {self.max_line_length} chars: {very_long_lines[:5]}" + + (f" (+{len(very_long_lines) - 5} more)" if len(very_long_lines) > 5 else "") + ) + + if long_lines: + result.metadata["long_lines"] = long_lines[:10] + + def _check_line_endings(self, content: str, result: ContentValidationResult) -> None: + """Check for mixed line endings.""" + has_crlf = "\r\n" in content + # Check for standalone CR or LF after removing CRLF + content_without_crlf = content.replace("\r\n", "") + has_lf = "\n" in content_without_crlf + has_cr = "\r" in content_without_crlf + + line_ending_types = sum([has_crlf, has_lf, has_cr]) + if line_ending_types > 1: + result.warnings.append("Mixed line endings detected (CRLF/LF/CR)") + result.metadata["mixed_line_endings"] = True + + # Record the detected line ending type + if has_crlf and not has_lf and not has_cr: + result.metadata["line_ending"] = "CRLF" + elif has_lf and not has_crlf and not has_cr: + result.metadata["line_ending"] = "LF" + elif has_cr and not has_crlf and not has_lf: + result.metadata["line_ending"] = "CR" + else: + result.metadata["line_ending"] = "mixed" + + def _check_trailing_whitespace( + self, + lines: list[str], + result: ContentValidationResult, + ) -> None: + """Check for trailing whitespace.""" + lines_with_trailing = [] + + for i, line in enumerate(lines, 1): + if line and line != line.rstrip(): + lines_with_trailing.append(i) + + if lines_with_trailing: + count = len(lines_with_trailing) + result.metadata["trailing_whitespace_lines"] = count + if count > 10: + result.warnings.append(f"{count} lines have trailing whitespace") + + def validate_for_write(self, content: str, strict: bool = False) -> tuple[bool, str]: + """ + Validate content before writing. + + Args: + content: Content to validate + strict: If True, warnings become blocking issues + + Returns: + Tuple of (is_valid, message) + """ + result = self.validate(content) + + if not result.valid: + return False, "; ".join(result.issues) + + if strict and result.warnings: + return False, "; ".join(result.warnings) + + return True, "Content is valid" diff --git a/codeflow_engine/core/llm/__init__.py b/codeflow_engine/core/llm/__init__.py new file mode 100644 index 0000000..1037326 --- /dev/null +++ b/codeflow_engine/core/llm/__init__.py @@ -0,0 +1,22 @@ +""" +Core LLM Module - Base classes and utilities for LLM providers. + +This module provides: +- BaseLLMProvider: Abstract base class for all LLM providers +- OpenAICompatibleProvider: Template for OpenAI API-compatible providers +- LLMProviderRegistry: Dynamic provider registration (Open/Closed principle) +- ResponseExtractor: Common response extraction utilities (DRY) +""" + +from codeflow_engine.core.llm.base import BaseLLMProvider +from codeflow_engine.core.llm.registry import LLMProviderRegistry +from codeflow_engine.core.llm.response import LLMResponse, ResponseExtractor +from codeflow_engine.core.llm.openai_compatible import OpenAICompatibleProvider + +__all__ = [ + "BaseLLMProvider", + "LLMProviderRegistry", + "LLMResponse", + "OpenAICompatibleProvider", + "ResponseExtractor", +] diff --git a/codeflow_engine/core/llm/base.py b/codeflow_engine/core/llm/base.py new file mode 100644 index 0000000..2b163f2 --- /dev/null +++ b/codeflow_engine/core/llm/base.py @@ -0,0 +1,104 @@ +""" +Abstract base class for LLM providers. + +Provides common functionality and defines the interface that all providers must implement. +""" + +import logging +import os +from abc import ABC, abstractmethod +from typing import Any + +from codeflow_engine.core.llm.response import LLMResponse + +logger = logging.getLogger(__name__) + + +class BaseLLMProvider(ABC): + """ + Abstract base class for LLM providers. + + This class defines the interface and provides common functionality for all LLM providers. + + Attributes: + config: Configuration dictionary for the provider + api_key: API key for authentication + base_url: Optional custom base URL for the API + default_model: Default model to use for completions + name: Human-readable name of the provider + available: Whether the provider is properly initialized + """ + + def __init__(self, config: dict[str, Any]) -> None: + """ + Initialize the provider with configuration. + + Args: + config: Configuration dictionary containing: + - api_key: Direct API key (optional) + - api_key_env: Environment variable name for API key + - base_url: Custom API base URL (optional) + - default_model: Default model name + - name: Provider name (optional, defaults to class name) + """ + self.config = config + self.api_key = config.get("api_key") or os.getenv(config.get("api_key_env", "")) + self.base_url = config.get("base_url") + self.default_model = config.get("default_model") + self.name = config.get("name", self.__class__.__name__.lower().replace("provider", "")) + self.available = False # Subclasses set this during initialization + + @abstractmethod + def complete(self, request: dict[str, Any]) -> LLMResponse: + """ + Complete a chat conversation. + + Args: + request: Request dictionary containing: + - messages: List of message dicts with 'role' and 'content' + - model: Model name (optional, uses default_model if not specified) + - temperature: Sampling temperature (optional) + - max_tokens: Maximum tokens in response (optional) + - Additional provider-specific parameters + + Returns: + LLMResponse with the completion result or error + """ + + def is_available(self) -> bool: + """ + Check if the provider is properly configured and available. + + Returns: + True if the provider can accept requests + """ + return self.available and bool(self.api_key) + + def get_model(self, request: dict[str, Any], fallback: str = "unknown") -> str: + """ + Get the model to use from request, config, or fallback. + + Args: + request: Request dictionary that may contain 'model' key + fallback: Fallback model name if none specified + + Returns: + Model name to use + """ + return request.get("model") or self.default_model or fallback + + def _create_error_response(self, error: Exception | str, request: dict[str, Any], fallback_model: str = "unknown") -> LLMResponse: + """ + Create an error response with consistent formatting. + + Args: + error: The error that occurred + request: The original request + fallback_model: Fallback model name for the response + + Returns: + LLMResponse with error information + """ + error_msg = str(error) if isinstance(error, Exception) else error + model = self.get_model(request, fallback_model) + return LLMResponse.from_error(f"Error calling {self.name} API: {error_msg}", model) diff --git a/codeflow_engine/core/llm/openai_compatible.py b/codeflow_engine/core/llm/openai_compatible.py new file mode 100644 index 0000000..9c577b8 --- /dev/null +++ b/codeflow_engine/core/llm/openai_compatible.py @@ -0,0 +1,185 @@ +""" +OpenAI-Compatible Provider Template. + +This module provides a base class for providers that use the OpenAI API format, +eliminating duplication across OpenAI, Groq, Perplexity, Together AI, and other +compatible providers. +""" + +import logging +from typing import Any + +from codeflow_engine.core.llm.base import BaseLLMProvider +from codeflow_engine.core.llm.response import LLMResponse, ResponseExtractor + +logger = logging.getLogger(__name__) + + +class OpenAICompatibleProvider(BaseLLMProvider): + """ + Base class for providers that use the OpenAI API format. + + This template class implements the common patterns shared by: + - OpenAI + - Azure OpenAI + - Groq + - Perplexity + - Together AI + - Any other OpenAI-compatible API + + Subclasses only need to implement: + - _initialize_client(): Set up the specific client + - Optionally override _get_default_model() and other hooks + """ + + # Class-level defaults that subclasses can override + DEFAULT_MODEL: str = "gpt-4" + LIBRARY_NAME: str = "openai" # For error messages + CLIENT_CLASS_PATH: str = "openai.OpenAI" # For documentation + + def __init__(self, config: dict[str, Any]) -> None: + """Initialize the provider.""" + super().__init__(config) + self.client: Any = None + self._initialize_client() + + def _initialize_client(self) -> None: + """ + Initialize the API client. + + Subclasses should override this to set up their specific client. + Must set self.available = True on success, False on failure. + """ + try: + import openai + + self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) + self.available = True + except ImportError: + logger.debug(f"{self.LIBRARY_NAME} package not installed") + self.available = False + + def _get_default_model(self) -> str: + """Get the default model for this provider.""" + return self.default_model or self.DEFAULT_MODEL + + def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, str]]: + """ + Prepare messages for the API call. + + Override in subclasses that need custom message formatting. + + Args: + messages: Raw messages from the request + + Returns: + Formatted messages ready for the API + """ + return ResponseExtractor.filter_messages(messages) + + def _make_api_call( + self, + messages: list[dict[str, str]], + model: str, + temperature: float, + max_tokens: int | None, + **kwargs: Any, + ) -> Any: + """ + Make the actual API call. + + Override in subclasses that need custom API call logic. + + Args: + messages: Prepared messages + model: Model to use + temperature: Sampling temperature + max_tokens: Maximum tokens in response + **kwargs: Additional parameters + + Returns: + Raw API response + """ + call_params: dict[str, Any] = { + "model": str(model), + "messages": messages, + "temperature": temperature, + } + + if max_tokens is not None: + call_params["max_tokens"] = max_tokens + + # Add any additional parameters from kwargs + for key in ["top_p", "frequency_penalty", "presence_penalty", "stop"]: + if key in kwargs and kwargs[key] is not None: + call_params[key] = kwargs[key] + + return self.client.chat.completions.create(**call_params) + + def _extract_response(self, response: Any, model: str) -> LLMResponse: + """ + Extract the LLMResponse from the API response. + + Override in subclasses that have non-standard response formats. + + Args: + response: Raw API response + model: Model used for the request + + Returns: + LLMResponse object + """ + content, finish_reason, usage = ResponseExtractor.extract_openai_response(response) + + return LLMResponse( + content=str(content), + model=str(getattr(response, "model", model)), + finish_reason=str(finish_reason), + usage=usage, + ) + + def complete(self, request: dict[str, Any]) -> LLMResponse: + """ + Complete a chat conversation using the OpenAI-compatible API. + + Args: + request: Request dictionary with messages and optional parameters + + Returns: + LLMResponse with the completion result or error + """ + if not self.client: + return LLMResponse.from_error( + f"{self.name} client not initialized", + self.get_model(request, self._get_default_model()), + ) + + try: + messages = request.get("messages", []) + model = self.get_model(request, self._get_default_model()) + temperature = request.get("temperature", 0.7) + max_tokens = request.get("max_tokens") + + # Prepare messages (can be overridden by subclasses) + prepared_messages = self._prepare_messages(messages) + + if not prepared_messages: + return LLMResponse.from_error("No valid messages provided", model) + + # Make the API call + response = self._make_api_call( + messages=prepared_messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + top_p=request.get("top_p"), + frequency_penalty=request.get("frequency_penalty"), + presence_penalty=request.get("presence_penalty"), + stop=request.get("stop"), + ) + + # Extract and return the response + return self._extract_response(response, model) + + except Exception as e: + return self._create_error_response(e, request, self._get_default_model()) diff --git a/codeflow_engine/core/llm/registry.py b/codeflow_engine/core/llm/registry.py new file mode 100644 index 0000000..603a264 --- /dev/null +++ b/codeflow_engine/core/llm/registry.py @@ -0,0 +1,167 @@ +""" +LLM Provider Registry - Dynamic provider registration following Open/Closed principle. + +This registry allows new providers to be added without modifying existing code. +""" + +import logging +from typing import Any, TypeVar + +from codeflow_engine.core.llm.base import BaseLLMProvider + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseLLMProvider) + + +class LLMProviderRegistry: + """ + Registry for LLM providers following the Open/Closed principle. + + Allows providers to be registered dynamically without modifying manager code. + + Usage: + # Register a provider + LLMProviderRegistry.register("openai", OpenAIProvider) + + # Create a provider instance + provider = LLMProviderRegistry.create("openai", config) + + # Get all registered providers + providers = LLMProviderRegistry.get_all() + """ + + _providers: dict[str, type[BaseLLMProvider]] = {} + _default_configs: dict[str, dict[str, Any]] = {} + + @classmethod + def register( + cls, + name: str, + provider_class: type[BaseLLMProvider], + default_config: dict[str, Any] | None = None, + ) -> None: + """ + Register a provider class with the registry. + + Args: + name: Unique identifier for the provider (e.g., 'openai', 'anthropic') + provider_class: The provider class to register + default_config: Optional default configuration for this provider + """ + cls._providers[name.lower()] = provider_class + if default_config: + cls._default_configs[name.lower()] = default_config + logger.debug(f"Registered LLM provider: {name}") + + @classmethod + def unregister(cls, name: str) -> bool: + """ + Remove a provider from the registry. + + Args: + name: The provider name to remove + + Returns: + True if the provider was removed, False if not found + """ + name_lower = name.lower() + if name_lower in cls._providers: + del cls._providers[name_lower] + cls._default_configs.pop(name_lower, None) + return True + return False + + @classmethod + def create(cls, name: str, config: dict[str, Any] | None = None) -> BaseLLMProvider | None: + """ + Create a provider instance by name. + + Args: + name: The provider name to instantiate + config: Configuration to pass to the provider (merged with defaults) + + Returns: + Provider instance or None if provider not found + """ + name_lower = name.lower() + provider_class = cls._providers.get(name_lower) + + if provider_class is None: + logger.warning(f"Provider '{name}' not found in registry") + return None + + # Merge default config with provided config + default_config = cls._default_configs.get(name_lower, {}) + merged_config = {**default_config, **(config or {})} + + try: + return provider_class(merged_config) + except Exception as e: + logger.exception(f"Failed to create provider '{name}': {e}") + return None + + @classmethod + def get_provider_class(cls, name: str) -> type[BaseLLMProvider] | None: + """ + Get a provider class by name without instantiating it. + + Args: + name: The provider name + + Returns: + Provider class or None if not found + """ + return cls._providers.get(name.lower()) + + @classmethod + def get_all(cls) -> dict[str, type[BaseLLMProvider]]: + """ + Get all registered providers. + + Returns: + Dictionary mapping provider names to their classes + """ + return cls._providers.copy() + + @classmethod + def get_default_config(cls, name: str) -> dict[str, Any]: + """ + Get the default configuration for a provider. + + Args: + name: The provider name + + Returns: + Default configuration dictionary (empty if none registered) + """ + return cls._default_configs.get(name.lower(), {}).copy() + + @classmethod + def is_registered(cls, name: str) -> bool: + """ + Check if a provider is registered. + + Args: + name: The provider name to check + + Returns: + True if the provider is registered + """ + return name.lower() in cls._providers + + @classmethod + def list_providers(cls) -> list[str]: + """ + List all registered provider names. + + Returns: + List of provider names + """ + return list(cls._providers.keys()) + + @classmethod + def clear(cls) -> None: + """Clear all registered providers. Mainly useful for testing.""" + cls._providers.clear() + cls._default_configs.clear() diff --git a/codeflow_engine/core/llm/response.py b/codeflow_engine/core/llm/response.py new file mode 100644 index 0000000..e970664 --- /dev/null +++ b/codeflow_engine/core/llm/response.py @@ -0,0 +1,119 @@ +""" +LLM Response types and extraction utilities. + +This module centralizes response handling to eliminate duplication across providers. +""" + +from dataclasses import dataclass +from typing import Any + + +@dataclass +class LLMResponse: + """Response from an LLM provider.""" + + content: str + model: str + finish_reason: str + usage: dict[str, int] | None = None + error: str | None = None + + @classmethod + def from_error(cls, error: str, model: str = "unknown") -> "LLMResponse": + """Create an error response.""" + return cls(content="", model=model, finish_reason="error", error=error) + + +class ResponseExtractor: + """ + Utility class for extracting responses from various LLM API formats. + + Centralizes response extraction logic to eliminate duplication (DRY principle). + """ + + @staticmethod + def extract_openai_response(response: Any, default_model: str = "unknown") -> tuple[str, str, dict[str, int] | None]: + """ + Extract content, finish_reason, and usage from OpenAI-compatible responses. + + Works with: OpenAI, Azure OpenAI, Groq, Perplexity, Together AI. + + Args: + response: The API response object + default_model: Fallback model name if not in response + + Returns: + Tuple of (content, finish_reason, usage) + """ + content = "" + finish_reason = "stop" + usage = None + + # Extract content and finish_reason from choices + if hasattr(response, "choices") and response.choices and len(response.choices) > 0: + choice = response.choices[0] + if hasattr(choice, "message") and hasattr(choice.message, "content"): + content = choice.message.content or "" + finish_reason = getattr(choice, "finish_reason", "stop") or "stop" + + # Extract usage information + if hasattr(response, "usage") and response.usage: + if hasattr(response.usage, "dict"): + usage = response.usage.dict() + else: + usage = { + "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), + "completion_tokens": getattr(response.usage, "completion_tokens", 0), + "total_tokens": getattr(response.usage, "total_tokens", 0), + } + + return content, finish_reason, usage + + @staticmethod + def extract_anthropic_response(response: Any) -> tuple[str, str, dict[str, int] | None]: + """ + Extract content, finish_reason, and usage from Anthropic responses. + + Args: + response: The Anthropic API response object + + Returns: + Tuple of (content, finish_reason, usage) + """ + content = "" + if hasattr(response, "content") and response.content: + content = "\n".join( + block.text for block in response.content if hasattr(block, "text") + ) + + finish_reason = getattr(response, "stop_reason", "stop") + + usage = None + if hasattr(response, "usage"): + response_usage = response.usage + input_tokens = getattr(response_usage, "input_tokens", 0) if hasattr(response_usage, "input_tokens") else response_usage.get("input_tokens", 0) if isinstance(response_usage, dict) else 0 + output_tokens = getattr(response_usage, "output_tokens", 0) if hasattr(response_usage, "output_tokens") else response_usage.get("output_tokens", 0) if isinstance(response_usage, dict) else 0 + usage = { + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + } + + return content, finish_reason, usage + + @staticmethod + def filter_messages(messages: list[dict[str, Any]]) -> list[dict[str, str]]: + """ + Filter and normalize messages, removing empty content. + + Args: + messages: List of message dictionaries + + Returns: + Filtered list of messages with role and content + """ + return [ + {"role": msg.get("role", "user"), "content": msg.get("content", "")} + for msg in messages + if msg.get("content", "").strip() + ] diff --git a/codeflow_engine/core/managers/__init__.py b/codeflow_engine/core/managers/__init__.py new file mode 100644 index 0000000..fe50187 --- /dev/null +++ b/codeflow_engine/core/managers/__init__.py @@ -0,0 +1,20 @@ +""" +Core Manager Framework. + +Provides base classes and utilities for building manager components +with consistent patterns for logging, configuration, and lifecycle. +""" + +from codeflow_engine.core.managers.base import ( + BaseManager, + ManagerConfig, + SessionMixin, + StatsMixin, +) + +__all__ = [ + "BaseManager", + "ManagerConfig", + "SessionMixin", + "StatsMixin", +] diff --git a/codeflow_engine/core/managers/base.py b/codeflow_engine/core/managers/base.py new file mode 100644 index 0000000..4d89e16 --- /dev/null +++ b/codeflow_engine/core/managers/base.py @@ -0,0 +1,328 @@ +""" +Base Manager Framework. + +Provides common patterns for manager classes including: +- Structured logging +- Configuration management +- Session handling +- Statistics collection +""" + +from abc import ABC +from dataclasses import dataclass, field +from datetime import datetime, UTC +from typing import Any, TypeVar + +import structlog + + +T = TypeVar("T", bound="ManagerConfig") + + +@dataclass +class ManagerConfig: + """ + Base configuration for managers. + + Subclass this to add manager-specific configuration options. + """ + + name: str = "manager" + enabled: bool = True + log_level: str = "INFO" + metadata: dict[str, Any] = field(default_factory=dict) + + def merge(self: T, overrides: dict[str, Any]) -> T: + """ + Create a new config with overrides applied. + + Args: + overrides: Dictionary of values to override + + Returns: + New config instance with overrides applied + """ + current = {k: v for k, v in self.__dict__.items()} + current.update(overrides) + return type(self)(**current) + + +class BaseManager(ABC): + """ + Abstract base class for manager components. + + Provides: + - Structured logging via structlog + - Configuration handling + - Lifecycle management (startup/shutdown) + + Usage: + class MyManager(BaseManager): + def __init__(self, config: MyConfig | None = None): + super().__init__(config or MyConfig()) + + def _on_startup(self) -> None: + # Custom initialization + pass + """ + + def __init__(self, config: ManagerConfig | None = None) -> None: + """ + Initialize the manager. + + Args: + config: Optional configuration object + """ + self._config = config or ManagerConfig() + self._logger = structlog.get_logger(self._config.name) + self._started = False + self._start_time: datetime | None = None + + @property + def config(self) -> ManagerConfig: + """Get the manager configuration.""" + return self._config + + @property + def logger(self) -> structlog.stdlib.BoundLogger: + """Get the bound logger for this manager.""" + return self._logger + + @property + def is_started(self) -> bool: + """Check if the manager has been started.""" + return self._started + + @property + def uptime_seconds(self) -> float: + """Get the uptime in seconds since startup.""" + if not self._start_time: + return 0.0 + return (datetime.now(UTC) - self._start_time).total_seconds() + + def startup(self) -> None: + """ + Start the manager. + + Calls _on_startup() for subclass initialization. + """ + if self._started: + self._logger.warning("manager_already_started", name=self._config.name) + return + + self._logger.info("manager_starting", name=self._config.name) + self._start_time = datetime.now(UTC) + self._on_startup() + self._started = True + self._logger.info("manager_started", name=self._config.name) + + def shutdown(self) -> None: + """ + Shutdown the manager. + + Calls _on_shutdown() for subclass cleanup. + """ + if not self._started: + return + + self._logger.info("manager_shutting_down", name=self._config.name) + self._on_shutdown() + self._started = False + self._logger.info("manager_shutdown", name=self._config.name) + + def _on_startup(self) -> None: + """ + Hook for subclass startup initialization. + + Override in subclasses to perform custom startup logic. + """ + pass + + def _on_shutdown(self) -> None: + """ + Hook for subclass shutdown cleanup. + + Override in subclasses to perform custom cleanup logic. + """ + pass + + +class SessionMixin: + """ + Mixin providing session management capabilities. + + Add this to managers that need to track operations within sessions. + """ + + def __init__(self) -> None: + """Initialize session tracking.""" + self._sessions: dict[str, dict[str, Any]] = {} + self._current_session: str | None = None + + @property + def current_session_id(self) -> str | None: + """Get the current session ID.""" + return self._current_session + + @property + def active_sessions(self) -> list[str]: + """Get list of active session IDs.""" + return [ + sid for sid, data in self._sessions.items() + if data.get("is_active", False) + ] + + def start_session(self, session_id: str, metadata: dict[str, Any] | None = None) -> None: + """ + Start a new session. + + Args: + session_id: Unique identifier for the session + metadata: Optional metadata to attach to the session + """ + self._sessions[session_id] = { + "start_time": datetime.now(UTC), + "is_active": True, + "metadata": metadata or {}, + "data": {}, + } + self._current_session = session_id + + def end_session(self, session_id: str | None = None) -> None: + """ + End a session. + + Args: + session_id: Session to end, defaults to current session + """ + sid = session_id or self._current_session + if sid and sid in self._sessions: + self._sessions[sid]["is_active"] = False + self._sessions[sid]["end_time"] = datetime.now(UTC) + + if self._current_session == sid: + self._current_session = None + + def get_session_data(self, session_id: str | None = None) -> dict[str, Any]: + """ + Get session data. + + Args: + session_id: Session to get data for, defaults to current session + + Returns: + Session data dictionary + """ + sid = session_id or self._current_session + if sid and sid in self._sessions: + return self._sessions[sid].get("data", {}) + return {} + + def set_session_data(self, key: str, value: Any, session_id: str | None = None) -> None: + """ + Set session data. + + Args: + key: Data key + value: Data value + session_id: Session to set data for, defaults to current session + """ + sid = session_id or self._current_session + if sid and sid in self._sessions: + self._sessions[sid].setdefault("data", {})[key] = value + + +class StatsMixin: + """ + Mixin providing statistics collection capabilities. + + Add this to managers that need to track operational metrics. + """ + + def __init__(self) -> None: + """Initialize statistics tracking.""" + self._stats: dict[str, int | float] = {} + self._stats_history: list[dict[str, Any]] = [] + + def increment_stat(self, name: str, amount: int = 1) -> None: + """ + Increment a counter statistic. + + Args: + name: Statistic name + amount: Amount to increment by + """ + self._stats[name] = self._stats.get(name, 0) + amount + + def set_stat(self, name: str, value: int | float) -> None: + """ + Set a statistic value. + + Args: + name: Statistic name + value: Value to set + """ + self._stats[name] = value + + def get_stat(self, name: str, default: int | float = 0) -> int | float: + """ + Get a statistic value. + + Args: + name: Statistic name + default: Default value if stat doesn't exist + + Returns: + Statistic value + """ + return self._stats.get(name, default) + + def get_all_stats(self) -> dict[str, int | float]: + """ + Get all statistics. + + Returns: + Dictionary of all statistics + """ + return self._stats.copy() + + def record_event(self, event_type: str, data: dict[str, Any] | None = None) -> None: + """ + Record an event in statistics history. + + Args: + event_type: Type of event + data: Optional event data + """ + self._stats_history.append({ + "timestamp": datetime.now(UTC).isoformat(), + "event_type": event_type, + "data": data or {}, + }) + + def get_stats_history( + self, + event_type: str | None = None, + limit: int | None = None, + ) -> list[dict[str, Any]]: + """ + Get statistics history. + + Args: + event_type: Optional filter by event type + limit: Optional limit on number of events + + Returns: + List of recorded events + """ + history = self._stats_history + if event_type: + history = [e for e in history if e["event_type"] == event_type] + if limit: + history = history[-limit:] + return history + + def clear_stats(self) -> None: + """Clear all statistics and history.""" + self._stats.clear() + self._stats_history.clear() diff --git a/codeflow_engine/core/validation/__init__.py b/codeflow_engine/core/validation/__init__.py new file mode 100644 index 0000000..2182dfb --- /dev/null +++ b/codeflow_engine/core/validation/__init__.py @@ -0,0 +1,29 @@ +""" +Core Validation Module - Base classes and utilities for input validation. + +This module provides: +- SecurityPatterns: Centralized security threat patterns (DRY) +- ValidationResult: Standard validation result structure +- BaseTypeValidator: Abstract base for type-specific validators +- CompositeValidator: Composition-based validator (SOLID principles) +""" + +from codeflow_engine.core.validation.patterns import SecurityPatterns +from codeflow_engine.core.validation.result import ( + ValidationResult, + ValidationSeverity, + merge_validation_results, + update_severity, +) +from codeflow_engine.core.validation.base import BaseTypeValidator +from codeflow_engine.core.validation.composite import CompositeValidator + +__all__ = [ + "BaseTypeValidator", + "CompositeValidator", + "SecurityPatterns", + "ValidationResult", + "ValidationSeverity", + "merge_validation_results", + "update_severity", +] diff --git a/codeflow_engine/core/validation/base.py b/codeflow_engine/core/validation/base.py new file mode 100644 index 0000000..37fe519 --- /dev/null +++ b/codeflow_engine/core/validation/base.py @@ -0,0 +1,85 @@ +""" +Base Type Validator. + +This module provides the abstract base class for type-specific validators, +following the Single Responsibility Principle. +""" + +from abc import ABC, abstractmethod +from typing import Any + +from codeflow_engine.core.validation.result import ValidationResult, ValidationSeverity +from codeflow_engine.core.validation.patterns import SecurityPatterns, DEFAULT_SECURITY_PATTERNS + + +class BaseTypeValidator(ABC): + """ + Abstract base class for type-specific validators. + + Each validator is responsible for validating a single type of data, + following the Single Responsibility Principle. + + Subclasses must implement: + - can_validate(): Check if this validator handles the given type + - validate(): Perform the validation + + Attributes: + security_patterns: Security patterns to use for threat detection + """ + + def __init__(self, security_patterns: SecurityPatterns | None = None) -> None: + """ + Initialize the validator. + + Args: + security_patterns: Optional custom security patterns. Uses defaults if not provided. + """ + self.security_patterns = security_patterns or DEFAULT_SECURITY_PATTERNS + + @abstractmethod + def can_validate(self, value: Any) -> bool: + """ + Check if this validator can handle the given value type. + + Args: + value: The value to check + + Returns: + True if this validator can validate the value + """ + + @abstractmethod + def validate(self, key: str, value: Any) -> ValidationResult: + """ + Validate the given value. + + Args: + key: The key/name of the field being validated + value: The value to validate + + Returns: + ValidationResult with validation outcome and sanitized data + """ + + def _check_security_threats(self, key: str, value: str) -> ValidationResult: + """ + Check for common security threats in a string value. + + This is a helper method that subclasses can use for threat detection. + + Args: + key: The field key for error messages + value: The string value to check + + Returns: + ValidationResult with threat detection outcome + """ + has_threat, threat_type = self.security_patterns.check_all_threats(value) + + if has_threat: + return ValidationResult.failure( + f"Potential {threat_type} detected in '{key}'", + ValidationSeverity.CRITICAL, + ) + + return ValidationResult.success() diff --git a/codeflow_engine/core/validation/composite.py b/codeflow_engine/core/validation/composite.py new file mode 100644 index 0000000..2ee4f5f --- /dev/null +++ b/codeflow_engine/core/validation/composite.py @@ -0,0 +1,226 @@ +""" +Composite Validator. + +This module implements the Composite pattern for validation, +allowing multiple type-specific validators to work together. +""" + +import re +from typing import Any + +import structlog + +from codeflow_engine.core.validation.base import BaseTypeValidator +from codeflow_engine.core.validation.result import ( + ValidationResult, + ValidationSeverity, + update_severity, +) +from codeflow_engine.core.validation.patterns import SecurityPatterns, DEFAULT_SECURITY_PATTERNS + + +logger = structlog.get_logger(__name__) + + +# Constants +MAX_KEY_LENGTH = 100 +SAFE_KEY_PATTERN = re.compile(r"^[a-zA-Z0-9_\-\.]+$") + + +class CompositeValidator: + """ + Composite validator that delegates to type-specific validators. + + This class implements the Composite pattern, allowing multiple validators + to be composed together while providing a unified interface. + + Following the Open/Closed Principle, new validators can be registered + without modifying this class. + + Usage: + validator = CompositeValidator() + validator.register(StringTypeValidator()) + validator.register(ArrayTypeValidator()) + + result = validator.validate_input({"name": "John", "age": 30}) + """ + + def __init__( + self, + security_patterns: SecurityPatterns | None = None, + validators: list[BaseTypeValidator] | None = None, + ) -> None: + """ + Initialize the composite validator. + + Args: + security_patterns: Optional custom security patterns + validators: Optional list of type validators to register + """ + self.security_patterns = security_patterns or DEFAULT_SECURITY_PATTERNS + self._validators: list[BaseTypeValidator] = validators or [] + + def register(self, validator: BaseTypeValidator) -> "CompositeValidator": + """ + Register a type validator. + + Args: + validator: The validator to register + + Returns: + Self for method chaining + """ + self._validators.append(validator) + return self + + def unregister(self, validator_type: type[BaseTypeValidator]) -> bool: + """ + Unregister all validators of a specific type. + + Args: + validator_type: The type of validator to remove + + Returns: + True if any validators were removed + """ + original_count = len(self._validators) + self._validators = [ + v for v in self._validators if not isinstance(v, validator_type) + ] + return len(self._validators) < original_count + + def validate_input( + self, + data: dict[str, Any], + schema: type | None = None, + ) -> ValidationResult: + """ + Validate input data comprehensively. + + Args: + data: Dictionary of data to validate + schema: Optional Pydantic model for schema validation + + Returns: + ValidationResult with validation outcome + """ + result = ValidationResult(is_valid=True) + sanitized_data: dict[str, Any] = {} + + try: + for key, value in data.items(): + # Validate key name first + if not self._is_safe_key(key): + result.add_error(f"Invalid key name: {key}", ValidationSeverity.HIGH) + continue + + # Validate the value + value_result = self._validate_value(key, value) + self._merge_result(result, value_result) + + if value_result.is_valid and value_result.sanitized_data is not None: + # Extract the actual value from the wrapper + sanitized_data[key] = self._unwrap_sanitized(value_result.sanitized_data) + + # Apply schema validation if provided + if schema and result.is_valid: + result = self._apply_schema(schema, sanitized_data, result) + else: + result.sanitized_data = sanitized_data + + # Log validation results + self._log_validation_result(result, data) + + return result + + except Exception: + logger.exception("Input validation error") + return ValidationResult.failure( + "Validation system error", + ValidationSeverity.CRITICAL, + ) + + def _validate_value(self, key: str, value: Any) -> ValidationResult: + """ + Validate a single value using registered validators. + + Args: + key: The key name for error messages + value: The value to validate + + Returns: + ValidationResult from the appropriate validator + """ + # Find the first validator that can handle this type + for validator in self._validators: + if validator.can_validate(value): + return validator.validate(key, value) + + # No validator found - pass through with wrapper + return ValidationResult.success({"value": value}) + + def _is_safe_key(self, key: str) -> bool: + """Check if a key name is safe for use.""" + return bool(SAFE_KEY_PATTERN.match(key)) and len(key) <= MAX_KEY_LENGTH + + def _merge_result(self, target: ValidationResult, source: ValidationResult) -> None: + """Merge a source result into the target.""" + if not source.is_valid: + target.is_valid = False + target.errors.extend(source.errors) + target.warnings.extend(source.warnings) + target.severity = update_severity(target.severity, source.severity) + + def _unwrap_sanitized(self, sanitized_data: dict[str, Any]) -> Any: + """ + Unwrap sanitized data from validator wrappers. + + Validators may wrap simple values in {"value": x} format. + This unwraps them for the final result. + """ + if ( + isinstance(sanitized_data, dict) + and len(sanitized_data) == 1 + and "value" in sanitized_data + ): + return sanitized_data["value"] + if isinstance(sanitized_data, dict) and "items" in sanitized_data: + return sanitized_data["items"] + return sanitized_data + + def _apply_schema( + self, + schema: type, + sanitized_data: dict[str, Any], + current_result: ValidationResult, + ) -> ValidationResult: + """Apply Pydantic schema validation.""" + try: + validated = schema(**sanitized_data) + if hasattr(validated, "dict"): + current_result.sanitized_data = validated.dict() + elif hasattr(validated, "model_dump"): + current_result.sanitized_data = validated.model_dump() + else: + current_result.sanitized_data = sanitized_data + except Exception as e: + current_result.add_error( + f"Schema validation failed: {e!s}", + ValidationSeverity.HIGH, + ) + + return current_result + + def _log_validation_result( + self, result: ValidationResult, data: dict[str, Any] + ) -> None: + """Log the validation result.""" + if not result.is_valid: + logger.warning( + "Input validation failed", + errors=result.errors, + severity=result.severity.value, + data_keys=list(data.keys()), + ) + else: + logger.debug("Input validation passed", data_keys=list(data.keys())) diff --git a/codeflow_engine/core/validation/patterns.py b/codeflow_engine/core/validation/patterns.py new file mode 100644 index 0000000..a3a02ea --- /dev/null +++ b/codeflow_engine/core/validation/patterns.py @@ -0,0 +1,130 @@ +""" +Centralized Security Patterns. + +This module consolidates all security threat detection patterns in one place +to eliminate duplication across validators (DRY principle). +""" + +import re +from dataclasses import dataclass, field +from typing import Pattern + + +@dataclass +class SecurityPatterns: + """ + Centralized repository of security threat patterns. + + All validators should use these patterns instead of defining their own. + This ensures consistency and makes updates easier. + + Attributes: + sql_injection: Patterns to detect SQL injection attempts + xss: Patterns to detect Cross-Site Scripting attempts + command_injection: Patterns to detect command injection attempts + path_traversal: Patterns to detect path traversal attempts + """ + + # SQL Injection patterns + sql_injection: list[str] = field(default_factory=lambda: [ + r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|EXECUTE|UNION|SCRIPT)\b)", + r"(\b(OR|AND)\b\s+\d+\s*=\s*\d+)", + r"(\b(OR|AND)\b\s+['\"]\w+['\"]\s*=\s*['\"]\w+['\"])", + r"(--|\b(COMMENT|REM)\b)", + r"(\b(WAITFOR|DELAY)\b)", + r"(\b(BENCHMARK|SLEEP)\b)", + r"(\bUNION\s+SELECT\b)", + ]) + + # XSS patterns + xss: list[str] = field(default_factory=lambda: [ + r"]*>.*?", + r"javascript:", + r"on\w+\s*=", + r"]*>", + r"]*>", + r"]*>", + r"]*>", + r"]*>", + r"]*>", + r"]*>", + ]) + + # Command injection patterns + command_injection: list[str] = field(default_factory=lambda: [ + r"[;&|`$(){}[\]]", + r"\b(cat|ls|pwd|whoami|id|uname|ps|top|kill|rm|cp|mv|chmod|chown)\b", + r"\b(netcat|nc|telnet|ssh|scp|wget|curl|ftp|sftp)\b", + r"\b(bash|sh|zsh|fish|powershell|cmd|command)\b", + r"(>|>>|<|\|)", + ]) + + # Path traversal patterns + path_traversal: list[str] = field(default_factory=lambda: [ + r"\.\./", + r"\.\.\\", + r"%2e%2e/", + r"%2e%2e\\", + ]) + + _compiled_patterns: dict[str, list[Pattern[str]]] = field( + default_factory=dict, init=False, repr=False + ) + + def __post_init__(self) -> None: + """Compile patterns for better performance.""" + self._compile_patterns() + + def _compile_patterns(self) -> None: + """Pre-compile regex patterns for performance.""" + self._compiled_patterns = { + "sql_injection": [re.compile(p, re.IGNORECASE) for p in self.sql_injection], + "xss": [re.compile(p, re.IGNORECASE) for p in self.xss], + "command_injection": [re.compile(p) for p in self.command_injection], + "path_traversal": [re.compile(p, re.IGNORECASE) for p in self.path_traversal], + } + + def check_sql_injection(self, value: str) -> bool: + """Check if value contains SQL injection patterns.""" + return self._check_patterns("sql_injection", value) + + def check_xss(self, value: str) -> bool: + """Check if value contains XSS patterns.""" + return self._check_patterns("xss", value) + + def check_command_injection(self, value: str) -> bool: + """Check if value contains command injection patterns.""" + return self._check_patterns("command_injection", value) + + def check_path_traversal(self, value: str) -> bool: + """Check if value contains path traversal patterns.""" + return self._check_patterns("path_traversal", value) + + def _check_patterns(self, pattern_type: str, value: str) -> bool: + """Check if value matches any pattern of the given type.""" + patterns = self._compiled_patterns.get(pattern_type, []) + return any(pattern.search(value) for pattern in patterns) + + def check_all_threats(self, value: str) -> tuple[bool, str | None]: + """ + Check for all security threats. + + Args: + value: The string to check + + Returns: + Tuple of (has_threat, threat_type) where threat_type is None if no threat + """ + if self.check_sql_injection(value): + return True, "SQL injection" + if self.check_xss(value): + return True, "XSS" + if self.check_command_injection(value): + return True, "command injection" + if self.check_path_traversal(value): + return True, "path traversal" + return False, None + + +# Global default instance for common use +DEFAULT_SECURITY_PATTERNS = SecurityPatterns() diff --git a/codeflow_engine/core/validation/result.py b/codeflow_engine/core/validation/result.py new file mode 100644 index 0000000..6592880 --- /dev/null +++ b/codeflow_engine/core/validation/result.py @@ -0,0 +1,125 @@ +""" +Validation Result types and utilities. + +This module provides the standard ValidationResult class and helper functions +for consistent validation result handling across all validators. +""" + +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field + + +class ValidationSeverity(Enum): + """Severity levels for validation issues.""" + + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + CRITICAL = "critical" + + +# Severity ordering for comparisons +SEVERITY_ORDER = { + ValidationSeverity.LOW: 0, + ValidationSeverity.MEDIUM: 1, + ValidationSeverity.HIGH: 2, + ValidationSeverity.CRITICAL: 3, +} + + +class ValidationResult(BaseModel): + """ + Result of input validation. + + Attributes: + is_valid: Whether the validation passed + errors: List of error messages + warnings: List of warning messages + sanitized_data: Sanitized version of the validated data + severity: The highest severity of any issue found + """ + + is_valid: bool + errors: list[str] = Field(default_factory=list) + warnings: list[str] = Field(default_factory=list) + sanitized_data: dict[str, Any] | None = None + severity: ValidationSeverity = ValidationSeverity.LOW + + @classmethod + def success(cls, sanitized_data: dict[str, Any] | None = None) -> "ValidationResult": + """Create a successful validation result.""" + return cls(is_valid=True, sanitized_data=sanitized_data) + + @classmethod + def failure( + cls, + error: str, + severity: ValidationSeverity = ValidationSeverity.MEDIUM, + ) -> "ValidationResult": + """Create a failed validation result.""" + return cls(is_valid=False, errors=[error], severity=severity) + + def add_error(self, error: str, severity: ValidationSeverity | None = None) -> None: + """Add an error to the result and update validity.""" + self.is_valid = False + self.errors.append(error) + if severity: + self.severity = update_severity(self.severity, severity) + + def add_warning(self, warning: str) -> None: + """Add a warning to the result.""" + self.warnings.append(warning) + + +def update_severity( + current: ValidationSeverity, new: ValidationSeverity +) -> ValidationSeverity: + """ + Update severity to the higher of the two values. + + This helper function eliminates the duplicated severity comparison logic + that was spread across multiple validators. + + Args: + current: The current severity level + new: The new severity level to consider + + Returns: + The higher of the two severity levels + """ + if SEVERITY_ORDER[new] > SEVERITY_ORDER[current]: + return new + return current + + +def merge_validation_results(results: list[ValidationResult]) -> ValidationResult: + """ + Merge multiple validation results into one. + + Args: + results: List of ValidationResult objects to merge + + Returns: + A single ValidationResult combining all results + """ + if not results: + return ValidationResult(is_valid=True) + + merged = ValidationResult(is_valid=True) + + for result in results: + if not result.is_valid: + merged.is_valid = False + merged.errors.extend(result.errors) + merged.warnings.extend(result.warnings) + merged.severity = update_severity(merged.severity, result.severity) + + # Merge sanitized data + if result.sanitized_data: + if merged.sanitized_data is None: + merged.sanitized_data = {} + merged.sanitized_data.update(result.sanitized_data) + + return merged diff --git a/codeflow_engine/core/validation/validators/__init__.py b/codeflow_engine/core/validation/validators/__init__.py new file mode 100644 index 0000000..8728d20 --- /dev/null +++ b/codeflow_engine/core/validation/validators/__init__.py @@ -0,0 +1,20 @@ +""" +Type-Specific Validators. + +This module provides validators for specific data types, each following +the Single Responsibility Principle. +""" + +from codeflow_engine.core.validation.validators.string_validator import StringTypeValidator +from codeflow_engine.core.validation.validators.number_validator import NumberTypeValidator +from codeflow_engine.core.validation.validators.array_validator import ArrayTypeValidator +from codeflow_engine.core.validation.validators.object_validator import ObjectTypeValidator +from codeflow_engine.core.validation.validators.file_validator import FileTypeValidator + +__all__ = [ + "ArrayTypeValidator", + "FileTypeValidator", + "NumberTypeValidator", + "ObjectTypeValidator", + "StringTypeValidator", +] diff --git a/codeflow_engine/core/validation/validators/array_validator.py b/codeflow_engine/core/validation/validators/array_validator.py new file mode 100644 index 0000000..974e66c --- /dev/null +++ b/codeflow_engine/core/validation/validators/array_validator.py @@ -0,0 +1,113 @@ +""" +Array Type Validator. + +Validates array/list inputs and their elements. +""" + +from typing import Any, Callable + +from codeflow_engine.core.validation.base import BaseTypeValidator +from codeflow_engine.core.validation.result import ( + ValidationResult, + ValidationSeverity, + update_severity, +) +from codeflow_engine.core.validation.patterns import SecurityPatterns + + +class ArrayTypeValidator(BaseTypeValidator): + """ + Validator for array/list values. + + Performs: + - Length validation + - Element validation (using provided element validator) + """ + + def __init__( + self, + max_length: int = 1000, + element_validator: Callable[[str, Any], ValidationResult] | None = None, + security_patterns: SecurityPatterns | None = None, + ) -> None: + """ + Initialize the array validator. + + Args: + max_length: Maximum allowed array length + element_validator: Optional validator function for array elements + security_patterns: Optional custom security patterns + """ + super().__init__(security_patterns) + self.max_length = max_length + self._element_validator = element_validator + + def set_element_validator( + self, validator: Callable[[str, Any], ValidationResult] + ) -> None: + """ + Set the validator function for array elements. + + This is used by CompositeValidator to inject the composite's + _validate_value method for recursive validation. + + Args: + validator: Function that takes (key, value) and returns ValidationResult + """ + self._element_validator = validator + + def can_validate(self, value: Any) -> bool: + """Check if this validator handles the value type.""" + return isinstance(value, (list, tuple)) + + def validate(self, key: str, value: Any) -> ValidationResult: + """Validate an array value.""" + if not isinstance(value, (list, tuple)): + return ValidationResult.failure( + f"Expected array for '{key}', got {type(value).__name__}", + ValidationSeverity.MEDIUM, + ) + + # Length validation + if len(value) > self.max_length: + return ValidationResult.failure( + f"Array too long for key '{key}': {len(value)} > {self.max_length}", + ValidationSeverity.MEDIUM, + ) + + result = ValidationResult(is_valid=True) + sanitized_array: list[Any] = [] + + for i, item in enumerate(value): + item_key = f"{key}[{i}]" + + # Validate element + if self._element_validator: + item_result = self._element_validator(item_key, item) + else: + # Default: pass through + item_result = ValidationResult.success({"value": item}) + + if not item_result.is_valid: + result.is_valid = False + result.errors.extend(item_result.errors) + result.warnings.extend(item_result.warnings) + result.severity = update_severity(result.severity, item_result.severity) + else: + # Unwrap single-value dict wrappers + sanitized_value = self._unwrap_value(item_result.sanitized_data) + sanitized_array.append(sanitized_value) + + if result.is_valid: + result.sanitized_data = {"items": sanitized_array} + + return result + + @staticmethod + def _unwrap_value(data: dict[str, Any] | None) -> Any: + """Unwrap single-value dict wrappers from element validation.""" + if data is None: + return None + if isinstance(data, dict) and len(data) == 1 and "value" in data: + return data["value"] + return data diff --git a/codeflow_engine/core/validation/validators/file_validator.py b/codeflow_engine/core/validation/validators/file_validator.py new file mode 100644 index 0000000..2b8b8c6 --- /dev/null +++ b/codeflow_engine/core/validation/validators/file_validator.py @@ -0,0 +1,139 @@ +""" +File Type Validator. + +Validates file uploads for size, extension, and content safety. +""" + +import html +from pathlib import Path +from typing import Any + +from codeflow_engine.core.validation.base import BaseTypeValidator +from codeflow_engine.core.validation.result import ValidationResult, ValidationSeverity +from codeflow_engine.core.validation.patterns import SecurityPatterns + + +class FileTypeValidator(BaseTypeValidator): + """ + Validator for file upload data. + + Performs: + - File size validation + - File extension validation + - Content validation for text files + """ + + # Default allowed extensions + DEFAULT_ALLOWED_EXTENSIONS = {".txt", ".json", ".yaml", ".yml", ".md"} + # Text extensions that should have content validated + TEXT_EXTENSIONS = {".txt", ".json", ".yaml", ".yml", ".md"} + # Default max file size (10MB) + DEFAULT_MAX_SIZE = 10 * 1024 * 1024 + + def __init__( + self, + allowed_extensions: set[str] | None = None, + max_size: int | None = None, + security_patterns: SecurityPatterns | None = None, + ) -> None: + """ + Initialize the file validator. + + Args: + allowed_extensions: Set of allowed file extensions (with dots) + max_size: Maximum file size in bytes + security_patterns: Optional custom security patterns + """ + super().__init__(security_patterns) + self.allowed_extensions = ( + allowed_extensions + if allowed_extensions is not None + else self.DEFAULT_ALLOWED_EXTENSIONS + ) + self.max_size = max_size if max_size is not None else self.DEFAULT_MAX_SIZE + + def can_validate(self, value: Any) -> bool: + """ + Check if this validator handles the value type. + + File uploads are expected as tuples of (filename, content). + """ + if isinstance(value, tuple) and len(value) == 2: + filename, content = value + return isinstance(filename, str) and isinstance(content, bytes) + return False + + def validate(self, key: str, value: Any) -> ValidationResult: + """Validate a file upload value.""" + if not self.can_validate(value): + return ValidationResult.failure( + f"Expected file upload tuple (filename, content) for '{key}'", + ValidationSeverity.MEDIUM, + ) + + filename, content = value + return self.validate_file_upload(filename, content) + + def validate_file_upload( + self, + filename: str, + content: bytes, + max_size: int | None = None, + ) -> ValidationResult: + """ + Validate a file upload. + + Args: + filename: The name of the uploaded file + content: The file content as bytes + max_size: Optional override for max file size + + Returns: + ValidationResult with validation outcome + """ + effective_max_size = max_size if max_size is not None else self.max_size + + # File size validation + if len(content) > effective_max_size: + return ValidationResult.failure( + f"File too large: {len(content)} > {effective_max_size}", + ValidationSeverity.MEDIUM, + ) + + # File extension validation + file_ext = Path(filename).suffix.lower() + if file_ext not in self.allowed_extensions: + return ValidationResult.failure( + f"File extension not allowed: {file_ext}", + ValidationSeverity.HIGH, + ) + + # Content validation for text files + if file_ext in self.TEXT_EXTENSIONS: + content_result = self._validate_text_content(content) + if not content_result.is_valid: + return content_result + + return ValidationResult.success({ + "filename": html.escape(filename), + "content": content, + "size": len(content), + "extension": file_ext, + }) + + def _validate_text_content(self, content: bytes) -> ValidationResult: + """Validate text file content for security threats.""" + try: + content_str = content.decode("utf-8") + except UnicodeDecodeError: + return ValidationResult.failure( + "File contains invalid UTF-8 encoding", + ValidationSeverity.HIGH, + ) + + # Check for security threats + threat_result = self._check_security_threats("file_content", content_str) + if not threat_result.is_valid: + return threat_result + + return ValidationResult.success() diff --git a/codeflow_engine/core/validation/validators/number_validator.py b/codeflow_engine/core/validation/validators/number_validator.py new file mode 100644 index 0000000..1785839 --- /dev/null +++ b/codeflow_engine/core/validation/validators/number_validator.py @@ -0,0 +1,77 @@ +""" +Number Type Validator. + +Validates numeric inputs for safe ranges. +""" + +from typing import Any + +from codeflow_engine.core.validation.base import BaseTypeValidator +from codeflow_engine.core.validation.result import ValidationResult, ValidationSeverity +from codeflow_engine.core.validation.patterns import SecurityPatterns + + +class NumberTypeValidator(BaseTypeValidator): + """ + Validator for numeric values (int and float). + + Performs: + - Integer range validation (within 32-bit signed range) + - Float range validation (within safe float range) + """ + + # Safe ranges for numeric types + INT_MIN = -(2**31) + INT_MAX = 2**31 - 1 + FLOAT_ABS_MAX = 1e308 + + def __init__( + self, + int_min: int | None = None, + int_max: int | None = None, + float_abs_max: float | None = None, + security_patterns: SecurityPatterns | None = None, + ) -> None: + """ + Initialize the number validator. + + Args: + int_min: Minimum integer value (defaults to -(2^31)) + int_max: Maximum integer value (defaults to 2^31-1) + float_abs_max: Maximum absolute float value (defaults to 1e308) + security_patterns: Optional custom security patterns + """ + super().__init__(security_patterns) + self.int_min = int_min if int_min is not None else self.INT_MIN + self.int_max = int_max if int_max is not None else self.INT_MAX + self.float_abs_max = float_abs_max if float_abs_max is not None else self.FLOAT_ABS_MAX + + def can_validate(self, value: Any) -> bool: + """Check if this validator handles the value type.""" + return isinstance(value, (int, float)) and not isinstance(value, bool) + + def validate(self, key: str, value: Any) -> ValidationResult: + """Validate a numeric value.""" + if not isinstance(value, (int, float)) or isinstance(value, bool): + return ValidationResult.failure( + f"Expected number for '{key}', got {type(value).__name__}", + ValidationSeverity.MEDIUM, + ) + + # Integer range validation + if isinstance(value, int): + if value < self.int_min or value > self.int_max: + return ValidationResult.failure( + f"Integer out of safe range for key '{key}': {value}", + ValidationSeverity.MEDIUM, + ) + + # Float range validation + if isinstance(value, float): + if abs(value) > self.float_abs_max: + return ValidationResult.failure( + f"Float out of safe range for key '{key}': {value}", + ValidationSeverity.MEDIUM, + ) + + return ValidationResult.success({"value": value}) diff --git a/codeflow_engine/core/validation/validators/object_validator.py b/codeflow_engine/core/validation/validators/object_validator.py new file mode 100644 index 0000000..0a06ba9 --- /dev/null +++ b/codeflow_engine/core/validation/validators/object_validator.py @@ -0,0 +1,122 @@ +""" +Object/Dict Type Validator. + +Validates dictionary/object inputs and their nested values. +""" + +import re +from typing import Any, Callable + +from codeflow_engine.core.validation.base import BaseTypeValidator +from codeflow_engine.core.validation.result import ( + ValidationResult, + ValidationSeverity, + update_severity, +) +from codeflow_engine.core.validation.patterns import SecurityPatterns + + +# Constants +MAX_KEY_LENGTH = 100 +SAFE_KEY_PATTERN = re.compile(r"^[a-zA-Z0-9_\-\.]+$") + + +class ObjectTypeValidator(BaseTypeValidator): + """ + Validator for dictionary/object values. + + Performs: + - Key name validation + - Nested value validation (using provided value validator) + """ + + def __init__( + self, + value_validator: Callable[[str, Any], ValidationResult] | None = None, + security_patterns: SecurityPatterns | None = None, + ) -> None: + """ + Initialize the object validator. + + Args: + value_validator: Optional validator function for nested values + security_patterns: Optional custom security patterns + """ + super().__init__(security_patterns) + self._value_validator = value_validator + + def set_value_validator( + self, validator: Callable[[str, Any], ValidationResult] + ) -> None: + """ + Set the validator function for nested values. + + This is used by CompositeValidator to inject the composite's + _validate_value method for recursive validation. + + Args: + validator: Function that takes (key, value) and returns ValidationResult + """ + self._value_validator = validator + + def can_validate(self, value: Any) -> bool: + """Check if this validator handles the value type.""" + return isinstance(value, dict) + + def validate(self, key: str, value: Any) -> ValidationResult: + """Validate a dictionary value.""" + if not isinstance(value, dict): + return ValidationResult.failure( + f"Expected object for '{key}', got {type(value).__name__}", + ValidationSeverity.MEDIUM, + ) + + result = ValidationResult(is_valid=True) + sanitized_object: dict[str, Any] = {} + + for obj_key, obj_value in value.items(): + # Validate nested key name + if not self._is_safe_key(str(obj_key)): + result.add_error( + f"Invalid nested key name: {key}.{obj_key}", + ValidationSeverity.HIGH, + ) + continue + + nested_key = f"{key}.{obj_key}" + + # Validate nested value + if self._value_validator: + obj_result = self._value_validator(nested_key, obj_value) + else: + # Default: pass through + obj_result = ValidationResult.success({"value": obj_value}) + + if not obj_result.is_valid: + result.is_valid = False + result.errors.extend(obj_result.errors) + result.warnings.extend(obj_result.warnings) + result.severity = update_severity(result.severity, obj_result.severity) + else: + # Unwrap single-value dict wrappers + sanitized_value = self._unwrap_value(obj_result.sanitized_data) + sanitized_object[obj_key] = sanitized_value + + if result.is_valid: + result.sanitized_data = sanitized_object + + return result + + @staticmethod + def _is_safe_key(key: str) -> bool: + """Check if a key name is safe.""" + return bool(SAFE_KEY_PATTERN.match(key)) and len(key) <= MAX_KEY_LENGTH + + @staticmethod + def _unwrap_value(data: dict[str, Any] | None) -> Any: + """Unwrap single-value dict wrappers from value validation.""" + if data is None: + return None + if isinstance(data, dict) and len(data) == 1 and "value" in data: + return data["value"] + return data diff --git a/codeflow_engine/core/validation/validators/string_validator.py b/codeflow_engine/core/validation/validators/string_validator.py new file mode 100644 index 0000000..cdd852e --- /dev/null +++ b/codeflow_engine/core/validation/validators/string_validator.py @@ -0,0 +1,104 @@ +""" +String Type Validator. + +Validates string inputs for security threats and format requirements. +""" + +import html +import re +from typing import Any + +from codeflow_engine.core.validation.base import BaseTypeValidator +from codeflow_engine.core.validation.result import ValidationResult, ValidationSeverity +from codeflow_engine.core.validation.patterns import SecurityPatterns + + +class StringTypeValidator(BaseTypeValidator): + """ + Validator for string values. + + Performs: + - Length validation + - Security threat detection (SQL injection, XSS, command injection) + - HTML entity sanitization + - Format validation (email, URL) + """ + + def __init__( + self, + max_length: int = 1000, + security_patterns: SecurityPatterns | None = None, + ) -> None: + """ + Initialize the string validator. + + Args: + max_length: Maximum allowed string length + security_patterns: Optional custom security patterns + """ + super().__init__(security_patterns) + self.max_length = max_length + + def can_validate(self, value: Any) -> bool: + """Check if this validator handles the value type.""" + return isinstance(value, str) + + def validate(self, key: str, value: Any) -> ValidationResult: + """Validate a string value.""" + if not isinstance(value, str): + return ValidationResult.failure( + f"Expected string for '{key}', got {type(value).__name__}", + ValidationSeverity.MEDIUM, + ) + + # Length validation + if len(value) > self.max_length: + return ValidationResult.failure( + f"String too long for key '{key}': {len(value)} > {self.max_length}", + ValidationSeverity.MEDIUM, + ) + + # Security threat check + threat_result = self._check_security_threats(key, value) + if not threat_result.is_valid: + return threat_result + + # Sanitize HTML entities + sanitized_value = html.escape(value) + + # Format validation for special keys + format_result = self._validate_format(key, sanitized_value) + if not format_result.is_valid: + return format_result + + return ValidationResult.success({"value": sanitized_value}) + + def _validate_format(self, key: str, value: str) -> ValidationResult: + """Validate format for special field types.""" + key_lower = key.lower() + + if "email" in key_lower and not self._is_valid_email(value): + return ValidationResult.failure( + f"Invalid email format in '{key}'", + ValidationSeverity.MEDIUM, + ) + + if "url" in key_lower and not self._is_valid_url(value): + return ValidationResult.failure( + f"Invalid URL format in '{key}'", + ValidationSeverity.MEDIUM, + ) + + return ValidationResult.success() + + @staticmethod + def _is_valid_email(email: str) -> bool: + """Validate email format.""" + email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" + return bool(re.match(email_pattern, email)) + + @staticmethod + def _is_valid_url(url: str) -> bool: + """Validate URL format.""" + url_pattern = r"^https?://[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}(/.*)?$" + return bool(re.match(url_pattern, url)) diff --git a/codeflow_engine/security/validation_models.py b/codeflow_engine/security/validation_models.py index 8cac7ed..ea148f4 100644 --- a/codeflow_engine/security/validation_models.py +++ b/codeflow_engine/security/validation_models.py @@ -1,23 +1,14 @@ -from enum import Enum -from typing import Any +""" +Validation Models - Backwards compatibility module. -from pydantic import BaseModel, Field +This module re-exports from codeflow_engine.core.validation for backwards compatibility. +New code should import directly from codeflow_engine.core.validation. +""" +# Re-export from core for backwards compatibility +from codeflow_engine.core.validation.result import ( + ValidationResult, + ValidationSeverity, +) -class ValidationSeverity(Enum): - """Severity levels for validation issues.""" - - LOW = "low" - MEDIUM = "medium" - HIGH = "high" - CRITICAL = "critical" - - -class ValidationResult(BaseModel): - """Result of input validation.""" - - is_valid: bool - errors: list[str] = Field(default_factory=list) - warnings: list[str] = Field(default_factory=list) - sanitized_data: dict[str, Any] | None = None - severity: ValidationSeverity = ValidationSeverity.LOW +__all__ = ["ValidationResult", "ValidationSeverity"] diff --git a/codeflow_engine/security/validators/__init__.py b/codeflow_engine/security/validators/__init__.py index 2e0f9f5..beeed32 100644 --- a/codeflow_engine/security/validators/__init__.py +++ b/codeflow_engine/security/validators/__init__.py @@ -1,4 +1,58 @@ +""" +Security Validators package - Backwards compatibility module. + +This module provides backwards compatibility for existing code. +New code should import directly from codeflow_engine.core.validation. +""" + +# Main enterprise validator from codeflow_engine.security.validators.base import EnterpriseInputValidator +# Re-export core validation framework +from codeflow_engine.core.validation import ( + BaseTypeValidator, + CompositeValidator, + SecurityPatterns, + ValidationResult, + ValidationSeverity, +) + +# Re-export type validators +from codeflow_engine.core.validation.validators import ( + ArrayTypeValidator, + FileTypeValidator, + NumberTypeValidator, + ObjectTypeValidator, + StringTypeValidator, +) + +# Legacy aliases +from codeflow_engine.security.validators.string_validator import StringValidator +from codeflow_engine.security.validators.number_validator import NumberValidator +from codeflow_engine.security.validators.array_validator import ArrayValidator +from codeflow_engine.security.validators.object_validator import ObjectValidator +from codeflow_engine.security.validators.file_validator import FileValidator + -__all__ = ["EnterpriseInputValidator"] +__all__ = [ + # Main validator + "EnterpriseInputValidator", + # Core framework + "BaseTypeValidator", + "CompositeValidator", + "SecurityPatterns", + "ValidationResult", + "ValidationSeverity", + # Type validators + "ArrayTypeValidator", + "FileTypeValidator", + "NumberTypeValidator", + "ObjectTypeValidator", + "StringTypeValidator", + # Legacy aliases + "ArrayValidator", + "FileValidator", + "NumberValidator", + "ObjectValidator", + "StringValidator", +] diff --git a/codeflow_engine/security/validators/array_validator.py b/codeflow_engine/security/validators/array_validator.py index 24fa21e..b779323 100644 --- a/codeflow_engine/security/validators/array_validator.py +++ b/codeflow_engine/security/validators/array_validator.py @@ -1,57 +1,16 @@ -from codeflow_engine.security.validation_models import ValidationResult, ValidationSeverity +""" +Array Validator - Backwards compatibility module. +This module re-exports from codeflow_engine.core.validation for backwards compatibility. +New code should import directly from codeflow_engine.core.validation.validators. +""" -class ArrayValidator: - """Array validation functionality.""" +# Re-export from core for backwards compatibility +from codeflow_engine.core.validation.validators import ArrayTypeValidator +from codeflow_engine.core.validation import ValidationResult, ValidationSeverity - # Expected to be provided by composite validator - max_array_length: int = 1000 - def _validate_value(self, key: str, value): # type: ignore[override] - from codeflow_engine.security.validation_models import ValidationResult +# Legacy alias for backwards compatibility +ArrayValidator = ArrayTypeValidator - return ValidationResult(is_valid=True, sanitized_data={"value": value}) - - def _validate_array(self, key: str, value: list | tuple) -> ValidationResult: - """Validate array input.""" - result = ValidationResult(is_valid=True) - - # Length validation - if len(value) > self.max_array_length: # type: ignore[attr-defined] - result.errors.append( - f"Array too long for key '{key}': {len(value)} > {self.max_array_length}" - ) - result.severity = ValidationSeverity.MEDIUM - result.is_valid = False - return result - - sanitized_array = [] - for i, item in enumerate(value): - item_result = self._validate_value(f"{key}[{i}]", item) # type: ignore[attr-defined] - if not item_result.is_valid: - result.errors.extend(item_result.errors) - result.warnings.extend(item_result.warnings) - result.is_valid = False - - # Update severity - if item_result.severity.value == "critical": - result.severity = ValidationSeverity.CRITICAL - elif ( - item_result.severity.value == "high" - and result.severity != ValidationSeverity.CRITICAL - ): - result.severity = ValidationSeverity.HIGH - elif ( - isinstance(item_result.sanitized_data, dict) - and "value" in item_result.sanitized_data - and len(item_result.sanitized_data) == 1 - ): - sanitized_array.append(item_result.sanitized_data["value"]) - else: - sanitized_array.append(item_result.sanitized_data) - - # Wrap array in a dictionary to satisfy type constraints - if result.is_valid: - result.sanitized_data = {"items": sanitized_array} - - return result +__all__ = ["ArrayValidator", "ArrayTypeValidator", "ValidationResult", "ValidationSeverity"] diff --git a/codeflow_engine/security/validators/base.py b/codeflow_engine/security/validators/base.py index 4f3500f..28b2358 100644 --- a/codeflow_engine/security/validators/base.py +++ b/codeflow_engine/security/validators/base.py @@ -1,150 +1,188 @@ +""" +Enterprise Input Validator. + +This module provides the EnterpriseInputValidator class that composes +multiple type validators using the new core validation framework. + +The implementation follows composition over inheritance, using the +CompositeValidator pattern from codeflow_engine.core.validation. +""" + from typing import Any -# mypy: disable-error-code=misc import structlog -from codeflow_engine.security.validation_models import ValidationResult, ValidationSeverity -from codeflow_engine.security.validators.array_validator import ArrayValidator -from codeflow_engine.security.validators.file_validator import FileValidator -from codeflow_engine.security.validators.number_validator import NumberValidator -from codeflow_engine.security.validators.object_validator import ObjectValidator -from codeflow_engine.security.validators.string_validator import StringValidator +# Import from core validation framework +from codeflow_engine.core.validation import ( + CompositeValidator, + SecurityPatterns, + ValidationResult, + ValidationSeverity, +) +from codeflow_engine.core.validation.validators import ( + ArrayTypeValidator, + FileTypeValidator, + NumberTypeValidator, + ObjectTypeValidator, + StringTypeValidator, +) logger = structlog.get_logger(__name__) -class EnterpriseInputValidator( - StringValidator, ArrayValidator, ObjectValidator, NumberValidator, FileValidator -): - """Enterprise-grade input validation and sanitization.""" - - # Security pattern overrides for enterprise rules are set per-instance in __init__ - - def __init__(self): - self.max_string_length = 10000 - self.max_array_length = 1000 - self.allowed_file_extensions = {".txt", ".json", ".yaml", ".yml", ".md"} - - # Override base string validator patterns with stricter enterprise rules - self.SQL_INJECTION_PATTERNS = [ - r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|UNION)\b)", - r"(--|#|/\*|\*/)", - r"(\b(OR|AND)\s+\d+\s*=\s*\d+)", - r"(\bUNION\s+SELECT\b)", - ] - - self.XSS_PATTERNS = [ - r"]*>.*?", - r"javascript:", - r"on\w+\s*=", - r"]*>.*?", - r"]*>.*?", - r"]*>.*?", - ] - - self.COMMAND_INJECTION_PATTERNS = [ - r"[;&|`$(){}[\]\\]", - r"\b(rm|del|format|shutdown|reboot|halt)\b", - r"(>|>>|<|\|)", - ] +class EnterpriseInputValidator: + """ + Enterprise-grade input validation and sanitization. + + This class uses the CompositeValidator pattern from the core validation + framework, providing a unified interface while leveraging type-specific + validators. + + The validator automatically: + - Validates key names for safety + - Detects and blocks security threats (SQL injection, XSS, etc.) + - Sanitizes input data + - Validates against optional Pydantic schemas + """ + + def __init__( + self, + max_string_length: int = 10000, + max_array_length: int = 1000, + allowed_file_extensions: set[str] | None = None, + ) -> None: + """ + Initialize the enterprise validator. + + Args: + max_string_length: Maximum allowed string length + max_array_length: Maximum allowed array length + allowed_file_extensions: Set of allowed file extensions + """ + self.max_string_length = max_string_length + self.max_array_length = max_array_length + self.allowed_file_extensions = allowed_file_extensions or { + ".txt", ".json", ".yaml", ".yml", ".md" + } + + # Create custom security patterns with stricter enterprise rules + self._security_patterns = SecurityPatterns( + sql_injection=[ + r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|UNION)\b)", + r"(--|#|/\*|\*/)", + r"(\b(OR|AND)\s+\d+\s*=\s*\d+)", + r"(\bUNION\s+SELECT\b)", + ], + xss=[ + r"]*>.*?", + r"javascript:", + r"on\w+\s*=", + r"]*>.*?", + r"]*>.*?", + r"]*>.*?", + ], + command_injection=[ + r"[;&|`$(){}[\]\\]", + r"\b(rm|del|format|shutdown|reboot|halt)\b", + r"(>|>>|<|\|)", + ], + ) + + # Initialize type validators + string_validator = StringTypeValidator( + max_length=max_string_length, + security_patterns=self._security_patterns, + ) + number_validator = NumberTypeValidator( + security_patterns=self._security_patterns, + ) + array_validator = ArrayTypeValidator( + max_length=max_array_length, + security_patterns=self._security_patterns, + ) + object_validator = ObjectTypeValidator( + security_patterns=self._security_patterns, + ) + file_validator = FileTypeValidator( + allowed_extensions=self.allowed_file_extensions, + security_patterns=self._security_patterns, + ) + + # Create composite validator with all type validators + self._composite = CompositeValidator( + security_patterns=self._security_patterns, + validators=[ + string_validator, + number_validator, + array_validator, + object_validator, + file_validator, + ], + ) + + # Wire up recursive validation for nested structures + array_validator.set_element_validator(self._composite._validate_value) + object_validator.set_value_validator(self._composite._validate_value) def validate_input( self, data: dict[str, Any], schema: type | None = None ) -> ValidationResult: - """Comprehensive input validation.""" - result = ValidationResult(is_valid=True) - sanitized_data = {} - - try: - for key, value in data.items(): - # Validate key name - if not self._is_safe_key(key): - result.errors.append(f"Invalid key name: {key}") - result.severity = ValidationSeverity.HIGH - result.is_valid = False - continue - - # Validate and sanitize value - validation_result = self._validate_value(key, value) - - if not validation_result.is_valid: - result.errors.extend(validation_result.errors) - result.warnings.extend(validation_result.warnings) - result.is_valid = False - - # Update severity to highest found - if validation_result.severity.value == "critical": - result.severity = ValidationSeverity.CRITICAL - elif ( - validation_result.severity.value == "high" - and result.severity != ValidationSeverity.CRITICAL - ): - result.severity = ValidationSeverity.HIGH - elif ( - validation_result.severity.value == "medium" - and result.severity - not in [ - ValidationSeverity.CRITICAL, - ValidationSeverity.HIGH, - ] - ): - result.severity = ValidationSeverity.MEDIUM - else: - sanitized_data[key] = validation_result.sanitized_data - - # Apply schema validation if provided - if schema and result.is_valid: - try: - # Use the schema with the dictionary unpacked - validated_data = schema(**sanitized_data) - if hasattr(validated_data, "dict"): - result.sanitized_data = validated_data.dict() - else: - result.sanitized_data = sanitized_data - except Exception as e: - result.errors.append(f"Schema validation failed: {e!s}") - result.is_valid = False - result.severity = ValidationSeverity.HIGH - else: - result.sanitized_data = sanitized_data - - # Log validation results - if not result.is_valid: - logger.warning( - "Input validation failed", - errors=result.errors, - severity=result.severity.value, - data_keys=list(data.keys()), - ) - else: - logger.debug("Input validation passed", data_keys=list(data.keys())) - - return result - - except Exception: - logger.exception("Input validation error") - return ValidationResult( - is_valid=False, - errors=["Validation system error"], - severity=ValidationSeverity.CRITICAL, - ) + """ + Comprehensive input validation. + + Args: + data: Dictionary of input data to validate + schema: Optional Pydantic model for schema validation + + Returns: + ValidationResult with validation outcome and sanitized data + """ + return self._composite.validate_input(data, schema) + + def validate_file_upload( + self, filename: str, content: bytes, max_size: int = 10 * 1024 * 1024 + ) -> ValidationResult: + """ + Validate file upload. + + Args: + filename: The uploaded file name + content: The file content as bytes + max_size: Maximum allowed file size in bytes + + Returns: + ValidationResult with validation outcome + """ + file_validator = FileTypeValidator( + allowed_extensions=self.allowed_file_extensions, + max_size=max_size, + security_patterns=self._security_patterns, + ) + return file_validator.validate_file_upload(filename, content) + + # Backward compatibility methods for existing code + + def _is_safe_key(self, key: str) -> bool: + """Check if key name is safe. Delegates to composite validator.""" + return self._composite._is_safe_key(key) def _validate_value(self, key: str, value: Any) -> ValidationResult: - """Validate individual value based on its type.""" - result = ValidationResult(is_valid=True) - - if isinstance(value, str): - result = self._validate_string(key, value) - elif isinstance(value, list | tuple): - result = self._validate_array(key, value) - elif isinstance(value, dict): - result = self._validate_object(key, value) - elif isinstance(value, int | float): - result = self._validate_number(key, value) - else: - result.sanitized_data = { - "value": value - } # Wrap in dict to satisfy type constraints - - return result + """Validate individual value. Delegates to composite validator.""" + return self._composite._validate_value(key, value) + + # Legacy pattern access for tests/backward compatibility + @property + def SQL_INJECTION_PATTERNS(self) -> list[str]: + """Get SQL injection patterns for backward compatibility.""" + return self._security_patterns.sql_injection + + @property + def XSS_PATTERNS(self) -> list[str]: + """Get XSS patterns for backward compatibility.""" + return self._security_patterns.xss + + @property + def COMMAND_INJECTION_PATTERNS(self) -> list[str]: + """Get command injection patterns for backward compatibility.""" + return self._security_patterns.command_injection diff --git a/codeflow_engine/security/validators/file_validator.py b/codeflow_engine/security/validators/file_validator.py index 6aa242c..12875be 100644 --- a/codeflow_engine/security/validators/file_validator.py +++ b/codeflow_engine/security/validators/file_validator.py @@ -1,62 +1,16 @@ -import html -from pathlib import Path +""" +File Validator - Backwards compatibility module. -from codeflow_engine.security.validation_models import ValidationResult, ValidationSeverity +This module re-exports from codeflow_engine.core.validation for backwards compatibility. +New code should import directly from codeflow_engine.core.validation.validators. +""" +# Re-export from core for backwards compatibility +from codeflow_engine.core.validation.validators import FileTypeValidator +from codeflow_engine.core.validation import ValidationResult, ValidationSeverity -class FileValidator: - """File validation functionality.""" - # Expected to be configured by composite validator - allowed_file_extensions: set[str] = {".txt", ".json", ".yaml", ".yml", ".md"} +# Legacy alias for backwards compatibility +FileValidator = FileTypeValidator - def _validate_string(self, key: str, value: str): # type: ignore[override] - from codeflow_engine.security.validation_models import ValidationResult - - return ValidationResult(is_valid=True, sanitized_data={"value": value}) - - def validate_file_upload( - self, filename: str, content: bytes, max_size: int = 10 * 1024 * 1024 - ) -> ValidationResult: - """Validate file upload.""" - result = ValidationResult(is_valid=True) - - # File size validation - if len(content) > max_size: - result.errors.append(f"File too large: {len(content)} > {max_size}") - result.severity = ValidationSeverity.MEDIUM - result.is_valid = False - return result - - # File extension validation - file_ext = Path(filename).suffix.lower() - if file_ext not in self.allowed_file_extensions: # type: ignore[attr-defined] - result.errors.append(f"File extension not allowed: {file_ext}") - result.severity = ValidationSeverity.HIGH - result.is_valid = False - return result - - # Content validation for text files - if file_ext in {".txt", ".json", ".yaml", ".yml", ".md"}: - try: - content_str = content.decode("utf-8") - content_result = self._validate_string("file_content", content_str) # type: ignore[attr-defined] - if not content_result.is_valid: - result.errors.extend(content_result.errors) - result.severity = content_result.severity - result.is_valid = False - return result - except UnicodeDecodeError: - result.errors.append("File contains invalid UTF-8 encoding") - result.severity = ValidationSeverity.HIGH - result.is_valid = False - return result - - result.sanitized_data = { - "filename": html.escape(filename), - "content": content, - "size": len(content), - "extension": file_ext, - } - - return result +__all__ = ["FileValidator", "FileTypeValidator", "ValidationResult", "ValidationSeverity"] diff --git a/codeflow_engine/security/validators/number_validator.py b/codeflow_engine/security/validators/number_validator.py index 7e438a0..94bb3a3 100644 --- a/codeflow_engine/security/validators/number_validator.py +++ b/codeflow_engine/security/validators/number_validator.py @@ -1,30 +1,16 @@ -from codeflow_engine.security.validation_models import ValidationResult, ValidationSeverity +""" +Number Validator - Backwards compatibility module. +This module re-exports from codeflow_engine.core.validation for backwards compatibility. +New code should import directly from codeflow_engine.core.validation.validators. +""" -class NumberValidator: - """Number validation functionality.""" +# Re-export from core for backwards compatibility +from codeflow_engine.core.validation.validators import NumberTypeValidator +from codeflow_engine.core.validation import ValidationResult, ValidationSeverity - def _validate_number(self, key: str, value: float) -> ValidationResult: - """Validate numeric input.""" - result = ValidationResult(is_valid=True) - # Range validation - if isinstance(value, int) and (value < -(2**31) or value > 2**31 - 1): - result.errors.append(f"Integer out of safe range for key '{key}': {value}") - result.severity = ValidationSeverity.MEDIUM - result.is_valid = False - return result +# Legacy alias for backwards compatibility +NumberValidator = NumberTypeValidator - if isinstance(value, float): - FLOAT_ABS_MAX = 1e308 - if abs(value) > FLOAT_ABS_MAX: - result.errors.append( - f"Float out of safe range for key '{key}': {value}" - ) - result.severity = ValidationSeverity.MEDIUM - result.is_valid = False - return result - - # Wrap numeric value in a dictionary to satisfy type constraints - result.sanitized_data = {"value": value} - return result +__all__ = ["NumberValidator", "NumberTypeValidator", "ValidationResult", "ValidationSeverity"] diff --git a/codeflow_engine/security/validators/object_validator.py b/codeflow_engine/security/validators/object_validator.py index 0fbb7db..de209c3 100644 --- a/codeflow_engine/security/validators/object_validator.py +++ b/codeflow_engine/security/validators/object_validator.py @@ -1,54 +1,16 @@ -from codeflow_engine.security.validation_models import ValidationResult, ValidationSeverity +""" +Object Validator - Backwards compatibility module. +This module re-exports from codeflow_engine.core.validation for backwards compatibility. +New code should import directly from codeflow_engine.core.validation.validators. +""" -class ObjectValidator: - """Object validation functionality.""" +# Re-export from core for backwards compatibility +from codeflow_engine.core.validation.validators import ObjectTypeValidator +from codeflow_engine.core.validation import ValidationResult, ValidationSeverity - # These helper methods are expected to be provided by mixins on the concrete class - def _is_safe_key(self, key: str) -> bool: # type: ignore[override] - return True - def _validate_value(self, key: str, value): # type: ignore[override] - from codeflow_engine.security.validation_models import ValidationResult +# Legacy alias for backwards compatibility +ObjectValidator = ObjectTypeValidator - return ValidationResult(is_valid=True, sanitized_data={"value": value}) - - def _validate_object(self, key: str, value: dict) -> ValidationResult: - """Validate object input.""" - result = ValidationResult(is_valid=True) - sanitized_object = {} - - for obj_key, obj_value in value.items(): - if not self._is_safe_key(obj_key): - result.errors.append(f"Invalid nested key name: {key}.{obj_key}") - result.severity = ValidationSeverity.HIGH - result.is_valid = False - continue - - obj_result = self._validate_value(f"{key}.{obj_key}", obj_value) - if not obj_result.is_valid: - result.errors.extend(obj_result.errors) - result.warnings.extend(obj_result.warnings) - result.is_valid = False - - # Update severity - if obj_result.severity.value == "critical": - result.severity = ValidationSeverity.CRITICAL - elif ( - obj_result.severity.value == "high" - and result.severity != ValidationSeverity.CRITICAL - ): - result.severity = ValidationSeverity.HIGH - elif ( - isinstance(obj_result.sanitized_data, dict) - and "value" in obj_result.sanitized_data - and len(obj_result.sanitized_data) == 1 - ): - sanitized_object[obj_key] = obj_result.sanitized_data["value"] - else: - sanitized_object[obj_key] = obj_result.sanitized_data - - if result.is_valid: - result.sanitized_data = sanitized_object - - return result +__all__ = ["ObjectValidator", "ObjectTypeValidator", "ValidationResult", "ValidationSeverity"] diff --git a/codeflow_engine/security/validators/string_validator.py b/codeflow_engine/security/validators/string_validator.py index dd3e6a4..d8e24e4 100644 --- a/codeflow_engine/security/validators/string_validator.py +++ b/codeflow_engine/security/validators/string_validator.py @@ -1,138 +1,16 @@ -import html -import re +""" +String Validator - Backwards compatibility module. -from codeflow_engine.security.validation_models import ValidationResult, ValidationSeverity +This module re-exports from codeflow_engine.core.validation for backwards compatibility. +New code should import directly from codeflow_engine.core.validation.validators. +""" +# Re-export from core for backwards compatibility +from codeflow_engine.core.validation.validators import StringTypeValidator +from codeflow_engine.core.validation import ValidationResult, ValidationSeverity -# Constants -MAX_KEY_LENGTH = 100 +# Legacy alias for backwards compatibility +StringValidator = StringTypeValidator -class StringValidator: - """String validation functionality.""" - - def __init__(self, max_string_length: int = 1000): - self.max_string_length = max_string_length - self.SQL_INJECTION_PATTERNS = [ - r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|EXECUTE|UNION|SCRIPT)\b)", - r"(\b(OR|AND)\b\s+\d+\s*=\s*\d+)", - r"(\b(OR|AND)\b\s+['\"]\w+['\"]\s*=\s*['\"]\w+['\"])", - r"(--|\b(COMMENT|REM)\b)", - r"(\b(WAITFOR|DELAY)\b)", - r"(\b(BENCHMARK|SLEEP)\b)", - ] - self.XSS_PATTERNS = [ - r"]*>.*?", - r"javascript:", - r"on\w+\s*=", - r"]*>", - r"]*>", - r"]*>", - r"]*>", - r"]*>", - r"]*>", - r"]*>", - ] - self.COMMAND_INJECTION_PATTERNS = [ - r"[;&|`$(){}[\]]", - r"\b(cat|ls|pwd|whoami|id|uname|ps|top|kill|rm|cp|mv|chmod|chown)\b", - r"\b(netcat|nc|telnet|ssh|scp|wget|curl|ftp|sftp)\b", - r"\b(bash|sh|zsh|fish|powershell|cmd|command)\b", - ] - - def _validate_string(self, key: str, value: str) -> ValidationResult: - """Validate string input for security threats.""" - result = ValidationResult(is_valid=True) - - # Check length validation - if len(value) > self.max_string_length: - result.errors.append( - f"String too long for key '{key}': {len(value)} > {self.max_string_length}" - ) - result.severity = ValidationSeverity.MEDIUM - result.is_valid = False - return result - - # Check for security threats - threat_check = self._check_security_threats(key, value) - if not threat_check.is_valid: - return threat_check - - # Sanitize HTML entities - sanitized_value = html.escape(value) - - # Check format validation - format_check = self._check_format_validation(key, sanitized_value) - if not format_check.is_valid: - return format_check - - # Wrap string value in a dictionary to satisfy type constraints - result.sanitized_data = {"value": sanitized_value} - return result - - def _check_security_threats(self, key: str, value: str) -> ValidationResult: - """Check for security threats in the string.""" - result = ValidationResult(is_valid=True) - - # SQL Injection detection - for pattern in self.SQL_INJECTION_PATTERNS: - if re.search(pattern, value, re.IGNORECASE): - result.errors.append(f"Potential SQL injection detected in '{key}'") - result.severity = ValidationSeverity.CRITICAL - result.is_valid = False - return result - - # XSS detection - for pattern in self.XSS_PATTERNS: - if re.search(pattern, value, re.IGNORECASE): - result.errors.append(f"Potential XSS attack detected in '{key}'") - result.severity = ValidationSeverity.CRITICAL - result.is_valid = False - return result - - # Command injection detection - for pattern in self.COMMAND_INJECTION_PATTERNS: - if re.search(pattern, value): - result.errors.append(f"Potential command injection detected in '{key}'") - result.severity = ValidationSeverity.CRITICAL - result.is_valid = False - return result - - return result - - def _check_format_validation( - self, key: str, sanitized_value: str - ) -> ValidationResult: - """Check format validation for specific contexts.""" - result = ValidationResult(is_valid=True) - - # Additional sanitization for specific contexts - if "email" in key.lower() and not self._is_valid_email(sanitized_value): - result.errors.append(f"Invalid email format in '{key}'") - result.severity = ValidationSeverity.MEDIUM - result.is_valid = False - return result - - if "url" in key.lower() and not self._is_valid_url(sanitized_value): - result.errors.append(f"Invalid URL format in '{key}'") - result.severity = ValidationSeverity.MEDIUM - result.is_valid = False - return result - - return result - - def _is_safe_key(self, key: str) -> bool: - """Check if key name is safe.""" - # Allow alphanumeric, underscore, hyphen, and dot - safe_pattern = r"^[a-zA-Z0-9_\-\.]+$" - return bool(re.match(safe_pattern, key)) and len(key) <= MAX_KEY_LENGTH - - def _is_valid_email(self, email: str) -> bool: - """Validate email format.""" - email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" - return bool(re.match(email_pattern, email)) - - def _is_valid_url(self, url: str) -> bool: - """Validate URL format.""" - url_pattern = r"^https?://[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}(/.*)?$" - return bool(re.match(url_pattern, url)) +__all__ = ["StringValidator", "StringTypeValidator", "ValidationResult", "ValidationSeverity"]