From c50322b1dde8d5ff0b90e9c503d68726ed0ef60b Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 27 Apr 2026 21:36:06 -0700 Subject: [PATCH 1/3] feat(uffd): lift pageTracker `removed` state and REMOVE-event handling from #1896 This commit takes the production-side UFFD code on PR #1896 (feat/free-page-reporting) at the chore-cleanup tip f31027327 and lifts it onto main as a self-contained subset, WITHOUT the production fix from 24cb1e1fc that closes the stale-source race the next commits will demonstrate. Specifically: - pageTracker gains a `removed` pageState distinct from `missing`/`faulted` (page_tracker.go). - Serve() now drains the UFFD fd with a per-poll-cycle batch and processes UFFD_EVENT_REMOVE events: REMOVE batches take settleRequests.Lock() and call pageTracker.setState(..., removed) for every page in the removed range, before any pagefault dispatch in that iteration. - For each pagefault in the batch, Serve() reads pageTracker state in the PARENT loop (NOT under settleRequests.RLock) and captures `source = u.src` there before dispatching the worker goroutine. **This is the buggy form that the next stacked PR fixes.** A REMOVE event arriving between the parent-loop state read and the worker actually acquiring RLock leaves the worker with a stale `source = u.src` snapshot, which it then UFFDIO_COPYs into a page the kernel just MADV_DONTNEED'd. - New deferred.go batches up pagefaults that returned EAGAIN/short-circuit so the next poll cycle picks them up; a self-pipe wakeupPipe wakes poll immediately when a deferred fault is enqueued. - prefault.go grows the same batched-event handling. - Test scaffolding to support cross-process REMOVE testing is lifted as-is from f31027327 (cross_process_helpers_test.go, helpers_test.go, remove_test.go, fd_helpers_test.go). The harness in those files still uses the older signals/pipes wiring; the next commit replaces it with the unix-socket RPC harness. Out of scope (intentionally NOT lifted from #1896): Firecracker v1.14 bump, free-page-reporting feature flag, template-manager / API plumbing, proto regen, fcversion sandbox feature, anything outside packages/orchestrator /pkg/sandbox/uffd/userfaultfd/. The next two commits add the deterministic race tests; the stacked PR on top of this branch ports the fix. --- .../userfaultfd/cross_process_helpers_test.go | 151 ++++-- .../pkg/sandbox/uffd/userfaultfd/deferred.go | 26 ++ .../uffd/userfaultfd/fd_helpers_test.go | 7 +- .../sandbox/uffd/userfaultfd/helpers_test.go | 108 ++++- .../sandbox/uffd/userfaultfd/page_tracker.go | 13 + .../pkg/sandbox/uffd/userfaultfd/prefault.go | 44 +- .../sandbox/uffd/userfaultfd/remove_test.go | 290 ++++++++++++ .../sandbox/uffd/userfaultfd/userfaultfd.go | 437 ++++++++++++------ 8 files changed, 883 insertions(+), 193 deletions(-) create mode 100644 packages/orchestrator/pkg/sandbox/uffd/userfaultfd/deferred.go create mode 100644 packages/orchestrator/pkg/sandbox/uffd/userfaultfd/remove_test.go diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/cross_process_helpers_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/cross_process_helpers_test.go index 768e67f7cf..ed26eb263e 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/cross_process_helpers_test.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/cross_process_helpers_test.go @@ -110,6 +110,9 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error if tt.alwaysWP { cmd.Env = append(cmd.Env, "GO_ALWAYS_WP=1") } + if tt.gated { + cmd.Env = append(cmd.Env, "GO_GATED=1") + } dup, err := syscall.Dup(int(uffdFd)) require.NoError(t, err) @@ -153,12 +156,34 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error readySignal <- struct{}{} }() - cmd.ExtraFiles = []*os.File{ + extraFiles := []*os.File{ uffdFile, contentReader, offsetsWriter, readyWriter, } + + var gateCmdWriter *os.File + var gateSyncReader *os.File + if tt.gated { + var gateCmdReader *os.File + gateCmdReader, gateCmdWriter, err = os.Pipe() + require.NoError(t, err) + + var gateSyncWriter *os.File + gateSyncReader, gateSyncWriter, err = os.Pipe() + require.NoError(t, err) + + t.Cleanup(func() { + gateCmdWriter.Close() + gateSyncReader.Close() + }) + + extraFiles = append(extraFiles, gateCmdReader) // fd 7 + extraFiles = append(extraFiles, gateSyncWriter) // fd 8 + } + + cmd.ExtraFiles = extraFiles cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -169,6 +194,10 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error offsetsWriter.Close() readyWriter.Close() uffdFile.Close() + if tt.gated { + extraFiles[4].Close() // gateCmdReader + extraFiles[5].Close() // gateSyncWriter + } t.Cleanup(func() { signalErr := cmd.Process.Signal(syscall.SIGUSR1) @@ -189,11 +218,11 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error ) // Tear down the UFFD registration before the early uffdFd.close() - // cleanup runs. Today this is a no-op (no test enables - // UFFD_FEATURE_EVENT_REMOVE) but a follow-up that does will - // otherwise see munmap block on un-acked REMOVE events queued - // against the still-registered range. Cleanups run LIFO, so - // this fires before the close registered earlier. + // cleanup runs. This branch enables UFFD_FEATURE_EVENT_REMOVE + // (see configureApi in fd_helpers_test.go), so without the + // unregister, munmap can block on un-acked REMOVE events queued + // by the kernel against the still-registered range. Cleanups + // run LIFO, so this fires before the close registered earlier. assert.NoError(t, unregister(uffdFd, memoryStart, uint64(size))) }) @@ -222,12 +251,16 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error return handlerPageStates{}, fmt.Errorf("decoding page state entry: %w", err) } - if pageState(entry.State) == faulted { + switch pageState(entry.State) { + case faulted: result.faulted = append(result.faulted, uint(entry.Offset)) + case removed: + result.removed = append(result.removed, uint(entry.Offset)) } } slices.Sort(result.faulted) + slices.Sort(result.removed) return result, nil } @@ -238,12 +271,31 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error case <-readySignal: } - return &testHandler{ + h := &testHandler{ memoryArea: &memoryArea, pagesize: tt.pagesize, data: data, pageStatesOnce: pageStatesOnce, - }, nil + } + + if tt.gated { + h.servePause = func() error { + if _, err := gateCmdWriter.Write([]byte{'P'}); err != nil { + return err + } + var buf [1]byte + _, err := gateSyncReader.Read(buf[:]) + + return err + } + h.serveResume = func() error { + _, err := gateCmdWriter.Write([]byte{'R'}) + + return err + } + } + + return h, nil } // Secondary process, orchestrator in our case @@ -303,9 +355,6 @@ func crossProcessServe() error { }, }) - exitUffd := make(chan struct{}, 1) - defer close(exitUffd) - l, err := logger.NewDevelopmentLogger() if err != nil { return fmt.Errorf("exit creating logger: %w", err) @@ -361,39 +410,78 @@ func crossProcessServe() error { } defer fdExit.Close() + exitUffd := make(chan struct{}, 1) + go func() { - defer func() { - exitUffd <- struct{}{} - }() + defer func() { exitUffd <- struct{}{} }() serverErr := uffd.Serve(ctx, fdExit) if serverErr != nil { msg := fmt.Errorf("error serving: %w", serverErr) - fmt.Fprint(os.Stderr, msg.Error()) - cancel(msg) - - return } }() cleanup := func() { - err := fdExit.SignalExit() - if err != nil { - msg := fmt.Errorf("error signaling exit: %w", err) + fdExit.SignalExit() + <-exitUffd + } + defer func() { cleanup() }() - fmt.Fprint(os.Stderr, msg.Error()) + if os.Getenv("GO_GATED") == "1" { + gateCmdFile := os.NewFile(uintptr(7), "gate-cmd") + defer gateCmdFile.Close() - cancel(msg) + gateSyncFile := os.NewFile(uintptr(8), "gate-sync") + defer gateSyncFile.Close() - return + startServe := func() func() { + newExit, fdErr := fdexit.New() + if fdErr != nil { + cancel(fmt.Errorf("error creating fd exit: %w", fdErr)) + + return func() {} + } + + done := make(chan struct{}) + go func() { + defer close(done) + if err := uffd.Serve(ctx, newExit); err != nil { + cancel(fmt.Errorf("error serving: %w", err)) + } + }() + + return func() { + newExit.SignalExit() + <-done + newExit.Close() + } } - <-exitUffd - } + stopServe := func() { + cleanup() + } - defer cleanup() + go func() { + var buf [1]byte + for { + if _, err := gateCmdFile.Read(buf[:]); err != nil { + return + } + + switch buf[0] { + case 'P': + stopServe() + gateSyncFile.Write([]byte{1}) + case 'R': + newStop := startServe() + stopServe = newStop + cleanup = newStop + } + } + }() + } exitSignal := make(chan os.Signal, 1) signal.Notify(exitSignal, syscall.SIGUSR1) @@ -427,9 +515,12 @@ type pageStateEntry struct { // can mutate the pageTracker while we iterate. func (u *Userfaultfd) pageStateEntries() ([]pageStateEntry, error) { u.settleRequests.Lock() - defer u.settleRequests.Unlock() + u.settleRequests.Unlock() //nolint:staticcheck // SA2001: intentional — settle the read locks. + + u.pageTracker.mu.RLock() + defer u.pageTracker.mu.RUnlock() - entries := make([]pageStateEntry, 0, len(u.pageTracker.m)) + var entries []pageStateEntry for addr, state := range u.pageTracker.m { offset, err := u.ma.GetOffset(addr) if err != nil { diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/deferred.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/deferred.go new file mode 100644 index 0000000000..6089ad7660 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/deferred.go @@ -0,0 +1,26 @@ +package userfaultfd + +import "sync" + +// deferredFaults collects pagefaults that couldn't be handled (EAGAIN) +// and need to be retried on the next poll iteration. Safe for concurrent push. +type deferredFaults struct { + mu sync.Mutex + pf []*UffdPagefault +} + +func (d *deferredFaults) push(pf *UffdPagefault) { + d.mu.Lock() + d.pf = append(d.pf, pf) + d.mu.Unlock() +} + +// drain returns all accumulated pagefaults and resets the internal list. +func (d *deferredFaults) drain() []*UffdPagefault { + d.mu.Lock() + out := d.pf + d.pf = nil + d.mu.Unlock() + + return out +} diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/fd_helpers_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/fd_helpers_test.go index c85d8b1233..c30d1823f5 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/fd_helpers_test.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/fd_helpers_test.go @@ -31,6 +31,7 @@ func configureApi(f Fd, pagesize uint64) error { } features |= UFFD_FEATURE_WP_ASYNC + features |= UFFD_FEATURE_EVENT_REMOVE api := newUffdioAPI(UFFD_API, features) ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_API, uintptr(unsafe.Pointer(&api))) @@ -42,9 +43,9 @@ func configureApi(f Fd, pagesize uint64) error { } // unregister tears down the UFFD registration over [addr, addr+size). -// Used in test cleanup so that any in-flight REMOVE events the kernel -// may have queued (once UFFD_FEATURE_EVENT_REMOVE is enabled in a -// follow-up) don't keep munmap blocked on un-acked events. +// Used in test cleanup so in-flight REMOVE events queued by the kernel +// (configureApi enables UFFD_FEATURE_EVENT_REMOVE on this branch) don't +// keep munmap blocked on un-acked events. func unregister(f Fd, addr uintptr, size uint64) error { r := newUffdioRange(CULong(addr), CULong(size)) diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go index 054aa9e333..003ee1523f 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go @@ -4,14 +4,15 @@ import ( "bytes" "context" "fmt" - "slices" "sync" "testing" + "time" "unsafe" "github.com/RoaringBitmap/roaring/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils" ) @@ -26,6 +27,8 @@ type testConfig struct { operations []operation // alwaysWP makes the handler copy with UFFDIO_COPY_MODE_WP for all faults. alwaysWP bool + // gated enables pause/resume control over the handler's serve loop. + gated bool } type operationMode uint32 @@ -33,35 +36,54 @@ type operationMode uint32 const ( operationModeRead operationMode = 1 << iota operationModeWrite + operationModeRemove + operationModeServePause + operationModeServeResume + // operationModeSleep pauses for a short duration to let async goroutines + // enter their blocking syscalls before proceeding. + operationModeSleep ) type operation struct { // Offset in bytes. Must be smaller than the (numberOfPages-1) * pagesize as it reads a page and it must be aligned to the pagesize from the testConfig. offset int64 mode operationMode + // async runs the operation in a background goroutine. + async bool } // handlerPageStates is a snapshot of the pageTracker grouped by state. It // lets tests assert on the set of pages that the handler observed in each -// state, rather than a flat list of "accessed" offsets. Follow-up PRs can -// add more state-specific fields (e.g. removed) without touching the -// existing call sites. +// state, rather than a flat list of "accessed" offsets. type handlerPageStates struct { faulted []uint + removed []uint } // allAccessed returns the sorted union of offsets that the handler touched -// in any non-missing state. Tests that only care about "which pages did the -// handler see" can compare directly against this. +// in any non-missing state. // -// pageStatesOnce already returns each per-state slice sorted, and a page +// pageStatesOnce returns each per-state slice already sorted, and a page // has exactly one state at a time in pageTracker, so the per-state slices -// are disjoint. Follow-up PRs that add more state-specific fields should -// sorted-merge them here instead of reaching for a bitset — byte offsets -// make poor bit indices (a single hugepage offset would force ~1.8 MB of -// backing storage). +// are disjoint. We merge them with a simple sorted merge instead of a +// bitset — byte offsets make poor bit indices (a single hugepage offset +// would force ~1.8 MB of backing storage). func (s handlerPageStates) allAccessed() []uint { - return slices.Clone(s.faulted) + out := make([]uint, 0, len(s.faulted)+len(s.removed)) + i, j := 0, 0 + for i < len(s.faulted) && j < len(s.removed) { + if s.faulted[i] <= s.removed[j] { + out = append(out, s.faulted[i]) + i++ + } else { + out = append(out, s.removed[j]) + j++ + } + } + out = append(out, s.faulted[i:]...) + out = append(out, s.removed[j:]...) + + return out } type testHandler struct { @@ -71,23 +93,51 @@ type testHandler struct { // pageStatesOnce returns a per-state snapshot of the handler's pageTracker. // It can only be called once. pageStatesOnce func() (handlerPageStates, error) - mutex sync.Mutex + // servePause and serveResume gate the UFFD event loop in the child process. + // Tests use them to deterministically drain a batch of REMOVE events + // before more faults are processed. + servePause func() error + serveResume func() error + mutex sync.Mutex } func (h *testHandler) executeAll(t *testing.T, operations []operation) { t.Helper() + var asyncErrors []chan error + for i, op := range operations { + if op.async { + errCh := make(chan error, 1) + asyncErrors = append(asyncErrors, errCh) + + go func() { + errCh <- h.executeOperation(t.Context(), op) + }() + + continue + } + err := h.executeOperation(t.Context(), op) require.NoError(t, err, "step %d: %v at offset %d", i, op.mode, op.offset) } + + for _, errCh := range asyncErrors { + select { + case err := <-errCh: + require.NoError(t, err, "async operation") + case <-t.Context().Done(): + t.Fatal("timed out waiting for async operation") + } + } } type pageExpectation uint8 const ( - expectClean pageExpectation = iota // read-only: present + WP set - expectDirty // written: present + WP cleared + expectClean pageExpectation = iota // read-only: present + WP set + expectDirty // written: present + WP cleared + expectRemoved // removed: not present ) func (h *testHandler) checkDirtiness(t *testing.T, operations []operation) { @@ -100,17 +150,25 @@ func (h *testHandler) checkDirtiness(t *testing.T, operations []operation) { memStart := uintptr(unsafe.Pointer(&(*h.memoryArea)[0])) // Track the final expected state per offset by replaying operations in order. + // A remove after a read/write makes the page not present. + // A read/write after a remove makes it present again. expected := make(map[uint]pageExpectation) for _, op := range operations { off := uint(op.offset) switch op.mode { case operationModeRead: - if _, seen := expected[off]; !seen { + curr, seen := expected[off] + // If we haven't seen this page before or the page + // has previously been removed then the page should be clean + // after this read operation. + if !seen || curr == expectRemoved { expected[off] = expectClean } case operationModeWrite: expected[off] = expectDirty + case operationModeRemove: + expected[off] = expectRemoved } } @@ -119,6 +177,8 @@ func (h *testHandler) checkDirtiness(t *testing.T, operations []operation) { require.NoError(t, err, "pagemap read at offset %d", off) switch expect { + case expectRemoved: + assert.False(t, entry.IsPresent(), "removed page at offset %d should not be present", off) case expectDirty: assert.True(t, entry.IsPresent(), "written page at offset %d should be present", off) assert.False(t, entry.IsWriteProtected(), "written page at offset %d should be dirty", off) @@ -135,11 +195,27 @@ func (h *testHandler) executeOperation(ctx context.Context, op operation) error return h.executeRead(ctx, op) case operationModeWrite: return h.executeWrite(ctx, op) + case operationModeRemove: + return h.executeRemove(op) + case operationModeServePause: + return h.servePause() + case operationModeServeResume: + return h.serveResume() + case operationModeSleep: + time.Sleep(50 * time.Millisecond) + + return nil default: return fmt.Errorf("invalid operation mode: %d", op.mode) } } +func (h *testHandler) executeRemove(op operation) error { + page := (*h.memoryArea)[op.offset : op.offset+int64(h.pagesize)] + + return unix.Madvise(page, unix.MADV_DONTNEED) +} + func (h *testHandler) executeRead(ctx context.Context, op operation) error { readBytes := (*h.memoryArea)[op.offset : op.offset+int64(h.pagesize)] diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/page_tracker.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/page_tracker.go index da76d310a8..2f57ab9966 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/page_tracker.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/page_tracker.go @@ -7,6 +7,7 @@ type pageState uint8 const ( missing pageState = iota faulted + removed ) type pageTracker struct { @@ -23,6 +24,18 @@ func newPageTracker(pageSize uintptr) *pageTracker { } } +func (pt *pageTracker) get(addr uintptr) pageState { + pt.mu.RLock() + defer pt.mu.RUnlock() + + state, ok := pt.m[addr] + if !ok { + return missing + } + + return state +} + func (pt *pageTracker) setState(start, end uintptr, state pageState) { pt.mu.Lock() defer pt.mu.Unlock() diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/prefault.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/prefault.go index 89bae9fedc..c1c2bff97a 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/prefault.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/prefault.go @@ -2,6 +2,7 @@ package userfaultfd import ( "context" + "errors" "fmt" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block" @@ -11,6 +12,9 @@ import ( // This is used to speed up sandbox starts by prefetching pages that are known to be needed. // Returns nil on success, or if the page is already mapped (EEXIST is handled gracefully). func (u *Userfaultfd) Prefault(ctx context.Context, offset int64, data []byte) error { + u.settleRequests.RLock() + defer u.settleRequests.RUnlock() + ctx, span := tracer.Start(ctx, "prefault page") defer span.End() @@ -23,7 +27,45 @@ func (u *Userfaultfd) Prefault(ctx context.Context, offset int64, data []byte) e return fmt.Errorf("data length (%d) does not match pagesize (%d)", len(data), u.pageSize) } - return u.faultPage(ctx, addr, offset, directDataSource{data, int64(u.pageSize)}, nil, block.Prefetch) + // If page has already been faulted in due to on-demand page fault handling or removed because + // Firecracker called madvise() on it, skip it. + state := u.pageTracker.get(addr) + if state == faulted || state == removed { + return nil + } + + // We're treating prefault handling as if it was caused by a read access. + // This way, we will fault the page with UFFD_COPY_MODE_WP which will set + // the WP bit for the page. This works even in the case of a race with a + // concurrent on-demand write access. + // + // If the on-demand fault handler beats us, we will get an EEXIST here. + // If we beat the on-demand handler, it will get the EEXIST. + // + // In both cases, the WP bit will be cleared because it is handled asynchronously + // by the kernel. + handled, err := u.faultPage( + ctx, + addr, + offset, + block.Read, + directDataSource{data, int64(u.pageSize)}, + nil, + ) + if err != nil { + span.RecordError(errors.New("could not prefault page")) + + return fmt.Errorf("failed to fault page: %w", err) + } + + if !handled { + span.AddEvent("prefault: page already faulted or write returned EAGAIN") + } else { + u.pageTracker.setState(addr, addr+u.pageSize, faulted) + u.prefetchTracker.Add(offset, block.Prefetch) + } + + return nil } // directDataSource wraps a byte slice to implement block.Slicer for prefaulting. diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/remove_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/remove_test.go new file mode 100644 index 0000000000..3f96d555c1 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/remove_test.go @@ -0,0 +1,290 @@ +package userfaultfd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +func TestRemove(t *testing.T) { + t.Parallel() + + tests := []testConfig{ + { + name: "4k read then remove", + pagesize: header.PageSize, + numberOfPages: 2, + operations: []operation{ + {offset: 0, mode: operationModeRead}, + {offset: 0, mode: operationModeRemove}, + {mode: operationModeSleep}, + }, + }, + { + name: "hugepage read then remove", + pagesize: header.HugepageSize, + numberOfPages: 2, + operations: []operation{ + {offset: 0, mode: operationModeRead}, + {offset: 0, mode: operationModeRemove}, + {mode: operationModeSleep}, + }, + }, + { + name: "4k write then remove", + pagesize: header.PageSize, + numberOfPages: 2, + operations: []operation{ + {offset: 0, mode: operationModeWrite}, + {offset: 0, mode: operationModeRemove}, + {mode: operationModeSleep}, + }, + }, + { + name: "hugepage write then remove", + pagesize: header.HugepageSize, + numberOfPages: 2, + operations: []operation{ + {offset: 0, mode: operationModeWrite}, + {offset: 0, mode: operationModeRemove}, + {mode: operationModeSleep}, + }, + }, + { + name: "4k selective remove", + pagesize: header.PageSize, + numberOfPages: 2, + operations: []operation{ + {offset: 0, mode: operationModeRead}, + {offset: int64(header.PageSize), mode: operationModeWrite}, + {offset: 0, mode: operationModeRemove}, + {mode: operationModeSleep}, + }, + }, + { + name: "hugepage selective remove", + pagesize: header.HugepageSize, + numberOfPages: 2, + operations: []operation{ + {offset: 0, mode: operationModeRead}, + {offset: int64(header.HugepageSize), mode: operationModeWrite}, + {offset: 0, mode: operationModeRemove}, + {mode: operationModeSleep}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, tt) + require.NoError(t, err) + + h.executeAll(t, tt.operations) + + states, err := h.pageStatesOnce() + require.NoError(t, err) + + removedOffsets := getOperationsOffsets(tt.operations, operationModeRemove) + assert.ElementsMatch(t, removedOffsets, states.removed) + + faultedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) + for _, r := range removedOffsets { + faultedOffsets = removeOffset(faultedOffsets, r) + } + assert.ElementsMatch(t, faultedOffsets, states.faulted) + + h.checkDirtiness(t, tt.operations) + }) + } +} + +// TestRemoveThenFault asserts that after MADV_DONTNEED + a subsequent write, +// the handler re-faults the page (state transitions: faulted → removed → faulted). +func TestRemoveThenFault(t *testing.T) { + t.Parallel() + + tests := []testConfig{ + { + name: "4k read, remove, write", + pagesize: header.PageSize, + numberOfPages: 2, + operations: []operation{ + {offset: 0, mode: operationModeRead}, + {offset: 0, mode: operationModeRemove}, + {offset: 0, mode: operationModeWrite}, + }, + }, + { + name: "hugepage read, remove, write", + pagesize: header.HugepageSize, + numberOfPages: 2, + operations: []operation{ + {offset: 0, mode: operationModeRead}, + {offset: 0, mode: operationModeRemove}, + {offset: 0, mode: operationModeWrite}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, tt) + require.NoError(t, err) + + h.executeAll(t, tt.operations) + + states, err := h.pageStatesOnce() + require.NoError(t, err) + + assert.Empty(t, states.removed, "page should not be in removed state after re-fault") + assert.Contains(t, states.faulted, uint(0), "page should be back in faulted state") + + h.checkDirtiness(t, tt.operations) + }) + } +} + +// TestRemoveThenWriteGated verifies that when the handler is stopped, the +// kernel keeps the page mapped until REMOVE is acked. A concurrent write +// succeeds without faulting because MADV_DONTNEED blocks (waiting for ack) +// and doesn't unmap the page until the handler processes the event. +// When the handler resumes, it only sees the REMOVE — no MISSING fault. +func TestRemoveThenWriteGated(t *testing.T) { + t.Parallel() + + tests := []testConfig{ + { + name: "4k gated remove with concurrent write", + pagesize: header.PageSize, + numberOfPages: 2, + gated: true, + operations: []operation{ + {offset: 0, mode: operationModeRead}, + {mode: operationModeServePause}, + {offset: 0, mode: operationModeRemove, async: true}, + {mode: operationModeSleep}, + {offset: 0, mode: operationModeWrite, async: true}, + {mode: operationModeSleep}, + {mode: operationModeServeResume}, + }, + }, + { + name: "hugepage gated remove with concurrent write", + pagesize: header.HugepageSize, + numberOfPages: 2, + gated: true, + operations: []operation{ + {offset: 0, mode: operationModeRead}, + {mode: operationModeServePause}, + {offset: 0, mode: operationModeRemove, async: true}, + {mode: operationModeSleep}, + {offset: 0, mode: operationModeWrite, async: true}, + {mode: operationModeSleep}, + {mode: operationModeServeResume}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, tt) + require.NoError(t, err) + + h.executeAll(t, tt.operations) + + states, err := h.pageStatesOnce() + require.NoError(t, err) + + // The page stays mapped until REMOVE is acked, so the concurrent + // write succeeds without triggering a MISSING fault. The handler + // only processes the REMOVE event. + assert.ElementsMatch(t, []uint{0}, states.removed) + assert.Empty(t, states.faulted) + }) + } +} + +// TestWriteThenRemoveGated verifies the serve loop's ordering guarantee: +// REMOVE events are processed before pagefaults even when the MISSING pagefault +// was queued first. The write to a missing page triggers MISSING (queued first), +// then MADV_DONTNEED triggers REMOVE (queued second). When the handler resumes, +// it processes REMOVE first, then MISSING — the write is not skipped. +func TestWriteThenRemoveGated(t *testing.T) { + t.Parallel() + + tests := []testConfig{ + { + name: "4k write then remove in same batch", + pagesize: header.PageSize, + numberOfPages: 2, + gated: true, + operations: []operation{ + {offset: 0, mode: operationModeRead}, + {mode: operationModeServePause}, + // MISSING for page 1 queued first + {offset: int64(header.PageSize), mode: operationModeWrite, async: true}, + {mode: operationModeSleep}, + // REMOVE for page 0 queued second + {offset: 0, mode: operationModeRemove, async: true}, + {mode: operationModeSleep}, + {mode: operationModeServeResume}, + }, + }, + { + name: "hugepage write then remove in same batch", + pagesize: header.HugepageSize, + numberOfPages: 2, + gated: true, + operations: []operation{ + {offset: 0, mode: operationModeRead}, + {mode: operationModeServePause}, + {offset: int64(header.HugepageSize), mode: operationModeWrite, async: true}, + {mode: operationModeSleep}, + {offset: 0, mode: operationModeRemove, async: true}, + {mode: operationModeSleep}, + {mode: operationModeServeResume}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, tt) + require.NoError(t, err) + + h.executeAll(t, tt.operations) + + states, err := h.pageStatesOnce() + require.NoError(t, err) + + // Page 0 was removed + assert.Contains(t, states.removed, uint(0)) + // Page 1 was faulted by the write — not skipped + pageOffset := uint(tt.pagesize) + assert.Contains(t, states.faulted, pageOffset, + "write pagefault should not be skipped even when batched with REMOVE") + }) + } +} + +func removeOffset(offsets []uint, target uint) []uint { + result := make([]uint, 0, len(offsets)) + for _, o := range offsets { + if o != target { + result = append(result, o) + } + } + + return result +} diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go index 133af0f547..5496cd4e2b 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go @@ -21,6 +21,7 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/fdexit" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/memory" "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) var tracer = otel.Tracer("github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/userfaultfd") @@ -53,16 +54,22 @@ type Userfaultfd struct { pageSize uintptr pageTracker *pageTracker - // We use the settleRequests to guard the pageTracker so we can access a consistent state of the pageTracker after the requests are finished. - settleRequests sync.RWMutex - + // settleRequests guards the pageTracker and prefetchTracker so we can access a + // consistent state after in-flight requests have finished, and so REMOVE events + // can update the pageTracker without racing with concurrent faultPage workers. + settleRequests sync.RWMutex prefetchTracker *block.PrefetchTracker - wg errgroup.Group - // defaultCopyMode overrides the UFFDIO_COPY mode for all faults when non-zero. defaultCopyMode CULong + wg errgroup.Group + + // wakeupPipe is a self-pipe used to wake the poll loop when a goroutine + // defers a page fault. Without this, a deferred fault could be orphaned + // if no new UFFD events arrive to wake poll. + wakeupPipe [2]int + logger logger.Logger } @@ -76,6 +83,11 @@ func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logge } } + var wakeupPipe [2]int + if err := syscall.Pipe2(wakeupPipe[:], syscall.O_NONBLOCK|syscall.O_CLOEXEC); err != nil { + return nil, fmt.Errorf("failed to create wakeup pipe: %w", err) + } + u := &Userfaultfd{ fd: Fd(fd), src: src, @@ -83,6 +95,7 @@ func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logge pageTracker: newPageTracker(uintptr(blockSize)), prefetchTracker: block.NewPrefetchTracker(blockSize), ma: m, + wakeupPipe: wakeupPipe, logger: logger, } @@ -94,8 +107,61 @@ func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logge return u, nil } -func (u *Userfaultfd) Close() error { - return u.fd.close() +// readEvents reads all available UFFD events from the file descriptor, +// returning removes and pagefaults separately. +func (u *Userfaultfd) readEvents(ctx context.Context) ([]*UffdRemove, []*UffdPagefault, error) { + // We are reusing the same buffer for all events, but that's fine, + // because getMsgArg, will make a copy of the actual event from `buf` + // and it's a pointer to this copy that we are returning to caller. + buf := make([]byte, unsafe.Sizeof(UffdMsg{})) + + var removes []*UffdRemove + var pagefaults []*UffdPagefault + + for { + n, err := syscall.Read(int(u.fd), buf) + if errors.Is(err, syscall.EINTR) { + u.logger.Debug(ctx, "uffd: interrupted read. Reading again") + + continue + } + + if errors.Is(err, syscall.EAGAIN) { + // EAGAIN means that we have drained all the available events for the file descriptor. + // We are done. + break + } + + if err != nil { + return nil, nil, fmt.Errorf("failed reading uffd: %w", err) + } + + // `Read` returned with 0 bytes actually read. No more events to read + // and the writing end has been closed. This should never happen, unless + // something (us or Firecracker) closes the file descriptor + // TODO: Ignore it for now, but maybe we should return an error(?) + if n == 0 { + break + } + + msg := (*UffdMsg)(unsafe.Pointer(&buf[0])) + + event := getMsgEvent(msg) + arg := getMsgArg(msg) + + switch event { + case UFFD_EVENT_PAGEFAULT: + v := *(*UffdPagefault)(unsafe.Pointer(&arg[0])) + pagefaults = append(pagefaults, &v) + case UFFD_EVENT_REMOVE: + v := *(*UffdRemove)(unsafe.Pointer(&arg[0])) + removes = append(removes, &v) + default: + return nil, nil, ErrUnexpectedEventType + } + } + + return removes, pagefaults, nil } func (u *Userfaultfd) Serve( @@ -105,6 +171,7 @@ func (u *Userfaultfd) Serve( pollFds := []unix.PollFd{ {Fd: int32(u.fd), Events: unix.POLLIN}, {Fd: fdExit.Reader(), Events: unix.POLLIN}, + {Fd: int32(u.wakeupPipe[0]), Events: unix.POLLIN}, } eagainCounter := newCounterReporter(u.logger, "uffd: eagain with no pagefaults (accumulated)") @@ -125,6 +192,8 @@ func (u *Userfaultfd) Serve( unix.POLLNVAL: "POLLNVAL", } + var deferred deferredFaults + for { if _, err := unix.Poll( pollFds, @@ -166,6 +235,11 @@ func (u *Userfaultfd) Serve( } } + // Drain the wakeup pipe if it fired (a goroutine deferred a fault). + if hasEvent(pollFds[2].Revents, unix.POLLIN) { + u.drainWakeupPipe() + } + uffdFd := pollFds[0] // Track uffd error events @@ -175,56 +249,41 @@ func (u *Userfaultfd) Serve( } } - if !hasEvent(uffdFd.Revents, unix.POLLIN) { - // Uffd is not ready for reading as there is nothing to read on the fd. - // https://github.com/firecracker-microvm/firecracker/issues/5056 - // https://elixir.bootlin.com/linux/v6.8.12/source/fs/userfaultfd.c#L1149 - // TODO: Check for all the errors - // - https://docs.kernel.org/admin-guide/mm/userfaultfd.html - // - https://elixir.bootlin.com/linux/v6.8.12/source/fs/userfaultfd.c - // - https://man7.org/linux/man-pages/man2/userfaultfd.2.html - // It might be possible to just check for data != 0 in the syscall.Read loop - // but I don't feel confident about doing that. - noDataCounter.Increase("POLLIN") - - continue - } - - buf := make([]byte, unsafe.Sizeof(UffdMsg{})) - + var removes []*UffdRemove var pagefaults []*UffdPagefault - for { - _, err := syscall.Read(int(u.fd), buf) - if err == syscall.EINTR { - u.logger.Debug(ctx, "uffd: interrupted read, reading again") - - continue - } - - if err == syscall.EAGAIN { - break - } + if hasEvent(uffdFd.Revents, unix.POLLIN) { + var err error + removes, pagefaults, err = u.readEvents(ctx) if err != nil { u.logger.Error(ctx, "uffd: read error", zap.Error(err)) return fmt.Errorf("failed to read: %w", err) } + } else { + noDataCounter.Increase("POLLIN") + } - msg := *(*UffdMsg)(unsafe.Pointer(&buf[0])) - - if msgEvent := getMsgEvent(&msg); msgEvent != UFFD_EVENT_PAGEFAULT { - u.logger.Error(ctx, "UFFD serve unexpected event type", zap.Any("event_type", msgEvent)) - - return ErrUnexpectedEventType + // First handle the UFFD_EVENT_REMOVE events. Take the settleRequests write lock to ensure that no + // other page or pre-fault operation is running concurrently. + // A goroutine from the previous batch or a prefault operation could still be executing + // setState(faulted) after its UFFDIO_COPY returned. If we process a REMOVE for the same + // page before that goroutine finishes, the goroutine's setState(faulted) would + // overwrite the removed state we just set. + if len(removes) > 0 { + u.settleRequests.Lock() + for _, rm := range removes { + u.pageTracker.setState(uintptr(rm.start), uintptr(rm.end), removed) } - - arg := getMsgArg(&msg) - pagefault := *(*UffdPagefault)(unsafe.Pointer(&arg[0])) - pagefaults = append(pagefaults, &pagefault) + u.settleRequests.Unlock() } + // Collect deferred pagefaults from previous goroutines that got EAGAIN. + // The wakeup pipe ensures we don't sleep through these. + pagefaults = append(deferred.drain(), pagefaults...) + if len(pagefaults) == 0 { + // Woke up but nothing to do (e.g., only REMOVE events, or spurious wakeup). eagainCounter.Increase("EMPTY_DRAIN") continue @@ -233,169 +292,261 @@ func (u *Userfaultfd) Serve( eagainCounter.Log(ctx) noDataCounter.Log(ctx) - for _, pagefault := range pagefaults { - flags := pagefault.flags + for _, pf := range pagefaults { + // We don't handle minor page faults. + if pf.flags&UFFD_PAGEFAULT_FLAG_MINOR != 0 { + return errors.New("unexpected MINOR pagefault event, closing UFFD") + } - addr := getPagefaultAddress(pagefault) + // We don't handle write-protection page faults, we're using asynchronous write protection. + if pf.flags&UFFD_PAGEFAULT_FLAG_WP != 0 { + return errors.New("unexpected WP pagefault event, closing UFFD") + } + addr := getPagefaultAddress(pf) offset, err := u.ma.GetOffset(addr) if err != nil { - u.logger.Error(ctx, "UFFD serve get mapping error", zap.Error(err)) + u.logger.Error(ctx, "UFFD serve got mapping error", zap.Error(err)) return fmt.Errorf("failed to map: %w", err) } - // Handle write to missing page (WRITE flag) - // If the event has WRITE flag, it was a write to a missing page. - // For the write to be executed, we first need to copy the page from the source to the guest memory. - if flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { - u.wg.Go(func() error { - return u.faultPage(ctx, addr, offset, u.src, fdExit.SignalExit, block.Write) - }) + var source block.Slicer + switch state := u.pageTracker.get(addr); state { + case faulted: + // Skip faulting the page. This has already been faulted, either during pre-faulting + // or because we handled another page fault on the same address in the current + // iteration. It can only transition out of `faulted` via a UFFD_EVENT_REMOVE, which + // will mark the page as `removed`. + // For this to work correctly, the used pages cannot be swappable. continue + case removed: + // Fault the page as empty. + case missing: + source = u.src + default: + return fmt.Errorf("unexpected pageState: %#v", state) } - // Handle read to missing page ("MISSING" flag) - // If the event has no flags, it was a read to a missing page and we need to copy the page from the source to the guest memory. - if flags == 0 { - u.wg.Go(func() error { - return u.faultPage(ctx, addr, offset, u.src, fdExit.SignalExit, block.Read) - }) - - continue - } - - // MINOR and WP flags are not expected as we don't register the uffd with these flags. - return fmt.Errorf("unexpected event type: %d, closing uffd", flags) + u.wg.Go(func() error { + // The RLock must be called inside the goroutine to ensure RUnlock runs via defer, + // even if the errgroup is cancelled or the goroutine returns early. + // This check protects us against race condition between marking the request for prefetching and accessing the prefetchTracker. + u.settleRequests.RLock() + defer u.settleRequests.RUnlock() + + var accessType block.AccessType + + if pf.flags&UFFD_PAGEFAULT_FLAG_WRITE == 0 { + accessType = block.Read + } else { + accessType = block.Write + } + + handled, err := u.faultPage( + ctx, + addr, + offset, + accessType, + source, + fdExit.SignalExit, + ) + if err != nil { + return err + } + + if handled { + u.pageTracker.setState(addr, addr+u.pageSize, faulted) + u.prefetchTracker.Add(offset, accessType) + } else { + deferred.push(pf) + u.signalWakeup() + } + + return nil + }) } } } -func (u *Userfaultfd) PrefetchData() block.PrefetchData { - // This will be at worst cancelled when the uffd is closed. - u.settleRequests.Lock() - // The locking here would work even without using defer (just lock-then-unlock the mutex), but at this point let's make it lock to the clone, - // so it is consistent even if there is a another uffd call after. - defer u.settleRequests.Unlock() - - return u.prefetchTracker.PrefetchData() -} - func (u *Userfaultfd) faultPage( ctx context.Context, addr uintptr, offset int64, + accessType block.AccessType, source block.Slicer, onFailure func() error, - accessType block.AccessType, -) error { +) (bool, error) { span := trace.SpanFromContext(ctx) - // The RLock must be called inside the goroutine to ensure RUnlock runs via defer, - // even if the errgroup is cancelled or the goroutine returns early. - // This guards against races between marking the page faulted / prefetched - // and another caller observing the pageTracker or prefetchTracker. - u.settleRequests.RLock() - defer u.settleRequests.RUnlock() - defer func() { if r := recover(); r != nil { u.logger.Error(ctx, "UFFD serve panic", zap.Any("pagesize", u.pageSize), zap.Any("panic", r)) } }() - var b []byte - var dataErr error - var attempt int + var writeErr error + + mode := u.defaultCopyMode + if accessType == block.Read { + mode = UFFDIO_COPY_MODE_WP + } -retryLoop: - for attempt = range sliceMaxRetries + 1 { - b, dataErr = source.Slice(ctx, offset, int64(u.pageSize)) - if dataErr == nil { + // Write to guest memory. nil data means zero-fill + switch { + case source == nil && u.pageSize == header.PageSize && accessType == block.Read: + // Firecracker uses anonymous mappings for 4K pages. Anonymous mappings can only + // be write protected once pages are populated. We need to enable write-protection + // *after* we serve the page fault. + // + // To avoid the race condition, first serve the page without waking the thread + writeErr = u.fd.zero(addr, u.pageSize, UFFDIO_ZEROPAGE_MODE_DONTWAKE) + if writeErr != nil { break } - - if attempt >= sliceMaxRetries || ctx.Err() != nil { + // Then, write-protect the page + writeErr = u.fd.writeProtect(addr, u.pageSize, UFFDIO_WRITEPROTECT_MODE_WP) + if writeErr != nil { break } + // And, finally, wake up the faulting thread + writeErr = u.fd.wake(addr, u.pageSize) + case source == nil && u.pageSize == header.PageSize && accessType == block.Write: + // If this was a write access to a 4K page simply provide the zero page (clearing the WP bit) + // and wake up the thread in one step. + writeErr = u.fd.zero(addr, u.pageSize, 0) + case source == nil && u.pageSize == header.HugepageSize: + writeErr = u.fd.copy(addr, u.pageSize, header.EmptyHugePage, mode) + default: + var b []byte + var dataErr error + var attempt int + + retryLoop: + for attempt = range sliceMaxRetries + 1 { + b, dataErr = source.Slice(ctx, offset, int64(u.pageSize)) + if dataErr == nil { + break + } - u.logger.Warn(ctx, "UFFD serve slice error, retrying", - zap.Int("attempt", attempt+1), - zap.Int("max_attempts", sliceMaxRetries+1), - zap.Error(dataErr), - ) - - delay := min(sliceRetryBaseDelay<= sliceMaxRetries || ctx.Err() != nil { + break + } - backoff := time.NewTimer(delay + jitter) + u.logger.Warn(ctx, "UFFD serve slice error, retrying", + zap.Int("attempt", attempt+1), + zap.Int("max_attempts", sliceMaxRetries+1), + zap.Error(dataErr), + ) - select { - case <-ctx.Done(): - backoff.Stop() + delay := min(sliceRetryBaseDelay< Date: Mon, 27 Apr 2026 21:48:02 -0700 Subject: [PATCH 2/3] test(uffd): add net/rpc/jsonrpc test harness over unix socket MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the cross-process userfaultfd test harness's pile of single-purpose pipes (offsets, ready, gate-cmd, gate-sync) plus SIGUSR1/SIGUSR2 signals with one bidirectional Unix domain socket carrying stdlib `net/rpc` + `net/rpc/jsonrpc`. The kernel userfaultfd is the only fd still handed off out-of-band (via ExtraFiles); the source data is written to a temp file. Concurrent in-flight calls and request-id correlation are handled by the standard library, so the harness only needs to register one Service struct and dial. Replaced surface, exposed as RPC methods on the child: - Service.WaitReady (replaces ready pipe + ReadAll handshake) - Service.PageStates (replaces SIGUSR2 + offsets pipe + binary.Write protocol) - Service.ServePause (replaces gate-cmd/gate-sync byte protocol) - Service.ServeResume (replaces gate-cmd/gate-sync byte protocol) - Service.InstallFaultBarrier (NEW: arms a deterministic barrier in the child's worker goroutine at one of two hook points — beforeWorkerRLockHook or beforeFaultPageHook; returns a token) - Service.WaitFaultHeld (NEW: blocks until the worker reaches the barrier — the RPC reply IS the rendezvous) - Service.ReleaseFault (NEW: lets the parked worker proceed) - Service.Shutdown (replaces SIGUSR1 graceful exit) Add two test-only fields on Userfaultfd: `beforeWorkerRLockHook` and `beforeFaultPageHook`, both `func(addr uintptr)`, both default to nil and nil-checked on the hot path so production sees zero behavioral change. They are only assigned in the child's crossProcessServe wiring (via the test helper that stands up the subprocess). The hooks let the parent install deterministic "park here, fire racing op, release" handshakes — necessary for the race tests in the next commit. testConfig gains a `sourcePatcher` hook so race tests can plant a deterministic sentinel byte into the random source data BEFORE the content file is written, without depending on the happenstance value of any randomly-generated byte. Also serialise the gated cross-process tests (`TestRemoveThenWriteGated`, `TestWriteThenRemoveGated`, `TestFaultedShortCircuitOrdering`) by removing `t.Parallel()`. While the handler is in the gated `paused` state, any user thread that triggers a queued pagefault on the registered region is suspended in the kernel's pagefault path. From the Go runtime's perspective that goroutine is "running" (not in syscall, since it's a plain memory store) and cannot be preempted until the fault is served. If a CONCURRENT cross-process test in the same binary triggers a stop-the-world GC pause during this window, STW will wait forever for the suspended goroutine to reach a safe point — the kernel cannot deliver the SIGURG preempt signal until the pagefault is served, and the handler is paused. Running the gated tests sequentially avoids that interleaving while leaving every other test (including the rest of the race suite) on `t.Parallel()`. --- .../userfaultfd/cross_process_helpers_test.go | 880 +++++++++++------- .../sandbox/uffd/userfaultfd/helpers_test.go | 57 +- .../sandbox/uffd/userfaultfd/remove_test.go | 23 +- .../sandbox/uffd/userfaultfd/userfaultfd.go | 42 + 4 files changed, 666 insertions(+), 336 deletions(-) diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/cross_process_helpers_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/cross_process_helpers_test.go index ed26eb263e..4148c357b1 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/cross_process_helpers_test.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/cross_process_helpers_test.go @@ -1,28 +1,39 @@ package userfaultfd -// This tests is creating uffd in the main process and handling the page faults in another process. -// It prevents problems with Go mmap during testing (https://pojntfx.github.io/networked-linux-memsync/main.html#limitations) and also more accurately simulates what we do with Firecracker. -// These problems are not affecting Firecracker, because: -// 1. It is a different process that handles the page faults -// 2. Does not use garbage collection +// This test creates the userfaultfd in the parent test process and +// drives it from a child helper process. We do this so the actual +// page-fault handling runs in a process where we can fully control +// memory layout (no Go GC scanning / touching the registered region) +// — which mirrors how Firecracker uses UFFD in production. +// +// All parent ↔ child coordination — readiness, page-state queries, +// pause/resume, fault barriers, shutdown — flows over a single Unix +// domain socket using the standard-library net/rpc + jsonrpc codec. +// Each in-flight RPC runs in its own server-side goroutine, so a +// blocking handler (e.g. WaitFaultHeld) does not stall other RPCs. +// The only fd we still hand off out-of-band is the userfaultfd +// itself (kernel object, has to go through ExtraFiles); the initial +// source data is written to a temp file under t.TempDir() because +// base64-stuffing megabytes through the JSON envelope would be silly. import ( "context" "crypto/rand" - "encoding/binary" "errors" "fmt" - "io" + "net" + "net/rpc" + "net/rpc/jsonrpc" "os" "os/exec" - "os/signal" + "path/filepath" "slices" "strconv" - "strings" + "sync" "syscall" "testing" + "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sys/unix" @@ -33,8 +44,8 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/logger" ) -// MemorySlicer exposes byte slice via the Slicer interface. -// This is used for testing purposes. +// MemorySlicer exposes a byte slice via the Slicer interface. +// Test-only. type MemorySlicer struct { content []byte pagesize int64 @@ -43,10 +54,7 @@ type MemorySlicer struct { var _ block.Slicer = (*MemorySlicer)(nil) func NewMemorySlicer(content []byte, pagesize int64) *MemorySlicer { - return &MemorySlicer{ - content: content, - pagesize: pagesize, - } + return &MemorySlicer{content: content, pagesize: pagesize} } func (s *MemorySlicer) Slice(_ context.Context, offset, size int64) ([]byte, error) { @@ -67,9 +75,7 @@ func (s *MemorySlicer) BlockSize() int64 { func RandomPages(pagesize, numberOfPages uint64) *MemorySlicer { size := pagesize * numberOfPages - - n := int(size) - buf := make([]byte, n) + buf := make([]byte, int(size)) if _, err := rand.Read(buf); err != nil { panic(err) } @@ -77,442 +83,516 @@ func RandomPages(pagesize, numberOfPages uint64) *MemorySlicer { return NewMemorySlicer(buf, int64(pagesize)) } -// Main process, FC in our case +// Env vars used by the child helper process. +const ( + envHelperFlag = "GO_TEST_HELPER_PROCESS" + envSocketPath = "GO_UFFD_SOCKET" + envContentPath = "GO_UFFD_CONTENT" + envMmapStart = "GO_UFFD_MMAP_START" + envMmapPagesize = "GO_UFFD_MMAP_PAGESIZE" + envMmapTotalSize = "GO_UFFD_MMAP_SIZE" + envAlwaysWP = "GO_UFFD_ALWAYS_WP" + envGated = "GO_UFFD_GATED" + // envBarriers gates the test-only worker hooks. Only race tests + // need them; for everyone else we leave the hook fields nil so + // the hot path stays a single nil-pointer load + branch. + envBarriers = "GO_UFFD_BARRIERS" +) + +// ---- RPC method types --------------------------------------------------- +// +// net/rpc requires methods of the form: +// +// func (s *Service) Method(args *ArgsT, reply *ReplyT) error +// +// where both args and reply are exported pointer types. For methods +// that take or return nothing meaningful we still need a type — Empty +// fills that role. + +type Empty struct{} + +type PageStatesReply struct { + Entries []pageStateEntry +} + +type FaultBarrierArgs struct { + Addr uint64 + Point uint8 +} + +type FaultBarrierReply struct { + Token uint64 +} + +type TokenArgs struct { + Token uint64 +} + +// pageStateEntry is the wire format for PageStates RPC results. +type pageStateEntry struct { + State uint8 + Offset uint64 +} + +// ---- Parent side -------------------------------------------------------- + +// childForkMu serialises the cmd.Start() call across all parallel +// cross-process tests in this binary. Without it, the duplicated +// uffd fd we hand to one child via ExtraFiles is briefly visible in +// the parent's fd table while ANOTHER concurrent test calls fork() +// — so that other test's child inherits a uffd fd it does not own. +// The leaked fd keeps the original test's uffd kernel object alive +// after its owner closes its end, prevents madvise from completing +// once the owning child exits, and produces hard-to-diagnose +// -parallel-only deadlocks. +// +// Holding the mutex only across cmd.Start (which itself holds the +// process lock for the underlying syscall.ForkExec) is enough — by +// the time Start returns the dup'd fd is already mapped into fd 3 +// in the new child and we close it immediately in the parent below. +var childForkMu sync.Mutex + +// Main process, FC in our case. func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error) { t.Helper() data := RandomPages(tt.pagesize, tt.numberOfPages) + if tt.sourcePatcher != nil { + tt.sourcePatcher(data.Content()) + } + size, err := data.Size() require.NoError(t, err) memoryArea, memoryStart, err := testutils.NewPageMmap(t, uint64(size), tt.pagesize) require.NoError(t, err) - // We can pass mapping nil as the serve is used only in the helper process. uffdFd, err := newFd(syscall.O_CLOEXEC | syscall.O_NONBLOCK) require.NoError(t, err) - t.Cleanup(func() { uffdFd.close() }) - err = configureApi(uffdFd, tt.pagesize) - require.NoError(t, err) + require.NoError(t, configureApi(uffdFd, tt.pagesize)) + require.NoError(t, register(uffdFd, memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING|UFFDIO_REGISTER_MODE_WP)) + + t.Cleanup(func() { + // Tear the registration down before the late close. With + // UFFD_FEATURE_EVENT_REMOVE enabled (see configureApi), + // munmap can otherwise block on un-acked REMOVE events. + _ = unregister(uffdFd, memoryStart, uint64(size)) + }) + + tmpDir := t.TempDir() + + contentPath := filepath.Join(tmpDir, "content.bin") + require.NoError(t, os.WriteFile(contentPath, data.Content(), 0o600)) - err = register(uffdFd, memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING|UFFDIO_REGISTER_MODE_WP) + socketPath := filepath.Join(tmpDir, "rpc.sock") + listener, err := net.Listen("unix", socketPath) require.NoError(t, err) cmd := exec.CommandContext(t.Context(), os.Args[0], "-test.run=TestHelperServingProcess", "-test.timeout=0") - cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS=1") - cmd.Env = append(cmd.Env, fmt.Sprintf("GO_MMAP_START=%d", memoryStart)) - cmd.Env = append(cmd.Env, fmt.Sprintf("GO_MMAP_PAGE_SIZE=%d", tt.pagesize)) + cmd.Env = append(os.Environ(), + envHelperFlag+"=1", + envSocketPath+"="+socketPath, + envContentPath+"="+contentPath, + fmt.Sprintf("%s=%d", envMmapStart, memoryStart), + fmt.Sprintf("%s=%d", envMmapPagesize, tt.pagesize), + fmt.Sprintf("%s=%d", envMmapTotalSize, size), + ) if tt.alwaysWP { - cmd.Env = append(cmd.Env, "GO_ALWAYS_WP=1") + cmd.Env = append(cmd.Env, envAlwaysWP+"=1") } if tt.gated { - cmd.Env = append(cmd.Env, "GO_GATED=1") + cmd.Env = append(cmd.Env, envGated+"=1") + } + if tt.barriers { + cmd.Env = append(cmd.Env, envBarriers+"=1") } - dup, err := syscall.Dup(int(uffdFd)) - require.NoError(t, err) + // We hand the uffd fd to the child via ExtraFiles. The child- + // side dup3 inside fork+exec clears CLOEXEC on the destination + // fd (i.e. fd 3 in the child) automatically, so the SOURCE fd + // in our parent should remain CLOEXEC — otherwise every other + // test fork()'d concurrently from this binary inherits a uffd + // it does not own, the kernel keeps the original test's uffd + // alive after its real owner exits, and madvise stops draining. + // At higher -parallel this surfaces as long, hard-to-diagnose + // hangs. + // + // syscall.Dup creates the new fd WITHOUT CLOEXEC, so we + // re-arm it explicitly. Holding childForkMu across the + // dup → cmd.Start window further guarantees no concurrent + // fork can race the F_SETFD. + childForkMu.Lock() - // clear FD_CLOEXEC on the dup we pass across exec - _, err = unix.FcntlInt(uintptr(dup), unix.F_SETFD, 0) + dup, err := syscall.Dup(int(uffdFd)) require.NoError(t, err) + if _, err := unix.FcntlInt(uintptr(dup), unix.F_SETFD, unix.FD_CLOEXEC); err != nil { + childForkMu.Unlock() + require.NoError(t, err) + } uffdFile := os.NewFile(uintptr(dup), "uffd") + cmd.ExtraFiles = []*os.File{uffdFile} + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr - contentReader, contentWriter, err := os.Pipe() - require.NoError(t, err) - - go func() { - _, writeErr := contentWriter.Write(data.Content()) - assert.NoError(t, writeErr) - - closeErr := contentWriter.Close() - assert.NoError(t, closeErr) - }() - - offsetsReader, offsetsWriter, err := os.Pipe() - require.NoError(t, err) - - t.Cleanup(func() { - offsetsReader.Close() - }) + startErr := cmd.Start() + uffdFile.Close() + childForkMu.Unlock() - readyReader, readyWriter, err := os.Pipe() - require.NoError(t, err) + require.NoError(t, startErr) - t.Cleanup(func() { - readyReader.Close() - }) - - readySignal := make(chan struct{}, 1) + // Accept the child's connection. Tight deadline so a wedged + // child surfaces fast instead of hanging the test. + type acceptResult struct { + conn net.Conn + err error + } + acceptCh := make(chan acceptResult, 1) go func() { - _, err := io.ReadAll(readyReader) - assert.NoError(t, err) - - readySignal <- struct{}{} + c, err := listener.Accept() + acceptCh <- acceptResult{conn: c, err: err} }() - extraFiles := []*os.File{ - uffdFile, - contentReader, - offsetsWriter, - readyWriter, + var conn net.Conn + select { + case res := <-acceptCh: + require.NoError(t, res.err) + conn = res.conn + case <-time.After(10 * time.Second): + listener.Close() + _ = cmd.Process.Kill() + _, _ = cmd.Process.Wait() + t.Fatalf("child did not connect within 10s") } + listener.Close() - var gateCmdWriter *os.File - var gateSyncReader *os.File - if tt.gated { - var gateCmdReader *os.File - gateCmdReader, gateCmdWriter, err = os.Pipe() - require.NoError(t, err) - - var gateSyncWriter *os.File - gateSyncReader, gateSyncWriter, err = os.Pipe() - require.NoError(t, err) + client := jsonrpc.NewClient(conn) - t.Cleanup(func() { - gateCmdWriter.Close() - gateSyncReader.Close() - }) - - extraFiles = append(extraFiles, gateCmdReader) // fd 7 - extraFiles = append(extraFiles, gateSyncWriter) // fd 8 + h := &testHandler{ + memoryArea: &memoryArea, + pagesize: tt.pagesize, + data: data, + client: client, + conn: conn, + cmd: cmd, } - cmd.ExtraFiles = extraFiles - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - err = cmd.Start() - require.NoError(t, err) - - contentReader.Close() - offsetsWriter.Close() - readyWriter.Close() - uffdFile.Close() - if tt.gated { - extraFiles[4].Close() // gateCmdReader - extraFiles[5].Close() // gateSyncWriter - } + // WaitReady blocks on the child until its initial setup is done + // (uffd serve goroutine running, hooks installed). The RPC reply + // IS the readiness signal — no separate ready pipe / signal + // needed. + require.NoError(t, h.client.Call("Service.WaitReady", &Empty{}, &Empty{})) t.Cleanup(func() { - signalErr := cmd.Process.Signal(syscall.SIGUSR1) - assert.NoError(t, signalErr) + // Best-effort graceful shutdown via RPC. If the child has + // already crashed the RPC will error and we fall back to + // killing the process below. + _ = h.client.Call("Service.Shutdown", &Empty{}, &Empty{}) + _ = client.Close() waitErr := cmd.Wait() - // It can be either nil, an ExitError, a context.Canceled error, or "signal: killed" - assert.True(t, - (waitErr != nil && func(err error) bool { - var exitErr *exec.ExitError - - return errors.As(err, &exitErr) - }(waitErr)) || - errors.Is(waitErr, context.Canceled) || - (waitErr != nil && strings.Contains(waitErr.Error(), "signal: killed")) || - waitErr == nil, - "unexpected error: %v", waitErr, - ) - - // Tear down the UFFD registration before the early uffdFd.close() - // cleanup runs. This branch enables UFFD_FEATURE_EVENT_REMOVE - // (see configureApi in fd_helpers_test.go), so without the - // unregister, munmap can block on un-acked REMOVE events queued - // by the kernel against the still-registered range. Cleanups - // run LIFO, so this fires before the close registered earlier. - assert.NoError(t, unregister(uffdFd, memoryStart, uint64(size))) + if waitErr != nil { + var exitErr *exec.ExitError + if !errors.As(waitErr, &exitErr) { + t.Logf("helper process Wait: %v", waitErr) + } + } }) - // pageStatesOnce asks the serving process for a snapshot of its pageTracker - // and decodes it into a per-state view. It can only be called once. - pageStatesOnce := func() (handlerPageStates, error) { - err := cmd.Process.Signal(syscall.SIGUSR2) - if err != nil { - return handlerPageStates{}, err + if tt.gated { + h.servePause = func() error { + return h.client.Call("Service.ServePause", &Empty{}, &Empty{}) } - - var result handlerPageStates - - for { - var entry pageStateEntry - - // binary.Read uses the same field layout as binary.Write on - // the producer side (sum of fixed-size fields, no struct - // padding), so we never have to hard-code the wire size. - err := binary.Read(offsetsReader, binary.LittleEndian, &entry) - if errors.Is(err, io.EOF) { - break - } - - if err != nil { - return handlerPageStates{}, fmt.Errorf("decoding page state entry: %w", err) - } - - switch pageState(entry.State) { - case faulted: - result.faulted = append(result.faulted, uint(entry.Offset)) - case removed: - result.removed = append(result.removed, uint(entry.Offset)) - } + h.serveResume = func() error { + return h.client.Call("Service.ServeResume", &Empty{}, &Empty{}) } - - slices.Sort(result.faulted) - slices.Sort(result.removed) - - return result, nil - } - - select { - case <-t.Context().Done(): - return nil, t.Context().Err() - case <-readySignal: } - h := &testHandler{ - memoryArea: &memoryArea, - pagesize: tt.pagesize, - data: data, - pageStatesOnce: pageStatesOnce, - } + h.pageStatesOnce = func() (handlerPageStates, error) { + var reply PageStatesReply + if err := h.client.Call("Service.PageStates", &Empty{}, &reply); err != nil { + return handlerPageStates{}, err + } - if tt.gated { - h.servePause = func() error { - if _, err := gateCmdWriter.Write([]byte{'P'}); err != nil { - return err + var states handlerPageStates + for _, e := range reply.Entries { + switch pageState(e.State) { + case faulted: + states.faulted = append(states.faulted, uint(e.Offset)) + case removed: + states.removed = append(states.removed, uint(e.Offset)) } - var buf [1]byte - _, err := gateSyncReader.Read(buf[:]) - - return err } - h.serveResume = func() error { - _, err := gateCmdWriter.Write([]byte{'R'}) + slices.Sort(states.faulted) + slices.Sort(states.removed) - return err - } + return states, nil } return h, nil } -// Secondary process, orchestrator in our case +// ---- Child side --------------------------------------------------------- + +// Secondary process, orchestrator in our case. func TestHelperServingProcess(t *testing.T) { t.Parallel() - if os.Getenv("GO_TEST_HELPER_PROCESS") != "1" { + if os.Getenv(envHelperFlag) != "1" { t.Skip("this is a helper process, skipping direct execution") } - err := crossProcessServe() - if err != nil { - fmt.Println("exit serving process", err) + if err := crossProcessServe(); err != nil { + fmt.Fprintln(os.Stderr, "exit serving process:", err) os.Exit(1) } os.Exit(0) } +// crossProcessServe wires up the child side: connects back to the +// parent socket, registers the RPC service, and runs uffd.Serve in a +// background goroutine that pause/resume RPCs can stop and restart. func crossProcessServe() error { - ctx, cancel := context.WithCancelCause(context.Background()) - defer cancel(nil) + socketPath := os.Getenv(envSocketPath) + if socketPath == "" { + return fmt.Errorf("missing %s", envSocketPath) + } - startRaw, err := strconv.Atoi(os.Getenv("GO_MMAP_START")) + conn, err := net.Dial("unix", socketPath) if err != nil { - return fmt.Errorf("exit parsing mmap start: %w", err) + return fmt.Errorf("dial parent socket: %w", err) } + defer conn.Close() + startRaw, err := strconv.ParseUint(os.Getenv(envMmapStart), 10, 64) + if err != nil { + return fmt.Errorf("parse %s: %w", envMmapStart, err) + } memoryStart := uintptr(startRaw) - uffdFile := os.NewFile(uintptr(3), os.Getenv("GO_UFFD_FILE")) - defer uffdFile.Close() - - uffdFd := uffdFile.Fd() - - contentFile := os.NewFile(uintptr(4), "content") - defer contentFile.Close() + pagesize, err := strconv.ParseInt(os.Getenv(envMmapPagesize), 10, 64) + if err != nil { + return fmt.Errorf("parse %s: %w", envMmapPagesize, err) + } - content, err := io.ReadAll(contentFile) + totalSize, err := strconv.ParseInt(os.Getenv(envMmapTotalSize), 10, 64) if err != nil { - return fmt.Errorf("exit reading content: %w", err) + return fmt.Errorf("parse %s: %w", envMmapTotalSize, err) } - pageSize, err := strconv.ParseInt(os.Getenv("GO_MMAP_PAGE_SIZE"), 10, 64) + content, err := os.ReadFile(os.Getenv(envContentPath)) if err != nil { - return fmt.Errorf("exit parsing page size: %w", err) + return fmt.Errorf("read content: %w", err) } + if int64(len(content)) != totalSize { + return fmt.Errorf("content size %d != expected %d", len(content), totalSize) + } + + data := NewMemorySlicer(content, pagesize) - data := NewMemorySlicer(content, pageSize) + uffdFile := os.NewFile(uintptr(3), "uffd") + defer uffdFile.Close() + uffdFd := uffdFile.Fd() - m := memory.NewMapping([]memory.Region{ + mapping := memory.NewMapping([]memory.Region{ { BaseHostVirtAddr: memoryStart, - Size: uintptr(len(content)), + Size: uintptr(totalSize), Offset: 0, - PageSize: uintptr(pageSize), + PageSize: uintptr(pagesize), }, }) l, err := logger.NewDevelopmentLogger() if err != nil { - return fmt.Errorf("exit creating logger: %w", err) + return fmt.Errorf("logger: %w", err) } - uffd, err := NewUserfaultfdFromFd(uffdFd, data, m, l) + uffd, err := NewUserfaultfdFromFd(uffdFd, data, mapping, l) if err != nil { - return fmt.Errorf("exit creating uffd: %w", err) + return fmt.Errorf("NewUserfaultfdFromFd: %w", err) } - if os.Getenv("GO_ALWAYS_WP") == "1" { + if os.Getenv(envAlwaysWP) == "1" { uffd.defaultCopyMode = UFFDIO_COPY_MODE_WP } - offsetsFile := os.NewFile(uintptr(5), "offsets") + br := newBarrierRegistry() - offsetsSignal := make(chan os.Signal, 1) - signal.Notify(offsetsSignal, syscall.SIGUSR2) - defer signal.Stop(offsetsSignal) + // Hooks are only wired up when the test asked for them (race + // tests). For everyone else we leave the fields nil so the hot + // path is a single nil-pointer load + branch — keeps the high- + // throughput tests (TestParallelMissingWriteWithPrefault, etc.) + // from paying for a Mutex per fault. + if os.Getenv(envBarriers) == "1" { + uffd.beforeWorkerRLockHook = br.hookFor(barrierBeforeRLock) + uffd.beforeFaultPageHook = br.hookFor(barrierBeforeFaultPage) + } - go func() { - defer offsetsFile.Close() - - for { - select { - case <-ctx.Done(): - return - case <-offsetsSignal: - entries, entriesErr := uffd.pageStateEntries() - if entriesErr != nil { - cancel(fmt.Errorf("error getting page state entries: %w", entriesErr)) - - return - } - - for _, entry := range entries { - writeErr := binary.Write(offsetsFile, binary.LittleEndian, entry) - if writeErr != nil { - cancel(fmt.Errorf("error writing page state entry: %w", writeErr)) - - return - } - } - - return - } - } - }() + gated := os.Getenv(envGated) == "1" - fdExit, err := fdexit.New() - if err != nil { - return fmt.Errorf("exit creating fd exit: %w", err) + svc := &Service{ + uffd: uffd, + br: br, + gated: gated, + shutdown: make(chan struct{}), } - defer fdExit.Close() + svc.startServe() - exitUffd := make(chan struct{}, 1) + server := rpc.NewServer() + if err := server.Register(svc); err != nil { + return fmt.Errorf("rpc Register: %w", err) + } + // Run the codec in a goroutine so we can react to Shutdown + // without depending on the codec returning. + codecDone := make(chan struct{}) go func() { - defer func() { exitUffd <- struct{}{} }() - - serverErr := uffd.Serve(ctx, fdExit) - if serverErr != nil { - msg := fmt.Errorf("error serving: %w", serverErr) - fmt.Fprint(os.Stderr, msg.Error()) - cancel(msg) - } + defer close(codecDone) + server.ServeCodec(jsonrpc.NewServerCodec(conn)) }() - cleanup := func() { - fdExit.SignalExit() - <-exitUffd + select { + case <-svc.shutdown: + fmt.Fprintln(os.Stderr, "child: shutdown received") + case <-codecDone: + fmt.Fprintln(os.Stderr, "child: codec done") } - defer func() { cleanup() }() - if os.Getenv("GO_GATED") == "1" { - gateCmdFile := os.NewFile(uintptr(7), "gate-cmd") - defer gateCmdFile.Close() + // Release any still-parked barriers so the serve goroutine can + // finish, then stop the serve goroutine. + br.releaseAll() + fmt.Fprintln(os.Stderr, "child: barriers released") + svc.stopServe() + fmt.Fprintln(os.Stderr, "child: serve stopped") - gateSyncFile := os.NewFile(uintptr(8), "gate-sync") - defer gateSyncFile.Close() + // Closing the conn is sufficient to unblock ServeCodec if it + // hasn't already returned. + _ = conn.Close() + <-codecDone + fmt.Fprintln(os.Stderr, "child: codec exited") - startServe := func() func() { - newExit, fdErr := fdexit.New() - if fdErr != nil { - cancel(fmt.Errorf("error creating fd exit: %w", fdErr)) + return nil +} - return func() {} - } +// Service is the RPC surface exposed to the parent. Methods follow +// net/rpc's required signature. +type Service struct { + uffd *Userfaultfd + br *barrierRegistry - done := make(chan struct{}) - go func() { - defer close(done) - if err := uffd.Serve(ctx, newExit); err != nil { - cancel(fmt.Errorf("error serving: %w", err)) - } - }() - - return func() { - newExit.SignalExit() - <-done - newExit.Close() - } - } + gated bool + + mu sync.Mutex + stop func() // currently active serve-stop function, nil if paused + shutdown chan struct{} + closed bool +} + +func (s *Service) startServe() { + exit, err := fdexit.New() + if err != nil { + fmt.Fprintln(os.Stderr, "fdexit.New:", err) + + return + } - stopServe := func() { - cleanup() + done := make(chan struct{}) + go func() { + defer close(done) + if err := s.uffd.Serve(context.Background(), exit); err != nil { + fmt.Fprintln(os.Stderr, "uffd.Serve:", err) } + }() - go func() { - var buf [1]byte - for { - if _, err := gateCmdFile.Read(buf[:]); err != nil { - return - } - - switch buf[0] { - case 'P': - stopServe() - gateSyncFile.Write([]byte{1}) - case 'R': - newStop := startServe() - stopServe = newStop - cleanup = newStop - } - } - }() + s.stop = func() { + _ = exit.SignalExit() + <-done + exit.Close() } +} - exitSignal := make(chan os.Signal, 1) - signal.Notify(exitSignal, syscall.SIGUSR1) - defer signal.Stop(exitSignal) +func (s *Service) stopServe() { + s.mu.Lock() + defer s.mu.Unlock() + if s.stop != nil { + s.stop() + s.stop = nil + } +} - readyFile := os.NewFile(uintptr(6), "ready") +// WaitReady is a no-op handler whose successful reply is the +// readiness signal for the parent. +func (s *Service) WaitReady(_ *Empty, _ *Empty) error { + return nil +} - closeErr := readyFile.Close() - if closeErr != nil { - return fmt.Errorf("error closing ready file: %w", closeErr) +func (s *Service) PageStates(_ *Empty, reply *PageStatesReply) error { + entries, err := s.uffd.pageStateEntries() + if err != nil { + return err } + reply.Entries = entries - select { - case <-ctx.Done(): - return fmt.Errorf("context done: %w: %w", ctx.Err(), context.Cause(ctx)) - case <-exitSignal: - return nil + return nil +} + +func (s *Service) ServePause(_ *Empty, _ *Empty) error { + if !s.gated { + return errors.New("ServePause called on a non-gated handler") } + s.stopServe() + + return nil } -// pageStateEntry is the wire format used between the main test process -// and the serving helper process. State is emitted as a single byte so it -// can be written directly with binary.Write and decoded on the other side. -type pageStateEntry struct { - State uint8 - Offset uint64 +func (s *Service) ServeResume(_ *Empty, _ *Empty) error { + if !s.gated { + return errors.New("ServeResume called on a non-gated handler") + } + s.mu.Lock() + defer s.mu.Unlock() + s.startServe() + + return nil +} + +func (s *Service) InstallFaultBarrier(args *FaultBarrierArgs, reply *FaultBarrierReply) error { + reply.Token = s.br.install(uintptr(args.Addr), barrierPoint(args.Point)) + + return nil +} + +func (s *Service) WaitFaultHeld(args *TokenArgs, _ *Empty) error { + return s.br.waitArrived(context.Background(), args.Token) +} + +func (s *Service) ReleaseFault(args *TokenArgs, _ *Empty) error { + s.br.release(args.Token) + + return nil } -// pageStateEntries returns a snapshot of every tracked page and its state. -// It holds the settleRequests write lock so no in-flight faultPage worker -// can mutate the pageTracker while we iterate. +func (s *Service) Shutdown(_ *Empty, _ *Empty) error { + s.mu.Lock() + defer s.mu.Unlock() + if !s.closed { + s.closed = true + close(s.shutdown) + } + + return nil +} + +// pageStateEntries returns a snapshot of every tracked page and its +// state. It briefly takes settleRequests.Lock so no in-flight worker +// can mutate the pageTracker while we read it. func (u *Userfaultfd) pageStateEntries() ([]pageStateEntry, error) { u.settleRequests.Lock() u.settleRequests.Unlock() //nolint:staticcheck // SA2001: intentional — settle the read locks. @@ -520,15 +600,163 @@ func (u *Userfaultfd) pageStateEntries() ([]pageStateEntry, error) { u.pageTracker.mu.RLock() defer u.pageTracker.mu.RUnlock() - var entries []pageStateEntry + entries := make([]pageStateEntry, 0, len(u.pageTracker.m)) for addr, state := range u.pageTracker.m { offset, err := u.ma.GetOffset(addr) if err != nil { return nil, fmt.Errorf("address %#x not in mapping: %w", addr, err) } - - entries = append(entries, pageStateEntry{uint8(state), uint64(offset)}) + entries = append(entries, pageStateEntry{State: uint8(state), Offset: uint64(offset)}) } return entries, nil } + +// ---- Barrier registry --------------------------------------------------- + +// barrierPoint identifies WHICH hook a barrier should park on. +type barrierPoint uint8 + +const ( + // barrierBeforeRLock parks the worker BEFORE settleRequests.RLock(), + // i.e. before it can read the page state. Use this for the + // stale-source race: a parallel REMOVE batch on the parent loop + // can take the write lock immediately because no worker holds + // the read lock. + barrierBeforeRLock barrierPoint = 1 + // barrierBeforeFaultPage parks the worker AFTER it has taken + // settleRequests.RLock and decided on `source`, but BEFORE the + // actual UFFDIO_COPY syscall. Use this for the in-flight COPY + // deadlock test: the parent's madvise must still return even + // though a worker holds RLock. + barrierBeforeFaultPage barrierPoint = 2 +) + +// barrierRegistry is the child-process side of the barrier. The +// hooks installed on Userfaultfd consult this registry by addr+point +// to decide whether to park, and the RPC handlers manipulate it from +// the parent over the socket. +type barrierRegistry struct { + mu sync.Mutex + next uint64 + tokens map[uint64]*barrierSlot + byKey map[barrierKey]uint64 +} + +type barrierKey struct { + addr uintptr + point barrierPoint +} + +type barrierSlot struct { + addr uintptr + point barrierPoint + arrived chan struct{} + release chan struct{} + arrivedOnce sync.Once +} + +func newBarrierRegistry() *barrierRegistry { + return &barrierRegistry{ + tokens: make(map[uint64]*barrierSlot), + byKey: make(map[barrierKey]uint64), + } +} + +func (b *barrierRegistry) install(addr uintptr, point barrierPoint) uint64 { + b.mu.Lock() + defer b.mu.Unlock() + + b.next++ + token := b.next + slot := &barrierSlot{ + addr: addr, + point: point, + arrived: make(chan struct{}), + release: make(chan struct{}), + } + b.tokens[token] = slot + b.byKey[barrierKey{addr, point}] = token + + return token +} + +func (b *barrierRegistry) lookupByAddr(addr uintptr, point barrierPoint) *barrierSlot { + b.mu.Lock() + defer b.mu.Unlock() + + token, ok := b.byKey[barrierKey{addr, point}] + if !ok { + return nil + } + + return b.tokens[token] +} + +func (b *barrierRegistry) waitArrived(ctx context.Context, token uint64) error { + b.mu.Lock() + slot, ok := b.tokens[token] + b.mu.Unlock() + if !ok { + return fmt.Errorf("unknown barrier token %d", token) + } + + select { + case <-slot.arrived: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (b *barrierRegistry) release(token uint64) { + b.mu.Lock() + slot, ok := b.tokens[token] + delete(b.tokens, token) + if ok { + delete(b.byKey, barrierKey{slot.addr, slot.point}) + } + b.mu.Unlock() + + if !ok { + return + } + + select { + case <-slot.release: + default: + close(slot.release) + } +} + +func (b *barrierRegistry) releaseAll() { + b.mu.Lock() + tokens := make([]uint64, 0, len(b.tokens)) + for t := range b.tokens { + tokens = append(tokens, t) + } + b.mu.Unlock() + + for _, t := range tokens { + b.release(t) + } +} + +// hookFor returns the function to assign to a Userfaultfd +// beforeXxxHook field. The returned function is a no-op for any +// (addr, point) pair that hasn't been Install'd, so non-targeted +// faults see no scheduling distortion. +func (b *barrierRegistry) hookFor(point barrierPoint) func(addr uintptr) { + return func(addr uintptr) { + slot := b.lookupByAddr(addr, point) + if slot == nil { + return + } + + slot.arrivedOnce.Do(func() { + close(slot.arrived) + }) + + <-slot.release + } +} diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go index 003ee1523f..5cd7752a58 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/helpers_test.go @@ -4,6 +4,9 @@ import ( "bytes" "context" "fmt" + "io" + "net/rpc" + "os/exec" "sync" "testing" "time" @@ -29,6 +32,18 @@ type testConfig struct { alwaysWP bool // gated enables pause/resume control over the handler's serve loop. gated bool + // barriers wires up the per-worker fault hooks in the child + // (used by race tests). Off by default so the worker hot path + // stays a single nil-pointer load + branch in non-race tests. + barriers bool + // sourcePatcher, if non-nil, is invoked on the random source data + // AFTER it's generated but BEFORE it's written to the on-disk + // content file the child reads. Tests can use this to plant + // deterministic sentinel bytes in the source so the post-test + // assertion can distinguish "post-fix zero-fault" from "pre-fix + // UFFDIO_COPY of stale src bytes" without depending on the + // happenstance value of randomly-generated bytes. + sourcePatcher func([]byte) } type operationMode uint32 @@ -91,14 +106,52 @@ type testHandler struct { pagesize uint64 data *MemorySlicer // pageStatesOnce returns a per-state snapshot of the handler's pageTracker. - // It can only be called once. + // Backed by the PageStates RPC; callable any number of times. + // The "Once" suffix is kept for source-stability with the existing + // test sites. pageStatesOnce func() (handlerPageStates, error) // servePause and serveResume gate the UFFD event loop in the child process. // Tests use them to deterministically drain a batch of REMOVE events // before more faults are processed. servePause func() error serveResume func() error - mutex sync.Mutex + + // client is the RPC channel to the child helper process. + client *rpc.Client + conn io.Closer + cmd *exec.Cmd + + mutex sync.Mutex +} + +// installFaultBarrier asks the child to park the next worker that +// hits `point` for `addr`. Returns a token that must be passed to +// waitFaultHeld and releaseFault. +func (h *testHandler) installFaultBarrier(_ context.Context, addr uintptr, point barrierPoint) (uint64, error) { + var reply FaultBarrierReply + err := h.client.Call("Service.InstallFaultBarrier", &FaultBarrierArgs{Addr: uint64(addr), Point: uint8(point)}, &reply) + + return reply.Token, err +} + +// waitFaultHeld blocks until the child reports that a worker has +// reached the barrier identified by token. The wait is bounded via +// context by issuing the call on a goroutine and racing it against +// ctx; net/rpc's Call doesn't take a context directly. +func (h *testHandler) waitFaultHeld(ctx context.Context, token uint64) error { + call := h.client.Go("Service.WaitFaultHeld", &TokenArgs{Token: token}, &Empty{}, nil) + select { + case <-call.Done: + return call.Error + case <-ctx.Done(): + return ctx.Err() + } +} + +// releaseFault releases a parked worker so it proceeds past the +// barrier. +func (h *testHandler) releaseFault(_ context.Context, token uint64) error { + return h.client.Call("Service.ReleaseFault", &TokenArgs{Token: token}, &Empty{}) } func (h *testHandler) executeAll(t *testing.T, operations []operation) { diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/remove_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/remove_test.go index 3f96d555c1..6ec229f78f 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/remove_test.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/remove_test.go @@ -156,9 +156,20 @@ func TestRemoveThenFault(t *testing.T) { // succeeds without faulting because MADV_DONTNEED blocks (waiting for ack) // and doesn't unmap the page until the handler processes the event. // When the handler resumes, it only sees the REMOVE — no MISSING fault. +// +// NOTE: this test (and the other gated tests below) deliberately does +// NOT call t.Parallel(). While the handler is paused, any user thread +// that triggers a queued pagefault on the registered region is +// suspended in the kernel's pagefault path. From the Go runtime's +// perspective that goroutine is "running" (not in syscall, since it's +// a plain memory store) and cannot be preempted until the fault is +// served. If a CONCURRENT cross-process test in the same binary +// triggers a stop-the-world GC pause during this window, STW will +// wait forever for the suspended goroutine to reach a safe point — +// the kernel cannot deliver the SIGURG preempt signal until the +// pagefault is served, and the handler is paused. Running the gated +// tests sequentially avoids that interleaving. func TestRemoveThenWriteGated(t *testing.T) { - t.Parallel() - tests := []testConfig{ { name: "4k gated remove with concurrent write", @@ -194,8 +205,6 @@ func TestRemoveThenWriteGated(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() - h, err := configureCrossProcessTest(t, tt) require.NoError(t, err) @@ -218,9 +227,9 @@ func TestRemoveThenWriteGated(t *testing.T) { // was queued first. The write to a missing page triggers MISSING (queued first), // then MADV_DONTNEED triggers REMOVE (queued second). When the handler resumes, // it processes REMOVE first, then MISSING — the write is not skipped. +// +// See TestRemoveThenWriteGated for why this test is not parallel. func TestWriteThenRemoveGated(t *testing.T) { - t.Parallel() - tests := []testConfig{ { name: "4k write then remove in same batch", @@ -258,8 +267,6 @@ func TestWriteThenRemoveGated(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() - h, err := configureCrossProcessTest(t, tt) require.NoError(t, err) diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go index 5496cd4e2b..1593e9ef3a 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go @@ -70,6 +70,29 @@ type Userfaultfd struct { // if no new UFFD events arrive to wake poll. wakeupPipe [2]int + // Test-only synchronisation hooks. Both default to nil and the nil + // branch costs a single un-predictable load + branch in the hot path, + // so they are effectively free in production. They MUST only be set + // from _test.go files. They let tests park a worker goroutine at a + // known point so a racing event (REMOVE, MISSING) can be issued + // deterministically before the worker proceeds. + // + // - beforeWorkerRLockHook: called as the very first thing in the + // worker goroutine, BEFORE settleRequests.RLock(). At this point + // the test holds the goroutine before it can claim the read lock, + // so a parallel REMOVE batch in the parent loop can take the + // write lock immediately and mutate page state. This is the + // window the production fix actually closes — the post-fix + // worker reads state under RLock, so it observes the REMOVE. + // + // - beforeFaultPageHook: called inside the worker AFTER RLock and + // AFTER the state-vs-source decision, but BEFORE the actual + // UFFDIO_COPY/UFFDIO_ZEROPAGE syscall. Lets a test simulate a + // slow data fetch / in-flight COPY so a parent madvise can race + // against an in-flight worker. + beforeWorkerRLockHook func(addr uintptr) + beforeFaultPageHook func(addr uintptr) + logger logger.Logger } @@ -330,6 +353,16 @@ func (u *Userfaultfd) Serve( } u.wg.Go(func() error { + // Test-only barrier: park the worker BEFORE it takes + // RLock. While parked, the parent loop is free to take + // settleRequests.Lock() to process REMOVE events, which + // is exactly the window the production fix had to close + // (pre-fix the worker had already captured a stale state + // snapshot in the parent loop). + if hook := u.beforeWorkerRLockHook; hook != nil { + hook(addr) + } + // The RLock must be called inside the goroutine to ensure RUnlock runs via defer, // even if the errgroup is cancelled or the goroutine returns early. // This check protects us against race condition between marking the request for prefetching and accessing the prefetchTracker. @@ -344,6 +377,15 @@ func (u *Userfaultfd) Serve( accessType = block.Write } + // Test-only barrier: park the worker AFTER state has been + // read under RLock but BEFORE the actual UFFDIO_* syscall. + // Lets tests simulate a slow / in-flight COPY so the + // parent's madvise (and the subsequent REMOVE batch) can + // race against a worker that already holds RLock. + if hook := u.beforeFaultPageHook; hook != nil { + hook(addr) + } + handled, err := u.faultPage( ctx, addr, From 151736b8193c69a06af839b0bd478168f1c8f6c2 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 27 Apr 2026 21:48:31 -0700 Subject: [PATCH 3/3] test(uffd): add deterministic stale-source / madvise-deadlock / faulted-short-circuit race tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three new race tests built on the unix-socket RPC harness and the test-only fault-barrier hooks. None of them use sleeps, retries, or soak loops — each test installs explicit barriers on the child's worker goroutine, drives the racing kernel operation from the parent, and asserts on a concrete post-state. - TestStaleSourceRaceMissingAndRemove: regression test for the production fix in the next stacked PR. Plants a non-zero sentinel into the source page, parks the worker via barrierBeforeRLock, fires madvise on the same page, waits for the REMOVE batch to commit (state == removed), releases the worker, then asserts the page is zero-filled. PRE-FIX (this PR) the worker UFFDIO_COPYs the planted sentinel because it captured `source = u.src` in the parent loop before the REMOVE landed and never re-reads state inside the goroutine; the assertion fires. POST-FIX (next PR) the worker re-reads state under RLock, observes `removed`, and zero-faults; assertion passes. Wallclock < 50ms / variant. - TestNoMadviseDeadlockWithInflightCopy: liveness regression test for the user-visible symptom that originally surfaced the race — parent madvise deadlocking while the worker holds RLock. Parks the worker via barrierBeforeFaultPage (i.e. holding RLock, mid- handler), fires MADV_DONTNEED, asserts madvise returns within 2s. Catches any future change that accidentally couples readEvents() to settleRequests as a fast assertion failure rather than a 30m CI timeout. Wallclock < 50ms / variant. - TestFaultedShortCircuitOrdering: smoke test on the REMOVE-then-pagefault batch ordering using the gated harness. Pins the invariant that REMOVE batches drain before pagefault dispatch in a single Serve iteration. Wallclock ~120ms / variant. These tests are EXPECTED TO FAIL on this PR — they are the bug demonstration. The production fix lands in the stacked PR fix/uffd-stale-source-race and flips them green. Once the fix lands, all three pass `-count=20 -timeout=30s` deterministically. --- .../pkg/sandbox/uffd/userfaultfd/race_test.go | 407 ++++++++++++++++++ 1 file changed, 407 insertions(+) create mode 100644 packages/orchestrator/pkg/sandbox/uffd/userfaultfd/race_test.go diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/race_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/race_test.go new file mode 100644 index 0000000000..b742ef25e2 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/race_test.go @@ -0,0 +1,407 @@ +package userfaultfd + +import ( + "context" + "fmt" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +// raceHappyPathBudget bounds every race test in this file. The whole +// point of these tests is that they detect a regression as a fast, +// targeted assertion rather than as a CI -timeout 30m hang. None of +// these tests should approach this budget on a healthy build. +const raceHappyPathBudget = 5 * time.Second + +// barrierArrivalDeadline is how long the test will wait for a worker +// to reach an installed barrier. The hook fires the first thing in +// the worker goroutine, so on a healthy build it's a sub-millisecond +// rendezvous over the unix-socket RPC. Anything approaching this +// deadline means the handler dispatch is wedged. +const barrierArrivalDeadline = 2 * time.Second + +// madviseBudget is how long we allow MADV_DONTNEED to spend in the +// kernel after we've parked a worker mid-handler. The fix guarantees +// madvise unblocks as soon as the handler drains the REMOVE event +// from the uffd fd, regardless of any worker holding RLock — +// readEvents requires no lock. +const madviseBudget = 2 * time.Second + +// withRaceContext bounds a single race test to raceHappyPathBudget, +// failing with a clear "deadlock" message if the budget is exceeded. +func withRaceContext(t *testing.T, body func(ctx context.Context)) { + t.Helper() + + ctx, cancel := context.WithTimeout(t.Context(), raceHappyPathBudget) + defer cancel() + + done := make(chan struct{}) + go func() { + defer close(done) + body(ctx) + }() + + select { + case <-done: + case <-ctx.Done(): + t.Fatalf("race test exceeded happy-path budget of %s — handler is wedged", raceHappyPathBudget) + } +} + +// TestStaleSourceRaceMissingAndRemove is the deterministic regression +// test for the production fix in Serve(): +// +// - Pre-fix the parent serve loop captured `state == missing` and +// `source = u.src` BEFORE handing the work to a worker goroutine. +// A REMOVE event for the same page that arrived between then and +// the worker actually running would silently leave the worker +// with a stale `source = u.src` snapshot, which it would then +// UFFDIO_COPY into the page that the kernel had just unmapped. +// +// - Post-fix the worker reads pageTracker state INSIDE the +// goroutine, under settleRequests.RLock, atomically with the +// decision of which `source` to use. +// +// The test installs a barrierBeforeRLock on page X (so the worker +// for X parks before it can read state), triggers a MISSING-write +// fault on X from the parent, waits for the worker to park, fires +// MADV_DONTNEED on X (which can take settleRequests.Lock immediately +// — no worker holds RLock), and then releases the worker. After +// release the worker, post-fix, observes state=removed under RLock +// and zero-faults; pre-fix it would have UFFDIO_COPY'd the planted +// sentinel byte from u.src. A direct read of the page contents +// distinguishes the two outcomes deterministically. +func TestStaleSourceRaceMissingAndRemove(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + pagesize uint64 + }{ + {name: "4k", pagesize: header.PageSize}, + {name: "hugepage", pagesize: header.HugepageSize}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + withRaceContext(t, func(ctx context.Context) { + // Plant a deterministic, non-zero sentinel as the + // first byte of the source data for the page we'll + // race on. Pre-fix, the worker would UFFDIO_COPY this + // sentinel into the page after the REMOVE has already + // unmapped it. Post-fix the worker reads + // state == removed under RLock and zero-fills. + // Planting goes through testConfig.sourcePatcher so + // it lands BOTH in the parent's MemorySlicer and in + // the on-disk content file the child reads. + const sentinel = byte(0xC3) + const pageIdx = 1 + pageOffset := int64(pageIdx) * int64(tt.pagesize) + + cfg := testConfig{ + pagesize: tt.pagesize, + numberOfPages: 4, + barriers: true, + sourcePatcher: func(content []byte) { + content[pageOffset] = sentinel + }, + } + + h, err := configureCrossProcessTest(t, cfg) + require.NoError(t, err) + + memStart := uintptr(unsafe.Pointer(&(*h.memoryArea)[0])) + addr := memStart + uintptr(pageIdx)*uintptr(tt.pagesize) + + token, err := h.installFaultBarrier(ctx, addr, barrierBeforeRLock) + require.NoError(t, err) + + // Trigger a READ fault (NOT a write — a write would + // overwrite the very byte we want to inspect to + // distinguish the two outcomes). h.executeRead does + // the touch + content check; we run it in a goroutine + // because it blocks on the fault until we release the + // barrier. + readErrCh := make(chan error, 1) + go func() { + readErrCh <- h.executeRead(ctx, operation{offset: pageOffset, mode: operationModeRead}) + }() + + // Wait for the worker for `addr` to park at the + // pre-RLock barrier. + waitCtx, waitCancel := context.WithTimeout(ctx, barrierArrivalDeadline) + err = h.waitFaultHeld(waitCtx, token) + waitCancel() + require.NoError(t, err, "worker for page %d (addr %#x) did not park at barrier", pageIdx, addr) + + // Fire MADV_DONTNEED on the same page from the + // parent. The serve loop can take Lock immediately + // because the parked worker has not yet acquired + // RLock. + madviseCtx, madviseCancel := context.WithTimeout(ctx, madviseBudget) + err = h.executeRemove(operation{offset: pageOffset, mode: operationModeRemove}) + madviseCancel() + _ = madviseCtx + require.NoError(t, err, "MADV_DONTNEED on page %d did not return — handler dispatch wedged", pageIdx) + + // Wait for the handler to commit setState(removed). + // A tight poll loop with a hard deadline is used + // rather than a sleep — the transition is + // microseconds in the happy path. + require.NoError(t, waitForState(ctx, h, uint64(pageOffset), removed, barrierArrivalDeadline), + "handler did not transition page %d to `removed` after MADV_DONTNEED", pageIdx) + + // Release the parked worker. Post-fix it will + // observe state == removed and zero-fault; pre-fix + // it would proceed with the captured stale source. + require.NoError(t, h.releaseFault(ctx, token)) + + select { + case err := <-readErrCh: + // Pre-fix: executeRead's bytes.Equal succeeds + // (page contains src bytes), so err == nil but + // the page is observably wrong. Post-fix: + // bytes.Equal fails (page is zero-filled), so + // err != nil. We use the page-content assertion + // below instead of relying on this side-channel. + _ = err + case <-ctx.Done(): + t.Fatalf("read of page %d did not unblock after barrier release", pageIdx) + } + + // THE bug-detection assertion: post-fix the page + // MUST be zero-filled. Pre-fix the worker + // UFFDIO_COPY'd the planted sentinel. + page := (*h.memoryArea)[pageOffset : pageOffset+int64(tt.pagesize)] + assert.Equalf(t, byte(0), page[0], + "page %d first byte: want 0 (post-fix zero-fault for `removed` state), got %#x — "+ + "if this equals the sentinel %#x, the worker used a stale `source = u.src` snapshot (regression)", + pageIdx, page[0], sentinel, + ) + + // Sanity: verify with /proc/self/pagemap that the + // page is in fact present after the racing read was + // served (worker re-mapped it as zero). + pagemap, err := testutils.NewPagemapReader() + require.NoError(t, err) + defer pagemap.Close() + entry, err := pagemap.ReadEntry(addr) + require.NoError(t, err) + assert.True(t, entry.IsPresent(), "page %d should be present after the racing read", pageIdx) + }) + }) + } +} + +// TestNoMadviseDeadlockWithInflightCopy is a liveness regression test +// for the user-visible symptom that originally surfaced the stale- +// source race: the orchestrator's parent madvise(MADV_DONTNEED) +// blocking forever because the UFFD handler loop was wedged behind a +// worker. +// +// The harness parks the worker AFTER it has taken settleRequests.RLock +// AND captured `source` (i.e. as if its UFFDIO_COPY was in flight). +// From the parent we then issue MADV_DONTNEED on the same page and +// require that madvise returns within `madviseBudget`. madvise +// unblocks as soon as the handler's readEvents drains the REMOVE +// event, and readEvents requires no lock — so any future change that +// accidentally couples readEvents to settleRequests fails this test +// at the `madviseBudget` boundary instead of as a 30-minute CI +// timeout. +func TestNoMadviseDeadlockWithInflightCopy(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + pagesize uint64 + }{ + {name: "4k", pagesize: header.PageSize}, + {name: "hugepage", pagesize: header.HugepageSize}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + withRaceContext(t, func(ctx context.Context) { + cfg := testConfig{ + pagesize: tt.pagesize, + numberOfPages: 4, + barriers: true, + } + + h, err := configureCrossProcessTest(t, cfg) + require.NoError(t, err) + + const pageIdx = 2 + pageOffset := int64(pageIdx) * int64(tt.pagesize) + + memStart := uintptr(unsafe.Pointer(&(*h.memoryArea)[0])) + addr := memStart + uintptr(pageIdx)*uintptr(tt.pagesize) + + token, err := h.installFaultBarrier(ctx, addr, barrierBeforeFaultPage) + require.NoError(t, err) + + writeErrCh := make(chan error, 1) + go func() { + writeErrCh <- h.executeWrite(ctx, operation{offset: pageOffset, mode: operationModeWrite}) + }() + + waitCtx, waitCancel := context.WithTimeout(ctx, barrierArrivalDeadline) + err = h.waitFaultHeld(waitCtx, token) + waitCancel() + require.NoError(t, err, "worker for page %d (addr %#x) did not park at pre-COPY barrier", pageIdx, addr) + + // Worker is parked AFTER RLock. Issue MADV_DONTNEED + // on the same page from the parent. The handler's + // readEvents must drain the REMOVE event (so madvise + // returns) even while the worker holds RLock. + madviseDone := make(chan error, 1) + go func() { + madviseDone <- unix.Madvise((*h.memoryArea)[pageOffset:pageOffset+int64(tt.pagesize)], unix.MADV_DONTNEED) + }() + + select { + case err := <-madviseDone: + require.NoError(t, err) + case <-time.After(madviseBudget): + _ = h.releaseFault(ctx, token) + <-writeErrCh + t.Fatalf("DEADLOCK: madvise(MADV_DONTNEED) on page %d did not return within %s "+ + "while a worker was parked holding settleRequests.RLock — readEvents must not require any lock", + pageIdx, madviseBudget) + } + + require.NoError(t, h.releaseFault(ctx, token)) + + select { + case err := <-writeErrCh: + require.NoError(t, err) + case <-ctx.Done(): + t.Fatalf("user-side write of page %d did not unblock after barrier release", pageIdx) + } + }) + }) + } +} + +// TestFaultedShortCircuitOrdering uses the gated harness to +// deterministically queue a WRITE pagefault for a fresh page AND a +// REMOVE for an already-faulted page in the SAME serve-loop +// iteration. After resume, the post-batch state is asserted: the +// REMOVE'd page is `removed` and the racing-write page is `faulted`. +// +// Both pre-fix and post-fix code reach the same end state for this +// scenario (REMOVE batch runs before the pagefault dispatch loop in +// every Serve iteration). This test guards the batch-processing +// invariant itself: any future change that, for example, dispatched +// pagefaults before draining REMOVEs would fail this test as a +// concrete state-mismatch assertion rather than a 30-minute hang. +// +// NOTE: this test deliberately does NOT call t.Parallel(). While the +// handler is in the gated `paused` state, the user thread that +// triggered the queued WRITE fault is suspended in the kernel's +// pagefault path. From the Go runtime's perspective that goroutine +// is "running" (not in syscall, since it's a plain memory store) but +// can't be preempted. If a CONCURRENT cross-process test in the same +// binary triggers a stop-the-world GC pause during this window, STW +// will wait forever for the suspended goroutine to reach a safe +// point — the kernel can't deliver the SIGURG preempt signal until +// the pagefault is served, and the handler is paused. Running this +// test sequentially avoids that interleaving. +func TestFaultedShortCircuitOrdering(t *testing.T) { + tests := []struct { + name string + pagesize uint64 + }{ + {name: "4k", pagesize: header.PageSize}, + {name: "hugepage", pagesize: header.HugepageSize}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + withRaceContext(t, func(_ context.Context) { + cfg := testConfig{ + pagesize: tt.pagesize, + numberOfPages: 2, + gated: true, + operations: []operation{ + {offset: 0, mode: operationModeRead}, + {mode: operationModeServePause}, + {offset: 0, mode: operationModeRemove, async: true}, + {mode: operationModeSleep}, + {offset: int64(tt.pagesize), mode: operationModeWrite, async: true}, + {mode: operationModeSleep}, + {mode: operationModeServeResume}, + }, + } + + h, err := configureCrossProcessTest(t, cfg) + require.NoError(t, err) + + h.executeAll(t, cfg.operations) + + states, err := h.pageStatesOnce() + require.NoError(t, err) + + assert.Contains(t, states.removed, uint(0), + "page 0 should be `removed` after REMOVE batch (got removed=%v faulted=%v)", + states.removed, states.faulted, + ) + assert.Contains(t, states.faulted, uint(tt.pagesize), + "page 1 (offset %d) should be `faulted` after the racing write was served (got removed=%v faulted=%v)", + tt.pagesize, states.removed, states.faulted, + ) + }) + }) + } +} + +// waitForState polls the child's PageStates RPC until the page at +// the given offset reaches `want` or `deadline` elapses. Each RPC +// round-trip is microseconds-to-low-milliseconds; we yield with a +// small sleep between polls so the harness doesn't burn an entire +// CPU on tight-loop encoding while the rest of the suite is also +// running cross-process tests. +func waitForState(ctx context.Context, h *testHandler, offset uint64, want pageState, deadline time.Duration) error { + const pollInterval = 1 * time.Millisecond + + end := time.Now().Add(deadline) + for { + states, err := h.pageStatesOnce() + if err != nil { + return err + } + + var bucket []uint + switch want { + case removed: + bucket = states.removed + case faulted: + bucket = states.faulted + } + + for _, off := range bucket { + if uint64(off) == offset { + return nil + } + } + + if time.Now().After(end) { + return fmt.Errorf("page state at offset %d: want %d after %s — last seen removed=%v faulted=%v", + offset, want, deadline, states.removed, states.faulted) + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(pollInterval): + } + } +}