diff --git a/github/github.go b/github/github.go index 6537938e54b..9f1eee4b90e 100644 --- a/github/github.go +++ b/github/github.go @@ -920,6 +920,11 @@ const ( SleepUntilPrimaryRateLimitResetWhenRateLimited ) +// maxErrorBodySize is the maximum number of bytes read from an HTTP error +// response body. Limits memory allocation when a server returns an +// unexpectedly large error body. +const maxErrorBodySize = 1 * 1024 * 1024 // 1 MiB + // bareDo sends an API request using `caller` http.Client passed in the parameters // and lets you handle the api response. If an error or API Error occurs, the error // will contain more information. Otherwise, you are supposed to read and close the @@ -997,7 +1002,7 @@ func (c *Client) bareDo(caller *http.Client, req *http.Request) (*Response, erro // Issue #1022 var aerr *AcceptedError if errors.As(err, &aerr) { - b, readErr := io.ReadAll(resp.Body) + b, readErr := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodySize)) if readErr != nil { return response, readErr } @@ -1502,7 +1507,7 @@ func CheckResponse(r *http.Response) error { } errorResponse := &ErrorResponse{Response: r} - data, err := io.ReadAll(r.Body) + data, err := io.ReadAll(io.LimitReader(r.Body, maxErrorBodySize)) if err == nil && data != nil { err = json.Unmarshal(data, errorResponse) if err != nil { diff --git a/github/github_test.go b/github/github_test.go index b41ebd8f849..87703736508 100644 --- a/github/github_test.go +++ b/github/github_test.go @@ -1400,6 +1400,36 @@ func TestDo_preservesResponseInHTTPError(t *testing.T) { } } +// TestDo_AcceptedError_LargeBodyTruncated verifies that when the API returns a +// 202 Accepted with a body larger than maxErrorBodySize, the client reads at +// most maxErrorBodySize bytes into AcceptedError.Raw and does not allocate +// unbounded memory. +func TestDo_AcceptedError_LargeBodyTruncated(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + + // Serve a 202 response whose body exceeds the cap by one byte. + mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusAccepted) + fmt.Fprint(w, strings.Repeat("x", maxErrorBodySize+1)) + }) + + req, _ := client.NewRequest(t.Context(), "GET", ".", nil) + _, err := client.Do(req, nil) + if err == nil { + t.Fatal("Expected AcceptedError, got nil") + } + + var aerr *AcceptedError + if !errors.As(err, &aerr) { + t.Fatalf("Expected *AcceptedError, got %T: %v", err, err) + } + + if got, want := len(aerr.Raw), maxErrorBodySize; got != want { + t.Errorf("AcceptedError.Raw length = %v, want %v (maxErrorBodySize)", got, want) + } +} + // Test that an error caused by the internal http client's Do() function // does not leak the client secret. func TestDo_sanitizeURL(t *testing.T) { @@ -2982,6 +3012,38 @@ func TestCheckResponse_unexpectedErrorStructure(t *testing.T) { } } +// TestCheckResponse_LargeBodyTruncated verifies that CheckResponse reads at +// most maxErrorBodySize bytes from an error response body, so that a +// malicious or buggy server cannot cause the client to allocate unbounded +// memory. +func TestCheckResponse_LargeBodyTruncated(t *testing.T) { + t.Parallel() + // Build a body that is one byte larger than the cap. + body := strings.Repeat("x", maxErrorBodySize+1) + res := &http.Response{ + Request: &http.Request{}, + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(strings.NewReader(body)), + } + + // CheckResponse should not return an error from the read itself; the HTTP + // error status is the expected error. + if err := CheckResponse(res); err == nil { + t.Fatal("Expected error from CheckResponse, got nil") + } + + // After CheckResponse, the body is restored with the (truncated) bytes that + // were actually read. Verify the restored body is exactly maxErrorBodySize + // bytes — not the full maxErrorBodySize+1 that the server sent. + restored, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("io.ReadAll on restored body: %v", err) + } + if got, want := len(restored), maxErrorBodySize; got != want { + t.Errorf("restored body length = %v, want %v (maxErrorBodySize)", got, want) + } +} + func TestParseBooleanResponse_true(t *testing.T) { t.Parallel() result, err := parseBoolResponse(nil)