diff --git a/cmd/results.go b/cmd/results.go index 985469e..79857dc 100755 --- a/cmd/results.go +++ b/cmd/results.go @@ -29,6 +29,13 @@ func init() { resultsCmd.AddCommand(resultsQueryCmd) resultsCmd.AddCommand(resultsStatsCmd) resultsCmd.AddCommand(resultsIdentityChainsCmd) + resultsCmd.AddCommand(resultsMarkFixedCmd) + resultsCmd.AddCommand(resultsMarkVerifiedCmd) + resultsCmd.AddCommand(resultsMarkFalsePositiveCmd) + resultsCmd.AddCommand(resultsRegressionsCmd) + resultsCmd.AddCommand(resultsTimelineCmd) + resultsCmd.AddCommand(resultsNewFindingsCmd) + resultsCmd.AddCommand(resultsFixedFindingsCmd) } var resultsListCmd = &cobra.Command{ @@ -1412,11 +1419,368 @@ func getSeverityColor(severity types.Severity) func(string) string { } } +var resultsMarkFixedCmd = &cobra.Command{ + Use: "mark-fixed [finding-id]", + Short: "Mark a finding as fixed (for regression detection)", + Long: `Mark a vulnerability finding as fixed. If the same vulnerability +is detected in a future scan, it will be flagged as a regression.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + logger := GetLogger().WithComponent("results") + findingID := args[0] + + logger.Infow("Marking finding as fixed", + "finding_id", findingID, + ) + + store := GetStore() + if store == nil { + return fmt.Errorf("database not initialized") + } + + err := store.UpdateFindingStatus(GetContext(), findingID, types.FindingStatusFixed) + if err != nil { + logger.Errorw("Failed to mark finding as fixed", + "finding_id", findingID, + "error", err, + ) + return fmt.Errorf("failed to mark finding as fixed: %w", err) + } + + logger.Infow("Finding marked as fixed - regression detection enabled", + "finding_id", findingID, + "note", "If this vulnerability reappears in future scans, it will be flagged as a regression", + ) + + return nil + }, +} + +var resultsMarkVerifiedCmd = &cobra.Command{ + Use: "mark-verified [finding-id]", + Short: "Mark a finding as manually verified", + Long: `Mark a vulnerability finding as manually verified. This indicates +that a human security researcher has confirmed the vulnerability exists.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + logger := GetLogger().WithComponent("results") + findingID := args[0] + + unverify, _ := cmd.Flags().GetBool("unverify") + + logger.Infow("Updating finding verification status", + "finding_id", findingID, + "verified", !unverify, + ) + + store := GetStore() + if store == nil { + return fmt.Errorf("database not initialized") + } + + err := store.MarkFindingVerified(GetContext(), findingID, !unverify) + if err != nil { + logger.Errorw("Failed to update finding verification status", + "finding_id", findingID, + "error", err, + ) + return fmt.Errorf("failed to update verification status: %w", err) + } + + if unverify { + logger.Infow("Finding marked as unverified", + "finding_id", findingID, + ) + } else { + logger.Infow("Finding marked as verified", + "finding_id", findingID, + ) + } + + return nil + }, +} + +var resultsMarkFalsePositiveCmd = &cobra.Command{ + Use: "mark-false-positive [finding-id]", + Short: "Mark a finding as a false positive", + Long: `Mark a vulnerability finding as a false positive. This indicates +the vulnerability was incorrectly identified by the scanner.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + logger := GetLogger().WithComponent("results") + findingID := args[0] + + remove, _ := cmd.Flags().GetBool("remove") + + logger.Infow("Updating finding false positive status", + "finding_id", findingID, + "false_positive", !remove, + ) + + store := GetStore() + if store == nil { + return fmt.Errorf("database not initialized") + } + + err := store.MarkFindingFalsePositive(GetContext(), findingID, !remove) + if err != nil { + logger.Errorw("Failed to update finding false positive status", + "finding_id", findingID, + "error", err, + ) + return fmt.Errorf("failed to update false positive status: %w", err) + } + + if remove { + logger.Infow("Finding false positive flag removed", + "finding_id", findingID, + ) + } else { + logger.Infow("Finding marked as false positive", + "finding_id", findingID, + ) + } + + return nil + }, +} + +var resultsRegressionsCmd = &cobra.Command{ + Use: "regressions", + Short: "List vulnerabilities that were fixed and then reappeared", + Long: `Show findings that were marked as fixed but have since reappeared +in subsequent scans (regression detection).`, + RunE: func(cmd *cobra.Command, args []string) error { + logger := GetLogger().WithComponent("results") + + limit, _ := cmd.Flags().GetInt("limit") + output, _ := cmd.Flags().GetString("output") + + logger.Infow("Querying regressions", + "limit", limit, + ) + + store := GetStore() + if store == nil { + return fmt.Errorf("database not initialized") + } + + findings, err := store.GetRegressions(GetContext(), limit) + if err != nil { + logger.Errorw("Failed to query regressions", + "error", err, + ) + return fmt.Errorf("failed to query regressions: %w", err) + } + + if output == "json" { + jsonData, _ := json.MarshalIndent(findings, "", " ") + fmt.Println(string(jsonData)) + } else { + logger.Infow("Regressions found", + "count", len(findings), + ) + printFindings(findings) + } + + return nil + }, +} + +var resultsTimelineCmd = &cobra.Command{ + Use: "timeline [fingerprint]", + Short: "Show the full lifecycle of a specific vulnerability", + Long: `Display all instances of a vulnerability across scans to see its +complete lifecycle (when first detected, when fixed, if it reappeared).`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + logger := GetLogger().WithComponent("results") + fingerprint := args[0] + + output, _ := cmd.Flags().GetString("output") + + logger.Infow("Querying vulnerability timeline", + "fingerprint", fingerprint, + ) + + store := GetStore() + if store == nil { + return fmt.Errorf("database not initialized") + } + + findings, err := store.GetVulnerabilityTimeline(GetContext(), fingerprint) + if err != nil { + logger.Errorw("Failed to query vulnerability timeline", + "fingerprint", fingerprint, + "error", err, + ) + return fmt.Errorf("failed to query timeline: %w", err) + } + + if output == "json" { + jsonData, _ := json.MarshalIndent(findings, "", " ") + fmt.Println(string(jsonData)) + } else { + logger.Infow("Timeline entries found", + "fingerprint", fingerprint, + "count", len(findings), + ) + + if len(findings) == 0 { + logger.Infow("No findings found for this fingerprint") + return nil + } + + logger.Infow("Vulnerability Lifecycle", + "first_seen", findings[0].CreatedAt, + "last_seen", findings[len(findings)-1].CreatedAt, + "total_instances", len(findings), + ) + + printFindings(findings) + } + + return nil + }, +} + +var resultsNewFindingsCmd = &cobra.Command{ + Use: "new-findings", + Short: "List vulnerabilities that appeared recently", + Long: `Show findings that first appeared after a specific date. +Useful for tracking new vulnerabilities discovered over time.`, + RunE: func(cmd *cobra.Command, args []string) error { + logger := GetLogger().WithComponent("results") + + days, _ := cmd.Flags().GetInt("days") + output, _ := cmd.Flags().GetString("output") + + sinceDate := time.Now().AddDate(0, 0, -days) + + logger.Infow("Querying new findings", + "since_date", sinceDate, + "days", days, + ) + + store := GetStore() + if store == nil { + return fmt.Errorf("database not initialized") + } + + findings, err := store.GetNewFindings(GetContext(), sinceDate) + if err != nil { + logger.Errorw("Failed to query new findings", + "error", err, + "since_date", sinceDate, + ) + return fmt.Errorf("failed to query new findings: %w", err) + } + + if output == "json" { + jsonData, _ := json.MarshalIndent(findings, "", " ") + fmt.Println(string(jsonData)) + } else { + logger.Infow("New findings", + "count", len(findings), + "since_days", days, + ) + printFindings(findings) + } + + return nil + }, +} + +var resultsFixedFindingsCmd = &cobra.Command{ + Use: "fixed-findings", + Short: "List vulnerabilities that have been marked as fixed", + Long: `Show findings that have been marked as fixed by security researchers. +Useful for tracking remediation progress.`, + RunE: func(cmd *cobra.Command, args []string) error { + logger := GetLogger().WithComponent("results") + + limit, _ := cmd.Flags().GetInt("limit") + output, _ := cmd.Flags().GetString("output") + + logger.Infow("Querying fixed findings", + "limit", limit, + ) + + store := GetStore() + if store == nil { + return fmt.Errorf("database not initialized") + } + + findings, err := store.GetFixedFindings(GetContext(), limit) + if err != nil { + logger.Errorw("Failed to query fixed findings", + "error", err, + ) + return fmt.Errorf("failed to query fixed findings: %w", err) + } + + if output == "json" { + jsonData, _ := json.MarshalIndent(findings, "", " ") + fmt.Println(string(jsonData)) + } else { + logger.Infow("Fixed findings", + "count", len(findings), + ) + printFindings(findings) + } + + return nil + }, +} + +func printFindings(findings []types.Finding) { + if len(findings) == 0 { + fmt.Println("No findings") + return + } + + for i, finding := range findings { + fmt.Printf("\n[%d] %s - %s\n", i+1, finding.Severity, finding.Title) + fmt.Printf(" Tool: %s | Type: %s\n", finding.Tool, finding.Type) + fmt.Printf(" Status: %s | First Seen: %s\n", finding.Status, finding.FirstScanID) + fmt.Printf(" Scan: %s | Created: %s\n", finding.ScanID, finding.CreatedAt.Format("2006-01-02 15:04:05")) + if finding.Fingerprint != "" { + fmt.Printf(" Fingerprint: %s\n", finding.Fingerprint) + } + if finding.Verified { + fmt.Printf(" [VERIFIED]\n") + } + if finding.FalsePositive { + fmt.Printf(" [FALSE POSITIVE]\n") + } + } + fmt.Println() +} + func init() { // Add diff command resultsCmd.AddCommand(resultsDiffCmd) resultsDiffCmd.Flags().StringP("output", "o", "text", "Output format (text, json)") + // Add flags for mark-verified command + resultsMarkVerifiedCmd.Flags().Bool("unverify", false, "Remove verified flag") + + // Add flags for mark-false-positive command + resultsMarkFalsePositiveCmd.Flags().Bool("remove", false, "Remove false positive flag") + + // Add flags for temporal query commands + resultsRegressionsCmd.Flags().IntP("limit", "l", 50, "Maximum number of regressions to show") + resultsRegressionsCmd.Flags().StringP("output", "o", "text", "Output format (text, json)") + + resultsTimelineCmd.Flags().StringP("output", "o", "text", "Output format (text, json)") + + resultsNewFindingsCmd.Flags().IntP("days", "d", 7, "Number of days to look back") + resultsNewFindingsCmd.Flags().StringP("output", "o", "text", "Output format (text, json)") + + resultsFixedFindingsCmd.Flags().IntP("limit", "l", 50, "Maximum number of fixed findings to show") + resultsFixedFindingsCmd.Flags().StringP("output", "o", "text", "Output format (text, json)") + // Add history command resultsCmd.AddCommand(resultsHistoryCmd) resultsHistoryCmd.Flags().IntP("limit", "l", 50, "Maximum number of scans to show") diff --git a/internal/core/interfaces.go b/internal/core/interfaces.go index a684259..744fcc1 100755 --- a/internal/core/interfaces.go +++ b/internal/core/interfaces.go @@ -36,6 +36,18 @@ type ResultStore interface { GetFindings(ctx context.Context, scanID string) ([]types.Finding, error) GetFindingsBySeverity(ctx context.Context, severity types.Severity) ([]types.Finding, error) + // Finding status management (for lifecycle tracking and regression detection) + UpdateFindingStatus(ctx context.Context, findingID string, status types.FindingStatus) error + MarkFindingVerified(ctx context.Context, findingID string, verified bool) error + MarkFindingFalsePositive(ctx context.Context, findingID string, falsePositive bool) error + + // Temporal query methods (for historical analysis and trend detection) + GetRegressions(ctx context.Context, limit int) ([]types.Finding, error) + GetVulnerabilityTimeline(ctx context.Context, fingerprint string) ([]types.Finding, error) + GetFindingsByFingerprint(ctx context.Context, fingerprint string) ([]types.Finding, error) + GetNewFindings(ctx context.Context, sinceDate time.Time) ([]types.Finding, error) + GetFixedFindings(ctx context.Context, limit int) ([]types.Finding, error) + // Enhanced query methods QueryFindings(ctx context.Context, query FindingQuery) ([]types.Finding, error) GetFindingStats(ctx context.Context) (*FindingStats, error) diff --git a/internal/database/migrations.go b/internal/database/migrations.go index 4145255..6acc660 100644 --- a/internal/database/migrations.go +++ b/internal/database/migrations.go @@ -180,6 +180,107 @@ func GetAllMigrations() []Migration { ); `, }, + { + Version: 6, + Description: "Add database constraints and GIN indexes for performance and data integrity", + Up: ` + -- Add foreign key constraint for first_scan_id (ensures referential integrity) + -- Note: This assumes first_scan_id references scans(id) + -- Skip if constraint already exists + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_constraint + WHERE conname = 'fk_findings_first_scan_id' + ) THEN + ALTER TABLE findings + ADD CONSTRAINT fk_findings_first_scan_id + FOREIGN KEY (first_scan_id) REFERENCES scans(id) ON DELETE SET NULL; + END IF; + END $$; + + -- Add check constraint for status enum (prevents invalid status values) + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_constraint + WHERE conname = 'chk_findings_status' + ) THEN + ALTER TABLE findings + ADD CONSTRAINT chk_findings_status + CHECK (status IN ('new', 'active', 'fixed', 'duplicate', 'reopened')); + END IF; + END $$; + + -- Add check constraint for severity enum + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_constraint + WHERE conname = 'chk_findings_severity' + ) THEN + ALTER TABLE findings + ADD CONSTRAINT chk_findings_severity + CHECK (severity IN ('critical', 'high', 'medium', 'low', 'info')); + END IF; + END $$; + + -- Add NOT NULL constraints for critical fields + ALTER TABLE findings ALTER COLUMN fingerprint SET NOT NULL; + ALTER TABLE findings ALTER COLUMN first_scan_id SET NOT NULL; + ALTER TABLE findings ALTER COLUMN status SET NOT NULL; + + -- Add GIN indexes for JSONB columns (PostgreSQL only, enables fast JSONB queries) + -- These indexes dramatically improve queries like: metadata @> '{"key": "value"}' + CREATE INDEX IF NOT EXISTS idx_findings_metadata_gin ON findings USING GIN (metadata); + CREATE INDEX IF NOT EXISTS idx_correlation_related_findings_gin ON correlation_results USING GIN (related_findings); + CREATE INDEX IF NOT EXISTS idx_correlation_attack_path_gin ON correlation_results USING GIN (attack_path); + CREATE INDEX IF NOT EXISTS idx_correlation_metadata_gin ON correlation_results USING GIN (metadata); + + -- Add composite indexes for common query patterns + -- Regression queries: WHERE status = 'reopened' ORDER BY created_at DESC + CREATE INDEX IF NOT EXISTS idx_findings_status_created ON findings(status, created_at DESC); + + -- Timeline queries: WHERE fingerprint = ? ORDER BY created_at ASC + CREATE INDEX IF NOT EXISTS idx_findings_fingerprint_created ON findings(fingerprint, created_at ASC); + + -- Fixed findings: WHERE status = 'fixed' ORDER BY updated_at DESC + CREATE INDEX IF NOT EXISTS idx_findings_status_updated ON findings(status, updated_at DESC); + + -- Add unique constraint on fingerprint + scan_id (prevents exact duplicates within same scan) + CREATE UNIQUE INDEX IF NOT EXISTS idx_findings_fingerprint_scan_unique ON findings(fingerprint, scan_id); + + -- Add comments for documentation + COMMENT ON CONSTRAINT fk_findings_first_scan_id ON findings IS 'Ensures first_scan_id references valid scan'; + COMMENT ON CONSTRAINT chk_findings_status ON findings IS 'Enforces valid status enum values'; + COMMENT ON CONSTRAINT chk_findings_severity ON findings IS 'Enforces valid severity enum values'; + `, + Down: ` + -- Remove composite indexes + DROP INDEX IF EXISTS idx_findings_status_created; + DROP INDEX IF EXISTS idx_findings_fingerprint_created; + DROP INDEX IF EXISTS idx_findings_status_updated; + DROP INDEX IF EXISTS idx_findings_fingerprint_scan_unique; + + -- Remove GIN indexes + DROP INDEX IF EXISTS idx_findings_metadata_gin; + DROP INDEX IF EXISTS idx_correlation_related_findings_gin; + DROP INDEX IF EXISTS idx_correlation_attack_path_gin; + DROP INDEX IF EXISTS idx_correlation_metadata_gin; + + -- Remove NOT NULL constraints + ALTER TABLE findings ALTER COLUMN fingerprint DROP NOT NULL; + ALTER TABLE findings ALTER COLUMN first_scan_id DROP NOT NULL; + ALTER TABLE findings ALTER COLUMN status DROP NOT NULL; + + -- Remove check constraints + ALTER TABLE findings DROP CONSTRAINT IF EXISTS chk_findings_status; + ALTER TABLE findings DROP CONSTRAINT IF EXISTS chk_findings_severity; + + -- Remove foreign key constraint + ALTER TABLE findings DROP CONSTRAINT IF EXISTS fk_findings_first_scan_id; + `, + }, } } diff --git a/internal/database/store.go b/internal/database/store.go index 00bb0db..385958a 100755 --- a/internal/database/store.go +++ b/internal/database/store.go @@ -73,7 +73,7 @@ import ( "time" "github.com/jmoiron/sqlx" - _ "github.com/lib/pq" + "github.com/lib/pq" "github.com/CodeMonkeyCybersecurity/shells/internal/config" "github.com/CodeMonkeyCybersecurity/shells/internal/core" @@ -652,48 +652,95 @@ func (s *sqlStore) ListScans(ctx context.Context, filter core.ScanFilter) ([]*ty // generateFindingFingerprint creates a hash for deduplication across scans // Fingerprint is based on: tool + type + title + target (normalized) -// Target is extracted from metadata["target"] or metadata["endpoint"] or metadata["url"] +// Target is extracted from metadata or evidence with extensive field checking func generateFindingFingerprint(finding types.Finding) string { // Extract target information from metadata target := "" if finding.Metadata != nil { - // Try common target field names - if t, ok := finding.Metadata["target"].(string); ok { + // Try common target field names in priority order + if t, ok := finding.Metadata["target"].(string); ok && t != "" { target = t - } else if ep, ok := finding.Metadata["endpoint"].(string); ok { + } else if ep, ok := finding.Metadata["endpoint"].(string); ok && ep != "" { target = ep - } else if url, ok := finding.Metadata["url"].(string); ok { + } else if url, ok := finding.Metadata["url"].(string); ok && url != "" { target = url - } else if host, ok := finding.Metadata["host"].(string); ok { + } else if host, ok := finding.Metadata["host"].(string); ok && host != "" { target = host - } else if param, ok := finding.Metadata["parameter"].(string); ok { + } else if hostname, ok := finding.Metadata["hostname"].(string); ok && hostname != "" { + target = hostname + } else if domain, ok := finding.Metadata["domain"].(string); ok && domain != "" { + target = domain + } else if ip, ok := finding.Metadata["ip"].(string); ok && ip != "" { + target = ip + } else if path, ok := finding.Metadata["path"].(string); ok && path != "" { + target = path + } else if param, ok := finding.Metadata["parameter"].(string); ok && param != "" { // For parameter-specific vulns (e.g., XSS in specific param) target = param + } else if port, ok := finding.Metadata["port"].(string); ok && port != "" { + // For port-specific vulns + target = port + } else if svc, ok := finding.Metadata["service"].(string); ok && svc != "" { + target = svc } } - // If no target in metadata, extract from evidence (first line or URL pattern) + // If no target in metadata, extract from evidence if target == "" && finding.Evidence != "" { - // Try to extract URL or endpoint from evidence - // Look for common patterns like "GET /path" or "https://..." evidenceLines := strings.Split(finding.Evidence, "\n") - if len(evidenceLines) > 0 { - firstLine := strings.TrimSpace(evidenceLines[0]) - // Extract HTTP method + path pattern - if strings.Contains(firstLine, "GET ") || strings.Contains(firstLine, "POST ") || - strings.Contains(firstLine, "PUT ") || strings.Contains(firstLine, "DELETE ") { - parts := strings.Fields(firstLine) + + // Try each line until we find a target + for _, line := range evidenceLines { + if target != "" { + break + } + + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // Pattern 1: HTTP method + path (e.g., "GET /api/users") + if strings.Contains(line, "GET ") || strings.Contains(line, "POST ") || + strings.Contains(line, "PUT ") || strings.Contains(line, "DELETE ") || + strings.Contains(line, "PATCH ") || strings.Contains(line, "OPTIONS ") { + parts := strings.Fields(line) if len(parts) >= 2 { target = parts[1] // The path + break } - } else if strings.HasPrefix(firstLine, "http://") || strings.HasPrefix(firstLine, "https://") { - // Extract hostname and path - if idx := strings.Index(firstLine, "://"); idx != -1 { - remaining := firstLine[idx+3:] - if slashIdx := strings.Index(remaining, "/"); slashIdx != -1 { - target = remaining[:slashIdx] + remaining[slashIdx:strings.IndexAny(remaining, "? \t")] - } else { - target = remaining + } + + // Pattern 2: Full URL (e.g., "https://example.com/path") + if strings.HasPrefix(line, "http://") || strings.HasPrefix(line, "https://") { + if idx := strings.Index(line, "://"); idx != -1 { + remaining := line[idx+3:] + if spaceIdx := strings.IndexAny(remaining, " \t"); spaceIdx != -1 { + remaining = remaining[:spaceIdx] + } + target = remaining + break + } + } + + // Pattern 3: Look for URL: prefix (e.g., "URL: https://example.com") + if strings.Contains(line, "URL:") || strings.Contains(line, "url:") { + if idx := strings.Index(strings.ToLower(line), "url:"); idx != -1 { + urlPart := strings.TrimSpace(line[idx+4:]) + if urlPart != "" { + target = urlPart + break + } + } + } + + // Pattern 4: Look for Target: prefix + if strings.Contains(line, "Target:") || strings.Contains(line, "target:") { + if idx := strings.Index(strings.ToLower(line), "target:"); idx != -1 { + targetPart := strings.TrimSpace(line[idx+7:]) + if targetPart != "" { + target = targetPart + break } } } @@ -710,13 +757,129 @@ func generateFindingFingerprint(finding types.Finding) string { // Generate SHA256 hash hash := sha256.Sum256([]byte(normalized)) - return fmt.Sprintf("%x", hash[:16]) // Use first 16 bytes (32 hex chars) + fingerprint := fmt.Sprintf("%x", hash[:16]) // Use first 16 bytes (32 hex chars) + + // EDGE CASE HANDLING: Empty target creates weaker fingerprints + // If target is empty, the fingerprint is based on tool:type:title only + // This allows deduplication of identical findings across scans, but may cause + // false positives if the same vulnerability type exists at multiple locations + // and we can't extract location from metadata or evidence. + // + // This is acceptable because: + // 1. Most scanners populate metadata["target"] or similar fields + // 2. Evidence usually contains URLs or paths we can extract + // 3. Manual verification can mark false positives via verified/false_positive fields + // 4. Weak fingerprints are better than no deduplication + // + // To improve: Ensure scanners populate metadata with location info + + return fingerprint +} + +// DuplicateInfo holds information about a duplicate finding +type DuplicateInfo struct { + IsDuplicate bool + FirstScanID string + PreviousStatus types.FindingStatus +} + +// batchCheckDuplicateFindings checks multiple fingerprints in a single query (performance optimization) +// Returns a map of fingerprint -> DuplicateInfo +func (s *sqlStore) batchCheckDuplicateFindings(ctx context.Context, tx *sqlx.Tx, fingerprints []string, currentScanID string) (map[string]DuplicateInfo, error) { + if len(fingerprints) == 0 { + return make(map[string]DuplicateInfo), nil + } + + var query string + var args []interface{} + + // PostgreSQL uses ANY($1) with array parameter + if s.cfg.Type == "postgres" { + query = ` + SELECT DISTINCT ON (fingerprint) + fingerprint, first_scan_id, scan_id, status + FROM findings + WHERE fingerprint = ANY($1) + ORDER BY fingerprint, created_at DESC + ` + args = []interface{}{pq.Array(fingerprints)} + } else { + // SQLite uses IN clause with placeholders + placeholders := make([]string, len(fingerprints)) + args = make([]interface{}, len(fingerprints)) + for i, fp := range fingerprints { + placeholders[i] = s.getPlaceholder(i + 1) + args[i] = fp + } + + query = fmt.Sprintf(` + SELECT fingerprint, first_scan_id, scan_id, status + FROM ( + SELECT fingerprint, first_scan_id, scan_id, status, + ROW_NUMBER() OVER (PARTITION BY fingerprint ORDER BY created_at DESC) as rn + FROM findings + WHERE fingerprint IN (%s) + ) + WHERE rn = 1 + `, strings.Join(placeholders, ",")) + } + + rows, err := tx.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to batch check duplicates: %w", err) + } + defer rows.Close() + + result := make(map[string]DuplicateInfo) + + for rows.Next() { + var fingerprint, firstScanID, scanID string + var previousStatus types.FindingStatus + + if err := rows.Scan(&fingerprint, &firstScanID, &scanID, &previousStatus); err != nil { + return nil, fmt.Errorf("failed to scan duplicate row: %w", err) + } + + // If first_scan_id is empty (old data before migration), use the scan_id we found + if firstScanID == "" { + firstScanID = scanID + } + + // Check for regression (previously fixed vulnerability reappearing) + if previousStatus == types.FindingStatusFixed { + s.logger.Errorw("REGRESSION DETECTED: Previously fixed vulnerability has reappeared", + "fingerprint", fingerprint, + "first_scan_id", firstScanID, + "last_seen_scan", scanID, + "current_scan", currentScanID, + "impact", "CRITICAL", + ) + result[fingerprint] = DuplicateInfo{ + IsDuplicate: true, + FirstScanID: firstScanID, + PreviousStatus: types.FindingStatusReopened, + } + } else { + result[fingerprint] = DuplicateInfo{ + IsDuplicate: true, + FirstScanID: firstScanID, + PreviousStatus: previousStatus, + } + } + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating duplicate rows: %w", err) + } + + return result, nil } // checkDuplicateFinding checks if a finding with the same fingerprint exists in previous scans // Returns: (isDuplicate, firstScanID, previousStatus, error) // Also detects regressions when a "fixed" vulnerability reappears -func (s *sqlStore) checkDuplicateFinding(ctx context.Context, tx *sqlx.Tx, fingerprint, currentScanID string) (bool, string, string, error) { +// NOTE: This is kept for backward compatibility but batchCheckDuplicateFindings should be preferred +func (s *sqlStore) checkDuplicateFinding(ctx context.Context, tx *sqlx.Tx, fingerprint, currentScanID string) (bool, string, types.FindingStatus, error) { // Get the most recent occurrence to check for regressions recentQuery := ` SELECT first_scan_id, scan_id, status @@ -726,7 +889,8 @@ func (s *sqlStore) checkDuplicateFinding(ctx context.Context, tx *sqlx.Tx, finge LIMIT 1 ` - var firstScanID, scanID, previousStatus string + var firstScanID, scanID string + var previousStatus types.FindingStatus err := tx.QueryRowContext(ctx, recentQuery, fingerprint).Scan(&firstScanID, &scanID, &previousStatus) if err == sql.ErrNoRows { // Not a duplicate - this is the first occurrence @@ -742,7 +906,7 @@ func (s *sqlStore) checkDuplicateFinding(ctx context.Context, tx *sqlx.Tx, finge } // Check for regression (previously fixed vulnerability reappearing) - if previousStatus == string(types.FindingStatusFixed) { + if previousStatus == types.FindingStatusFixed { s.logger.Errorw("REGRESSION DETECTED: Previously fixed vulnerability has reappeared", "fingerprint", fingerprint, "first_scan_id", firstScanID, @@ -750,7 +914,7 @@ func (s *sqlStore) checkDuplicateFinding(ctx context.Context, tx *sqlx.Tx, finge "current_scan", currentScanID, "impact", "CRITICAL", ) - return true, firstScanID, string(types.FindingStatusReopened), nil + return true, firstScanID, types.FindingStatusReopened, nil } return true, firstScanID, previousStatus, nil @@ -783,7 +947,7 @@ func (s *sqlStore) SaveFindings(ctx context.Context, findings []types.Finding) e // Count findings by severity for logging severityCounts := make(map[types.Severity]int) toolCounts := make(map[string]int) - statusCounts := make(map[string]int) + statusCounts := make(map[types.FindingStatus]int) duplicateCount := 0 for _, finding := range findings { @@ -839,38 +1003,62 @@ func (s *sqlStore) SaveFindings(ctx context.Context, findings []types.Finding) e insertStart := time.Now() totalRowsAffected := int64(0) + // P1 OPTIMIZATION: Batch fingerprint lookup (fixes N+1 query problem) + // Instead of querying database for each finding (100 findings = 100 queries), + // we query once for all fingerprints (100 findings = 1 query) + batchLookupStart := time.Now() + fingerprints := make([]string, len(findings)) + for i, finding := range findings { + fingerprints[i] = generateFindingFingerprint(finding) + } + + duplicateLookup, err := s.batchCheckDuplicateFindings(ctx, tx, fingerprints, scanID) + if err != nil { + s.logger.LogError(ctx, err, "database.SaveFindings.batch_check_duplicates", + "scan_id", scanID, + "findings_count", len(findings), + ) + // Initialize empty map and continue (will treat all as new findings) + duplicateLookup = make(map[string]DuplicateInfo) + } + + s.logger.LogDuration(ctx, "database.SaveFindings.batch_lookup", batchLookupStart, + "scan_id", scanID, + "fingerprints_checked", len(fingerprints), + "duplicates_found", len(duplicateLookup), + ) + for i, finding := range findings { findingStart := time.Now() - // Generate fingerprint for deduplication (includes target for uniqueness) - fingerprint := generateFindingFingerprint(finding) + // Get fingerprint (already generated in batch lookup phase) + fingerprint := fingerprints[i] - // Check if this is a duplicate from a previous scan (also detects regressions) - isDuplicate, firstScanID, previousStatus, err := s.checkDuplicateFinding(ctx, tx, fingerprint, finding.ScanID) - if err != nil { - s.logger.LogError(ctx, err, "database.SaveFindings.check_duplicate", - "finding_id", finding.ID, - "fingerprint", fingerprint, - ) - // Continue with insertion even if duplicate check fails - isDuplicate = false - firstScanID = finding.ScanID - previousStatus = "" + // Look up duplicate information from batch check + dupInfo, found := duplicateLookup[fingerprint] + isDuplicate := found + firstScanID := finding.ScanID + previousStatus := types.FindingStatus("") + + if found { + isDuplicate = dupInfo.IsDuplicate + firstScanID = dupInfo.FirstScanID + previousStatus = dupInfo.PreviousStatus } // Set status based on duplication and regression detection - status := string(types.FindingStatusNew) + status := types.FindingStatusNew if isDuplicate { // If previousStatus is "reopened", this is a regression - if previousStatus == string(types.FindingStatusReopened) { - status = string(types.FindingStatusReopened) + if previousStatus == types.FindingStatusReopened { + status = types.FindingStatusReopened s.logger.Warnw("Marking finding as reopened (regression)", "finding_id", finding.ID, "fingerprint", fingerprint, "first_scan_id", firstScanID, ) } else { - status = string(types.FindingStatusDuplicate) + status = types.FindingStatusDuplicate duplicateCount++ } } @@ -1107,6 +1295,303 @@ func (s *sqlStore) GetFindingsBySeverity(ctx context.Context, severity types.Sev return findings, nil } +// UpdateFindingStatus updates the status of a finding (for lifecycle tracking) +func (s *sqlStore) UpdateFindingStatus(ctx context.Context, findingID string, status types.FindingStatus) error { + ctx, span := s.logger.StartOperation(ctx, "database.UpdateFindingStatus", + "finding_id", findingID, + "status", status, + ) + var err error + defer func() { s.logger.EndOperation(ctx, span, err) }() + + query := fmt.Sprintf(` + UPDATE findings + SET status = %s, updated_at = %s + WHERE id = %s + `, s.getPlaceholder(1), s.getPlaceholder(2), s.getPlaceholder(3)) + + result, err := s.db.ExecContext(ctx, query, string(status), time.Now(), findingID) + if err != nil { + return fmt.Errorf("failed to update finding status: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rowsAffected == 0 { + return fmt.Errorf("finding not found: %s", findingID) + } + + s.logger.Infow("Finding status updated", + "finding_id", findingID, + "new_status", status, + ) + + return nil +} + +// MarkFindingVerified marks a finding as manually verified or unverified +func (s *sqlStore) MarkFindingVerified(ctx context.Context, findingID string, verified bool) error { + ctx, span := s.logger.StartOperation(ctx, "database.MarkFindingVerified", + "finding_id", findingID, + "verified", verified, + ) + var err error + defer func() { s.logger.EndOperation(ctx, span, err) }() + + query := fmt.Sprintf(` + UPDATE findings + SET verified = %s, updated_at = %s + WHERE id = %s + `, s.getPlaceholder(1), s.getPlaceholder(2), s.getPlaceholder(3)) + + result, err := s.db.ExecContext(ctx, query, verified, time.Now(), findingID) + if err != nil { + return fmt.Errorf("failed to mark finding verified: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rowsAffected == 0 { + return fmt.Errorf("finding not found: %s", findingID) + } + + s.logger.Infow("Finding verification status updated", + "finding_id", findingID, + "verified", verified, + ) + + return nil +} + +// MarkFindingFalsePositive marks a finding as a false positive or removes the false positive flag +func (s *sqlStore) MarkFindingFalsePositive(ctx context.Context, findingID string, falsePositive bool) error { + ctx, span := s.logger.StartOperation(ctx, "database.MarkFindingFalsePositive", + "finding_id", findingID, + "false_positive", falsePositive, + ) + var err error + defer func() { s.logger.EndOperation(ctx, span, err) }() + + query := fmt.Sprintf(` + UPDATE findings + SET false_positive = %s, updated_at = %s + WHERE id = %s + `, s.getPlaceholder(1), s.getPlaceholder(2), s.getPlaceholder(3)) + + result, err := s.db.ExecContext(ctx, query, falsePositive, time.Now(), findingID) + if err != nil { + return fmt.Errorf("failed to mark finding as false positive: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rowsAffected == 0 { + return fmt.Errorf("finding not found: %s", findingID) + } + + s.logger.Infow("Finding false positive status updated", + "finding_id", findingID, + "false_positive", falsePositive, + ) + + return nil +} + +// GetRegressions returns findings that were marked as fixed and then reopened (regressions) +func (s *sqlStore) GetRegressions(ctx context.Context, limit int) ([]types.Finding, error) { + ctx, span := s.logger.StartOperation(ctx, "database.GetRegressions", + "limit", limit, + ) + var err error + defer func() { s.logger.EndOperation(ctx, span, err) }() + + query := fmt.Sprintf(` + SELECT id, scan_id, tool, type, severity, title, description, + evidence, solution, refs, metadata, + fingerprint, first_scan_id, status, verified, false_positive, + created_at, updated_at + FROM findings + WHERE status = %s + ORDER BY created_at DESC + LIMIT %s + `, s.getPlaceholder(1), s.getPlaceholder(2)) + + rows, err := s.db.QueryContext(ctx, query, types.FindingStatusReopened, limit) + if err != nil { + return nil, fmt.Errorf("failed to query regressions: %w", err) + } + defer rows.Close() + + return s.scanFindings(rows) +} + +// GetVulnerabilityTimeline returns all instances of a specific vulnerability across scans (full lifecycle) +func (s *sqlStore) GetVulnerabilityTimeline(ctx context.Context, fingerprint string) ([]types.Finding, error) { + ctx, span := s.logger.StartOperation(ctx, "database.GetVulnerabilityTimeline", + "fingerprint", fingerprint, + ) + var err error + defer func() { s.logger.EndOperation(ctx, span, err) }() + + query := fmt.Sprintf(` + SELECT id, scan_id, tool, type, severity, title, description, + evidence, solution, refs, metadata, + fingerprint, first_scan_id, status, verified, false_positive, + created_at, updated_at + FROM findings + WHERE fingerprint = %s + ORDER BY created_at ASC + `, s.getPlaceholder(1)) + + rows, err := s.db.QueryContext(ctx, query, fingerprint) + if err != nil { + return nil, fmt.Errorf("failed to query vulnerability timeline: %w", err) + } + defer rows.Close() + + return s.scanFindings(rows) +} + +// GetFindingsByFingerprint returns all findings with a specific fingerprint (across all scans) +func (s *sqlStore) GetFindingsByFingerprint(ctx context.Context, fingerprint string) ([]types.Finding, error) { + ctx, span := s.logger.StartOperation(ctx, "database.GetFindingsByFingerprint", + "fingerprint", fingerprint, + ) + var err error + defer func() { s.logger.EndOperation(ctx, span, err) }() + + query := fmt.Sprintf(` + SELECT id, scan_id, tool, type, severity, title, description, + evidence, solution, refs, metadata, + fingerprint, first_scan_id, status, verified, false_positive, + created_at, updated_at + FROM findings + WHERE fingerprint = %s + ORDER BY created_at DESC + `, s.getPlaceholder(1)) + + rows, err := s.db.QueryContext(ctx, query, fingerprint) + if err != nil { + return nil, fmt.Errorf("failed to query findings by fingerprint: %w", err) + } + defer rows.Close() + + return s.scanFindings(rows) +} + +// GetNewFindings returns findings that first appeared after a specific date +func (s *sqlStore) GetNewFindings(ctx context.Context, sinceDate time.Time) ([]types.Finding, error) { + ctx, span := s.logger.StartOperation(ctx, "database.GetNewFindings", + "since_date", sinceDate, + ) + var err error + defer func() { s.logger.EndOperation(ctx, span, err) }() + + // Get findings where first_scan_id's created_at is after sinceDate + query := fmt.Sprintf(` + SELECT f.id, f.scan_id, f.tool, f.type, f.severity, f.title, f.description, + f.evidence, f.solution, f.refs, f.metadata, + f.fingerprint, f.first_scan_id, f.status, f.verified, f.false_positive, + f.created_at, f.updated_at + FROM findings f + INNER JOIN ( + SELECT fingerprint, MIN(created_at) as first_seen + FROM findings + GROUP BY fingerprint + ) first ON f.fingerprint = first.fingerprint + WHERE first.first_seen >= %s + AND f.status = %s + ORDER BY first.first_seen DESC + `, s.getPlaceholder(1), s.getPlaceholder(2)) + + rows, err := s.db.QueryContext(ctx, query, sinceDate, types.FindingStatusNew) + if err != nil { + return nil, fmt.Errorf("failed to query new findings: %w", err) + } + defer rows.Close() + + return s.scanFindings(rows) +} + +// GetFixedFindings returns findings that have been marked as fixed +func (s *sqlStore) GetFixedFindings(ctx context.Context, limit int) ([]types.Finding, error) { + ctx, span := s.logger.StartOperation(ctx, "database.GetFixedFindings", + "limit", limit, + ) + var err error + defer func() { s.logger.EndOperation(ctx, span, err) }() + + query := fmt.Sprintf(` + SELECT id, scan_id, tool, type, severity, title, description, + evidence, solution, refs, metadata, + fingerprint, first_scan_id, status, verified, false_positive, + created_at, updated_at + FROM findings + WHERE status = %s + ORDER BY updated_at DESC + LIMIT %s + `, s.getPlaceholder(1), s.getPlaceholder(2)) + + rows, err := s.db.QueryContext(ctx, query, types.FindingStatusFixed, limit) + if err != nil { + return nil, fmt.Errorf("failed to query fixed findings: %w", err) + } + defer rows.Close() + + return s.scanFindings(rows) +} + +// scanFindings is a helper function to scan rows into Finding structs +func (s *sqlStore) scanFindings(rows *sql.Rows) ([]types.Finding, error) { + findings := []types.Finding{} + + for rows.Next() { + var finding types.Finding + var refsJSON, metaJSON string + + err := rows.Scan( + &finding.ID, &finding.ScanID, &finding.Tool, &finding.Type, + &finding.Severity, &finding.Title, &finding.Description, + &finding.Evidence, &finding.Solution, &refsJSON, &metaJSON, + &finding.Fingerprint, &finding.FirstScanID, &finding.Status, + &finding.Verified, &finding.FalsePositive, + &finding.CreatedAt, &finding.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan finding: %w", err) + } + + if refsJSON != "" { + if err := json.Unmarshal([]byte(refsJSON), &finding.References); err != nil { + s.logger.Warn("Failed to unmarshal references for finding", "finding_id", finding.ID, "error", err) + } + } + + if metaJSON != "" { + if err := json.Unmarshal([]byte(metaJSON), &finding.Metadata); err != nil { + s.logger.Warn("Failed to unmarshal metadata for finding", "finding_id", finding.ID, "error", err) + } + } + + findings = append(findings, finding) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating findings: %w", err) + } + + return findings, nil +} + func (s *sqlStore) GetSummary(ctx context.Context, scanID string) (*types.Summary, error) { summary := &types.Summary{ BySeverity: make(map[types.Severity]int), diff --git a/internal/database/store_test.go b/internal/database/store_test.go new file mode 100644 index 0000000..3714792 --- /dev/null +++ b/internal/database/store_test.go @@ -0,0 +1,421 @@ +package database + +import ( + "context" + "testing" + "time" + + "github.com/CodeMonkeyCybersecurity/shells/internal/logger" + "github.com/CodeMonkeyCybersecurity/shells/pkg/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test generateFindingFingerprint with various metadata fields +func TestGenerateFindingFingerprint_MetadataExtraction(t *testing.T) { + tests := []struct { + name string + finding types.Finding + expectSame bool + description string + }{ + { + name: "target field in metadata", + finding: types.Finding{ + Tool: "nmap", + Type: "open_port", + Title: "Port 443 Open", + Metadata: map[string]interface{}{ + "target": "example.com", + }, + }, + expectSame: true, + description: "Should extract target from metadata['target']", + }, + { + name: "endpoint field in metadata", + finding: types.Finding{ + Tool: "nikto", + Type: "web_vuln", + Title: "SQL Injection", + Metadata: map[string]interface{}{ + "endpoint": "/api/users", + }, + }, + expectSame: true, + description: "Should extract target from metadata['endpoint']", + }, + { + name: "url field in metadata", + finding: types.Finding{ + Tool: "burp", + Type: "xss", + Title: "Reflected XSS", + Metadata: map[string]interface{}{ + "url": "https://example.com/search", + }, + }, + expectSame: true, + description: "Should extract target from metadata['url']", + }, + { + name: "hostname field in metadata", + finding: types.Finding{ + Tool: "ssl", + Type: "cert_vuln", + Title: "Expired Certificate", + Metadata: map[string]interface{}{ + "hostname": "api.example.com", + }, + }, + expectSame: true, + description: "Should extract target from metadata['hostname']", + }, + { + name: "ip field in metadata", + finding: types.Finding{ + Tool: "nmap", + Type: "port_scan", + Title: "SSH Open", + Metadata: map[string]interface{}{ + "ip": "192.168.1.100", + }, + }, + expectSame: true, + description: "Should extract target from metadata['ip']", + }, + { + name: "empty metadata", + finding: types.Finding{ + Tool: "test", + Type: "vuln", + Title: "Test Vulnerability", + Metadata: map[string]interface{}{}, + }, + expectSame: false, + description: "Should create weak fingerprint when no target available", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fingerprint := generateFindingFingerprint(tt.finding) + + // Fingerprint should never be empty + assert.NotEmpty(t, fingerprint, "Fingerprint should never be empty") + + // Fingerprint should be consistent + fingerprint2 := generateFindingFingerprint(tt.finding) + assert.Equal(t, fingerprint, fingerprint2, "Same finding should produce same fingerprint") + }) + } +} + +// Test generateFindingFingerprint with evidence parsing +func TestGenerateFindingFingerprint_EvidenceParsing(t *testing.T) { + tests := []struct { + name string + finding types.Finding + shouldMatch bool + description string + }{ + { + name: "HTTP method in evidence", + finding: types.Finding{ + Tool: "burp", + Type: "xss", + Title: "XSS Vulnerability", + Evidence: "GET /api/search?q= HTTP/1.1\nHost: example.com", + }, + shouldMatch: true, + description: "Should extract /api/search from evidence", + }, + { + name: "URL in evidence", + finding: types.Finding{ + Tool: "nikto", + Type: "web_vuln", + Title: "Directory Listing", + Evidence: "https://example.com/admin/\nStatus: 200 OK", + }, + shouldMatch: true, + description: "Should extract URL from evidence", + }, + { + name: "URL: prefix in evidence", + finding: types.Finding{ + Tool: "scanner", + Type: "vuln", + Title: "Vulnerability Found", + Evidence: "URL: https://api.example.com/users\nSeverity: High", + }, + shouldMatch: true, + description: "Should extract URL from URL: prefix", + }, + { + name: "Target: prefix in evidence", + finding: types.Finding{ + Tool: "scanner", + Type: "vuln", + Title: "SQL Injection", + Evidence: "Target: /api/login\nPayload: ' OR '1'='1", + }, + shouldMatch: true, + description: "Should extract target from Target: prefix", + }, + { + name: "No parseable evidence", + finding: types.Finding{ + Tool: "scanner", + Type: "vuln", + Title: "Vulnerability", + Evidence: "Some random text without URL or path", + }, + shouldMatch: false, + description: "Should create weak fingerprint when evidence can't be parsed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fingerprint := generateFindingFingerprint(tt.finding) + assert.NotEmpty(t, fingerprint, "Fingerprint should never be empty") + }) + } +} + +// Test that same vulnerability at different locations gets different fingerprints +func TestGenerateFindingFingerprint_DifferentLocations(t *testing.T) { + finding1 := types.Finding{ + Tool: "scanner", + Type: "sql_injection", + Title: "SQL Injection", + Metadata: map[string]interface{}{ + "endpoint": "/api/users", + }, + } + + finding2 := types.Finding{ + Tool: "scanner", + Type: "sql_injection", + Title: "SQL Injection", + Metadata: map[string]interface{}{ + "endpoint": "/api/products", + }, + } + + fp1 := generateFindingFingerprint(finding1) + fp2 := generateFindingFingerprint(finding2) + + assert.NotEqual(t, fp1, fp2, "Same vulnerability type at different endpoints should have different fingerprints") +} + +// Test that same vulnerability across scans gets same fingerprint +func TestGenerateFindingFingerprint_CrossScanConsistency(t *testing.T) { + // Scan 1 + finding1 := types.Finding{ + ID: "finding-1", + ScanID: "scan-1", + Tool: "nmap", + Type: "open_port", + Title: "Port 443 Open", + Metadata: map[string]interface{}{ + "target": "example.com", + "port": "443", + }, + } + + // Scan 2 (different IDs, different scan, but same vulnerability) + finding2 := types.Finding{ + ID: "finding-2", + ScanID: "scan-2", + Tool: "nmap", + Type: "open_port", + Title: "Port 443 Open", + Metadata: map[string]interface{}{ + "target": "example.com", + "port": "443", + }, + } + + fp1 := generateFindingFingerprint(finding1) + fp2 := generateFindingFingerprint(finding2) + + assert.Equal(t, fp1, fp2, "Same vulnerability across different scans should have identical fingerprints") +} + +// Test UpdateFindingStatus method +func TestUpdateFindingStatus(t *testing.T) { + // Create in-memory test database + store, cleanup := setupTestStore(t) + defer cleanup() + + ctx := context.Background() + + // Create a test scan + scan := &types.ScanRequest{ + ID: "test-scan-1", + Target: "example.com", + Type: types.ScanTypeWeb, + Status: types.ScanStatusRunning, + } + err := store.SaveScan(ctx, scan) + require.NoError(t, err) + + // Create a test finding + finding := types.Finding{ + ID: "test-finding-1", + ScanID: "test-scan-1", + Tool: "test", + Type: "test_vuln", + Severity: types.SeverityHigh, + Title: "Test Vulnerability", + Description: "Test Description", + Status: types.FindingStatusNew, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err = store.SaveFindings(ctx, []types.Finding{finding}) + require.NoError(t, err) + + // Update status to fixed + err = store.UpdateFindingStatus(ctx, "test-finding-1", types.FindingStatusFixed) + assert.NoError(t, err, "Should successfully update finding status") + + // Verify status was updated + findings, err := store.GetFindings(ctx, "test-scan-1") + require.NoError(t, err) + require.Len(t, findings, 1) + assert.Equal(t, types.FindingStatusFixed, findings[0].Status) + + // Try to update non-existent finding + err = store.UpdateFindingStatus(ctx, "non-existent", types.FindingStatusFixed) + assert.Error(t, err, "Should fail when finding doesn't exist") +} + +// Test MarkFindingVerified method +func TestMarkFindingVerified(t *testing.T) { + store, cleanup := setupTestStore(t) + defer cleanup() + + ctx := context.Background() + + // Create test scan and finding + scan := &types.ScanRequest{ + ID: "test-scan-2", + Target: "example.com", + Type: types.ScanTypeWeb, + Status: types.ScanStatusRunning, + } + err := store.SaveScan(ctx, scan) + require.NoError(t, err) + + finding := types.Finding{ + ID: "test-finding-2", + ScanID: "test-scan-2", + Tool: "test", + Type: "test_vuln", + Severity: types.SeverityMedium, + Title: "Test Vulnerability", + Verified: false, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err = store.SaveFindings(ctx, []types.Finding{finding}) + require.NoError(t, err) + + // Mark as verified + err = store.MarkFindingVerified(ctx, "test-finding-2", true) + assert.NoError(t, err) + + // Verify flag was set + findings, err := store.GetFindings(ctx, "test-scan-2") + require.NoError(t, err) + require.Len(t, findings, 1) + assert.True(t, findings[0].Verified) + + // Unmark verification + err = store.MarkFindingVerified(ctx, "test-finding-2", false) + assert.NoError(t, err) + + findings, err = store.GetFindings(ctx, "test-scan-2") + require.NoError(t, err) + assert.False(t, findings[0].Verified) +} + +// Test MarkFindingFalsePositive method +func TestMarkFindingFalsePositive(t *testing.T) { + store, cleanup := setupTestStore(t) + defer cleanup() + + ctx := context.Background() + + // Create test scan and finding + scan := &types.ScanRequest{ + ID: "test-scan-3", + Target: "example.com", + Type: types.ScanTypeWeb, + Status: types.ScanStatusRunning, + } + err := store.SaveScan(ctx, scan) + require.NoError(t, err) + + finding := types.Finding{ + ID: "test-finding-3", + ScanID: "test-scan-3", + Tool: "test", + Type: "test_vuln", + Severity: types.SeverityLow, + Title: "Test Vulnerability", + FalsePositive: false, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err = store.SaveFindings(ctx, []types.Finding{finding}) + require.NoError(t, err) + + // Mark as false positive + err = store.MarkFindingFalsePositive(ctx, "test-finding-3", true) + assert.NoError(t, err) + + // Verify flag was set + findings, err := store.GetFindings(ctx, "test-scan-3") + require.NoError(t, err) + require.Len(t, findings, 1) + assert.True(t, findings[0].FalsePositive) + + // Remove false positive flag + err = store.MarkFindingFalsePositive(ctx, "test-finding-3", false) + assert.NoError(t, err) + + findings, err = store.GetFindings(ctx, "test-scan-3") + require.NoError(t, err) + assert.False(t, findings[0].FalsePositive) +} + +// Helper function to set up a test store +func setupTestStore(t *testing.T) (*sqlStore, func()) { + // Create logger + log, err := logger.New(logger.Config{ + Level: "info", + Format: "json", + }) + require.NoError(t, err) + + // Create in-memory SQLite database + store, err := NewResultStore(Config{ + Type: "sqlite", + Host: ":memory:", + Database: "test", + }, log) + require.NoError(t, err) + + sqlStore, ok := store.(*sqlStore) + require.True(t, ok) + + cleanup := func() { + _ = store.Close() + } + + return sqlStore, cleanup +} diff --git a/internal/orchestrator/bounty_engine.go b/internal/orchestrator/bounty_engine.go index 006ae8a..76a15d2 100644 --- a/internal/orchestrator/bounty_engine.go +++ b/internal/orchestrator/bounty_engine.go @@ -1756,6 +1756,7 @@ func (e *BugBountyEngine) ExecuteWithPipeline(ctx context.Context, target string pipeline.correlationEngine = NewCorrelationEngine( e.config, e.logger, + e.store, exploitChainer, e.enricher, ) diff --git a/internal/orchestrator/correlation.go b/internal/orchestrator/correlation.go index afb96d4..40b63a7 100644 --- a/internal/orchestrator/correlation.go +++ b/internal/orchestrator/correlation.go @@ -26,10 +26,12 @@ import ( "fmt" "time" + "github.com/CodeMonkeyCybersecurity/shells/internal/core" "github.com/CodeMonkeyCybersecurity/shells/internal/logger" "github.com/CodeMonkeyCybersecurity/shells/pkg/correlation" "github.com/CodeMonkeyCybersecurity/shells/pkg/enrichment" "github.com/CodeMonkeyCybersecurity/shells/pkg/types" + "github.com/google/uuid" ) // ExploitChain represents a sequence of vulnerabilities that combine for higher impact @@ -48,6 +50,7 @@ type ExploitChain struct { type CorrelationEngine struct { logger *logger.Logger config BugBountyConfig + store core.ResultStore exploitChainer *correlation.ExploitChainer enricher *enrichment.ResultEnricher } @@ -56,12 +59,14 @@ type CorrelationEngine struct { func NewCorrelationEngine( config BugBountyConfig, logger *logger.Logger, + store core.ResultStore, exploitChainer *correlation.ExploitChainer, enricher *enrichment.ResultEnricher, ) *CorrelationEngine { return &CorrelationEngine{ logger: logger.WithComponent("correlation"), config: config, + store: store, exploitChainer: exploitChainer, enricher: enricher, } @@ -129,6 +134,24 @@ func (c *CorrelationEngine) Execute(ctx context.Context, state *PipelineState) e // Log detected chains if len(chains) > 0 { c.logExploitChains(state.ScanID, chains) + + // P0 FIX: Save correlation results to database + if c.store != nil { + correlationResults := c.convertChainsToCorrelationResults(state.ScanID, chains) + if err := c.store.SaveCorrelationResults(ctx, correlationResults); err != nil { + c.logger.Errorw("Failed to save correlation results", + "error", err, + "scan_id", state.ScanID, + "chains_count", len(chains), + ) + // Don't fail the entire pipeline - just log the error + } else { + c.logger.Infow("Correlation results saved to database", + "scan_id", state.ScanID, + "results_saved", len(correlationResults), + ) + } + } } return nil @@ -313,3 +336,58 @@ func (c *CorrelationEngine) logExploitChains(scanID string, chains []ExploitChai "note", "Exploit chains often receive higher bounties than individual vulnerabilities", ) } + +// convertChainsToCorrelationResults converts ExploitChain objects to CorrelationResult for database persistence +func (c *CorrelationEngine) convertChainsToCorrelationResults(scanID string, chains []ExploitChain) []types.CorrelationResult { + results := make([]types.CorrelationResult, 0, len(chains)) + now := time.Now() + + for _, chain := range chains { + // Extract finding IDs from chain steps + relatedFindings := make([]string, 0, len(chain.Steps)) + for _, step := range chain.Steps { + relatedFindings = append(relatedFindings, step.ID) + } + + // Build attack path with step-by-step breakdown + attackPath := make([]map[string]interface{}, 0, len(chain.Steps)) + for i, step := range chain.Steps { + attackPath = append(attackPath, map[string]interface{}{ + "step": i + 1, + "finding_id": step.ID, + "type": step.Type, + "title": step.Title, + "severity": step.Severity, + "description": step.Description, + }) + } + + // Build metadata with chain-specific information + metadata := map[string]interface{}{ + "chain_name": chain.Name, + "cvss_score": chain.CVSSScore, + "impact": chain.Impact, + "remediation": chain.Remediation, + "step_count": len(chain.Steps), + } + + result := types.CorrelationResult{ + ID: uuid.New().String(), + ScanID: scanID, + InsightType: "attack_chain", + Severity: chain.Severity, + Title: chain.Name, + Description: chain.Description, + Confidence: 0.85, // High confidence for detected chains + RelatedFindings: relatedFindings, + AttackPath: attackPath, + Metadata: metadata, + CreatedAt: now, + UpdatedAt: now, + } + + results = append(results, result) + } + + return results +} diff --git a/pkg/types/types.go b/pkg/types/types.go index b6e256c..241232b 100755 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -65,7 +65,7 @@ type Finding struct { Metadata map[string]interface{} `json:"metadata,omitempty"` Fingerprint string `json:"fingerprint,omitempty" db:"fingerprint"` // Hash for deduplication across scans FirstScanID string `json:"first_scan_id,omitempty" db:"first_scan_id"` // Scan ID where first detected - Status string `json:"status,omitempty" db:"status"` // new, active, fixed, duplicate, reopened + Status FindingStatus `json:"status,omitempty" db:"status"` // new, active, fixed, duplicate, reopened Verified bool `json:"verified" db:"verified"` // Manually verified FalsePositive bool `json:"false_positive" db:"false_positive"` // Marked as false positive CreatedAt time.Time `json:"created_at" db:"created_at"`