diff --git a/config.go b/config.go index 2dda5c34..74debaf1 100644 --- a/config.go +++ b/config.go @@ -15,6 +15,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/robfig/cron/v3" + "github.com/snowflakedb/gosnowflake" "gopkg.in/yaml.v2" ) @@ -164,6 +165,8 @@ type connection struct { user string tokenExpirationTime time.Time iteratorValues []string + snowflakeConfig *gosnowflake.Config + snowflakeDSN string } // Query is an SQL query that is executed on a connection diff --git a/job.go b/job.go index 6852e6ed..3d1b1d26 100644 --- a/job.go +++ b/job.go @@ -9,6 +9,9 @@ import ( "strconv" "strings" "time" + "crypto/rsa" + "crypto/x509" + "encoding/pem" _ "github.com/ClickHouse/clickhouse-go/v2" // register the ClickHouse driver "github.com/cenkalti/backoff" @@ -374,36 +377,86 @@ func (j *Job) updateConnections() { } } if newConn.driver == "snowflake" { + u, err := url.Parse(conn) + if err != nil { + level.Error(j.log).Log("msg", "Failed to parse Snowflake URL", "url", conn, "err", err) + continue + } + + queryParams := u.Query() + privateKeyPath := os.ExpandEnv(queryParams.Get("private_key_file")) + cfg := &gosnowflake.Config{ Account: u.Host, User: u.User.Username(), + Role: queryParams.Get("role"), + Database: queryParams.Get("database"), + Schema: queryParams.Get("schema"), } - - pw, set := u.User.Password() - if set { - cfg.Password = pw - } - - if u.Port() != "" { - portStr, err := strconv.Atoi(u.Port()) + + if privateKeyPath != "" { + // RSA key auth + keyBytes, err := os.ReadFile(privateKeyPath) if err != nil { - level.Error(j.log).Log("msg", "Failed to parse Snowflake port", "connection", conn, "err", err) + level.Error(j.log).Log("msg", "Failed to read private key file", "path", privateKeyPath, "err", err) + continue + } + + keyBlock, _ := pem.Decode(keyBytes) + if keyBlock == nil { + level.Error(j.log).Log("msg", "Failed to decode PEM block", "path", privateKeyPath) + continue + } + + var privateKey *rsa.PrivateKey + if parsedKey, err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes); err == nil { + privateKey, _ = parsedKey.(*rsa.PrivateKey) + } else if parsedKey, err := x509.ParsePKCS1PrivateKey(keyBlock.Bytes); err == nil { + privateKey = parsedKey + } else { + level.Error(j.log).Log("msg", "Failed to parse private key", "err", err) + continue + } + + cfg.Authenticator = gosnowflake.AuthTypeJwt + cfg.PrivateKey = privateKey + + dsn, err := gosnowflake.DSN(cfg) + if err != nil { + level.Error(j.log).Log("msg", "Failed to create Snowflake DSN with RSA", "err", err) + continue + } + + newConn.snowflakeConfig = cfg + newConn.snowflakeDSN = dsn + newConn.host = u.Host + newConn.tokenExpirationTime = time.Now().Add(time.Hour) + } else { + // Password auth + if pw, set := u.User.Password(); set { + cfg.Password = pw + } + if u.Port() != "" { + if port, err := strconv.Atoi(u.Port()); err == nil { + cfg.Port = port + } + } + + dsn, err := gosnowflake.DSN(cfg) + if err != nil { + level.Error(j.log).Log("msg", "Failed to create Snowflake DSN with password", "err", err) + continue + } + + newConn.conn, err = sqlx.Open("snowflake", dsn) + if err != nil { + level.Error(j.log).Log("msg", "Failed to open Snowflake connection", "err", err) continue } - cfg.Port = portStr - } - - dsn, err := gosnowflake.DSN(cfg) - if err != nil { - level.Error(j.log).Log("msg", "Failed to create Snowflake DSN", "connection", conn, "err", err) - continue - } - - newConn.conn, err = sqlx.Open("snowflake", dsn) - if err != nil { - level.Error(j.log).Log("msg", "Failed to open Snowflake connection", "connection", conn, "err", err) - continue } + + j.conns = append(j.conns, newConn) + continue } j.conns = append(j.conns, newConn) @@ -570,6 +623,34 @@ func (c *connection) connect(job *Job) error { } return nil } + if c.driver == "snowflake" { + if c.snowflakeDSN != "" { + if time.Now().After(c.tokenExpirationTime) { + if c.conn != nil { + c.conn.Close() + c.conn = nil + } + c.tokenExpirationTime = time.Now().Add(time.Hour) + } + + db, err := sqlx.Open("snowflake", c.snowflakeDSN) + if err != nil { + return fmt.Errorf("failed to open Snowflake connection: %w (host: %s)", err, c.host) + } + + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(0) + db.SetConnMaxLifetime(30 * time.Minute) + + if err := db.Ping(); err != nil { + db.Close() + return fmt.Errorf("failed to ping Snowflake: %w (host: %s)", err, c.host) + } + + c.conn = db + return nil + } + } dsn := c.url switch c.driver { case "mysql":