Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions runner/cmd/shim/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/dstackai/dstack/runner/internal/shim/api"
"github.com/dstackai/dstack/runner/internal/shim/components"
"github.com/dstackai/dstack/runner/internal/shim/dcgm"
"github.com/dstackai/dstack/runner/internal/shim/netmeter"
)

// Version is a build-time variable. The value is overridden by ldflags.
Expand Down Expand Up @@ -270,11 +271,22 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
}
}

var nm *netmeter.NetMeter
nm = netmeter.New()
if err := nm.Start(ctx); err != nil {
log.Warning(ctx, "data transfer metering unavailable", "err", err)
nm = nil
} else {
log.Info(ctx, "data transfer metering started")
defer nm.Stop()
}

address := fmt.Sprintf("localhost:%d", args.Shim.HTTPPort)
shimServer := api.NewShimServer(
ctx, address, Version,
dockerRunner, dcgmExporter, dcgmWrapper,
runnerManager, shimManager,
nm,
)

if serviceMode {
Expand Down
4 changes: 4 additions & 0 deletions runner/internal/shim/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ func (s *ShimServer) InstanceHealthHandler(w http.ResponseWriter, r *http.Reques
response.DCGM = &dcgmHealth
}
}
if s.netMeter != nil {
b := s.netMeter.Bytes()
response.DataTransferBytes = &b
}

