Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 85 additions & 51 deletions crates/client-api/src/routes/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::auth::{
};
use crate::routes::subscribe::generate_random_connection_id;
use crate::util::{ByteStringBody, NameOrIdentity};
use crate::{log_and_500, ControlStateDelegate, DatabaseDef, NodeDelegate};
use crate::{log_and_500, ControlStateDelegate, DatabaseDef, Host, NodeDelegate};
use axum::body::{Body, Bytes};
use axum::extract::{Path, Query, State};
use axum::response::{ErrorResponse, IntoResponse};
Expand All @@ -20,16 +20,16 @@ use http::StatusCode;
use serde::Deserialize;
use spacetimedb::database_logger::DatabaseLogger;
use spacetimedb::host::module_host::ClientConnectedError;
use spacetimedb::host::ReducerArgs;
use spacetimedb::host::ReducerCallError;
use spacetimedb::host::ReducerOutcome;
use spacetimedb::host::UpdateDatabaseResult;
use spacetimedb::host::{ModuleHost, ReducerArgs};
use spacetimedb::identity::Identity;
use spacetimedb::messages::control_db::{Database, HostType};
use spacetimedb_client_api_messages::name::{self, DatabaseName, DomainName, PublishOp, PublishResult};
use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9;
use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_lib::sats;
use spacetimedb_lib::{sats, ConnectionId};

use super::subscribe::handle_websocket;

Expand All @@ -41,22 +41,20 @@ pub struct CallParams {

pub const NO_SUCH_DATABASE: (StatusCode, &str) = (StatusCode::NOT_FOUND, "No such database.");

pub async fn call<S: ControlStateDelegate + NodeDelegate>(
State(worker_ctx): State<S>,
Extension(auth): Extension<SpacetimeAuth>,
Path(CallParams {
name_or_identity,
reducer,
}): Path<CallParams>,
TypedHeader(content_type): TypedHeader<headers::ContentType>,
ByteStringBody(body): ByteStringBody,
) -> axum::response::Result<impl IntoResponse> {
if content_type != headers::ContentType::json() {
return Err(axum::extract::rejection::MissingJsonContentType::default().into());
}
let caller_identity = auth.identity;
struct Connected {
database: Database,
leader: Host,
module: ModuleHost,
connection_id: ConnectionId,
caller_identity: Identity,
}

let args = ReducerArgs::Json(body);
async fn call_on_connect<S: ControlStateDelegate + NodeDelegate>(
worker_ctx: S,
auth: SpacetimeAuth,
name_or_identity: NameOrIdentity,
) -> axum::response::Result<Connected> {
let caller_identity = auth.identity;

let db_identity = name_or_identity.resolve(&worker_ctx).await?;
let database = worker_ctx_find_database(&worker_ctx, &db_identity)
Expand All @@ -65,7 +63,6 @@ pub async fn call<S: ControlStateDelegate + NodeDelegate>(
log::error!("Could not find database: {}", db_identity.to_hex());
NO_SUCH_DATABASE
})?;
let identity = database.owner_identity;

let leader = worker_ctx
.leader(database.id)
Expand All @@ -81,32 +78,75 @@ pub async fn call<S: ControlStateDelegate + NodeDelegate>(
match module.call_identity_connected(caller_identity, connection_id).await {
// If `call_identity_connected` returns `Err(Rejected)`, then the `client_connected` reducer errored,
// meaning the connection was refused. Return 403 forbidden.
Err(ClientConnectedError::Rejected(msg)) => return Err((StatusCode::FORBIDDEN, msg).into()),
Err(ClientConnectedError::Rejected(msg)) => Err((StatusCode::FORBIDDEN, msg).into()),
// If `call_identity_connected` returns `Err(OutOfEnergy)`,
// then, well, the database is out of energy.
// Return 503 service unavailable.
Err(err @ ClientConnectedError::OutOfEnergy) => {
return Err((StatusCode::SERVICE_UNAVAILABLE, err.to_string()).into())
}
Err(err @ ClientConnectedError::OutOfEnergy) => Err((StatusCode::SERVICE_UNAVAILABLE, err.to_string()).into()),
// If `call_identity_connected` returns `Err(ReducerCall)`,
// something went wrong while invoking the `client_connected` reducer.
// I (pgoldman 2025-03-27) am not really sure how this would happen,
// but we returned 404 not found in this case prior to my editing this code,
// so I guess let's keep doing that.
Err(ClientConnectedError::ReducerCall(e)) => {
return Err((StatusCode::NOT_FOUND, format!("{:#}", anyhow::anyhow!(e))).into())
Err((StatusCode::NOT_FOUND, format!("{:#}", anyhow::anyhow!(e))).into())
}
// If `call_identity_connected` returns `Err(DBError)`,
// then the module didn't define `client_connected`,
// but something went wrong when we tried to insert into `st_client`.
// That's weird and scary, so return 500 internal error.
Err(e @ ClientConnectedError::DBError(_)) => {
return Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into())
}
Err(e @ ClientConnectedError::DBError(_)) => Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into()),

// If `call_identity_connected` returns `Ok`, then we can actually call the reducer we want.
Ok(()) => (),
Ok(()) => Ok(Connected {
caller_identity,
database,
leader,
module,
connection_id,
}),
}
}

