diff --git a/Cargo.lock b/Cargo.lock index 60fbaac65ab..a8f8021bfa2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8193,6 +8193,7 @@ dependencies = [ "http 1.3.1", "log", "pgwire", + "spacetimedb-auth", "spacetimedb-client-api", "spacetimedb-client-api-messages", "spacetimedb-lib 2.0.3", diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index 20e6196c78b..cc6e7954e4b 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -24,6 +24,7 @@ use futures::TryStreamExt; use http::StatusCode; use log::{info, warn}; use serde::Deserialize; +use spacetimedb::auth::identity::ConnectionAuthCtx; use spacetimedb::database_logger::DatabaseLogger; use spacetimedb::host::module_host::ClientConnectedError; use spacetimedb::host::{CallResult, UpdateDatabaseResult}; @@ -518,22 +519,46 @@ pub async fn sql_direct( SqlParams { name_or_identity }: SqlParams, SqlQueryParams { confirmed }: SqlQueryParams, caller_identity: Identity, + caller_auth: ConnectionAuthCtx, sql: String, ) -> axum::response::Result>> where S: NodeDelegate + ControlStateDelegate + Authorization, { - // Anyone is authorized to execute SQL queries. The SQL engine will determine - // which queries this identity is allowed to execute against the database. + let connection_id = generate_random_connection_id(); let (host, database) = find_leader_and_database(&worker_ctx, name_or_identity).await?; - let auth = worker_ctx - .authorize_sql(caller_identity, database.database_identity) - .await?; + // Run the module's client_connected reducer, if any. + // If it rejects the connection, bail before executing SQL. + let module = host.module().await.map_err(log_and_500)?; + module + .call_identity_connected(caller_auth, connection_id) + .await + .map_err(client_connected_error_to_response)?; + + let result = async { + let sql_auth = worker_ctx + .authorize_sql(caller_identity, database.database_identity) + .await?; + + host.exec_sql( + sql_auth, + database, + confirmed.unwrap_or(crate::DEFAULT_CONFIRMED_READS), + sql, + ) + .await + } + .await; - host.exec_sql(auth, database, confirmed.unwrap_or(crate::DEFAULT_CONFIRMED_READS), sql) + // Always disconnect, even if authorization or execution failed. + module + .call_identity_disconnected(caller_identity, connection_id, false) .await + .map_err(client_disconnected_error_to_response)?; + + result } pub async fn sql( @@ -546,7 +571,9 @@ pub async fn sql( where S: NodeDelegate + ControlStateDelegate + Authorization, { - let json = sql_direct(worker_ctx, name_or_identity, params, auth.claims.identity, body).await?; + let caller_identity = auth.claims.identity; + let caller_auth: ConnectionAuthCtx = auth.into(); + let json = sql_direct(worker_ctx, name_or_identity, params, caller_identity, caller_auth, body).await?; let total_duration = json.iter().fold(0, |acc, x| acc + x.total_duration_micros); diff --git a/crates/pg/Cargo.toml b/crates/pg/Cargo.toml index dd49122dea0..a3e1f0c4f3f 100644 --- a/crates/pg/Cargo.toml +++ b/crates/pg/Cargo.toml @@ -7,6 +7,7 @@ license-file = "LICENSE" description = "Postgres wire protocol Server support for SpacetimeDB" [dependencies] +spacetimedb-auth.workspace = true spacetimedb-client-api-messages.workspace = true spacetimedb-client-api.workspace = true spacetimedb-lib.workspace = true diff --git a/crates/pg/src/pg_server.rs b/crates/pg/src/pg_server.rs index 10dae652a02..138ab6430b2 100644 --- a/crates/pg/src/pg_server.rs +++ b/crates/pg/src/pg_server.rs @@ -22,6 +22,7 @@ use pgwire::messages::data::DataRow; use pgwire::messages::startup::Authentication; use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; use pgwire::tokio::process_socket; +use spacetimedb_auth::identity::ConnectionAuthCtx; use spacetimedb_client_api::auth::validate_token; use spacetimedb_client_api::routes::database; use spacetimedb_client_api::routes::database::{SqlParams, SqlQueryParams}; @@ -64,6 +65,7 @@ impl From for PgWireError { struct Metadata { database: String, caller_identity: Identity, + caller_auth: ConnectionAuthCtx, } pub(crate) fn to_rows( @@ -163,6 +165,7 @@ where db, SqlQueryParams { confirmed: Some(true) }, params.caller_identity, + params.caller_auth.clone(), query.to_string(), ) .await, @@ -266,8 +269,8 @@ impl claims.identity, + let claims = match validate_token(&self.ctx, &pwd.password).await { + Ok(claims) => claims, Err(err) => { log::error!( "PG: Authentication failed for identity `{}` on database {database}: {err}", @@ -277,12 +280,22 @@ impl