diff --git a/go.mod b/go.mod index 92ccdc1f..20e160c5 100644 --- a/go.mod +++ b/go.mod @@ -111,6 +111,7 @@ require ( github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.5 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/gosimple/slug v1.15.0 // indirect github.com/gosimple/unidecode v1.0.1 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect diff --git a/go.sum b/go.sum index 783e9070..1113e858 100644 --- a/go.sum +++ b/go.sum @@ -318,6 +318,8 @@ github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBY github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gosimple/slug v1.15.0 h1:wRZHsRrRcs6b0XnxMUBM6WK1U1Vg5B0R7VkIf1Xzobo= github.com/gosimple/slug v1.15.0/go.mod h1:UiRaFH+GEilHstLUmcBgWcI42viBN7mAb818JrYOeFQ= github.com/gosimple/unidecode v1.0.1 h1:hZzFTMMqSswvf0LBJZCZgThIZrpDHFXux9KeGmn6T/o= diff --git a/packages/api/model.go b/packages/api/model.go index cdd063e7..fa7f2d61 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -825,7 +825,8 @@ type PAMAccessApprovalRequestResponse struct { } type PAMSessionCredentialsResponse struct { - Credentials PAMSessionCredentials `json:"credentials"` + Credentials PAMSessionCredentials `json:"credentials"` + SharedSecret string `json:"sharedSecret,omitempty"` } type PAMSessionCredentials struct { diff --git a/packages/cmd/relay.go b/packages/cmd/relay.go index d38ff31c..b42a5e60 100644 --- a/packages/cmd/relay.go +++ b/packages/cmd/relay.go @@ -53,6 +53,7 @@ var relayStartCmd = &cobra.Command{ RelayName: relayName, SSHPort: "2222", TLSPort: "8443", + WSPort: "8444", Host: host, Type: instanceType, }) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 62ba3613..6c5dc7e5 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -10,6 +10,7 @@ import ( "encoding/json" "encoding/pem" "fmt" + "io" "net" "strconv" "strings" @@ -550,6 +551,30 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { go ssh.DiscardRequests(requests) + // Peek first byte to detect connection type: + // 0x16 = TLS ClientHello → mTLS flow (CLI clients) + // 0x00 = Web ECDH magic byte → web flow (browser clients) + firstByte := make([]byte, 1) + if _, err := io.ReadFull(channel, firstByte); err != nil { + log.Info().Msgf("Failed to read first byte: %v", err) + return + } + + switch firstByte[0] { + case 0x16: + g.handleMTLSConnection(channel, firstByte) + case 0x00: + // ECDH magic byte - handle web proxy connection + virtualConn := &prefixedVirtualConnection{channel: channel} + if err := pam.HandlePAMWebProxy(g.ctx, virtualConn, g.httpClient, g.pamCredentialsManager, g.pamSessionUploader); err != nil { + log.Error().Err(err).Msg("PAM web proxy handler ended with error") + } + default: + log.Warn().Msgf("Unknown protocol byte: 0x%02x, closing channel", firstByte[0]) + } +} + +func (g *Gateway) handleMTLSConnection(channel ssh.Channel, firstByte []byte) { // Create mTLS server configuration tlsConfig := g.tlsConfig if tlsConfig == nil { @@ -557,10 +582,8 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { return } - // Create a virtual connection that pipes data between SSH channel and TLS - virtualConn := &virtualConnection{ - channel: channel, - } + // Create a virtual connection that prepends the peeked byte + virtualConn := newPrefixedVirtualConnection(firstByte, channel) // Wrap the virtual connection with TLS tlsConn := tls.Server(virtualConn, tlsConfig) @@ -776,40 +799,55 @@ func (g *Gateway) parseDetailsFromCertificate(tlsConn *tls.Conn, config *Forward return nil } -// virtualConnection implements net.Conn to bridge SSH channel and TLS -type virtualConnection struct { +// prefixedVirtualConnection implements net.Conn, prepending buffered bytes before reading from the channel +type prefixedVirtualConnection struct { channel ssh.Channel + prefix []byte + offset int } -func (vc *virtualConnection) Read(b []byte) (n int, err error) { +func newPrefixedVirtualConnection(prefix []byte, channel ssh.Channel) *prefixedVirtualConnection { + return &prefixedVirtualConnection{ + channel: channel, + prefix: prefix, + offset: 0, + } +} + +func (vc *prefixedVirtualConnection) Read(b []byte) (int, error) { + if vc.offset < len(vc.prefix) { + n := copy(b, vc.prefix[vc.offset:]) + vc.offset += n + return n, nil + } return vc.channel.Read(b) } -func (vc *virtualConnection) Write(b []byte) (n int, err error) { +func (vc *prefixedVirtualConnection) Write(b []byte) (int, error) { return vc.channel.Write(b) } -func (vc *virtualConnection) Close() error { +func (vc *prefixedVirtualConnection) Close() error { return vc.channel.Close() } -func (vc *virtualConnection) LocalAddr() net.Addr { +func (vc *prefixedVirtualConnection) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} } -func (vc *virtualConnection) RemoteAddr() net.Addr { +func (vc *prefixedVirtualConnection) RemoteAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} } -func (vc *virtualConnection) SetDeadline(t time.Time) error { +func (vc *prefixedVirtualConnection) SetDeadline(t time.Time) error { return nil } -func (vc *virtualConnection) SetReadDeadline(t time.Time) error { +func (vc *prefixedVirtualConnection) SetReadDeadline(t time.Time) error { return nil } -func (vc *virtualConnection) SetWriteDeadline(t time.Time) error { +func (vc *prefixedVirtualConnection) SetWriteDeadline(t time.Time) error { return nil } diff --git a/packages/pam/encrypted_conn.go b/packages/pam/encrypted_conn.go new file mode 100644 index 00000000..05783e80 --- /dev/null +++ b/packages/pam/encrypted_conn.go @@ -0,0 +1,140 @@ +package pam + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "net" + "sync" + "time" +) + +// EncryptedConn wraps a net.Conn with AES-256-GCM encryption. +// Frame format: [4-byte big-endian total-frame-length][12-byte random nonce][ciphertext + 16-byte GCM auth tag] +type EncryptedConn struct { + inner net.Conn + gcm cipher.AEAD + + readMu sync.Mutex + readBuf []byte + + writeMu sync.Mutex +} + +func NewEncryptedConn(inner net.Conn, aesKey []byte) (*EncryptedConn, error) { + block, err := aes.NewCipher(aesKey) + if err != nil { + return nil, fmt.Errorf("failed to create AES cipher: %w", err) + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + return &EncryptedConn{ + inner: inner, + gcm: gcm, + }, nil +} + +func (ec *EncryptedConn) Read(b []byte) (int, error) { + ec.readMu.Lock() + defer ec.readMu.Unlock() + + // Serve from buffer if available + if len(ec.readBuf) > 0 { + n := copy(b, ec.readBuf) + ec.readBuf = ec.readBuf[n:] + return n, nil + } + + // Read frame length + lengthBuf := make([]byte, 4) + if _, err := io.ReadFull(ec.inner, lengthBuf); err != nil { + return 0, err + } + frameLen := binary.BigEndian.Uint32(lengthBuf) + if frameLen > 1<<24 { + return 0, fmt.Errorf("encrypted frame too large: %d bytes", frameLen) + } + + // Read the full frame (nonce + ciphertext) + frame := make([]byte, frameLen) + if _, err := io.ReadFull(ec.inner, frame); err != nil { + return 0, fmt.Errorf("failed to read encrypted frame: %w", err) + } + + nonceSize := ec.gcm.NonceSize() + if int(frameLen) < nonceSize { + return 0, fmt.Errorf("frame too short for nonce") + } + + nonce := frame[:nonceSize] + ciphertext := frame[nonceSize:] + + plaintext, err := ec.gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return 0, fmt.Errorf("GCM decryption failed: %w", err) + } + + n := copy(b, plaintext) + if n < len(plaintext) { + ec.readBuf = plaintext[n:] + } + return n, nil +} + +func (ec *EncryptedConn) Write(b []byte) (int, error) { + ec.writeMu.Lock() + defer ec.writeMu.Unlock() + + nonce := make([]byte, ec.gcm.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return 0, fmt.Errorf("failed to generate nonce: %w", err) + } + + ciphertext := ec.gcm.Seal(nil, nonce, b, nil) + + // Frame = nonce + ciphertext + frameLen := len(nonce) + len(ciphertext) + lengthBuf := make([]byte, 4) + binary.BigEndian.PutUint32(lengthBuf, uint32(frameLen)) + + if _, err := ec.inner.Write(lengthBuf); err != nil { + return 0, fmt.Errorf("failed to write frame length: %w", err) + } + if _, err := ec.inner.Write(nonce); err != nil { + return 0, fmt.Errorf("failed to write nonce: %w", err) + } + if _, err := ec.inner.Write(ciphertext); err != nil { + return 0, fmt.Errorf("failed to write ciphertext: %w", err) + } + + return len(b), nil +} + +func (ec *EncryptedConn) Close() error { + return ec.inner.Close() +} + +func (ec *EncryptedConn) LocalAddr() net.Addr { + return ec.inner.LocalAddr() +} + +func (ec *EncryptedConn) RemoteAddr() net.Addr { + return ec.inner.RemoteAddr() +} + +func (ec *EncryptedConn) SetDeadline(t time.Time) error { + return ec.inner.SetDeadline(t) +} + +func (ec *EncryptedConn) SetReadDeadline(t time.Time) error { + return ec.inner.SetReadDeadline(t) +} + +func (ec *EncryptedConn) SetWriteDeadline(t time.Time) error { + return ec.inner.SetWriteDeadline(t) +} diff --git a/packages/pam/pam-web-proxy.go b/packages/pam/pam-web-proxy.go new file mode 100644 index 00000000..5703ef39 --- /dev/null +++ b/packages/pam/pam-web-proxy.go @@ -0,0 +1,331 @@ +package pam + +import ( + "context" + "crypto/ecdh" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net" + "net/url" + "time" + + "github.com/Infisical/infisical-merge/packages/pam/handlers" + "github.com/Infisical/infisical-merge/packages/pam/handlers/kubernetes" + "github.com/Infisical/infisical-merge/packages/pam/handlers/mysql" + "github.com/Infisical/infisical-merge/packages/pam/handlers/redis" + "github.com/Infisical/infisical-merge/packages/pam/handlers/ssh" + "github.com/Infisical/infisical-merge/packages/pam/session" + "github.com/go-resty/resty/v2" + "github.com/rs/zerolog/log" + "golang.org/x/crypto/hkdf" +) + +type webHandshakeMessage struct { + SessionID string `json:"sessionId"` + ResourceType string `json:"resourceType"` + PublicKey string `json:"publicKey"` + Signature string `json:"signature"` +} + +type webHandshakeResponse struct { + PublicKey string `json:"publicKey"` + Signature string `json:"signature"` +} + +func readLengthPrefixedMessage(conn net.Conn) ([]byte, error) { + lengthBuf := make([]byte, 4) + if _, err := io.ReadFull(conn, lengthBuf); err != nil { + return nil, fmt.Errorf("failed to read length prefix: %w", err) + } + length := binary.BigEndian.Uint32(lengthBuf) + log.Debug().Msgf("[pam-web] readLengthPrefixed: raw length bytes=%x, decoded length=%d", lengthBuf, length) + if length > 1<<20 { + return nil, fmt.Errorf("message too large: %d bytes (length prefix bytes: %x)", length, lengthBuf) + } + data := make([]byte, length) + if _, err := io.ReadFull(conn, data); err != nil { + return nil, fmt.Errorf("failed to read message body (expected %d bytes): %w", length, err) + } + log.Debug().Msgf("[pam-web] readLengthPrefixed: read %d bytes of body", len(data)) + return data, nil +} + +func writeLengthPrefixedMessage(conn net.Conn, data []byte) error { + lengthBuf := make([]byte, 4) + binary.BigEndian.PutUint32(lengthBuf, uint32(len(data))) + if _, err := conn.Write(lengthBuf); err != nil { + return fmt.Errorf("failed to write length prefix: %w", err) + } + if _, err := conn.Write(data); err != nil { + return fmt.Errorf("failed to write message body: %w", err) + } + return nil +} + +func HandlePAMWebProxy( + ctx context.Context, + conn net.Conn, + httpClient *resty.Client, + credentialsManager *session.CredentialsManager, + sessionUploader *session.SessionUploader, +) error { + // Read browser's ECDH handshake + msgData, err := readLengthPrefixedMessage(conn) + if err != nil { + return fmt.Errorf("[pam-web] failed to read handshake message: %w", err) + } + + var handshake webHandshakeMessage + if err := json.Unmarshal(msgData, &handshake); err != nil { + log.Error().Msgf("[pam-web] Failed to parse handshake JSON (raw first 200 bytes): %s", string(msgData[:min(len(msgData), 200)])) + return fmt.Errorf("[pam-web] failed to parse handshake message: %w", err) + } + + log.Info(). + Str("sessionId", handshake.SessionID). + Str("resourceType", handshake.ResourceType). + Msg("[pam-web] Received web ECDH handshake") + + browserPubKeyBytes, err := base64.StdEncoding.DecodeString(handshake.PublicKey) + if err != nil { + return fmt.Errorf("failed to decode browser public key: %w", err) + } + + browserSigBytes, err := base64.StdEncoding.DecodeString(handshake.Signature) + if err != nil { + return fmt.Errorf("failed to decode browser signature: %w", err) + } + + // Fetch credentials + shared secret + credentials, err := credentialsManager.GetPAMSessionCredentials(handshake.SessionID, time.Now().Add(24*time.Hour)) + if err != nil { + return fmt.Errorf("[pam-web] failed to retrieve PAM session credentials: %w", err) + } + + if credentials.SharedSecret == "" { + return fmt.Errorf("[pam-web] shared secret not available for session %s", handshake.SessionID) + } + + sharedSecretBytes, err := base64.StdEncoding.DecodeString(credentials.SharedSecret) + if err != nil { + return fmt.Errorf("failed to decode shared secret: %w", err) + } + + // Verify browser's signature -- there are other packages that make this easier but its OK + mac := hmac.New(sha256.New, sharedSecretBytes) + mac.Write(browserPubKeyBytes) + expectedSig := mac.Sum(nil) + if !hmac.Equal(browserSigBytes, expectedSig) { + return fmt.Errorf("[pam-web] browser ECDH public key signature verification failed (possible MITM)") + } + + // Parse browser's ECDH public key + browserPubKey, err := ecdh.P256().NewPublicKey(browserPubKeyBytes) + if err != nil { + return fmt.Errorf("failed to parse browser ECDH public key: %w", err) + } + + // Generate gateway ECDH keypair + gatewayPrivKey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + return fmt.Errorf("failed to generate gateway ECDH keypair: %w", err) + } + gatewayPubKeyBytes := gatewayPrivKey.PublicKey().Bytes() + + // Sign and send gateway's public key + gwMac := hmac.New(sha256.New, sharedSecretBytes) + gwMac.Write(gatewayPubKeyBytes) + gwSig := gwMac.Sum(nil) + + responseMsg := webHandshakeResponse{ + PublicKey: base64.StdEncoding.EncodeToString(gatewayPubKeyBytes), + Signature: base64.StdEncoding.EncodeToString(gwSig), + } + + responseData, err := json.Marshal(responseMsg) + if err != nil { + return fmt.Errorf("failed to marshal handshake response: %w", err) + } + + if err := writeLengthPrefixedMessage(conn, responseData); err != nil { + return fmt.Errorf("[pam-web] failed to send handshake response: %w", err) + } + + // Derive AES-256 key via ECDH + HKDF + sharedECDH, err := gatewayPrivKey.ECDH(browserPubKey) + if err != nil { + return fmt.Errorf("failed to compute ECDH shared secret: %w", err) + } + + hkdfReader := hkdf.New(sha256.New, sharedECDH, nil, []byte("infisical-pam-web-encryption")) + aesKey := make([]byte, 32) + if _, err := io.ReadFull(hkdfReader, aesKey); err != nil { + return fmt.Errorf("failed to derive AES key: %w", err) + } + + // Create encrypted connection wrapper + encConn, err := NewEncryptedConn(conn, aesKey) + if err != nil { + return fmt.Errorf("failed to create encrypted connection: %w", err) + } + + log.Info(). + Str("sessionId", handshake.SessionID). + Str("resourceType", handshake.ResourceType). + Msg("[pam-web] ECDH handshake completed, AES key derived, encrypted tunnel established") + + // Session expiry monitoring + timeUntilExpiry := 24 * time.Hour + go func() { + if creds, err := credentialsManager.GetPAMSessionCredentials(handshake.SessionID, time.Now().Add(timeUntilExpiry)); err == nil { + _ = creds + } + + timer := time.NewTimer(timeUntilExpiry) + defer timer.Stop() + + select { + case <-timer.C: + log.Info(). + Str("sessionId", handshake.SessionID). + Msg("PAM web session expired, closing connection") + if err := sessionUploader.CleanupPAMSession(handshake.SessionID, "expiry"); err != nil { + log.Error().Err(err).Str("sessionId", handshake.SessionID).Msg("Failed to cleanup PAM web session on expiry") + } + encConn.Close() + case <-ctx.Done(): + return + } + }() + + // Session recording setup + encryptionKey, err := credentialsManager.GetPAMSessionEncryptionKey() + if err != nil { + return fmt.Errorf("failed to get PAM session encryption key: %w", err) + } + sessionLogger, err := session.NewSessionLogger(handshake.SessionID, encryptionKey, time.Now().Add(timeUntilExpiry), handshake.ResourceType) + if err != nil { + return fmt.Errorf("failed to create session logger: %w", err) + } + + serverName := credentials.Host + if handshake.ResourceType == session.ResourceTypeKubernetes { + parsed, parseErr := url.Parse(credentials.Url) + if parseErr != nil { + return fmt.Errorf("failed to parse URL: %w", parseErr) + } + serverName = parsed.Hostname() + } + + tlsConfig := &tls.Config{ + InsecureSkipVerify: !credentials.SSLRejectUnauthorized, + ServerName: serverName, + } + if credentials.SSLCertificate != "" { + certPool := x509.NewCertPool() + if certPool.AppendCertsFromPEM([]byte(credentials.SSLCertificate)) { + tlsConfig.RootCAs = certPool + } + } + + // Route to protocol handler by resource type + switch handshake.ResourceType { + case session.ResourceTypePostgres: + proxyConfig := handlers.PostgresProxyConfig{ + TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), + InjectUsername: credentials.Username, + InjectPassword: credentials.Password, + InjectDatabase: credentials.Database, + EnableTLS: credentials.SSLEnabled, + TLSConfig: tlsConfig, + SessionID: handshake.SessionID, + SessionLogger: sessionLogger, + } + proxy := handlers.NewPostgresProxy(proxyConfig) + log.Info(). + Str("sessionId", handshake.SessionID). + Str("target", proxyConfig.TargetAddr). + Msg("Starting PostgreSQL PAM web proxy") + return proxy.HandleConnection(ctx, encConn) + + case session.ResourceTypeMysql: + mysqlConfig := mysql.MysqlProxyConfig{ + TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), + InjectUsername: credentials.Username, + InjectPassword: credentials.Password, + InjectDatabase: credentials.Database, + EnableTLS: credentials.SSLEnabled, + TLSConfig: tlsConfig, + SessionID: handshake.SessionID, + SessionLogger: sessionLogger, + } + proxy := mysql.NewMysqlProxy(mysqlConfig) + log.Info(). + Str("sessionId", handshake.SessionID). + Str("target", mysqlConfig.TargetAddr). + Msg("Starting MySQL PAM web proxy") + return proxy.HandleConnection(ctx, encConn) + + case session.ResourceTypeRedis: + redisConfig := redis.RedisProxyConfig{ + TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), + InjectUsername: credentials.Username, + InjectPassword: credentials.Password, + EnableTLS: credentials.SSLEnabled, + TLSConfig: tlsConfig, + SessionID: handshake.SessionID, + SessionLogger: sessionLogger, + } + proxy := redis.NewRedisProxy(redisConfig) + log.Info(). + Str("sessionId", handshake.SessionID). + Str("target", redisConfig.TargetAddr). + Msg("Starting Redis PAM web proxy") + return proxy.HandleConnection(ctx, encConn) + + case session.ResourceTypeSSH: + sshConfig := ssh.SSHProxyConfig{ + TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), + AuthMethod: credentials.AuthMethod, + InjectUsername: credentials.Username, + InjectPassword: credentials.Password, + InjectPrivateKey: credentials.PrivateKey, + InjectCertificate: credentials.Certificate, + SessionID: handshake.SessionID, + SessionLogger: sessionLogger, + } + proxy := ssh.NewSSHProxy(sshConfig) + log.Info(). + Str("sessionId", handshake.SessionID). + Str("target", sshConfig.TargetAddr). + Msg("Starting SSH PAM web proxy") + return proxy.HandleConnection(ctx, encConn) + + case session.ResourceTypeKubernetes: + kubernetesConfig := kubernetes.KubernetesProxyConfig{ + AuthMethod: credentials.AuthMethod, + InjectServiceAccountToken: credentials.ServiceAccountToken, + TargetApiServer: credentials.Url, + TLSConfig: tlsConfig, + SessionID: handshake.SessionID, + SessionLogger: sessionLogger, + } + proxy := kubernetes.NewKubernetesProxy(kubernetesConfig) + log.Info(). + Str("sessionId", handshake.SessionID). + Str("target", kubernetesConfig.TargetApiServer). + Msg("Starting Kubernetes PAM web proxy") + return proxy.HandleConnection(ctx, encConn) + + default: + return fmt.Errorf("unsupported resource type: %s", handshake.ResourceType) + } +} diff --git a/packages/pam/session/credentials.go b/packages/pam/session/credentials.go index 2abdd7cf..4d02db1a 100644 --- a/packages/pam/session/credentials.go +++ b/packages/pam/session/credentials.go @@ -24,6 +24,7 @@ type PAMCredentials struct { SSLCertificate string Url string ServiceAccountToken string + SharedSecret string } type cachedCredentials struct { @@ -98,6 +99,7 @@ func (cm *CredentialsManager) GetPAMSessionCredentials(sessionId string, expiryT SSLCertificate: response.Credentials.SSLCertificate, Url: response.Credentials.Url, ServiceAccountToken: response.Credentials.ServiceAccountToken, + SharedSecret: response.SharedSecret, } cm.cacheMutex.Lock() diff --git a/packages/relay/relay.go b/packages/relay/relay.go index cce24431..18ecd3e0 100644 --- a/packages/relay/relay.go +++ b/packages/relay/relay.go @@ -39,6 +39,7 @@ type RelayConfig struct { // Server Ports SSHPort string TLSPort string + WSPort string // Network Configuration Host string @@ -67,6 +68,7 @@ type Relay struct { // Server listeners sshListener net.Listener tlsListener net.Listener + wsListener net.Listener } func NewRelay(config *RelayConfig) (*Relay, error) { @@ -188,6 +190,11 @@ func (r *Relay) Start(ctx context.Context) error { // Start TLS server go r.startTLSServer() + // Start WebSocket server (if configured) + if r.config.WSPort != "" { + go r.startWSServer() + } + log.Info().Msg("Relay server started successfully") systemd.SdNotify(false, systemd.SdNotifyReady) @@ -592,6 +599,9 @@ func (r *Relay) cleanup() { if r.tlsListener != nil { r.tlsListener.Close() } + if r.wsListener != nil { + r.wsListener.Close() + } log.Info().Msg("Relay server shutdown complete") } diff --git a/packages/relay/relay_ws.go b/packages/relay/relay_ws.go new file mode 100644 index 00000000..ff14fc27 --- /dev/null +++ b/packages/relay/relay_ws.go @@ -0,0 +1,263 @@ +package relay + +import ( + "crypto/rand" + "crypto/x509" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "io" + "net" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog/log" +) + +const pendingAuthTTL = 30 * time.Second + +type pendingWSAuth struct { + gatewayID string + gatewayName string + expiresAt time.Time +} + +type wsAuthStore struct { + mu sync.Mutex + pending map[string]*pendingWSAuth +} + +func newWSAuthStore() *wsAuthStore { + return &wsAuthStore{pending: make(map[string]*pendingWSAuth)} +} + +func (s *wsAuthStore) store(id string, auth *pendingWSAuth) { + s.mu.Lock() + s.pending[id] = auth + s.mu.Unlock() +} + +func (s *wsAuthStore) consume(id string) (*pendingWSAuth, bool) { + s.mu.Lock() + defer s.mu.Unlock() + + auth, ok := s.pending[id] + if !ok { + return nil, false + } + delete(s.pending, id) + + if time.Now().After(auth.expiresAt) { + return nil, false + } + + return auth, true +} + +var wsUpgrader = websocket.Upgrader{ + CheckOrigin: func(req *http.Request) bool { return true }, +} + +func corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + + if req.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(w, req) + }) +} + +func (r *Relay) startWSServer() { + listener, err := net.Listen("tcp", ":"+r.config.WSPort) + if err != nil { + log.Fatal().Msgf("Failed to start WebSocket server: %v", err) + } + r.wsListener = listener + + authStore := newWSAuthStore() + + mux := http.NewServeMux() + mux.HandleFunc("/ws/authenticate", r.handleWSAuthenticate(authStore)) + mux.HandleFunc("/ws", r.handleWSUpgrade(authStore)) + + server := &http.Server{Handler: corsMiddleware(mux)} + log.Info().Msgf("WebSocket server listening on :%s for browser clients", r.config.WSPort) + + if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) && !errors.Is(err, net.ErrClosed) { + log.Error().Msgf("WebSocket server error: %v", err) + } +} + +func (r *Relay) handleWSAuthenticate(authStore *wsAuthStore) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(io.LimitReader(req.Body, 8192)) + if err != nil || len(body) == 0 { + http.Error(w, "missing certificate", http.StatusBadRequest) + return + } + + block, _ := pem.Decode(body) + if block == nil { + http.Error(w, "invalid PEM", http.StatusBadRequest) + return + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + http.Error(w, "invalid certificate", http.StatusBadRequest) + return + } + + if _, err := cert.Verify(x509.VerifyOptions{ + Roots: r.tlsConfig.ClientCAs, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }); err != nil { + log.Debug().Msgf("WebSocket client certificate verification failed: %v", err) + http.Error(w, "certificate verification failed", http.StatusUnauthorized) + return + } + + gatewayID := cert.Subject.CommonName + if gatewayID == "" { + http.Error(w, "missing gateway ID in certificate", http.StatusBadRequest) + return + } + + var gatewayName string + for _, ext := range cert.Extensions { + if ext.Id.String() == RELAY_CONNECTING_GATEWAY_INFO_OID { + var info ConnectingGatewayInfo + if err := json.Unmarshal(ext.Value, &info); err != nil { + log.Warn().Msgf("Failed to unmarshal gateway info from WS auth cert: %v", err) + } else { + gatewayName = info.Name + } + } + } + + idBytes := make([]byte, 16) + if _, err := rand.Read(idBytes); err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + connectionID := hex.EncodeToString(idBytes) + + authStore.store(connectionID, &pendingWSAuth{ + gatewayID: gatewayID, + gatewayName: gatewayName, + expiresAt: time.Now().Add(pendingAuthTTL), + }) + + log.Info().Msgf("WebSocket auth successful for gateway %s (%s), connectionId=%s", gatewayName, gatewayID, connectionID) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"connectionId": connectionID}) + } +} + +func (r *Relay) handleWSUpgrade(authStore *wsAuthStore) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + connectionID := req.URL.Query().Get("connectionId") + if connectionID == "" { + http.Error(w, "missing connectionId", http.StatusUnauthorized) + return + } + + auth, ok := authStore.consume(connectionID) + if !ok { + http.Error(w, "invalid or expired connectionId", http.StatusUnauthorized) + return + } + + wsConn, err := wsUpgrader.Upgrade(w, req, nil) + if err != nil { + log.Error().Msgf("WebSocket upgrade failed: %v", err) + return + } + + r.handleWSClient(wsConn, auth.gatewayID, auth.gatewayName) + } +} + +func (r *Relay) handleWSClient(wsConn *websocket.Conn, gatewayID string, gatewayName string) { + defer wsConn.Close() + + log.Info().Msgf("WebSocket client connected for gateway %s (%s)", gatewayName, gatewayID) + + r.mu.RLock() + sshConn, exists := r.tunnels[gatewayID] + r.mu.RUnlock() + + if !exists { + log.Warn().Msgf("Gateway '%s' (%s) not connected (WebSocket client)", gatewayName, gatewayID) + wsConn.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "gateway not connected")) + return + } + + channel, _, err := sshConn.OpenChannel("direct-tcpip", nil) + if err != nil { + log.Error().Msgf("Failed to open SSH channel to gateway %s (%s): %v", gatewayName, gatewayID, err) + wsConn.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "gateway connection failed")) + return + } + defer channel.Close() + + done := make(chan struct{}, 1) + + // WS -> SSH + go func() { + defer func() { + done <- struct{}{} + }() + for { + msgType, msg, err := wsConn.ReadMessage() + if err != nil { + return + } + if msgType != websocket.BinaryMessage { + continue + } + if _, err := channel.Write(msg); err != nil { + return + } + } + }() + + // SSH -> WS + go func() { + defer func() { + done <- struct{}{} + }() + buf := make([]byte, 32*1024) + for { + n, err := channel.Read(buf) + if n > 0 { + if writeErr := wsConn.WriteMessage(websocket.BinaryMessage, buf[:n]); writeErr != nil { + return + } + } + if err != nil { + return + } + } + }() + + <-done + log.Info().Msgf("WebSocket client disconnected for gateway %s (%s)", gatewayName, gatewayID) +}