diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index c696bd467..95cdf448b 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -22,6 +22,7 @@ import ( "github.com/dstackai/dstack/runner/internal/shim/api" "github.com/dstackai/dstack/runner/internal/shim/components" "github.com/dstackai/dstack/runner/internal/shim/dcgm" + "github.com/dstackai/dstack/runner/internal/shim/netmeter" ) // Version is a build-time variable. The value is overridden by ldflags. @@ -270,11 +271,22 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) } } + var nm *netmeter.NetMeter + nm = netmeter.New() + if err := nm.Start(ctx); err != nil { + log.Warning(ctx, "data transfer metering unavailable", "err", err) + nm = nil + } else { + log.Info(ctx, "data transfer metering started") + defer nm.Stop() + } + address := fmt.Sprintf("localhost:%d", args.Shim.HTTPPort) shimServer := api.NewShimServer( ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper, runnerManager, shimManager, + nm, ) if serviceMode { diff --git a/runner/internal/shim/api/handlers.go b/runner/internal/shim/api/handlers.go index b3382d0f2..37de925a6 100644 --- a/runner/internal/shim/api/handlers.go +++ b/runner/internal/shim/api/handlers.go @@ -47,6 +47,10 @@ func (s *ShimServer) InstanceHealthHandler(w http.ResponseWriter, r *http.Reques response.DCGM = &dcgmHealth } } + if s.netMeter != nil { + b := s.netMeter.Bytes() + response.DataTransferBytes = &b + } return &response, nil } diff --git a/runner/internal/shim/api/handlers_test.go b/runner/internal/shim/api/handlers_test.go index bb19ebbf1..751151c81 100644 --- a/runner/internal/shim/api/handlers_test.go +++ b/runner/internal/shim/api/handlers_test.go @@ -13,7 +13,7 @@ func TestHealthcheck(t *testing.T) { request := httptest.NewRequest("GET", "/api/healthcheck", nil) responseRecorder := httptest.NewRecorder() - server := NewShimServer(context.Background(), ":12345", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil) + server := NewShimServer(context.Background(), ":12345", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil, nil) f := commonapi.JSONResponseHandler(server.HealthcheckHandler) f(responseRecorder, request) @@ -30,7 +30,7 @@ func TestHealthcheck(t *testing.T) { } func TestTaskSubmit(t *testing.T) { - server := NewShimServer(context.Background(), ":12340", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil) + server := NewShimServer(context.Background(), ":12340", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil, nil) requestBody := `{ "id": "dummy-id", "name": "dummy-name", diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index cd0db6a20..adc272375 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -16,7 +16,8 @@ type ShutdownRequest struct { } type InstanceHealthResponse struct { - DCGM *dcgm.Health `json:"dcgm"` + DCGM *dcgm.Health `json:"dcgm"` + DataTransferBytes *int64 `json:"data_transfer_bytes,omitempty"` } type TaskListResponse struct { diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index 9008aa2ef..98378bce7 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -13,6 +13,7 @@ import ( "github.com/dstackai/dstack/runner/internal/shim" "github.com/dstackai/dstack/runner/internal/shim/components" "github.com/dstackai/dstack/runner/internal/shim/dcgm" + "github.com/dstackai/dstack/runner/internal/shim/netmeter" ) type TaskRunner interface { @@ -45,6 +46,8 @@ type ShimServer struct { runnerManager components.ComponentManager shimManager components.ComponentManager + netMeter *netmeter.NetMeter // may be nil if metering is unavailable + version string } @@ -52,6 +55,7 @@ func NewShimServer( ctx context.Context, address string, version string, runner TaskRunner, dcgmExporter *dcgm.DCGMExporter, dcgmWrapper dcgm.DCGMWrapperInterface, runnerManager components.ComponentManager, shimManager components.ComponentManager, + nm *netmeter.NetMeter, ) *ShimServer { bgJobsCtx, bgJobsCancel := context.WithCancel(ctx) if dcgmWrapper != nil && reflect.ValueOf(dcgmWrapper).IsNil() { @@ -78,6 +82,8 @@ func NewShimServer( runnerManager: runnerManager, shimManager: shimManager, + netMeter: nm, + version: version, } diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 6acfb27a5..ad26b590a 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -380,6 +380,7 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error { if err := d.tasks.Update(task); err != nil { return fmt.Errorf("%w: failed to update task %s: %w", ErrInternal, task.ID, err) } + err = d.waitContainer(ctx, &task) } if err != nil { @@ -1228,6 +1229,10 @@ func (c *CLIArgs) DockerPorts() []int { return []int{c.Runner.HTTPPort, c.Runner.SSHPort} } +func (c *CLIArgs) RunnerHTTPPort() int { + return c.Runner.HTTPPort +} + func (c *CLIArgs) MakeRunnerDir(name string) (string, error) { runnerTemp := filepath.Join(c.Shim.HomeDir, "runners", name) if err := os.MkdirAll(runnerTemp, 0o755); err != nil { diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index 18f8c31fc..3723f53e3 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -123,6 +123,10 @@ func (c *dockerParametersMock) DockerPorts() []int { return []int{} } +func (c *dockerParametersMock) RunnerHTTPPort() int { + return 10999 +} + func (c *dockerParametersMock) DockerMounts(string) ([]mount.Mount, error) { return nil, nil } diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index d50fe6e29..e7913aced 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -9,6 +9,7 @@ type DockerParameters interface { DockerShellCommands(authorizedKeys []string, runnerHttpAddress string) []string DockerMounts(string) ([]mount.Mount, error) DockerPorts() []int + RunnerHTTPPort() int MakeRunnerDir(name string) (string, error) DockerPJRTDevice() string } diff --git a/runner/internal/shim/netmeter/netmeter.go b/runner/internal/shim/netmeter/netmeter.go new file mode 100644 index 000000000..43993e4a5 --- /dev/null +++ b/runner/internal/shim/netmeter/netmeter.go @@ -0,0 +1,227 @@ +package netmeter + +import ( + "bytes" + "context" + "fmt" + "os/exec" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/dstackai/dstack/runner/internal/common/log" +) + +const ( + pollInterval = 10 * time.Second + chainName = "dstack-nm" +) + +// NetMeter monitors outbound data transfer using iptables byte counters. +// It excludes private/VPC traffic and counts only external (billable) bytes. +// The meter runs for the lifetime of the shim process (per-instance, not per-task). +type NetMeter struct { + bytes atomic.Int64 + stopCh chan struct{} + stopped chan struct{} +} + +// New creates a new NetMeter. +func New() *NetMeter { + return &NetMeter{ + stopCh: make(chan struct{}), + stopped: make(chan struct{}), + } +} + +// Start sets up iptables rules and begins polling byte counters. +func (m *NetMeter) Start(ctx context.Context) error { + if err := checkIptables(); err != nil { + return fmt.Errorf("iptables not available: %w", err) + } + + // Clean up any orphaned chain from a previous shim process + cleanupChain(ctx) + + if err := setupChain(ctx); err != nil { + return fmt.Errorf("setup iptables chain: %w", err) + } + + go m.pollLoop(ctx) + return nil +} + +// Stop signals the poll loop to stop and cleans up iptables rules. +func (m *NetMeter) Stop() { + close(m.stopCh) + <-m.stopped +} + +// Bytes returns the cumulative external outbound byte count (thread-safe). +func (m *NetMeter) Bytes() int64 { + return m.bytes.Load() +} + +func checkIptables() error { + _, err := exec.LookPath("iptables") + return err +} + +func setupChain(ctx context.Context) error { + // Create the chain + if err := iptables(ctx, "-N", chainName); err != nil { + return fmt.Errorf("create chain: %w", err) + } + + // Add exclusion rules for private/internal traffic (these RETURN without counting) + privateCIDRs := []struct { + cidr string + comment string + }{ + {"10.0.0.0/8", "VPC/private"}, + {"172.16.0.0/12", "VPC/private"}, + {"192.168.0.0/16", "VPC/private"}, + {"169.254.0.0/16", "link-local/metadata"}, + {"127.0.0.0/8", "loopback"}, + } + for _, p := range privateCIDRs { + if err := iptables(ctx, "-A", chainName, "-d", p.cidr, "-j", "RETURN"); err != nil { + cleanupChain(ctx) + return fmt.Errorf("add exclusion rule for %s: %w", p.comment, err) + } + } + + // Add catch-all counting rule (counts all remaining = external/billable bytes) + if err := iptables(ctx, "-A", chainName, "-j", "RETURN"); err != nil { + cleanupChain(ctx) + return fmt.Errorf("add counting rule: %w", err) + } + + // Insert jump from OUTPUT chain (catches host-mode Docker and host processes) + if err := iptables(ctx, "-I", "OUTPUT", "-j", chainName); err != nil { + cleanupChain(ctx) + return fmt.Errorf("insert OUTPUT jump: %w", err) + } + + // Insert jump from FORWARD chain (catches bridge-mode Docker traffic) + if err := iptables(ctx, "-I", "FORWARD", "-j", chainName); err != nil { + cleanupChain(ctx) + return fmt.Errorf("insert FORWARD jump: %w", err) + } + + return nil +} + +func cleanupChain(ctx context.Context) { + _ = iptables(ctx, "-D", "OUTPUT", "-j", chainName) + _ = iptables(ctx, "-D", "FORWARD", "-j", chainName) + _ = iptables(ctx, "-F", chainName) + _ = iptables(ctx, "-X", chainName) +} + +func (m *NetMeter) pollLoop(ctx context.Context) { + defer close(m.stopped) + defer cleanupChain(ctx) + + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + + for { + select { + case <-m.stopCh: + return + case <-ticker.C: + b, err := readCounter(ctx) + if err != nil { + log.Error(ctx, "failed to read data transfer counter", "err", err) + continue + } + m.bytes.Store(b) + log.Debug(ctx, "data transfer meter poll", "bytes", b) + } + } +} + +// readCounter reads the cumulative byte count from the catch-all rule (last rule in chain). +func readCounter(ctx context.Context) (int64, error) { + output, err := iptablesOutput(ctx, "-L", chainName, "-v", "-x", "-n") + if err != nil { + return 0, err + } + return parseByteCounter(output) +} + +// parseByteCounter extracts the byte count from the last rule (catch-all counting rule) +// in the iptables -L -v -x -n output. +// +// Example output: +// +// Chain dstack-nm (1 references) +// pkts bytes target prot opt in out source destination +// 0 0 RETURN all -- * * 0.0.0.0/0 10.0.0.0/8 +// 0 0 RETURN all -- * * 0.0.0.0/0 172.16.0.0/12 +// 0 0 RETURN all -- * * 0.0.0.0/0 192.168.0.0/16 +// 0 0 RETURN all -- * * 0.0.0.0/0 169.254.0.0/16 +// 0 0 RETURN all -- * * 0.0.0.0/0 127.0.0.0/8 +// 123 456789 RETURN all -- * * 0.0.0.0/0 0.0.0.0/0 +// +// The last rule (destination 0.0.0.0/0) is the catch-all; its bytes field is what we want. +func parseByteCounter(output string) (int64, error) { + lines := strings.Split(strings.TrimSpace(output), "\n") + + // Find lines that are rule entries (skip header lines) + var lastRuleLine string + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" { + continue + } + // Skip "Chain ..." and column header lines + if strings.HasPrefix(trimmed, "Chain ") { + continue + } + if strings.HasPrefix(trimmed, "pkts") { + continue + } + lastRuleLine = trimmed + } + + if lastRuleLine == "" { + return 0, fmt.Errorf("no rules found in chain %s", chainName) + } + + // Parse the bytes field (second field in the line) + fields := strings.Fields(lastRuleLine) + if len(fields) < 2 { + return 0, fmt.Errorf("unexpected rule format: %q", lastRuleLine) + } + + byteCount, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, fmt.Errorf("parse byte count %q: %w", fields[1], err) + } + + return byteCount, nil +} + +func iptables(ctx context.Context, args ...string) error { + cmd := exec.CommandContext(ctx, "iptables", args...) + var stderr bytes.Buffer + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("iptables %s: %s: %w", strings.Join(args, " "), stderr.String(), err) + } + return nil +} + +func iptablesOutput(ctx context.Context, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, "iptables", args...) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("iptables %s: %s: %w", strings.Join(args, " "), stderr.String(), err) + } + return stdout.String(), nil +} diff --git a/runner/internal/shim/netmeter/netmeter_test.go b/runner/internal/shim/netmeter/netmeter_test.go new file mode 100644 index 000000000..5af00c61e --- /dev/null +++ b/runner/internal/shim/netmeter/netmeter_test.go @@ -0,0 +1,87 @@ +package netmeter + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseByteCounter(t *testing.T) { + tests := []struct { + name string + output string + expected int64 + expectErr bool + }{ + { + name: "typical output with traffic", + output: `Chain dstack-nm (1 references) + pkts bytes target prot opt in out source destination + 0 0 RETURN all -- * * 0.0.0.0/0 10.0.0.0/8 + 0 0 RETURN all -- * * 0.0.0.0/0 172.16.0.0/12 + 0 0 RETURN all -- * * 0.0.0.0/0 192.168.0.0/16 + 0 0 RETURN all -- * * 0.0.0.0/0 169.254.0.0/16 + 0 0 RETURN all -- * * 0.0.0.0/0 127.0.0.0/8 + 123 456789 RETURN all -- * * 0.0.0.0/0 0.0.0.0/0 +`, + expected: 456789, + }, + { + name: "zero traffic", + output: `Chain dstack-nm (1 references) + pkts bytes target prot opt in out source destination + 0 0 RETURN all -- * * 0.0.0.0/0 10.0.0.0/8 + 0 0 RETURN all -- * * 0.0.0.0/0 172.16.0.0/12 + 0 0 RETURN all -- * * 0.0.0.0/0 192.168.0.0/16 + 0 0 RETURN all -- * * 0.0.0.0/0 169.254.0.0/16 + 0 0 RETURN all -- * * 0.0.0.0/0 127.0.0.0/8 + 0 0 RETURN all -- * * 0.0.0.0/0 0.0.0.0/0 +`, + expected: 0, + }, + { + name: "large byte count", + output: `Chain dstack-nm (1 references) + pkts bytes target prot opt in out source destination + 10000 5000000 RETURN all -- * * 0.0.0.0/0 10.0.0.0/8 + 0 0 RETURN all -- * * 0.0.0.0/0 172.16.0.0/12 + 0 0 RETURN all -- * * 0.0.0.0/0 192.168.0.0/16 + 0 0 RETURN all -- * * 0.0.0.0/0 169.254.0.0/16 + 0 0 RETURN all -- * * 0.0.0.0/0 127.0.0.0/8 + 500000 107374182400 RETURN all -- * * 0.0.0.0/0 0.0.0.0/0 +`, + expected: 107374182400, // ~100 GB + }, + { + name: "empty output", + output: "", + expectErr: true, + }, + { + name: "only headers no rules", + output: `Chain dstack-nm (1 references) + pkts bytes target prot opt in out source destination +`, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseByteCounter(tt.output) + if tt.expectErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestNew(t *testing.T) { + nm := New() + assert.NotNil(t, nm) + assert.Equal(t, int64(0), nm.Bytes()) +} diff --git a/src/dstack/_internal/core/models/instances.py b/src/dstack/_internal/core/models/instances.py index 11a1aca51..392f4bb51 100644 --- a/src/dstack/_internal/core/models/instances.py +++ b/src/dstack/_internal/core/models/instances.py @@ -356,3 +356,4 @@ class Instance(CoreModel): price: Optional[float] = None total_blocks: Optional[int] = None busy_blocks: int = 0 + data_transfer_bytes: int = 0 diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py index d23d536cd..cf1998dc4 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py @@ -166,6 +166,10 @@ async def check_instance(instance_model: InstanceModel) -> ProcessResult: status=health_status, response=instance_check.health_response.json(), ) + if instance_check.health_response.data_transfer_bytes is not None: + result.instance_update_map["data_transfer_bytes"] = ( + instance_check.health_response.data_transfer_bytes + ) set_health_update( update_map=result.instance_update_map, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py index 34e80311f..59937a72c 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py @@ -52,6 +52,7 @@ class InstanceUpdateMap(ItemUpdateMap, total=False): job_provisioning_data: str total_blocks: int busy_blocks: int + data_transfer_bytes: int deleted: bool deleted_at: UpdateMapDateTime diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py index eb1f3c8a3..7c9fea35d 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py @@ -1,6 +1,8 @@ +from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT from dstack._internal.core.errors import BackendError, NotYetTerminated from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.models.runs import JobProvisioningData from dstack._internal.server.background.pipeline_tasks.base import NOW_PLACEHOLDER from dstack._internal.server.background.pipeline_tasks.instances.common import ( ProcessResult, @@ -10,7 +12,12 @@ ) from dstack._internal.server.models import InstanceModel from dstack._internal.server.services import backends as backends_services -from dstack._internal.server.services.instances import get_instance_provisioning_data +from dstack._internal.server.services.instances import ( + get_instance_provisioning_data, + get_instance_ssh_private_keys, +) +from dstack._internal.server.services.runner import client as runner_client +from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel from dstack._internal.utils.common import get_current_datetime, run_async from dstack._internal.utils.logging import get_logger @@ -39,6 +46,7 @@ async def terminate_instance(instance_model: InstanceModel) -> ProcessResult: job_provisioning_data.backend, ) else: + await _capture_final_data_transfer_bytes(instance_model, job_provisioning_data, result) logger.debug("Terminating runner instance %s", job_provisioning_data.hostname) try: await run_async( @@ -86,3 +94,37 @@ async def terminate_instance(instance_model: InstanceModel) -> ProcessResult: new_status=InstanceStatus.TERMINATED, ) return result + + +async def _capture_final_data_transfer_bytes( + instance_model: InstanceModel, + jpd: JobProvisioningData, + result: ProcessResult, +) -> None: + """Best-effort final read of data_transfer_bytes before the instance is destroyed.""" + try: + health_response = await run_async( + _read_instance_health, + get_instance_ssh_private_keys(instance_model), + jpd, + None, + instance=instance_model, + ) + if ( + health_response is not False + and health_response is not None + and health_response.data_transfer_bytes is not None + ): + result.instance_update_map["data_transfer_bytes"] = health_response.data_transfer_bytes + except Exception as exc: + logger.debug( + "Failed to capture final data_transfer_bytes for %s: %s", + instance_model.name, + exc, + ) + + +@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) +def _read_instance_health(ports, *, instance): + shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + return shim_client.get_instance_health() diff --git a/src/dstack/_internal/server/background/scheduled_tasks/instances.py b/src/dstack/_internal/server/background/scheduled_tasks/instances.py index 5c041dc2b..b70b37ffb 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/instances.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/instances.py @@ -763,6 +763,8 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non response=instance_check.health_response.json(), ) session.add(health_check_model) + if instance_check.health_response.data_transfer_bytes is not None: + instance.data_transfer_bytes = instance_check.health_response.data_transfer_bytes _set_health(session, instance, health_status) _set_unreachable(session, instance, unreachable=not instance_check.reachable) @@ -1127,6 +1129,7 @@ async def _terminate(session: AsyncSession, instance: InstanceModel) -> None: jpd.backend, ) else: + await _capture_final_data_transfer_bytes(instance, jpd) logger.debug("Terminating runner instance %s", jpd.hostname) try: await run_async( @@ -1165,6 +1168,38 @@ async def _terminate(session: AsyncSession, instance: InstanceModel) -> None: switch_instance_status(session, instance, InstanceStatus.TERMINATED) +async def _capture_final_data_transfer_bytes( + instance: InstanceModel, jpd: JobProvisioningData +) -> None: + """Best-effort final read of data_transfer_bytes before the instance is destroyed.""" + try: + health_response = await run_async( + _read_instance_health, + get_instance_ssh_private_keys(instance), + jpd, + None, + instance=instance, + ) + if ( + health_response is not False + and health_response is not None + and health_response.data_transfer_bytes is not None + ): + instance.data_transfer_bytes = health_response.data_transfer_bytes + except Exception as exc: + logger.debug( + "Failed to capture final data_transfer_bytes for %s: %s", + instance.name, + exc, + ) + + +@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) +def _read_instance_health(ports, *, instance): + shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + return shim_client.get_instance_health() + + def _set_health(session: AsyncSession, instance: InstanceModel, health: HealthStatus) -> None: if instance.health != health: events.emit( diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_25_1200_a1b2c3d4e5f6_add_data_transfer_bytes_to_instances.py b/src/dstack/_internal/server/migrations/versions/2026/03_25_1200_a1b2c3d4e5f6_add_data_transfer_bytes_to_instances.py new file mode 100644 index 000000000..1727e1f62 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_25_1200_a1b2c3d4e5f6_add_data_transfer_bytes_to_instances.py @@ -0,0 +1,27 @@ +"""Add data_transfer_bytes to instances + +Revision ID: a1b2c3d4e5f6 +Revises: c1c2ecaee45c +Create Date: 2026-03-25 12:00:00.000000+00:00 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a1b2c3d4e5f6" +down_revision = "c1c2ecaee45c" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "instances", + sa.Column("data_transfer_bytes", sa.BigInteger(), nullable=False, server_default="0"), + ) + + +def downgrade() -> None: + op.drop_column("instances", "data_transfer_bytes") diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index b599c4314..156f502d7 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -776,6 +776,9 @@ class InstanceModel(PipelineModelMixin, BaseModel): """`total_blocks` uses `NULL` to mean `auto` during provisioning; once ready it is not `NULL`.""" busy_blocks: Mapped[int] = mapped_column(Integer, default=0) + data_transfer_bytes: Mapped[int] = mapped_column(BigInteger, default=0) + """Cumulative outbound data transfer bytes (external/billable traffic only).""" + jobs: Mapped[list["JobModel"]] = relationship(back_populates="instance") last_job_processed_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) diff --git a/src/dstack/_internal/server/schemas/instances.py b/src/dstack/_internal/server/schemas/instances.py index 8f87935b9..fd72bbf66 100644 --- a/src/dstack/_internal/server/schemas/instances.py +++ b/src/dstack/_internal/server/schemas/instances.py @@ -37,7 +37,10 @@ def get_health_status(self) -> HealthStatus: def has_health_checks(self) -> bool: if self.health_response is None: return False - return self.health_response.dcgm is not None + return ( + self.health_response.dcgm is not None + or self.health_response.data_transfer_bytes is not None + ) class GetInstanceHealthChecksRequest(CoreModel): diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 549ff7914..fe6d14cb5 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -128,6 +128,7 @@ class HealthcheckResponse(CoreModel): class InstanceHealthResponse(CoreModel): dcgm: Optional[DCGMHealthResponse] = None + data_transfer_bytes: Optional[int] = None class ShutdownRequest(CoreModel): diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index d54ec8b68..9dc196e16 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -241,6 +241,7 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance: finished_at=instance_model.finished_at, total_blocks=instance_model.total_blocks, busy_blocks=instance_model.busy_blocks, + data_transfer_bytes=instance_model.data_transfer_bytes or 0, ) offer = get_instance_offer(instance_model) diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 47544b921..5b7ac55e8 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -1025,6 +1025,7 @@ async def test_creates_fleet(self, test_db, session: AsyncSession, client: Async "price": None, "total_blocks": 1, "busy_blocks": 0, + "data_transfer_bytes": 0, } ], } @@ -1167,6 +1168,7 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A "price": 0.0, "total_blocks": 1, "busy_blocks": 0, + "data_transfer_bytes": 0, } ], } @@ -1347,6 +1349,7 @@ async def test_updates_ssh_fleet(self, test_db, session: AsyncSession, client: A "price": 0.0, "total_blocks": 1, "busy_blocks": 0, + "data_transfer_bytes": 0, }, { "id": SomeUUID4Str(), @@ -1382,6 +1385,7 @@ async def test_updates_ssh_fleet(self, test_db, session: AsyncSession, client: A "price": 0.0, "total_blocks": 1, "busy_blocks": 0, + "data_transfer_bytes": 0, }, ], }