Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions app_auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package gitgrab

import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net/http"
"os"
"time"
)

// GitHubAppCredentials holds the credentials needed to authenticate as a GitHub App.
type GitHubAppCredentials struct {
AppID string
PrivateKeyPath string
InstallationID string
}

// loadPrivateKey reads and parses an RSA private key from a PEM-encoded file.
// Both PKCS#1 and PKCS#8 formats are supported.
func loadPrivateKey(path string) (*rsa.PrivateKey, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read private key file %q: %w", path, err)
}

block, _ := pem.Decode(data)
if block == nil {
return nil, fmt.Errorf("no PEM block found in %q", path)
}

// Try PKCS#1 (traditional RSA private key) first.
if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return key, nil
}

// Fall back to PKCS#8 (used by some key generation tools).
parsed, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key from %q: %w", path, err)
}

rsaKey, ok := parsed.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("private key in %q is not an RSA key", path)
}

return rsaKey, nil
}

// buildJWT creates a signed RS256 JWT suitable for GitHub App authentication.
// The token is valid for 10 minutes with a 60-second back-dated iat to tolerate
// minor clock skew between the client and GitHub's servers.
func buildJWT(appID string, privateKey *rsa.PrivateKey) (string, error) {
now := time.Now()

headerJSON, err := json.Marshal(map[string]string{"alg": "RS256", "typ": "JWT"})
if err != nil {
return "", fmt.Errorf("failed to marshal JWT header: %w", err)
}

payloadJSON, err := json.Marshal(map[string]interface{}{
"iat": now.Add(-60 * time.Second).Unix(),
"exp": now.Add(10 * time.Minute).Unix(),
"iss": appID,
})
if err != nil {
return "", fmt.Errorf("failed to marshal JWT payload: %w", err)
}

header := base64.RawURLEncoding.EncodeToString(headerJSON)
payload := base64.RawURLEncoding.EncodeToString(payloadJSON)
signingInput := header + "." + payload

h := sha256.New()
h.Write([]byte(signingInput))
digest := h.Sum(nil)

sig, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, digest)
if err != nil {
return "", fmt.Errorf("failed to sign JWT: %w", err)
}

return signingInput + "." + base64.RawURLEncoding.EncodeToString(sig), nil
}

// GetInstallationToken exchanges GitHub App credentials for a short-lived
// installation access token. The returned GitHubToken can be used directly
// in place of a PAT — it uses the same Authorization header format.
func GetInstallationToken(creds GitHubAppCredentials, client HTTPClient) (GitHubToken, error) {
privateKey, err := loadPrivateKey(creds.PrivateKeyPath)
if err != nil {
return "", err
}

jwt, err := buildJWT(creds.AppID, privateKey)
if err != nil {
return "", err
}

url := fmt.Sprintf("https://api.github.com/app/installations/%s/access_tokens", creds.InstallationID)
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return "", fmt.Errorf("failed to create installation token request: %w", err)
}

req.Header.Set("Authorization", "Bearer "+jwt)
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "GitHub-Repo-Cloner")

resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to request installation token: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusCreated {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("failed to get installation token: %s - %s", resp.Status, string(body))
}

