diff --git a/internal/stack/stack.go b/internal/stack/stack.go index f11fd00..82189eb 100644 --- a/internal/stack/stack.go +++ b/internal/stack/stack.go @@ -2,10 +2,12 @@ package stack import ( "fmt" + "net" "os" "os/exec" "path/filepath" "runtime" + "strconv" "strings" "github.com/ObolNetwork/obol-stack/internal/config" @@ -169,6 +171,12 @@ func Up(cfg *config.Config) error { return fmt.Errorf("failed to create data directory: %w", err) } + // Check required ports are available before creating cluster + requiredPorts := []int{80, 8080, 443, 8443} + if err := checkPortsAvailable(requiredPorts); err != nil { + return err + } + // Create cluster using k3d config with custom name fmt.Println("Creating k3d cluster...") createCmd := exec.Command( @@ -392,6 +400,39 @@ func syncDefaults(cfg *config.Config, kubeconfigPath string) error { return nil } +// checkPortsAvailable verifies that all required ports can be bound. +// Returns an actionable error if any port is already in use. +func checkPortsAvailable(ports []int) error { + var blocked []int + for _, port := range ports { + ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + blocked = append(blocked, port) + continue + } + ln.Close() + } + if len(blocked) > 0 { + return fmt.Errorf( + "port(s) %s already in use\n\n"+ + "Obol Stack needs these ports for HTTP/HTTPS access.\n"+ + "Find what's using them with:\n"+ + " sudo lsof -i :%d\n\n"+ + "Then stop the conflicting service and retry 'obol stack up'.", + formatPorts(blocked), blocked[0], + ) + } + return nil +} + +func formatPorts(ports []int) string { + strs := make([]string, len(ports)) + for i, p := range ports { + strs[i] = strconv.Itoa(p) + } + return strings.Join(strs, ", ") +} + func migrateDefaultsHTTPRouteHostnames(helmfilePath string) error { data, err := os.ReadFile(helmfilePath) if err != nil { diff --git a/internal/stack/stack_test.go b/internal/stack/stack_test.go new file mode 100644 index 0000000..c3b254b --- /dev/null +++ b/internal/stack/stack_test.go @@ -0,0 +1,92 @@ +package stack + +import ( + "fmt" + "net" + "strings" + "testing" +) + +func TestCheckPortsAvailable_FreePorts(t *testing.T) { + // Use high ephemeral ports that are almost certainly free + ports := []int{19876, 19877} + if err := checkPortsAvailable(ports); err != nil { + t.Fatalf("expected no error for free ports, got: %v", err) + } +} + +func TestCheckPortsAvailable_BlockedPort(t *testing.T) { + // Bind a port to simulate a conflict + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("failed to bind ephemeral port: %v", err) + } + defer ln.Close() + + // Extract the port number from the listener address + addr := ln.Addr().(*net.TCPAddr) + blockedPort := addr.Port + + err = checkPortsAvailable([]int{blockedPort}) + if err == nil { + t.Fatal("expected error for blocked port, got nil") + } + + portStr := fmt.Sprintf("%d", blockedPort) + if !strings.Contains(err.Error(), portStr) { + t.Errorf("error should mention blocked port %d, got: %v", blockedPort, err) + } + if !strings.Contains(err.Error(), "already in use") { + t.Errorf("error should mention 'already in use', got: %v", err) + } + if !strings.Contains(err.Error(), "sudo lsof") { + t.Errorf("error should include remediation hint, got: %v", err) + } +} + +func TestCheckPortsAvailable_MixedPorts(t *testing.T) { + // Bind one port, leave another free + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("failed to bind ephemeral port: %v", err) + } + defer ln.Close() + + blockedPort := ln.Addr().(*net.TCPAddr).Port + + // Pick a free port by briefly binding and releasing + ln2, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("failed to bind second ephemeral port: %v", err) + } + freePort := ln2.Addr().(*net.TCPAddr).Port + ln2.Close() + + err = checkPortsAvailable([]int{freePort, blockedPort}) + if err == nil { + t.Fatal("expected error when one port is blocked, got nil") + } + + // Should mention only the blocked port + blockedStr := fmt.Sprintf("%d", blockedPort) + if !strings.Contains(err.Error(), blockedStr) { + t.Errorf("error should mention blocked port %d, got: %v", blockedPort, err) + } +} + +func TestFormatPorts(t *testing.T) { + tests := []struct { + ports []int + expected string + }{ + {[]int{443}, "443"}, + {[]int{80, 443}, "80, 443"}, + {[]int{80, 8080, 443, 8443}, "80, 8080, 443, 8443"}, + } + for _, tt := range tests { + got := formatPorts(tt.ports) + if got != tt.expected { + t.Errorf("formatPorts(%v) = %q, want %q", tt.ports, got, tt.expected) + } + } +}