0
0
mirror of https://github.com/tursodatabase/libsql.git synced 2024-11-22 02:16:15 +00:00
libsql/libsql-server/tests/common/net.rs
Lucio Franco 954029e31f
libsql: add Builder to construct the db (#993)
This adds a new `Builder` type that can now be used to construct the
`Database` type. This will scale better as we add more varied options.
This commit also deprecates the old builder types and will produce a
warning that will push users to using the new `Builder` type. This will
then allow us to remove the old deprecated constructors at some point in
the future.
2024-02-06 13:59:11 +00:00

186 lines
5.6 KiB
Rust

#![allow(deprecated)]
use std::io::Error as IoError;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Once;
use std::task::{Context, Poll};
use futures_core::Future;
use hyper::client::connect::Connected;
use hyper::server::accept::Accept as HyperAccept;
use hyper::Uri;
use metrics_util::debugging::DebuggingRecorder;
use tokio::io::{AsyncRead, AsyncWrite};
use tower::Service;
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
use libsql_server::net::Accept;
use libsql_server::net::AddrStream;
use libsql_server::Server;
type TurmoilAddrStream = AddrStream<turmoil::net::TcpStream>;
pub struct TurmoilAcceptor {
acceptor: Pin<
Box<dyn HyperAccept<Conn = TurmoilAddrStream, Error = IoError> + Send + Sync + 'static>,
>,
}
impl TurmoilAcceptor {
pub async fn bind(addr: impl Into<SocketAddr>) -> std::io::Result<Self> {
let addr = addr.into();
let stream = async_stream::stream! {
let listener = turmoil::net::TcpListener::bind(addr).await?;
loop {
yield listener.accept().await.and_then(|(stream, remote_addr)| Ok(AddrStream {
remote_addr,
local_addr: stream.local_addr()?,
stream,
}));
}
};
let acceptor = hyper::server::accept::from_stream(stream);
Ok(Self {
acceptor: Box::pin(acceptor),
})
}
}
impl Accept for TurmoilAcceptor {
type Connection = TurmoilAddrStream;
}
impl HyperAccept for TurmoilAcceptor {
type Conn = TurmoilAddrStream;
type Error = IoError;
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
self.acceptor.as_mut().poll_accept(cx)
}
}
#[derive(Clone)]
pub struct TurmoilConnector;
pin_project_lite::pin_project! {
pub struct TurmoilStream {
#[pin]
inner: turmoil::net::TcpStream,
}
}
impl AsyncWrite for TurmoilStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.project().inner.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
self.project().inner.poll_shutdown(cx)
}
}
impl AsyncRead for TurmoilStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.project().inner.poll_read(cx, buf)
}
}
impl hyper::client::connect::Connection for TurmoilStream {
fn connected(&self) -> hyper::client::connect::Connected {
Connected::new()
}
}
impl Service<Uri> for TurmoilConnector {
type Response = TurmoilStream;
type Error = IoError;
type Future = Pin<Box<dyn Future<Output = std::io::Result<Self::Response>> + Send + 'static>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, uri: Uri) -> Self::Future {
Box::pin(async move {
let host = uri.host().unwrap();
let host = host.split('.').collect::<Vec<_>>();
// get the domain from `namespace.domain` and `domain` hosts
let domain = if host.len() == 1 { host[0] } else { host[1] };
let addr = turmoil::lookup(domain);
let port = uri.port().unwrap().as_u16();
let inner = turmoil::net::TcpStream::connect((addr, port)).await?;
Ok(TurmoilStream { inner })
})
}
}
pub type TestServer = Server<TurmoilConnector, TurmoilAcceptor, TurmoilConnector>;
#[async_trait::async_trait]
pub trait SimServer {
async fn start_sim(self, user_api_port: usize) -> anyhow::Result<()>;
}
#[async_trait::async_trait]
impl SimServer for TestServer {
async fn start_sim(mut self, user_api_port: usize) -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
// We need to ensure that libsql's init code runs before we do anything
// with rusqlite in sqld. This is because libsql has saftey checks and
// needs to configure the sqlite api. Thus if we init sqld first
// it will fail. To work around this we open a temp db in memory db
// to ensure we run libsql's init code first. This DB is not actually
// used in the test only for its run once init code.
//
// This does change the serialization mode for sqld but because the mode
// that we use in libsql is safer than the sqld one it is still safe.
let db = libsql::Database::open_in_memory().unwrap();
db.connect().unwrap();
// Ignore the result because we may set it many times in a single process.
let _ = DebuggingRecorder::per_thread().install();
let user_api = TurmoilAcceptor::bind(([0, 0, 0, 0], user_api_port as u16)).await?;
self.user_api_config.http_acceptor = Some(user_api);
// Disable prom metrics since we already created our recorder.
if let Some(admin_api) = &mut self.admin_api_config {
admin_api.disable_metrics = true;
}
self.start().await?;
Ok(())
}
}
pub fn init_tracing() {
static INIT_TRACING: Once = Once::new();
INIT_TRACING.call_once(|| {
tracing_subscriber::registry()
.with(fmt::layer())
.with(EnvFilter::from_default_env())
.init();
});
}