var result struct {
Token string `json:"token"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to decode installation token response: %w", err)
}

if result.Token == "" {
return "", fmt.Errorf("received empty token from GitHub API")
}

return GitHubToken(result.Token), nil
}
233 changes: 233 additions & 0 deletions app_auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
package gitgrab

import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)

// writeTestKey generates a 2048-bit RSA key, writes it as PKCS#1 PEM to a
// temp file, and returns the key and its path.
func writeTestKey(t *testing.T) (*rsa.PrivateKey, string) {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate test RSA key: %v", err)
}
pemBytes := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
path := filepath.Join(t.TempDir(), "test.pem")
if err := os.WriteFile(path, pemBytes, 0600); err != nil {
t.Fatalf("failed to write test key file: %v", err)
}
return key, path
}

func TestLoadPrivateKey_PKCS1(t *testing.T) {
_, path := writeTestKey(t)
key, err := loadPrivateKey(path)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if key == nil {
t.Fatal("expected non-nil key")
}
}

func TestLoadPrivateKey_PKCS8(t *testing.T) {
rawKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate key: %v", err)
}
pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(rawKey)
if err != nil {
t.Fatalf("failed to marshal PKCS8 key: %v", err)
}
pemBytes := pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: pkcs8Bytes,
})
path := filepath.Join(t.TempDir(), "pkcs8.pem")
if err := os.WriteFile(path, pemBytes, 0600); err != nil {
t.Fatalf("failed to write key file: %v", err)
}

key, err := loadPrivateKey(path)
if err != nil {
t.Fatalf("expected no error for PKCS8 key, got %v", err)
}
if key == nil {
t.Fatal("expected non-nil key")
}
}

func TestLoadPrivateKey_FileNotFound(t *testing.T) {
_, err := loadPrivateKey("/nonexistent/path/key.pem")
if err == nil {
t.Fatal("expected error for missing file, got nil")
}
if !strings.Contains(err.Error(), "failed to read private key file") {
t.Errorf("unexpected error message: %v", err)
}
}

func TestLoadPrivateKey_InvalidPEM(t *testing.T) {
path := filepath.Join(t.TempDir(), "bad.pem")
if err := os.WriteFile(path, []byte("this is not pem"), 0600); err != nil {
t.Fatalf("failed to write file: %v", err)
}
_, err := loadPrivateKey(path)
if err == nil {
t.Fatal("expected error for invalid PEM, got nil")
}
if !strings.Contains(err.Error(), "no PEM block found") {
t.Errorf("unexpected error message: %v", err)
}
}

func TestBuildJWT_Structure(t *testing.T) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate key: %v", err)
}

jwt, err := buildJWT("12345", key)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}

parts := strings.Split(jwt, ".")
if len(parts) != 3 {
t.Fatalf("expected 3 JWT parts, got %d", len(parts))
}

// Each part must be non-empty.
for i, p := range parts {
if p == "" {
t.Errorf("JWT part %d is empty", i)
}
}
}

func TestGetInstallationToken_Success(t *testing.T) {
_, keyPath := writeTestKey(t)

mockClient := &mockHTTPClient{
doFunc: func(req *http.Request) (*http.Response, error) {
// Verify the request shape.
if req.Method != "POST" {
t.Errorf("expected POST, got %s", req.Method)
}
if !strings.HasSuffix(req.URL.Path, "/access_tokens") {
t.Errorf("unexpected path: %s", req.URL.Path)
}
auth := req.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
t.Errorf("expected Bearer token in Authorization header, got %q", auth)
}

recorder := httptest.NewRecorder()
recorder.WriteHeader(http.StatusCreated)
json.NewEncoder(recorder).Encode(map[string]string{"token": "ghs_test_installation_token"})
return recorder.Result(), nil
},
}

creds := GitHubAppCredentials{
AppID: "12345",
PrivateKeyPath: keyPath,
InstallationID: "67890",
}

token, err := GetInstallationToken(creds, mockClient)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if token != GitHubToken("ghs_test_installation_token") {
t.Errorf("expected token 'ghs_test_installation_token', got %q", token)
}
}

func TestGetInstallationToken_APIError(t *testing.T) {
_, keyPath := writeTestKey(t)

mockClient := &mockHTTPClient{
doFunc: func(req *http.Request) (*http.Response, error) {
recorder := httptest.NewRecorder()
recorder.WriteHeader(http.StatusUnauthorized)
recorder.Write([]byte(`{"message":"Bad credentials"}`))
return recorder.Result(), nil
},
}

creds := GitHubAppCredentials{
AppID: "12345",
PrivateKeyPath: keyPath,
InstallationID: "67890",
}

_, err := GetInstallationToken(creds, mockClient)
if err == nil {
t.Fatal("expected error for API failure, got nil")
}
if !strings.Contains(err.Error(), "failed to get installation token") {
t.Errorf("unexpected error message: %v", err)
}
}

func TestGetInstallationToken_EmptyToken(t *testing.T) {
_, keyPath := writeTestKey(t)

mockClient := &mockHTTPClient{
doFunc: func(req *http.Request) (*http.Response, error) {
recorder := httptest.NewRecorder()
recorder.WriteHeader(http.StatusCreated)
json.NewEncoder(recorder).Encode(map[string]string{"token": ""})
return recorder.Result(), nil
},
}

creds := GitHubAppCredentials{
AppID: "12345",
PrivateKeyPath: keyPath,
InstallationID: "67890",
}

_, err := GetInstallationToken(creds, mockClient)
if err == nil {
t.Fatal("expected error for empty token, got nil")
}
if !strings.Contains(err.Error(), "empty token") {
t.Errorf("unexpected error message: %v", err)
}
}

func TestGetInstallationToken_BadKeyPath(t *testing.T) {
mockClient := &mockHTTPClient{
doFunc: func(req *http.Request) (*http.Response, error) {
t.Error("HTTP client should not be called when key file is missing")
return nil, nil
},
}

creds := GitHubAppCredentials{
AppID: "12345",
PrivateKeyPath: "/nonexistent/key.pem",
InstallationID: "67890",
}

_, err := GetInstallationToken(creds, mockClient)
if err == nil {
t.Fatal("expected error for bad key path, got nil")
}
}
Loading
Loading