From 7a2802eabad9d06daff7fc33e5241b8c8417ff07 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 22:29:32 +0100 Subject: [PATCH 01/39] Add K8s GPU metrics collection design spec Three-source fallback chain: CloudWatch Container Insights, DCGM exporter scrape, and Prometheus query. Per-node fallback with new ruleK8sLowGPUUtil analysis rule. --- .../2026-04-19-k8s-gpu-metrics-design.md | 149 ++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 docs/specs/2026-04-19-k8s-gpu-metrics-design.md diff --git a/docs/specs/2026-04-19-k8s-gpu-metrics-design.md b/docs/specs/2026-04-19-k8s-gpu-metrics-design.md new file mode 100644 index 0000000..ed8c3d7 --- /dev/null +++ b/docs/specs/2026-04-19-k8s-gpu-metrics-design.md @@ -0,0 +1,149 @@ +# K8s GPU Metrics Collection + +## Goal + +Collect GPU utilization metrics for Kubernetes GPU nodes discovered by gpuaudit, using a per-node fallback chain of three sources: CloudWatch Container Insights, DCGM exporter scrape, and Prometheus query. Enable utilization-based waste detection for K8s GPU nodes (currently limited to allocation-based detection only). + +## Architecture + +Three metrics sources, tried in priority order **per node** (stop at the first source that returns data for a given node): + +1. **CloudWatch Container Insights** — AWS API call, no in-cluster access needed beyond what we already have. +2. **DCGM exporter scrape** — probe port 9400 on dcgm-exporter pods via K8s API proxy. +3. **Prometheus query** — query a user-configured Prometheus endpoint for historical GPU metrics. + +All three populate the same existing fields: `GPUInstance.AvgGPUUtilization` and `GPUInstance.AvgGPUMemUtilization`. + +## Data Flow + +``` +1. AWS scan → ScanResult (EC2, SageMaker, EKS) +2. K8s scan → []GPUInstance (nodes + allocation) +3. Enrich K8s GPU metrics (fallback chain): + a. CloudWatch Container Insights (if AWS creds available, !skipMetrics) + b. DCGM scrape via K8s API proxy (for nodes still missing metrics) + c. Prometheus query (for remaining nodes, if --prom-url or --prom-endpoint set) +4. AnalyzeAll on K8s instances +5. Merge into result +``` + +Steps 3a through 3c each skip nodes that already have `AvgGPUUtilization` populated by a prior step. + +## Source 1: CloudWatch Container Insights + +Requires the CloudWatch Observability EKS add-on to be installed in the cluster. If not installed, the query returns empty (not an error) and we fall through. + +**Metrics queried:** +- `node_gpu_utilization` (Average) — maps to `AvgGPUUtilization` +- `node_gpu_memory_utilization` (Average) — maps to `AvgGPUMemUtilization` + +**Namespace:** `ContainerInsights` + +**Dimensions:** `ClusterName` + `InstanceId` + +**Implementation:** New function `EnrichK8sGPUMetrics(ctx, client CloudWatchClient, instances []GPUInstance, clusterName string, window MetricWindow)` in `internal/providers/aws/cloudwatch.go`, following the same pattern as `EnrichEC2Metrics` and `EnrichSageMakerMetrics`. + +**Prerequisites per node:** The node must have an EC2 instance ID (extracted from `providerID`). Non-AWS nodes are skipped for this source. + +**Wiring:** Called from `main.go` after the K8s scan returns instances, passing the CloudWatch client from the AWS config. Only called when AWS credentials are available and `!skipMetrics`. + +## Source 2: DCGM Exporter Scrape + +Auto-detected, no user configuration needed. + +**Discovery:** List pods across all namespaces matching labels `app=nvidia-dcgm-exporter` or `app.kubernetes.io/name=dcgm-exporter`. If no pods found, log `"DCGM exporter not detected, skipping"` and fall through to Prometheus. + +**Scraping:** For each GPU node still missing metrics, find the dcgm-exporter pod on that node (match by `pod.Spec.NodeName`), then scrape `/metrics` on port 9400 via the K8s API proxy (`ProxyGet`). + +**Metrics parsed:** +- `DCGM_FI_DEV_GPU_UTIL` — maps to `AvgGPUUtilization` +- `DCGM_FI_DEV_MEM_COPY_UTIL` — maps to `AvgGPUMemUtilization` + +These are point-in-time values, not historical averages. The analysis rule's confidence (0.85 vs 0.9) accounts for this lower fidelity. + +**Prometheus text format parsing:** Use `prometheus/common/expfmt` to parse the scrape response. + +**K8s client extension:** Add `ProxyGet(ctx, namespace, podName, port, path string) ([]byte, error)` to the `K8sClient` interface. Wraps `clientset.CoreV1().Pods(ns).ProxyGet()`. + +**Stderr output:** +``` + Probing DCGM exporter on GPU nodes... + DCGM: got GPU metrics for 3 of 5 remaining nodes +``` + +## Source 3: Prometheus Query + +Only attempted when `--prom-url` or `--prom-endpoint` is provided. No auto-discovery. + +**CLI flags:** +- `--prom-url` — full URL to a Prometheus-compatible API (e.g., `https://prometheus.corp.example.com`, AMP endpoint, Grafana Cloud). Hit directly via HTTP. +- `--prom-endpoint` — in-cluster service as `namespace/service:port` (e.g., `monitoring/prometheus:9090`). Proxied through the K8s API server. + +These flags are mutually exclusive. Error if both are set. + +**Query:** Batch all remaining nodes into one PromQL query: +``` +avg_over_time(DCGM_FI_DEV_GPU_UTIL{node=~"node1|node2|..."}[7d]) +``` +And similarly for `DCGM_FI_DEV_MEM_COPY_UTIL`. + +**API:** HTTP GET to `/api/v1/query`, parse the standard Prometheus JSON response. No client library — plain `net/http` for direct URLs, K8s API proxy for in-cluster endpoints. + +**Stderr output:** +``` + Querying Prometheus at monitoring/prometheus:9090... + Prometheus: got GPU metrics for 2 of 3 remaining nodes +``` + +## Analysis Rule + +New rule `ruleK8sLowGPUUtil` in `internal/analysis/rules.go`: + +- **Source filter:** `SourceK8sNode` only +- **Guard:** `AvgGPUUtilization != nil` (skip nodes where no metrics were collected) +- **Threshold:** average GPU utilization < 10% +- **Signal type:** `low_utilization` +- **Severity:** Critical +- **Confidence:** 0.85 +- **Recommendation:** "GPU utilization averaging X%. Consider bin-packing more workloads, downsizing, or removing from the node pool." +- **Savings estimate:** `MonthlyCost * 0.8` (same rough estimate as SageMaker equivalent) + +**Interplay with `ruleK8sUnallocatedGPU`:** Both rules can fire on the same node. Unallocated detects zero pod scheduling (allocation-based). Low-util detects pods that are scheduled but barely using the GPU (utilization-based). Different problems, different fixes. + +## File Changes + +- **Modify:** `internal/providers/aws/cloudwatch.go` — add `EnrichK8sGPUMetrics()` +- **Create:** `internal/providers/k8s/metrics.go` — DCGM scraping, Prometheus querying, fallback orchestration +- **Create:** `internal/providers/k8s/metrics_test.go` — tests for DCGM and Prometheus paths +- **Modify:** `internal/providers/k8s/discover.go` — extend `K8sClient` interface with `ProxyGet` (DCGM pod discovery uses existing `ListPods` with label selector) +- **Modify:** `internal/providers/k8s/scanner.go` — wire metrics enrichment into the K8s scan, accept new options +- **Modify:** `internal/analysis/rules.go` — add `ruleK8sLowGPUUtil` +- **Modify:** `internal/analysis/rules_test.go` — tests for the new rule +- **Modify:** `cmd/gpuaudit/main.go` — add `--prom-url` and `--prom-endpoint` flags, wire CloudWatch enrichment for K8s instances + +## Error Handling + +- **CloudWatch returns empty:** Not an error. Container Insights add-on probably not installed. Fall through to DCGM. +- **No EC2 instance ID on a node:** Skip CW enrichment for that node (non-AWS or providerID not set). +- **No dcgm-exporter pods found:** Log on stderr, fall through to Prometheus. +- **DCGM scrape fails for a node:** Warn on stderr, continue with other nodes. Don't fail the scan. +- **Prometheus endpoint unreachable:** Warn on stderr, continue without metrics for remaining nodes. +- **Both `--prom-url` and `--prom-endpoint` set:** Return an error at flag validation time. + +## New Dependencies + +- `prometheus/common/expfmt` — for parsing Prometheus text format from DCGM exporter scrapes. Small, well-established library. + +## IAM Policy + +No new IAM permissions required. `EnrichK8sGPUMetrics` uses the existing `cloudwatch:GetMetricData` permission already in the IAM policy output. + +## RBAC + +The K8s API proxy calls (`ProxyGet` to pods) require the `pods/proxy` resource permission. For DCGM scraping: +``` +- apiGroups: [""] + resources: ["pods/proxy"] + verbs: ["get"] +``` +This should be documented and added to any RBAC guide. From d271797577340abc2e9e765c7d3ba196e6d288a9 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 22:41:11 +0100 Subject: [PATCH 02/39] Add K8s GPU metrics collection implementation plan --- .../plans/2026-04-19-k8s-gpu-metrics.md | 1394 +++++++++++++++++ 1 file changed, 1394 insertions(+) create mode 100644 docs/superpowers/plans/2026-04-19-k8s-gpu-metrics.md diff --git a/docs/superpowers/plans/2026-04-19-k8s-gpu-metrics.md b/docs/superpowers/plans/2026-04-19-k8s-gpu-metrics.md new file mode 100644 index 0000000..14c7f2c --- /dev/null +++ b/docs/superpowers/plans/2026-04-19-k8s-gpu-metrics.md @@ -0,0 +1,1394 @@ +# K8s GPU Metrics Collection Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Collect GPU utilization metrics for Kubernetes GPU nodes via a per-node fallback chain (CloudWatch Container Insights → DCGM exporter → Prometheus), and add a utilization-based waste detection rule. + +**Architecture:** Three metrics sources tried in priority order per node, all populating the existing `AvgGPUUtilization` and `AvgGPUMemUtilization` fields on `GPUInstance`. A new analysis rule `ruleK8sLowGPUUtil` flags nodes with GPU utilization < 10%. The fallback chain is wired in `main.go` between K8s discovery and analysis. + +**Tech Stack:** Go, AWS SDK v2 (CloudWatch), client-go (K8s API proxy), prometheus/common/expfmt (Prometheus text parsing), net/http (Prometheus API) + +--- + +## File Structure + +| File | Responsibility | +|------|---------------| +| `internal/providers/aws/cloudwatch.go` | Add `EnrichK8sGPUMetrics()` — CloudWatch Container Insights queries | +| `internal/providers/aws/cloudwatch_test.go` | Tests for `EnrichK8sGPUMetrics()` (new file) | +| `internal/providers/k8s/discover.go` | Extend `K8sClient` interface with `ProxyGet` | +| `internal/providers/k8s/scanner.go` | Extend `ScanOptions` with Prometheus config, export `BuildClientPublic` | +| `internal/providers/k8s/metrics.go` | DCGM scraping, Prometheus querying, fallback orchestration (new file) | +| `internal/providers/k8s/metrics_test.go` | Tests for DCGM and Prometheus paths (new file) | +| `internal/analysis/rules.go` | Add `ruleK8sLowGPUUtil` | +| `internal/analysis/rules_test.go` | Tests for new rule | +| `cmd/gpuaudit/main.go` | Add `--prom-url`, `--prom-endpoint` flags; wire CW enrichment for K8s instances | + +--- + +### Task 1: CloudWatch Container Insights Enrichment + +**Files:** +- Create: `internal/providers/aws/cloudwatch_test.go` +- Modify: `internal/providers/aws/cloudwatch.go:60-80` + +This task adds `EnrichK8sGPUMetrics()` following the exact same pattern as the existing `EnrichEC2Metrics()` and `EnrichSageMakerMetrics()` functions. It queries the `ContainerInsights` namespace for `node_gpu_utilization` and `node_gpu_memory_utilization`. + +- [ ] **Step 1: Write the failing tests** + +Create `internal/providers/aws/cloudwatch_test.go`: + +```go +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/cloudwatch" + cwtypes "github.com/aws/aws-sdk-go-v2/service/cloudwatch/types" + + "github.com/gpuaudit/cli/internal/models" +) + +type mockCloudWatchClient struct { + output *cloudwatch.GetMetricDataOutput + err error +} + +func (m *mockCloudWatchClient) GetMetricData(ctx context.Context, params *cloudwatch.GetMetricDataInput, optFns ...func(*cloudwatch.Options)) (*cloudwatch.GetMetricDataOutput, error) { + if m.err != nil { + return nil, m.err + } + return m.output, nil +} + +func TestEnrichK8sGPUMetrics_PopulatesUtilization(t *testing.T) { + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{ + MetricDataResults: []cwtypes.MetricDataResult{ + {Id: aws.String("gpu_util_i_abc123"), Values: []float64{45.0, 50.0, 55.0}}, + {Id: aws.String("gpu_mem_i_abc123"), Values: []float64{30.0, 35.0, 40.0}}, + }, + }, + } + instances := []models.GPUInstance{ + { + InstanceID: "i-abc123", + Source: models.SourceK8sNode, + }, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "ml-cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization == nil { + t.Fatal("expected GPU utilization to be populated") + } + if *instances[0].AvgGPUUtilization != 50.0 { + t.Errorf("expected avg GPU util 50.0, got %f", *instances[0].AvgGPUUtilization) + } + if instances[0].AvgGPUMemUtilization == nil { + t.Fatal("expected GPU memory utilization to be populated") + } + if *instances[0].AvgGPUMemUtilization != 35.0 { + t.Errorf("expected avg GPU mem util 35.0, got %f", *instances[0].AvgGPUMemUtilization) + } +} + +func TestEnrichK8sGPUMetrics_SkipsNonK8sNodes(t *testing.T) { + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{}, + } + instances := []models.GPUInstance{ + {InstanceID: "i-ec2", Source: models.SourceEC2}, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil GPU util for non-K8s instance") + } +} + +func TestEnrichK8sGPUMetrics_SkipsNodesWithoutInstanceID(t *testing.T) { + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{}, + } + instances := []models.GPUInstance{ + {InstanceID: "node-hostname", Source: models.SourceK8sNode}, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil GPU util for node without EC2 instance ID") + } +} + +func TestEnrichK8sGPUMetrics_SkipsAlreadyEnriched(t *testing.T) { + gpuUtil := 75.0 + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{}, + } + instances := []models.GPUInstance{ + { + InstanceID: "i-abc123", + Source: models.SourceK8sNode, + AvgGPUUtilization: &gpuUtil, + }, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if *instances[0].AvgGPUUtilization != 75.0 { + t.Errorf("expected existing value 75.0 to be preserved, got %f", *instances[0].AvgGPUUtilization) + } +} + +func TestEnrichK8sGPUMetrics_HandlesAPIError(t *testing.T) { + client := &mockCloudWatchClient{ + err: fmt.Errorf("access denied"), + } + instances := []models.GPUInstance{ + {InstanceID: "i-abc123", Source: models.SourceK8sNode}, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil GPU util after API error") + } +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `go test ./internal/providers/aws/ -run TestEnrichK8sGPUMetrics -v` +Expected: FAIL — `EnrichK8sGPUMetrics` not defined + +- [ ] **Step 3: Implement EnrichK8sGPUMetrics** + +Add to `internal/providers/aws/cloudwatch.go`, after the `EnrichSageMakerMetrics` function (after line 80): + +```go +// EnrichK8sGPUMetrics populates GPU utilization metrics on K8s nodes using CloudWatch Container Insights. +func EnrichK8sGPUMetrics(ctx context.Context, client CloudWatchClient, instances []models.GPUInstance, clusterName string, window MetricWindow) { + type nodeRef struct { + index int + instanceID string + } + var nodes []nodeRef + for i := range instances { + inst := &instances[i] + if inst.Source != models.SourceK8sNode { + continue + } + if inst.AvgGPUUtilization != nil { + continue + } + if !strings.HasPrefix(inst.InstanceID, "i-") { + continue + } + nodes = append(nodes, nodeRef{index: i, instanceID: inst.InstanceID}) + } + if len(nodes) == 0 { + return + } + + now := time.Now() + start := now.Add(-window.Duration) + + clusterDim := cwtypes.Dimension{ + Name: aws.String("ClusterName"), + Value: aws.String(clusterName), + } + + for _, node := range nodes { + instanceDim := cwtypes.Dimension{ + Name: aws.String("InstanceId"), + Value: aws.String(node.instanceID), + } + + safeID := strings.ReplaceAll(node.instanceID, "-", "_") + + queries := []cwtypes.MetricDataQuery{ + metricQuery2("gpu_util_"+safeID, "ContainerInsights", "node_gpu_utilization", "Average", window.Period, clusterDim, instanceDim), + metricQuery2("gpu_mem_"+safeID, "ContainerInsights", "node_gpu_memory_utilization", "Average", window.Period, clusterDim, instanceDim), + } + + results, err := fetchMetrics(ctx, client, queries, start, now) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Container Insights metrics unavailable for %s: %v\n", node.instanceID, err) + continue + } + + instances[node.index].AvgGPUUtilization = results["gpu_util_"+safeID] + instances[node.index].AvgGPUMemUtilization = results["gpu_mem_"+safeID] + } +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `go test ./internal/providers/aws/ -run TestEnrichK8sGPUMetrics -v` +Expected: PASS (all 5 tests) + +- [ ] **Step 5: Run full test suite** + +Run: `go test ./...` +Expected: All tests pass + +- [ ] **Step 6: Commit** + +```bash +git add internal/providers/aws/cloudwatch.go internal/providers/aws/cloudwatch_test.go +git commit -m "Add EnrichK8sGPUMetrics for CloudWatch Container Insights GPU metrics" +``` + +--- + +### Task 2: Extend K8sClient Interface with ProxyGet + +**Files:** +- Modify: `internal/providers/k8s/discover.go:24-27` +- Modify: `internal/providers/k8s/scanner.go:91-101` +- Modify: `internal/providers/k8s/discover_test.go:19-30` + +This task adds `ProxyGet` to the `K8sClient` interface and updates the mock and wrapper. This is needed for both DCGM scraping (Task 3) and Prometheus in-cluster queries (Task 4). + +- [ ] **Step 1: Add ProxyGet to the K8sClient interface** + +In `internal/providers/k8s/discover.go`, change the `K8sClient` interface (lines 24-27) from: + +```go +type K8sClient interface { + ListNodes(ctx context.Context, opts metav1.ListOptions) (*corev1.NodeList, error) + ListPods(ctx context.Context, namespace string, opts metav1.ListOptions) (*corev1.PodList, error) +} +``` + +to: + +```go +type K8sClient interface { + ListNodes(ctx context.Context, opts metav1.ListOptions) (*corev1.NodeList, error) + ListPods(ctx context.Context, namespace string, opts metav1.ListOptions) (*corev1.PodList, error) + ProxyGet(ctx context.Context, namespace, podName, port, path string) ([]byte, error) +} +``` + +- [ ] **Step 2: Implement ProxyGet on k8sClientWrapper** + +In `internal/providers/k8s/scanner.go`, add this method after the `ListPods` method (after line 101): + +```go +func (w *k8sClientWrapper) ProxyGet(ctx context.Context, namespace, podName, port, path string) ([]byte, error) { + return w.clientset.CoreV1().Pods(namespace).ProxyGet("http", podName, port, path, nil).DoRaw(ctx) +} +``` + +- [ ] **Step 3: Add ProxyGet to the mock in tests** + +In `internal/providers/k8s/discover_test.go`, change the `mockK8sClient` struct (lines 19-22) from: + +```go +type mockK8sClient struct { + nodes *corev1.NodeList + pods *corev1.PodList +} +``` + +to: + +```go +type mockK8sClient struct { + nodes *corev1.NodeList + pods *corev1.PodList + proxyData map[string][]byte + proxyErr error +} +``` + +And add the method after `ListPods` (after line 30): + +```go +func (m *mockK8sClient) ProxyGet(ctx context.Context, namespace, podName, port, path string) ([]byte, error) { + if m.proxyErr != nil { + return nil, m.proxyErr + } + key := fmt.Sprintf("%s/%s:%s%s", namespace, podName, port, path) + if data, ok := m.proxyData[key]; ok { + return data, nil + } + return nil, fmt.Errorf("no mock data for %s", key) +} +``` + +- [ ] **Step 4: Run tests to verify nothing is broken** + +Run: `go test ./internal/providers/k8s/ -v` +Expected: All existing tests pass + +- [ ] **Step 5: Commit** + +```bash +git add internal/providers/k8s/discover.go internal/providers/k8s/scanner.go internal/providers/k8s/discover_test.go +git commit -m "Add ProxyGet to K8sClient interface for pod API proxy" +``` + +--- + +### Task 3: DCGM Exporter Scraping + +**Files:** +- Create: `internal/providers/k8s/metrics.go` +- Create: `internal/providers/k8s/metrics_test.go` + +This task implements DCGM exporter auto-discovery and metric scraping. It discovers dcgm-exporter pods by label, matches them to GPU nodes, scrapes `/metrics` on port 9400, and parses `DCGM_FI_DEV_GPU_UTIL` and `DCGM_FI_DEV_MEM_COPY_UTIL`. + +- [ ] **Step 1: Add the `prometheus/common` dependency** + +Run: `go get github.com/prometheus/common@latest` + +This will also pull in `github.com/prometheus/client_model` (needed for `dto.MetricFamily`). + +- [ ] **Step 2: Write the failing tests** + +Create `internal/providers/k8s/metrics_test.go`: + +```go +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package k8s + +import ( + "context" + "fmt" + "testing" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/gpuaudit/cli/internal/models" +) + +func dcgmPod(name, namespace, nodeName string) corev1.Pod { + return corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + Labels: map[string]string{ + "app.kubernetes.io/name": "dcgm-exporter", + }, + }, + Spec: corev1.PodSpec{ + NodeName: nodeName, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodRunning, + }, + } +} + +const sampleDCGMMetrics = `# HELP DCGM_FI_DEV_GPU_UTIL GPU utilization. +# TYPE DCGM_FI_DEV_GPU_UTIL gauge +DCGM_FI_DEV_GPU_UTIL{gpu="0",UUID="GPU-abc",device="nvidia0",modelName="NVIDIA A10G",Hostname="node1"} 42.0 +DCGM_FI_DEV_GPU_UTIL{gpu="1",UUID="GPU-def",device="nvidia1",modelName="NVIDIA A10G",Hostname="node1"} 38.0 +# HELP DCGM_FI_DEV_MEM_COPY_UTIL GPU memory utilization. +# TYPE DCGM_FI_DEV_MEM_COPY_UTIL gauge +DCGM_FI_DEV_MEM_COPY_UTIL{gpu="0",UUID="GPU-abc",device="nvidia0",modelName="NVIDIA A10G",Hostname="node1"} 55.0 +DCGM_FI_DEV_MEM_COPY_UTIL{gpu="1",UUID="GPU-def",device="nvidia1",modelName="NVIDIA A10G",Hostname="node1"} 60.0 +` + +func TestEnrichDCGMMetrics_PopulatesUtilization(t *testing.T) { + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{ + Items: []corev1.Pod{ + dcgmPod("dcgm-exporter-abc", "gpu-operator", "i-node1"), + }, + }, + proxyData: map[string][]byte{ + "gpu-operator/dcgm-exporter-abc:9400/metrics": []byte(sampleDCGMMetrics), + }, + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode, Name: "cluster/i-node1"}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if instances[0].AvgGPUUtilization == nil { + t.Fatal("expected GPU utilization to be populated") + } + if *instances[0].AvgGPUUtilization != 40.0 { + t.Errorf("expected avg GPU util 40.0 (average of 42 and 38), got %f", *instances[0].AvgGPUUtilization) + } + if instances[0].AvgGPUMemUtilization == nil { + t.Fatal("expected GPU memory utilization to be populated") + } + if *instances[0].AvgGPUMemUtilization != 57.5 { + t.Errorf("expected avg GPU mem util 57.5 (average of 55 and 60), got %f", *instances[0].AvgGPUMemUtilization) + } + if enriched != 1 { + t.Errorf("expected 1 enriched node, got %d", enriched) + } +} + +func TestEnrichDCGMMetrics_SkipsAlreadyEnriched(t *testing.T) { + gpuUtil := 75.0 + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{ + Items: []corev1.Pod{ + dcgmPod("dcgm-exporter-abc", "gpu-operator", "i-node1"), + }, + }, + proxyData: map[string][]byte{ + "gpu-operator/dcgm-exporter-abc:9400/metrics": []byte(sampleDCGMMetrics), + }, + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode, AvgGPUUtilization: &gpuUtil}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if *instances[0].AvgGPUUtilization != 75.0 { + t.Error("should not overwrite existing utilization") + } + if enriched != 0 { + t.Errorf("expected 0 enriched nodes, got %d", enriched) + } +} + +func TestEnrichDCGMMetrics_NoDCGMPods(t *testing.T) { + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{Items: []corev1.Pod{}}, + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil when no DCGM pods") + } + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestEnrichDCGMMetrics_HandlesScrapeError(t *testing.T) { + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{ + Items: []corev1.Pod{ + dcgmPod("dcgm-exporter-abc", "gpu-operator", "i-node1"), + }, + }, + proxyErr: fmt.Errorf("connection refused"), + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil after scrape error") + } + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestParseDCGMMetrics(t *testing.T) { + gpuUtil, memUtil := parseDCGMMetrics([]byte(sampleDCGMMetrics)) + + if gpuUtil == nil { + t.Fatal("expected gpu util") + } + if *gpuUtil != 40.0 { + t.Errorf("expected 40.0, got %f", *gpuUtil) + } + if memUtil == nil { + t.Fatal("expected mem util") + } + if *memUtil != 57.5 { + t.Errorf("expected 57.5, got %f", *memUtil) + } +} + +func TestParseDCGMMetrics_EmptyInput(t *testing.T) { + gpuUtil, memUtil := parseDCGMMetrics([]byte("")) + if gpuUtil != nil || memUtil != nil { + t.Error("expected nil for empty input") + } +} +``` + +- [ ] **Step 3: Run tests to verify they fail** + +Run: `go test ./internal/providers/k8s/ -run "TestEnrichDCGM|TestParseDCGM" -v` +Expected: FAIL — functions not defined + +- [ ] **Step 4: Implement DCGM metrics enrichment** + +Create `internal/providers/k8s/metrics.go`: + +```go +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package k8s + +import ( + "bytes" + "context" + "fmt" + "os" + + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/common/expfmt" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/gpuaudit/cli/internal/models" +) + +// EnrichDCGMMetrics discovers dcgm-exporter pods and scrapes GPU metrics for K8s nodes +// that don't already have AvgGPUUtilization populated. Returns the number of nodes enriched. +func EnrichDCGMMetrics(ctx context.Context, client K8sClient, instances []models.GPUInstance) int { + needsMetrics := make(map[string]int) + for i := range instances { + inst := &instances[i] + if inst.Source != models.SourceK8sNode || inst.AvgGPUUtilization != nil { + continue + } + needsMetrics[inst.InstanceID] = i + } + if len(needsMetrics) == 0 { + return 0 + } + + dcgmPods, err := findDCGMPods(ctx, client) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: could not list DCGM exporter pods: %v\n", err) + return 0 + } + if len(dcgmPods) == 0 { + fmt.Fprintf(os.Stderr, " DCGM exporter not detected, skipping\n") + return 0 + } + + fmt.Fprintf(os.Stderr, " Probing DCGM exporter on GPU nodes...\n") + + enriched := 0 + for _, pod := range dcgmPods { + idx, ok := needsMetrics[pod.Spec.NodeName] + if !ok { + continue + } + + data, err := client.ProxyGet(ctx, pod.Namespace, pod.Name, "9400", "/metrics") + if err != nil { + fmt.Fprintf(os.Stderr, " warning: DCGM scrape failed for %s: %v\n", pod.Spec.NodeName, err) + continue + } + + gpuUtil, memUtil := parseDCGMMetrics(data) + if gpuUtil != nil { + instances[idx].AvgGPUUtilization = gpuUtil + instances[idx].AvgGPUMemUtilization = memUtil + enriched++ + } + } + + fmt.Fprintf(os.Stderr, " DCGM: got GPU metrics for %d of %d remaining nodes\n", enriched, len(needsMetrics)) + return enriched +} + +func findDCGMPods(ctx context.Context, client K8sClient) ([]corev1.Pod, error) { + podList, err := client.ListPods(ctx, "", metav1.ListOptions{ + LabelSelector: "app.kubernetes.io/name=dcgm-exporter", + }) + if err != nil { + return nil, err + } + if len(podList.Items) > 0 { + return runningPods(podList.Items), nil + } + + podList, err = client.ListPods(ctx, "", metav1.ListOptions{ + LabelSelector: "app=nvidia-dcgm-exporter", + }) + if err != nil { + return nil, err + } + return runningPods(podList.Items), nil +} + +func runningPods(pods []corev1.Pod) []corev1.Pod { + var result []corev1.Pod + for _, p := range pods { + if p.Status.Phase == corev1.PodRunning { + result = append(result, p) + } + } + return result +} + +func parseDCGMMetrics(data []byte) (gpuUtil, memUtil *float64) { + parser := expfmt.TextParser{} + families, err := parser.TextToMetricFamilies(bytes.NewReader(data)) + if err != nil { + return nil, nil + } + + gpuUtil = avgMetricValue(families["DCGM_FI_DEV_GPU_UTIL"]) + memUtil = avgMetricValue(families["DCGM_FI_DEV_MEM_COPY_UTIL"]) + return gpuUtil, memUtil +} + +func avgMetricValue(family *dto.MetricFamily) *float64 { + if family == nil || len(family.Metric) == 0 { + return nil + } + sum := 0.0 + count := 0 + for _, m := range family.Metric { + if m.Gauge != nil && m.Gauge.Value != nil { + sum += *m.Gauge.Value + count++ + } + } + if count == 0 { + return nil + } + avg := sum / float64(count) + return &avg +} +``` + +- [ ] **Step 5: Run tests to verify they pass** + +Run: `go test ./internal/providers/k8s/ -run "TestEnrichDCGM|TestParseDCGM" -v` +Expected: PASS (all 6 tests) + +- [ ] **Step 6: Run full test suite** + +Run: `go test ./...` +Expected: All tests pass + +- [ ] **Step 7: Commit** + +```bash +git add internal/providers/k8s/metrics.go internal/providers/k8s/metrics_test.go go.mod go.sum +git commit -m "Add DCGM exporter scraping for K8s GPU metrics" +``` + +--- + +### Task 4: Prometheus Query Enrichment + +**Files:** +- Modify: `internal/providers/k8s/metrics.go` +- Modify: `internal/providers/k8s/metrics_test.go` + +This task adds the Prometheus query path — the third fallback. It supports both direct URL (`--prom-url`) and in-cluster service endpoint (`--prom-endpoint`), querying `avg_over_time(DCGM_FI_DEV_GPU_UTIL{node=~"..."}[7d])`. + +- [ ] **Step 1: Write the failing tests** + +Add to `internal/providers/k8s/metrics_test.go`: + +```go +import ( + "net/http" + "net/http/httptest" + "strings" +) +``` + +Add these test functions: + +```go +func TestEnrichPrometheusMetrics_PopulatesFromDirectURL(t *testing.T) { + promResponse := `{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + {"metric": {"node": "i-node1"}, "value": [1700000000, "65.5"]}, + {"metric": {"node": "i-node2"}, "value": [1700000000, "30.0"]} + ] + } + }` + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/query" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + query := r.URL.Query().Get("query") + if !strings.Contains(query, "DCGM_FI_DEV_GPU_UTIL") { + t.Errorf("unexpected query: %s", query) + } + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(promResponse)) + })) + defer server.Close() + + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode, Name: "cluster/i-node1"}, + {InstanceID: "i-node2", Source: models.SourceK8sNode, Name: "cluster/i-node2"}, + } + opts := PrometheusOptions{URL: server.URL} + + enriched := EnrichPrometheusMetrics(context.Background(), nil, instances, opts) + + if enriched != 2 { + t.Errorf("expected 2 enriched, got %d", enriched) + } + if instances[0].AvgGPUUtilization == nil || *instances[0].AvgGPUUtilization != 65.5 { + t.Errorf("expected node1 GPU util 65.5, got %v", instances[0].AvgGPUUtilization) + } + if instances[1].AvgGPUUtilization == nil || *instances[1].AvgGPUUtilization != 30.0 { + t.Errorf("expected node2 GPU util 30.0, got %v", instances[1].AvgGPUUtilization) + } +} + +func TestEnrichPrometheusMetrics_SkipsAlreadyEnriched(t *testing.T) { + gpuUtil := 80.0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status":"success","data":{"resultType":"vector","result":[]}}`)) + })) + defer server.Close() + + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode, AvgGPUUtilization: &gpuUtil}, + } + opts := PrometheusOptions{URL: server.URL} + + enriched := EnrichPrometheusMetrics(context.Background(), nil, instances, opts) + + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestEnrichPrometheusMetrics_NoOptions(t *testing.T) { + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + + enriched := EnrichPrometheusMetrics(context.Background(), nil, instances, PrometheusOptions{}) + + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestEnrichPrometheusMetrics_InClusterEndpoint(t *testing.T) { + promResponse := `{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + {"metric": {"node": "i-node1"}, "value": [1700000000, "50.0"]} + ] + } + }` + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{}, + proxyData: map[string][]byte{ + "monitoring/prometheus:9090/api/v1/query": []byte(promResponse), + }, + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + opts := PrometheusOptions{Endpoint: "monitoring/prometheus:9090"} + + enriched := EnrichPrometheusMetrics(context.Background(), client, instances, opts) + + if enriched != 1 { + t.Errorf("expected 1 enriched, got %d", enriched) + } + if instances[0].AvgGPUUtilization == nil || *instances[0].AvgGPUUtilization != 50.0 { + t.Errorf("expected 50.0, got %v", instances[0].AvgGPUUtilization) + } +} + +func TestParsePrometheusEndpoint(t *testing.T) { + tests := []struct { + input string + namespace string + service string + port string + wantErr bool + }{ + {"monitoring/prometheus:9090", "monitoring", "prometheus", "9090", false}, + {"kube-system/thanos-query:10902", "kube-system", "thanos-query", "10902", false}, + {"invalid", "", "", "", true}, + {"ns/svc", "", "", "", true}, + } + for _, tt := range tests { + ns, svc, port, err := parsePrometheusEndpoint(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("parsePrometheusEndpoint(%q): err=%v, wantErr=%v", tt.input, err, tt.wantErr) + continue + } + if ns != tt.namespace || svc != tt.service || port != tt.port { + t.Errorf("parsePrometheusEndpoint(%q) = (%q,%q,%q), want (%q,%q,%q)", + tt.input, ns, svc, port, tt.namespace, tt.service, tt.port) + } + } +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `go test ./internal/providers/k8s/ -run "TestEnrichPrometheus|TestParsePrometheus" -v` +Expected: FAIL — functions not defined + +- [ ] **Step 3: Implement Prometheus metrics enrichment** + +Add to `internal/providers/k8s/metrics.go` (additional imports at the top): + +```go +import ( + "encoding/json" + "io" + "net/http" + "net/url" + "strconv" + "strings" +) +``` + +Add these types and functions: + +```go +// PrometheusOptions configures how to reach a Prometheus-compatible API. +type PrometheusOptions struct { + URL string + Endpoint string +} + +// EnrichPrometheusMetrics queries a Prometheus endpoint for GPU utilization metrics +// for K8s nodes that don't already have AvgGPUUtilization populated. +func EnrichPrometheusMetrics(ctx context.Context, client K8sClient, instances []models.GPUInstance, opts PrometheusOptions) int { + if opts.URL == "" && opts.Endpoint == "" { + return 0 + } + + type nodeRef struct { + index int + name string + } + var nodes []nodeRef + for i := range instances { + inst := &instances[i] + if inst.Source != models.SourceK8sNode || inst.AvgGPUUtilization != nil { + continue + } + nodes = append(nodes, nodeRef{index: i, name: inst.InstanceID}) + } + if len(nodes) == 0 { + return 0 + } + + source := opts.URL + if source == "" { + source = opts.Endpoint + } + fmt.Fprintf(os.Stderr, " Querying Prometheus at %s...\n", source) + + nodeNames := make([]string, len(nodes)) + for i, n := range nodes { + nodeNames[i] = n.name + } + nodeRegex := strings.Join(nodeNames, "|") + + gpuResults := queryPrometheus(ctx, client, opts, + fmt.Sprintf(`avg_over_time(DCGM_FI_DEV_GPU_UTIL{node=~"%s"}[7d])`, nodeRegex)) + memResults := queryPrometheus(ctx, client, opts, + fmt.Sprintf(`avg_over_time(DCGM_FI_DEV_MEM_COPY_UTIL{node=~"%s"}[7d])`, nodeRegex)) + + enriched := 0 + for _, node := range nodes { + if val, ok := gpuResults[node.name]; ok { + instances[node.index].AvgGPUUtilization = &val + if memVal, ok := memResults[node.name]; ok { + instances[node.index].AvgGPUMemUtilization = &memVal + } + enriched++ + } + } + + fmt.Fprintf(os.Stderr, " Prometheus: got GPU metrics for %d of %d remaining nodes\n", enriched, len(nodes)) + return enriched +} + +func queryPrometheus(ctx context.Context, client K8sClient, opts PrometheusOptions, query string) map[string]float64 { + var data []byte + var err error + + if opts.URL != "" { + data, err = queryPrometheusHTTP(ctx, opts.URL, query) + } else { + data, err = queryPrometheusProxy(ctx, client, opts.Endpoint, query) + } + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Prometheus query failed: %v\n", err) + return nil + } + + return parsePrometheusResponse(data) +} + +func queryPrometheusHTTP(ctx context.Context, baseURL, query string) ([]byte, error) { + u := fmt.Sprintf("%s/api/v1/query?query=%s", strings.TrimRight(baseURL, "/"), url.QueryEscape(query)) + req, err := http.NewRequestWithContext(ctx, "GET", u, nil) + if err != nil { + return nil, err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return io.ReadAll(resp.Body) +} + +func queryPrometheusProxy(ctx context.Context, client K8sClient, endpoint, query string) ([]byte, error) { + ns, svc, port, err := parsePrometheusEndpoint(endpoint) + if err != nil { + return nil, err + } + path := fmt.Sprintf("/api/v1/query?query=%s", url.QueryEscape(query)) + return client.ProxyGet(ctx, ns, svc, port, path) +} + +func parsePrometheusEndpoint(endpoint string) (namespace, service, port string, err error) { + slashIdx := strings.Index(endpoint, "/") + if slashIdx < 1 { + return "", "", "", fmt.Errorf("invalid endpoint format %q, expected namespace/service:port", endpoint) + } + namespace = endpoint[:slashIdx] + rest := endpoint[slashIdx+1:] + colonIdx := strings.LastIndex(rest, ":") + if colonIdx < 1 { + return "", "", "", fmt.Errorf("invalid endpoint format %q, expected namespace/service:port", endpoint) + } + service = rest[:colonIdx] + port = rest[colonIdx+1:] + return namespace, service, port, nil +} + +func parsePrometheusResponse(data []byte) map[string]float64 { + var resp struct { + Status string `json:"status"` + Data struct { + ResultType string `json:"resultType"` + Result []struct { + Metric map[string]string `json:"metric"` + Value []json.RawMessage `json:"value"` + } `json:"result"` + } `json:"data"` + } + if err := json.Unmarshal(data, &resp); err != nil { + return nil + } + if resp.Status != "success" { + return nil + } + + results := make(map[string]float64) + for _, r := range resp.Data.Result { + node := r.Metric["node"] + if node == "" || len(r.Value) < 2 { + continue + } + var valStr string + if err := json.Unmarshal(r.Value[1], &valStr); err != nil { + continue + } + val, err := strconv.ParseFloat(valStr, 64) + if err != nil { + continue + } + results[node] = val + } + return results +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `go test ./internal/providers/k8s/ -run "TestEnrichPrometheus|TestParsePrometheus" -v` +Expected: PASS (all 5 tests) + +- [ ] **Step 5: Run full test suite** + +Run: `go test ./...` +Expected: All tests pass + +- [ ] **Step 6: Commit** + +```bash +git add internal/providers/k8s/metrics.go internal/providers/k8s/metrics_test.go +git commit -m "Add Prometheus query enrichment for K8s GPU metrics" +``` + +--- + +### Task 5: K8s Low GPU Utilization Analysis Rule + +**Files:** +- Modify: `internal/analysis/rules.go` +- Modify: `internal/analysis/rules_test.go` + +- [ ] **Step 1: Write the failing tests** + +Add to `internal/analysis/rules_test.go`: + +```go +func TestRuleK8sLowGPUUtil_FlagsLowUtilization(t *testing.T) { + inst := models.GPUInstance{ + InstanceID: "i-node1", + Source: models.SourceK8sNode, + State: "ready", + InstanceType: "g5.xlarge", + GPUModel: "A10G", + GPUCount: 1, + GPUAllocated: 1, + MonthlyCost: 734, + AvgGPUUtilization: ptr(3.5), + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 1 { + t.Fatalf("expected 1 signal, got %d", len(inst.WasteSignals)) + } + if inst.WasteSignals[0].Type != "low_utilization" { + t.Errorf("expected low_utilization, got %s", inst.WasteSignals[0].Type) + } + if inst.WasteSignals[0].Severity != models.SeverityCritical { + t.Errorf("expected critical, got %s", inst.WasteSignals[0].Severity) + } + if inst.WasteSignals[0].Confidence != 0.85 { + t.Errorf("expected confidence 0.85, got %f", inst.WasteSignals[0].Confidence) + } + if len(inst.Recommendations) != 1 { + t.Fatalf("expected 1 recommendation, got %d", len(inst.Recommendations)) + } + if inst.Recommendations[0].MonthlySavings != 734*0.8 { + t.Errorf("expected savings %.0f, got %f", 734*0.8, inst.Recommendations[0].MonthlySavings) + } +} + +func TestRuleK8sLowGPUUtil_SkipsNonK8s(t *testing.T) { + inst := models.GPUInstance{ + Source: models.SourceEC2, + AvgGPUUtilization: ptr(3.5), + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals for EC2 instance") + } +} + +func TestRuleK8sLowGPUUtil_SkipsNoMetrics(t *testing.T) { + inst := models.GPUInstance{ + Source: models.SourceK8sNode, + State: "ready", + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals when metrics unavailable") + } +} + +func TestRuleK8sLowGPUUtil_SkipsHighUtilization(t *testing.T) { + inst := models.GPUInstance{ + Source: models.SourceK8sNode, + State: "ready", + AvgGPUUtilization: ptr(45.0), + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals for well-utilized GPU") + } +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `go test ./internal/analysis/ -run TestRuleK8sLowGPUUtil -v` +Expected: FAIL — `ruleK8sLowGPUUtil` not defined + +- [ ] **Step 3: Implement the rule** + +In `internal/analysis/rules.go`, add `ruleK8sLowGPUUtil` to the rules slice inside `analyzeInstance()` (line 23-31). The full slice should be: + +```go + rules := []func(*models.GPUInstance){ + ruleIdle, + ruleOversizedGPU, + rulePricingMismatch, + ruleStale, + ruleSageMakerLowUtil, + ruleSageMakerOversized, + ruleK8sUnallocatedGPU, + ruleSpotEligible, + ruleK8sLowGPUUtil, + } +``` + +Then add the rule function at the end of the file: + +```go +// Rule 9: K8s GPU node with low GPU utilization (requires DCGM/CW/Prometheus metrics). +func ruleK8sLowGPUUtil(inst *models.GPUInstance) { + if inst.Source != models.SourceK8sNode { + return + } + if inst.AvgGPUUtilization == nil { + return + } + if *inst.AvgGPUUtilization >= 10 { + return + } + + inst.WasteSignals = append(inst.WasteSignals, models.WasteSignal{ + Type: "low_utilization", + Severity: models.SeverityCritical, + Confidence: 0.85, + Evidence: fmt.Sprintf("K8s GPU node utilization averaging %.1f%%. GPUs are allocated but barely used.", *inst.AvgGPUUtilization), + }) + inst.Recommendations = append(inst.Recommendations, models.Recommendation{ + Action: models.ActionDownsize, + Description: fmt.Sprintf("GPU utilization averaging %.1f%%. Consider bin-packing more workloads, downsizing, or removing from the node pool.", *inst.AvgGPUUtilization), + CurrentMonthlyCost: inst.MonthlyCost, + RecommendedMonthlyCost: inst.MonthlyCost * 0.2, + MonthlySavings: inst.MonthlyCost * 0.8, + SavingsPercent: 80, + Risk: models.RiskMedium, + }) +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `go test ./internal/analysis/ -run TestRuleK8sLowGPUUtil -v` +Expected: PASS (all 4 tests) + +- [ ] **Step 5: Run full test suite** + +Run: `go test ./...` +Expected: All tests pass + +- [ ] **Step 6: Commit** + +```bash +git add internal/analysis/rules.go internal/analysis/rules_test.go +git commit -m "Add ruleK8sLowGPUUtil for utilization-based K8s GPU waste detection" +``` + +--- + +### Task 6: Wire Everything into CLI and Scan Flow + +**Files:** +- Modify: `cmd/gpuaudit/main.go` +- Modify: `internal/providers/k8s/scanner.go` + +This task adds the `--prom-url` and `--prom-endpoint` CLI flags, passes them through to the K8s scan, wires CloudWatch Container Insights enrichment, and orchestrates the fallback chain in `main.go`. + +- [ ] **Step 1: Extend K8s ScanOptions** + +In `internal/providers/k8s/scanner.go`, change the `ScanOptions` struct (lines 20-23) from: + +```go +type ScanOptions struct { + Kubeconfig string + Context string +} +``` + +to: + +```go +type ScanOptions struct { + Kubeconfig string + Context string + PromURL string + PromEndpoint string +} +``` + +- [ ] **Step 2: Export BuildClient** + +Add to `internal/providers/k8s/scanner.go` after the existing `buildClient` function: + +```go +func BuildClientPublic(kubeconfigPath, contextName string) (K8sClient, string, error) { + return buildClient(kubeconfigPath, contextName) +} +``` + +- [ ] **Step 3: Add CLI flags** + +In `cmd/gpuaudit/main.go`, add the flag variables after `scanKubeContext` (around line 51): + +```go + scanPromURL string + scanPromEndpoint string +``` + +Add the flag registrations inside the first `init()` function, after the `--kube-context` flag (after line 73): + +```go + scanCmd.Flags().StringVar(&scanPromURL, "prom-url", "", "Prometheus URL for GPU metrics (e.g., https://prometheus.corp.example.com)") + scanCmd.Flags().StringVar(&scanPromEndpoint, "prom-endpoint", "", "In-cluster Prometheus service as namespace/service:port (e.g., monitoring/prometheus:9090)") +``` + +- [ ] **Step 4: Add flag validation and wiring in runScan** + +In `cmd/gpuaudit/main.go`, in the `runScan` function, add validation after `ctx := context.Background()` (line 84): + +```go + if scanPromURL != "" && scanPromEndpoint != "" { + return fmt.Errorf("--prom-url and --prom-endpoint are mutually exclusive") + } +``` + +Then modify the K8s scan section. Replace the block starting with `// Kubernetes API scan` (around lines 107-119) with: + +```go + // Kubernetes API scan + if !scanSkipK8s { + k8sOpts := k8sprovider.ScanOptions{ + Kubeconfig: scanKubeconfig, + Context: scanKubeContext, + PromURL: scanPromURL, + PromEndpoint: scanPromEndpoint, + } + k8sInstances, err := k8sprovider.Scan(ctx, k8sOpts) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Kubernetes scan failed: %v\n", err) + } else if len(k8sInstances) > 0 { + if !scanSkipMetrics { + enrichK8sGPUMetrics(ctx, k8sInstances, k8sOpts, opts) + } + analysis.AnalyzeAll(k8sInstances) + result.Instances = append(result.Instances, k8sInstances...) + result.Summary = awsprovider.BuildSummary(result.Instances) + } + } +``` + +- [ ] **Step 5: Add the enrichK8sGPUMetrics helper function** + +Add this function at the bottom of `cmd/gpuaudit/main.go`: + +```go +func enrichK8sGPUMetrics(ctx context.Context, instances []models.GPUInstance, k8sOpts k8sprovider.ScanOptions, awsOpts awsprovider.ScanOptions) { + // Source 1: CloudWatch Container Insights + if len(instances) > 0 && instances[0].ClusterName != "" { + cfgOpts := []func(*awsconfig.LoadOptions) error{} + if awsOpts.Profile != "" { + cfgOpts = append(cfgOpts, awsconfig.WithSharedConfigProfile(awsOpts.Profile)) + } + cfg, err := awsconfig.LoadDefaultConfig(ctx, cfgOpts...) + if err == nil { + region := instances[0].Region + if region == "" { + region = "us-east-1" + } + cfg.Region = region + cwClient := cloudwatch.NewFromConfig(cfg) + fmt.Fprintf(os.Stderr, " Enriching K8s GPU metrics via CloudWatch Container Insights...\n") + awsprovider.EnrichK8sGPUMetrics(ctx, cwClient, instances, instances[0].ClusterName, awsprovider.DefaultMetricWindow) + + enriched := 0 + for _, inst := range instances { + if inst.AvgGPUUtilization != nil { + enriched++ + } + } + fmt.Fprintf(os.Stderr, " CloudWatch: got GPU metrics for %d of %d nodes\n", enriched, len(instances)) + } + } + + // Count remaining + remaining := 0 + for _, inst := range instances { + if inst.AvgGPUUtilization == nil { + remaining++ + } + } + + // Source 2: DCGM exporter scrape + if remaining > 0 { + client, _, err := k8sprovider.BuildClientPublic(k8sOpts.Kubeconfig, k8sOpts.Context) + if err == nil { + k8sprovider.EnrichDCGMMetrics(ctx, client, instances) + } + + remaining = 0 + for _, inst := range instances { + if inst.AvgGPUUtilization == nil { + remaining++ + } + } + } + + // Source 3: Prometheus query + if remaining > 0 && (k8sOpts.PromURL != "" || k8sOpts.PromEndpoint != "") { + var client k8sprovider.K8sClient + if k8sOpts.PromEndpoint != "" { + c, _, err := k8sprovider.BuildClientPublic(k8sOpts.Kubeconfig, k8sOpts.Context) + if err == nil { + client = c + } + } + promOpts := k8sprovider.PrometheusOptions{ + URL: k8sOpts.PromURL, + Endpoint: k8sOpts.PromEndpoint, + } + k8sprovider.EnrichPrometheusMetrics(ctx, client, instances, promOpts) + } +} +``` + +You will need to add the `"github.com/aws/aws-sdk-go-v2/service/cloudwatch"` import to `main.go` if it's not already present. + +- [ ] **Step 6: Run build and full test suite** + +Run: `go build ./... && go test ./...` +Expected: Build succeeds, all tests pass + +- [ ] **Step 7: Commit** + +```bash +git add cmd/gpuaudit/main.go internal/providers/k8s/scanner.go +git commit -m "Wire K8s GPU metrics fallback chain into CLI scan flow" +``` From 1c2d3d837d4c3a4a672abc09a8454c39eab72906 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 22:45:15 +0100 Subject: [PATCH 03/39] Add EnrichK8sGPUMetrics for CloudWatch Container Insights GPU metrics --- internal/providers/aws/cloudwatch.go | 57 ++++++++++ internal/providers/aws/cloudwatch_test.go | 125 ++++++++++++++++++++++ 2 files changed, 182 insertions(+) create mode 100644 internal/providers/aws/cloudwatch_test.go diff --git a/internal/providers/aws/cloudwatch.go b/internal/providers/aws/cloudwatch.go index 819261c..b9d1978 100644 --- a/internal/providers/aws/cloudwatch.go +++ b/internal/providers/aws/cloudwatch.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "os" + "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -79,6 +80,62 @@ func EnrichSageMakerMetrics(ctx context.Context, client CloudWatchClient, instan return nil } +// EnrichK8sGPUMetrics populates GPU utilization metrics on K8s nodes using CloudWatch Container Insights. +func EnrichK8sGPUMetrics(ctx context.Context, client CloudWatchClient, instances []models.GPUInstance, clusterName string, window MetricWindow) { + type nodeRef struct { + index int + instanceID string + } + var nodes []nodeRef + for i := range instances { + inst := &instances[i] + if inst.Source != models.SourceK8sNode { + continue + } + if inst.AvgGPUUtilization != nil { + continue + } + if !strings.HasPrefix(inst.InstanceID, "i-") { + continue + } + nodes = append(nodes, nodeRef{index: i, instanceID: inst.InstanceID}) + } + if len(nodes) == 0 { + return + } + + now := time.Now() + start := now.Add(-window.Duration) + + clusterDim := cwtypes.Dimension{ + Name: aws.String("ClusterName"), + Value: aws.String(clusterName), + } + + for _, node := range nodes { + instanceDim := cwtypes.Dimension{ + Name: aws.String("InstanceId"), + Value: aws.String(node.instanceID), + } + + safeID := strings.ReplaceAll(node.instanceID, "-", "_") + + queries := []cwtypes.MetricDataQuery{ + metricQuery2("gpu_util_"+safeID, "ContainerInsights", "node_gpu_utilization", "Average", window.Period, clusterDim, instanceDim), + metricQuery2("gpu_mem_"+safeID, "ContainerInsights", "node_gpu_memory_utilization", "Average", window.Period, clusterDim, instanceDim), + } + + results, err := fetchMetrics(ctx, client, queries, start, now) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Container Insights metrics unavailable for %s: %v\n", node.instanceID, err) + continue + } + + instances[node.index].AvgGPUUtilization = results["gpu_util_"+safeID] + instances[node.index].AvgGPUMemUtilization = results["gpu_mem_"+safeID] + } +} + func getEC2Metrics(ctx context.Context, client CloudWatchClient, instanceID string, window MetricWindow) (map[string]*float64, error) { now := time.Now() start := now.Add(-window.Duration) diff --git a/internal/providers/aws/cloudwatch_test.go b/internal/providers/aws/cloudwatch_test.go new file mode 100644 index 0000000..6dd1d8f --- /dev/null +++ b/internal/providers/aws/cloudwatch_test.go @@ -0,0 +1,125 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/cloudwatch" + cwtypes "github.com/aws/aws-sdk-go-v2/service/cloudwatch/types" + + "github.com/gpuaudit/cli/internal/models" +) + +type mockCloudWatchClient struct { + output *cloudwatch.GetMetricDataOutput + err error +} + +func (m *mockCloudWatchClient) GetMetricData(ctx context.Context, params *cloudwatch.GetMetricDataInput, optFns ...func(*cloudwatch.Options)) (*cloudwatch.GetMetricDataOutput, error) { + if m.err != nil { + return nil, m.err + } + return m.output, nil +} + +func TestEnrichK8sGPUMetrics_PopulatesUtilization(t *testing.T) { + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{ + MetricDataResults: []cwtypes.MetricDataResult{ + {Id: aws.String("gpu_util_i_abc123"), Values: []float64{45.0, 50.0, 55.0}}, + {Id: aws.String("gpu_mem_i_abc123"), Values: []float64{30.0, 35.0, 40.0}}, + }, + }, + } + instances := []models.GPUInstance{ + { + InstanceID: "i-abc123", + Source: models.SourceK8sNode, + }, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "ml-cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization == nil { + t.Fatal("expected GPU utilization to be populated") + } + if *instances[0].AvgGPUUtilization != 50.0 { + t.Errorf("expected avg GPU util 50.0, got %f", *instances[0].AvgGPUUtilization) + } + if instances[0].AvgGPUMemUtilization == nil { + t.Fatal("expected GPU memory utilization to be populated") + } + if *instances[0].AvgGPUMemUtilization != 35.0 { + t.Errorf("expected avg GPU mem util 35.0, got %f", *instances[0].AvgGPUMemUtilization) + } +} + +func TestEnrichK8sGPUMetrics_SkipsNonK8sNodes(t *testing.T) { + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{}, + } + instances := []models.GPUInstance{ + {InstanceID: "i-ec2", Source: models.SourceEC2}, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil GPU util for non-K8s instance") + } +} + +func TestEnrichK8sGPUMetrics_SkipsNodesWithoutInstanceID(t *testing.T) { + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{}, + } + instances := []models.GPUInstance{ + {InstanceID: "node-hostname", Source: models.SourceK8sNode}, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil GPU util for node without EC2 instance ID") + } +} + +func TestEnrichK8sGPUMetrics_SkipsAlreadyEnriched(t *testing.T) { + gpuUtil := 75.0 + client := &mockCloudWatchClient{ + output: &cloudwatch.GetMetricDataOutput{}, + } + instances := []models.GPUInstance{ + { + InstanceID: "i-abc123", + Source: models.SourceK8sNode, + AvgGPUUtilization: &gpuUtil, + }, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if *instances[0].AvgGPUUtilization != 75.0 { + t.Errorf("expected existing value 75.0 to be preserved, got %f", *instances[0].AvgGPUUtilization) + } +} + +func TestEnrichK8sGPUMetrics_HandlesAPIError(t *testing.T) { + client := &mockCloudWatchClient{ + err: fmt.Errorf("access denied"), + } + instances := []models.GPUInstance{ + {InstanceID: "i-abc123", Source: models.SourceK8sNode}, + } + + EnrichK8sGPUMetrics(context.Background(), client, instances, "cluster", DefaultMetricWindow) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil GPU util after API error") + } +} From 96155c152e2d987d9bbf16be8e32a30d4c26a91f Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 22:46:52 +0100 Subject: [PATCH 04/39] Add ProxyGet to K8sClient interface for pod API proxy --- internal/providers/k8s/discover.go | 1 + internal/providers/k8s/discover_test.go | 17 +++++++++++++++-- internal/providers/k8s/scanner.go | 4 ++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/internal/providers/k8s/discover.go b/internal/providers/k8s/discover.go index 6df9ef0..14fe00c 100644 --- a/internal/providers/k8s/discover.go +++ b/internal/providers/k8s/discover.go @@ -24,6 +24,7 @@ const gpuResourceName corev1.ResourceName = "nvidia.com/gpu" type K8sClient interface { ListNodes(ctx context.Context, opts metav1.ListOptions) (*corev1.NodeList, error) ListPods(ctx context.Context, namespace string, opts metav1.ListOptions) (*corev1.PodList, error) + ProxyGet(ctx context.Context, namespace, podName, port, path string) ([]byte, error) } // DiscoverGPUNodes finds Kubernetes nodes with GPU capacity and reports their allocation. diff --git a/internal/providers/k8s/discover_test.go b/internal/providers/k8s/discover_test.go index 9d0cff1..016c9df 100644 --- a/internal/providers/k8s/discover_test.go +++ b/internal/providers/k8s/discover_test.go @@ -17,8 +17,10 @@ import ( ) type mockK8sClient struct { - nodes *corev1.NodeList - pods *corev1.PodList + nodes *corev1.NodeList + pods *corev1.PodList + proxyData map[string][]byte + proxyErr error } func (m *mockK8sClient) ListNodes(ctx context.Context, opts metav1.ListOptions) (*corev1.NodeList, error) { @@ -29,6 +31,17 @@ func (m *mockK8sClient) ListPods(ctx context.Context, namespace string, opts met return m.pods, nil } +func (m *mockK8sClient) ProxyGet(ctx context.Context, namespace, podName, port, path string) ([]byte, error) { + if m.proxyErr != nil { + return nil, m.proxyErr + } + key := fmt.Sprintf("%s/%s:%s%s", namespace, podName, port, path) + if data, ok := m.proxyData[key]; ok { + return data, nil + } + return nil, fmt.Errorf("no mock data for %s", key) +} + func gpuNode(name, instanceType string, gpuCount int, ready bool, created time.Time) corev1.Node { readyStatus := corev1.ConditionFalse if ready { diff --git a/internal/providers/k8s/scanner.go b/internal/providers/k8s/scanner.go index 67634f3..edea338 100644 --- a/internal/providers/k8s/scanner.go +++ b/internal/providers/k8s/scanner.go @@ -100,6 +100,10 @@ func (w *k8sClientWrapper) ListPods(ctx context.Context, namespace string, opts return w.clientset.CoreV1().Pods(namespace).List(ctx, opts) } +func (w *k8sClientWrapper) ProxyGet(ctx context.Context, namespace, podName, port, path string) ([]byte, error) { + return w.clientset.CoreV1().Pods(namespace).ProxyGet("http", podName, port, path, nil).DoRaw(ctx) +} + func defaultKubeconfig() string { home, err := os.UserHomeDir() if err != nil { From 84a4a9b156ab43586703d0e6703614d16260a9a8 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 22:50:49 +0100 Subject: [PATCH 05/39] Add DCGM exporter scraping for K8s GPU metrics Discovers dcgm-exporter pods via label selectors and scrapes their Prometheus metrics endpoint via kubectl proxy to populate GPU and GPU memory utilization on K8s node instances. Skips nodes that already have utilization data and gracefully handles scrape errors. --- go.mod | 17 ++- go.sum | 42 +++--- internal/providers/k8s/metrics.go | 132 +++++++++++++++++++ internal/providers/k8s/metrics_test.go | 172 +++++++++++++++++++++++++ 4 files changed, 338 insertions(+), 25 deletions(-) create mode 100644 internal/providers/k8s/metrics.go create mode 100644 internal/providers/k8s/metrics_test.go diff --git a/go.mod b/go.mod index b86d582..e6bceb9 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,8 @@ require ( github.com/aws/aws-sdk-go-v2/service/eks v1.82.0 github.com/aws/aws-sdk-go-v2/service/sagemaker v1.238.0 github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 + github.com/prometheus/client_model v0.6.2 + github.com/prometheus/common v0.67.5 github.com/spf13/cobra v1.10.2 k8s.io/api v0.32.3 k8s.io/apimachinery v0.32.3 @@ -39,7 +41,7 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect - github.com/google/go-cmp v0.6.0 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -52,13 +54,14 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/spf13/pflag v1.0.9 // indirect github.com/x448/float16 v0.8.4 // indirect - golang.org/x/net v0.30.0 // indirect - golang.org/x/oauth2 v0.23.0 // indirect - golang.org/x/sys v0.26.0 // indirect - golang.org/x/term v0.25.0 // indirect - golang.org/x/text v0.19.0 // indirect + go.yaml.in/yaml/v2 v2.4.3 // indirect + golang.org/x/net v0.48.0 // indirect + golang.org/x/oauth2 v0.34.0 // indirect + golang.org/x/sys v0.39.0 // indirect + golang.org/x/term v0.38.0 // indirect + golang.org/x/text v0.32.0 // indirect golang.org/x/time v0.7.0 // indirect - google.golang.org/protobuf v1.35.1 // indirect + google.golang.org/protobuf v1.36.11 // indirect gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index c4d6139..08691a8 100644 --- a/go.sum +++ b/go.sum @@ -65,8 +65,8 @@ github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6 github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I= github.com/google/gnostic-models v0.6.8/go.mod h1:5n7qKqH0f5wFt+aWF8CW6pZLLNOfYuF5OpfBSENuI8U= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -107,6 +107,10 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4= +github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -121,12 +125,14 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= +go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -137,38 +143,38 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= -golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= -golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= -golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= -golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= +golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= -golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ= -golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= +golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= -google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/providers/k8s/metrics.go b/internal/providers/k8s/metrics.go new file mode 100644 index 0000000..c487a45 --- /dev/null +++ b/internal/providers/k8s/metrics.go @@ -0,0 +1,132 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package k8s + +import ( + "bytes" + "context" + "fmt" + "os" + + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/common/expfmt" + "github.com/prometheus/common/model" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/gpuaudit/cli/internal/models" +) + +// EnrichDCGMMetrics discovers dcgm-exporter pods and scrapes GPU metrics for K8s nodes +// that don't already have AvgGPUUtilization populated. Returns the number of nodes enriched. +func EnrichDCGMMetrics(ctx context.Context, client K8sClient, instances []models.GPUInstance) int { + needsMetrics := make(map[string]int) + for i := range instances { + inst := &instances[i] + if inst.Source != models.SourceK8sNode || inst.AvgGPUUtilization != nil { + continue + } + needsMetrics[inst.InstanceID] = i + } + if len(needsMetrics) == 0 { + return 0 + } + + dcgmPods, err := findDCGMPods(ctx, client) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: could not list DCGM exporter pods: %v\n", err) + return 0 + } + if len(dcgmPods) == 0 { + fmt.Fprintf(os.Stderr, " DCGM exporter not detected, skipping\n") + return 0 + } + + fmt.Fprintf(os.Stderr, " Probing DCGM exporter on GPU nodes...\n") + + enriched := 0 + for _, pod := range dcgmPods { + idx, ok := needsMetrics[pod.Spec.NodeName] + if !ok { + continue + } + + data, err := client.ProxyGet(ctx, pod.Namespace, pod.Name, "9400", "/metrics") + if err != nil { + fmt.Fprintf(os.Stderr, " warning: DCGM scrape failed for %s: %v\n", pod.Spec.NodeName, err) + continue + } + + gpuUtil, memUtil := parseDCGMMetrics(data) + if gpuUtil != nil { + instances[idx].AvgGPUUtilization = gpuUtil + instances[idx].AvgGPUMemUtilization = memUtil + enriched++ + } + } + + fmt.Fprintf(os.Stderr, " DCGM: got GPU metrics for %d of %d remaining nodes\n", enriched, len(needsMetrics)) + return enriched +} + +func findDCGMPods(ctx context.Context, client K8sClient) ([]corev1.Pod, error) { + podList, err := client.ListPods(ctx, "", metav1.ListOptions{ + LabelSelector: "app.kubernetes.io/name=dcgm-exporter", + }) + if err != nil { + return nil, err + } + if len(podList.Items) > 0 { + return runningPods(podList.Items), nil + } + + podList, err = client.ListPods(ctx, "", metav1.ListOptions{ + LabelSelector: "app=nvidia-dcgm-exporter", + }) + if err != nil { + return nil, err + } + return runningPods(podList.Items), nil +} + +func runningPods(pods []corev1.Pod) []corev1.Pod { + var result []corev1.Pod + for _, p := range pods { + if p.Status.Phase == corev1.PodRunning { + result = append(result, p) + } + } + return result +} + +func parseDCGMMetrics(data []byte) (gpuUtil, memUtil *float64) { + parser := expfmt.NewTextParser(model.LegacyValidation) + families, err := parser.TextToMetricFamilies(bytes.NewReader(data)) + if err != nil { + return nil, nil + } + + gpuUtil = avgMetricValue(families["DCGM_FI_DEV_GPU_UTIL"]) + memUtil = avgMetricValue(families["DCGM_FI_DEV_MEM_COPY_UTIL"]) + return gpuUtil, memUtil +} + +func avgMetricValue(family *dto.MetricFamily) *float64 { + if family == nil || len(family.Metric) == 0 { + return nil + } + sum := 0.0 + count := 0 + for _, m := range family.Metric { + if m.Gauge != nil && m.Gauge.Value != nil { + sum += *m.Gauge.Value + count++ + } + } + if count == 0 { + return nil + } + avg := sum / float64(count) + return &avg +} diff --git a/internal/providers/k8s/metrics_test.go b/internal/providers/k8s/metrics_test.go new file mode 100644 index 0000000..01103cd --- /dev/null +++ b/internal/providers/k8s/metrics_test.go @@ -0,0 +1,172 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package k8s + +import ( + "context" + "fmt" + "testing" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/gpuaudit/cli/internal/models" +) + +func dcgmPod(name, namespace, nodeName string) corev1.Pod { + return corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + Labels: map[string]string{ + "app.kubernetes.io/name": "dcgm-exporter", + }, + }, + Spec: corev1.PodSpec{ + NodeName: nodeName, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodRunning, + }, + } +} + +const sampleDCGMMetrics = `# HELP DCGM_FI_DEV_GPU_UTIL GPU utilization. +# TYPE DCGM_FI_DEV_GPU_UTIL gauge +DCGM_FI_DEV_GPU_UTIL{gpu="0",UUID="GPU-abc",device="nvidia0",modelName="NVIDIA A10G",Hostname="node1"} 42.0 +DCGM_FI_DEV_GPU_UTIL{gpu="1",UUID="GPU-def",device="nvidia1",modelName="NVIDIA A10G",Hostname="node1"} 38.0 +# HELP DCGM_FI_DEV_MEM_COPY_UTIL GPU memory utilization. +# TYPE DCGM_FI_DEV_MEM_COPY_UTIL gauge +DCGM_FI_DEV_MEM_COPY_UTIL{gpu="0",UUID="GPU-abc",device="nvidia0",modelName="NVIDIA A10G",Hostname="node1"} 55.0 +DCGM_FI_DEV_MEM_COPY_UTIL{gpu="1",UUID="GPU-def",device="nvidia1",modelName="NVIDIA A10G",Hostname="node1"} 60.0 +` + +func TestEnrichDCGMMetrics_PopulatesUtilization(t *testing.T) { + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{ + Items: []corev1.Pod{ + dcgmPod("dcgm-exporter-abc", "gpu-operator", "i-node1"), + }, + }, + proxyData: map[string][]byte{ + "gpu-operator/dcgm-exporter-abc:9400/metrics": []byte(sampleDCGMMetrics), + }, + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode, Name: "cluster/i-node1"}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if instances[0].AvgGPUUtilization == nil { + t.Fatal("expected GPU utilization to be populated") + } + if *instances[0].AvgGPUUtilization != 40.0 { + t.Errorf("expected avg GPU util 40.0 (average of 42 and 38), got %f", *instances[0].AvgGPUUtilization) + } + if instances[0].AvgGPUMemUtilization == nil { + t.Fatal("expected GPU memory utilization to be populated") + } + if *instances[0].AvgGPUMemUtilization != 57.5 { + t.Errorf("expected avg GPU mem util 57.5 (average of 55 and 60), got %f", *instances[0].AvgGPUMemUtilization) + } + if enriched != 1 { + t.Errorf("expected 1 enriched node, got %d", enriched) + } +} + +func TestEnrichDCGMMetrics_SkipsAlreadyEnriched(t *testing.T) { + gpuUtil := 75.0 + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{ + Items: []corev1.Pod{ + dcgmPod("dcgm-exporter-abc", "gpu-operator", "i-node1"), + }, + }, + proxyData: map[string][]byte{ + "gpu-operator/dcgm-exporter-abc:9400/metrics": []byte(sampleDCGMMetrics), + }, + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode, AvgGPUUtilization: &gpuUtil}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if *instances[0].AvgGPUUtilization != 75.0 { + t.Error("should not overwrite existing utilization") + } + if enriched != 0 { + t.Errorf("expected 0 enriched nodes, got %d", enriched) + } +} + +func TestEnrichDCGMMetrics_NoDCGMPods(t *testing.T) { + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{Items: []corev1.Pod{}}, + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil when no DCGM pods") + } + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestEnrichDCGMMetrics_HandlesScrapeError(t *testing.T) { + client := &mockK8sClient{ + nodes: &corev1.NodeList{}, + pods: &corev1.PodList{ + Items: []corev1.Pod{ + dcgmPod("dcgm-exporter-abc", "gpu-operator", "i-node1"), + }, + }, + proxyErr: fmt.Errorf("connection refused"), + } + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + + enriched := EnrichDCGMMetrics(context.Background(), client, instances) + + if instances[0].AvgGPUUtilization != nil { + t.Error("expected nil after scrape error") + } + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestParseDCGMMetrics(t *testing.T) { + gpuUtil, memUtil := parseDCGMMetrics([]byte(sampleDCGMMetrics)) + + if gpuUtil == nil { + t.Fatal("expected gpu util") + } + if *gpuUtil != 40.0 { + t.Errorf("expected 40.0, got %f", *gpuUtil) + } + if memUtil == nil { + t.Fatal("expected mem util") + } + if *memUtil != 57.5 { + t.Errorf("expected 57.5, got %f", *memUtil) + } +} + +func TestParseDCGMMetrics_EmptyInput(t *testing.T) { + gpuUtil, memUtil := parseDCGMMetrics([]byte("")) + if gpuUtil != nil || memUtil != nil { + t.Error("expected nil for empty input") + } +} From 0f460c4c4d4169d2883180b2fce6fa1c4f3a15d8 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 22:54:16 +0100 Subject: [PATCH 06/39] Add Prometheus query enrichment for K8s GPU metrics --- internal/providers/k8s/metrics.go | 160 +++++++++++++++++++++++++ internal/providers/k8s/metrics_test.go | 142 ++++++++++++++++++++++ 2 files changed, 302 insertions(+) diff --git a/internal/providers/k8s/metrics.go b/internal/providers/k8s/metrics.go index c487a45..ef47470 100644 --- a/internal/providers/k8s/metrics.go +++ b/internal/providers/k8s/metrics.go @@ -6,8 +6,14 @@ package k8s import ( "bytes" "context" + "encoding/json" "fmt" + "io" + "net/http" + "net/url" "os" + "strconv" + "strings" dto "github.com/prometheus/client_model/go" "github.com/prometheus/common/expfmt" @@ -130,3 +136,157 @@ func avgMetricValue(family *dto.MetricFamily) *float64 { avg := sum / float64(count) return &avg } + +// PrometheusOptions configures how to reach a Prometheus-compatible API. +type PrometheusOptions struct { + URL string + Endpoint string +} + +// EnrichPrometheusMetrics queries a Prometheus endpoint for GPU utilization metrics +// for K8s nodes that don't already have AvgGPUUtilization populated. +func EnrichPrometheusMetrics(ctx context.Context, client K8sClient, instances []models.GPUInstance, opts PrometheusOptions) int { + if opts.URL == "" && opts.Endpoint == "" { + return 0 + } + + type nodeRef struct { + index int + name string + } + var nodes []nodeRef + for i := range instances { + inst := &instances[i] + if inst.Source != models.SourceK8sNode || inst.AvgGPUUtilization != nil { + continue + } + nodes = append(nodes, nodeRef{index: i, name: inst.InstanceID}) + } + if len(nodes) == 0 { + return 0 + } + + source := opts.URL + if source == "" { + source = opts.Endpoint + } + fmt.Fprintf(os.Stderr, " Querying Prometheus at %s...\n", source) + + nodeNames := make([]string, len(nodes)) + for i, n := range nodes { + nodeNames[i] = n.name + } + nodeRegex := strings.Join(nodeNames, "|") + + gpuResults := queryPrometheus(ctx, client, opts, + fmt.Sprintf(`avg_over_time(DCGM_FI_DEV_GPU_UTIL{node=~"%s"}[7d])`, nodeRegex)) + memResults := queryPrometheus(ctx, client, opts, + fmt.Sprintf(`avg_over_time(DCGM_FI_DEV_MEM_COPY_UTIL{node=~"%s"}[7d])`, nodeRegex)) + + enriched := 0 + for _, node := range nodes { + if val, ok := gpuResults[node.name]; ok { + instances[node.index].AvgGPUUtilization = &val + if memVal, ok := memResults[node.name]; ok { + instances[node.index].AvgGPUMemUtilization = &memVal + } + enriched++ + } + } + + fmt.Fprintf(os.Stderr, " Prometheus: got GPU metrics for %d of %d remaining nodes\n", enriched, len(nodes)) + return enriched +} + +func queryPrometheus(ctx context.Context, client K8sClient, opts PrometheusOptions, query string) map[string]float64 { + var data []byte + var err error + + if opts.URL != "" { + data, err = queryPrometheusHTTP(ctx, opts.URL, query) + } else { + data, err = queryPrometheusProxy(ctx, client, opts.Endpoint, query) + } + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Prometheus query failed: %v\n", err) + return nil + } + + return parsePrometheusResponse(data) +} + +func queryPrometheusHTTP(ctx context.Context, baseURL, query string) ([]byte, error) { + u := fmt.Sprintf("%s/api/v1/query?query=%s", strings.TrimRight(baseURL, "/"), url.QueryEscape(query)) + req, err := http.NewRequestWithContext(ctx, "GET", u, nil) + if err != nil { + return nil, err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return io.ReadAll(resp.Body) +} + +func queryPrometheusProxy(ctx context.Context, client K8sClient, endpoint, query string) ([]byte, error) { + ns, svc, port, err := parsePrometheusEndpoint(endpoint) + if err != nil { + return nil, err + } + path := fmt.Sprintf("/api/v1/query?query=%s", url.QueryEscape(query)) + return client.ProxyGet(ctx, ns, svc, port, path) +} + +func parsePrometheusEndpoint(endpoint string) (namespace, service, port string, err error) { + slashIdx := strings.Index(endpoint, "/") + if slashIdx < 1 { + return "", "", "", fmt.Errorf("invalid endpoint format %q, expected namespace/service:port", endpoint) + } + namespace = endpoint[:slashIdx] + rest := endpoint[slashIdx+1:] + colonIdx := strings.LastIndex(rest, ":") + if colonIdx < 1 { + return "", "", "", fmt.Errorf("invalid endpoint format %q, expected namespace/service:port", endpoint) + } + service = rest[:colonIdx] + port = rest[colonIdx+1:] + return namespace, service, port, nil +} + +func parsePrometheusResponse(data []byte) map[string]float64 { + var resp struct { + Status string `json:"status"` + Data struct { + ResultType string `json:"resultType"` + Result []struct { + Metric map[string]string `json:"metric"` + Value []json.RawMessage `json:"value"` + } `json:"result"` + } `json:"data"` + } + if err := json.Unmarshal(data, &resp); err != nil { + return nil + } + if resp.Status != "success" { + return nil + } + + results := make(map[string]float64) + for _, r := range resp.Data.Result { + node := r.Metric["node"] + if node == "" || len(r.Value) < 2 { + continue + } + var valStr string + if err := json.Unmarshal(r.Value[1], &valStr); err != nil { + continue + } + val, err := strconv.ParseFloat(valStr, 64) + if err != nil { + continue + } + results[node] = val + } + return results +} diff --git a/internal/providers/k8s/metrics_test.go b/internal/providers/k8s/metrics_test.go index 01103cd..329d7eb 100644 --- a/internal/providers/k8s/metrics_test.go +++ b/internal/providers/k8s/metrics_test.go @@ -6,6 +6,9 @@ package k8s import ( "context" "fmt" + "net/http" + "net/http/httptest" + "strings" "testing" corev1 "k8s.io/api/core/v1" @@ -170,3 +173,142 @@ func TestParseDCGMMetrics_EmptyInput(t *testing.T) { t.Error("expected nil for empty input") } } + +func TestEnrichPrometheusMetrics_PopulatesFromDirectURL(t *testing.T) { + promResponse := `{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + {"metric": {"node": "i-node1"}, "value": [1700000000, "65.5"]}, + {"metric": {"node": "i-node2"}, "value": [1700000000, "30.0"]} + ] + } + }` + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/query" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + query := r.URL.Query().Get("query") + if !strings.Contains(query, "DCGM_FI_DEV_GPU_UTIL") && !strings.Contains(query, "DCGM_FI_DEV_MEM_COPY_UTIL") { + t.Errorf("unexpected query: %s", query) + } + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(promResponse)) + })) + defer server.Close() + + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode, Name: "cluster/i-node1"}, + {InstanceID: "i-node2", Source: models.SourceK8sNode, Name: "cluster/i-node2"}, + } + opts := PrometheusOptions{URL: server.URL} + + enriched := EnrichPrometheusMetrics(context.Background(), nil, instances, opts) + + if enriched != 2 { + t.Errorf("expected 2 enriched, got %d", enriched) + } + if instances[0].AvgGPUUtilization == nil || *instances[0].AvgGPUUtilization != 65.5 { + t.Errorf("expected node1 GPU util 65.5, got %v", instances[0].AvgGPUUtilization) + } + if instances[1].AvgGPUUtilization == nil || *instances[1].AvgGPUUtilization != 30.0 { + t.Errorf("expected node2 GPU util 30.0, got %v", instances[1].AvgGPUUtilization) + } +} + +func TestEnrichPrometheusMetrics_SkipsAlreadyEnriched(t *testing.T) { + gpuUtil := 80.0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status":"success","data":{"resultType":"vector","result":[]}}`)) + })) + defer server.Close() + + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode, AvgGPUUtilization: &gpuUtil}, + } + opts := PrometheusOptions{URL: server.URL} + + enriched := EnrichPrometheusMetrics(context.Background(), nil, instances, opts) + + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestEnrichPrometheusMetrics_NoOptions(t *testing.T) { + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + + enriched := EnrichPrometheusMetrics(context.Background(), nil, instances, PrometheusOptions{}) + + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestEnrichPrometheusMetrics_InClusterEndpoint(t *testing.T) { + promResponse := `{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + {"metric": {"node": "i-node1"}, "value": [1700000000, "50.0"]} + ] + } + }` + instances := []models.GPUInstance{ + {InstanceID: "i-node1", Source: models.SourceK8sNode}, + } + opts := PrometheusOptions{Endpoint: "monitoring/prometheus:9090"} + + // Use a custom client that returns promResponse for any ProxyGet to monitoring/prometheus + customClient := &promMockClient{response: []byte(promResponse)} + + enriched := EnrichPrometheusMetrics(context.Background(), customClient, instances, opts) + + if enriched != 1 { + t.Errorf("expected 1 enriched, got %d", enriched) + } + if instances[0].AvgGPUUtilization == nil || *instances[0].AvgGPUUtilization != 50.0 { + t.Errorf("expected 50.0, got %v", instances[0].AvgGPUUtilization) + } +} + +// promMockClient is a specialized mock that always returns a fixed response for ProxyGet. +type promMockClient struct { + mockK8sClient + response []byte +} + +func (m *promMockClient) ProxyGet(ctx context.Context, namespace, podName, port, path string) ([]byte, error) { + return m.response, nil +} + +func TestParsePrometheusEndpoint(t *testing.T) { + tests := []struct { + input string + namespace string + service string + port string + wantErr bool + }{ + {"monitoring/prometheus:9090", "monitoring", "prometheus", "9090", false}, + {"kube-system/thanos-query:10902", "kube-system", "thanos-query", "10902", false}, + {"invalid", "", "", "", true}, + {"ns/svc", "", "", "", true}, + } + for _, tt := range tests { + ns, svc, port, err := parsePrometheusEndpoint(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("parsePrometheusEndpoint(%q): err=%v, wantErr=%v", tt.input, err, tt.wantErr) + continue + } + if ns != tt.namespace || svc != tt.service || port != tt.port { + t.Errorf("parsePrometheusEndpoint(%q) = (%q,%q,%q), want (%q,%q,%q)", + tt.input, ns, svc, port, tt.namespace, tt.service, tt.port) + } + } +} From d605cb489a0f87abf529ec0e928da3a2144fcff9 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 22:56:16 +0100 Subject: [PATCH 07/39] Add ruleK8sLowGPUUtil for utilization-based K8s GPU waste detection --- internal/analysis/rules.go | 30 +++++++++++++ internal/analysis/rules_test.go | 75 +++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) diff --git a/internal/analysis/rules.go b/internal/analysis/rules.go index f91bcbe..8b03d7b 100644 --- a/internal/analysis/rules.go +++ b/internal/analysis/rules.go @@ -28,6 +28,7 @@ func analyzeInstance(inst *models.GPUInstance) { ruleSageMakerLowUtil, ruleSageMakerOversized, ruleK8sUnallocatedGPU, + ruleK8sLowGPUUtil, } for _, rule := range rules { rule(inst) @@ -347,3 +348,32 @@ func ruleK8sUnallocatedGPU(inst *models.GPUInstance) { }) } } + +// Rule 8: K8s GPU node with low GPU utilization (requires DCGM/CW/Prometheus metrics). +func ruleK8sLowGPUUtil(inst *models.GPUInstance) { + if inst.Source != models.SourceK8sNode { + return + } + if inst.AvgGPUUtilization == nil { + return + } + if *inst.AvgGPUUtilization >= 10 { + return + } + + inst.WasteSignals = append(inst.WasteSignals, models.WasteSignal{ + Type: "low_utilization", + Severity: models.SeverityCritical, + Confidence: 0.85, + Evidence: fmt.Sprintf("K8s GPU node utilization averaging %.1f%%. GPUs are allocated but barely used.", *inst.AvgGPUUtilization), + }) + inst.Recommendations = append(inst.Recommendations, models.Recommendation{ + Action: models.ActionDownsize, + Description: fmt.Sprintf("GPU utilization averaging %.1f%%. Consider bin-packing more workloads, downsizing, or removing from the node pool.", *inst.AvgGPUUtilization), + CurrentMonthlyCost: inst.MonthlyCost, + RecommendedMonthlyCost: inst.MonthlyCost * 0.2, + MonthlySavings: inst.MonthlyCost * 0.8, + SavingsPercent: 80, + Risk: models.RiskMedium, + }) +} diff --git a/internal/analysis/rules_test.go b/internal/analysis/rules_test.go index d8d264d..c1d6223 100644 --- a/internal/analysis/rules_test.go +++ b/internal/analysis/rules_test.go @@ -259,3 +259,78 @@ func TestAnalyzeAll_ComputesSavings(t *testing.T) { t.Errorf("expected no signals for healthy instance, got %d", len(instances[1].WasteSignals)) } } + +func TestRuleK8sLowGPUUtil_FlagsLowUtilization(t *testing.T) { + inst := models.GPUInstance{ + InstanceID: "i-node1", + Source: models.SourceK8sNode, + State: "ready", + InstanceType: "g5.xlarge", + GPUModel: "A10G", + GPUCount: 1, + GPUAllocated: 1, + MonthlyCost: 734, + AvgGPUUtilization: ptr(3.5), + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 1 { + t.Fatalf("expected 1 signal, got %d", len(inst.WasteSignals)) + } + if inst.WasteSignals[0].Type != "low_utilization" { + t.Errorf("expected low_utilization, got %s", inst.WasteSignals[0].Type) + } + if inst.WasteSignals[0].Severity != models.SeverityCritical { + t.Errorf("expected critical, got %s", inst.WasteSignals[0].Severity) + } + if inst.WasteSignals[0].Confidence != 0.85 { + t.Errorf("expected confidence 0.85, got %f", inst.WasteSignals[0].Confidence) + } + if len(inst.Recommendations) != 1 { + t.Fatalf("expected 1 recommendation, got %d", len(inst.Recommendations)) + } + if inst.Recommendations[0].MonthlySavings != 734*0.8 { + t.Errorf("expected savings %.0f, got %f", 734*0.8, inst.Recommendations[0].MonthlySavings) + } +} + +func TestRuleK8sLowGPUUtil_SkipsNonK8s(t *testing.T) { + inst := models.GPUInstance{ + Source: models.SourceEC2, + AvgGPUUtilization: ptr(3.5), + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals for EC2 instance") + } +} + +func TestRuleK8sLowGPUUtil_SkipsNoMetrics(t *testing.T) { + inst := models.GPUInstance{ + Source: models.SourceK8sNode, + State: "ready", + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals when metrics unavailable") + } +} + +func TestRuleK8sLowGPUUtil_SkipsHighUtilization(t *testing.T) { + inst := models.GPUInstance{ + Source: models.SourceK8sNode, + State: "ready", + AvgGPUUtilization: ptr(45.0), + } + + ruleK8sLowGPUUtil(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals for well-utilized GPU") + } +} From 54dc0ce724ae5a2ad8b1e5293885bc1d47b440af Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 23:02:04 +0100 Subject: [PATCH 08/39] Wire K8s GPU metrics fallback chain into CLI scan flow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add --prom-url and --prom-endpoint flags (mutually exclusive) for Prometheus GPU metrics. Orchestrate the 3-source fallback chain (CloudWatch Container Insights → DCGM scrape → Prometheus) between K8s discovery and analysis. --- cmd/gpuaudit/main.go | 83 +++++++++++++++++++++++++++++-- internal/providers/k8s/scanner.go | 11 +++- 2 files changed, 87 insertions(+), 7 deletions(-) diff --git a/cmd/gpuaudit/main.go b/cmd/gpuaudit/main.go index ce8d61e..2aca4b5 100644 --- a/cmd/gpuaudit/main.go +++ b/cmd/gpuaudit/main.go @@ -13,12 +13,15 @@ import ( "github.com/spf13/cobra" - "github.com/gpuaudit/cli/internal/models" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/cloudwatch" + "github.com/gpuaudit/cli/internal/analysis" - awsprovider "github.com/gpuaudit/cli/internal/providers/aws" - k8sprovider "github.com/gpuaudit/cli/internal/providers/k8s" + "github.com/gpuaudit/cli/internal/models" "github.com/gpuaudit/cli/internal/output" "github.com/gpuaudit/cli/internal/pricing" + awsprovider "github.com/gpuaudit/cli/internal/providers/aws" + k8sprovider "github.com/gpuaudit/cli/internal/providers/k8s" ) var version = "dev" @@ -49,6 +52,8 @@ var ( scanSkipCosts bool scanKubeconfig string scanKubeContext string + scanPromURL string + scanPromEndpoint string scanExcludeTags []string scanMinUptimeDays int ) @@ -71,6 +76,8 @@ func init() { scanCmd.Flags().BoolVar(&scanSkipCosts, "skip-costs", false, "Skip Cost Explorer data enrichment") scanCmd.Flags().StringVar(&scanKubeconfig, "kubeconfig", "", "Path to kubeconfig file (default: ~/.kube/config)") scanCmd.Flags().StringVar(&scanKubeContext, "kube-context", "", "Kubernetes context to use (default: current context)") + scanCmd.Flags().StringVar(&scanPromURL, "prom-url", "", "Prometheus URL for GPU metrics (e.g., https://prometheus.corp.example.com)") + scanCmd.Flags().StringVar(&scanPromEndpoint, "prom-endpoint", "", "In-cluster Prometheus service as namespace/service:port (e.g., monitoring/prometheus:9090)") scanCmd.Flags().StringSliceVar(&scanExcludeTags, "exclude-tag", nil, "Exclude instances matching tag (key=value, repeatable)") scanCmd.Flags().IntVar(&scanMinUptimeDays, "min-uptime-days", 0, "Only flag instances running for at least this many days") @@ -81,6 +88,10 @@ func init() { } func runScan(cmd *cobra.Command, args []string) error { + if scanPromURL != "" && scanPromEndpoint != "" { + return fmt.Errorf("--prom-url and --prom-endpoint are mutually exclusive") + } + ctx := context.Background() opts := awsprovider.DefaultScanOptions() @@ -106,13 +117,18 @@ func runScan(cmd *cobra.Command, args []string) error { // Kubernetes API scan if !scanSkipK8s { k8sOpts := k8sprovider.ScanOptions{ - Kubeconfig: scanKubeconfig, - Context: scanKubeContext, + Kubeconfig: scanKubeconfig, + Context: scanKubeContext, + PromURL: scanPromURL, + PromEndpoint: scanPromEndpoint, } k8sInstances, err := k8sprovider.Scan(ctx, k8sOpts) if err != nil { fmt.Fprintf(os.Stderr, " warning: Kubernetes scan failed: %v\n", err) } else if len(k8sInstances) > 0 { + if !scanSkipMetrics { + enrichK8sGPUMetrics(ctx, k8sInstances, k8sOpts, opts) + } analysis.AnalyzeAll(k8sInstances) result.Instances = append(result.Instances, k8sInstances...) result.Summary = awsprovider.BuildSummary(result.Instances) @@ -300,3 +316,60 @@ func parseExcludeTags(raw []string) map[string]string { } return tags } + +func enrichK8sGPUMetrics(ctx context.Context, instances []models.GPUInstance, k8sOpts k8sprovider.ScanOptions, awsOpts awsprovider.ScanOptions) { + // Source 1: CloudWatch Container Insights + if len(instances) > 0 && instances[0].ClusterName != "" { + cfgOpts := []func(*awsconfig.LoadOptions) error{} + if awsOpts.Profile != "" { + cfgOpts = append(cfgOpts, awsconfig.WithSharedConfigProfile(awsOpts.Profile)) + } + cfg, err := awsconfig.LoadDefaultConfig(ctx, cfgOpts...) + if err == nil { + region := instances[0].Region + if region == "" { + region = "us-east-1" + } + cfg.Region = region + cwClient := cloudwatch.NewFromConfig(cfg) + fmt.Fprintf(os.Stderr, " Enriching K8s GPU metrics via CloudWatch Container Insights...\n") + awsprovider.EnrichK8sGPUMetrics(ctx, cwClient, instances, instances[0].ClusterName, awsprovider.DefaultMetricWindow) + } + } + + // Source 2: DCGM exporter scrape + remaining := 0 + for _, inst := range instances { + if inst.Source == models.SourceK8sNode && inst.AvgGPUUtilization == nil { + remaining++ + } + } + if remaining > 0 { + client, _, err := k8sprovider.BuildClientPublic(k8sOpts.Kubeconfig, k8sOpts.Context) + if err == nil { + k8sprovider.EnrichDCGMMetrics(ctx, client, instances) + } + } + + // Source 3: Prometheus query + remaining = 0 + for _, inst := range instances { + if inst.Source == models.SourceK8sNode && inst.AvgGPUUtilization == nil { + remaining++ + } + } + if remaining > 0 && (k8sOpts.PromURL != "" || k8sOpts.PromEndpoint != "") { + var client k8sprovider.K8sClient + if k8sOpts.PromEndpoint != "" { + c, _, err := k8sprovider.BuildClientPublic(k8sOpts.Kubeconfig, k8sOpts.Context) + if err == nil { + client = c + } + } + promOpts := k8sprovider.PrometheusOptions{ + URL: k8sOpts.PromURL, + Endpoint: k8sOpts.PromEndpoint, + } + k8sprovider.EnrichPrometheusMetrics(ctx, client, instances, promOpts) + } +} diff --git a/internal/providers/k8s/scanner.go b/internal/providers/k8s/scanner.go index edea338..c35ef88 100644 --- a/internal/providers/k8s/scanner.go +++ b/internal/providers/k8s/scanner.go @@ -19,8 +19,10 @@ import ( // ScanOptions controls Kubernetes GPU scanning. type ScanOptions struct { - Kubeconfig string - Context string + Kubeconfig string + Context string + PromURL string + PromEndpoint string } // Scan discovers GPU nodes in Kubernetes clusters accessible via kubeconfig. @@ -47,6 +49,11 @@ func Scan(ctx context.Context, opts ScanOptions) ([]models.GPUInstance, error) { return instances, nil } +// BuildClientPublic builds a K8s client and returns the cluster name. +func BuildClientPublic(kubeconfigPath, contextName string) (K8sClient, string, error) { + return buildClient(kubeconfigPath, contextName) +} + func buildClient(kubeconfigPath, contextName string) (K8sClient, string, error) { loadingRules := clientcmd.NewDefaultClientConfigLoadingRules() if kubeconfigPath != "" { From c93e0f7feb93b5531aa51a01832ae8a6ac994bdc Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 23:09:03 +0100 Subject: [PATCH 09/39] Fix DCGM node matching and CW error spam MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DCGM enrichment matched pods to instances by InstanceID, but pod.Spec.NodeName is the K8s hostname (e.g. ip-10-22-1-100.ec2.internal) while InstanceID is the EC2 ID (i-0671...). Add K8sNodeName field to GPUInstance and use it for DCGM matching. Also stop retrying CW queries after the first error — all nodes will get the same AccessDenied when credentials aren't available. --- internal/models/models.go | 1 + internal/providers/aws/cloudwatch.go | 15 +++++++++++---- internal/providers/k8s/discover.go | 1 + internal/providers/k8s/metrics.go | 6 +++++- internal/providers/k8s/metrics_test.go | 12 ++++++------ 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/internal/models/models.go b/internal/models/models.go index 0fd6557..8e99dbd 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -66,6 +66,7 @@ type GPUInstance struct { // Kubernetes (populated for k8s-node source) ClusterName string `json:"cluster_name,omitempty"` + K8sNodeName string `json:"k8s_node_name,omitempty"` GPUAllocated int `json:"gpu_allocated,omitempty"` // State diff --git a/internal/providers/aws/cloudwatch.go b/internal/providers/aws/cloudwatch.go index b9d1978..ab06d3e 100644 --- a/internal/providers/aws/cloudwatch.go +++ b/internal/providers/aws/cloudwatch.go @@ -112,6 +112,7 @@ func EnrichK8sGPUMetrics(ctx context.Context, client CloudWatchClient, instances Value: aws.String(clusterName), } + enriched := 0 for _, node := range nodes { instanceDim := cwtypes.Dimension{ Name: aws.String("InstanceId"), @@ -127,12 +128,18 @@ func EnrichK8sGPUMetrics(ctx context.Context, client CloudWatchClient, instances results, err := fetchMetrics(ctx, client, queries, start, now) if err != nil { - fmt.Fprintf(os.Stderr, " warning: Container Insights metrics unavailable for %s: %v\n", node.instanceID, err) - continue + fmt.Fprintf(os.Stderr, " warning: Container Insights metrics unavailable: %v\n", err) + break } - instances[node.index].AvgGPUUtilization = results["gpu_util_"+safeID] - instances[node.index].AvgGPUMemUtilization = results["gpu_mem_"+safeID] + if results["gpu_util_"+safeID] != nil { + instances[node.index].AvgGPUUtilization = results["gpu_util_"+safeID] + instances[node.index].AvgGPUMemUtilization = results["gpu_mem_"+safeID] + enriched++ + } + } + if enriched > 0 { + fmt.Fprintf(os.Stderr, " CloudWatch: got GPU metrics for %d of %d nodes\n", enriched, len(nodes)) } } diff --git a/internal/providers/k8s/discover.go b/internal/providers/k8s/discover.go index 14fe00c..e3316c0 100644 --- a/internal/providers/k8s/discover.go +++ b/internal/providers/k8s/discover.go @@ -164,6 +164,7 @@ func nodeToGPUInstance(node corev1.Node, gpuPods []corev1.Pod, clusterName strin Name: fmt.Sprintf("%s/%s", clusterName, hostname), Tags: tags, ClusterName: clusterName, + K8sNodeName: node.Name, GPUAllocated: gpuAllocated, InstanceType: instanceType, GPUModel: gpuModel, diff --git a/internal/providers/k8s/metrics.go b/internal/providers/k8s/metrics.go index ef47470..5275347 100644 --- a/internal/providers/k8s/metrics.go +++ b/internal/providers/k8s/metrics.go @@ -33,7 +33,11 @@ func EnrichDCGMMetrics(ctx context.Context, client K8sClient, instances []models if inst.Source != models.SourceK8sNode || inst.AvgGPUUtilization != nil { continue } - needsMetrics[inst.InstanceID] = i + key := inst.K8sNodeName + if key == "" { + key = inst.InstanceID + } + needsMetrics[key] = i } if len(needsMetrics) == 0 { return 0 diff --git a/internal/providers/k8s/metrics_test.go b/internal/providers/k8s/metrics_test.go index 329d7eb..4d7e851 100644 --- a/internal/providers/k8s/metrics_test.go +++ b/internal/providers/k8s/metrics_test.go @@ -50,7 +50,7 @@ func TestEnrichDCGMMetrics_PopulatesUtilization(t *testing.T) { nodes: &corev1.NodeList{}, pods: &corev1.PodList{ Items: []corev1.Pod{ - dcgmPod("dcgm-exporter-abc", "gpu-operator", "i-node1"), + dcgmPod("dcgm-exporter-abc", "gpu-operator", "ip-10-22-1-100.ec2.internal"), }, }, proxyData: map[string][]byte{ @@ -58,7 +58,7 @@ func TestEnrichDCGMMetrics_PopulatesUtilization(t *testing.T) { }, } instances := []models.GPUInstance{ - {InstanceID: "i-node1", Source: models.SourceK8sNode, Name: "cluster/i-node1"}, + {InstanceID: "i-abc123", K8sNodeName: "ip-10-22-1-100.ec2.internal", Source: models.SourceK8sNode, Name: "cluster/ip-10-22-1-100"}, } enriched := EnrichDCGMMetrics(context.Background(), client, instances) @@ -86,7 +86,7 @@ func TestEnrichDCGMMetrics_SkipsAlreadyEnriched(t *testing.T) { nodes: &corev1.NodeList{}, pods: &corev1.PodList{ Items: []corev1.Pod{ - dcgmPod("dcgm-exporter-abc", "gpu-operator", "i-node1"), + dcgmPod("dcgm-exporter-abc", "gpu-operator", "node1"), }, }, proxyData: map[string][]byte{ @@ -94,7 +94,7 @@ func TestEnrichDCGMMetrics_SkipsAlreadyEnriched(t *testing.T) { }, } instances := []models.GPUInstance{ - {InstanceID: "i-node1", Source: models.SourceK8sNode, AvgGPUUtilization: &gpuUtil}, + {InstanceID: "i-abc123", K8sNodeName: "node1", Source: models.SourceK8sNode, AvgGPUUtilization: &gpuUtil}, } enriched := EnrichDCGMMetrics(context.Background(), client, instances) @@ -131,13 +131,13 @@ func TestEnrichDCGMMetrics_HandlesScrapeError(t *testing.T) { nodes: &corev1.NodeList{}, pods: &corev1.PodList{ Items: []corev1.Pod{ - dcgmPod("dcgm-exporter-abc", "gpu-operator", "i-node1"), + dcgmPod("dcgm-exporter-abc", "gpu-operator", "node1"), }, }, proxyErr: fmt.Errorf("connection refused"), } instances := []models.GPUInstance{ - {InstanceID: "i-node1", Source: models.SourceK8sNode}, + {InstanceID: "node1", Source: models.SourceK8sNode}, } enriched := EnrichDCGMMetrics(context.Background(), client, instances) From 2640b88f571a14f512fa9d49efe38c7f6762f29a Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 23:20:02 +0100 Subject: [PATCH 10/39] Fix DCGM scrape spam and Prometheus node name mismatch DCGM: stop spamming per-node warnings when scrapes fail consistently (likely RBAC). Log one warning, bail after 3 consecutive failures. Prometheus: use K8sNodeName (the actual K8s hostname) in the PromQL node=~ regex instead of InstanceID (EC2 ID). The Prometheus node label matches K8s hostnames, not EC2 instance IDs. --- internal/providers/k8s/metrics.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/internal/providers/k8s/metrics.go b/internal/providers/k8s/metrics.go index 5275347..180b8e1 100644 --- a/internal/providers/k8s/metrics.go +++ b/internal/providers/k8s/metrics.go @@ -56,6 +56,7 @@ func EnrichDCGMMetrics(ctx context.Context, client K8sClient, instances []models fmt.Fprintf(os.Stderr, " Probing DCGM exporter on GPU nodes...\n") enriched := 0 + scrapeErrors := 0 for _, pod := range dcgmPods { idx, ok := needsMetrics[pod.Spec.NodeName] if !ok { @@ -64,7 +65,14 @@ func EnrichDCGMMetrics(ctx context.Context, client K8sClient, instances []models data, err := client.ProxyGet(ctx, pod.Namespace, pod.Name, "9400", "/metrics") if err != nil { - fmt.Fprintf(os.Stderr, " warning: DCGM scrape failed for %s: %v\n", pod.Spec.NodeName, err) + scrapeErrors++ + if scrapeErrors == 1 { + fmt.Fprintf(os.Stderr, " warning: DCGM scrape failed: %v\n", err) + } + if scrapeErrors >= 3 { + fmt.Fprintf(os.Stderr, " warning: DCGM scrape failing consistently, skipping remaining nodes\n") + break + } continue } @@ -73,6 +81,7 @@ func EnrichDCGMMetrics(ctx context.Context, client K8sClient, instances []models instances[idx].AvgGPUUtilization = gpuUtil instances[idx].AvgGPUMemUtilization = memUtil enriched++ + scrapeErrors = 0 } } @@ -164,7 +173,11 @@ func EnrichPrometheusMetrics(ctx context.Context, client K8sClient, instances [] if inst.Source != models.SourceK8sNode || inst.AvgGPUUtilization != nil { continue } - nodes = append(nodes, nodeRef{index: i, name: inst.InstanceID}) + name := inst.K8sNodeName + if name == "" { + name = inst.InstanceID + } + nodes = append(nodes, nodeRef{index: i, name: name}) } if len(nodes) == 0 { return 0 From b08c0259216c2afce1f9bd812a337e373dc28846 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 23:26:23 +0100 Subject: [PATCH 11/39] Include time window in low GPU utilization recommendation text --- internal/analysis/rules.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/analysis/rules.go b/internal/analysis/rules.go index 8b03d7b..a676f7f 100644 --- a/internal/analysis/rules.go +++ b/internal/analysis/rules.go @@ -365,11 +365,11 @@ func ruleK8sLowGPUUtil(inst *models.GPUInstance) { Type: "low_utilization", Severity: models.SeverityCritical, Confidence: 0.85, - Evidence: fmt.Sprintf("K8s GPU node utilization averaging %.1f%%. GPUs are allocated but barely used.", *inst.AvgGPUUtilization), + Evidence: fmt.Sprintf("K8s GPU node utilization averaging %.1f%% over the past 7 days. GPUs are allocated but barely used.", *inst.AvgGPUUtilization), }) inst.Recommendations = append(inst.Recommendations, models.Recommendation{ Action: models.ActionDownsize, - Description: fmt.Sprintf("GPU utilization averaging %.1f%%. Consider bin-packing more workloads, downsizing, or removing from the node pool.", *inst.AvgGPUUtilization), + Description: fmt.Sprintf("GPU utilization averaging %.1f%% over the past 7 days. Consider bin-packing more workloads, downsizing, or removing from the node pool.", *inst.AvgGPUUtilization), CurrentMonthlyCost: inst.MonthlyCost, RecommendedMonthlyCost: inst.MonthlyCost * 0.2, MonthlySavings: inst.MonthlyCost * 0.8, From e846cc8088995ffd1e892c4c6e24adf2a7b873a3 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 23:41:41 +0100 Subject: [PATCH 12/39] Skip CW enrichment when AWS creds unavailable, reduce DCGM noise --- cmd/gpuaudit/main.go | 10 ++++++---- internal/providers/k8s/metrics.go | 4 +++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/cmd/gpuaudit/main.go b/cmd/gpuaudit/main.go index 2aca4b5..08b3e3c 100644 --- a/cmd/gpuaudit/main.go +++ b/cmd/gpuaudit/main.go @@ -104,8 +104,10 @@ func runScan(cmd *cobra.Command, args []string) error { opts.ExcludeTags = parseExcludeTags(scanExcludeTags) opts.MinUptimeDays = scanMinUptimeDays + awsAvailable := true result, err := awsprovider.Scan(ctx, opts) if err != nil { + awsAvailable = false if scanSkipK8s { return fmt.Errorf("scan failed: %w", err) } @@ -127,7 +129,7 @@ func runScan(cmd *cobra.Command, args []string) error { fmt.Fprintf(os.Stderr, " warning: Kubernetes scan failed: %v\n", err) } else if len(k8sInstances) > 0 { if !scanSkipMetrics { - enrichK8sGPUMetrics(ctx, k8sInstances, k8sOpts, opts) + enrichK8sGPUMetrics(ctx, k8sInstances, k8sOpts, opts, awsAvailable) } analysis.AnalyzeAll(k8sInstances) result.Instances = append(result.Instances, k8sInstances...) @@ -317,9 +319,9 @@ func parseExcludeTags(raw []string) map[string]string { return tags } -func enrichK8sGPUMetrics(ctx context.Context, instances []models.GPUInstance, k8sOpts k8sprovider.ScanOptions, awsOpts awsprovider.ScanOptions) { - // Source 1: CloudWatch Container Insights - if len(instances) > 0 && instances[0].ClusterName != "" { +func enrichK8sGPUMetrics(ctx context.Context, instances []models.GPUInstance, k8sOpts k8sprovider.ScanOptions, awsOpts awsprovider.ScanOptions, awsAvailable bool) { + // Source 1: CloudWatch Container Insights (skip if AWS creds unavailable) + if awsAvailable && len(instances) > 0 && instances[0].ClusterName != "" { cfgOpts := []func(*awsconfig.LoadOptions) error{} if awsOpts.Profile != "" { cfgOpts = append(cfgOpts, awsconfig.WithSharedConfigProfile(awsOpts.Profile)) diff --git a/internal/providers/k8s/metrics.go b/internal/providers/k8s/metrics.go index 180b8e1..4a587c2 100644 --- a/internal/providers/k8s/metrics.go +++ b/internal/providers/k8s/metrics.go @@ -85,7 +85,9 @@ func EnrichDCGMMetrics(ctx context.Context, client K8sClient, instances []models } } - fmt.Fprintf(os.Stderr, " DCGM: got GPU metrics for %d of %d remaining nodes\n", enriched, len(needsMetrics)) + if enriched > 0 { + fmt.Fprintf(os.Stderr, " DCGM: got GPU metrics for %d of %d remaining nodes\n", enriched, len(needsMetrics)) + } return enriched } From ff5b7f5b0f5f6bf14e167eff8956249b44181e63 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Wed, 15 Apr 2026 01:03:26 +0100 Subject: [PATCH 13/39] Add diff package with Compare function and tests Compares two scan results by instance ID. Detects added, removed, and changed instances across 6 fields (instance type, pricing model, cost, state, GPU allocation, waste severity). Computes cost deltas. --- internal/diff/diff.go | 171 +++++++++++++++++++++++++++++ internal/diff/diff_test.go | 219 +++++++++++++++++++++++++++++++++++++ 2 files changed, 390 insertions(+) create mode 100644 internal/diff/diff.go create mode 100644 internal/diff/diff_test.go diff --git a/internal/diff/diff.go b/internal/diff/diff.go new file mode 100644 index 0000000..7d74430 --- /dev/null +++ b/internal/diff/diff.go @@ -0,0 +1,171 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package diff compares two scan results and reports what changed. +package diff + +import ( + "fmt" + + "github.com/gpuaudit/cli/internal/models" +) + +// DiffResult holds the comparison between two scan results. +type DiffResult struct { + OldTimestamp string `json:"old_timestamp"` + NewTimestamp string `json:"new_timestamp"` + Added []models.GPUInstance `json:"added,omitempty"` + Removed []models.GPUInstance `json:"removed,omitempty"` + Changed []InstanceDiff `json:"changed,omitempty"` + UnchangedCount int `json:"unchanged_count"` + CostSummary CostDelta `json:"cost_summary"` +} + +// InstanceDiff describes what changed for a single instance between scans. +type InstanceDiff struct { + InstanceID string `json:"instance_id"` + Old models.GPUInstance `json:"old"` + New models.GPUInstance `json:"new"` + CostDelta float64 `json:"cost_delta"` + Changes []string `json:"changes"` +} + +// CostDelta summarizes the financial impact of changes between scans. +type CostDelta struct { + OldTotalMonthlyCost float64 `json:"old_total_monthly_cost"` + NewTotalMonthlyCost float64 `json:"new_total_monthly_cost"` + CostChange float64 `json:"cost_change"` + OldTotalWaste float64 `json:"old_total_waste"` + NewTotalWaste float64 `json:"new_total_waste"` + WasteChange float64 `json:"waste_change"` + AddedCost float64 `json:"added_cost"` + RemovedSavings float64 `json:"removed_savings"` +} + +// Compare computes the diff between two scan results, matching instances by ID. +func Compare(old, new *models.ScanResult) *DiffResult { + oldMap := make(map[string]models.GPUInstance, len(old.Instances)) + for _, inst := range old.Instances { + oldMap[inst.InstanceID] = inst + } + + newMap := make(map[string]models.GPUInstance, len(new.Instances)) + for _, inst := range new.Instances { + newMap[inst.InstanceID] = inst + } + + result := &DiffResult{ + OldTimestamp: old.Timestamp.Format("2006-01-02 15:04 UTC"), + NewTimestamp: new.Timestamp.Format("2006-01-02 15:04 UTC"), + } + + // Find removed and changed + for id, oldInst := range oldMap { + newInst, exists := newMap[id] + if !exists { + result.Removed = append(result.Removed, oldInst) + continue + } + changes := diffInstance(oldInst, newInst) + if len(changes) > 0 { + result.Changed = append(result.Changed, InstanceDiff{ + InstanceID: id, + Old: oldInst, + New: newInst, + CostDelta: newInst.MonthlyCost - oldInst.MonthlyCost, + Changes: changes, + }) + } else { + result.UnchangedCount++ + } + } + + // Find added + for id, newInst := range newMap { + if _, exists := oldMap[id]; !exists { + result.Added = append(result.Added, newInst) + } + } + + // Cost summary + result.CostSummary = computeCostDelta(old, new, result) + + return result +} + +func diffInstance(old, new models.GPUInstance) []string { + var changes []string + + if old.InstanceType != new.InstanceType { + changes = append(changes, fmt.Sprintf("Instance type: %s → %s", old.InstanceType, new.InstanceType)) + } + if old.PricingModel != new.PricingModel { + changes = append(changes, fmt.Sprintf("Pricing: %s → %s", old.PricingModel, new.PricingModel)) + } + if old.MonthlyCost != new.MonthlyCost { + delta := new.MonthlyCost - old.MonthlyCost + changes = append(changes, fmt.Sprintf("Cost: $%.0f → $%.0f (%s/mo)", old.MonthlyCost, new.MonthlyCost, fmtDelta(delta))) + } + if old.State != new.State { + changes = append(changes, fmt.Sprintf("State: %s → %s", old.State, new.State)) + } + if old.GPUAllocated != new.GPUAllocated { + changes = append(changes, fmt.Sprintf("GPU allocated: %d → %d", old.GPUAllocated, new.GPUAllocated)) + } + if maxSeverityStr(old.WasteSignals) != maxSeverityStr(new.WasteSignals) { + oldSev := maxSeverityStr(old.WasteSignals) + newSev := maxSeverityStr(new.WasteSignals) + if oldSev == "" { + oldSev = "(none)" + } + if newSev == "" { + newSev = "(none)" + } + changes = append(changes, fmt.Sprintf("Severity: %s → %s", oldSev, newSev)) + } + + return changes +} + +func maxSeverityStr(signals []models.WasteSignal) string { + max := models.Severity("") + for _, s := range signals { + if s.Severity == models.SeverityCritical { + return string(models.SeverityCritical) + } + if s.Severity == models.SeverityWarning { + max = models.SeverityWarning + } + if s.Severity == models.SeverityInfo && max == "" { + max = models.SeverityInfo + } + } + return string(max) +} + +func fmtDelta(v float64) string { + if v >= 0 { + return fmt.Sprintf("+$%.0f", v) + } + return fmt.Sprintf("-$%.0f", -v) +} + +func computeCostDelta(old, new *models.ScanResult, diff *DiffResult) CostDelta { + cd := CostDelta{ + OldTotalMonthlyCost: old.Summary.TotalMonthlyCost, + NewTotalMonthlyCost: new.Summary.TotalMonthlyCost, + CostChange: new.Summary.TotalMonthlyCost - old.Summary.TotalMonthlyCost, + OldTotalWaste: old.Summary.TotalEstimatedWaste, + NewTotalWaste: new.Summary.TotalEstimatedWaste, + WasteChange: new.Summary.TotalEstimatedWaste - old.Summary.TotalEstimatedWaste, + } + + for _, inst := range diff.Added { + cd.AddedCost += inst.MonthlyCost + } + for _, inst := range diff.Removed { + cd.RemovedSavings += inst.MonthlyCost + } + + return cd +} diff --git a/internal/diff/diff_test.go b/internal/diff/diff_test.go new file mode 100644 index 0000000..35d4f1f --- /dev/null +++ b/internal/diff/diff_test.go @@ -0,0 +1,219 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package diff + +import ( + "testing" + "time" + + "github.com/gpuaudit/cli/internal/models" +) + +func scanResult(instances ...models.GPUInstance) *models.ScanResult { + return &models.ScanResult{ + Timestamp: time.Date(2026, 4, 8, 12, 0, 0, 0, time.UTC), + Instances: instances, + Summary: models.ScanSummary{ + TotalInstances: len(instances), + TotalMonthlyCost: sumMonthlyCost(instances), + TotalEstimatedWaste: sumWaste(instances), + }, + } +} + +func sumMonthlyCost(instances []models.GPUInstance) float64 { + var total float64 + for _, inst := range instances { + total += inst.MonthlyCost + } + return total +} + +func sumWaste(instances []models.GPUInstance) float64 { + var total float64 + for _, inst := range instances { + total += inst.EstimatedSavings + } + return total +} + +func inst(id string, monthlyCost float64) models.GPUInstance { + return models.GPUInstance{ + InstanceID: id, + InstanceType: "g6e.16xlarge", + GPUModel: "L40S", + GPUCount: 1, + MonthlyCost: monthlyCost, + HourlyCost: monthlyCost / 730, + State: "ready", + Source: models.SourceK8sNode, + PricingModel: "on-demand", + } +} + +func TestCompare_AddedInstances(t *testing.T) { + old := scanResult(inst("i-aaa", 6750)) + new := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + + result := Compare(old, new) + + if len(result.Added) != 1 { + t.Fatalf("expected 1 added, got %d", len(result.Added)) + } + if result.Added[0].InstanceID != "i-bbb" { + t.Errorf("expected added instance i-bbb, got %s", result.Added[0].InstanceID) + } + if result.CostSummary.AddedCost != 3000 { + t.Errorf("expected added cost 3000, got %.0f", result.CostSummary.AddedCost) + } +} + +func TestCompare_RemovedInstances(t *testing.T) { + old := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + new := scanResult(inst("i-aaa", 6750)) + + result := Compare(old, new) + + if len(result.Removed) != 1 { + t.Fatalf("expected 1 removed, got %d", len(result.Removed)) + } + if result.Removed[0].InstanceID != "i-bbb" { + t.Errorf("expected removed instance i-bbb, got %s", result.Removed[0].InstanceID) + } + if result.CostSummary.RemovedSavings != 3000 { + t.Errorf("expected removed savings 3000, got %.0f", result.CostSummary.RemovedSavings) + } +} + +func TestCompare_CostChange(t *testing.T) { + old := scanResult(inst("i-aaa", 6750)) + new := scanResult(inst("i-aaa", 4200)) + + result := Compare(old, new) + + if len(result.Changed) != 1 { + t.Fatalf("expected 1 changed, got %d", len(result.Changed)) + } + if result.Changed[0].CostDelta != -2550 { + t.Errorf("expected cost delta -2550, got %.0f", result.Changed[0].CostDelta) + } + found := false + for _, c := range result.Changed[0].Changes { + if c == "Cost: $6750 → $4200 (-$2550/mo)" { + found = true + } + } + if !found { + t.Errorf("expected cost change string, got %v", result.Changed[0].Changes) + } +} + +func TestCompare_AllFieldChanges(t *testing.T) { + oldInst := inst("i-aaa", 6750) + oldInst.InstanceType = "g6e.16xlarge" + oldInst.PricingModel = "on-demand" + oldInst.State = "ready" + oldInst.GPUAllocated = 0 + oldInst.WasteSignals = []models.WasteSignal{{Severity: models.SeverityCritical}} + + newInst := inst("i-aaa", 4200) + newInst.InstanceType = "g6e.12xlarge" + newInst.PricingModel = "reserved" + newInst.State = "not-ready" + newInst.GPUAllocated = 2 + newInst.WasteSignals = nil + + old := scanResult(oldInst) + new := scanResult(newInst) + + result := Compare(old, new) + + if len(result.Changed) != 1 { + t.Fatalf("expected 1 changed, got %d", len(result.Changed)) + } + + changes := result.Changed[0].Changes + expected := []string{ + "Instance type: g6e.16xlarge → g6e.12xlarge", + "Pricing: on-demand → reserved", + "Cost: $6750 → $4200 (-$2550/mo)", + "State: ready → not-ready", + "GPU allocated: 0 → 2", + "Severity: critical → (none)", + } + if len(changes) != len(expected) { + t.Fatalf("expected %d changes, got %d: %v", len(expected), len(changes), changes) + } + for i, exp := range expected { + if changes[i] != exp { + t.Errorf("change[%d]: expected %q, got %q", i, exp, changes[i]) + } + } +} + +func TestCompare_UnchangedInstances(t *testing.T) { + old := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + new := scanResult(inst("i-aaa", 6750), inst("i-bbb", 3000)) + + result := Compare(old, new) + + if len(result.Added) != 0 { + t.Errorf("expected 0 added, got %d", len(result.Added)) + } + if len(result.Removed) != 0 { + t.Errorf("expected 0 removed, got %d", len(result.Removed)) + } + if len(result.Changed) != 0 { + t.Errorf("expected 0 changed, got %d", len(result.Changed)) + } + if result.UnchangedCount != 2 { + t.Errorf("expected 2 unchanged, got %d", result.UnchangedCount) + } +} + +func TestCompare_CostSummary(t *testing.T) { + oldA := inst("i-aaa", 6750) + oldA.EstimatedSavings = 6750 + oldB := inst("i-bbb", 3000) + + newA := inst("i-aaa", 6750) + newA.EstimatedSavings = 6750 + newC := inst("i-ccc", 2000) + + old := scanResult(oldA, oldB) + new := scanResult(newA, newC) + + result := Compare(old, new) + + cs := result.CostSummary + if cs.OldTotalMonthlyCost != 9750 { + t.Errorf("expected old total 9750, got %.0f", cs.OldTotalMonthlyCost) + } + if cs.NewTotalMonthlyCost != 8750 { + t.Errorf("expected new total 8750, got %.0f", cs.NewTotalMonthlyCost) + } + if cs.CostChange != -1000 { + t.Errorf("expected cost change -1000, got %.0f", cs.CostChange) + } + if cs.RemovedSavings != 3000 { + t.Errorf("expected removed savings 3000, got %.0f", cs.RemovedSavings) + } + if cs.AddedCost != 2000 { + t.Errorf("expected added cost 2000, got %.0f", cs.AddedCost) + } +} + +func TestCompare_EmptyScans(t *testing.T) { + old := scanResult() + new := scanResult() + + result := Compare(old, new) + + if len(result.Added) != 0 || len(result.Removed) != 0 || len(result.Changed) != 0 { + t.Errorf("expected no changes for empty scans") + } + if result.UnchangedCount != 0 { + t.Errorf("expected 0 unchanged, got %d", result.UnchangedCount) + } +} From 1eed3b8184d11589810d56689c0d393eb14b76db Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Wed, 15 Apr 2026 01:04:33 +0100 Subject: [PATCH 14/39] Add diff table and JSON output formatters --- internal/output/diff.go | 145 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 internal/output/diff.go diff --git a/internal/output/diff.go b/internal/output/diff.go new file mode 100644 index 0000000..2bd5753 --- /dev/null +++ b/internal/output/diff.go @@ -0,0 +1,145 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package output + +import ( + "encoding/json" + "fmt" + "io" + "sort" + "strings" + + "github.com/gpuaudit/cli/internal/diff" + "github.com/gpuaudit/cli/internal/models" +) + +// FormatDiffTable writes a human-readable diff report. +func FormatDiffTable(w io.Writer, d *diff.DiffResult) { + fmt.Fprintf(w, "\n gpuaudit diff — %s → %s\n\n", d.OldTimestamp, d.NewTimestamp) + + cs := d.CostSummary + + oldCount := len(d.Removed) + len(d.Changed) + d.UnchangedCount + newCount := len(d.Added) + len(d.Changed) + d.UnchangedCount + + // Cost summary box + fmt.Fprintf(w, " ┌──────────────────────────────────────────────────────────┐\n") + fmt.Fprintf(w, " │ Cost Delta │\n") + fmt.Fprintf(w, " ├──────────────────────────────────────────────────────────┤\n") + fmt.Fprintf(w, " │ Monthly spend: $%-9.0f → $%-9.0f (%s)%s│\n", + cs.OldTotalMonthlyCost, cs.NewTotalMonthlyCost, + diffFmtDelta(cs.CostChange), diffPad(cs.CostChange)) + fmt.Fprintf(w, " │ Estimated waste: $%-9.0f → $%-9.0f (%s)%s│\n", + cs.OldTotalWaste, cs.NewTotalWaste, + diffFmtDelta(cs.WasteChange), diffPad(cs.WasteChange)) + fmt.Fprintf(w, " │ Instances: %-3d → %-3d (-%d removed, +%d added)%s│\n", + oldCount, newCount, len(d.Removed), len(d.Added), + diffPadInstances(oldCount, newCount, len(d.Removed), len(d.Added))) + fmt.Fprintf(w, " └──────────────────────────────────────────────────────────┘\n") + + // Removed + if len(d.Removed) > 0 { + sortInstancesByCost(d.Removed) + fmt.Fprintf(w, "\n REMOVED — %d instance(s), -$%.0f/mo\n\n", len(d.Removed), cs.RemovedSavings) + printDiffInstanceTable(w, d.Removed) + } + + // Added + if len(d.Added) > 0 { + sortInstancesByCost(d.Added) + fmt.Fprintf(w, "\n ADDED — %d instance(s), +$%.0f/mo\n\n", len(d.Added), cs.AddedCost) + printDiffInstanceTable(w, d.Added) + } + + // Changed + if len(d.Changed) > 0 { + fmt.Fprintf(w, "\n CHANGED — %d instance(s)\n\n", len(d.Changed)) + fmt.Fprintf(w, " %-36s %s\n", "Instance", "Change") + fmt.Fprintf(w, " %s %s\n", strings.Repeat("─", 36), strings.Repeat("─", 50)) + for _, c := range d.Changed { + name := c.New.Name + if name == "" { + name = c.InstanceID + } + if len(name) > 34 { + name = name[:31] + "..." + } + for i, change := range c.Changes { + if i == 0 { + fmt.Fprintf(w, " %-36s %s\n", name, change) + } else { + fmt.Fprintf(w, " %-36s %s\n", "", change) + } + } + } + fmt.Fprintln(w) + } + + // Unchanged + if d.UnchangedCount > 0 { + fmt.Fprintf(w, " UNCHANGED — %d instance(s)\n\n", d.UnchangedCount) + } +} + +func printDiffInstanceTable(w io.Writer, instances []models.GPUInstance) { + fmt.Fprintf(w, " %-36s %-26s %10s\n", "Instance", "Type", "Monthly") + fmt.Fprintf(w, " %s %s %s\n", + strings.Repeat("─", 36), strings.Repeat("─", 26), strings.Repeat("─", 10)) + for _, inst := range instances { + name := inst.Name + if name == "" { + name = inst.InstanceID + } + if len(name) > 34 { + name = name[:31] + "..." + } + gpuDesc := fmt.Sprintf("%d× %s", inst.GPUCount, inst.GPUModel) + typeDesc := fmt.Sprintf("%s (%s)", inst.InstanceType, gpuDesc) + if len(typeDesc) > 26 { + typeDesc = typeDesc[:23] + "..." + } + fmt.Fprintf(w, " %-36s %-26s $%9.0f\n", name, typeDesc, inst.MonthlyCost) + } +} + +func sortInstancesByCost(instances []models.GPUInstance) { + sort.Slice(instances, func(i, j int) bool { + return instances[i].MonthlyCost > instances[j].MonthlyCost + }) +} + +func diffFmtDelta(v float64) string { + if v >= 0 { + return fmt.Sprintf("+$%.0f", v) + } + return fmt.Sprintf("-$%.0f", -v) +} + +// diffPad returns spaces to align the summary box closing border. +func diffPad(delta float64) string { + s := diffFmtDelta(delta) + // The content before the delta is ~44 chars, delta is variable, need to fill to col 59 + used := 44 + len(s) + 2 // +2 for parens + target := 59 + if used >= target { + return "" + } + return strings.Repeat(" ", target-used) +} + +func diffPadInstances(oldCount, newCount, removed, added int) string { + content := fmt.Sprintf(" │ Instances: %-3d → %-3d (-%d removed, +%d added)", + oldCount, newCount, removed, added) + if len(content) >= 59 { + return "" + } + return strings.Repeat(" ", 59-len(content)) +} + +// FormatDiffJSON writes the diff result as pretty-printed JSON. +func FormatDiffJSON(w io.Writer, d *diff.DiffResult) error { + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + return enc.Encode(d) +} From 567222998d9bdafd7ebbc3a8ed4cd39a3d5b0fdc Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Wed, 15 Apr 2026 01:05:37 +0100 Subject: [PATCH 15/39] Add diff subcommand to compare two scan results gpuaudit diff old.json new.json [--format table|json] Closes #5 --- cmd/gpuaudit/main.go | 49 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/cmd/gpuaudit/main.go b/cmd/gpuaudit/main.go index 08b3e3c..a057c1c 100644 --- a/cmd/gpuaudit/main.go +++ b/cmd/gpuaudit/main.go @@ -17,6 +17,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/cloudwatch" "github.com/gpuaudit/cli/internal/analysis" + "github.com/gpuaudit/cli/internal/diff" "github.com/gpuaudit/cli/internal/models" "github.com/gpuaudit/cli/internal/output" "github.com/gpuaudit/cli/internal/pricing" @@ -58,6 +59,17 @@ var ( scanMinUptimeDays int ) +// --- diff command --- + +var diffFormat string + +var diffCmd = &cobra.Command{ + Use: "diff ", + Short: "Compare two scan results and show what changed", + Args: cobra.ExactArgs(2), + RunE: runDiff, +} + var scanCmd = &cobra.Command{ Use: "scan", Short: "Scan AWS account for GPU waste", @@ -81,7 +93,10 @@ func init() { scanCmd.Flags().StringSliceVar(&scanExcludeTags, "exclude-tag", nil, "Exclude instances matching tag (key=value, repeatable)") scanCmd.Flags().IntVar(&scanMinUptimeDays, "min-uptime-days", 0, "Only flag instances running for at least this many days") + diffCmd.Flags().StringVar(&diffFormat, "format", "table", "Output format: table, json") + rootCmd.AddCommand(scanCmd) + rootCmd.AddCommand(diffCmd) rootCmd.AddCommand(pricingCmd) rootCmd.AddCommand(iamPolicyCmd) rootCmd.AddCommand(versionCmd) @@ -162,6 +177,40 @@ func runScan(cmd *cobra.Command, args []string) error { return nil } +func runDiff(cmd *cobra.Command, args []string) error { + old, err := loadScanResult(args[0]) + if err != nil { + return fmt.Errorf("loading old scan: %w", err) + } + new, err := loadScanResult(args[1]) + if err != nil { + return fmt.Errorf("loading new scan: %w", err) + } + + result := diff.Compare(old, new) + + switch strings.ToLower(diffFormat) { + case "json": + return output.FormatDiffJSON(os.Stdout, result) + default: + output.FormatDiffTable(os.Stdout, result) + } + + return nil +} + +func loadScanResult(path string) (*models.ScanResult, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var result models.ScanResult + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("parsing %s: %w", path, err) + } + return &result, nil +} + // --- pricing command --- var pricingGPU string From 2e95784c59923f4fa97cca6a6dc05a6495ff0238 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Wed, 15 Apr 2026 09:15:45 +0100 Subject: [PATCH 16/39] Fix box alignment in diff table output --- internal/output/diff.go | 55 ++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/internal/output/diff.go b/internal/output/diff.go index 2bd5753..db2f7c9 100644 --- a/internal/output/diff.go +++ b/internal/output/diff.go @@ -24,19 +24,18 @@ func FormatDiffTable(w io.Writer, d *diff.DiffResult) { newCount := len(d.Added) + len(d.Changed) + d.UnchangedCount // Cost summary box - fmt.Fprintf(w, " ┌──────────────────────────────────────────────────────────┐\n") - fmt.Fprintf(w, " │ Cost Delta │\n") - fmt.Fprintf(w, " ├──────────────────────────────────────────────────────────┤\n") - fmt.Fprintf(w, " │ Monthly spend: $%-9.0f → $%-9.0f (%s)%s│\n", - cs.OldTotalMonthlyCost, cs.NewTotalMonthlyCost, - diffFmtDelta(cs.CostChange), diffPad(cs.CostChange)) - fmt.Fprintf(w, " │ Estimated waste: $%-9.0f → $%-9.0f (%s)%s│\n", - cs.OldTotalWaste, cs.NewTotalWaste, - diffFmtDelta(cs.WasteChange), diffPad(cs.WasteChange)) - fmt.Fprintf(w, " │ Instances: %-3d → %-3d (-%d removed, +%d added)%s│\n", - oldCount, newCount, len(d.Removed), len(d.Added), - diffPadInstances(oldCount, newCount, len(d.Removed), len(d.Added))) - fmt.Fprintf(w, " └──────────────────────────────────────────────────────────┘\n") + boxWidth := 58 // inner width between │ markers + boxLine := strings.Repeat("─", boxWidth) + fmt.Fprintf(w, " ┌%s┐\n", boxLine) + writeBoxLine(w, "Cost Delta", boxWidth) + fmt.Fprintf(w, " ├%s┤\n", boxLine) + writeBoxLine(w, fmt.Sprintf("Monthly spend: $%-9.0f → $%-9.0f (%s)", + cs.OldTotalMonthlyCost, cs.NewTotalMonthlyCost, diffFmtDelta(cs.CostChange)), boxWidth) + writeBoxLine(w, fmt.Sprintf("Estimated waste: $%-9.0f → $%-9.0f (%s)", + cs.OldTotalWaste, cs.NewTotalWaste, diffFmtDelta(cs.WasteChange)), boxWidth) + writeBoxLine(w, fmt.Sprintf("Instances: %d → %d (-%d removed, +%d added)", + oldCount, newCount, len(d.Removed), len(d.Added)), boxWidth) + fmt.Fprintf(w, " └%s┘\n", boxLine) // Removed if len(d.Removed) > 0 { @@ -109,6 +108,15 @@ func sortInstancesByCost(instances []models.GPUInstance) { }) } +func writeBoxLine(w io.Writer, content string, width int) { + // Pad content to fill the box width (with 2-char margin on each side) + inner := width - 4 // 2 spaces on each side + if len(content) > inner { + content = content[:inner] + } + fmt.Fprintf(w, " │ %-*s │\n", inner, content) +} + func diffFmtDelta(v float64) string { if v >= 0 { return fmt.Sprintf("+$%.0f", v) @@ -116,27 +124,6 @@ func diffFmtDelta(v float64) string { return fmt.Sprintf("-$%.0f", -v) } -// diffPad returns spaces to align the summary box closing border. -func diffPad(delta float64) string { - s := diffFmtDelta(delta) - // The content before the delta is ~44 chars, delta is variable, need to fill to col 59 - used := 44 + len(s) + 2 // +2 for parens - target := 59 - if used >= target { - return "" - } - return strings.Repeat(" ", target-used) -} - -func diffPadInstances(oldCount, newCount, removed, added int) string { - content := fmt.Sprintf(" │ Instances: %-3d → %-3d (-%d removed, +%d added)", - oldCount, newCount, removed, added) - if len(content) >= 59 { - return "" - } - return strings.Repeat(" ", 59-len(content)) -} - // FormatDiffJSON writes the diff result as pretty-printed JSON. func FormatDiffJSON(w io.Writer, d *diff.DiffResult) error { enc := json.NewEncoder(w) From 01b45124f7f8f51847117885bd2934e1968356e3 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Wed, 15 Apr 2026 12:33:54 +0100 Subject: [PATCH 17/39] Fix misleading idle duration in K8s GPU node recommendations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The recommendation said "No GPU pods scheduled for X days" but X was the node's total uptime, not the idle duration. We don't know when the node became idle — only that it currently has zero GPU pods. Changed wording to "Node up X days with 0 GPU pods scheduled." --- internal/analysis/rules.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/analysis/rules.go b/internal/analysis/rules.go index a676f7f..b975a1c 100644 --- a/internal/analysis/rules.go +++ b/internal/analysis/rules.go @@ -318,11 +318,11 @@ func ruleK8sUnallocatedGPU(inst *models.GPUInstance) { Type: "idle", Severity: models.SeverityCritical, Confidence: 0.9, - Evidence: fmt.Sprintf("GPU node has %d GPU(s) but no pods requesting GPUs for %.0f+ hours.", inst.GPUCount, inst.UptimeHours), + Evidence: fmt.Sprintf("GPU node has %d GPU(s) but no pods requesting GPUs. Node up for %d days.", inst.GPUCount, int(inst.UptimeHours/24)), }) inst.Recommendations = append(inst.Recommendations, models.Recommendation{ Action: models.ActionTerminate, - Description: fmt.Sprintf("No GPU pods scheduled on this node for %d days. Remove from node pool or scale down.", int(inst.UptimeHours/24)), + Description: fmt.Sprintf("Node up %d days with 0 GPU pods scheduled. Remove from node pool or scale down.", int(inst.UptimeHours/24)), CurrentMonthlyCost: inst.MonthlyCost, MonthlySavings: inst.MonthlyCost, SavingsPercent: 100, From c9ee92dc47f817027aa86b577171f46cfb592510 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Wed, 15 Apr 2026 14:36:26 +0100 Subject: [PATCH 18/39] Update README with K8s scanning, diff command, and current output format --- README.md | 120 ++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 86 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index f3c05dc..c2396c7 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,39 @@ # gpuaudit -Scan your AWS account for GPU waste and get actionable recommendations to cut your cloud spend. +Scan your cloud for GPU waste and get actionable recommendations to cut your spend. ``` -$ gpuaudit scan --profile ml-prod +$ gpuaudit scan --skip-eks - GPU Fleet Summary - Total GPU instances: 14 - Total monthly GPU spend: $47,832 - Estimated monthly waste: $18,240 (38%) + Found 103 GPU nodes across 111 nodes in gpu-cluster - CRITICAL (3 instances, $8,940/mo potential savings) + gpuaudit — GPU Cost Audit for AWS + Account: 123456789012 | Regions: us-east-1 | Duration: 4.2s - i-0a1b2c3d4e g5.12xlarge (4x A10G) $4,380/mo Idle — no activity for 18 days → terminate - i-9f8e7d6c5b p4d.24xlarge (8x A100) $23,652/mo Idle — <1% CPU for 6 days → terminate - sagemaker:asr ml.g6.48xlarge (8x L40S) $9,490/mo GPU util avg 8% → downsize to ml.g5.xlarge + ┌──────────────────────────────────────────────────────────┐ + │ GPU Fleet Summary │ + ├──────────────────────────────────────────────────────────┤ + │ Total GPU instances: 103 │ + │ Total monthly GPU spend: $365155 │ + │ Estimated monthly waste: $23408 ( 6%) │ + └──────────────────────────────────────────────────────────┘ + + CRITICAL — 4 instance(s), $21728/mo potential savings + + Instance Type Monthly Signal Recommendation + ──────────────────────────────────── ────────────────────────── ──────── ──────────────── ────────────────────────────────────────────── + gpu-cluster/ip-10-15-255-248 g6e.16xlarge (1× L40S) $ 6752 idle Node up 13 days with 0 GPU pods scheduled. + gpu-cluster/ip-10-22-250-15 g6e.16xlarge (1× L40S) $ 6752 idle Node up 1 days with 0 GPU pods scheduled. + ... ``` +## What it scans + +- **EC2** — GPU instances (g4dn, g5, g6, g6e, p4d, p4de, p5, inf2, trn1) with CloudWatch metrics +- **SageMaker** — Endpoints with GPU utilization and invocation metrics +- **EKS** — Managed GPU node groups via the AWS EKS API +- **Kubernetes** — GPU nodes and pod allocation via the Kubernetes API (Karpenter, self-managed, any CNI) + ## What it detects - **Idle GPU instances** — running but doing nothing (low CPU + near-zero network for 24+ hours) @@ -25,6 +42,7 @@ $ gpuaudit scan --profile ml-prod - **Stale instances** — non-production instances running 90+ days - **SageMaker low utilization** — endpoints with <10% GPU utilization - **SageMaker oversized** — endpoints using <30% GPU memory on multi-GPU instances +- **K8s unallocated GPUs** — nodes with GPU capacity but no pods requesting GPUs ## Install @@ -36,7 +54,7 @@ Or build from source: ```bash git clone https://github.com/gpuaudit/cli.git -cd gpuaudit +cd cli go build -o gpuaudit ./cmd/gpuaudit ``` @@ -49,22 +67,57 @@ gpuaudit scan # Specific profile and region gpuaudit scan --profile production --region us-east-1 +# Kubernetes cluster scan (uses KUBECONFIG or ~/.kube/config) +gpuaudit scan --skip-eks + +# Specific kubeconfig and context +gpuaudit scan --kubeconfig ~/.kube/config --kube-context gpu-cluster + # JSON output for automation -gpuaudit scan --format json --output report.json +gpuaudit scan --format json -o report.json -# Markdown for docs/PRs -gpuaudit scan --format markdown +# Compare two scans to see what changed +gpuaudit diff old-report.json new-report.json # Slack Block Kit payload (pipe to webhook) -gpuaudit scan --format slack --output - | curl -X POST -H 'Content-Type: application/json' -d @- $SLACK_WEBHOOK - -# Skip CloudWatch metrics (faster, less accurate) -gpuaudit scan --skip-metrics +gpuaudit scan --format slack -o - | \ + curl -X POST -H 'Content-Type: application/json' -d @- $SLACK_WEBHOOK -# Skip SageMaker scanning +# Skip specific scanners +gpuaudit scan --skip-metrics # faster, less accurate gpuaudit scan --skip-sagemaker +gpuaudit scan --skip-eks # skip AWS EKS API (use --skip-k8s for Kubernetes API) +gpuaudit scan --skip-k8s ``` +## Comparing scans + +Save scan results as JSON, then diff them later: + +```bash +gpuaudit scan --format json -o scan-apr-08.json +# ... time passes, changes happen ... +gpuaudit scan --format json -o scan-apr-15.json +gpuaudit diff scan-apr-08.json scan-apr-15.json +``` + +``` + gpuaudit diff — 2026-04-08 12:00 UTC → 2026-04-15 12:00 UTC + + ┌──────────────────────────────────────────────────────────┐ + │ Cost Delta │ + ├──────────────────────────────────────────────────────────┤ + │ Monthly spend: $372000 → $365155 (-$6845) │ + │ Estimated waste: $189000 → $23408 (-$165592) │ + │ Instances: 116 → 103 (-13 removed, +0 added) │ + └──────────────────────────────────────────────────────────┘ + + REMOVED — 13 instance(s), -$6845/mo + ... +``` + +Matches instances by ID. Reports added, removed, and changed instances with per-field diffs (instance type, pricing model, cost, state, GPU allocation, waste severity). + ## IAM permissions gpuaudit is read-only. It never modifies your infrastructure. Generate the minimal IAM policy: @@ -73,7 +126,7 @@ gpuaudit is read-only. It never modifies your infrastructure. Generate the minim gpuaudit iam-policy ``` -This outputs a JSON policy requiring only `Describe*`, `List*`, `Get*` permissions for EC2, SageMaker, CloudWatch, Cost Explorer, and Pricing APIs. +For Kubernetes scanning, gpuaudit needs `get`/`list` on `nodes` and `pods` cluster-wide. ## GPU pricing reference @@ -83,8 +136,7 @@ gpuaudit pricing # Filter by GPU model gpuaudit pricing --gpu H100 -gpuaudit pricing --gpu A10G -gpuaudit pricing --gpu T4 +gpuaudit pricing --gpu L4 ``` ## Output formats @@ -92,18 +144,17 @@ gpuaudit pricing --gpu T4 | Format | Flag | Use case | |---|---|---| | Table | `--format table` (default) | Terminal viewing | -| JSON | `--format json` | Automation, CI/CD pipelines | +| JSON | `--format json` | Automation, CI/CD, `gpuaudit diff` | | Markdown | `--format markdown` | PRs, wikis, docs | | Slack | `--format slack` | Slack webhook integration | ## How it works -1. **Discovery** — Scans EC2 and SageMaker across multiple regions for GPU instance families (g4dn, g5, g6, g6e, p4d, p4de, p5, inf2, trn1) +1. **Discovery** — Scans EC2, SageMaker, EKS node groups, and Kubernetes API across multiple regions for GPU resources 2. **Metrics** — Collects 7-day CloudWatch metrics: CPU, network I/O for EC2; GPU utilization, GPU memory, invocations for SageMaker -3. **Analysis** — Applies 6 waste detection rules with severity levels (critical/warning) -4. **Recommendations** — Generates specific actions (terminate, downsize, switch pricing) with estimated monthly savings - -Regions scanned by default: us-east-1, us-east-2, us-west-2, eu-west-1, eu-west-2, eu-central-1, ap-southeast-1, ap-northeast-1, ap-south-1. +3. **K8s allocation** — Lists pods requesting `nvidia.com/gpu` resources and maps them to nodes +4. **Analysis** — Applies 7 waste detection rules with severity levels (critical/warning/info) +5. **Recommendations** — Generates specific actions (terminate, downsize, switch pricing) with estimated monthly savings ## Project structure @@ -113,21 +164,22 @@ gpuaudit/ ├── internal/ │ ├── models/ Core data types (GPUInstance, WasteSignal, Recommendation) │ ├── pricing/ Bundled GPU pricing database (40+ instance types) -│ ├── analysis/ Waste detection rules engine -│ ├── output/ Formatters (table, JSON, markdown, Slack) -│ └── providers/aws/ EC2, SageMaker, CloudWatch, scanner orchestrator +│ ├── analysis/ Waste detection rules engine (7 rules) +│ ├── diff/ Scan comparison logic +│ ├── output/ Formatters (table, JSON, markdown, Slack, diff) +│ └── providers/ +│ ├── aws/ EC2, SageMaker, EKS, CloudWatch, Cost Explorer +│ └── k8s/ Kubernetes API GPU node/pod discovery └── LICENSE Apache 2.0 ``` ## Roadmap -- [ ] AWS Cost Explorer integration (actual vs projected spend) -- [ ] EKS GPU pod discovery +- [ ] DCGM GPU metrics via Kubernetes (actual GPU utilization, not just allocation) - [ ] SageMaker training job analysis - [ ] Multi-account (AWS Organizations) scanning - [ ] GCP + Azure support - [ ] GitHub Action for scheduled scans -- [ ] Historical scan comparison (`gpuaudit diff`) ## License From cd27862c434b6fec15eda4bddc8653a78a9da496 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:03:27 +0100 Subject: [PATCH 19/39] Add multi-target scanning design spec Covers CLI flags (--targets, --role, --org), architecture for parallel cross-account scanning via STS AssumeRole, output changes with per-target sub-summaries, and IAM role setup docs (Terraform + CloudFormation). --- ...2026-04-18-multi-target-scanning-design.md | 374 ++++++++++++++++++ 1 file changed, 374 insertions(+) create mode 100644 docs/specs/2026-04-18-multi-target-scanning-design.md diff --git a/docs/specs/2026-04-18-multi-target-scanning-design.md b/docs/specs/2026-04-18-multi-target-scanning-design.md new file mode 100644 index 0000000..9c2fd34 --- /dev/null +++ b/docs/specs/2026-04-18-multi-target-scanning-design.md @@ -0,0 +1,374 @@ +# Multi-Target Scanning + +**Date:** April 18, 2026 +**Status:** Draft + +--- + +## Summary + +Add the ability to scan multiple AWS accounts (and eventually GCP projects / Azure subscriptions) in a single `gpuaudit scan` invocation. Uses STS AssumeRole to obtain credentials for each target, scans them all in parallel, and merges results into a single flat output with per-target sub-summaries. + +Zero breaking changes — existing single-account behavior is the default. + +--- + +## CLI Interface + +### New flags on `gpuaudit scan` + +| Flag | Type | Description | +|------|------|-------------| +| `--targets` | `[]string` | Comma-separated list of account IDs to scan | +| `--role` | `string` | IAM role name to assume in each target (required with `--targets` or `--org`) | +| `--org` | `bool` | Auto-discover all accounts from AWS Organizations | +| `--external-id` | `string` | STS external ID for cross-account role assumption (optional) | +| `--skip-self` | `bool` | Exclude the caller's own account from the scan | + +### Constraints + +- `--targets` and `--org` are mutually exclusive. +- `--role` is required when `--targets` or `--org` is set. +- No `--targets` or `--org` means scan the caller's account only (current behavior, no changes). +- The caller's own account is included by default unless `--skip-self` is set. + +### Examples + +```bash +# Current behavior (unchanged) +gpuaudit scan + +# Scan 3 specific accounts +gpuaudit scan --targets 111111111111,222222222222,333333333333 --role gpuaudit-reader + +# Scan entire AWS Organization +gpuaudit scan --org --role gpuaudit-reader + +# Org scan, exclude management account +gpuaudit scan --org --role gpuaudit-reader --skip-self + +# With external ID for extra security +gpuaudit scan --targets 111111111111 --role gpuaudit-reader --external-id my-secret +``` + +### Flag naming rationale + +Flags use provider-neutral names (`--targets` not `--accounts`, `--role` not `--assume-role`) so that when GCP and Azure support lands, the same flags work: targets are project IDs or subscription IDs, role is a service account or principal name. No renaming, no backward-compatibility concerns. + +--- + +## Architecture + +### New file: `internal/providers/aws/multiaccount.go` + +Contains: + +- `Target` struct: `{AccountID string, Config aws.Config}` +- `ResolveTargets(ctx, cfg, opts) ([]Target, []TargetError)`: + - No `--targets`/`--org`: returns caller's account with existing config. + - `--targets`: calls `sts:AssumeRole` for each account ID, returns credentials. Failed assumptions are collected as `TargetError`, not fatal. + - `--org`: calls `organizations:ListAccounts`, filters to active accounts, then assumes role in each. + - Caller's own account is included (with original config, no AssumeRole needed) unless `--skip-self`. +- `TargetError` struct: `{AccountID string, Err error}` + +### Changes to `ScanOptions` + +```go +type ScanOptions struct { + // ... existing fields ... + Targets []string // account IDs to scan + Role string // role name to assume + ExternalID string // STS external ID + OrgScan bool // auto-discover from Organizations + SkipSelf bool // exclude caller's account +} +``` + +### Changes to `Scan()` + +Current flow: +``` +load config → get account ID → scan regions in parallel → merge → analyze → output +``` + +New flow: +``` +load config → ResolveTargets() → for each target (parallel): + for each region (parallel): + scanRegion(ctx, target.Config, target.AccountID, region, opts) +→ merge all instances into flat list +→ filter, analyze, enrich (unchanged) +→ BuildSummary (global + per-target sub-summaries) +→ output +``` + +All targets are scanned in parallel. Within each target, all regions are scanned in parallel (same as today). + +### Error handling: best-effort + +- `ResolveTargets` returns both successful targets and a list of `TargetError`s. +- Scan continues for all resolvable targets. +- Per-region errors within a target are handled as today (warn and continue). +- Target-level errors are surfaced in the output (see Output section). +- Exit code: 0 = success, non-zero if all targets failed. + +### Unchanged components + +- Analysis rules — operate per-instance, already provider-agnostic. +- Diff command — matches by `instance_id`, globally unique across accounts. +- `GPUInstance` model — already has `AccountID` field. +- Pricing database — account-independent. + +--- + +## Model Changes + +### `ScanResult` + +```go +type ScanResult struct { + Timestamp time.Time `json:"timestamp"` + AccountID string `json:"account_id"` // caller's account (kept for backward compat) + Targets []string `json:"targets,omitempty"` // NEW: all scanned target IDs + Regions []string `json:"regions"` + ScanDuration string `json:"scan_duration"` + Instances []GPUInstance `json:"instances"` + Summary ScanSummary `json:"summary"` + TargetSummaries []TargetSummary `json:"target_summaries,omitempty"` // NEW: per-target breakdown + TargetErrors []TargetErrorInfo `json:"target_errors,omitempty"` // NEW: failed targets +} + +type TargetSummary struct { + Target string `json:"target"` + TotalInstances int `json:"total_instances"` + TotalMonthlyCost float64 `json:"total_monthly_cost"` + TotalEstimatedWaste float64 `json:"total_estimated_waste"` + WastePercent float64 `json:"waste_percent"` + CriticalCount int `json:"critical_count"` + WarningCount int `json:"warning_count"` +} + +type TargetErrorInfo struct { + Target string `json:"target"` + Error string `json:"error"` +} +``` + +New fields use `omitempty` — single-account scans produce identical JSON to today. + +--- + +## Output Changes + +### Table + +When multiple targets are present, two additions: + +1. **"By Target" summary table** after the global summary: + +``` + By Target + ┌──────────────┬───────────┬───────────┬───────────┬───────┐ + │ Target │ Instances │ Spend/mo │ Waste/mo │ Waste │ + ├──────────────┼───────────┼───────────┼───────────┼───────┤ + │ 111111111111 │ 31 │ $142,000 │ $38,000 │ 27% │ + │ 222222222222 │ 12 │ $35,400 │ $4,200 │ 12% │ + └──────────────┴───────────┴───────────┴───────────┴───────┘ +``` + +2. **"Target" column** in instance detail tables. + +Single-target scans look identical to today. + +### JSON + +New `targets`, `target_summaries`, and `target_errors` fields as shown in the model above. Omitted when empty. + +### Markdown + +Per-target summary section added when multiple targets present. + +### Slack + +Per-target summary block added when multiple targets present. + +### Errors + +When targets fail, a warnings section appears in all formats: + +``` + Warnings + ✗ 444444444444 — AssumeRole failed: AccessDenied + ✗ 555555555555 — role "gpuaudit-reader" not found in account +``` + +--- + +## IAM Policy Updates + +### `gpuaudit iam-policy` additions + +Add two new statements to the generated policy: + +```json +{ + "Sid": "GPUAuditCrossAccount", + "Effect": "Allow", + "Action": "sts:AssumeRole", + "Resource": "arn:aws:iam::*:role/gpuaudit-reader" +}, +{ + "Sid": "GPUAuditOrganizations", + "Effect": "Allow", + "Action": "organizations:ListAccounts", + "Resource": "*" +} +``` + +These are printed as a separate "Multi-Account Permissions" section in the `iam-policy` output, with a comment explaining they're only needed for `--targets` or `--org` scanning. Always included in the output — users can ignore them if they only scan a single account. + +--- + +## Cross-Account Role Setup + +### Terraform + +```hcl +variable "management_account_id" { + description = "AWS account ID where gpuaudit runs" + type = string +} + +variable "external_id" { + description = "External ID for AssumeRole (optional but recommended)" + type = string + default = "" +} + +resource "aws_iam_role" "gpuaudit_reader" { + name = "gpuaudit-reader" + + assume_role_policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Principal = { AWS = "arn:aws:iam::${var.management_account_id}:root" } + Action = "sts:AssumeRole" + Condition = var.external_id != "" ? { + StringEquals = { "sts:ExternalId" = var.external_id } + } : {} + }] + }) +} + +resource "aws_iam_role_policy" "gpuaudit_reader" { + name = "gpuaudit-policy" + role = aws_iam_role.gpuaudit_reader.id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [ + { + Sid = "EC2ReadOnly" + Effect = "Allow" + Action = ["ec2:DescribeInstances", "ec2:DescribeInstanceTypes", "ec2:DescribeRegions"] + Resource = "*" + }, + { + Sid = "SageMakerReadOnly" + Effect = "Allow" + Action = ["sagemaker:ListEndpoints", "sagemaker:DescribeEndpoint", "sagemaker:DescribeEndpointConfig"] + Resource = "*" + }, + { + Sid = "EKSReadOnly" + Effect = "Allow" + Action = ["eks:ListClusters", "eks:ListNodegroups", "eks:DescribeNodegroup"] + Resource = "*" + }, + { + Sid = "CloudWatchReadOnly" + Effect = "Allow" + Action = ["cloudwatch:GetMetricData", "cloudwatch:GetMetricStatistics", "cloudwatch:ListMetrics"] + Resource = "*" + }, + { + Sid = "CostExplorerReadOnly" + Effect = "Allow" + Action = ["ce:GetCostAndUsage", "ce:GetReservationUtilization", "ce:GetSavingsPlansUtilization"] + Resource = "*" + }, + { + Sid = "PricingReadOnly" + Effect = "Allow" + Action = ["pricing:GetProducts"] + Resource = "*" + } + ] + }) +} +``` + +### CloudFormation (for StackSet deployment across all accounts) + +```yaml +AWSTemplateFormatVersion: "2010-09-09" +Description: gpuaudit cross-account reader role + +Parameters: + ManagementAccountId: + Type: String + Description: Account ID where gpuaudit runs + ExternalId: + Type: String + Description: External ID for AssumeRole + Default: "" + +Resources: + GpuAuditRole: + Type: AWS::IAM::Role + Properties: + RoleName: gpuaudit-reader + AssumeRolePolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Principal: + AWS: !Sub "arn:aws:iam::${ManagementAccountId}:root" + Action: sts:AssumeRole + Policies: + - PolicyName: gpuaudit-policy + PolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Action: + - ec2:DescribeInstances + - ec2:DescribeInstanceTypes + - ec2:DescribeRegions + - sagemaker:ListEndpoints + - sagemaker:DescribeEndpoint + - sagemaker:DescribeEndpointConfig + - eks:ListClusters + - eks:ListNodegroups + - eks:DescribeNodegroup + - cloudwatch:GetMetricData + - cloudwatch:GetMetricStatistics + - cloudwatch:ListMetrics + - ce:GetCostAndUsage + - ce:GetReservationUtilization + - ce:GetSavingsPlansUtilization + - pricing:GetProducts + Resource: "*" +``` + +Recommended deployment: use CloudFormation StackSets to deploy the role to all member accounts from the management account. + +--- + +## Testing + +- **Unit tests for `ResolveTargets`**: mock STS and Organizations clients, verify correct target list for each mode (explicit, org, skip-self, mixed failures). +- **Unit tests for `BuildSummary`**: verify per-target summaries compute correctly with instances from multiple accounts. +- **Unit tests for output formatters**: verify "By Target" table and Target column appear only when multiple targets present. +- **Integration test pattern**: test the full `Scan` flow with mocked AWS clients for 2-3 accounts, verify merged output. From b8c6a35a24b657ba28f4055b768b253b6f3adff4 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:11:03 +0100 Subject: [PATCH 20/39] Add multi-target scanning implementation plan --- .../plans/2026-04-18-multi-target-scanning.md | 1537 +++++++++++++++++ 1 file changed, 1537 insertions(+) create mode 100644 docs/superpowers/plans/2026-04-18-multi-target-scanning.md diff --git a/docs/superpowers/plans/2026-04-18-multi-target-scanning.md b/docs/superpowers/plans/2026-04-18-multi-target-scanning.md new file mode 100644 index 0000000..ebde5e3 --- /dev/null +++ b/docs/superpowers/plans/2026-04-18-multi-target-scanning.md @@ -0,0 +1,1537 @@ +# Multi-Target Scanning Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Enable gpuaudit to scan multiple AWS accounts in a single invocation via STS AssumeRole, with optional Organizations auto-discovery. + +**Architecture:** New `multiaccount.go` handles target resolution (explicit list or Organizations API) and credential assumption. The existing `Scan()` function is refactored to accept multiple targets and scan them all in parallel. Output formatters gain per-target summary sections when multiple targets are present. All new fields use `omitempty` so single-account scans produce identical output to today. + +**Tech Stack:** Go 1.24, AWS SDK v2 (STS, Organizations), cobra CLI, standard library testing + +--- + +## File Map + +| File | Action | Responsibility | +|------|--------|---------------| +| `internal/providers/aws/multiaccount.go` | Create | `Target` struct, `ResolveTargets()`, `TargetError` type, STS AssumeRole + Organizations list | +| `internal/providers/aws/multiaccount_test.go` | Create | Tests for `ResolveTargets()` with mock STS/Org clients | +| `internal/models/models.go` | Modify | Add `TargetSummary`, `TargetErrorInfo` types; add new fields to `ScanResult` | +| `internal/providers/aws/scanner.go` | Modify | Refactor `Scan()` to use `ResolveTargets()` and scan all targets in parallel | +| `cmd/gpuaudit/main.go` | Modify | Add `--targets`, `--role`, `--org`, `--external-id`, `--skip-self` flags; wire into `ScanOptions` | +| `internal/providers/aws/summary.go` | Create | Extract `BuildSummary` from scanner.go, add `BuildTargetSummaries()` | +| `internal/providers/aws/summary_test.go` | Create | Tests for per-target summary computation | +| `internal/output/table.go` | Modify | Add "By Target" summary table and "Target" column when multiple targets | +| `internal/output/markdown.go` | Modify | Add per-target summary section when multiple targets | +| `internal/output/slack.go` | Modify | Add per-target summary block when multiple targets | +| `go.mod` | Modify | Add `organizations` SDK dependency | + +--- + +### Task 1: Add model types for multi-target results + +**Files:** +- Modify: `internal/models/models.go` + +- [ ] **Step 1: Add `TargetSummary` and `TargetErrorInfo` types and new `ScanResult` fields** + +Add to `internal/models/models.go` after the `ScanSummary` struct: + +```go +// TargetSummary provides per-target aggregate statistics. +type TargetSummary struct { + Target string `json:"target"` + TotalInstances int `json:"total_instances"` + TotalMonthlyCost float64 `json:"total_monthly_cost"` + TotalEstimatedWaste float64 `json:"total_estimated_waste"` + WastePercent float64 `json:"waste_percent"` + CriticalCount int `json:"critical_count"` + WarningCount int `json:"warning_count"` +} + +// TargetErrorInfo describes a target that failed to scan. +type TargetErrorInfo struct { + Target string `json:"target"` + Error string `json:"error"` +} +``` + +Add three new fields to `ScanResult`: + +```go +type ScanResult struct { + Timestamp time.Time `json:"timestamp"` + AccountID string `json:"account_id"` + Targets []string `json:"targets,omitempty"` + Regions []string `json:"regions"` + ScanDuration string `json:"scan_duration"` + Instances []GPUInstance `json:"instances"` + Summary ScanSummary `json:"summary"` + TargetSummaries []TargetSummary `json:"target_summaries,omitempty"` + TargetErrors []TargetErrorInfo `json:"target_errors,omitempty"` +} +``` + +- [ ] **Step 2: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success (new types are additive, omitempty means no output change) + +- [ ] **Step 3: Run existing tests to confirm nothing broke** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 4: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/models/models.go +git commit -m "Add TargetSummary and TargetErrorInfo model types for multi-target scanning" +``` + +--- + +### Task 2: Extract `BuildSummary` and add `BuildTargetSummaries` + +**Files:** +- Create: `internal/providers/aws/summary.go` +- Create: `internal/providers/aws/summary_test.go` +- Modify: `internal/providers/aws/scanner.go` (remove `BuildSummary` — it moves to summary.go) + +- [ ] **Step 1: Write the failing test for `BuildTargetSummaries`** + +Create `internal/providers/aws/summary_test.go`: + +```go +package aws + +import ( + "testing" + + "github.com/gpuaudit/cli/internal/models" +) + +func TestBuildTargetSummaries_MultipleAccounts(t *testing.T) { + instances := []models.GPUInstance{ + { + AccountID: "111111111111", + MonthlyCost: 1000, + EstimatedSavings: 500, + WasteSignals: []models.WasteSignal{{Severity: models.SeverityCritical}}, + }, + { + AccountID: "111111111111", + MonthlyCost: 2000, + EstimatedSavings: 0, + }, + { + AccountID: "222222222222", + MonthlyCost: 3000, + EstimatedSavings: 1000, + WasteSignals: []models.WasteSignal{{Severity: models.SeverityWarning}}, + }, + } + + summaries := BuildTargetSummaries(instances) + + if len(summaries) != 2 { + t.Fatalf("expected 2 target summaries, got %d", len(summaries)) + } + + // Find each target + var s1, s2 *models.TargetSummary + for i := range summaries { + switch summaries[i].Target { + case "111111111111": + s1 = &summaries[i] + case "222222222222": + s2 = &summaries[i] + } + } + + if s1 == nil || s2 == nil { + t.Fatal("missing target summaries") + } + + if s1.TotalInstances != 2 { + t.Errorf("acct1: expected 2 instances, got %d", s1.TotalInstances) + } + if s1.TotalMonthlyCost != 3000 { + t.Errorf("acct1: expected $3000 cost, got $%.0f", s1.TotalMonthlyCost) + } + if s1.TotalEstimatedWaste != 500 { + t.Errorf("acct1: expected $500 waste, got $%.0f", s1.TotalEstimatedWaste) + } + if s1.CriticalCount != 1 { + t.Errorf("acct1: expected 1 critical, got %d", s1.CriticalCount) + } + + if s2.TotalInstances != 1 { + t.Errorf("acct2: expected 1 instance, got %d", s2.TotalInstances) + } + if s2.WarningCount != 1 { + t.Errorf("acct2: expected 1 warning, got %d", s2.WarningCount) + } +} + +func TestBuildTargetSummaries_SingleAccount(t *testing.T) { + instances := []models.GPUInstance{ + {AccountID: "111111111111", MonthlyCost: 1000}, + } + + summaries := BuildTargetSummaries(instances) + + if len(summaries) != 1 { + t.Fatalf("expected 1 summary, got %d", len(summaries)) + } +} + +func TestBuildTargetSummaries_Empty(t *testing.T) { + summaries := BuildTargetSummaries(nil) + + if len(summaries) != 0 { + t.Fatalf("expected 0 summaries for nil input, got %d", len(summaries)) + } +} +``` + +- [ ] **Step 2: Run the test to verify it fails** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./internal/providers/aws/ -run TestBuildTargetSummaries -v` +Expected: FAIL (function not defined) + +- [ ] **Step 3: Create `summary.go` with `BuildSummary` (moved from scanner.go) and `BuildTargetSummaries`** + +Create `internal/providers/aws/summary.go`: + +```go +package aws + +import ( + "sort" + + "github.com/gpuaudit/cli/internal/models" +) + +// BuildSummary computes aggregate statistics for a set of GPU instances. +func BuildSummary(instances []models.GPUInstance) models.ScanSummary { + s := models.ScanSummary{ + TotalInstances: len(instances), + } + + for _, inst := range instances { + s.TotalMonthlyCost += inst.MonthlyCost + s.TotalEstimatedWaste += inst.EstimatedSavings + + maxSeverity := models.Severity("") + for _, sig := range inst.WasteSignals { + if sig.Severity == models.SeverityCritical { + maxSeverity = models.SeverityCritical + } else if sig.Severity == models.SeverityWarning && maxSeverity != models.SeverityCritical { + maxSeverity = models.SeverityWarning + } else if sig.Severity == models.SeverityInfo && maxSeverity == "" { + maxSeverity = models.SeverityInfo + } + } + + switch maxSeverity { + case models.SeverityCritical: + s.CriticalCount++ + case models.SeverityWarning: + s.WarningCount++ + case models.SeverityInfo: + s.InfoCount++ + default: + s.HealthyCount++ + } + } + + if s.TotalMonthlyCost > 0 { + s.WastePercent = (s.TotalEstimatedWaste / s.TotalMonthlyCost) * 100 + } + + return s +} + +// BuildTargetSummaries computes per-target breakdowns from a flat instance list. +func BuildTargetSummaries(instances []models.GPUInstance) []models.TargetSummary { + if len(instances) == 0 { + return nil + } + + byTarget := make(map[string][]models.GPUInstance) + for _, inst := range instances { + byTarget[inst.AccountID] = append(byTarget[inst.AccountID], inst) + } + + summaries := make([]models.TargetSummary, 0, len(byTarget)) + for target, insts := range byTarget { + ts := models.TargetSummary{ + Target: target, + TotalInstances: len(insts), + } + for _, inst := range insts { + ts.TotalMonthlyCost += inst.MonthlyCost + ts.TotalEstimatedWaste += inst.EstimatedSavings + + maxSev := models.Severity("") + for _, sig := range inst.WasteSignals { + if sig.Severity == models.SeverityCritical { + maxSev = models.SeverityCritical + } else if sig.Severity == models.SeverityWarning && maxSev != models.SeverityCritical { + maxSev = models.SeverityWarning + } + } + switch maxSev { + case models.SeverityCritical: + ts.CriticalCount++ + case models.SeverityWarning: + ts.WarningCount++ + } + } + if ts.TotalMonthlyCost > 0 { + ts.WastePercent = (ts.TotalEstimatedWaste / ts.TotalMonthlyCost) * 100 + } + summaries = append(summaries, ts) + } + + sort.Slice(summaries, func(i, j int) bool { + return summaries[i].TotalMonthlyCost > summaries[j].TotalMonthlyCost + }) + + return summaries +} +``` + +- [ ] **Step 4: Remove `BuildSummary` and `matchesExcludeTags` from `scanner.go`** + +In `internal/providers/aws/scanner.go`, delete the `BuildSummary` function (lines 235-272) and keep `matchesExcludeTags`. The `BuildSummary` is now in `summary.go`. No import changes needed since both files are in the same package. + +- [ ] **Step 5: Run tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./... -v` +Expected: all pass, including the new `TestBuildTargetSummaries_*` tests + +- [ ] **Step 6: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/providers/aws/summary.go internal/providers/aws/summary_test.go internal/providers/aws/scanner.go +git commit -m "Extract BuildSummary to summary.go and add BuildTargetSummaries" +``` + +--- + +### Task 3: Implement `ResolveTargets` with STS AssumeRole + +**Files:** +- Create: `internal/providers/aws/multiaccount.go` +- Create: `internal/providers/aws/multiaccount_test.go` +- Modify: `go.mod` (add organizations dependency) + +- [ ] **Step 1: Add the Organizations SDK dependency** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go get github.com/aws/aws-sdk-go-v2/service/organizations` + +- [ ] **Step 2: Write failing tests for `ResolveTargets`** + +Create `internal/providers/aws/multiaccount_test.go`: + +```go +package aws + +import ( + "context" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/organizations" + orgtypes "github.com/aws/aws-sdk-go-v2/service/organizations/types" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +type mockSTSClient struct { + identity *sts.GetCallerIdentityOutput + roles map[string]*sts.AssumeRoleOutput // keyed by account ID + failAccts map[string]error +} + +func (m *mockSTSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { + return m.identity, nil +} + +func (m *mockSTSClient) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + // Extract account ID from ARN: arn:aws:iam:::role/ + arn := aws.ToString(params.RoleArn) + // Simple extraction: find the account ID between the 4th and 5th colons + acct := "" + colons := 0 + for i, c := range arn { + if c == ':' { + colons++ + if colons == 4 { + rest := arn[i+1:] + for j, r := range rest { + if r == ':' { + acct = rest[:j] + break + } + } + break + } + } + } + if err, ok := m.failAccts[acct]; ok { + return nil, err + } + if out, ok := m.roles[acct]; ok { + return out, nil + } + return nil, fmt.Errorf("no role for account %s", acct) +} + +type mockOrgClient struct { + accounts []orgtypes.Account +} + +func (m *mockOrgClient) ListAccounts(ctx context.Context, params *organizations.ListAccountsInput, optFns ...func(*organizations.Options)) (*organizations.ListAccountsOutput, error) { + return &organizations.ListAccountsOutput{Accounts: m.accounts}, nil +} + +func TestResolveTargets_NoTargets_ReturnsSelf(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, nil, ScanOptions{}) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %v", errs) + } + if len(targets) != 1 { + t.Fatalf("expected 1 target, got %d", len(targets)) + } + if targets[0].AccountID != "999999999999" { + t.Errorf("expected account 999999999999, got %s", targets[0].AccountID) + } +} + +func TestResolveTargets_ExplicitTargets(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + roles: map[string]*sts.AssumeRoleOutput{ + "111111111111": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK1"), SecretAccessKey: aws.String("SK1"), SessionToken: aws.String("ST1"), + }}, + "222222222222": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK2"), SecretAccessKey: aws.String("SK2"), SessionToken: aws.String("ST2"), + }}, + }, + } + + opts := ScanOptions{ + Targets: []string{"111111111111", "222222222222"}, + Role: "gpuaudit-reader", + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %v", errs) + } + // 2 explicit + self = 3 + if len(targets) != 3 { + t.Fatalf("expected 3 targets (2 explicit + self), got %d", len(targets)) + } +} + +func TestResolveTargets_ExplicitTargets_SkipSelf(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + roles: map[string]*sts.AssumeRoleOutput{ + "111111111111": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK1"), SecretAccessKey: aws.String("SK1"), SessionToken: aws.String("ST1"), + }}, + }, + } + + opts := ScanOptions{ + Targets: []string{"111111111111"}, + Role: "gpuaudit-reader", + SkipSelf: true, + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %v", errs) + } + if len(targets) != 1 { + t.Fatalf("expected 1 target (skip self), got %d", len(targets)) + } + if targets[0].AccountID != "111111111111" { + t.Errorf("expected 111111111111, got %s", targets[0].AccountID) + } +} + +func TestResolveTargets_PartialFailure(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + roles: map[string]*sts.AssumeRoleOutput{ + "111111111111": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK1"), SecretAccessKey: aws.String("SK1"), SessionToken: aws.String("ST1"), + }}, + }, + failAccts: map[string]error{ + "222222222222": fmt.Errorf("AccessDenied"), + }, + } + + opts := ScanOptions{ + Targets: []string{"111111111111", "222222222222"}, + Role: "gpuaudit-reader", + SkipSelf: true, + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, nil, opts) + + if len(errs) != 1 { + t.Fatalf("expected 1 error, got %d", len(errs)) + } + if errs[0].AccountID != "222222222222" { + t.Errorf("expected error for 222222222222, got %s", errs[0].AccountID) + } + if len(targets) != 1 { + t.Fatalf("expected 1 successful target, got %d", len(targets)) + } +} + +func TestResolveTargets_OrgDiscovery(t *testing.T) { + stsClient := &mockSTSClient{ + identity: &sts.GetCallerIdentityOutput{Account: aws.String("999999999999")}, + roles: map[string]*sts.AssumeRoleOutput{ + "111111111111": {Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AK1"), SecretAccessKey: aws.String("SK1"), SessionToken: aws.String("ST1"), + }}, + }, + } + + orgClient := &mockOrgClient{ + accounts: []orgtypes.Account{ + {Id: aws.String("999999999999"), Status: orgtypes.AccountStatusActive}, + {Id: aws.String("111111111111"), Status: orgtypes.AccountStatusActive}, + {Id: aws.String("333333333333"), Status: orgtypes.AccountStatusSuspended}, + }, + } + + opts := ScanOptions{ + OrgScan: true, + Role: "gpuaudit-reader", + } + + targets, errs := ResolveTargets(context.Background(), aws.Config{}, stsClient, orgClient, opts) + + // 999 (self, no assume) + 111 (assumed) = 2 targets; 333 is suspended so skipped + // Note: 999 is self so not assumed; 111 is assumed successfully + if len(targets) != 2 { + t.Fatalf("expected 2 targets (self + 1 active non-self), got %d", len(targets)) + } + if len(errs) != 0 { + t.Fatalf("expected no errors, got %v", errs) + } +} +``` + +- [ ] **Step 3: Run test to verify it fails** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./internal/providers/aws/ -run TestResolveTargets -v` +Expected: FAIL (function and types not defined) + +- [ ] **Step 4: Implement `multiaccount.go`** + +Create `internal/providers/aws/multiaccount.go`: + +```go +package aws + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/organizations" + orgtypes "github.com/aws/aws-sdk-go-v2/service/organizations/types" + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" +) + +// Target represents a resolved scan target with its credentials. +type Target struct { + AccountID string + Config aws.Config +} + +// TargetError records a target that failed credential resolution. +type TargetError struct { + AccountID string + Err error +} + +// STSClient is the subset of the STS API we need. +type STSClient interface { + GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) + AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) +} + +// OrgClient is the subset of the Organizations API we need. +type OrgClient interface { + ListAccounts(ctx context.Context, params *organizations.ListAccountsInput, optFns ...func(*organizations.Options)) (*organizations.ListAccountsOutput, error) +} + +// ResolveTargets determines which accounts to scan and obtains credentials for each. +func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient, orgClient OrgClient, opts ScanOptions) ([]Target, []TargetError) { + identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return nil, []TargetError{{AccountID: "unknown", Err: fmt.Errorf("GetCallerIdentity: %w", err)}} + } + selfAccount := aws.ToString(identity.Account) + + // No multi-target flags: return self only + if len(opts.Targets) == 0 && !opts.OrgScan { + return []Target{{AccountID: selfAccount, Config: baseCfg}}, nil + } + + // Determine account IDs to scan + var accountIDs []string + if opts.OrgScan { + discovered, err := discoverOrgAccounts(ctx, orgClient) + if err != nil { + return nil, []TargetError{{AccountID: "org", Err: fmt.Errorf("ListAccounts: %w", err)}} + } + accountIDs = discovered + } else { + accountIDs = opts.Targets + } + + var targets []Target + var targetErrors []TargetError + + // Include self unless skipped + if !opts.SkipSelf { + targets = append(targets, Target{AccountID: selfAccount, Config: baseCfg}) + } + + // Assume role in each non-self account + for _, acctID := range accountIDs { + if acctID == selfAccount { + continue // already included as self (or skipped) + } + + cfg, err := assumeRole(ctx, baseCfg, stsClient, acctID, opts.Role, opts.ExternalID) + if err != nil { + targetErrors = append(targetErrors, TargetError{AccountID: acctID, Err: err}) + continue + } + targets = append(targets, Target{AccountID: acctID, Config: cfg}) + } + + return targets, targetErrors +} + +func discoverOrgAccounts(ctx context.Context, client OrgClient) ([]string, error) { + var accounts []string + var nextToken *string + + for { + out, err := client.ListAccounts(ctx, &organizations.ListAccountsInput{ + NextToken: nextToken, + }) + if err != nil { + return nil, err + } + for _, acct := range out.Accounts { + if acct.Status == orgtypes.AccountStatusActive { + accounts = append(accounts, aws.ToString(acct.Id)) + } + } + if out.NextToken == nil { + break + } + nextToken = out.NextToken + } + + return accounts, nil +} + +func assumeRole(ctx context.Context, baseCfg aws.Config, stsClient STSClient, accountID, roleName, externalID string) (aws.Config, error) { + roleArn := fmt.Sprintf("arn:aws:iam::%s:role/%s", accountID, roleName) + + input := &sts.AssumeRoleInput{ + RoleArn: &roleArn, + RoleSessionName: aws.String("gpuaudit"), + } + if externalID != "" { + input.ExternalId = &externalID + } + + out, err := stsClient.AssumeRole(ctx, input) + if err != nil { + return aws.Config{}, fmt.Errorf("AssumeRole %s: %w", roleArn, err) + } + + creds := out.Credentials + cfg := baseCfg.Copy() + cfg.Credentials = credentials.NewStaticCredentialsProvider( + aws.ToString(creds.AccessKeyId), + aws.ToString(creds.SecretAccessKey), + aws.ToString(creds.SessionToken), + ) + + return cfg, nil +} +``` + +- [ ] **Step 5: Fix the test import — add `ststypes` import** + +The tests reference `ststypes.Credentials`. Add this import to `multiaccount_test.go`: + +```go +import ( + // ... existing imports ... + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" +) +``` + +- [ ] **Step 6: Run tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./internal/providers/aws/ -run TestResolveTargets -v` +Expected: all pass + +- [ ] **Step 7: Run full test suite** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 8: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/providers/aws/multiaccount.go internal/providers/aws/multiaccount_test.go go.mod go.sum +git commit -m "Add ResolveTargets with STS AssumeRole and Organizations discovery" +``` + +--- + +### Task 4: Refactor `Scan()` for multi-target parallel scanning + +**Files:** +- Modify: `internal/providers/aws/scanner.go` + +- [ ] **Step 1: Add multi-target fields to `ScanOptions`** + +In `internal/providers/aws/scanner.go`, add to the `ScanOptions` struct: + +```go +type ScanOptions struct { + Profile string + Regions []string + MetricWindow MetricWindow + SkipMetrics bool + SkipSageMaker bool + SkipEKS bool + SkipCosts bool + ExcludeTags map[string]string + MinUptimeDays int + + // Multi-target options + Targets []string + Role string + ExternalID string + OrgScan bool + SkipSelf bool +} +``` + +- [ ] **Step 2: Refactor `Scan()` to use `ResolveTargets` and scan all targets in parallel** + +Replace the `Scan` function in `scanner.go` with: + +```go +func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { + start := time.Now() + + // Load AWS config + cfgOpts := []func(*awsconfig.LoadOptions) error{} + if opts.Profile != "" { + cfgOpts = append(cfgOpts, awsconfig.WithSharedConfigProfile(opts.Profile)) + } + + cfg, err := awsconfig.LoadDefaultConfig(ctx, cfgOpts...) + if err != nil { + return nil, fmt.Errorf("loading AWS config: %w", err) + } + + // Resolve targets + stsClient := sts.NewFromConfig(cfg) + var orgClient OrgClient + if opts.OrgScan { + orgClient = organizations.NewFromConfig(cfg) + } + + targets, targetErrors := ResolveTargets(ctx, cfg, stsClient, orgClient, opts) + if len(targets) == 0 { + return nil, fmt.Errorf("no scannable targets resolved") + } + + // Report target errors + for _, te := range targetErrors { + fmt.Fprintf(os.Stderr, " warning: target %s: %v\n", te.AccountID, te.Err) + } + + fmt.Fprintf(os.Stderr, " Scanning %d target(s)...\n", len(targets)) + + // Determine regions to scan + regions := opts.Regions + if len(regions) == 0 { + regions, err = getGPURegions(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("listing regions: %w", err) + } + } + + fmt.Fprintf(os.Stderr, " Scanning %d regions per target for GPU instances...\n", len(regions)) + + // Scan all targets in parallel + type targetResult struct { + accountID string + instances []models.GPUInstance + regions []string + } + + resultsCh := make(chan targetResult, len(targets)) + var wg sync.WaitGroup + + for _, target := range targets { + wg.Add(1) + go func(t Target) { + defer wg.Done() + instances, scannedRegions := scanTarget(ctx, t, regions, opts) + resultsCh <- targetResult{ + accountID: t.AccountID, + instances: instances, + regions: scannedRegions, + } + }(target) + } + + go func() { + wg.Wait() + close(resultsCh) + }() + + var allInstances []models.GPUInstance + regionSet := make(map[string]bool) + callerAccount := "" + if len(targets) > 0 { + callerAccount = targets[0].AccountID + } + + for res := range resultsCh { + allInstances = append(allInstances, res.instances...) + for _, r := range res.regions { + regionSet[r] = true + } + } + + var scannedRegions []string + for r := range regionSet { + scannedRegions = append(scannedRegions, r) + } + + // Filter by excluded tags + if len(opts.ExcludeTags) > 0 { + filtered := allInstances[:0] + excluded := 0 + for _, inst := range allInstances { + if matchesExcludeTags(inst.Tags, opts.ExcludeTags) { + excluded++ + continue + } + filtered = append(filtered, inst) + } + allInstances = filtered + if excluded > 0 { + fmt.Fprintf(os.Stderr, " Excluded %d instance(s) by tag filter.\n", excluded) + } + } + + // Run analysis + analysis.AnalyzeAll(allInstances) + + // Suppress signals below minimum uptime threshold + if opts.MinUptimeDays > 0 { + minHours := float64(opts.MinUptimeDays) * 24 + for i := range allInstances { + inst := &allInstances[i] + if inst.UptimeHours >= minHours { + continue + } + inst.WasteSignals = nil + inst.Recommendations = nil + inst.EstimatedSavings = 0 + } + } + + // Build summaries + summary := BuildSummary(allInstances) + + result := &models.ScanResult{ + Timestamp: start, + AccountID: callerAccount, + Regions: scannedRegions, + ScanDuration: time.Since(start).Round(time.Millisecond).String(), + Instances: allInstances, + Summary: summary, + } + + // Add multi-target metadata + if len(targets) > 1 || len(targetErrors) > 0 { + for _, t := range targets { + result.Targets = append(result.Targets, t.AccountID) + } + result.TargetSummaries = BuildTargetSummaries(allInstances) + for _, te := range targetErrors { + result.TargetErrors = append(result.TargetErrors, models.TargetErrorInfo{ + Target: te.AccountID, + Error: te.Err.Error(), + }) + } + } + + return result, nil +} + +// scanTarget scans all regions for a single target account. +func scanTarget(ctx context.Context, target Target, regions []string, opts ScanOptions) ([]models.GPUInstance, []string) { + type regionResult struct { + region string + instances []models.GPUInstance + err error + } + + results := make(chan regionResult, len(regions)) + var wg sync.WaitGroup + + for _, region := range regions { + wg.Add(1) + go func(r string) { + defer wg.Done() + instances, err := scanRegion(ctx, target.Config, target.AccountID, r, opts) + results <- regionResult{region: r, instances: instances, err: err} + }(region) + } + + go func() { + wg.Wait() + close(results) + }() + + var allInstances []models.GPUInstance + var scannedRegions []string + + for res := range results { + if res.err != nil { + fmt.Fprintf(os.Stderr, " warning: %s/%s: %v\n", target.AccountID, res.region, res.err) + continue + } + if len(res.instances) > 0 { + allInstances = append(allInstances, res.instances...) + scannedRegions = append(scannedRegions, res.region) + } + } + + // Enrich with Cost Explorer data (per-target, since CE is account-scoped) + if !opts.SkipCosts && len(allInstances) > 0 { + ceClient := costexplorer.NewFromConfig(target.Config) + if err := EnrichCostData(ctx, ceClient, allInstances); err != nil { + fmt.Fprintf(os.Stderr, " warning: %s cost enrichment: %v\n", target.AccountID, err) + } + } + + return allInstances, scannedRegions +} +``` + +- [ ] **Step 3: Add the organizations import to scanner.go** + +Add to the import block: + +```go +"github.com/aws/aws-sdk-go-v2/service/organizations" +``` + +- [ ] **Step 4: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 5: Run all tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 6: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/providers/aws/scanner.go +git commit -m "Refactor Scan() for parallel multi-target scanning" +``` + +--- + +### Task 5: Wire CLI flags into scan command + +**Files:** +- Modify: `cmd/gpuaudit/main.go` + +- [ ] **Step 1: Add flag variables and register flags** + +Add the new flag variables alongside the existing scan flags: + +```go +var ( + // ... existing flags ... + scanTargets []string + scanRole string + scanExternalID string + scanOrg bool + scanSkipSelf bool +) +``` + +In the `init()` function, add after the existing `scanCmd.Flags` calls: + +```go +scanCmd.Flags().StringSliceVar(&scanTargets, "targets", nil, "Account IDs to scan (comma-separated)") +scanCmd.Flags().StringVar(&scanRole, "role", "", "IAM role name to assume in each target") +scanCmd.Flags().StringVar(&scanExternalID, "external-id", "", "STS external ID for cross-account role assumption") +scanCmd.Flags().BoolVar(&scanOrg, "org", false, "Auto-discover all accounts from AWS Organizations") +scanCmd.Flags().BoolVar(&scanSkipSelf, "skip-self", false, "Exclude the caller's own account from the scan") +scanCmd.MarkFlagsMutuallyExclusive("targets", "org") +``` + +- [ ] **Step 2: Wire flags into `ScanOptions` in `runScan`** + +In the `runScan` function, add the new fields to the opts construction: + +```go +opts := awsprovider.DefaultScanOptions() +opts.Profile = scanProfile +opts.Regions = scanRegions +opts.SkipMetrics = scanSkipMetrics +opts.SkipSageMaker = scanSkipSageMaker +opts.SkipEKS = scanSkipEKS +opts.SkipCosts = scanSkipCosts +opts.ExcludeTags = parseExcludeTags(scanExcludeTags) +opts.MinUptimeDays = scanMinUptimeDays +opts.Targets = scanTargets +opts.Role = scanRole +opts.ExternalID = scanExternalID +opts.OrgScan = scanOrg +opts.SkipSelf = scanSkipSelf +``` + +- [ ] **Step 3: Add validation — `--role` required with `--targets` or `--org`** + +Add at the top of `runScan`, before creating opts: + +```go +if (len(scanTargets) > 0 || scanOrg) && scanRole == "" { + return fmt.Errorf("--role is required when using --targets or --org") +} +``` + +- [ ] **Step 4: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 5: Verify CLI help shows new flags** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --help` +Expected: new flags visible in help text + +- [ ] **Step 6: Run all tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 7: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add cmd/gpuaudit/main.go +git commit -m "Add --targets, --role, --org, --external-id, --skip-self flags to scan command" +``` + +--- + +### Task 6: Update table formatter for multi-target output + +**Files:** +- Modify: `internal/output/table.go` + +- [ ] **Step 1: Add "By Target" summary table to `FormatTable`** + +In `internal/output/table.go`, add a new function and call it from `FormatTable` after the summary box: + +```go +func printTargetSummary(w io.Writer, result *models.ScanResult) { + if len(result.TargetSummaries) < 2 { + return + } + + fmt.Fprintf(w, " By Target\n") + fmt.Fprintf(w, " ┌──────────────┬───────────┬───────────┬───────────┬───────┐\n") + fmt.Fprintf(w, " │ Target │ Instances │ Spend/mo │ Waste/mo │ Waste │\n") + fmt.Fprintf(w, " ├──────────────┼───────────┼───────────┼───────────┼───────┤\n") + for _, ts := range result.TargetSummaries { + fmt.Fprintf(w, " │ %-12s │ %9d │ $%8.0f │ $%8.0f │ %4.0f%% │\n", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent) + } + fmt.Fprintf(w, " └──────────────┴───────────┴───────────┴───────────┴───────┘\n\n") + + // Target errors + if len(result.TargetErrors) > 0 { + fmt.Fprintf(w, " Warnings\n") + for _, te := range result.TargetErrors { + fmt.Fprintf(w, " ✗ %s — %s\n", te.Target, te.Error) + } + fmt.Fprintln(w) + } +} +``` + +In `FormatTable`, add the call after the summary box and before the "No GPU instances" check: + +```go +// ... after the summary box closing line ... + +printTargetSummary(w, result) + +if s.TotalInstances == 0 { +``` + +- [ ] **Step 2: Add "Target" column to `printInstanceTable` when multi-target** + +Modify `printInstanceTable` to accept and use target info. Since the formatter doesn't know if it's multi-target from just the instance slice, pass the result: + +Change the call sites in `FormatTable` from: +```go +printInstanceTable(w, critical) +``` +to: +```go +multiTarget := len(result.TargetSummaries) > 1 +printInstanceTable(w, critical, multiTarget) +``` + +Update `printInstanceTable`: + +```go +func printInstanceTable(w io.Writer, instances []models.GPUInstance, multiTarget bool) { + if multiTarget { + fmt.Fprintf(w, " %-36s %-14s %-26s %10s %-16s %s\n", + "Instance", "Target", "Type", "Monthly", "Signal", "Recommendation") + fmt.Fprintf(w, " %s %s %s %s %s %s\n", + strings.Repeat("─", 36), + strings.Repeat("─", 14), + strings.Repeat("─", 26), + strings.Repeat("─", 10), + strings.Repeat("─", 16), + strings.Repeat("─", 50), + ) + } else { + fmt.Fprintf(w, " %-36s %-26s %10s %-16s %s\n", + "Instance", "Type", "Monthly", "Signal", "Recommendation") + fmt.Fprintf(w, " %s %s %s %s %s\n", + strings.Repeat("─", 36), + strings.Repeat("─", 26), + strings.Repeat("─", 10), + strings.Repeat("─", 16), + strings.Repeat("─", 50), + ) + } + + for _, inst := range instances { + name := inst.Name + if name == "" { + name = inst.InstanceID + } + if len(name) > 34 { + name = name[:31] + "..." + } + + gpuDesc := fmt.Sprintf("%d× %s", inst.GPUCount, inst.GPUModel) + typeDesc := fmt.Sprintf("%s (%s)", inst.InstanceType, gpuDesc) + if len(typeDesc) > 26 { + typeDesc = typeDesc[:23] + "..." + } + + signal := "" + if len(inst.WasteSignals) > 0 { + signal = inst.WasteSignals[0].Type + } + + rec := "" + if len(inst.Recommendations) > 0 { + rec = inst.Recommendations[0].Description + } + + if multiTarget { + fmt.Fprintf(w, " %-36s %-14s %-26s $%9.0f %-16s %s\n", + name, inst.AccountID, typeDesc, inst.MonthlyCost, signal, rec) + } else { + fmt.Fprintf(w, " %-36s %-26s $%9.0f %-16s %s\n", + name, typeDesc, inst.MonthlyCost, signal, rec) + } + } + fmt.Fprintln(w) +} +``` + +- [ ] **Step 3: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 4: Run all tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 5: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/output/table.go +git commit -m "Add per-target summary table and target column to table formatter" +``` + +--- + +### Task 7: Update markdown and Slack formatters for multi-target output + +**Files:** +- Modify: `internal/output/markdown.go` +- Modify: `internal/output/slack.go` + +- [ ] **Step 1: Add per-target section to markdown formatter** + +In `internal/output/markdown.go`, add after the Summary table (after the `s.HealthyCount` line and before the "No GPU instances" check): + +```go +// Per-target breakdown +if len(result.TargetSummaries) > 1 { + fmt.Fprintf(w, "## By Target\n\n") + fmt.Fprintf(w, "| Target | Instances | Spend/mo | Waste/mo | Waste |\n") + fmt.Fprintf(w, "|---|---|---|---|---|\n") + for _, ts := range result.TargetSummaries { + fmt.Fprintf(w, "| %s | %d | $%.0f | $%.0f | %.0f%% |\n", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent) + } + fmt.Fprintln(w) +} + +if len(result.TargetErrors) > 0 { + fmt.Fprintf(w, "## Warnings\n\n") + for _, te := range result.TargetErrors { + fmt.Fprintf(w, "- **%s** — %s\n", te.Target, te.Error) + } + fmt.Fprintln(w) +} +``` + +Also add a "Target" column to the Findings table when multi-target. Change the table header and row formatting: + +```go +if len(result.TargetSummaries) > 1 { + fmt.Fprintf(w, "| Instance | Target | Type | Monthly Cost | Signal | Savings | Recommendation |\n") + fmt.Fprintf(w, "|---|---|---|---|---|---|---|\n") +} else { + fmt.Fprintf(w, "| Instance | Type | Monthly Cost | Signal | Savings | Recommendation |\n") + fmt.Fprintf(w, "|---|---|---|---|---|---|\n") +} + +for _, inst := range result.Instances { + // ... existing name/signal/rec/savings formatting ... + + if len(result.TargetSummaries) > 1 { + fmt.Fprintf(w, "| %s | %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", + name, inst.AccountID, inst.InstanceType, inst.GPUCount, inst.GPUModel, + inst.MonthlyCost, signal, savings, rec) + } else { + fmt.Fprintf(w, "| %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", + name, inst.InstanceType, inst.GPUCount, inst.GPUModel, + inst.MonthlyCost, signal, savings, rec) + } +} +``` + +- [ ] **Step 2: Add per-target block to Slack formatter** + +In `internal/output/slack.go`, in `FormatSlack`, add after the summary block and divider: + +```go +// Per-target breakdown +if len(result.TargetSummaries) > 1 { + lines := []string{"*By Target*"} + for _, ts := range result.TargetSummaries { + lines = append(lines, fmt.Sprintf("• `%s` — %d instances, $%.0f/mo spend, $%.0f/mo waste (%.0f%%)", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent)) + } + blocks = append(blocks, slackSection(strings.Join(lines, "\n"))) + blocks = append(blocks, map[string]any{"type": "divider"}) +} + +// Target errors +if len(result.TargetErrors) > 0 { + lines := []string{":warning: *Target Warnings*"} + for _, te := range result.TargetErrors { + lines = append(lines, fmt.Sprintf("• `%s` — %s", te.Target, te.Error)) + } + blocks = append(blocks, slackSection(strings.Join(lines, "\n"))) +} +``` + +- [ ] **Step 3: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 4: Run all tests** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./...` +Expected: all pass + +- [ ] **Step 5: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add internal/output/markdown.go internal/output/slack.go +git commit -m "Add per-target summaries to markdown and Slack formatters" +``` + +--- + +### Task 8: Update `iam-policy` command + +**Files:** +- Modify: `cmd/gpuaudit/main.go` + +- [ ] **Step 1: Add cross-account and Organizations statements to `iam-policy` output** + +In `cmd/gpuaudit/main.go`, in the `iamPolicyCmd` Run function, add two new statements to the policy `Statement` slice: + +```go +{ + "Sid": "GPUAuditCrossAccount", + "Effect": "Allow", + "Action": "sts:AssumeRole", + "Resource": "arn:aws:iam::*:role/gpuaudit-reader", +}, +{ + "Sid": "GPUAuditOrganizations", + "Effect": "Allow", + "Action": "organizations:ListAccounts", + "Resource": "*", +}, +``` + +Add a comment before encoding: + +```go +fmt.Fprintln(os.Stdout, "// The last two statements (CrossAccount, Organizations) are only needed for --targets or --org scanning.") +``` + +- [ ] **Step 2: Verify build passes** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go build ./...` +Expected: success + +- [ ] **Step 3: Verify output looks correct** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit iam-policy` +Expected: JSON policy with the two new statements appended + +- [ ] **Step 4: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add cmd/gpuaudit/main.go +git commit -m "Add cross-account and Organizations permissions to iam-policy output" +``` + +--- + +### Task 9: Update README with multi-target documentation + +**Files:** +- Modify: `README.md` + +- [ ] **Step 1: Add multi-account scanning section to README** + +Add a new section after the existing usage documentation: + +```markdown +## Multi-Account Scanning + +Scan multiple AWS accounts in a single invocation using STS AssumeRole. + +### Prerequisites + +Deploy a read-only IAM role (`gpuaudit-reader`) to each target account. See [Cross-Account Role Setup](#cross-account-role-setup) below. + +### Usage + +```bash +# Scan specific accounts +gpuaudit scan --targets 111111111111,222222222222 --role gpuaudit-reader + +# Scan entire AWS Organization +gpuaudit scan --org --role gpuaudit-reader + +# Exclude management account +gpuaudit scan --org --role gpuaudit-reader --skip-self + +# With external ID +gpuaudit scan --targets 111111111111 --role gpuaudit-reader --external-id my-secret +``` + +### Cross-Account Role Setup + +#### Terraform + +```hcl +variable "management_account_id" { + description = "AWS account ID where gpuaudit runs" + type = string +} + +resource "aws_iam_role" "gpuaudit_reader" { + name = "gpuaudit-reader" + assume_role_policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Principal = { AWS = "arn:aws:iam::${var.management_account_id}:root" } + Action = "sts:AssumeRole" + }] + }) +} + +resource "aws_iam_role_policy" "gpuaudit_reader" { + name = "gpuaudit-policy" + role = aws_iam_role.gpuaudit_reader.id + policy = file("gpuaudit-policy.json") # from: gpuaudit iam-policy > gpuaudit-policy.json +} +``` + +Deploy to all accounts using Terraform workspaces or CloudFormation StackSets. + +#### CloudFormation StackSet + +```yaml +AWSTemplateFormatVersion: "2010-09-09" +Parameters: + ManagementAccountId: + Type: String +Resources: + GpuAuditRole: + Type: AWS::IAM::Role + Properties: + RoleName: gpuaudit-reader + AssumeRolePolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Principal: + AWS: !Sub "arn:aws:iam::${ManagementAccountId}:root" + Action: sts:AssumeRole + Policies: + - PolicyName: gpuaudit-policy + PolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Action: + - ec2:DescribeInstances + - ec2:DescribeInstanceTypes + - ec2:DescribeRegions + - sagemaker:ListEndpoints + - sagemaker:DescribeEndpoint + - sagemaker:DescribeEndpointConfig + - eks:ListClusters + - eks:ListNodegroups + - eks:DescribeNodegroup + - cloudwatch:GetMetricData + - cloudwatch:GetMetricStatistics + - cloudwatch:ListMetrics + - ce:GetCostAndUsage + - ce:GetReservationUtilization + - ce:GetSavingsPlansUtilization + - pricing:GetProducts + Resource: "*" +``` +``` + +- [ ] **Step 2: Commit** + +```bash +cd /Users/smaksimov/Work/0cloud/gpuaudit +git add README.md +git commit -m "Add multi-account scanning docs to README" +``` + +--- + +### Task 10: End-to-end verification + +**Files:** None (verification only) + +- [ ] **Step 1: Run full test suite** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go test ./... -v` +Expected: all pass + +- [ ] **Step 2: Run go vet** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go vet ./...` +Expected: no issues + +- [ ] **Step 3: Verify CLI help** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --help` +Expected: all new flags visible (--targets, --role, --org, --external-id, --skip-self) + +- [ ] **Step 4: Verify mutual exclusivity** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --targets 111 --org --role test 2>&1` +Expected: error about mutually exclusive flags + +- [ ] **Step 5: Verify --role validation** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --targets 111 2>&1` +Expected: error "role is required when using --targets or --org" + +- [ ] **Step 6: Verify single-account scan still works (no regression)** + +Run: `cd /Users/smaksimov/Work/0cloud/gpuaudit && go run ./cmd/gpuaudit scan --skip-metrics --skip-sagemaker --skip-eks --skip-k8s --skip-costs 2>&1` +Expected: runs normally, output unchanged from before this feature From 408399ef8b3382ddb6bd63b266ae29eac5656280 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:13:34 +0100 Subject: [PATCH 21/39] Add TargetSummary and TargetErrorInfo model types for multi-target scanning --- internal/models/models.go | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/internal/models/models.go b/internal/models/models.go index 8e99dbd..c523838 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -118,12 +118,15 @@ type Recommendation struct { // ScanResult holds the complete output of a gpuaudit scan. type ScanResult struct { - Timestamp time.Time `json:"timestamp"` - AccountID string `json:"account_id"` - Regions []string `json:"regions"` - ScanDuration string `json:"scan_duration"` - Instances []GPUInstance `json:"instances"` - Summary ScanSummary `json:"summary"` + Timestamp time.Time `json:"timestamp"` + AccountID string `json:"account_id"` + Targets []string `json:"targets,omitempty"` + Regions []string `json:"regions"` + ScanDuration string `json:"scan_duration"` + Instances []GPUInstance `json:"instances"` + Summary ScanSummary `json:"summary"` + TargetSummaries []TargetSummary `json:"target_summaries,omitempty"` + TargetErrors []TargetErrorInfo `json:"target_errors,omitempty"` } // ScanSummary provides aggregate statistics for a scan. @@ -138,5 +141,22 @@ type ScanSummary struct { HealthyCount int `json:"healthy_count"` } +// TargetSummary provides per-target aggregate statistics. +type TargetSummary struct { + Target string `json:"target"` + TotalInstances int `json:"total_instances"` + TotalMonthlyCost float64 `json:"total_monthly_cost"` + TotalEstimatedWaste float64 `json:"total_estimated_waste"` + WastePercent float64 `json:"waste_percent"` + CriticalCount int `json:"critical_count"` + WarningCount int `json:"warning_count"` +} + +// TargetErrorInfo describes a target that failed to scan. +type TargetErrorInfo struct { + Target string `json:"target"` + Error string `json:"error"` +} + // Ptr is a convenience helper for creating pointer values in tests and literals. func Ptr[T any](v T) *T { return &v } From e856adb45e0f59c4f0b60fd73423cfdc81d3f7f0 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:15:35 +0100 Subject: [PATCH 22/39] Extract BuildSummary to summary.go and add BuildTargetSummaries --- internal/providers/aws/scanner.go | 40 ----------- internal/providers/aws/summary.go | 96 ++++++++++++++++++++++++++ internal/providers/aws/summary_test.go | 89 ++++++++++++++++++++++++ 3 files changed, 185 insertions(+), 40 deletions(-) create mode 100644 internal/providers/aws/summary.go create mode 100644 internal/providers/aws/summary_test.go diff --git a/internal/providers/aws/scanner.go b/internal/providers/aws/scanner.go index d8d5921..e28cbab 100644 --- a/internal/providers/aws/scanner.go +++ b/internal/providers/aws/scanner.go @@ -231,46 +231,6 @@ func getGPURegions(ctx context.Context, cfg aws.Config) ([]string, error) { }, nil } -// BuildSummary computes aggregate statistics for a set of GPU instances. -func BuildSummary(instances []models.GPUInstance) models.ScanSummary { - s := models.ScanSummary{ - TotalInstances: len(instances), - } - - for _, inst := range instances { - s.TotalMonthlyCost += inst.MonthlyCost - s.TotalEstimatedWaste += inst.EstimatedSavings - - maxSeverity := models.Severity("") - for _, sig := range inst.WasteSignals { - if sig.Severity == models.SeverityCritical { - maxSeverity = models.SeverityCritical - } else if sig.Severity == models.SeverityWarning && maxSeverity != models.SeverityCritical { - maxSeverity = models.SeverityWarning - } else if sig.Severity == models.SeverityInfo && maxSeverity == "" { - maxSeverity = models.SeverityInfo - } - } - - switch maxSeverity { - case models.SeverityCritical: - s.CriticalCount++ - case models.SeverityWarning: - s.WarningCount++ - case models.SeverityInfo: - s.InfoCount++ - default: - s.HealthyCount++ - } - } - - if s.TotalMonthlyCost > 0 { - s.WastePercent = (s.TotalEstimatedWaste / s.TotalMonthlyCost) * 100 - } - - return s -} - func matchesExcludeTags(instanceTags map[string]string, excludes map[string]string) bool { for k, v := range excludes { if instanceTags[k] == v { diff --git a/internal/providers/aws/summary.go b/internal/providers/aws/summary.go new file mode 100644 index 0000000..bae351a --- /dev/null +++ b/internal/providers/aws/summary.go @@ -0,0 +1,96 @@ +package aws + +import ( + "sort" + + "github.com/gpuaudit/cli/internal/models" +) + +// BuildSummary computes aggregate statistics for a set of GPU instances. +func BuildSummary(instances []models.GPUInstance) models.ScanSummary { + s := models.ScanSummary{ + TotalInstances: len(instances), + } + + for _, inst := range instances { + s.TotalMonthlyCost += inst.MonthlyCost + s.TotalEstimatedWaste += inst.EstimatedSavings + + maxSeverity := models.Severity("") + for _, sig := range inst.WasteSignals { + if sig.Severity == models.SeverityCritical { + maxSeverity = models.SeverityCritical + } else if sig.Severity == models.SeverityWarning && maxSeverity != models.SeverityCritical { + maxSeverity = models.SeverityWarning + } else if sig.Severity == models.SeverityInfo && maxSeverity == "" { + maxSeverity = models.SeverityInfo + } + } + + switch maxSeverity { + case models.SeverityCritical: + s.CriticalCount++ + case models.SeverityWarning: + s.WarningCount++ + case models.SeverityInfo: + s.InfoCount++ + default: + s.HealthyCount++ + } + } + + if s.TotalMonthlyCost > 0 { + s.WastePercent = (s.TotalEstimatedWaste / s.TotalMonthlyCost) * 100 + } + + return s +} + +// BuildTargetSummaries computes per-target breakdowns from a flat instance list. +func BuildTargetSummaries(instances []models.GPUInstance) []models.TargetSummary { + if len(instances) == 0 { + return nil + } + + byTarget := make(map[string][]models.GPUInstance) + for _, inst := range instances { + byTarget[inst.AccountID] = append(byTarget[inst.AccountID], inst) + } + + summaries := make([]models.TargetSummary, 0, len(byTarget)) + for target, insts := range byTarget { + ts := models.TargetSummary{ + Target: target, + TotalInstances: len(insts), + } + for _, inst := range insts { + ts.TotalMonthlyCost += inst.MonthlyCost + ts.TotalEstimatedWaste += inst.EstimatedSavings + + maxSev := models.Severity("") + for _, sig := range inst.WasteSignals { + if sig.Severity == models.SeverityCritical { + maxSev = models.SeverityCritical + } else if sig.Severity == models.SeverityWarning && maxSev != models.SeverityCritical { + maxSev = models.SeverityWarning + } + } + switch maxSev { + case models.SeverityCritical: + ts.CriticalCount++ + case models.SeverityWarning: + ts.WarningCount++ + } + } + if ts.TotalMonthlyCost > 0 { + ts.WastePercent = (ts.TotalEstimatedWaste / ts.TotalMonthlyCost) * 100 + } + summaries = append(summaries, ts) + } + + sort.Slice(summaries, func(i, j int) bool { + return summaries[i].TotalMonthlyCost > summaries[j].TotalMonthlyCost + }) + + return summaries +} diff --git a/internal/providers/aws/summary_test.go b/internal/providers/aws/summary_test.go new file mode 100644 index 0000000..b429e39 --- /dev/null +++ b/internal/providers/aws/summary_test.go @@ -0,0 +1,89 @@ +package aws + +import ( + "testing" + + "github.com/gpuaudit/cli/internal/models" +) + +func TestBuildTargetSummaries_MultipleAccounts(t *testing.T) { + instances := []models.GPUInstance{ + { + AccountID: "111111111111", + MonthlyCost: 1000, + EstimatedSavings: 500, + WasteSignals: []models.WasteSignal{{Severity: models.SeverityCritical}}, + }, + { + AccountID: "111111111111", + MonthlyCost: 2000, + EstimatedSavings: 0, + }, + { + AccountID: "222222222222", + MonthlyCost: 3000, + EstimatedSavings: 1000, + WasteSignals: []models.WasteSignal{{Severity: models.SeverityWarning}}, + }, + } + + summaries := BuildTargetSummaries(instances) + + if len(summaries) != 2 { + t.Fatalf("expected 2 target summaries, got %d", len(summaries)) + } + + var s1, s2 *models.TargetSummary + for i := range summaries { + switch summaries[i].Target { + case "111111111111": + s1 = &summaries[i] + case "222222222222": + s2 = &summaries[i] + } + } + + if s1 == nil || s2 == nil { + t.Fatal("missing target summaries") + } + + if s1.TotalInstances != 2 { + t.Errorf("acct1: expected 2 instances, got %d", s1.TotalInstances) + } + if s1.TotalMonthlyCost != 3000 { + t.Errorf("acct1: expected $3000 cost, got $%.0f", s1.TotalMonthlyCost) + } + if s1.TotalEstimatedWaste != 500 { + t.Errorf("acct1: expected $500 waste, got $%.0f", s1.TotalEstimatedWaste) + } + if s1.CriticalCount != 1 { + t.Errorf("acct1: expected 1 critical, got %d", s1.CriticalCount) + } + + if s2.TotalInstances != 1 { + t.Errorf("acct2: expected 1 instance, got %d", s2.TotalInstances) + } + if s2.WarningCount != 1 { + t.Errorf("acct2: expected 1 warning, got %d", s2.WarningCount) + } +} + +func TestBuildTargetSummaries_SingleAccount(t *testing.T) { + instances := []models.GPUInstance{ + {AccountID: "111111111111", MonthlyCost: 1000}, + } + + summaries := BuildTargetSummaries(instances) + + if len(summaries) != 1 { + t.Fatalf("expected 1 summary, got %d", len(summaries)) + } +} + +func TestBuildTargetSummaries_Empty(t *testing.T) { + summaries := BuildTargetSummaries(nil) + + if len(summaries) != 0 { + t.Fatalf("expected 0 summaries for nil input, got %d", len(summaries)) + } +} From d1c79f3f5b8fc946c3f9901cb77fb100b132da35 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:20:17 +0100 Subject: [PATCH 23/39] Implement ResolveTargets with STS AssumeRole for multi-account scanning Add ResolveTargets function that resolves scan targets based on --targets, --org, --role, and --skip-self options. Self account uses original credentials (no AssumeRole), failed assumptions are collected as TargetError rather than being fatal. Add STSClient and OrgClient interfaces, Target and TargetError types, multi-target fields to ScanOptions, and organizations SDK dependency. Includes 6 tests covering: self-only, explicit targets, skip-self, partial failure, org discovery with suspended account filtering, and self-in-targets deduplication. --- go.mod | 11 +- go.sum | 18 +- internal/providers/aws/multiaccount.go | 165 +++++++++++ internal/providers/aws/multiaccount_test.go | 298 ++++++++++++++++++++ internal/providers/aws/scanner.go | 7 + 5 files changed, 486 insertions(+), 13 deletions(-) create mode 100644 internal/providers/aws/multiaccount.go create mode 100644 internal/providers/aws/multiaccount_test.go diff --git a/go.mod b/go.mod index e6bceb9..9b28a73 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,14 @@ module github.com/gpuaudit/cli go 1.24.0 require ( - github.com/aws/aws-sdk-go-v2 v1.41.5 + github.com/aws/aws-sdk-go-v2 v1.41.6 github.com/aws/aws-sdk-go-v2/config v1.32.14 + github.com/aws/aws-sdk-go-v2/credentials v1.19.14 github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.56.0 github.com/aws/aws-sdk-go-v2/service/costexplorer v1.63.6 github.com/aws/aws-sdk-go-v2/service/ec2 v1.296.2 github.com/aws/aws-sdk-go-v2/service/eks v1.82.0 + github.com/aws/aws-sdk-go-v2/service/organizations v1.51.2 github.com/aws/aws-sdk-go-v2/service/sagemaker v1.238.0 github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 github.com/prometheus/client_model v0.6.2 @@ -20,17 +22,16 @@ require ( ) require ( - github.com/aws/aws-sdk-go-v2/credentials v1.19.14 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect - github.com/aws/smithy-go v1.24.2 // indirect + github.com/aws/smithy-go v1.25.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect diff --git a/go.sum b/go.sum index 08691a8..67088e7 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,15 @@ -github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY= -github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2 v1.41.6 h1:1AX0AthnBQzMx1vbmir3Y4WsnJgiydmnJjiLu+LvXOg= +github.com/aws/aws-sdk-go-v2 v1.41.6/go.mod h1:dy0UzBIfwSeot4grGvY1AqFWN5zgziMmWGzysDnHFcQ= github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI= github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo= github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI= github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22 h1:GmLa5Kw1ESqtFpXsx5MmC84QWa/ZrLZvlJGa2y+4kcQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22/go.mod h1:6sW9iWm9DK9YRpRGga/qzrzNLgKpT2cIxb7Vo2eNOp0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22 h1:dY4kWZiSaXIzxnKlj17nHnBcXXBfac6UlsAx2qL6XrU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22/go.mod h1:KIpEUx0JuRZLO7U6cbV204cWAEco2iC3l061IxlwLtI= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.56.0 h1:ud2A364lLBkhGAC7oYw/1xg9BF4acwJC+qdLykxy83o= @@ -24,6 +24,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhL github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA= +github.com/aws/aws-sdk-go-v2/service/organizations v1.51.2 h1:2TDersSNowBwSRTrnD0LxLilpr6Dr5coXwVsWO7f2rw= +github.com/aws/aws-sdk-go-v2/service/organizations v1.51.2/go.mod h1:UMm4MKZDJMbuJZF5QOJBsVRMLeKiEXAgCXFpocWPDFo= github.com/aws/aws-sdk-go-v2/service/sagemaker v1.238.0 h1:5jLvLVu20tlFgVOsX+ns4jNVzoUWP36AQc5sAvNJSMI= github.com/aws/aws-sdk-go-v2/service/sagemaker v1.238.0/go.mod h1:zsRrjJIfG9a9b3VRU+uPa3dX5fqgI+zKMXD4tbIlbdA= github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg= @@ -34,8 +36,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6f github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w= github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U= github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw= -github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= -github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= +github.com/aws/smithy-go v1.25.0 h1:Sz/XJ64rwuiKtB6j98nDIPyYrV1nVNJ4YU74gttcl5U= +github.com/aws/smithy-go v1.25.0/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/providers/aws/multiaccount.go b/internal/providers/aws/multiaccount.go new file mode 100644 index 0000000..fd8a99c --- /dev/null +++ b/internal/providers/aws/multiaccount.go @@ -0,0 +1,165 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/organizations" + orgtypes "github.com/aws/aws-sdk-go-v2/service/organizations/types" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +// Target represents a resolved scan target with its credentials. +type Target struct { + AccountID string + Config aws.Config +} + +// TargetError records a target that failed credential resolution. +type TargetError struct { + AccountID string + Err error +} + +// STSClient is the subset of the STS API we need. +type STSClient interface { + GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) + AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) +} + +// OrgClient is the subset of the Organizations API we need. +type OrgClient interface { + ListAccounts(ctx context.Context, params *organizations.ListAccountsInput, optFns ...func(*organizations.Options)) (*organizations.ListAccountsOutput, error) +} + +// ResolveTargets determines which accounts to scan and obtains credentials for each. +// +// Behaviour: +// - No --targets/--org: returns self only (uses baseCfg, no AssumeRole) +// - --targets + --role: AssumeRole for each, self included by default +// - --org + --role: ListAccounts, filter Active, AssumeRole for non-self accounts +// - --skip-self: exclude caller's account +// - Self account is never AssumeRole'd — uses original credentials +// - Failed AssumeRole calls are collected as TargetError, not fatal +func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient, orgClient OrgClient, opts ScanOptions) ([]Target, []TargetError) { + // Identify the caller's own account. + identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return nil, []TargetError{{AccountID: "unknown", Err: fmt.Errorf("GetCallerIdentity: %w", err)}} + } + selfAccount := aws.ToString(identity.Account) + + // Determine the list of account IDs to scan. + var accountIDs []string + + switch { + case opts.OrgScan: + activeAccounts, listErr := listActiveOrgAccounts(ctx, orgClient) + if listErr != nil { + return nil, []TargetError{{AccountID: "org", Err: fmt.Errorf("ListAccounts: %w", listErr)}} + } + accountIDs = activeAccounts + case len(opts.Targets) > 0: + // Always include self unless it is already in the list or --skip-self is set. + seen := make(map[string]bool) + for _, id := range opts.Targets { + if !seen[id] { + accountIDs = append(accountIDs, id) + seen[id] = true + } + } + if !seen[selfAccount] && !opts.SkipSelf { + // Prepend self so it appears first. + accountIDs = append([]string{selfAccount}, accountIDs...) + } + default: + // No multi-target flags — scan self only. + return []Target{{AccountID: selfAccount, Config: baseCfg}}, nil + } + + // Resolve credentials for each account. + var targets []Target + var targetErrors []TargetError + + for _, acctID := range accountIDs { + if opts.SkipSelf && acctID == selfAccount { + continue + } + + if acctID == selfAccount { + // Self: use original credentials, no AssumeRole. + targets = append(targets, Target{AccountID: selfAccount, Config: baseCfg}) + continue + } + + // AssumeRole into the target account. + cfg, assumeErr := assumeRole(ctx, baseCfg, stsClient, acctID, opts.Role, opts.ExternalID) + if assumeErr != nil { + targetErrors = append(targetErrors, TargetError{AccountID: acctID, Err: assumeErr}) + continue + } + targets = append(targets, Target{AccountID: acctID, Config: cfg}) + } + + return targets, targetErrors +} + +// assumeRole assumes a role in the given account and returns an aws.Config +// with the temporary credentials. +func assumeRole(ctx context.Context, baseCfg aws.Config, stsClient STSClient, accountID, roleName, externalID string) (aws.Config, error) { + roleARN := fmt.Sprintf("arn:aws:iam::%s:role/%s", accountID, roleName) + + input := &sts.AssumeRoleInput{ + RoleArn: aws.String(roleARN), + RoleSessionName: aws.String("gpuaudit"), + } + if externalID != "" { + input.ExternalId = aws.String(externalID) + } + + result, err := stsClient.AssumeRole(ctx, input) + if err != nil { + return aws.Config{}, fmt.Errorf("AssumeRole %s: %w", roleARN, err) + } + + creds := result.Credentials + cfg := baseCfg.Copy() + cfg.Credentials = credentials.NewStaticCredentialsProvider( + aws.ToString(creds.AccessKeyId), + aws.ToString(creds.SecretAccessKey), + aws.ToString(creds.SessionToken), + ) + + return cfg, nil +} + +// listActiveOrgAccounts returns the account IDs of all active accounts in the organization. +func listActiveOrgAccounts(ctx context.Context, orgClient OrgClient) ([]string, error) { + var accountIDs []string + var nextToken *string + + for { + out, err := orgClient.ListAccounts(ctx, &organizations.ListAccountsInput{ + NextToken: nextToken, + }) + if err != nil { + return nil, err + } + for _, acct := range out.Accounts { + if acct.Status == orgtypes.AccountStatusActive { + accountIDs = append(accountIDs, aws.ToString(acct.Id)) + } + } + if out.NextToken == nil { + break + } + nextToken = out.NextToken + } + + return accountIDs, nil +} diff --git a/internal/providers/aws/multiaccount_test.go b/internal/providers/aws/multiaccount_test.go new file mode 100644 index 0000000..2d40cce --- /dev/null +++ b/internal/providers/aws/multiaccount_test.go @@ -0,0 +1,298 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/organizations" + orgtypes "github.com/aws/aws-sdk-go-v2/service/organizations/types" + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" +) + +// --- Mock STS client --- + +type mockSTSClient struct { + callerAccount string + assumeResults map[string]*sts.AssumeRoleOutput // accountID -> output + assumeErrors map[string]error // accountID -> error +} + +func (m *mockSTSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { + return &sts.GetCallerIdentityOutput{ + Account: aws.String(m.callerAccount), + }, nil +} + +func (m *mockSTSClient) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + // Extract account ID from the role ARN: arn:aws:iam:::role/ + arn := aws.ToString(params.RoleArn) + // Simple parse: find the account between the 4th and 5th colons + accountID := parseAccountFromARN(arn) + + if err, ok := m.assumeErrors[accountID]; ok { + return nil, err + } + if out, ok := m.assumeResults[accountID]; ok { + return out, nil + } + return nil, fmt.Errorf("no mock configured for account %s", accountID) +} + +func parseAccountFromARN(arn string) string { + // arn:aws:iam::123456789012:role/name + colons := 0 + start := 0 + for i, c := range arn { + if c == ':' { + colons++ + if colons == 4 { + start = i + 1 + } + if colons == 5 { + return arn[start:i] + } + } + } + return "" +} + +// --- Mock Org client --- + +type mockOrgClient struct { + accounts []orgtypes.Account + err error +} + +func (m *mockOrgClient) ListAccounts(ctx context.Context, params *organizations.ListAccountsInput, optFns ...func(*organizations.Options)) (*organizations.ListAccountsOutput, error) { + if m.err != nil { + return nil, m.err + } + return &organizations.ListAccountsOutput{Accounts: m.accounts}, nil +} + +// Helper to build a successful AssumeRole result with dummy credentials. +func assumeRoleOK(accountID string) *sts.AssumeRoleOutput { + exp := time.Now().Add(1 * time.Hour) + return &sts.AssumeRoleOutput{ + Credentials: &ststypes.Credentials{ + AccessKeyId: aws.String("AKID-" + accountID), + SecretAccessKey: aws.String("SECRET-" + accountID), + SessionToken: aws.String("TOKEN-" + accountID), + Expiration: &exp, + }, + } +} + +func TestResolveTargets_NoTargets_ReturnsSelfOnly(t *testing.T) { + stsClient := &mockSTSClient{callerAccount: "111111111111"} + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{} + + targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d: %v", len(errs), errs) + } + if len(targets) != 1 { + t.Fatalf("expected 1 target (self), got %d", len(targets)) + } + if targets[0].AccountID != "111111111111" { + t.Errorf("expected account 111111111111, got %s", targets[0].AccountID) + } +} + +func TestResolveTargets_ExplicitTargets_ReturnsSelfPlusAssumed(t *testing.T) { + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + "333333333333": assumeRoleOK("333333333333"), + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + Targets: []string{"222222222222", "333333333333"}, + Role: "AuditRole", + } + + targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d", len(errs)) + } + // Self + 2 explicit targets = 3 + if len(targets) != 3 { + t.Fatalf("expected 3 targets, got %d", len(targets)) + } + + // Verify self is included + found := false + for _, tgt := range targets { + if tgt.AccountID == "111111111111" { + found = true + break + } + } + if !found { + t.Error("self account 111111111111 not found in targets") + } + + // Verify assumed targets + for _, acct := range []string{"222222222222", "333333333333"} { + found := false + for _, tgt := range targets { + if tgt.AccountID == acct { + found = true + break + } + } + if !found { + t.Errorf("account %s not found in targets", acct) + } + } +} + +func TestResolveTargets_ExplicitTargets_SkipSelf(t *testing.T) { + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + Targets: []string{"222222222222"}, + Role: "AuditRole", + SkipSelf: true, + } + + targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d", len(errs)) + } + if len(targets) != 1 { + t.Fatalf("expected 1 target (no self), got %d", len(targets)) + } + if targets[0].AccountID != "222222222222" { + t.Errorf("expected account 222222222222, got %s", targets[0].AccountID) + } +} + +func TestResolveTargets_PartialFailure(t *testing.T) { + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + }, + assumeErrors: map[string]error{ + "333333333333": fmt.Errorf("access denied"), + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + Targets: []string{"222222222222", "333333333333"}, + Role: "AuditRole", + } + + targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + // Self + 222 succeeded, 333 failed + if len(targets) != 2 { + t.Fatalf("expected 2 targets, got %d", len(targets)) + } + if len(errs) != 1 { + t.Fatalf("expected 1 error, got %d", len(errs)) + } + if errs[0].AccountID != "333333333333" { + t.Errorf("expected error for 333333333333, got %s", errs[0].AccountID) + } +} + +func TestResolveTargets_OrgDiscovery(t *testing.T) { + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + "444444444444": assumeRoleOK("444444444444"), + }, + } + orgClient := &mockOrgClient{ + accounts: []orgtypes.Account{ + {Id: aws.String("111111111111"), Status: orgtypes.AccountStatusActive}, + {Id: aws.String("222222222222"), Status: orgtypes.AccountStatusActive}, + {Id: aws.String("333333333333"), Status: orgtypes.AccountStatusSuspended}, + {Id: aws.String("444444444444"), Status: orgtypes.AccountStatusActive}, + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + OrgScan: true, + Role: "AuditRole", + } + + targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, orgClient, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d: %v", len(errs), errs) + } + // Active accounts: 111 (self), 222, 444. Suspended 333 is filtered. + if len(targets) != 3 { + t.Fatalf("expected 3 targets (self + 2 active non-self), got %d", len(targets)) + } + + // Verify suspended account is excluded + for _, tgt := range targets { + if tgt.AccountID == "333333333333" { + t.Error("suspended account 333333333333 should be excluded") + } + } + + // Verify self is included + found := false + for _, tgt := range targets { + if tgt.AccountID == "111111111111" { + found = true + break + } + } + if !found { + t.Error("self account 111111111111 not found in targets") + } +} + +func TestResolveTargets_SelfInExplicitTargets_NotAssumed(t *testing.T) { + // If the caller's own account appears in --targets, it should use baseCfg (no AssumeRole). + stsClient := &mockSTSClient{ + callerAccount: "111111111111", + assumeResults: map[string]*sts.AssumeRoleOutput{ + "222222222222": assumeRoleOK("222222222222"), + }, + // No AssumeRole result for self — it should not be called + assumeErrors: map[string]error{ + "111111111111": fmt.Errorf("should not assume role for self"), + }, + } + baseCfg := aws.Config{Region: "us-east-1"} + opts := ScanOptions{ + Targets: []string{"111111111111", "222222222222"}, + Role: "AuditRole", + } + + targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + + if len(errs) != 0 { + t.Fatalf("expected no errors, got %d: %v", len(errs), errs) + } + // Self (from targets list, no duplicate) + 222 + if len(targets) != 2 { + t.Fatalf("expected 2 targets, got %d", len(targets)) + } +} diff --git a/internal/providers/aws/scanner.go b/internal/providers/aws/scanner.go index e28cbab..34f716c 100644 --- a/internal/providers/aws/scanner.go +++ b/internal/providers/aws/scanner.go @@ -34,6 +34,13 @@ type ScanOptions struct { SkipCosts bool ExcludeTags map[string]string MinUptimeDays int + + // Multi-target options + Targets []string + Role string + ExternalID string + OrgScan bool + SkipSelf bool } // DefaultScanOptions returns sensible defaults. From cd49f917373677ac0e637b63928cf04686d88cff Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:23:00 +0100 Subject: [PATCH 24/39] Refactor Scan() for parallel multi-target scanning --- internal/providers/aws/scanner.go | 151 +++++++++++++++++++++++------- 1 file changed, 117 insertions(+), 34 deletions(-) diff --git a/internal/providers/aws/scanner.go b/internal/providers/aws/scanner.go index 34f716c..b9a0986 100644 --- a/internal/providers/aws/scanner.go +++ b/internal/providers/aws/scanner.go @@ -16,6 +16,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/costexplorer" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/eks" + "github.com/aws/aws-sdk-go-v2/service/organizations" "github.com/aws/aws-sdk-go-v2/service/sagemaker" "github.com/aws/aws-sdk-go-v2/service/sts" @@ -50,7 +51,7 @@ func DefaultScanOptions() ScanOptions { } } -// Scan performs a full GPU audit of the AWS account. +// Scan performs a full GPU audit across one or more AWS accounts. func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { start := time.Now() @@ -65,13 +66,26 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { return nil, fmt.Errorf("loading AWS config: %w", err) } - // Get account ID + // Resolve targets (accounts to scan) stsClient := sts.NewFromConfig(cfg) - identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) - if err != nil { - return nil, fmt.Errorf("getting caller identity: %w", err) + + var orgClient OrgClient + if opts.OrgScan { + orgClient = organizations.NewFromConfig(cfg) + } + + targets, targetErrors := ResolveTargets(ctx, cfg, stsClient, orgClient, opts) + + // Print target errors to stderr and check for fatal failure + for _, te := range targetErrors { + fmt.Fprintf(os.Stderr, " warning: failed to resolve target %s: %v\n", te.AccountID, te.Err) } - accountID := aws.ToString(identity.Account) + if len(targets) == 0 { + return nil, fmt.Errorf("no scannable targets resolved (errors: %d)", len(targetErrors)) + } + + // Determine the caller account from the first target + callerAccount := targets[0].AccountID // Determine regions to scan regions := opts.Regions @@ -82,46 +96,55 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { } } - fmt.Fprintf(os.Stderr," Scanning %d regions for GPU instances...\n", len(regions)) + if len(targets) > 1 { + fmt.Fprintf(os.Stderr, " Scanning %d accounts across %d regions for GPU instances...\n", len(targets), len(regions)) + } else { + fmt.Fprintf(os.Stderr, " Scanning %d regions for GPU instances...\n", len(regions)) + } - // Scan all regions concurrently - type regionResult struct { - region string + // Scan all targets in parallel + type targetResult struct { instances []models.GPUInstance + regions []string err error } - results := make(chan regionResult, len(regions)) + targetResults := make(chan targetResult, len(targets)) var wg sync.WaitGroup - for _, region := range regions { + for _, t := range targets { wg.Add(1) - go func(r string) { + go func(target Target) { defer wg.Done() - instances, err := scanRegion(ctx, cfg, accountID, r, opts) - results <- regionResult{region: r, instances: instances, err: err} - }(region) + instances, scannedRegions, scanErr := scanTarget(ctx, target, regions, opts) + targetResults <- targetResult{instances: instances, regions: scannedRegions, err: scanErr} + }(t) } go func() { wg.Wait() - close(results) + close(targetResults) }() var allInstances []models.GPUInstance - var scannedRegions []string + regionSet := make(map[string]bool) - for res := range results { + for res := range targetResults { if res.err != nil { - fmt.Fprintf(os.Stderr," warning: error scanning %s: %v\n", res.region, res.err) + fmt.Fprintf(os.Stderr, " warning: target scan error: %v\n", res.err) continue } - if len(res.instances) > 0 { - allInstances = append(allInstances, res.instances...) - scannedRegions = append(scannedRegions, res.region) + allInstances = append(allInstances, res.instances...) + for _, r := range res.regions { + regionSet[r] = true } } + var scannedRegions []string + for r := range regionSet { + scannedRegions = append(scannedRegions, r) + } + // Filter by excluded tags if len(opts.ExcludeTags) > 0 { filtered := allInstances[:0] @@ -139,14 +162,6 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { } } - // Enrich with Cost Explorer data (account-level, not per-region) - if !opts.SkipCosts && len(allInstances) > 0 { - ceClient := costexplorer.NewFromConfig(cfg) - if err := EnrichCostData(ctx, ceClient, allInstances); err != nil { - fmt.Fprintf(os.Stderr," warning: could not enrich cost data: %v\n", err) - } - } - // Run analysis analysis.AnalyzeAll(allInstances) @@ -167,14 +182,82 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { // Build summary summary := BuildSummary(allInstances) - return &models.ScanResult{ + result := &models.ScanResult{ Timestamp: start, - AccountID: accountID, + AccountID: callerAccount, Regions: scannedRegions, ScanDuration: time.Since(start).Round(time.Millisecond).String(), Instances: allInstances, Summary: summary, - }, nil + } + + // Populate multi-target metadata when multiple targets are involved + isMultiTarget := len(targets) > 1 || len(targetErrors) > 0 + if isMultiTarget { + for _, t := range targets { + result.Targets = append(result.Targets, t.AccountID) + } + result.TargetSummaries = BuildTargetSummaries(allInstances) + for _, te := range targetErrors { + result.TargetErrors = append(result.TargetErrors, models.TargetErrorInfo{ + Target: te.AccountID, + Error: te.Err.Error(), + }) + } + } + + return result, nil +} + +// scanTarget scans all regions for a single target account, including +// Cost Explorer enrichment (which is account-scoped). +func scanTarget(ctx context.Context, target Target, regions []string, opts ScanOptions) ([]models.GPUInstance, []string, error) { + type regionResult struct { + region string + instances []models.GPUInstance + err error + } + + results := make(chan regionResult, len(regions)) + var wg sync.WaitGroup + + for _, region := range regions { + wg.Add(1) + go func(r string) { + defer wg.Done() + instances, err := scanRegion(ctx, target.Config, target.AccountID, r, opts) + results <- regionResult{region: r, instances: instances, err: err} + }(region) + } + + go func() { + wg.Wait() + close(results) + }() + + var allInstances []models.GPUInstance + var scannedRegions []string + + for res := range results { + if res.err != nil { + fmt.Fprintf(os.Stderr, " warning: error scanning %s in account %s: %v\n", res.region, target.AccountID, res.err) + continue + } + if len(res.instances) > 0 { + allInstances = append(allInstances, res.instances...) + scannedRegions = append(scannedRegions, res.region) + } + } + + // Enrich with Cost Explorer data (account-scoped) + if !opts.SkipCosts && len(allInstances) > 0 { + ceClient := costexplorer.NewFromConfig(target.Config) + if err := EnrichCostData(ctx, ceClient, allInstances); err != nil { + fmt.Fprintf(os.Stderr, " warning: could not enrich cost data for account %s: %v\n", target.AccountID, err) + } + } + + return allInstances, scannedRegions, nil } func scanRegion(ctx context.Context, cfg aws.Config, accountID, region string, opts ScanOptions) ([]models.GPUInstance, error) { From 698b0f6dc0e0dfac6673315127b28b35d20384c2 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:24:43 +0100 Subject: [PATCH 25/39] Add --targets, --role, --org, --external-id, --skip-self flags to scan command --- cmd/gpuaudit/main.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/cmd/gpuaudit/main.go b/cmd/gpuaudit/main.go index a057c1c..6f2c708 100644 --- a/cmd/gpuaudit/main.go +++ b/cmd/gpuaudit/main.go @@ -57,6 +57,11 @@ var ( scanPromEndpoint string scanExcludeTags []string scanMinUptimeDays int + scanTargets []string + scanRole string + scanExternalID string + scanOrg bool + scanSkipSelf bool ) // --- diff command --- @@ -92,6 +97,12 @@ func init() { scanCmd.Flags().StringVar(&scanPromEndpoint, "prom-endpoint", "", "In-cluster Prometheus service as namespace/service:port (e.g., monitoring/prometheus:9090)") scanCmd.Flags().StringSliceVar(&scanExcludeTags, "exclude-tag", nil, "Exclude instances matching tag (key=value, repeatable)") scanCmd.Flags().IntVar(&scanMinUptimeDays, "min-uptime-days", 0, "Only flag instances running for at least this many days") + scanCmd.Flags().StringSliceVar(&scanTargets, "targets", nil, "Account IDs to scan (comma-separated)") + scanCmd.Flags().StringVar(&scanRole, "role", "", "IAM role name to assume in each target") + scanCmd.Flags().StringVar(&scanExternalID, "external-id", "", "STS external ID for cross-account role assumption") + scanCmd.Flags().BoolVar(&scanOrg, "org", false, "Auto-discover all accounts from AWS Organizations") + scanCmd.Flags().BoolVar(&scanSkipSelf, "skip-self", false, "Exclude the caller's own account from the scan") + scanCmd.MarkFlagsMutuallyExclusive("targets", "org") diffCmd.Flags().StringVar(&diffFormat, "format", "table", "Output format: table, json") @@ -109,6 +120,10 @@ func runScan(cmd *cobra.Command, args []string) error { ctx := context.Background() + if (len(scanTargets) > 0 || scanOrg) && scanRole == "" { + return fmt.Errorf("--role is required when using --targets or --org") + } + opts := awsprovider.DefaultScanOptions() opts.Profile = scanProfile opts.Regions = scanRegions @@ -118,6 +133,11 @@ func runScan(cmd *cobra.Command, args []string) error { opts.SkipCosts = scanSkipCosts opts.ExcludeTags = parseExcludeTags(scanExcludeTags) opts.MinUptimeDays = scanMinUptimeDays + opts.Targets = scanTargets + opts.Role = scanRole + opts.ExternalID = scanExternalID + opts.OrgScan = scanOrg + opts.SkipSelf = scanSkipSelf awsAvailable := true result, err := awsprovider.Scan(ctx, opts) From 19fbbb722599a31f489d367591a2aec8719f9795 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:26:29 +0100 Subject: [PATCH 26/39] Add per-target summary table and target column to table formatter --- internal/output/table.go | 76 ++++++++++++++++++++++++++++++++-------- 1 file changed, 61 insertions(+), 15 deletions(-) diff --git a/internal/output/table.go b/internal/output/table.go index 3f73232..1e60464 100644 --- a/internal/output/table.go +++ b/internal/output/table.go @@ -34,6 +34,8 @@ func FormatTable(w io.Writer, result *models.ScanResult) { fmt.Fprintf(w, " │ Estimated monthly waste: $%-10.0f (%4.0f%%) │\n", s.TotalEstimatedWaste, s.WastePercent) fmt.Fprintf(w, " └──────────────────────────────────────────────────────────┘\n\n") + printTargetSummary(w, result) + if s.TotalInstances == 0 { fmt.Fprintf(w, " No GPU instances found.\n\n") return @@ -42,14 +44,16 @@ func FormatTable(w io.Writer, result *models.ScanResult) { // Group instances by severity critical, warning, healthy := groupBySeverity(result.Instances) + multiTarget := len(result.TargetSummaries) > 1 + if len(critical) > 0 { fmt.Fprintf(w, " CRITICAL — %d instance(s), $%.0f/mo potential savings\n\n", len(critical), sumSavings(critical)) - printInstanceTable(w, critical) + printInstanceTable(w, critical, multiTarget) } if len(warning) > 0 { fmt.Fprintf(w, " WARNING — %d instance(s), $%.0f/mo potential savings\n\n", len(warning), sumSavings(warning)) - printInstanceTable(w, warning) + printInstanceTable(w, warning, multiTarget) } if len(healthy) > 0 { @@ -57,17 +61,54 @@ func FormatTable(w io.Writer, result *models.ScanResult) { } } -func printInstanceTable(w io.Writer, instances []models.GPUInstance) { - // Header - fmt.Fprintf(w, " %-36s %-26s %10s %-16s %s\n", - "Instance", "Type", "Monthly", "Signal", "Recommendation") - fmt.Fprintf(w, " %s %s %s %s %s\n", - strings.Repeat("─", 36), - strings.Repeat("─", 26), - strings.Repeat("─", 10), - strings.Repeat("─", 16), - strings.Repeat("─", 50), - ) +func printTargetSummary(w io.Writer, result *models.ScanResult) { + if len(result.TargetSummaries) < 2 { + return + } + + fmt.Fprintf(w, " By Target\n") + fmt.Fprintf(w, " ┌──────────────┬───────────┬───────────┬───────────┬───────┐\n") + fmt.Fprintf(w, " │ Target │ Instances │ Spend/mo │ Waste/mo │ Waste │\n") + fmt.Fprintf(w, " ├──────────────┼───────────┼───────────┼───────────┼───────┤\n") + for _, ts := range result.TargetSummaries { + fmt.Fprintf(w, " │ %-12s │ %9d │ $%8.0f │ $%8.0f │ %4.0f%% │\n", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent) + } + fmt.Fprintf(w, " └──────────────┴───────────┴───────────┴───────────┴───────┘\n\n") + + if len(result.TargetErrors) > 0 { + fmt.Fprintf(w, " Warnings\n") + for _, te := range result.TargetErrors { + fmt.Fprintf(w, " ✗ %s — %s\n", te.Target, te.Error) + } + fmt.Fprintln(w) + } +} + +func printInstanceTable(w io.Writer, instances []models.GPUInstance, multiTarget bool) { + if multiTarget { + fmt.Fprintf(w, " %-36s %-14s %-26s %10s %-16s %s\n", + "Instance", "Target", "Type", "Monthly", "Signal", "Recommendation") + fmt.Fprintf(w, " %s %s %s %s %s %s\n", + strings.Repeat("─", 36), + strings.Repeat("─", 14), + strings.Repeat("─", 26), + strings.Repeat("─", 10), + strings.Repeat("─", 16), + strings.Repeat("─", 50), + ) + } else { + fmt.Fprintf(w, " %-36s %-26s %10s %-16s %s\n", + "Instance", "Type", "Monthly", "Signal", "Recommendation") + fmt.Fprintf(w, " %s %s %s %s %s\n", + strings.Repeat("─", 36), + strings.Repeat("─", 26), + strings.Repeat("─", 10), + strings.Repeat("─", 16), + strings.Repeat("─", 50), + ) + } for _, inst := range instances { name := inst.Name @@ -94,8 +135,13 @@ func printInstanceTable(w io.Writer, instances []models.GPUInstance) { rec = inst.Recommendations[0].Description } - fmt.Fprintf(w, " %-36s %-26s $%9.0f %-16s %s\n", - name, typeDesc, inst.MonthlyCost, signal, rec) + if multiTarget { + fmt.Fprintf(w, " %-36s %-14s %-26s $%9.0f %-16s %s\n", + name, inst.AccountID, typeDesc, inst.MonthlyCost, signal, rec) + } else { + fmt.Fprintf(w, " %-36s %-26s $%9.0f %-16s %s\n", + name, typeDesc, inst.MonthlyCost, signal, rec) + } } fmt.Fprintln(w) } From 8737a44e2a56d3bc32645883f3195bcea50f7abb Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:27:59 +0100 Subject: [PATCH 27/39] Add per-target summaries to markdown and Slack formatters --- internal/output/markdown.go | 43 ++++++++++++++++++++++++++++++++----- internal/output/slack.go | 21 ++++++++++++++++++ 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/internal/output/markdown.go b/internal/output/markdown.go index 13290bb..58995c5 100644 --- a/internal/output/markdown.go +++ b/internal/output/markdown.go @@ -31,14 +31,41 @@ func FormatMarkdown(w io.Writer, result *models.ScanResult) { fmt.Fprintf(w, "| Warning | %d |\n", s.WarningCount) fmt.Fprintf(w, "| Healthy | %d |\n\n", s.HealthyCount) + // Per-target breakdown + if len(result.TargetSummaries) > 1 { + fmt.Fprintf(w, "## By Target\n\n") + fmt.Fprintf(w, "| Target | Instances | Spend/mo | Waste/mo | Waste |\n") + fmt.Fprintf(w, "|---|---|---|---|---|\n") + for _, ts := range result.TargetSummaries { + fmt.Fprintf(w, "| %s | %d | $%.0f | $%.0f | %.0f%% |\n", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent) + } + fmt.Fprintln(w) + } + + if len(result.TargetErrors) > 0 { + fmt.Fprintf(w, "## Warnings\n\n") + for _, te := range result.TargetErrors { + fmt.Fprintf(w, "- **%s** — %s\n", te.Target, te.Error) + } + fmt.Fprintln(w) + } + if s.TotalInstances == 0 { fmt.Fprintf(w, "No GPU instances found.\n") return } fmt.Fprintf(w, "## Findings\n\n") - fmt.Fprintf(w, "| Instance | Type | Monthly Cost | Signal | Savings | Recommendation |\n") - fmt.Fprintf(w, "|---|---|---|---|---|---|\n") + multiTarget := len(result.TargetSummaries) > 1 + if multiTarget { + fmt.Fprintf(w, "| Instance | Target | Type | Monthly Cost | Signal | Savings | Recommendation |\n") + fmt.Fprintf(w, "|---|---|---|---|---|---|---|\n") + } else { + fmt.Fprintf(w, "| Instance | Type | Monthly Cost | Signal | Savings | Recommendation |\n") + fmt.Fprintf(w, "|---|---|---|---|---|---|\n") + } for _, inst := range result.Instances { name := inst.Name @@ -61,8 +88,14 @@ func FormatMarkdown(w io.Writer, result *models.ScanResult) { savings = fmt.Sprintf("$%.0f/mo", inst.EstimatedSavings) } - fmt.Fprintf(w, "| %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", - name, inst.InstanceType, inst.GPUCount, inst.GPUModel, - inst.MonthlyCost, signal, savings, rec) + if multiTarget { + fmt.Fprintf(w, "| %s | %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", + name, inst.AccountID, inst.InstanceType, inst.GPUCount, inst.GPUModel, + inst.MonthlyCost, signal, savings, rec) + } else { + fmt.Fprintf(w, "| %s | %s (%d× %s) | $%.0f | %s | %s | %s |\n", + name, inst.InstanceType, inst.GPUCount, inst.GPUModel, + inst.MonthlyCost, signal, savings, rec) + } } } diff --git a/internal/output/slack.go b/internal/output/slack.go index 530afe7..f8fc334 100644 --- a/internal/output/slack.go +++ b/internal/output/slack.go @@ -34,6 +34,27 @@ func FormatSlack(w io.Writer, result *models.ScanResult) error { blocks = append(blocks, map[string]any{"type": "divider"}) + // Per-target breakdown + if len(result.TargetSummaries) > 1 { + lines := []string{"*By Target*"} + for _, ts := range result.TargetSummaries { + lines = append(lines, fmt.Sprintf("• `%s` — %d instances, $%.0f/mo spend, $%.0f/mo waste (%.0f%%)", + ts.Target, ts.TotalInstances, ts.TotalMonthlyCost, + ts.TotalEstimatedWaste, ts.WastePercent)) + } + blocks = append(blocks, slackSection(strings.Join(lines, "\n"))) + blocks = append(blocks, map[string]any{"type": "divider"}) + } + + // Target errors + if len(result.TargetErrors) > 0 { + lines := []string{":warning: *Target Warnings*"} + for _, te := range result.TargetErrors { + lines = append(lines, fmt.Sprintf("• `%s` — %s", te.Target, te.Error)) + } + blocks = append(blocks, slackSection(strings.Join(lines, "\n"))) + } + // Critical findings critical, warning, _ := groupBySeverity(result.Instances) From 68fbeaa7789cecca2fceb59383b7c01f498aa04d Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:29:04 +0100 Subject: [PATCH 28/39] Add cross-account and Organizations permissions to iam-policy output --- cmd/gpuaudit/main.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/cmd/gpuaudit/main.go b/cmd/gpuaudit/main.go index 6f2c708..271d094 100644 --- a/cmd/gpuaudit/main.go +++ b/cmd/gpuaudit/main.go @@ -360,8 +360,21 @@ var iamPolicyCmd = &cobra.Command{ }, "Resource": "*", }, + { + "Sid": "GPUAuditCrossAccount", + "Effect": "Allow", + "Action": "sts:AssumeRole", + "Resource": "arn:aws:iam::*:role/gpuaudit-reader", + }, + { + "Sid": "GPUAuditOrganizations", + "Effect": "Allow", + "Action": "organizations:ListAccounts", + "Resource": "*", + }, }, } + fmt.Fprintln(os.Stdout, "// The last two statements (CrossAccount, Organizations) are only needed for --targets or --org scanning.") enc := json.NewEncoder(os.Stdout) enc.SetIndent("", " ") enc.Encode(policy) From 1f58d921199f56af11c8af0389f1b38d8071fb27 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:30:06 +0100 Subject: [PATCH 29/39] Add multi-account scanning docs to README --- README.md | 100 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/README.md b/README.md index c2396c7..b7592b5 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,106 @@ gpuaudit diff scan-apr-08.json scan-apr-15.json Matches instances by ID. Reports added, removed, and changed instances with per-field diffs (instance type, pricing model, cost, state, GPU allocation, waste severity). +## Multi-Account Scanning + +Scan multiple AWS accounts in a single invocation using STS AssumeRole. + +### Prerequisites + +Deploy a read-only IAM role (`gpuaudit-reader`) to each target account. See [Cross-Account Role Setup](#cross-account-role-setup) below. + +### Usage + +```bash +# Scan specific accounts +gpuaudit scan --targets 111111111111,222222222222 --role gpuaudit-reader + +# Scan entire AWS Organization +gpuaudit scan --org --role gpuaudit-reader + +# Exclude management account +gpuaudit scan --org --role gpuaudit-reader --skip-self + +# With external ID +gpuaudit scan --targets 111111111111 --role gpuaudit-reader --external-id my-secret +``` + +### Cross-Account Role Setup + +#### Terraform + +```hcl +variable "management_account_id" { + description = "AWS account ID where gpuaudit runs" + type = string +} + +resource "aws_iam_role" "gpuaudit_reader" { + name = "gpuaudit-reader" + assume_role_policy = jsonencode({ + Version = "2012-10-17" + Statement = [{ + Effect = "Allow" + Principal = { AWS = "arn:aws:iam::${var.management_account_id}:root" } + Action = "sts:AssumeRole" + }] + }) +} + +resource "aws_iam_role_policy" "gpuaudit_reader" { + name = "gpuaudit-policy" + role = aws_iam_role.gpuaudit_reader.id + policy = file("gpuaudit-policy.json") # from: gpuaudit iam-policy > gpuaudit-policy.json +} +``` + +Deploy to all accounts using Terraform workspaces or CloudFormation StackSets. + +#### CloudFormation StackSet + +```yaml +AWSTemplateFormatVersion: "2010-09-09" +Parameters: + ManagementAccountId: + Type: String +Resources: + GpuAuditRole: + Type: AWS::IAM::Role + Properties: + RoleName: gpuaudit-reader + AssumeRolePolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Principal: + AWS: !Sub "arn:aws:iam::${ManagementAccountId}:root" + Action: sts:AssumeRole + Policies: + - PolicyName: gpuaudit-policy + PolicyDocument: + Version: "2012-10-17" + Statement: + - Effect: Allow + Action: + - ec2:DescribeInstances + - ec2:DescribeInstanceTypes + - ec2:DescribeRegions + - sagemaker:ListEndpoints + - sagemaker:DescribeEndpoint + - sagemaker:DescribeEndpointConfig + - eks:ListClusters + - eks:ListNodegroups + - eks:DescribeNodegroup + - cloudwatch:GetMetricData + - cloudwatch:GetMetricStatistics + - cloudwatch:ListMetrics + - ce:GetCostAndUsage + - ce:GetReservationUtilization + - ce:GetSavingsPlansUtilization + - pricing:GetProducts + Resource: "*" +``` + ## IAM permissions gpuaudit is read-only. It never modifies your infrastructure. Generate the minimal IAM policy: From 60015983c24d5e38635bed15fdbc5818d8fcc9ba Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sat, 18 Apr 2026 15:41:41 +0100 Subject: [PATCH 30/39] Fix callerAccount bug, deduplicate severity logic, clean up dead code ResolveTargets now returns selfAccount separately so Scan() always gets the correct caller identity regardless of --skip-self. Extracted models.MaxSeverity to replace three copies of severity classification. Removed dead error return from scanTarget. Added missing copyright headers. --- internal/models/models.go | 17 ++++++++++++++ internal/output/table.go | 19 +-------------- internal/providers/aws/multiaccount.go | 16 +++++-------- internal/providers/aws/multiaccount_test.go | 18 +++++++++----- internal/providers/aws/scanner.go | 18 ++++---------- internal/providers/aws/summary.go | 26 ++++----------------- internal/providers/aws/summary_test.go | 3 +++ 7 files changed, 49 insertions(+), 68 deletions(-) diff --git a/internal/models/models.go b/internal/models/models.go index c523838..7b8a0e6 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -158,5 +158,22 @@ type TargetErrorInfo struct { Error string `json:"error"` } +// MaxSeverity returns the highest severity among the given waste signals. +func MaxSeverity(signals []WasteSignal) Severity { + max := Severity("") + for _, s := range signals { + if s.Severity == SeverityCritical { + return SeverityCritical + } + if s.Severity == SeverityWarning { + max = SeverityWarning + } + if s.Severity == SeverityInfo && max == "" { + max = SeverityInfo + } + } + return max +} + // Ptr is a convenience helper for creating pointer values in tests and literals. func Ptr[T any](v T) *T { return &v } diff --git a/internal/output/table.go b/internal/output/table.go index 1e60464..ece729c 100644 --- a/internal/output/table.go +++ b/internal/output/table.go @@ -148,8 +148,7 @@ func printInstanceTable(w io.Writer, instances []models.GPUInstance, multiTarget func groupBySeverity(instances []models.GPUInstance) (critical, warning, healthy []models.GPUInstance) { for _, inst := range instances { - maxSev := maxSeverity(inst.WasteSignals) - switch maxSev { + switch models.MaxSeverity(inst.WasteSignals) { case models.SeverityCritical: critical = append(critical, inst) case models.SeverityWarning: @@ -171,22 +170,6 @@ func groupBySeverity(instances []models.GPUInstance) (critical, warning, healthy return } -func maxSeverity(signals []models.WasteSignal) models.Severity { - max := models.Severity("") - for _, s := range signals { - if s.Severity == models.SeverityCritical { - return models.SeverityCritical - } - if s.Severity == models.SeverityWarning { - max = models.SeverityWarning - } - if s.Severity == models.SeverityInfo && max == "" { - max = models.SeverityInfo - } - } - return max -} - func sumSavings(instances []models.GPUInstance) float64 { total := 0.0 for _, inst := range instances { diff --git a/internal/providers/aws/multiaccount.go b/internal/providers/aws/multiaccount.go index fd8a99c..298c475 100644 --- a/internal/providers/aws/multiaccount.go +++ b/internal/providers/aws/multiaccount.go @@ -46,13 +46,13 @@ type OrgClient interface { // - --skip-self: exclude caller's account // - Self account is never AssumeRole'd — uses original credentials // - Failed AssumeRole calls are collected as TargetError, not fatal -func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient, orgClient OrgClient, opts ScanOptions) ([]Target, []TargetError) { +func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient, orgClient OrgClient, opts ScanOptions) (selfAccount string, targets []Target, targetErrors []TargetError) { // Identify the caller's own account. identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) if err != nil { - return nil, []TargetError{{AccountID: "unknown", Err: fmt.Errorf("GetCallerIdentity: %w", err)}} + return "", nil, []TargetError{{AccountID: "unknown", Err: fmt.Errorf("GetCallerIdentity: %w", err)}} } - selfAccount := aws.ToString(identity.Account) + selfAccount = aws.ToString(identity.Account) // Determine the list of account IDs to scan. var accountIDs []string @@ -61,7 +61,7 @@ func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient case opts.OrgScan: activeAccounts, listErr := listActiveOrgAccounts(ctx, orgClient) if listErr != nil { - return nil, []TargetError{{AccountID: "org", Err: fmt.Errorf("ListAccounts: %w", listErr)}} + return selfAccount, nil, []TargetError{{AccountID: "org", Err: fmt.Errorf("ListAccounts: %w", listErr)}} } accountIDs = activeAccounts case len(opts.Targets) > 0: @@ -79,13 +79,9 @@ func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient } default: // No multi-target flags — scan self only. - return []Target{{AccountID: selfAccount, Config: baseCfg}}, nil + return selfAccount, []Target{{AccountID: selfAccount, Config: baseCfg}}, nil } - // Resolve credentials for each account. - var targets []Target - var targetErrors []TargetError - for _, acctID := range accountIDs { if opts.SkipSelf && acctID == selfAccount { continue @@ -106,7 +102,7 @@ func ResolveTargets(ctx context.Context, baseCfg aws.Config, stsClient STSClient targets = append(targets, Target{AccountID: acctID, Config: cfg}) } - return targets, targetErrors + return selfAccount, targets, targetErrors } // assumeRole assumes a role in the given account and returns an aws.Config diff --git a/internal/providers/aws/multiaccount_test.go b/internal/providers/aws/multiaccount_test.go index 2d40cce..bc2ba11 100644 --- a/internal/providers/aws/multiaccount_test.go +++ b/internal/providers/aws/multiaccount_test.go @@ -95,11 +95,14 @@ func TestResolveTargets_NoTargets_ReturnsSelfOnly(t *testing.T) { baseCfg := aws.Config{Region: "us-east-1"} opts := ScanOptions{} - targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + self, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) if len(errs) != 0 { t.Fatalf("expected no errors, got %d: %v", len(errs), errs) } + if self != "111111111111" { + t.Errorf("expected self account 111111111111, got %s", self) + } if len(targets) != 1 { t.Fatalf("expected 1 target (self), got %d", len(targets)) } @@ -122,7 +125,7 @@ func TestResolveTargets_ExplicitTargets_ReturnsSelfPlusAssumed(t *testing.T) { Role: "AuditRole", } - targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + _, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) if len(errs) != 0 { t.Fatalf("expected no errors, got %d", len(errs)) @@ -173,11 +176,14 @@ func TestResolveTargets_ExplicitTargets_SkipSelf(t *testing.T) { SkipSelf: true, } - targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + self, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) if len(errs) != 0 { t.Fatalf("expected no errors, got %d", len(errs)) } + if self != "111111111111" { + t.Errorf("expected self account 111111111111, got %s", self) + } if len(targets) != 1 { t.Fatalf("expected 1 target (no self), got %d", len(targets)) } @@ -202,7 +208,7 @@ func TestResolveTargets_PartialFailure(t *testing.T) { Role: "AuditRole", } - targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + _, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) // Self + 222 succeeded, 333 failed if len(targets) != 2 { @@ -238,7 +244,7 @@ func TestResolveTargets_OrgDiscovery(t *testing.T) { Role: "AuditRole", } - targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, orgClient, opts) + _, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, orgClient, opts) if len(errs) != 0 { t.Fatalf("expected no errors, got %d: %v", len(errs), errs) @@ -286,7 +292,7 @@ func TestResolveTargets_SelfInExplicitTargets_NotAssumed(t *testing.T) { Role: "AuditRole", } - targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) + _, targets, errs := ResolveTargets(context.Background(), baseCfg, stsClient, nil, opts) if len(errs) != 0 { t.Fatalf("expected no errors, got %d: %v", len(errs), errs) diff --git a/internal/providers/aws/scanner.go b/internal/providers/aws/scanner.go index b9a0986..a678867 100644 --- a/internal/providers/aws/scanner.go +++ b/internal/providers/aws/scanner.go @@ -74,7 +74,7 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { orgClient = organizations.NewFromConfig(cfg) } - targets, targetErrors := ResolveTargets(ctx, cfg, stsClient, orgClient, opts) + callerAccount, targets, targetErrors := ResolveTargets(ctx, cfg, stsClient, orgClient, opts) // Print target errors to stderr and check for fatal failure for _, te := range targetErrors { @@ -84,9 +84,6 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { return nil, fmt.Errorf("no scannable targets resolved (errors: %d)", len(targetErrors)) } - // Determine the caller account from the first target - callerAccount := targets[0].AccountID - // Determine regions to scan regions := opts.Regions if len(regions) == 0 { @@ -106,7 +103,6 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { type targetResult struct { instances []models.GPUInstance regions []string - err error } targetResults := make(chan targetResult, len(targets)) @@ -116,8 +112,8 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { wg.Add(1) go func(target Target) { defer wg.Done() - instances, scannedRegions, scanErr := scanTarget(ctx, target, regions, opts) - targetResults <- targetResult{instances: instances, regions: scannedRegions, err: scanErr} + instances, scannedRegions := scanTarget(ctx, target, regions, opts) + targetResults <- targetResult{instances: instances, regions: scannedRegions} }(t) } @@ -130,10 +126,6 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { regionSet := make(map[string]bool) for res := range targetResults { - if res.err != nil { - fmt.Fprintf(os.Stderr, " warning: target scan error: %v\n", res.err) - continue - } allInstances = append(allInstances, res.instances...) for _, r := range res.regions { regionSet[r] = true @@ -211,7 +203,7 @@ func Scan(ctx context.Context, opts ScanOptions) (*models.ScanResult, error) { // scanTarget scans all regions for a single target account, including // Cost Explorer enrichment (which is account-scoped). -func scanTarget(ctx context.Context, target Target, regions []string, opts ScanOptions) ([]models.GPUInstance, []string, error) { +func scanTarget(ctx context.Context, target Target, regions []string, opts ScanOptions) ([]models.GPUInstance, []string) { type regionResult struct { region string instances []models.GPUInstance @@ -257,7 +249,7 @@ func scanTarget(ctx context.Context, target Target, regions []string, opts ScanO } } - return allInstances, scannedRegions, nil + return allInstances, scannedRegions } func scanRegion(ctx context.Context, cfg aws.Config, accountID, region string, opts ScanOptions) ([]models.GPUInstance, error) { diff --git a/internal/providers/aws/summary.go b/internal/providers/aws/summary.go index bae351a..5c6f715 100644 --- a/internal/providers/aws/summary.go +++ b/internal/providers/aws/summary.go @@ -1,3 +1,6 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + package aws import ( @@ -16,18 +19,7 @@ func BuildSummary(instances []models.GPUInstance) models.ScanSummary { s.TotalMonthlyCost += inst.MonthlyCost s.TotalEstimatedWaste += inst.EstimatedSavings - maxSeverity := models.Severity("") - for _, sig := range inst.WasteSignals { - if sig.Severity == models.SeverityCritical { - maxSeverity = models.SeverityCritical - } else if sig.Severity == models.SeverityWarning && maxSeverity != models.SeverityCritical { - maxSeverity = models.SeverityWarning - } else if sig.Severity == models.SeverityInfo && maxSeverity == "" { - maxSeverity = models.SeverityInfo - } - } - - switch maxSeverity { + switch models.MaxSeverity(inst.WasteSignals) { case models.SeverityCritical: s.CriticalCount++ case models.SeverityWarning: @@ -67,15 +59,7 @@ func BuildTargetSummaries(instances []models.GPUInstance) []models.TargetSummary ts.TotalMonthlyCost += inst.MonthlyCost ts.TotalEstimatedWaste += inst.EstimatedSavings - maxSev := models.Severity("") - for _, sig := range inst.WasteSignals { - if sig.Severity == models.SeverityCritical { - maxSev = models.SeverityCritical - } else if sig.Severity == models.SeverityWarning && maxSev != models.SeverityCritical { - maxSev = models.SeverityWarning - } - } - switch maxSev { + switch models.MaxSeverity(inst.WasteSignals) { case models.SeverityCritical: ts.CriticalCount++ case models.SeverityWarning: diff --git a/internal/providers/aws/summary_test.go b/internal/providers/aws/summary_test.go index b429e39..24702ec 100644 --- a/internal/providers/aws/summary_test.go +++ b/internal/providers/aws/summary_test.go @@ -1,3 +1,6 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + package aws import ( From 41b086748822f83450ab7a7fcb4e07f77d7dac1b Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 19:54:13 +0100 Subject: [PATCH 31/39] Add SpotHourlyCost field to GPUInstance model --- internal/models/models.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/internal/models/models.go b/internal/models/models.go index 7b8a0e6..9ee0bf7 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -86,10 +86,11 @@ type GPUInstance struct { InvocationCount *int64 `json:"invocation_count,omitempty"` // Cost - PricingModel string `json:"pricing_model"` // on-demand, spot, reserved, savings-plan - HourlyCost float64 `json:"hourly_cost"` - MonthlyCost float64 `json:"monthly_cost"` - MTDCost *float64 `json:"mtd_cost,omitempty"` + PricingModel string `json:"pricing_model"` // on-demand, spot, reserved, savings-plan + HourlyCost float64 `json:"hourly_cost"` + MonthlyCost float64 `json:"monthly_cost"` + SpotHourlyCost *float64 `json:"spot_hourly_cost,omitempty"` + MTDCost *float64 `json:"mtd_cost,omitempty"` // Analysis results (populated by analysis engine) WasteSignals []WasteSignal `json:"waste_signals,omitempty"` From b081710780f6f66ab6c59c326c7bfafb8e69b3c0 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 19:56:37 +0100 Subject: [PATCH 32/39] Implement EnrichSpotPrices with DescribeSpotPriceHistory --- internal/providers/aws/spot.go | 79 +++++++++++++++++++++ internal/providers/aws/spot_test.go | 106 ++++++++++++++++++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 internal/providers/aws/spot.go create mode 100644 internal/providers/aws/spot_test.go diff --git a/internal/providers/aws/spot.go b/internal/providers/aws/spot.go new file mode 100644 index 0000000..7f3281a --- /dev/null +++ b/internal/providers/aws/spot.go @@ -0,0 +1,79 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "os" + "strconv" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + + "github.com/gpuaudit/cli/internal/models" +) + +// SpotPriceClient is the subset of the EC2 API needed for spot price lookups. +type SpotPriceClient interface { + DescribeSpotPriceHistory(ctx context.Context, params *ec2.DescribeSpotPriceHistoryInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSpotPriceHistoryOutput, error) +} + +// EnrichSpotPrices fetches current spot prices for EC2 GPU instances and +// populates SpotHourlyCost on each instance where spot is available. +func EnrichSpotPrices(ctx context.Context, client SpotPriceClient, instances []models.GPUInstance) { + // Collect unique EC2 instance types. + typeSet := make(map[string]bool) + for _, inst := range instances { + if inst.Source == models.SourceEC2 { + typeSet[inst.InstanceType] = true + } + } + if len(typeSet) == 0 { + return + } + + instanceTypes := make([]ec2types.InstanceType, 0, len(typeSet)) + for t := range typeSet { + instanceTypes = append(instanceTypes, ec2types.InstanceType(t)) + } + + input := &ec2.DescribeSpotPriceHistoryInput{ + InstanceTypes: instanceTypes, + ProductDescriptions: []string{"Linux/UNIX"}, + StartTime: aws.Time(time.Now().Add(-1 * time.Hour)), + } + + out, err := client.DescribeSpotPriceHistory(ctx, input) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: could not fetch spot prices: %v\n", err) + return + } + + // Take the most recent price per instance type (API returns newest first). + latestPrice := make(map[string]float64) + for _, sp := range out.SpotPriceHistory { + itype := string(sp.InstanceType) + if _, seen := latestPrice[itype]; seen { + continue + } + price, err := strconv.ParseFloat(aws.ToString(sp.SpotPrice), 64) + if err != nil { + continue + } + latestPrice[itype] = price + } + + // Populate SpotHourlyCost on matching instances. + for i := range instances { + if instances[i].Source != models.SourceEC2 { + continue + } + if price, ok := latestPrice[instances[i].InstanceType]; ok { + instances[i].SpotHourlyCost = &price + } + } +} diff --git a/internal/providers/aws/spot_test.go b/internal/providers/aws/spot_test.go new file mode 100644 index 0000000..d82fcbb --- /dev/null +++ b/internal/providers/aws/spot_test.go @@ -0,0 +1,106 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + + "github.com/gpuaudit/cli/internal/models" +) + +type mockSpotPriceClient struct { + prices []ec2types.SpotPrice + err error +} + +func (m *mockSpotPriceClient) DescribeSpotPriceHistory(ctx context.Context, params *ec2.DescribeSpotPriceHistoryInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSpotPriceHistoryOutput, error) { + if m.err != nil { + return nil, m.err + } + return &ec2.DescribeSpotPriceHistoryOutput{ + SpotPriceHistory: m.prices, + }, nil +} + +func TestEnrichSpotPrices_PopulatesSpotCost(t *testing.T) { + client := &mockSpotPriceClient{ + prices: []ec2types.SpotPrice{ + { + InstanceType: ec2types.InstanceTypeG5Xlarge, + SpotPrice: aws.String("0.556"), + Timestamp: aws.Time(time.Now()), + }, + { + InstanceType: ec2types.InstanceTypeG5Xlarge, + SpotPrice: aws.String("0.500"), + Timestamp: aws.Time(time.Now().Add(-1 * time.Hour)), + }, + }, + } + instances := []models.GPUInstance{ + {InstanceID: "i-1", InstanceType: "g5.xlarge", Source: models.SourceEC2}, + {InstanceID: "i-2", InstanceType: "g5.2xlarge", Source: models.SourceEC2}, + } + + EnrichSpotPrices(context.Background(), client, instances) + + if instances[0].SpotHourlyCost == nil { + t.Fatal("expected spot price for g5.xlarge") + } + if *instances[0].SpotHourlyCost != 0.556 { + t.Errorf("expected 0.556, got %f", *instances[0].SpotHourlyCost) + } + if instances[1].SpotHourlyCost != nil { + t.Error("expected nil spot price for g5.2xlarge (not in API response)") + } +} + +func TestEnrichSpotPrices_SkipsNonEC2(t *testing.T) { + client := &mockSpotPriceClient{ + prices: []ec2types.SpotPrice{ + { + InstanceType: ec2types.InstanceTypeG5Xlarge, + SpotPrice: aws.String("0.556"), + Timestamp: aws.Time(time.Now()), + }, + }, + } + instances := []models.GPUInstance{ + {InstanceID: "ep-1", InstanceType: "ml.g5.xlarge", Source: models.SourceSageMakerEndpoint}, + } + + EnrichSpotPrices(context.Background(), client, instances) + + if instances[0].SpotHourlyCost != nil { + t.Error("expected nil spot price for SageMaker instance") + } +} + +func TestEnrichSpotPrices_HandlesAPIError(t *testing.T) { + client := &mockSpotPriceClient{ + err: fmt.Errorf("access denied"), + } + instances := []models.GPUInstance{ + {InstanceID: "i-1", InstanceType: "g5.xlarge", Source: models.SourceEC2}, + } + + EnrichSpotPrices(context.Background(), client, instances) + + if instances[0].SpotHourlyCost != nil { + t.Error("expected nil spot price after API error") + } +} + +func TestEnrichSpotPrices_EmptyInstances(t *testing.T) { + client := &mockSpotPriceClient{} + EnrichSpotPrices(context.Background(), client, nil) + EnrichSpotPrices(context.Background(), client, []models.GPUInstance{}) +} From 82ae997da4eb8b3736c3ed87c841a5ff809621de Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 19:58:30 +0100 Subject: [PATCH 33/39] Wire EnrichSpotPrices into scanRegion after EC2 discovery --- internal/providers/aws/scanner.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/providers/aws/scanner.go b/internal/providers/aws/scanner.go index a678867..15ed1a0 100644 --- a/internal/providers/aws/scanner.go +++ b/internal/providers/aws/scanner.go @@ -270,6 +270,7 @@ func scanRegion(ctx context.Context, cfg aws.Config, accountID, region string, o if err := EnrichEC2Metrics(ctx, cwClient, ec2Instances, opts.MetricWindow); err != nil { fmt.Fprintf(os.Stderr, " warning: could not enrich EC2 metrics in %s: %v\n", region, err) } + EnrichSpotPrices(ctx, ec2Client, ec2Instances) } allInstances = append(allInstances, ec2Instances...) } From 8f4973e1ed9b5922f39e91c3b52e5eae26d30124 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 20:00:17 +0100 Subject: [PATCH 34/39] Correct spot instance cost using live spot prices --- internal/providers/aws/spot.go | 13 ++++++-- internal/providers/aws/spot_test.go | 47 +++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/internal/providers/aws/spot.go b/internal/providers/aws/spot.go index 7f3281a..7bdd3b8 100644 --- a/internal/providers/aws/spot.go +++ b/internal/providers/aws/spot.go @@ -67,13 +67,20 @@ func EnrichSpotPrices(ctx context.Context, client SpotPriceClient, instances []m latestPrice[itype] = price } - // Populate SpotHourlyCost on matching instances. + // Populate SpotHourlyCost on matching instances and correct cost for + // instances already running as spot. for i := range instances { if instances[i].Source != models.SourceEC2 { continue } - if price, ok := latestPrice[instances[i].InstanceType]; ok { - instances[i].SpotHourlyCost = &price + price, ok := latestPrice[instances[i].InstanceType] + if !ok { + continue + } + instances[i].SpotHourlyCost = &price + if instances[i].PricingModel == "spot" { + instances[i].HourlyCost = price + instances[i].MonthlyCost = price * 730 } } } diff --git a/internal/providers/aws/spot_test.go b/internal/providers/aws/spot_test.go index d82fcbb..55c62f9 100644 --- a/internal/providers/aws/spot_test.go +++ b/internal/providers/aws/spot_test.go @@ -104,3 +104,50 @@ func TestEnrichSpotPrices_EmptyInstances(t *testing.T) { EnrichSpotPrices(context.Background(), client, nil) EnrichSpotPrices(context.Background(), client, []models.GPUInstance{}) } + +func TestEnrichSpotPrices_CorrectsCostForSpotInstances(t *testing.T) { + client := &mockSpotPriceClient{ + prices: []ec2types.SpotPrice{ + { + InstanceType: ec2types.InstanceTypeG5Xlarge, + SpotPrice: aws.String("0.556"), + Timestamp: aws.Time(time.Now()), + }, + }, + } + instances := []models.GPUInstance{ + { + InstanceID: "i-spot", + InstanceType: "g5.xlarge", + Source: models.SourceEC2, + PricingModel: "spot", + HourlyCost: 1.006, // on-demand price (wrong for spot) + MonthlyCost: 1.006 * 730, + }, + { + InstanceID: "i-ondemand", + InstanceType: "g5.xlarge", + Source: models.SourceEC2, + PricingModel: "on-demand", + HourlyCost: 1.006, + MonthlyCost: 1.006 * 730, + }, + } + + EnrichSpotPrices(context.Background(), client, instances) + + // Spot instance should have corrected cost + if instances[0].HourlyCost != 0.556 { + t.Errorf("spot instance hourly cost: expected 0.556, got %f", instances[0].HourlyCost) + } + expectedMonthlyCost := 0.556 * 730 + const epsilon = 0.0001 + if instances[0].MonthlyCost < expectedMonthlyCost-epsilon || instances[0].MonthlyCost > expectedMonthlyCost+epsilon { + t.Errorf("spot instance monthly cost: expected %f, got %f", expectedMonthlyCost, instances[0].MonthlyCost) + } + + // On-demand instance should keep original cost + if instances[1].HourlyCost != 1.006 { + t.Errorf("on-demand instance hourly cost should be unchanged, got %f", instances[1].HourlyCost) + } +} From 6b7bfab419da0b3f9c831c3ad79ce4f3578b96ba Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 20:05:09 +0100 Subject: [PATCH 35/39] Add ruleSpotEligible analysis rule for spot recommendations --- internal/analysis/rules.go | 47 ++++++++++++- internal/analysis/rules_test.go | 114 ++++++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 1 deletion(-) diff --git a/internal/analysis/rules.go b/internal/analysis/rules.go index b975a1c..777d332 100644 --- a/internal/analysis/rules.go +++ b/internal/analysis/rules.go @@ -28,6 +28,7 @@ func analyzeInstance(inst *models.GPUInstance) { ruleSageMakerLowUtil, ruleSageMakerOversized, ruleK8sUnallocatedGPU, + ruleSpotEligible, ruleK8sLowGPUUtil, } for _, rule := range rules { @@ -349,7 +350,51 @@ func ruleK8sUnallocatedGPU(inst *models.GPUInstance) { } } -// Rule 8: K8s GPU node with low GPU utilization (requires DCGM/CW/Prometheus metrics). +// Rule 8: On-demand instance eligible for Spot pricing. +func ruleSpotEligible(inst *models.GPUInstance) { + if inst.PricingModel != "on-demand" { + return + } + if inst.UptimeHours < 24 { + return + } + if inst.SpotHourlyCost == nil { + return + } + + spotHourly := *inst.SpotHourlyCost + savingsPercent := ((inst.HourlyCost - spotHourly) / inst.HourlyCost) * 100 + if savingsPercent <= 0 { + return + } + + monthlySavings := (inst.HourlyCost - spotHourly) * 730 + spotMonthlyCost := spotHourly * 730 + + // Higher savings → higher confidence + confidence := 0.35 + (savingsPercent / 120) + if confidence > 0.95 { + confidence = 0.95 + } + + inst.WasteSignals = append(inst.WasteSignals, models.WasteSignal{ + Type: "spot_eligible", + Severity: models.SeverityInfo, + Confidence: confidence, + Evidence: fmt.Sprintf("Spot pricing available at $%.3f/hr vs $%.3f/hr on-demand (%.0f%% savings).", spotHourly, inst.HourlyCost, savingsPercent), + }) + inst.Recommendations = append(inst.Recommendations, models.Recommendation{ + Action: models.ActionChangePricing, + Description: fmt.Sprintf("Spot pricing available at $%.2f/hr (%.0f%% savings). Spot instances may be interrupted — suitable for fault-tolerant workloads.", spotHourly, savingsPercent), + CurrentMonthlyCost: inst.MonthlyCost, + RecommendedMonthlyCost: spotMonthlyCost, + MonthlySavings: monthlySavings, + SavingsPercent: savingsPercent, + Risk: models.RiskHigh, + }) +} + +// Rule 9: K8s GPU node with low GPU utilization (requires DCGM/CW/Prometheus metrics). func ruleK8sLowGPUUtil(inst *models.GPUInstance) { if inst.Source != models.SourceK8sNode { return diff --git a/internal/analysis/rules_test.go b/internal/analysis/rules_test.go index c1d6223..80bfa6d 100644 --- a/internal/analysis/rules_test.go +++ b/internal/analysis/rules_test.go @@ -260,6 +260,120 @@ func TestAnalyzeAll_ComputesSavings(t *testing.T) { } } +func TestRuleSpotEligible_FlagsOnDemandWithSpotPrice(t *testing.T) { + spotPrice := 0.556 + inst := models.GPUInstance{ + InstanceID: "i-test", + Source: models.SourceEC2, + PricingModel: "on-demand", + UptimeHours: 48, + HourlyCost: 1.006, + MonthlyCost: 1.006 * 730, + SpotHourlyCost: &spotPrice, + } + + ruleSpotEligible(&inst) + + if len(inst.WasteSignals) != 1 { + t.Fatalf("expected 1 signal, got %d", len(inst.WasteSignals)) + } + if inst.WasteSignals[0].Type != "spot_eligible" { + t.Errorf("expected spot_eligible, got %s", inst.WasteSignals[0].Type) + } + if inst.WasteSignals[0].Severity != models.SeverityInfo { + t.Errorf("expected info severity, got %s", inst.WasteSignals[0].Severity) + } + if len(inst.Recommendations) != 1 { + t.Fatalf("expected 1 recommendation, got %d", len(inst.Recommendations)) + } + if inst.Recommendations[0].Action != models.ActionChangePricing { + t.Errorf("expected change_pricing, got %s", inst.Recommendations[0].Action) + } + expectedSavings := (1.006 - 0.556) * 730 + diff := inst.Recommendations[0].MonthlySavings - expectedSavings + if diff < -0.01 || diff > 0.01 { + t.Errorf("expected savings %.2f, got %.2f", expectedSavings, inst.Recommendations[0].MonthlySavings) + } +} + +func TestRuleSpotEligible_SkipsSpotInstances(t *testing.T) { + spotPrice := 0.556 + inst := models.GPUInstance{ + PricingModel: "spot", + UptimeHours: 48, + SpotHourlyCost: &spotPrice, + } + + ruleSpotEligible(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals for spot instance, got %d", len(inst.WasteSignals)) + } +} + +func TestRuleSpotEligible_SkipsRecentInstances(t *testing.T) { + spotPrice := 0.556 + inst := models.GPUInstance{ + PricingModel: "on-demand", + UptimeHours: 12, + SpotHourlyCost: &spotPrice, + } + + ruleSpotEligible(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals for recent instance, got %d", len(inst.WasteSignals)) + } +} + +func TestRuleSpotEligible_SkipsWhenNoSpotPrice(t *testing.T) { + inst := models.GPUInstance{ + PricingModel: "on-demand", + UptimeHours: 48, + SpotHourlyCost: nil, + } + + ruleSpotEligible(&inst) + + if len(inst.WasteSignals) != 0 { + t.Errorf("expected no signals when spot price unavailable, got %d", len(inst.WasteSignals)) + } +} + +func TestRuleSpotEligible_ConfidenceScalesWithSavings(t *testing.T) { + tests := []struct { + name string + onDemand float64 + spotPrice float64 + minConfidence float64 + }{ + {"large_savings_60pct", 1.0, 0.4, 0.85}, + {"moderate_savings_40pct", 1.0, 0.6, 0.65}, + {"small_savings_20pct", 1.0, 0.8, 0.5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inst := models.GPUInstance{ + PricingModel: "on-demand", + UptimeHours: 48, + HourlyCost: tt.onDemand, + MonthlyCost: tt.onDemand * 730, + SpotHourlyCost: &tt.spotPrice, + } + + ruleSpotEligible(&inst) + + if len(inst.WasteSignals) == 0 { + t.Fatal("expected signal") + } + if inst.WasteSignals[0].Confidence < tt.minConfidence { + t.Errorf("expected confidence >= %.2f, got %.2f", tt.minConfidence, inst.WasteSignals[0].Confidence) + } + }) + } +} + func TestRuleK8sLowGPUUtil_FlagsLowUtilization(t *testing.T) { inst := models.GPUInstance{ InstanceID: "i-node1", From 4b8c1c629561ffe0f7360fe55c9aec9fafeb0d59 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 20:06:31 +0100 Subject: [PATCH 36/39] Add ec2:DescribeSpotPriceHistory to IAM policy output --- cmd/gpuaudit/main.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/gpuaudit/main.go b/cmd/gpuaudit/main.go index 271d094..8fb807b 100644 --- a/cmd/gpuaudit/main.go +++ b/cmd/gpuaudit/main.go @@ -309,6 +309,7 @@ var iamPolicyCmd = &cobra.Command{ "ec2:DescribeInstances", "ec2:DescribeInstanceTypes", "ec2:DescribeRegions", + "ec2:DescribeSpotPriceHistory", }, "Resource": "*", }, From 1abda17739e986d281b640eabc017af2ae2d9203 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Sun, 19 Apr 2026 20:51:38 +0100 Subject: [PATCH 37/39] Address review: update signal type comment, add pagination note, guard div-by-zero --- internal/analysis/rules.go | 3 +++ internal/models/models.go | 2 +- internal/providers/aws/spot.go | 5 ++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/internal/analysis/rules.go b/internal/analysis/rules.go index 777d332..93c139e 100644 --- a/internal/analysis/rules.go +++ b/internal/analysis/rules.go @@ -361,6 +361,9 @@ func ruleSpotEligible(inst *models.GPUInstance) { if inst.SpotHourlyCost == nil { return } + if inst.HourlyCost <= 0 { + return + } spotHourly := *inst.SpotHourlyCost savingsPercent := ((inst.HourlyCost - spotHourly) / inst.HourlyCost) * 100 diff --git a/internal/models/models.go b/internal/models/models.go index 9ee0bf7..e819198 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -100,7 +100,7 @@ type GPUInstance struct { // WasteSignal represents a detected waste indicator on a GPU instance. type WasteSignal struct { - Type string `json:"type"` // idle, low_utilization, oversized_gpu, pricing_mismatch, stale, low_invocations + Type string `json:"type"` // idle, low_utilization, oversized_gpu, pricing_mismatch, stale, low_invocations, spot_eligible Severity Severity `json:"severity"` Confidence float64 `json:"confidence"` // 0.0 - 1.0 Evidence string `json:"evidence"` diff --git a/internal/providers/aws/spot.go b/internal/providers/aws/spot.go index 7bdd3b8..d8ddcd6 100644 --- a/internal/providers/aws/spot.go +++ b/internal/providers/aws/spot.go @@ -53,7 +53,10 @@ func EnrichSpotPrices(ctx context.Context, client SpotPriceClient, instances []m return } - // Take the most recent price per instance type (API returns newest first). + // Take the most recent price per instance type. The API returns entries + // per (type, AZ) sorted newest-first. We collapse across AZs — spot prices + // within a region are typically within a few percent. A 1-hour window with + // a handful of GPU types fits well within a single API page (1000 entries). latestPrice := make(map[string]float64) for _, sp := range out.SpotPriceHistory { itype := string(sp.InstanceType) From 2f05b9c91e3e79e20ab8ed7cf4aba535f1eae932 Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Mon, 20 Apr 2026 01:04:16 +0100 Subject: [PATCH 38/39] Add Prometheus GPU metrics support for EC2 instances Share --prom-url across both EC2 and K8s scan paths. EC2 instances are matched to DCGM Prometheus metrics via private DNS hostname, with fallback to IP extracted from the instance label. - Add PrivateDnsName field to GPUInstance, populated from EC2 API - Extract shared Prometheus HTTP query/parse into internal/prometheus - Refactor K8s metrics.go to use shared prometheus package - Add EnrichEC2PrometheusGPUMetrics with hostname + IP matching - Show GPU utilization in signal column for flagged instances - Wire --prom-url into AWS ScanOptions --- cmd/gpuaudit/main.go | 3 +- internal/models/models.go | 3 + internal/output/table.go | 3 + internal/prometheus/query.go | 74 +++++++++ internal/prometheus/query_test.go | 98 ++++++++++++ internal/providers/aws/ec2.go | 15 +- internal/providers/aws/prometheus.go | 130 ++++++++++++++++ internal/providers/aws/prometheus_test.go | 174 ++++++++++++++++++++++ internal/providers/aws/scanner.go | 8 + internal/providers/k8s/metrics.go | 79 ++-------- 10 files changed, 516 insertions(+), 71 deletions(-) create mode 100644 internal/prometheus/query.go create mode 100644 internal/prometheus/query_test.go create mode 100644 internal/providers/aws/prometheus.go create mode 100644 internal/providers/aws/prometheus_test.go diff --git a/cmd/gpuaudit/main.go b/cmd/gpuaudit/main.go index 8fb807b..9232ad9 100644 --- a/cmd/gpuaudit/main.go +++ b/cmd/gpuaudit/main.go @@ -93,7 +93,7 @@ func init() { scanCmd.Flags().BoolVar(&scanSkipCosts, "skip-costs", false, "Skip Cost Explorer data enrichment") scanCmd.Flags().StringVar(&scanKubeconfig, "kubeconfig", "", "Path to kubeconfig file (default: ~/.kube/config)") scanCmd.Flags().StringVar(&scanKubeContext, "kube-context", "", "Kubernetes context to use (default: current context)") - scanCmd.Flags().StringVar(&scanPromURL, "prom-url", "", "Prometheus URL for GPU metrics (e.g., https://prometheus.corp.example.com)") + scanCmd.Flags().StringVar(&scanPromURL, "prom-url", "", "Prometheus URL for GPU metrics on EC2 and K8s (e.g., https://prometheus.corp.example.com)") scanCmd.Flags().StringVar(&scanPromEndpoint, "prom-endpoint", "", "In-cluster Prometheus service as namespace/service:port (e.g., monitoring/prometheus:9090)") scanCmd.Flags().StringSliceVar(&scanExcludeTags, "exclude-tag", nil, "Exclude instances matching tag (key=value, repeatable)") scanCmd.Flags().IntVar(&scanMinUptimeDays, "min-uptime-days", 0, "Only flag instances running for at least this many days") @@ -136,6 +136,7 @@ func runScan(cmd *cobra.Command, args []string) error { opts.Targets = scanTargets opts.Role = scanRole opts.ExternalID = scanExternalID + opts.PromURL = scanPromURL opts.OrgScan = scanOrg opts.SkipSelf = scanSkipSelf diff --git a/internal/models/models.go b/internal/models/models.go index e819198..153ecec 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -57,6 +57,9 @@ type GPUInstance struct { Name string `json:"name"` // from Name tag or endpoint name Tags map[string]string `json:"tags,omitempty"` + // Network (populated for EC2) + PrivateDnsName string `json:"private_dns_name,omitempty"` + // GPU hardware InstanceType string `json:"instance_type"` GPUModel string `json:"gpu_model"` diff --git a/internal/output/table.go b/internal/output/table.go index ece729c..6052fe8 100644 --- a/internal/output/table.go +++ b/internal/output/table.go @@ -128,6 +128,9 @@ func printInstanceTable(w io.Writer, instances []models.GPUInstance, multiTarget signal := "" if len(inst.WasteSignals) > 0 { signal = inst.WasteSignals[0].Type + if inst.AvgGPUUtilization != nil { + signal += fmt.Sprintf(" [GPU %.0f%%]", *inst.AvgGPUUtilization) + } } rec := "" diff --git a/internal/prometheus/query.go b/internal/prometheus/query.go new file mode 100644 index 0000000..063cf09 --- /dev/null +++ b/internal/prometheus/query.go @@ -0,0 +1,74 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package prometheus + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" +) + +// QueryHTTP executes a PromQL instant query against a Prometheus-compatible HTTP API +// and returns a map from the given labelName to its metric value. +func QueryHTTP(ctx context.Context, baseURL, query, labelName string) (map[string]float64, error) { + u := fmt.Sprintf("%s/api/v1/query?query=%s", strings.TrimRight(baseURL, "/"), url.QueryEscape(query)) + req, err := http.NewRequestWithContext(ctx, "GET", u, nil) + if err != nil { + return nil, err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + return ParseResponse(data, labelName) +} + +// ParseResponse extracts metric values from a Prometheus API JSON response, +// keyed by the given label name. +func ParseResponse(data []byte, labelName string) (map[string]float64, error) { + var resp struct { + Status string `json:"status"` + Data struct { + ResultType string `json:"resultType"` + Result []struct { + Metric map[string]string `json:"metric"` + Value []json.RawMessage `json:"value"` + } `json:"result"` + } `json:"data"` + } + if err := json.Unmarshal(data, &resp); err != nil { + return nil, fmt.Errorf("parsing response: %w", err) + } + if resp.Status != "success" { + return nil, fmt.Errorf("query returned status %q", resp.Status) + } + + results := make(map[string]float64) + for _, r := range resp.Data.Result { + key := r.Metric[labelName] + if key == "" || len(r.Value) < 2 { + continue + } + var valStr string + if err := json.Unmarshal(r.Value[1], &valStr); err != nil { + continue + } + val, err := strconv.ParseFloat(valStr, 64) + if err != nil { + continue + } + results[key] = val + } + return results, nil +} diff --git a/internal/prometheus/query_test.go b/internal/prometheus/query_test.go new file mode 100644 index 0000000..6849b94 --- /dev/null +++ b/internal/prometheus/query_test.go @@ -0,0 +1,98 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package prometheus + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestParseResponse_ExtractsByLabel(t *testing.T) { + data := []byte(`{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + {"metric": {"Hostname": "ip-10-0-1-1"}, "value": [1700000000, "45.2"]}, + {"metric": {"Hostname": "ip-10-0-1-2"}, "value": [1700000000, "12.8"]} + ] + } + }`) + + results, err := ParseResponse(data, "Hostname") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + if results["ip-10-0-1-1"] != 45.2 { + t.Errorf("expected 45.2, got %f", results["ip-10-0-1-1"]) + } + if results["ip-10-0-1-2"] != 12.8 { + t.Errorf("expected 12.8, got %f", results["ip-10-0-1-2"]) + } +} + +func TestParseResponse_SkipsMissingLabel(t *testing.T) { + data := []byte(`{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + {"metric": {"other": "value"}, "value": [1700000000, "45.2"]} + ] + } + }`) + + results, err := ParseResponse(data, "Hostname") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(results) != 0 { + t.Errorf("expected 0 results, got %d", len(results)) + } +} + +func TestParseResponse_ErrorStatus(t *testing.T) { + data := []byte(`{"status": "error", "errorType": "bad_data", "error": "parse error"}`) + + _, err := ParseResponse(data, "node") + if err == nil { + t.Error("expected error for non-success status") + } +} + +func TestQueryHTTP(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/query" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + query := r.URL.Query().Get("query") + if query == "" { + t.Error("expected query parameter") + } + fmt.Fprintf(w, `{ + "status": "success", + "data": { + "resultType": "vector", + "result": [ + {"metric": {"node": "host1"}, "value": [1700000000, "55.5"]} + ] + } + }`) + })) + defer srv.Close() + + results, err := QueryHTTP(context.Background(), srv.URL, "up", "node") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if results["host1"] != 55.5 { + t.Errorf("expected 55.5, got %f", results["host1"]) + } +} diff --git a/internal/providers/aws/ec2.go b/internal/providers/aws/ec2.go index 0fa6738..cb82a46 100644 --- a/internal/providers/aws/ec2.go +++ b/internal/providers/aws/ec2.go @@ -97,13 +97,14 @@ func ec2InstanceToGPU(inst ec2types.Instance, accountID, region string) *models. // TODO: detect RI/SP coverage via Cost Explorer return &models.GPUInstance{ - InstanceID: aws.ToString(inst.InstanceId), - Source: models.SourceEC2, - AccountID: accountID, - Region: region, - Name: name, - Tags: tags, - InstanceType: instanceType, + InstanceID: aws.ToString(inst.InstanceId), + Source: models.SourceEC2, + AccountID: accountID, + Region: region, + Name: name, + Tags: tags, + PrivateDnsName: aws.ToString(inst.PrivateDnsName), + InstanceType: instanceType, GPUModel: spec.GPUModel, GPUCount: spec.GPUCount, GPUVRAMGiB: spec.GPUVRAMGiB, diff --git a/internal/providers/aws/prometheus.go b/internal/providers/aws/prometheus.go new file mode 100644 index 0000000..bef722f --- /dev/null +++ b/internal/providers/aws/prometheus.go @@ -0,0 +1,130 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/gpuaudit/cli/internal/models" + prom "github.com/gpuaudit/cli/internal/prometheus" +) + +// EnrichEC2PrometheusGPUMetrics queries a Prometheus endpoint for DCGM GPU metrics +// on EC2 instances that don't already have AvgGPUUtilization populated. +// It matches Prometheus results to EC2 instances via private DNS hostname. +func EnrichEC2PrometheusGPUMetrics(ctx context.Context, promURL string, instances []models.GPUInstance) int { + if promURL == "" { + return 0 + } + + type instRef struct { + index int + hostname string + ip string + } + var refs []instRef + for i := range instances { + inst := &instances[i] + if inst.Source != models.SourceEC2 || inst.State != "running" { + continue + } + if inst.AvgGPUUtilization != nil { + continue + } + if inst.PrivateDnsName == "" { + continue + } + hostname := strings.SplitN(inst.PrivateDnsName, ".", 2)[0] + ip := extractIPFromDNS(inst.PrivateDnsName) + refs = append(refs, instRef{index: i, hostname: hostname, ip: ip}) + } + if len(refs) == 0 { + return 0 + } + + // Build lookup maps: hostname → index, ip → index + hostnameToIdx := make(map[string]int, len(refs)) + ipToIdx := make(map[string]int, len(refs)) + for _, ref := range refs { + hostnameToIdx[ref.hostname] = ref.index + if ref.ip != "" { + ipToIdx[ref.ip] = ref.index + } + } + + fmt.Fprintf(os.Stderr, " Querying Prometheus at %s for EC2 GPU metrics...\n", promURL) + + // Query GPU utilization — get all DCGM metrics and match locally. + // DCGM exporter labels vary by setup: "Hostname" for host identity, + // "instance" for scrape target (ip:port). + gpuByHostname, err := prom.QueryHTTP(ctx, promURL, + `avg by (Hostname) (avg_over_time(DCGM_FI_DEV_GPU_UTIL[7d]))`, "Hostname") + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Prometheus EC2 GPU query failed: %v\n", err) + return 0 + } + + memByHostname, _ := prom.QueryHTTP(ctx, promURL, + `avg by (Hostname) (avg_over_time(DCGM_FI_DEV_MEM_COPY_UTIL[7d]))`, "Hostname") + + enriched := 0 + + // First pass: match by Hostname label (short hostname like "ip-10-22-249-234") + for _, ref := range refs { + if val, ok := gpuByHostname[ref.hostname]; ok { + instances[ref.index].AvgGPUUtilization = &val + if memVal, ok := memByHostname[ref.hostname]; ok { + instances[ref.index].AvgGPUMemUtilization = &memVal + } + enriched++ + } + } + + // Second pass: try matching by instance label (ip:port) for instances still missing metrics + instanceSeriesCount := 0 + if enriched < len(refs) { + gpuByInstance, err := prom.QueryHTTP(ctx, promURL, + `avg by (instance) (avg_over_time(DCGM_FI_DEV_GPU_UTIL[7d]))`, "instance") + if err == nil { + instanceSeriesCount = len(gpuByInstance) + memByInstance, _ := prom.QueryHTTP(ctx, promURL, + `avg by (instance) (avg_over_time(DCGM_FI_DEV_MEM_COPY_UTIL[7d]))`, "instance") + + for instanceLabel, val := range gpuByInstance { + ip := strings.SplitN(instanceLabel, ":", 2)[0] + idx, ok := ipToIdx[ip] + if !ok || instances[idx].AvgGPUUtilization != nil { + continue + } + v := val + instances[idx].AvgGPUUtilization = &v + if memVal, ok := memByInstance[instanceLabel]; ok { + instances[idx].AvgGPUMemUtilization = &memVal + } + enriched++ + } + } + } + + if enriched > 0 { + fmt.Fprintf(os.Stderr, " Prometheus: matched %d of %d EC2 instances\n", enriched, len(refs)) + } else { + fmt.Fprintf(os.Stderr, " Prometheus: matched 0 of %d EC2 instances (server returned %d Hostname series, %d instance series)\n", + len(refs), len(gpuByHostname), instanceSeriesCount) + } + return enriched +} + +// extractIPFromDNS extracts the IP address from an EC2 private DNS name. +// e.g., "ip-10-22-249-234.ec2.internal" → "10.22.249.234" +func extractIPFromDNS(dnsName string) string { + hostname := strings.SplitN(dnsName, ".", 2)[0] + if !strings.HasPrefix(hostname, "ip-") { + return "" + } + return strings.ReplaceAll(hostname[3:], "-", ".") +} diff --git a/internal/providers/aws/prometheus_test.go b/internal/providers/aws/prometheus_test.go new file mode 100644 index 0000000..0096349 --- /dev/null +++ b/internal/providers/aws/prometheus_test.go @@ -0,0 +1,174 @@ +// Copyright 2026 the gpuaudit authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package aws + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gpuaudit/cli/internal/models" +) + +func TestExtractIPFromDNS(t *testing.T) { + tests := []struct { + dns string + wantIP string + }{ + {"ip-10-22-249-234.ec2.internal", "10.22.249.234"}, + {"ip-172-31-0-5.us-west-2.compute.internal", "172.31.0.5"}, + {"custom-hostname.ec2.internal", ""}, + {"", ""}, + } + for _, tt := range tests { + got := extractIPFromDNS(tt.dns) + if got != tt.wantIP { + t.Errorf("extractIPFromDNS(%q) = %q, want %q", tt.dns, got, tt.wantIP) + } + } +} + +func TestEnrichEC2PrometheusGPUMetrics_MatchesByHostname(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query().Get("query") + if strings.Contains(query, "GPU_UTIL") { + fmt.Fprintf(w, `{ + "status": "success", + "data": {"resultType": "vector", "result": [ + {"metric": {"Hostname": "ip-10-0-1-100"}, "value": [1700000000, "72.5"]} + ]} + }`) + } else { + fmt.Fprintf(w, `{ + "status": "success", + "data": {"resultType": "vector", "result": [ + {"metric": {"Hostname": "ip-10-0-1-100"}, "value": [1700000000, "45.0"]} + ]} + }`) + } + })) + defer srv.Close() + + instances := []models.GPUInstance{ + { + InstanceID: "i-abc123", + Source: models.SourceEC2, + State: "running", + PrivateDnsName: "ip-10-0-1-100.ec2.internal", + }, + { + InstanceID: "i-def456", + Source: models.SourceEC2, + State: "running", + PrivateDnsName: "ip-10-0-1-200.ec2.internal", + }, + } + + enriched := EnrichEC2PrometheusGPUMetrics(context.Background(), srv.URL, instances) + + if enriched != 1 { + t.Fatalf("expected 1 enriched, got %d", enriched) + } + if instances[0].AvgGPUUtilization == nil || *instances[0].AvgGPUUtilization != 72.5 { + t.Errorf("expected GPU util 72.5, got %v", instances[0].AvgGPUUtilization) + } + if instances[0].AvgGPUMemUtilization == nil || *instances[0].AvgGPUMemUtilization != 45.0 { + t.Errorf("expected GPU mem util 45.0, got %v", instances[0].AvgGPUMemUtilization) + } + if instances[1].AvgGPUUtilization != nil { + t.Error("expected no GPU util for unmatched instance") + } +} + +func TestEnrichEC2PrometheusGPUMetrics_FallsBackToInstanceLabel(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query().Get("query") + if strings.Contains(query, "Hostname") { + // No results by hostname + fmt.Fprintf(w, `{"status": "success", "data": {"resultType": "vector", "result": []}}`) + } else if strings.Contains(query, "instance") && strings.Contains(query, "GPU_UTIL") { + fmt.Fprintf(w, `{ + "status": "success", + "data": {"resultType": "vector", "result": [ + {"metric": {"instance": "10.0.1.100:9400"}, "value": [1700000000, "88.0"]} + ]} + }`) + } else { + fmt.Fprintf(w, `{ + "status": "success", + "data": {"resultType": "vector", "result": [ + {"metric": {"instance": "10.0.1.100:9400"}, "value": [1700000000, "60.0"]} + ]} + }`) + } + })) + defer srv.Close() + + instances := []models.GPUInstance{ + { + InstanceID: "i-abc123", + Source: models.SourceEC2, + State: "running", + PrivateDnsName: "ip-10-0-1-100.ec2.internal", + }, + } + + enriched := EnrichEC2PrometheusGPUMetrics(context.Background(), srv.URL, instances) + + if enriched != 1 { + t.Fatalf("expected 1 enriched, got %d", enriched) + } + if instances[0].AvgGPUUtilization == nil || *instances[0].AvgGPUUtilization != 88.0 { + t.Errorf("expected GPU util 88.0, got %v", instances[0].AvgGPUUtilization) + } +} + +func TestEnrichEC2PrometheusGPUMetrics_SkipsAlreadyEnriched(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("should not query Prometheus when all instances already have metrics") + fmt.Fprintf(w, `{"status": "success", "data": {"resultType": "vector", "result": []}}`) + })) + defer srv.Close() + + gpuUtil := 50.0 + instances := []models.GPUInstance{ + { + InstanceID: "i-abc123", + Source: models.SourceEC2, + State: "running", + PrivateDnsName: "ip-10-0-1-100.ec2.internal", + AvgGPUUtilization: &gpuUtil, + }, + } + + enriched := EnrichEC2PrometheusGPUMetrics(context.Background(), srv.URL, instances) + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} + +func TestEnrichEC2PrometheusGPUMetrics_SkipsNonEC2(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("should not query Prometheus for non-EC2 instances") + fmt.Fprintf(w, `{"status": "success", "data": {"resultType": "vector", "result": []}}`) + })) + defer srv.Close() + + instances := []models.GPUInstance{ + { + InstanceID: "node-1", + Source: models.SourceK8sNode, + State: "ready", + PrivateDnsName: "ip-10-0-1-100.ec2.internal", + }, + } + + enriched := EnrichEC2PrometheusGPUMetrics(context.Background(), srv.URL, instances) + if enriched != 0 { + t.Errorf("expected 0 enriched, got %d", enriched) + } +} diff --git a/internal/providers/aws/scanner.go b/internal/providers/aws/scanner.go index 15ed1a0..908c4f5 100644 --- a/internal/providers/aws/scanner.go +++ b/internal/providers/aws/scanner.go @@ -36,6 +36,9 @@ type ScanOptions struct { ExcludeTags map[string]string MinUptimeDays int + // Prometheus + PromURL string + // Multi-target options Targets []string Role string @@ -249,6 +252,11 @@ func scanTarget(ctx context.Context, target Target, regions []string, opts ScanO } } + // Enrich EC2 GPU metrics from Prometheus (for instances missing GPU utilization) + if !opts.SkipMetrics && opts.PromURL != "" && len(allInstances) > 0 { + EnrichEC2PrometheusGPUMetrics(ctx, opts.PromURL, allInstances) + } + return allInstances, scannedRegions } diff --git a/internal/providers/k8s/metrics.go b/internal/providers/k8s/metrics.go index 4a587c2..02bb259 100644 --- a/internal/providers/k8s/metrics.go +++ b/internal/providers/k8s/metrics.go @@ -6,13 +6,9 @@ package k8s import ( "bytes" "context" - "encoding/json" "fmt" - "io" - "net/http" "net/url" "os" - "strconv" "strings" dto "github.com/prometheus/client_model/go" @@ -22,6 +18,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/gpuaudit/cli/internal/models" + prom "github.com/gpuaudit/cli/internal/prometheus" ) // EnrichDCGMMetrics discovers dcgm-exporter pods and scrapes GPU metrics for K8s nodes @@ -198,9 +195,9 @@ func EnrichPrometheusMetrics(ctx context.Context, client K8sClient, instances [] nodeRegex := strings.Join(nodeNames, "|") gpuResults := queryPrometheus(ctx, client, opts, - fmt.Sprintf(`avg_over_time(DCGM_FI_DEV_GPU_UTIL{node=~"%s"}[7d])`, nodeRegex)) + fmt.Sprintf(`avg_over_time(DCGM_FI_DEV_GPU_UTIL{node=~"%s"}[7d])`, nodeRegex), "node") memResults := queryPrometheus(ctx, client, opts, - fmt.Sprintf(`avg_over_time(DCGM_FI_DEV_MEM_COPY_UTIL{node=~"%s"}[7d])`, nodeRegex)) + fmt.Sprintf(`avg_over_time(DCGM_FI_DEV_MEM_COPY_UTIL{node=~"%s"}[7d])`, nodeRegex), "node") enriched := 0 for _, node := range nodes { @@ -217,35 +214,27 @@ func EnrichPrometheusMetrics(ctx context.Context, client K8sClient, instances [] return enriched } -func queryPrometheus(ctx context.Context, client K8sClient, opts PrometheusOptions, query string) map[string]float64 { - var data []byte - var err error - +func queryPrometheus(ctx context.Context, client K8sClient, opts PrometheusOptions, query, labelName string) map[string]float64 { if opts.URL != "" { - data, err = queryPrometheusHTTP(ctx, opts.URL, query) - } else { - data, err = queryPrometheusProxy(ctx, client, opts.Endpoint, query) + results, err := prom.QueryHTTP(ctx, opts.URL, query, labelName) + if err != nil { + fmt.Fprintf(os.Stderr, " warning: Prometheus query failed: %v\n", err) + return nil + } + return results } + + data, err := queryPrometheusProxy(ctx, client, opts.Endpoint, query) if err != nil { fmt.Fprintf(os.Stderr, " warning: Prometheus query failed: %v\n", err) return nil } - - return parsePrometheusResponse(data) -} - -func queryPrometheusHTTP(ctx context.Context, baseURL, query string) ([]byte, error) { - u := fmt.Sprintf("%s/api/v1/query?query=%s", strings.TrimRight(baseURL, "/"), url.QueryEscape(query)) - req, err := http.NewRequestWithContext(ctx, "GET", u, nil) + results, err := prom.ParseResponse(data, labelName) if err != nil { - return nil, err - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, err + fmt.Fprintf(os.Stderr, " warning: Prometheus response parse failed: %v\n", err) + return nil } - defer resp.Body.Close() - return io.ReadAll(resp.Body) + return results } func queryPrometheusProxy(ctx context.Context, client K8sClient, endpoint, query string) ([]byte, error) { @@ -273,39 +262,3 @@ func parsePrometheusEndpoint(endpoint string) (namespace, service, port string, return namespace, service, port, nil } -func parsePrometheusResponse(data []byte) map[string]float64 { - var resp struct { - Status string `json:"status"` - Data struct { - ResultType string `json:"resultType"` - Result []struct { - Metric map[string]string `json:"metric"` - Value []json.RawMessage `json:"value"` - } `json:"result"` - } `json:"data"` - } - if err := json.Unmarshal(data, &resp); err != nil { - return nil - } - if resp.Status != "success" { - return nil - } - - results := make(map[string]float64) - for _, r := range resp.Data.Result { - node := r.Metric["node"] - if node == "" || len(r.Value) < 2 { - continue - } - var valStr string - if err := json.Unmarshal(r.Value[1], &valStr); err != nil { - continue - } - val, err := strconv.ParseFloat(valStr, 64) - if err != nil { - continue - } - results[node] = val - } - return results -} From 0c33990463bbd3ea0db2e5567ba1f3375ab90f4c Mon Sep 17 00:00:00 2001 From: Stas Maksimov Date: Tue, 21 Apr 2026 12:49:03 +0100 Subject: [PATCH 39/39] Sanitize example output in README Replace fleet-specific numbers with generic examples that don't expose real infrastructure details. --- README.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index b7592b5..225cdaf 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ Scan your cloud for GPU waste and get actionable recommendations to cut your spe ``` $ gpuaudit scan --skip-eks - Found 103 GPU nodes across 111 nodes in gpu-cluster + Found 38 GPU nodes across 47 nodes in gpu-cluster gpuaudit — GPU Cost Audit for AWS Account: 123456789012 | Regions: us-east-1 | Duration: 4.2s @@ -13,12 +13,12 @@ $ gpuaudit scan --skip-eks ┌──────────────────────────────────────────────────────────┐ │ GPU Fleet Summary │ ├──────────────────────────────────────────────────────────┤ - │ Total GPU instances: 103 │ - │ Total monthly GPU spend: $365155 │ - │ Estimated monthly waste: $23408 ( 6%) │ + │ Total GPU instances: 38 │ + │ Total monthly GPU spend: $127450 │ + │ Estimated monthly waste: $18200 ( 14%) │ └──────────────────────────────────────────────────────────┘ - CRITICAL — 4 instance(s), $21728/mo potential savings + CRITICAL — 3 instance(s), $15400/mo potential savings Instance Type Monthly Signal Recommendation ──────────────────────────────────── ────────────────────────── ──────── ──────────────── ────────────────────────────────────────────── @@ -107,12 +107,12 @@ gpuaudit diff scan-apr-08.json scan-apr-15.json ┌──────────────────────────────────────────────────────────┐ │ Cost Delta │ ├──────────────────────────────────────────────────────────┤ - │ Monthly spend: $372000 → $365155 (-$6845) │ - │ Estimated waste: $189000 → $23408 (-$165592) │ - │ Instances: 116 → 103 (-13 removed, +0 added) │ + │ Monthly spend: $142000 → $127450 (-$14550) │ + │ Estimated waste: $31000 → $18200 (-$12800) │ + │ Instances: 45 → 38 (-9 removed, +2 added) │ └──────────────────────────────────────────────────────────┘ - REMOVED — 13 instance(s), -$6845/mo + REMOVED — 9 instance(s), -$16200/mo ... ```