Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions internal/stack/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package stack

import (
"fmt"
"net"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"

"github.com/ObolNetwork/obol-stack/internal/config"
Expand Down Expand Up @@ -169,6 +171,12 @@ func Up(cfg *config.Config) error {
return fmt.Errorf("failed to create data directory: %w", err)
}

// Check required ports are available before creating cluster
requiredPorts := []int{80, 8080, 443, 8443}
if err := checkPortsAvailable(requiredPorts); err != nil {
return err
}

// Create cluster using k3d config with custom name
fmt.Println("Creating k3d cluster...")
createCmd := exec.Command(
Expand Down Expand Up @@ -392,6 +400,39 @@ func syncDefaults(cfg *config.Config, kubeconfigPath string) error {
return nil
}

// checkPortsAvailable verifies that all required ports can be bound.
// Returns an actionable error if any port is already in use.
func checkPortsAvailable(ports []int) error {
var blocked []int
for _, port := range ports {
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
blocked = append(blocked, port)
continue
}
ln.Close()
}
if len(blocked) > 0 {
return fmt.Errorf(
"port(s) %s already in use\n\n"+
"Obol Stack needs these ports for HTTP/HTTPS access.\n"+
"Find what's using them with:\n"+
" sudo lsof -i :%d\n\n"+
"Then stop the conflicting service and retry 'obol stack up'.",
formatPorts(blocked), blocked[0],
)
}
return nil
}

func formatPorts(ports []int) string {
strs := make([]string, len(ports))
for i, p := range ports {
strs[i] = strconv.Itoa(p)
}
return strings.Join(strs, ", ")
}

func migrateDefaultsHTTPRouteHostnames(helmfilePath string) error {
data, err := os.ReadFile(helmfilePath)
if err != nil {
Expand Down
92 changes: 92 additions & 0 deletions internal/stack/stack_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package stack

import (
"fmt"
"net"
"strings"
"testing"
)

func TestCheckPortsAvailable_FreePorts(t *testing.T) {
// Use high ephemeral ports that are almost certainly free
ports := []int{19876, 19877}
if err := checkPortsAvailable(ports); err != nil {
t.Fatalf("expected no error for free ports, got: %v", err)
}
}

func TestCheckPortsAvailable_BlockedPort(t *testing.T) {
// Bind a port to simulate a conflict
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to bind ephemeral port: %v", err)
}
defer ln.Close()

// Extract the port number from the listener address
addr := ln.Addr().(*net.TCPAddr)
blockedPort := addr.Port

err = checkPortsAvailable([]int{blockedPort})
if err == nil {
t.Fatal("expected error for blocked port, got nil")
}

portStr := fmt.Sprintf("%d", blockedPort)
if !strings.Contains(err.Error(), portStr) {
t.Errorf("error should mention blocked port %d, got: %v", blockedPort, err)
}
if !strings.Contains(err.Error(), "already in use") {
t.Errorf("error should mention 'already in use', got: %v", err)
}
if !strings.Contains(err.Error(), "sudo lsof") {
t.Errorf("error should include remediation hint, got: %v", err)
}
}

func TestCheckPortsAvailable_MixedPorts(t *testing.T) {
// Bind one port, leave another free
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to bind ephemeral port: %v", err)
}
defer ln.Close()

blockedPort := ln.Addr().(*net.TCPAddr).Port

// Pick a free port by briefly binding and releasing
ln2, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to bind second ephemeral port: %v", err)
}
freePort := ln2.Addr().(*net.TCPAddr).Port
ln2.Close()

err = checkPortsAvailable([]int{freePort, blockedPort})
if err == nil {
t.Fatal("expected error when one port is blocked, got nil")
}

// Should mention only the blocked port
blockedStr := fmt.Sprintf("%d", blockedPort)
if !strings.Contains(err.Error(), blockedStr) {
t.Errorf("error should mention blocked port %d, got: %v", blockedPort, err)
}
}

func TestFormatPorts(t *testing.T) {
tests := []struct {
ports []int
expected string
}{
{[]int{443}, "443"},
{[]int{80, 443}, "80, 443"},
{[]int{80, 8080, 443, 8443}, "80, 8080, 443, 8443"},
}
for _, tt := range tests {
got := formatPorts(tt.ports)
if got != tt.expected {
t.Errorf("formatPorts(%v) = %q, want %q", tt.ports, got, tt.expected)
}
}
}
Loading