| use std::future::Future; |
| use std::io::{self, BufRead as _}; |
| #[cfg(unix)] |
| use std::os::unix::io::{AsRawFd, RawFd}; |
| #[cfg(windows)] |
| use std::os::windows::io::{AsRawSocket, RawSocket}; |
| use std::pin::Pin; |
| use std::sync::Arc; |
| use std::task::{Context, Poll}; |
| |
| use rustls::server::AcceptedAlert; |
| use rustls::{ServerConfig, ServerConnection}; |
| use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; |
| |
| use crate::common::{IoSession, MidHandshake, Stream, SyncReadAdapter, SyncWriteAdapter, TlsState}; |
| |
| /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. |
| #[derive(Clone)] |
| pub struct TlsAcceptor { |
| inner: Arc<ServerConfig>, |
| } |
| |
| impl From<Arc<ServerConfig>> for TlsAcceptor { |
| fn from(inner: Arc<ServerConfig>) -> Self { |
| Self { inner } |
| } |
| } |
| |
| impl TlsAcceptor { |
| #[inline] |
| pub fn accept<IO>(&self, stream: IO) -> Accept<IO> |
| where |
| IO: AsyncRead + AsyncWrite + Unpin, |
| { |
| self.accept_with(stream, |_| ()) |
| } |
| |
| pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO> |
| where |
| IO: AsyncRead + AsyncWrite + Unpin, |
| F: FnOnce(&mut ServerConnection), |
| { |
| let mut session = match ServerConnection::new(self.inner.clone()) { |
| Ok(session) => session, |
| Err(error) => { |
| return Accept(MidHandshake::Error { |
| io: stream, |
| // TODO(eliza): should this really return an `io::Error`? |
| // Probably not... |
| error: io::Error::new(io::ErrorKind::Other, error), |
| }); |
| } |
| }; |
| f(&mut session); |
| |
| Accept(MidHandshake::Handshaking(TlsStream { |
| session, |
| io: stream, |
| state: TlsState::Stream, |
| need_flush: false, |
| })) |
| } |
| |
| /// Get a read-only reference to underlying config |
| pub fn config(&self) -> &Arc<ServerConfig> { |
| &self.inner |
| } |
| } |
| |
| pub struct LazyConfigAcceptor<IO> { |
| acceptor: rustls::server::Acceptor, |
| io: Option<IO>, |
| alert: Option<(rustls::Error, AcceptedAlert)>, |
| } |
| |
| impl<IO> LazyConfigAcceptor<IO> |
| where |
| IO: AsyncRead + AsyncWrite + Unpin, |
| { |
| #[inline] |
| pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self { |
| Self { |
| acceptor, |
| io: Some(io), |
| alert: None, |
| } |
| } |
| |
| /// Takes back the client connection. Will return `None` if called more than once or if the |
| /// connection has been accepted. |
| /// |
| /// # Example |
| /// |
| /// ```no_run |
| /// # fn choose_server_config( |
| /// # _: rustls::server::ClientHello, |
| /// # ) -> std::sync::Arc<rustls::ServerConfig> { |
| /// # unimplemented!(); |
| /// # } |
| /// # #[allow(unused_variables)] |
| /// # async fn listen() { |
| /// use tokio::io::AsyncWriteExt; |
| /// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap(); |
| /// let (stream, _) = listener.accept().await.unwrap(); |
| /// |
| /// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream); |
| /// tokio::pin!(acceptor); |
| /// |
| /// match acceptor.as_mut().await { |
| /// Ok(start) => { |
| /// let clientHello = start.client_hello(); |
| /// let config = choose_server_config(clientHello); |
| /// let stream = start.into_stream(config).await.unwrap(); |
| /// // Proceed with handling the ServerConnection... |
| /// } |
| /// Err(err) => { |
| /// if let Some(mut stream) = acceptor.take_io() { |
| /// stream |
| /// .write_all( |
| /// format!("HTTP/1.1 400 Invalid Input\r\n\r\n\r\n{:?}\n", err) |
| /// .as_bytes() |
| /// ) |
| /// .await |
| /// .unwrap(); |
| /// } |
| /// } |
| /// } |
| /// # } |
| /// ``` |
| pub fn take_io(&mut self) -> Option<IO> { |
| self.io.take() |
| } |
| } |
| |
| impl<IO> Future for LazyConfigAcceptor<IO> |
| where |
| IO: AsyncRead + AsyncWrite + Unpin, |
| { |
| type Output = Result<StartHandshake<IO>, io::Error>; |
| |
| fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| let this = self.get_mut(); |
| loop { |
| let io = match this.io.as_mut() { |
| Some(io) => io, |
| None => { |
| return Poll::Ready(Err(io::Error::new( |
| io::ErrorKind::Other, |
| "acceptor cannot be polled after acceptance", |
| ))) |
| } |
| }; |
| |
| if let Some((err, mut alert)) = this.alert.take() { |
| match alert.write(&mut SyncWriteAdapter { io, cx }) { |
| Err(e) if e.kind() == io::ErrorKind::WouldBlock => { |
| this.alert = Some((err, alert)); |
| return Poll::Pending; |
| } |
| Ok(0) | Err(_) => { |
| return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err))) |
| } |
| Ok(_) => { |
| this.alert = Some((err, alert)); |
| continue; |
| } |
| }; |
| } |
| |
| let mut reader = SyncReadAdapter { io, cx }; |
| match this.acceptor.read_tls(&mut reader) { |
| Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(), |
| Ok(_) => {} |
| Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, |
| Err(e) => return Err(e).into(), |
| } |
| |
| match this.acceptor.accept() { |
| Ok(Some(accepted)) => { |
| let io = this.io.take().unwrap(); |
| return Poll::Ready(Ok(StartHandshake { accepted, io })); |
| } |
| Ok(None) => {} |
| Err((err, alert)) => { |
| this.alert = Some((err, alert)); |
| } |
| } |
| } |
| } |
| } |
| |
| /// An incoming connection received through [`LazyConfigAcceptor`]. |
| /// |
| /// This contains the generic `IO` asynchronous transport, |
| /// [`ClientHello`](rustls::server::ClientHello) data, |
| /// and all the state required to continue the TLS handshake (e.g. via |
| /// [`StartHandshake::into_stream`]). |
| #[non_exhaustive] |
| #[derive(Debug)] |
| pub struct StartHandshake<IO> { |
| pub accepted: rustls::server::Accepted, |
| pub io: IO, |
| } |
| |
| impl<IO> StartHandshake<IO> |
| where |
| IO: AsyncRead + AsyncWrite + Unpin, |
| { |
| /// Create a new object from an `IO` transport and prior TLS metadata. |
| pub fn from_parts(accepted: rustls::server::Accepted, transport: IO) -> Self { |
| Self { |
| accepted, |
| io: transport, |
| } |
| } |
| |
| pub fn client_hello(&self) -> rustls::server::ClientHello<'_> { |
| self.accepted.client_hello() |
| } |
| |
| pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> { |
| self.into_stream_with(config, |_| ()) |
| } |
| |
| pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO> |
| where |
| F: FnOnce(&mut ServerConnection), |
| { |
| let mut conn = match self.accepted.into_connection(config) { |
| Ok(conn) => conn, |
| Err((error, alert)) => { |
| return Accept(MidHandshake::SendAlert { |
| io: self.io, |
| alert, |
| // TODO(eliza): should this really return an `io::Error`? |
| // Probably not... |
| error: io::Error::new(io::ErrorKind::InvalidData, error), |
| }); |
| } |
| }; |
| f(&mut conn); |
| |
| Accept(MidHandshake::Handshaking(TlsStream { |
| session: conn, |
| io: self.io, |
| state: TlsState::Stream, |
| need_flush: false, |
| })) |
| } |
| } |
| |
| /// Future returned from `TlsAcceptor::accept` which will resolve |
| /// once the accept handshake has finished. |
| pub struct Accept<IO>(MidHandshake<TlsStream<IO>>); |
| |
| impl<IO> Accept<IO> { |
| #[inline] |
| pub fn into_fallible(self) -> FallibleAccept<IO> { |
| FallibleAccept(self.0) |
| } |
| |
| pub fn get_ref(&self) -> Option<&IO> { |
| match &self.0 { |
| MidHandshake::Handshaking(sess) => Some(sess.get_ref().0), |
| MidHandshake::SendAlert { io, .. } => Some(io), |
| MidHandshake::Error { io, .. } => Some(io), |
| MidHandshake::End => None, |
| } |
| } |
| |
| pub fn get_mut(&mut self) -> Option<&mut IO> { |
| match &mut self.0 { |
| MidHandshake::Handshaking(sess) => Some(sess.get_mut().0), |
| MidHandshake::SendAlert { io, .. } => Some(io), |
| MidHandshake::Error { io, .. } => Some(io), |
| MidHandshake::End => None, |
| } |
| } |
| } |
| |
| impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> { |
| type Output = io::Result<TlsStream<IO>>; |
| |
| #[inline] |
| fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err) |
| } |
| } |
| |
| /// Like [Accept], but returns `IO` on failure. |
| pub struct FallibleAccept<IO>(MidHandshake<TlsStream<IO>>); |
| |
| impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> { |
| type Output = Result<TlsStream<IO>, (io::Error, IO)>; |
| |
| #[inline] |
| fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| Pin::new(&mut self.0).poll(cx) |
| } |
| } |
| |
| /// A wrapper around an underlying raw stream which implements the TLS or SSL |
| /// protocol. |
| #[derive(Debug)] |
| pub struct TlsStream<IO> { |
| pub(crate) io: IO, |
| pub(crate) session: ServerConnection, |
| pub(crate) state: TlsState, |
| pub(crate) need_flush: bool, |
| } |
| |
| impl<IO> TlsStream<IO> { |
| #[inline] |
| pub fn get_ref(&self) -> (&IO, &ServerConnection) { |
| (&self.io, &self.session) |
| } |
| |
| #[inline] |
| pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) { |
| (&mut self.io, &mut self.session) |
| } |
| |
| #[inline] |
| pub fn into_inner(self) -> (IO, ServerConnection) { |
| (self.io, self.session) |
| } |
| } |
| |
| impl<IO> IoSession for TlsStream<IO> { |
| type Io = IO; |
| type Session = ServerConnection; |
| |
| #[inline] |
| fn skip_handshake(&self) -> bool { |
| false |
| } |
| |
| #[inline] |
| fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session, &mut bool) { |
| ( |
| &mut self.state, |
| &mut self.io, |
| &mut self.session, |
| &mut self.need_flush, |
| ) |
| } |
| |
| #[inline] |
| fn into_io(self) -> Self::Io { |
| self.io |
| } |
| } |
| |
| impl<IO> AsyncRead for TlsStream<IO> |
| where |
| IO: AsyncRead + AsyncWrite + Unpin, |
| { |
| fn poll_read( |
| mut self: Pin<&mut Self>, |
| cx: &mut Context<'_>, |
| buf: &mut ReadBuf<'_>, |
| ) -> Poll<io::Result<()>> { |
| let data = ready!(self.as_mut().poll_fill_buf(cx))?; |
| let len = data.len().min(buf.remaining()); |
| buf.put_slice(&data[..len]); |
| self.consume(len); |
| Poll::Ready(Ok(())) |
| } |
| } |
| |
| impl<IO> AsyncBufRead for TlsStream<IO> |
| where |
| IO: AsyncRead + AsyncWrite + Unpin, |
| { |
| fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { |
| match self.state { |
| TlsState::Stream | TlsState::WriteShutdown => { |
| let this = self.get_mut(); |
| let stream = |
| Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
| |
| match stream.poll_fill_buf(cx) { |
| Poll::Ready(Ok(buf)) => { |
| if buf.is_empty() { |
| this.state.shutdown_read(); |
| } |
| |
| Poll::Ready(Ok(buf)) |
| } |
| Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => { |
| this.state.shutdown_read(); |
| Poll::Ready(Err(err)) |
| } |
| output => output, |
| } |
| } |
| TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])), |
| #[cfg(feature = "early-data")] |
| ref s => unreachable!("server TLS can not hit this state: {:?}", s), |
| } |
| } |
| |
| fn consume(mut self: Pin<&mut Self>, amt: usize) { |
| self.session.reader().consume(amt); |
| } |
| } |
| |
| impl<IO> AsyncWrite for TlsStream<IO> |
| where |
| IO: AsyncRead + AsyncWrite + Unpin, |
| { |
| /// Note: that it does not guarantee the final data to be sent. |
| /// To be cautious, you must manually call `flush`. |
| fn poll_write( |
| self: Pin<&mut Self>, |
| cx: &mut Context<'_>, |
| buf: &[u8], |
| ) -> Poll<io::Result<usize>> { |
| let this = self.get_mut(); |
| let mut stream = |
| Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
| stream.as_mut_pin().poll_write(cx, buf) |
| } |
| |
| /// Note: that it does not guarantee the final data to be sent. |
| /// To be cautious, you must manually call `flush`. |
| fn poll_write_vectored( |
| self: Pin<&mut Self>, |
| cx: &mut Context<'_>, |
| bufs: &[io::IoSlice<'_>], |
| ) -> Poll<io::Result<usize>> { |
| let this = self.get_mut(); |
| let mut stream = |
| Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
| stream.as_mut_pin().poll_write_vectored(cx, bufs) |
| } |
| |
| #[inline] |
| fn is_write_vectored(&self) -> bool { |
| true |
| } |
| |
| fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| let this = self.get_mut(); |
| let mut stream = |
| Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
| stream.as_mut_pin().poll_flush(cx) |
| } |
| |
| fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| if self.state.writeable() { |
| self.session.send_close_notify(); |
| self.state.shutdown_write(); |
| } |
| |
| let this = self.get_mut(); |
| let mut stream = |
| Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
| stream.as_mut_pin().poll_shutdown(cx) |
| } |
| } |
| |
| #[cfg(unix)] |
| impl<IO> AsRawFd for TlsStream<IO> |
| where |
| IO: AsRawFd, |
| { |
| fn as_raw_fd(&self) -> RawFd { |
| self.get_ref().0.as_raw_fd() |
| } |
| } |
| |
| #[cfg(windows)] |
| impl<IO> AsRawSocket for TlsStream<IO> |
| where |
| IO: AsRawSocket, |
| { |
| fn as_raw_socket(&self) -> RawSocket { |
| self.get_ref().0.as_raw_socket() |
| } |
| } |