async fn call_on_disconnect(
module: ModuleHost,
connection_id: ConnectionId,
caller_identity: Identity,
) -> axum::response::Result<()> {
if let Err(e) = module.call_identity_disconnected(caller_identity, connection_id).await {
// If `call_identity_disconnected` errors, something is very wrong:
// it means we tried to delete the `st_client` row but failed.
// Note that `call_identity_disconnected` swallows errors from the `client_disconnected` reducer.
// Slap a 500 on it and pray.
return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("{:#}", anyhow::anyhow!(e))).into());
}
Ok(())
}

pub async fn call<S: ControlStateDelegate + NodeDelegate>(
State(worker_ctx): State<S>,
Extension(auth): Extension<SpacetimeAuth>,
Path(CallParams {
name_or_identity,
reducer,
}): Path<CallParams>,
TypedHeader(content_type): TypedHeader<headers::ContentType>,
ByteStringBody(body): ByteStringBody,
) -> axum::response::Result<impl IntoResponse> {
if content_type != headers::ContentType::json() {
return Err(axum::extract::rejection::MissingJsonContentType::default().into());
}

let Connected {
caller_identity,
database,
leader: _,
module,
connection_id,
} = call_on_connect(worker_ctx, auth, name_or_identity).await?;

let args = ReducerArgs::Json(body);

let result = match module
.call_reducer(caller_identity, Some(connection_id), None, None, None, &reducer, args)
.await
Expand Down Expand Up @@ -134,17 +174,11 @@ pub async fn call<S: ControlStateDelegate + NodeDelegate>(
}
};

if let Err(e) = module.call_identity_disconnected(caller_identity, connection_id).await {
// If `call_identity_disconnected` errors, something is very wrong:
// it means we tried to delete the `st_client` row but failed.
// Note that `call_identity_disconnected` swallows errors from the `client_disconnected` reducer.
// Slap a 500 on it and pray.
return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("{:#}", anyhow::anyhow!(e))).into());
}
call_on_disconnect(module, connection_id, caller_identity).await?;

