From 429e1fc86cb89db37c9bf103d5cc4652b8e4ab67 Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Mon, 2 Jun 2025 13:40:27 -0500 Subject: [PATCH 1/2] Invoke __identity_connected__ for sql http requests --- crates/client-api/src/routes/database.rs | 136 +++++++++++------- ...ient_connected_error_rejects_connection.py | 7 +- .../tests/connect_disconnect_from_cli.py | 11 +- 3 files changed, 99 insertions(+), 55 deletions(-) diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index e6e8091f27f..f380c66845c 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -8,7 +8,7 @@ use crate::auth::{ }; use crate::routes::subscribe::generate_random_connection_id; use crate::util::{ByteStringBody, NameOrIdentity}; -use crate::{log_and_500, ControlStateDelegate, DatabaseDef, NodeDelegate}; +use crate::{log_and_500, ControlStateDelegate, DatabaseDef, Host, NodeDelegate}; use axum::body::{Body, Bytes}; use axum::extract::{Path, Query, State}; use axum::response::{ErrorResponse, IntoResponse}; @@ -20,16 +20,16 @@ use http::StatusCode; use serde::Deserialize; use spacetimedb::database_logger::DatabaseLogger; use spacetimedb::host::module_host::ClientConnectedError; -use spacetimedb::host::ReducerArgs; use spacetimedb::host::ReducerCallError; use spacetimedb::host::ReducerOutcome; use spacetimedb::host::UpdateDatabaseResult; +use spacetimedb::host::{ModuleHost, ReducerArgs}; use spacetimedb::identity::Identity; use spacetimedb::messages::control_db::{Database, HostType}; use spacetimedb_client_api_messages::name::{self, DatabaseName, DomainName, PublishOp, PublishResult}; use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9; use spacetimedb_lib::identity::AuthCtx; -use spacetimedb_lib::sats; +use spacetimedb_lib::{sats, ConnectionId}; use super::subscribe::handle_websocket; @@ -41,22 +41,20 @@ pub struct CallParams { pub const NO_SUCH_DATABASE: (StatusCode, &str) = (StatusCode::NOT_FOUND, "No such database."); -pub async fn call( - State(worker_ctx): State, - Extension(auth): Extension, - Path(CallParams { - name_or_identity, - reducer, - }): Path, - TypedHeader(content_type): TypedHeader, - ByteStringBody(body): ByteStringBody, -) -> axum::response::Result { - if content_type != headers::ContentType::json() { - return Err(axum::extract::rejection::MissingJsonContentType::default().into()); - } - let caller_identity = auth.identity; +struct Connected { + database: Database, + leader: Host, + module: ModuleHost, + connection_id: ConnectionId, + caller_identity: Identity, +} - let args = ReducerArgs::Json(body); +async fn call_on_connect( + worker_ctx: S, + auth: SpacetimeAuth, + name_or_identity: NameOrIdentity, +) -> axum::response::Result { + let caller_identity = auth.identity; let db_identity = name_or_identity.resolve(&worker_ctx).await?; let database = worker_ctx_find_database(&worker_ctx, &db_identity) @@ -65,7 +63,6 @@ pub async fn call( log::error!("Could not find database: {}", db_identity.to_hex()); NO_SUCH_DATABASE })?; - let identity = database.owner_identity; let leader = worker_ctx .leader(database.id) @@ -81,32 +78,75 @@ pub async fn call( match module.call_identity_connected(caller_identity, connection_id).await { // If `call_identity_connected` returns `Err(Rejected)`, then the `client_connected` reducer errored, // meaning the connection was refused. Return 403 forbidden. - Err(ClientConnectedError::Rejected(msg)) => return Err((StatusCode::FORBIDDEN, msg).into()), + Err(ClientConnectedError::Rejected(msg)) => Err((StatusCode::FORBIDDEN, msg).into()), // If `call_identity_connected` returns `Err(OutOfEnergy)`, // then, well, the database is out of energy. // Return 503 service unavailable. - Err(err @ ClientConnectedError::OutOfEnergy) => { - return Err((StatusCode::SERVICE_UNAVAILABLE, err.to_string()).into()) - } + Err(err @ ClientConnectedError::OutOfEnergy) => Err((StatusCode::SERVICE_UNAVAILABLE, err.to_string()).into()), // If `call_identity_connected` returns `Err(ReducerCall)`, // something went wrong while invoking the `client_connected` reducer. // I (pgoldman 2025-03-27) am not really sure how this would happen, // but we returned 404 not found in this case prior to my editing this code, // so I guess let's keep doing that. Err(ClientConnectedError::ReducerCall(e)) => { - return Err((StatusCode::NOT_FOUND, format!("{:#}", anyhow::anyhow!(e))).into()) + Err((StatusCode::NOT_FOUND, format!("{:#}", anyhow::anyhow!(e))).into()) } // If `call_identity_connected` returns `Err(DBError)`, // then the module didn't define `client_connected`, // but something went wrong when we tried to insert into `st_client`. // That's weird and scary, so return 500 internal error. - Err(e @ ClientConnectedError::DBError(_)) => { - return Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into()) - } + Err(e @ ClientConnectedError::DBError(_)) => Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into()), // If `call_identity_connected` returns `Ok`, then we can actually call the reducer we want. - Ok(()) => (), + Ok(()) => Ok(Connected { + caller_identity, + database, + leader, + module, + connection_id, + }), } +} + +async fn call_on_disconnect( + module: ModuleHost, + connection_id: ConnectionId, + caller_identity: Identity, +) -> axum::response::Result<()> { + if let Err(e) = module.call_identity_disconnected(caller_identity, connection_id).await { + // If `call_identity_disconnected` errors, something is very wrong: + // it means we tried to delete the `st_client` row but failed. + // Note that `call_identity_disconnected` swallows errors from the `client_disconnected` reducer. + // Slap a 500 on it and pray. + return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("{:#}", anyhow::anyhow!(e))).into()); + } + Ok(()) +} + +pub async fn call( + State(worker_ctx): State, + Extension(auth): Extension, + Path(CallParams { + name_or_identity, + reducer, + }): Path, + TypedHeader(content_type): TypedHeader, + ByteStringBody(body): ByteStringBody, +) -> axum::response::Result { + if content_type != headers::ContentType::json() { + return Err(axum::extract::rejection::MissingJsonContentType::default().into()); + } + + let Connected { + caller_identity, + database, + leader: _, + module, + connection_id, + } = call_on_connect(worker_ctx, auth, name_or_identity).await?; + + let args = ReducerArgs::Json(body); + let result = match module .call_reducer(caller_identity, Some(connection_id), None, None, None, &reducer, args) .await @@ -134,17 +174,11 @@ pub async fn call( } }; - if let Err(e) = module.call_identity_disconnected(caller_identity, connection_id).await { - // If `call_identity_disconnected` errors, something is very wrong: - // it means we tried to delete the `st_client` row but failed. - // Note that `call_identity_disconnected` swallows errors from the `client_disconnected` reducer. - // Slap a 500 on it and pray. - return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("{:#}", anyhow::anyhow!(e))).into()); - } + call_on_disconnect(module, connection_id, caller_identity).await?; match result { Ok(result) => { - let (status, body) = reducer_outcome_response(&identity, &reducer, result.outcome); + let (status, body) = reducer_outcome_response(&database.owner_identity, &reducer, result.outcome); Ok(( status, TypedHeader(SpacetimeEnergyUsed(result.energy_used)), @@ -400,22 +434,21 @@ where { // Anyone is authorized to execute SQL queries. The SQL engine will determine // which queries this identity is allowed to execute against the database. - - let db_identity = name_or_identity.resolve(&worker_ctx).await?; - let database = worker_ctx_find_database(&worker_ctx, &db_identity) - .await? - .ok_or(NO_SUCH_DATABASE)?; - - let auth = AuthCtx::new(database.owner_identity, auth.identity); + let Connected { + caller_identity, + database, + leader, + module, + connection_id, + } = call_on_connect(worker_ctx, auth, name_or_identity).await?; + + let auth = AuthCtx::new(database.owner_identity, caller_identity); log::debug!("auth: {auth:?}"); - let host = worker_ctx - .leader(database.id) - .await - .map_err(log_and_500)? - .ok_or(StatusCode::NOT_FOUND)?; - let json = host.exec_sql(auth, database, body).await?; - + // Notify the disconnect even if the SQL execution fails. + let json_result = leader.exec_sql(auth, database, body).await; + call_on_disconnect(module, connection_id, caller_identity).await?; + let json = json_result?; let total_duration = json.iter().fold(0, |acc, x| acc + x.total_duration_micros); Ok(( @@ -766,9 +799,10 @@ where names_put: put(set_names::), identity_get: get(get_identity::), subscribe_get: get(handle_websocket::), - call_reducer_post: post(call::), schema_get: get(schema::), logs_get: get(logs::), + // Need calls to on_connect and on_disconnect... + call_reducer_post: post(call::), sql_post: post(sql::), } } diff --git a/smoketests/tests/client_connected_error_rejects_connection.py b/smoketests/tests/client_connected_error_rejects_connection.py index 8654643ad19..a95da85dde3 100644 --- a/smoketests/tests/client_connected_error_rejects_connection.py +++ b/smoketests/tests/client_connected_error_rejects_connection.py @@ -17,6 +17,7 @@ } """ + class ClientConnectedErrorRejectsConnection(Smoketest): MODULE_CODE = MODULE_HEADER + """ @@ -33,12 +34,13 @@ class ClientConnectedErrorRejectsConnection(Smoketest): def test_client_connected_error_rejects_connection(self): with self.assertRaises(Exception): - self.subscribe("select * from all_u8s", n = 0)() + self.subscribe("select * from all_u8s", n=0)() logs = self.logs(100) self.assertIn('Rejecting connection from client', logs) self.assertNotIn('This should never be called, since we reject all connections!', logs) + class ClientDisconnectedErrorStillDeletesStClient(Smoketest): MODULE_CODE = MODULE_HEADER + """ #[spacetimedb::reducer(client_connected)] @@ -53,13 +55,12 @@ class ClientDisconnectedErrorStillDeletesStClient(Smoketest): """ def test_client_disconnected_error_still_deletes_st_client(self): - self.subscribe("select * from all_u8s", n = 0)() + self.subscribe("select * from all_u8s", n=0)() logs = self.logs(100) self.assertIn('This should be called, but the `st_client` row should still be deleted', logs) sql_out = self.spacetime("sql", self.database_identity, "select * from st_client") - self.assertMultiLineEqual(sql_out, """ identity | connection_id ----------+--------------- """) diff --git a/smoketests/tests/connect_disconnect_from_cli.py b/smoketests/tests/connect_disconnect_from_cli.py index a2f46ac882e..739645609d7 100644 --- a/smoketests/tests/connect_disconnect_from_cli.py +++ b/smoketests/tests/connect_disconnect_from_cli.py @@ -21,7 +21,7 @@ class ConnDisconnFromCli(Smoketest): } """ - def test_conn_disconn(self): + def test_conn_disconn_cli(self): """ Ensure that the connect and disconnect functions are called when invoking a reducer from the CLI """ @@ -31,3 +31,12 @@ def test_conn_disconn(self): self.assertIn('_connect called', logs) self.assertIn('disconnect called', logs) self.assertIn('Hello, World!', logs) + + def test_conn_disconn_sql(self): + """ + Ensure that the connect and disconnect functions are called when invoking a sql from the CLI + """ + self.spacetime("sql", self.database_identity, "select * from st_client") + logs = self.logs(10) + self.assertIn('_connect called', logs) + self.assertIn('disconnect called', logs) From 44ce62589758a4702b926f1252e995797f2559e8 Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Fri, 6 Jun 2025 10:14:55 -0500 Subject: [PATCH 2/2] Change smoketest to reflect new behaviour --- .../client_connected_error_rejects_connection.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/smoketests/tests/client_connected_error_rejects_connection.py b/smoketests/tests/client_connected_error_rejects_connection.py index a95da85dde3..407db4d67c4 100644 --- a/smoketests/tests/client_connected_error_rejects_connection.py +++ b/smoketests/tests/client_connected_error_rejects_connection.py @@ -55,12 +55,19 @@ class ClientDisconnectedErrorStillDeletesStClient(Smoketest): """ def test_client_disconnected_error_still_deletes_st_client(self): + # The st_client table should only have the data of the current connection self.subscribe("select * from all_u8s", n=0)() logs = self.logs(100) self.assertIn('This should be called, but the `st_client` row should still be deleted', logs) sql_out = self.spacetime("sql", self.database_identity, "select * from st_client") - self.assertMultiLineEqual(sql_out, """ identity | connection_id -----------+--------------- -""") + row = [] + # Get only the rows with numeric data + # identity | connection_id + # ------------------------------------------------------------------------------+----------------------------------------- + for x in sql_out.splitlines()[1:]: + x = x.split("|")[0].strip() + if x.isdigit(): + row.append(x) + self.assertEqual(len(row), 1)