diff --git a/lightning-block-sync/Cargo.toml b/lightning-block-sync/Cargo.toml index 97f199963ac..d8d71da3fae 100644 --- a/lightning-block-sync/Cargo.toml +++ b/lightning-block-sync/Cargo.toml @@ -16,15 +16,16 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [features] -rest-client = [ "serde_json", "chunked_transfer" ] -rpc-client = [ "serde_json", "chunked_transfer" ] +rest-client = [ "serde_json", "dep:bitreq" ] +rpc-client = [ "serde_json", "dep:bitreq" ] +tokio = [ "dep:tokio", "bitreq?/async" ] [dependencies] bitcoin = "0.32.2" lightning = { version = "0.3.0", path = "../lightning" } tokio = { version = "1.35", features = [ "io-util", "net", "time", "rt" ], optional = true } serde_json = { version = "1.0", optional = true } -chunked_transfer = { version = "1.4", optional = true } +bitreq = { version = "0.3", default-features = false, features = ["std"], optional = true } [dev-dependencies] lightning = { version = "0.3.0", path = "../lightning", features = ["_test_utils"] } diff --git a/lightning-block-sync/src/convert.rs b/lightning-block-sync/src/convert.rs index a31b329a5af..2f1edd86d1f 100644 --- a/lightning-block-sync/src/convert.rs +++ b/lightning-block-sync/src/convert.rs @@ -1,4 +1,6 @@ -use crate::http::{BinaryResponse, JsonResponse}; +use crate::http::{BinaryResponse, HttpClientError, JsonResponse}; +#[cfg(feature = "rpc-client")] +use crate::rpc::RpcClientError; use crate::utils::hex_to_work; use crate::{BlockHeaderData, BlockSourceError}; @@ -35,6 +37,48 @@ impl From for BlockSourceError { } } +/// Conversion from `HttpClientError` into `BlockSourceError`. +impl From for BlockSourceError { + fn from(e: HttpClientError) -> BlockSourceError { + match &e { + // Transport errors (connection, timeout, etc.) are transient + HttpClientError::Transport(_) => BlockSourceError::transient(e), + // HTTP non-2xx errors are transient - e.g. "not found" must not stop polling + HttpClientError::Http(_) => BlockSourceError::transient(e), + // I/O errors follow the same logic as std::io::Error + HttpClientError::Io(io_error) => match io_error.kind() { + io::ErrorKind::InvalidData => BlockSourceError::persistent(e), + io::ErrorKind::InvalidInput => BlockSourceError::persistent(e), + _ => BlockSourceError::transient(e), + }, + } + } +} + +/// Conversion from `RpcClientError` into `BlockSourceError`. +#[cfg(feature = "rpc-client")] +impl From for BlockSourceError { + fn from(e: RpcClientError) -> BlockSourceError { + match &e { + RpcClientError::Http(http_error) => match http_error { + HttpClientError::Transport(_) => BlockSourceError::transient(e), + // HTTP non-2xx errors are transient + HttpClientError::Http(_) => BlockSourceError::transient(e), + HttpClientError::Io(io_error) => match io_error.kind() { + io::ErrorKind::InvalidData => BlockSourceError::persistent(e), + io::ErrorKind::InvalidInput => BlockSourceError::persistent(e), + _ => BlockSourceError::transient(e), + }, + }, + // RPC errors are transient + // e.g. "block not found" should not stop polling + RpcClientError::Rpc(_) => BlockSourceError::transient(e), + // Malformed response data is persistent + RpcClientError::InvalidData(_) => BlockSourceError::persistent(e), + } + } +} + /// Parses binary data as a block. impl TryInto for BinaryResponse { type Error = io::Error; diff --git a/lightning-block-sync/src/http.rs b/lightning-block-sync/src/http.rs index 0fb82b4acde..f51bae60869 100644 --- a/lightning-block-sync/src/http.rs +++ b/lightning-block-sync/src/http.rs @@ -1,30 +1,16 @@ //! Simple HTTP implementation which supports both async and traditional execution environments //! with minimal dependencies. This is used as the basis for REST and RPC clients. -use chunked_transfer; use serde_json; +#[cfg(feature = "tokio")] +use bitreq::RequestExt; + use std::convert::TryFrom; use std::fmt; -#[cfg(not(feature = "tokio"))] -use std::io::Write; use std::net::{SocketAddr, ToSocketAddrs}; use std::time::Duration; -#[cfg(feature = "tokio")] -use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}; -#[cfg(feature = "tokio")] -use tokio::net::TcpStream; - -#[cfg(not(feature = "tokio"))] -use std::io::BufRead; -use std::io::Read; -#[cfg(not(feature = "tokio"))] -use std::net::TcpStream; - -/// Timeout for operations on TCP streams. -const TCP_STREAM_TIMEOUT: Duration = Duration::from_secs(5); - /// Timeout for reading the first byte of a response. This is separate from the general read /// timeout as it is not uncommon for Bitcoin Core to be blocked waiting on UTXO cache flushes for /// upwards of 10 minutes on slow devices (e.g. RPis with SSDs over USB). Note that we always retry @@ -32,13 +18,59 @@ const TCP_STREAM_TIMEOUT: Duration = Duration::from_secs(5); /// value. const TCP_STREAM_RESPONSE_TIMEOUT: Duration = Duration::from_secs(300); -/// Maximum HTTP message header size in bytes. -const MAX_HTTP_MESSAGE_HEADER_SIZE: usize = 8192; - /// Maximum HTTP message body size in bytes. Enough for a hex-encoded block in JSON format and any /// overhead for HTTP chunked transfer encoding. const MAX_HTTP_MESSAGE_BODY_SIZE: usize = 2 * 4_000_000 + 32_000; +/// Error type for HTTP client operations. +#[derive(Debug)] +pub enum HttpClientError { + /// transport-level error (connection, timeout, protocol parsing, etc.) + Transport(bitreq::Error), + /// HTTP error response + Http(HttpError), + /// I/O error (DNS resolution, etc.) + Io(std::io::Error), +} + +impl std::error::Error for HttpClientError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + HttpClientError::Transport(e) => Some(e), + HttpClientError::Http(e) => Some(e), + HttpClientError::Io(e) => Some(e), + } + } +} + +impl fmt::Display for HttpClientError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + HttpClientError::Transport(e) => write!(f, "transport error: {}", e), + HttpClientError::Http(e) => write!(f, "HTTP error: {}", e), + HttpClientError::Io(e) => write!(f, "I/O error: {}", e), + } + } +} + +impl From for HttpClientError { + fn from(e: std::io::Error) -> Self { + HttpClientError::Io(e) + } +} + +impl From for HttpClientError { + fn from(e: bitreq::Error) -> Self { + HttpClientError::Transport(e) + } +} + +impl From for HttpClientError { + fn from(e: HttpError) -> Self { + HttpClientError::Http(e) + } +} + /// Endpoint for interacting with an HTTP-based API. #[derive(Debug)] pub struct HttpEndpoint { @@ -92,308 +124,141 @@ impl<'a> std::net::ToSocketAddrs for &'a HttpEndpoint { } } +/// Maximum number of cached connections in the connection pool. +#[cfg(feature = "tokio")] +const MAX_CONNECTIONS: usize = 10; + /// Client for making HTTP requests. pub(crate) struct HttpClient { address: SocketAddr, - stream: TcpStream, + #[cfg(feature = "tokio")] + client: bitreq::Client, } impl HttpClient { /// Opens a connection to an HTTP endpoint. - pub fn connect(endpoint: E) -> std::io::Result { + pub fn connect(endpoint: E) -> Result { let address = match endpoint.to_socket_addrs()?.next() { None => { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, "could not resolve to any addresses", - )); + ) + .into()); }, Some(address) => address, }; - let stream = std::net::TcpStream::connect_timeout(&address, TCP_STREAM_TIMEOUT)?; - stream.set_read_timeout(Some(TCP_STREAM_TIMEOUT))?; - stream.set_write_timeout(Some(TCP_STREAM_TIMEOUT))?; - #[cfg(feature = "tokio")] - let stream = { - stream.set_nonblocking(true)?; - TcpStream::from_std(stream)? - }; - - Ok(Self { address, stream }) + Ok(Self { + address, + #[cfg(feature = "tokio")] + client: bitreq::Client::new(MAX_CONNECTIONS), + }) } - /// Sends a `GET` request for a resource identified by `uri` at the `host`. + /// Sends a `GET` request for a resource identified by `uri`. /// /// Returns the response body in `F` format. #[allow(dead_code)] - pub async fn get(&mut self, uri: &str, host: &str) -> std::io::Result + pub async fn get(&mut self, uri: &str) -> Result where F: TryFrom, Error = std::io::Error>, { - let request = format!( - "GET {} HTTP/1.1\r\n\ - Host: {}\r\n\ - Connection: keep-alive\r\n\ - \r\n", - uri, host - ); - let response_body = self.send_request_with_retry(&request).await?; - F::try_from(response_body) - } - - /// Sends a `POST` request for a resource identified by `uri` at the `host` using the given HTTP + let address = self.address; + let response_body = self + .send_request_with_retry(|| { + let url = format!("http://{}{}", address, uri); + bitreq::get(url) + .with_timeout(TCP_STREAM_RESPONSE_TIMEOUT.as_secs()) + .with_max_body_size(Some(MAX_HTTP_MESSAGE_BODY_SIZE)) + }) + .await?; + F::try_from(response_body).map_err(HttpClientError::Io) + } + + /// Sends a `POST` request for a resource identified by `uri` using the given HTTP /// authentication credentials. /// /// The request body consists of the provided JSON `content`. Returns the response body in `F` /// format. #[allow(dead_code)] pub async fn post( - &mut self, uri: &str, host: &str, auth: &str, content: serde_json::Value, - ) -> std::io::Result + &mut self, uri: &str, auth: &str, content: serde_json::Value, + ) -> Result where F: TryFrom, Error = std::io::Error>, { + let address = self.address; let content = content.to_string(); - let request = format!( - "POST {} HTTP/1.1\r\n\ - Host: {}\r\n\ - Authorization: {}\r\n\ - Connection: keep-alive\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - \r\n\ - {}", - uri, - host, - auth, - content.len(), - content - ); - let response_body = self.send_request_with_retry(&request).await?; - F::try_from(response_body) + let response_body = self + .send_request_with_retry(|| { + let url = format!("http://{}{}", address, uri); + bitreq::post(url) + .with_header("Authorization", auth) + .with_header("Content-Type", "application/json") + .with_timeout(TCP_STREAM_RESPONSE_TIMEOUT.as_secs()) + .with_max_body_size(Some(MAX_HTTP_MESSAGE_BODY_SIZE)) + .with_body(content.clone()) + }) + .await?; + F::try_from(response_body).map_err(HttpClientError::Io) } /// Sends an HTTP request message and reads the response, returning its body. Attempts to - /// reconnect and retry if the connection has been closed. - async fn send_request_with_retry(&mut self, request: &str) -> std::io::Result> { - match self.send_request(request).await { + /// reconnect and retry only on transport failures (not on HTTP errors like 500/404). + async fn send_request_with_retry( + &mut self, build_request: impl Fn() -> bitreq::Request, + ) -> Result, HttpClientError> { + match self.send_request(build_request()).await { Ok(bytes) => Ok(bytes), - Err(_) => { - // Reconnect and retry on fail. This can happen if the connection was closed after - // the keep-alive limits are reached, or generally if the request timed out due to - // Bitcoin Core being stuck on a long-running operation or its RPC queue being - // full. - // Block 100ms before retrying the request as in many cases the source of the error - // may be persistent for some time. + Err(HttpClientError::Http(e)) => { + // Don't retry on HTTP errors (non-2xx responses) + Err(HttpClientError::Http(e)) + }, + Err(HttpClientError::Io(e)) => { + // Don't retry on I/O errors (e.g., response parsing failures). + Err(HttpClientError::Io(e)) + }, + Err(HttpClientError::Transport(_)) => { + // Reconnect and retry on transport failures. This can happen if the connection + // was closed after the keep-alive limits are reached, or generally if the + // request timed out due to Bitcoin Core being stuck on a long-running operation + // or its RPC queue being full. #[cfg(feature = "tokio")] tokio::time::sleep(Duration::from_millis(100)).await; #[cfg(not(feature = "tokio"))] std::thread::sleep(Duration::from_millis(100)); *self = Self::connect(self.address)?; - self.send_request(request).await + self.send_request(build_request()).await }, } } /// Sends an HTTP request message and reads the response, returning its body. - async fn send_request(&mut self, request: &str) -> std::io::Result> { - self.write_request(request).await?; - self.read_response().await - } - - /// Writes an HTTP request message. - async fn write_request(&mut self, request: &str) -> std::io::Result<()> { + async fn send_request(&self, request: bitreq::Request) -> Result, HttpClientError> { #[cfg(feature = "tokio")] - { - self.stream.write_all(request.as_bytes()).await?; - self.stream.flush().await - } + let response = request.send_async_with_client(&self.client).await?; #[cfg(not(feature = "tokio"))] - { - self.stream.write_all(request.as_bytes())?; - self.stream.flush() - } - } - - /// Reads an HTTP response message. - async fn read_response(&mut self) -> std::io::Result> { - #[cfg(feature = "tokio")] - let stream = self.stream.split().0; - #[cfg(not(feature = "tokio"))] - let stream = std::io::Read::by_ref(&mut self.stream); - - let limited_stream = stream.take(MAX_HTTP_MESSAGE_HEADER_SIZE as u64); + let response = request.send()?; - #[cfg(feature = "tokio")] - let mut reader = tokio::io::BufReader::new(limited_stream); - #[cfg(not(feature = "tokio"))] - let mut reader = std::io::BufReader::new(limited_stream); + let status_code = response.status_code; + let body = response.into_bytes(); - macro_rules! read_line { - () => { - read_line!(0) - }; - ($retry_count: expr) => {{ - let mut line = String::new(); - let mut timeout_count: u64 = 0; - let bytes_read = loop { - #[cfg(feature = "tokio")] - let read_res = reader.read_line(&mut line).await; - #[cfg(not(feature = "tokio"))] - let read_res = reader.read_line(&mut line); - match read_res { - Ok(bytes_read) => break bytes_read, - Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { - timeout_count += 1; - if timeout_count > $retry_count { - return Err(e); - } else { - continue; - } - }, - Err(e) => return Err(e), - } - }; - - match bytes_read { - 0 => None, - _ => { - // Remove trailing CRLF - if line.ends_with('\n') { - line.pop(); - if line.ends_with('\r') { - line.pop(); - } - } - Some(line) - }, - } - }}; + if !(200..300).contains(&status_code) { + return Err(HttpError { status_code, contents: body }.into()); } - // Read and parse status line - // Note that we allow retrying a few times to reach TCP_STREAM_RESPONSE_TIMEOUT. - let status_line = - read_line!(TCP_STREAM_RESPONSE_TIMEOUT.as_secs() / TCP_STREAM_TIMEOUT.as_secs()) - .ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no status line"))?; - let status = HttpStatus::parse(&status_line)?; - - // Read and parse relevant headers - let mut message_length = HttpMessageLength::Empty; - loop { - let line = read_line!() - .ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no headers"))?; - if line.is_empty() { - break; - } - - let header = HttpHeader::parse(&line)?; - if header.has_name("Content-Length") { - let length = header - .value - .parse() - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - if let HttpMessageLength::Empty = message_length { - message_length = HttpMessageLength::ContentLength(length); - } - continue; - } - - if header.has_name("Transfer-Encoding") { - message_length = HttpMessageLength::TransferEncoding(header.value.into()); - continue; - } - } - - // Read message body - let read_limit = MAX_HTTP_MESSAGE_BODY_SIZE - reader.buffer().len(); - reader.get_mut().set_limit(read_limit as u64); - let contents = match message_length { - HttpMessageLength::Empty => Vec::new(), - HttpMessageLength::ContentLength(length) => { - if length == 0 || length > MAX_HTTP_MESSAGE_BODY_SIZE { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("invalid response length: {} bytes", length), - )); - } else { - let mut content = vec![0; length]; - #[cfg(feature = "tokio")] - reader.read_exact(&mut content[..]).await?; - #[cfg(not(feature = "tokio"))] - reader.read_exact(&mut content[..])?; - content - } - }, - HttpMessageLength::TransferEncoding(coding) => { - if !coding.eq_ignore_ascii_case("chunked") { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "unsupported transfer coding", - )); - } else { - let mut content = Vec::new(); - #[cfg(feature = "tokio")] - { - // Since chunked_transfer doesn't have an async interface, only use it to - // determine the size of each chunk to read. - // - // TODO: Replace with an async interface when available. - // https://github.com/frewsxcv/rust-chunked-transfer/issues/7 - loop { - // Read the chunk header which contains the chunk size. - let mut chunk_header = String::new(); - reader.read_line(&mut chunk_header).await?; - if chunk_header == "0\r\n" { - // Read the terminator chunk since the decoder consumes the CRLF - // immediately when this chunk is encountered. - reader.read_line(&mut chunk_header).await?; - } - - // Decode the chunk header to obtain the chunk size. - let mut buffer = Vec::new(); - let mut decoder = - chunked_transfer::Decoder::new(chunk_header.as_bytes()); - decoder.read_to_end(&mut buffer)?; - - // Read the chunk body. - let chunk_size = match decoder.remaining_chunks_size() { - None => break, - Some(chunk_size) => chunk_size, - }; - let chunk_offset = content.len(); - content.resize(chunk_offset + chunk_size + "\r\n".len(), 0); - reader.read_exact(&mut content[chunk_offset..]).await?; - content.resize(chunk_offset + chunk_size, 0); - } - content - } - #[cfg(not(feature = "tokio"))] - { - let mut decoder = chunked_transfer::Decoder::new(reader); - decoder.read_to_end(&mut content)?; - content - } - } - }, - }; - - if !status.is_ok() { - // TODO: Handle 3xx redirection responses. - let error = HttpError { status_code: status.code.to_string(), contents }; - return Err(std::io::Error::new(std::io::ErrorKind::Other, error)); - } - - Ok(contents) + Ok(body) } } /// HTTP error consisting of a status code and body contents. #[derive(Debug)] -pub(crate) struct HttpError { - pub(crate) status_code: String, - pub(crate) contents: Vec, +pub struct HttpError { + /// The HTTP status code. + pub status_code: i32, + /// The response body contents. + pub contents: Vec, } impl std::error::Error for HttpError {} @@ -405,94 +270,6 @@ impl fmt::Display for HttpError { } } -/// HTTP response status code as defined by [RFC 7231]. -/// -/// [RFC 7231]: https://tools.ietf.org/html/rfc7231#section-6 -struct HttpStatus<'a> { - code: &'a str, -} - -impl<'a> HttpStatus<'a> { - /// Parses an HTTP status line as defined by [RFC 7230]. - /// - /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.1.2 - fn parse(line: &'a String) -> std::io::Result> { - let mut tokens = line.splitn(3, ' '); - - let http_version = tokens - .next() - .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no HTTP-Version"))?; - if !http_version.eq_ignore_ascii_case("HTTP/1.1") - && !http_version.eq_ignore_ascii_case("HTTP/1.0") - { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "invalid HTTP-Version", - )); - } - - let code = tokens - .next() - .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no Status-Code"))?; - if code.len() != 3 || !code.chars().all(|c| c.is_ascii_digit()) { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "invalid Status-Code", - )); - } - - let _reason = tokens - .next() - .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no Reason-Phrase"))?; - - Ok(Self { code }) - } - - /// Returns whether the status is successful (i.e., 2xx status class). - fn is_ok(&self) -> bool { - self.code.starts_with('2') - } -} - -/// HTTP response header as defined by [RFC 7231]. -/// -/// [RFC 7231]: https://tools.ietf.org/html/rfc7231#section-7 -struct HttpHeader<'a> { - name: &'a str, - value: &'a str, -} - -impl<'a> HttpHeader<'a> { - /// Parses an HTTP header field as defined by [RFC 7230]. - /// - /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.2 - fn parse(line: &'a String) -> std::io::Result> { - let mut tokens = line.splitn(2, ':'); - let name = tokens - .next() - .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no header name"))?; - let value = tokens - .next() - .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no header value"))? - .trim_start(); - Ok(Self { name, value }) - } - - /// Returns whether the header field has the given name. - fn has_name(&self, name: &str) -> bool { - self.name.eq_ignore_ascii_case(name) - } -} - -/// HTTP message body length as defined by [RFC 7230]. -/// -/// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.3.3 -enum HttpMessageLength { - Empty, - ContentLength(usize), - TransferEncoding(String), -} - /// An HTTP response body in binary format. pub struct BinaryResponse(pub Vec); @@ -572,27 +349,41 @@ mod endpoint_tests { #[cfg(test)] pub(crate) mod client_tests { use super::*; - use std::io::BufRead; - use std::io::Write; + use std::io::{BufRead, Read, Write}; /// Server for handling HTTP client requests with a stock response. pub struct HttpServer { address: std::net::SocketAddr, - handler: std::thread::JoinHandle<()>, + handler: Option>, shutdown: std::sync::Arc, } + impl Drop for HttpServer { + fn drop(&mut self) { + self.shutdown.store(true, std::sync::atomic::Ordering::SeqCst); + // Make a connection to unblock the listener's accept() call + let _ = std::net::TcpStream::connect(self.address); + if let Some(handler) = self.handler.take() { + let _ = handler.join(); + } + } + } + /// Body of HTTP response messages. pub enum MessageBody { Empty, Content(T), - ChunkedContent(T), } impl HttpServer { fn responding_with_body(status: &str, body: MessageBody) -> Self { let response = match body { - MessageBody::Empty => format!("{}\r\n\r\n", status), + MessageBody::Empty => format!( + "{}\r\n\ + Content-Length: 0\r\n\ + \r\n", + status + ), MessageBody::Content(body) => { let body = body.to_string(); format!( @@ -605,22 +396,6 @@ pub(crate) mod client_tests { body ) }, - MessageBody::ChunkedContent(body) => { - let mut chuncked_body = Vec::new(); - { - use chunked_transfer::Encoder; - let mut encoder = Encoder::with_chunks_size(&mut chuncked_body, 8); - encoder.write_all(body.to_string().as_bytes()).unwrap(); - } - format!( - "{}\r\n\ - Transfer-Encoding: chunked\r\n\ - \r\n\ - {}", - status, - String::from_utf8(chuncked_body).unwrap() - ) - }, }; HttpServer::responding_with(response) } @@ -645,39 +420,77 @@ pub(crate) mod client_tests { let shutdown = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); let shutdown_signaled = std::sync::Arc::clone(&shutdown); let handler = std::thread::spawn(move || { + let timeout = Duration::from_secs(5); for stream in listener.incoming() { - let mut stream = stream.unwrap(); - stream.set_write_timeout(Some(TCP_STREAM_TIMEOUT)).unwrap(); - - let lines_read = std::io::BufReader::new(&stream) - .lines() - .take_while(|line| !line.as_ref().unwrap().is_empty()) - .count(); - if lines_read == 0 { - continue; + if shutdown_signaled.load(std::sync::atomic::Ordering::SeqCst) { + return; } - for chunk in response.as_bytes().chunks(16) { + let stream = stream.unwrap(); + stream.set_write_timeout(Some(timeout)).unwrap(); + stream.set_read_timeout(Some(timeout)).unwrap(); + + let mut reader = std::io::BufReader::new(stream); + + // Handle multiple requests on the same connection (keep-alive) + loop { if shutdown_signaled.load(std::sync::atomic::Ordering::SeqCst) { return; - } else { - if let Err(_) = stream.write(chunk) { + } + + // Read request headers + let mut lines_read = 0; + let mut content_length: usize = 0; + loop { + let mut line = String::new(); + match reader.read_line(&mut line) { + Ok(0) => break, // eof + Ok(_) => { + if line == "\r\n" || line == "\n" { + break; // end of headers + } + // Parse content_length for POST body handling + if let Some(value) = line.strip_prefix("Content-Length:") { + content_length = value.trim().parse().unwrap_or(0); + } + lines_read += 1; + }, + Err(_) => break, // Read error or timeout + } + } + + if lines_read == 0 { + break; // No request received, connection closed + } + + // Consume request body if present (needed for POST keep-alive) + if content_length > 0 { + let mut body = vec![0u8; content_length]; + if reader.read_exact(&mut body).is_err() { break; } - if let Err(_) = stream.flush() { + } + + // Send response + let stream = reader.get_mut(); + let mut write_error = false; + for chunk in response.as_bytes().chunks(16) { + if shutdown_signaled.load(std::sync::atomic::Ordering::SeqCst) { + return; + } + if stream.write(chunk).is_err() || stream.flush().is_err() { + write_error = true; break; } } + if write_error { + break; + } } } }); - Self { address, handler, shutdown } - } - - fn shutdown(self) { - self.shutdown.store(true, std::sync::atomic::Ordering::SeqCst); - self.handler.join().unwrap(); + Self { address, handler: Some(handler), shutdown } } pub fn endpoint(&self) -> HttpEndpoint { @@ -703,121 +516,28 @@ pub(crate) mod client_tests { #[test] fn connect_with_no_socket_address() { match HttpClient::connect(&vec![][..]) { - Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput), + Err(HttpClientError::Io(e)) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput) + }, + Err(e) => panic!("Unexpected error type: {:?}", e), Ok(_) => panic!("Expected error"), } } - #[test] - fn connect_with_unknown_server() { - // get an unused port by binding to port 0 + #[tokio::test] + async fn request_to_unreachable_server() { + // Get an unused port by binding to port 0 let port = { let t = std::net::TcpListener::bind(("127.0.0.1", 0)).unwrap(); t.local_addr().unwrap().port() }; - match HttpClient::connect(("::", port)) { - #[cfg(target_os = "windows")] - Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::AddrNotAvailable), - #[cfg(not(target_os = "windows"))] - Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::ConnectionRefused), - Ok(_) => panic!("Expected error"), - } - } - - #[tokio::test] - async fn connect_with_valid_endpoint() { - let server = HttpServer::responding_with_ok::(MessageBody::Empty); - - match HttpClient::connect(&server.endpoint()) { - Err(e) => panic!("Unexpected error: {:?}", e), - Ok(_) => {}, - } - } - - #[tokio::test] - async fn read_empty_message() { - let server = HttpServer::responding_with("".to_string()); - - let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - match client.get::("/foo", "foo.com").await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof); - assert_eq!(e.get_ref().unwrap().to_string(), "no status line"); - }, - Ok(_) => panic!("Expected error"), - } - } - - #[tokio::test] - async fn read_incomplete_message() { - let server = HttpServer::responding_with("HTTP/1.1 200 OK".to_string()); - - let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - match client.get::("/foo", "foo.com").await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof); - assert_eq!(e.get_ref().unwrap().to_string(), "no headers"); - }, - Ok(_) => panic!("Expected error"), - } - } - - #[tokio::test] - async fn read_too_large_message_headers() { - let response = format!( - "HTTP/1.1 302 Found\r\n\ - Location: {}\r\n\ - \r\n", - "Z".repeat(MAX_HTTP_MESSAGE_HEADER_SIZE) - ); - let server = HttpServer::responding_with(response); - - let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - match client.get::("/foo", "foo.com").await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof); - assert_eq!(e.get_ref().unwrap().to_string(), "no headers"); - }, - Ok(_) => panic!("Expected error"), - } - } - - #[tokio::test] - async fn read_too_large_message_body() { - let body = "Z".repeat(MAX_HTTP_MESSAGE_BODY_SIZE + 1); - let server = HttpServer::responding_with_ok::(MessageBody::Content(body)); - - let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - match client.get::("/foo", "foo.com").await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); - assert_eq!( - e.get_ref().unwrap().to_string(), - "invalid response length: 8032001 bytes" - ); - }, - Ok(_) => panic!("Expected error"), - } - server.shutdown(); - } - - #[tokio::test] - async fn read_message_with_unsupported_transfer_coding() { - let response = String::from( - "HTTP/1.1 200 OK\r\n\ - Transfer-Encoding: gzip\r\n\ - \r\n\ - foobar", - ); - let server = HttpServer::responding_with(response); - - let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - match client.get::("/foo", "foo.com").await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput); - assert_eq!(e.get_ref().unwrap().to_string(), "unsupported transfer coding"); - }, + // Connect succeeds (just resolves address), but request fails + let endpoint = HttpEndpoint::for_host("127.0.0.1".to_string()).with_port(port); + let mut client = HttpClient::connect(&endpoint).unwrap(); + match client.get::("/").await { + Err(HttpClientError::Transport(_)) => {}, + Err(e) => panic!("Unexpected error type: {:?}", e), Ok(_) => panic!("Expected error"), } } @@ -827,49 +547,24 @@ pub(crate) mod client_tests { let server = HttpServer::responding_with_server_error("foo"); let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - match client.get::("/foo", "foo.com").await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::Other); - let http_error = e.into_inner().unwrap().downcast::().unwrap(); - assert_eq!(http_error.status_code, "500"); + match client.get::("/foo").await { + Err(HttpClientError::Http(http_error)) => { + assert_eq!(http_error.status_code, 500); assert_eq!(http_error.contents, "foo".as_bytes()); }, + Err(e) => panic!("Unexpected error type: {:?}", e), Ok(_) => panic!("Expected error"), } } #[tokio::test] - async fn read_empty_message_body() { - let server = HttpServer::responding_with_ok::(MessageBody::Empty); - - let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - match client.get::("/foo", "foo.com").await { - Err(e) => panic!("Unexpected error: {:?}", e), - Ok(bytes) => assert_eq!(bytes.0, Vec::::new()), - } - } - - #[tokio::test] - async fn read_message_body_with_length() { + async fn read_message_body() { let body = "foo bar baz qux".repeat(32); let content = MessageBody::Content(body.clone()); let server = HttpServer::responding_with_ok::(content); let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - match client.get::("/foo", "foo.com").await { - Err(e) => panic!("Unexpected error: {:?}", e), - Ok(bytes) => assert_eq!(bytes.0, body.as_bytes()), - } - } - - #[tokio::test] - async fn read_chunked_message_body() { - let body = "foo bar baz qux".repeat(32); - let chunked_content = MessageBody::ChunkedContent(body.clone()); - let server = HttpServer::responding_with_ok::(chunked_content); - - let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - match client.get::("/foo", "foo.com").await { + match client.get::("/foo").await { Err(e) => panic!("Unexpected error: {:?}", e), Ok(bytes) => assert_eq!(bytes.0, body.as_bytes()), } @@ -880,8 +575,8 @@ pub(crate) mod client_tests { let server = HttpServer::responding_with_ok::(MessageBody::Empty); let mut client = HttpClient::connect(&server.endpoint()).unwrap(); - assert!(client.get::("/foo", "foo.com").await.is_ok()); - match client.get::("/foo", "foo.com").await { + assert!(client.get::("/foo").await.is_ok()); + match client.get::("/foo").await { Err(e) => panic!("Unexpected error: {:?}", e), Ok(bytes) => assert_eq!(bytes.0, Vec::::new()), } diff --git a/lightning-block-sync/src/rest.rs b/lightning-block-sync/src/rest.rs index 619981bb4d0..a6482b1ff6d 100644 --- a/lightning-block-sync/src/rest.rs +++ b/lightning-block-sync/src/rest.rs @@ -3,7 +3,7 @@ use crate::convert::GetUtxosResponse; use crate::gossip::UtxoSource; -use crate::http::{BinaryResponse, HttpClient, HttpEndpoint, JsonResponse}; +use crate::http::{BinaryResponse, HttpClient, HttpClientError, HttpEndpoint, JsonResponse}; use crate::{BlockData, BlockHeaderData, BlockSource, BlockSourceResult}; use bitcoin::hash_types::BlockHash; @@ -29,11 +29,10 @@ impl RestClient { } /// Requests a resource encoded in `F` format and interpreted as type `T`. - pub async fn request_resource(&self, resource_path: &str) -> std::io::Result + pub async fn request_resource(&self, resource_path: &str) -> Result where F: TryFrom, Error = std::io::Error> + TryInto, { - let host = format!("{}:{}", self.endpoint.host(), self.endpoint.port()); let uri = format!("{}/{}", self.endpoint.path().trim_end_matches("/"), resource_path); let reserved_client = self.client.lock().unwrap().take(); let mut client = if let Some(client) = reserved_client { @@ -41,7 +40,7 @@ impl RestClient { } else { HttpClient::connect(&self.endpoint)? }; - let res = client.get::(&uri, &host).await?.try_into(); + let res = client.get::(&uri).await?.try_into().map_err(HttpClientError::Io); *self.client.lock().unwrap() = Some(client); res } @@ -126,7 +125,8 @@ mod tests { let client = RestClient::new(server.endpoint()); match client.request_resource::("/").await { - Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::Other), + Err(HttpClientError::Http(e)) => assert_eq!(e.status_code, 404), + Err(e) => panic!("Unexpected error type: {:?}", e), Ok(_) => panic!("Expected error"), } } @@ -137,7 +137,8 @@ mod tests { let client = RestClient::new(server.endpoint()); match client.request_resource::("/").await { - Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::InvalidData), + Err(HttpClientError::Io(_)) => {}, + Err(e) => panic!("Unexpected error type: {:?}", e), Ok(_) => panic!("Expected error"), } } diff --git a/lightning-block-sync/src/rpc.rs b/lightning-block-sync/src/rpc.rs index d851ba2ccf0..60ef62eb6a6 100644 --- a/lightning-block-sync/src/rpc.rs +++ b/lightning-block-sync/src/rpc.rs @@ -2,7 +2,7 @@ //! endpoint. use crate::gossip::UtxoSource; -use crate::http::{HttpClient, HttpEndpoint, HttpError, JsonResponse}; +use crate::http::{HttpClient, HttpClientError, HttpEndpoint, JsonResponse}; use crate::{BlockData, BlockHeaderData, BlockSource, BlockSourceResult}; use bitcoin::hash_types::BlockHash; @@ -36,6 +36,49 @@ impl fmt::Display for RpcError { impl Error for RpcError {} +/// Error type for RPC client operations. +#[derive(Debug)] +pub enum RpcClientError { + /// An HTTP client error (transport or HTTP error). + Http(HttpClientError), + /// An RPC error returned by the server. + Rpc(RpcError), + /// Invalid data in the response. + InvalidData(String), +} + +impl std::error::Error for RpcClientError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + RpcClientError::Http(e) => Some(e), + RpcClientError::Rpc(e) => Some(e), + RpcClientError::InvalidData(_) => None, + } + } +} + +impl fmt::Display for RpcClientError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + RpcClientError::Http(e) => write!(f, "HTTP error: {}", e), + RpcClientError::Rpc(e) => write!(f, "{}", e), + RpcClientError::InvalidData(msg) => write!(f, "invalid data: {}", msg), + } + } +} + +impl From for RpcClientError { + fn from(e: HttpClientError) -> Self { + RpcClientError::Http(e) + } +} + +impl From for RpcClientError { + fn from(e: RpcError) -> Self { + RpcClientError::Rpc(e) + } +} + /// A simple RPC client for calling methods using HTTP `POST`. /// /// Implements [`BlockSource`] and may return an `Err` containing [`RpcError`]. See @@ -61,16 +104,12 @@ impl RpcClient { } /// Calls a method with the response encoded in JSON format and interpreted as type `T`. - /// - /// When an `Err` is returned, [`std::io::Error::into_inner`] may contain an [`RpcError`] if - /// [`std::io::Error::kind`] is [`std::io::ErrorKind::Other`]. pub async fn call_method( &self, method: &str, params: &[serde_json::Value], - ) -> std::io::Result + ) -> Result where JsonResponse: TryFrom, Error = std::io::Error> + TryInto, { - let host = format!("{}:{}", self.endpoint.host(), self.endpoint.port()); let uri = self.endpoint.path(); let content = serde_json::json!({ "method": method, @@ -84,52 +123,42 @@ impl RpcClient { } else { HttpClient::connect(&self.endpoint)? }; - let http_response = - client.post::(&uri, &host, &self.basic_auth, content).await; + let http_response = client.post::(&uri, &self.basic_auth, content).await; *self.client.lock().unwrap() = Some(client); let mut response = match http_response { Ok(JsonResponse(response)) => response, - Err(e) if e.kind() == std::io::ErrorKind::Other => { - match e.get_ref().unwrap().downcast_ref::() { - Some(http_error) => match JsonResponse::try_from(http_error.contents.clone()) { - Ok(JsonResponse(response)) => response, - Err(_) => Err(e)?, - }, - None => Err(e)?, + Err(HttpClientError::Http(http_error)) => { + // Try to parse the error body as JSON-RPC response + match JsonResponse::try_from(http_error.contents.clone()) { + Ok(JsonResponse(response)) => response, + Err(_) => return Err(HttpClientError::Http(http_error).into()), } }, - Err(e) => Err(e)?, + Err(e) => return Err(e.into()), }; if !response.is_object() { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "expected JSON object", - )); + return Err(RpcClientError::InvalidData("expected JSON object".to_string())); } let error = &response["error"]; if !error.is_null() { - // TODO: Examine error code for a more precise std::io::ErrorKind. let rpc_error = RpcError { code: error["code"].as_i64().unwrap_or(-1), message: error["message"].as_str().unwrap_or("unknown error").to_string(), }; - return Err(std::io::Error::new(std::io::ErrorKind::Other, rpc_error)); + return Err(rpc_error.into()); } let result = match response.get_mut("result") { Some(result) => result.take(), - None => { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "expected JSON result", - )) - }, + None => return Err(RpcClientError::InvalidData("expected JSON result".to_string())), }; - JsonResponse(result).try_into() + JsonResponse(result) + .try_into() + .map_err(|e: std::io::Error| RpcClientError::InvalidData(e.to_string())) } } @@ -212,7 +241,10 @@ mod tests { let client = RpcClient::new(CREDENTIALS, server.endpoint()); match client.call_method::("getblockcount", &[]).await { - Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::Other), + Err(RpcClientError::Http(HttpClientError::Http(e))) => { + assert_eq!(e.status_code, 404); + }, + Err(e) => panic!("Unexpected error type: {:?}", e), Ok(_) => panic!("Expected error"), } } @@ -224,10 +256,10 @@ mod tests { let client = RpcClient::new(CREDENTIALS, server.endpoint()); match client.call_method::("getblockcount", &[]).await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); - assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON object"); + Err(RpcClientError::InvalidData(msg)) => { + assert_eq!(msg, "expected JSON object"); }, + Err(e) => panic!("Unexpected error type: {:?}", e), Ok(_) => panic!("Expected error"), } } @@ -242,12 +274,11 @@ mod tests { let invalid_block_hash = serde_json::json!("foo"); match client.call_method::("getblock", &[invalid_block_hash]).await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::Other); - let rpc_error: Box = e.into_inner().unwrap().downcast().unwrap(); + Err(RpcClientError::Rpc(rpc_error)) => { assert_eq!(rpc_error.code, -8); assert_eq!(rpc_error.message, "invalid parameter"); }, + Err(e) => panic!("Unexpected error type: {:?}", e), Ok(_) => panic!("Expected error"), } } @@ -259,10 +290,10 @@ mod tests { let client = RpcClient::new(CREDENTIALS, server.endpoint()); match client.call_method::("getblockcount", &[]).await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); - assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON result"); + Err(RpcClientError::InvalidData(msg)) => { + assert_eq!(msg, "expected JSON result"); }, + Err(e) => panic!("Unexpected error type: {:?}", e), Ok(_) => panic!("Expected error"), } } @@ -274,10 +305,10 @@ mod tests { let client = RpcClient::new(CREDENTIALS, server.endpoint()); match client.call_method::("getblockcount", &[]).await { - Err(e) => { - assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); - assert_eq!(e.get_ref().unwrap().to_string(), "not a number"); + Err(RpcClientError::InvalidData(msg)) => { + assert!(msg.contains("not a number")); }, + Err(e) => panic!("Unexpected error type: {:?}", e), Ok(_) => panic!("Expected error"), } }