return &response, nil
}
Expand Down
4 changes: 2 additions & 2 deletions runner/internal/shim/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func TestHealthcheck(t *testing.T) {
request := httptest.NewRequest("GET", "/api/healthcheck", nil)
responseRecorder := httptest.NewRecorder()

server := NewShimServer(context.Background(), ":12345", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil)
server := NewShimServer(context.Background(), ":12345", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil, nil)

f := commonapi.JSONResponseHandler(server.HealthcheckHandler)
f(responseRecorder, request)
Expand All @@ -30,7 +30,7 @@ func TestHealthcheck(t *testing.T) {
}

func TestTaskSubmit(t *testing.T) {
server := NewShimServer(context.Background(), ":12340", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil)
server := NewShimServer(context.Background(), ":12340", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil, nil)
requestBody := `{
"id": "dummy-id",
"name": "dummy-name",
Expand Down
3 changes: 2 additions & 1 deletion runner/internal/shim/api/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ type ShutdownRequest struct {
}

type InstanceHealthResponse struct {
DCGM *dcgm.Health `json:"dcgm"`
DCGM *dcgm.Health `json:"dcgm"`
DataTransferBytes *int64 `json:"data_transfer_bytes,omitempty"`
}

type TaskListResponse struct {
Expand Down
6 changes: 6 additions & 0 deletions runner/internal/shim/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/dstackai/dstack/runner/internal/shim"
"github.com/dstackai/dstack/runner/internal/shim/components"
"github.com/dstackai/dstack/runner/internal/shim/dcgm"
"github.com/dstackai/dstack/runner/internal/shim/netmeter"
)

type TaskRunner interface {
Expand Down Expand Up @@ -45,13 +46,16 @@ type ShimServer struct {
runnerManager components.ComponentManager
shimManager components.ComponentManager

netMeter *netmeter.NetMeter // may be nil if metering is unavailable

version string
}

func NewShimServer(
ctx context.Context, address string, version string,
runner TaskRunner, dcgmExporter *dcgm.DCGMExporter, dcgmWrapper dcgm.DCGMWrapperInterface,
runnerManager components.ComponentManager, shimManager components.ComponentManager,
nm *netmeter.NetMeter,
) *ShimServer {
bgJobsCtx, bgJobsCancel := context.WithCancel(ctx)
if dcgmWrapper != nil && reflect.ValueOf(dcgmWrapper).IsNil() {
Expand All @@ -78,6 +82,8 @@ func NewShimServer(
runnerManager: runnerManager,
shimManager: shimManager,

netMeter: nm,

version: version,
}

Expand Down
5 changes: 5 additions & 0 deletions runner/internal/shim/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
if err := d.tasks.Update(task); err != nil {
return fmt.Errorf("%w: failed to update task %s: %w", ErrInternal, task.ID, err)
}

err = d.waitContainer(ctx, &task)
}
if err != nil {
Expand Down Expand Up @@ -1228,6 +1229,10 @@ func (c *CLIArgs) DockerPorts() []int {
return []int{c.Runner.HTTPPort, c.Runner.SSHPort}
}

func (c *CLIArgs) RunnerHTTPPort() int {
return c.Runner.HTTPPort
}

func (c *CLIArgs) MakeRunnerDir(name string) (string, error) {
runnerTemp := filepath.Join(c.Shim.HomeDir, "runners", name)
if err := os.MkdirAll(runnerTemp, 0o755); err != nil {
Expand Down
4 changes: 4 additions & 0 deletions runner/internal/shim/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ func (c *dockerParametersMock) DockerPorts() []int {
return []int{}
}

func (c *dockerParametersMock) RunnerHTTPPort() int {
return 10999
}

func (c *dockerParametersMock) DockerMounts(string) ([]mount.Mount, error) {
return nil, nil
}
Expand Down
1 change: 1 addition & 0 deletions runner/internal/shim/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ type DockerParameters interface {
DockerShellCommands(authorizedKeys []string, runnerHttpAddress string) []string
DockerMounts(string) ([]mount.Mount, error)
DockerPorts() []int
RunnerHTTPPort() int
MakeRunnerDir(name string) (string, error)
DockerPJRTDevice() string
}
Expand Down
227 changes: 227 additions & 0 deletions runner/internal/shim/netmeter/netmeter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
package netmeter

import (
"bytes"
"context"
"fmt"
"os/exec"
"strconv"
"strings"
"sync/atomic"
"time"

"github.com/dstackai/dstack/runner/internal/common/log"
)

const (
pollInterval = 10 * time.Second
chainName = "dstack-nm"
)

// NetMeter monitors outbound data transfer using iptables byte counters.
// It excludes private/VPC traffic and counts only external (billable) bytes.
// The meter runs for the lifetime of the shim process (per-instance, not per-task).
type NetMeter struct {
bytes atomic.Int64
stopCh chan struct{}
stopped chan struct{}
}

// New creates a new NetMeter.
func New() *NetMeter {
return &NetMeter{
stopCh: make(chan struct{}),
stopped: make(chan struct{}),
}
}

// Start sets up iptables rules and begins polling byte counters.
func (m *NetMeter) Start(ctx context.Context) error {
if err := checkIptables(); err != nil {
return fmt.Errorf("iptables not available: %w", err)
}

// Clean up any orphaned chain from a previous shim process
cleanupChain(ctx)

if err := setupChain(ctx); err != nil {
return fmt.Errorf("setup iptables chain: %w", err)
}

go m.pollLoop(ctx)
return nil
}

// Stop signals the poll loop to stop and cleans up iptables rules.
func (m *NetMeter) Stop() {
close(m.stopCh)
<-m.stopped
}

// Bytes returns the cumulative external outbound byte count (thread-safe).
func (m *NetMeter) Bytes() int64 {
return m.bytes.Load()
}

func checkIptables() error {
_, err := exec.LookPath("iptables")
return err
}

func setupChain(ctx context.Context) error {
// Create the chain
if err := iptables(ctx, "-N", chainName); err != nil {
return fmt.Errorf("create chain: %w", err)
}

// Add exclusion rules for private/internal traffic (these RETURN without counting)
privateCIDRs := []struct {
cidr string
comment string
}{
{"10.0.0.0/8", "VPC/private"},
{"172.16.0.0/12", "VPC/private"},
{"192.168.0.0/16", "VPC/private"},
{"169.254.0.0/16", "link-local/metadata"},
{"127.0.0.0/8", "loopback"},
}
for _, p := range privateCIDRs {
if err := iptables(ctx, "-A", chainName, "-d", p.cidr, "-j", "RETURN"); err != nil {
cleanupChain(ctx)
return fmt.Errorf("add exclusion rule for %s: %w", p.comment, err)
}
}

// Add catch-all counting rule (counts all remaining = external/billable bytes)
if err := iptables(ctx, "-A", chainName, "-j", "RETURN"); err != nil {
cleanupChain(ctx)
return fmt.Errorf("add counting rule: %w", err)
}

// Insert jump from OUTPUT chain (catches host-mode Docker and host processes)
if err := iptables(ctx, "-I", "OUTPUT", "-j", chainName); err != nil {
cleanupChain(ctx)
return fmt.Errorf("insert OUTPUT jump: %w", err)
}

// Insert jump from FORWARD chain (catches bridge-mode Docker traffic)
if err := iptables(ctx, "-I", "FORWARD", "-j", chainName); err != nil {
cleanupChain(ctx)
return fmt.Errorf("insert FORWARD jump: %w", err)
}

return nil
}

func cleanupChain(ctx context.Context) {
_ = iptables(ctx, "-D", "OUTPUT", "-j", chainName)
_ = iptables(ctx, "-D", "FORWARD", "-j", chainName)
_ = iptables(ctx, "-F", chainName)
_ = iptables(ctx, "-X", chainName)
}

func (m *NetMeter) pollLoop(ctx context.Context) {
defer close(m.stopped)
defer cleanupChain(ctx)

ticker := time.NewTicker(pollInterval)
defer ticker.Stop()

for {
select {
case <-m.stopCh:
return
case <-ticker.C:
b, err := readCounter(ctx)
if err != nil {
log.Error(ctx, "failed to read data transfer counter", "err", err)
continue
}
m.bytes.Store(b)
log.Debug(ctx, "data transfer meter poll", "bytes", b)
}
}
}

// readCounter reads the cumulative byte count from the catch-all rule (last rule in chain).
func readCounter(ctx context.Context) (int64, error) {
output, err := iptablesOutput(ctx, "-L", chainName, "-v", "-x", "-n")
if err != nil {
return 0, err
}
return parseByteCounter(output)
}

// parseByteCounter extracts the byte count from the last rule (catch-all counting rule)
// in the iptables -L -v -x -n output.
//
// Example output:
//
// Chain dstack-nm (1 references)
// pkts bytes target prot opt in out source destination
// 0 0 RETURN all -- * * 0.0.0.0/0 10.0.0.0/8
// 0 0 RETURN all -- * * 0.0.0.0/0 172.16.0.0/12
// 0 0 RETURN all -- * * 0.0.0.0/0 192.168.0.0/16
// 0 0 RETURN all -- * * 0.0.0.0/0 169.254.0.0/16
// 0 0 RETURN all -- * * 0.0.0.0/0 127.0.0.0/8
// 123 456789 RETURN all -- * * 0.0.0.0/0 0.0.0.0/0
//
// The last rule (destination 0.0.0.0/0) is the catch-all; its bytes field is what we want.
func parseByteCounter(output string) (int64, error) {
lines := strings.Split(strings.TrimSpace(output), "\n")

// Find lines that are rule entries (skip header lines)
var lastRuleLine string
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
// Skip "Chain ..." and column header lines
if strings.HasPrefix(trimmed, "Chain ") {
continue
}
if strings.HasPrefix(trimmed, "pkts") {
continue
}
lastRuleLine = trimmed
}

if lastRuleLine == "" {
return 0, fmt.Errorf("no rules found in chain %s", chainName)
}

// Parse the bytes field (second field in the line)
fields := strings.Fields(lastRuleLine)
if len(fields) < 2 {
return 0, fmt.Errorf("unexpected rule format: %q", lastRuleLine)
}

byteCount, err := strconv.ParseInt(fields[1], 10, 64)
if err != nil {
return 0, fmt.Errorf("parse byte count %q: %w", fields[1], err)
}

return byteCount, nil
}

func iptables(ctx context.Context, args ...string) error {
cmd := exec.CommandContext(ctx, "iptables", args...)
var stderr bytes.Buffer
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("iptables %s: %s: %w", strings.Join(args, " "), stderr.String(), err)
}
return nil
}

func iptablesOutput(ctx context.Context, args ...string) (string, error) {
cmd := exec.CommandContext(ctx, "iptables", args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("iptables %s: %s: %w", strings.Join(args, " "), stderr.String(), err)
}
return stdout.String(), nil
}
Loading
Loading