match result {
Ok(result) => {
let (status, body) = reducer_outcome_response(&identity, &reducer, result.outcome);
let (status, body) = reducer_outcome_response(&database.owner_identity, &reducer, result.outcome);
Ok((
status,
TypedHeader(SpacetimeEnergyUsed(result.energy_used)),
Expand Down Expand Up @@ -400,22 +434,21 @@ where
{
// Anyone is authorized to execute SQL queries. The SQL engine will determine
// which queries this identity is allowed to execute against the database.

let db_identity = name_or_identity.resolve(&worker_ctx).await?;
let database = worker_ctx_find_database(&worker_ctx, &db_identity)
.await?
.ok_or(NO_SUCH_DATABASE)?;

let auth = AuthCtx::new(database.owner_identity, auth.identity);
let Connected {
caller_identity,
database,
leader,
module,
connection_id,
} = call_on_connect(worker_ctx, auth, name_or_identity).await?;

let auth = AuthCtx::new(database.owner_identity, caller_identity);
log::debug!("auth: {auth:?}");

let host = worker_ctx
.leader(database.id)
.await
.map_err(log_and_500)?
.ok_or(StatusCode::NOT_FOUND)?;
let json = host.exec_sql(auth, database, body).await?;

// Notify the disconnect even if the SQL execution fails.
let json_result = leader.exec_sql(auth, database, body).await;
call_on_disconnect(module, connection_id, caller_identity).await?;
let json = json_result?;
let total_duration = json.iter().fold(0, |acc, x| acc + x.total_duration_micros);

Ok((
Expand Down Expand Up @@ -766,9 +799,10 @@ where
names_put: put(set_names::<S>),
identity_get: get(get_identity::<S>),
subscribe_get: get(handle_websocket::<S>),
call_reducer_post: post(call::<S>),
schema_get: get(schema::<S>),
logs_get: get(logs::<S>),
// Need calls to on_connect and on_disconnect...
call_reducer_post: post(call::<S>),
sql_post: post(sql::<S>),
}
}
Expand Down
20 changes: 14 additions & 6 deletions smoketests/tests/client_connected_error_rejects_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
}
"""


class ClientConnectedErrorRejectsConnection(Smoketest):
MODULE_CODE = MODULE_HEADER + """

Expand All @@ -33,12 +34,13 @@ class ClientConnectedErrorRejectsConnection(Smoketest):

def test_client_connected_error_rejects_connection(self):
with self.assertRaises(Exception):
self.subscribe("select * from all_u8s", n = 0)()
self.subscribe("select * from all_u8s", n=0)()

logs = self.logs(100)
self.assertIn('Rejecting connection from client', logs)
self.assertNotIn('This should never be called, since we reject all connections!', logs)


class ClientDisconnectedErrorStillDeletesStClient(Smoketest):
MODULE_CODE = MODULE_HEADER + """
#[spacetimedb::reducer(client_connected)]
Expand All @@ -53,13 +55,19 @@ class ClientDisconnectedErrorStillDeletesStClient(Smoketest):
"""

def test_client_disconnected_error_still_deletes_st_client(self):
self.subscribe("select * from all_u8s", n = 0)()
# The st_client table should only have the data of the current connection
self.subscribe("select * from all_u8s", n=0)()

logs = self.logs(100)
self.assertIn('This should be called, but the `st_client` row should still be deleted', logs)

sql_out = self.spacetime("sql", self.database_identity, "select * from st_client")

self.assertMultiLineEqual(sql_out, """ identity | connection_id
----------+---------------
""")
row = []
# Get only the rows with numeric data
# identity | connection_id
# ------------------------------------------------------------------------------+-----------------------------------------
for x in sql_out.splitlines()[1:]:
x = x.split("|")[0].strip()
if x.isdigit():
row.append(x)
self.assertEqual(len(row), 1)
11 changes: 10 additions & 1 deletion smoketests/tests/connect_disconnect_from_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ConnDisconnFromCli(Smoketest):
}
"""

def test_conn_disconn(self):
def test_conn_disconn_cli(self):
"""
Ensure that the connect and disconnect functions are called when invoking a reducer from the CLI
"""
Expand All @@ -31,3 +31,12 @@ def test_conn_disconn(self):
self.assertIn('_connect called', logs)
self.assertIn('disconnect called', logs)
self.assertIn('Hello, World!', logs)

def test_conn_disconn_sql(self):
"""
Ensure that the connect and disconnect functions are called when invoking a sql from the CLI
"""
self.spacetime("sql", self.database_identity, "select * from st_client")
logs = self.logs(10)
self.assertIn('_connect called', logs)
self.assertIn('disconnect called', logs)
Loading