mirror of
https://github.com/tursodatabase/libsql.git
synced 2024-12-16 05:38:47 +00:00
954029e31f
This adds a new `Builder` type that can now be used to construct the `Database` type. This will scale better as we add more varied options. This commit also deprecates the old builder types and will produce a warning that will push users to using the new `Builder` type. This will then allow us to remove the old deprecated constructors at some point in the future.
186 lines
5.6 KiB
Rust
186 lines
5.6 KiB
Rust
#![allow(deprecated)]
|
|
|
|
use std::io::Error as IoError;
|
|
use std::net::SocketAddr;
|
|
use std::pin::Pin;
|
|
use std::sync::Once;
|
|
use std::task::{Context, Poll};
|
|
|
|
use futures_core::Future;
|
|
use hyper::client::connect::Connected;
|
|
use hyper::server::accept::Accept as HyperAccept;
|
|
use hyper::Uri;
|
|
use metrics_util::debugging::DebuggingRecorder;
|
|
use tokio::io::{AsyncRead, AsyncWrite};
|
|
use tower::Service;
|
|
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
|
|
|
use libsql_server::net::Accept;
|
|
use libsql_server::net::AddrStream;
|
|
use libsql_server::Server;
|
|
|
|
type TurmoilAddrStream = AddrStream<turmoil::net::TcpStream>;
|
|
|
|
pub struct TurmoilAcceptor {
|
|
acceptor: Pin<
|
|
Box<dyn HyperAccept<Conn = TurmoilAddrStream, Error = IoError> + Send + Sync + 'static>,
|
|
>,
|
|
}
|
|
|
|
impl TurmoilAcceptor {
|
|
pub async fn bind(addr: impl Into<SocketAddr>) -> std::io::Result<Self> {
|
|
let addr = addr.into();
|
|
let stream = async_stream::stream! {
|
|
let listener = turmoil::net::TcpListener::bind(addr).await?;
|
|
loop {
|
|
yield listener.accept().await.and_then(|(stream, remote_addr)| Ok(AddrStream {
|
|
remote_addr,
|
|
local_addr: stream.local_addr()?,
|
|
stream,
|
|
}));
|
|
}
|
|
};
|
|
let acceptor = hyper::server::accept::from_stream(stream);
|
|
Ok(Self {
|
|
acceptor: Box::pin(acceptor),
|
|
})
|
|
}
|
|
}
|
|
|
|
impl Accept for TurmoilAcceptor {
|
|
type Connection = TurmoilAddrStream;
|
|
}
|
|
|
|
impl HyperAccept for TurmoilAcceptor {
|
|
type Conn = TurmoilAddrStream;
|
|
type Error = IoError;
|
|
|
|
fn poll_accept(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
|
|
self.acceptor.as_mut().poll_accept(cx)
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct TurmoilConnector;
|
|
|
|
pin_project_lite::pin_project! {
|
|
pub struct TurmoilStream {
|
|
#[pin]
|
|
inner: turmoil::net::TcpStream,
|
|
}
|
|
}
|
|
|
|
impl AsyncWrite for TurmoilStream {
|
|
fn poll_write(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<std::io::Result<usize>> {
|
|
self.project().inner.poll_write(cx, buf)
|
|
}
|
|
|
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
|
self.project().inner.poll_flush(cx)
|
|
}
|
|
|
|
fn poll_shutdown(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
) -> Poll<Result<(), std::io::Error>> {
|
|
self.project().inner.poll_shutdown(cx)
|
|
}
|
|
}
|
|
|
|
impl AsyncRead for TurmoilStream {
|
|
fn poll_read(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut tokio::io::ReadBuf<'_>,
|
|
) -> Poll<std::io::Result<()>> {
|
|
self.project().inner.poll_read(cx, buf)
|
|
}
|
|
}
|
|
|
|
impl hyper::client::connect::Connection for TurmoilStream {
|
|
fn connected(&self) -> hyper::client::connect::Connected {
|
|
Connected::new()
|
|
}
|
|
}
|
|
|
|
impl Service<Uri> for TurmoilConnector {
|
|
type Response = TurmoilStream;
|
|
type Error = IoError;
|
|
type Future = Pin<Box<dyn Future<Output = std::io::Result<Self::Response>> + Send + 'static>>;
|
|
|
|
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
|
|
fn call(&mut self, uri: Uri) -> Self::Future {
|
|
Box::pin(async move {
|
|
let host = uri.host().unwrap();
|
|
let host = host.split('.').collect::<Vec<_>>();
|
|
// get the domain from `namespace.domain` and `domain` hosts
|
|
let domain = if host.len() == 1 { host[0] } else { host[1] };
|
|
let addr = turmoil::lookup(domain);
|
|
let port = uri.port().unwrap().as_u16();
|
|
let inner = turmoil::net::TcpStream::connect((addr, port)).await?;
|
|
Ok(TurmoilStream { inner })
|
|
})
|
|
}
|
|
}
|
|
|
|
pub type TestServer = Server<TurmoilConnector, TurmoilAcceptor, TurmoilConnector>;
|
|
|
|
#[async_trait::async_trait]
|
|
pub trait SimServer {
|
|
async fn start_sim(self, user_api_port: usize) -> anyhow::Result<()>;
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl SimServer for TestServer {
|
|
async fn start_sim(mut self, user_api_port: usize) -> anyhow::Result<()> {
|
|
let _ = tracing_subscriber::fmt::try_init();
|
|
|
|
// We need to ensure that libsql's init code runs before we do anything
|
|
// with rusqlite in sqld. This is because libsql has saftey checks and
|
|
// needs to configure the sqlite api. Thus if we init sqld first
|
|
// it will fail. To work around this we open a temp db in memory db
|
|
// to ensure we run libsql's init code first. This DB is not actually
|
|
// used in the test only for its run once init code.
|
|
//
|
|
// This does change the serialization mode for sqld but because the mode
|
|
// that we use in libsql is safer than the sqld one it is still safe.
|
|
let db = libsql::Database::open_in_memory().unwrap();
|
|
db.connect().unwrap();
|
|
|
|
// Ignore the result because we may set it many times in a single process.
|
|
let _ = DebuggingRecorder::per_thread().install();
|
|
|
|
let user_api = TurmoilAcceptor::bind(([0, 0, 0, 0], user_api_port as u16)).await?;
|
|
self.user_api_config.http_acceptor = Some(user_api);
|
|
|
|
// Disable prom metrics since we already created our recorder.
|
|
if let Some(admin_api) = &mut self.admin_api_config {
|
|
admin_api.disable_metrics = true;
|
|
}
|
|
|
|
self.start().await?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
pub fn init_tracing() {
|
|
static INIT_TRACING: Once = Once::new();
|
|
INIT_TRACING.call_once(|| {
|
|
tracing_subscriber::registry()
|
|
.with(fmt::layer())
|
|
.with(EnvFilter::from_default_env())
|
|
.init();
|
|
});
|
|
}
|