blob: c1d11cd930421d34ec50c2bbd221910d7e03e318 [file] [log] [blame] [edit]
//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/rustls/rustls).
//!
//! # Why do I need to call `poll_flush`?
//!
//! Most TLS implementations will have an internal buffer to improve throughput,
//! and rustls is no exception.
//!
//! When we write data to `TlsStream`, we always write rustls buffer first,
//! then take out rustls encrypted data packet, and write it to data channel (like TcpStream).
//! When data channel is pending, some data may remain in rustls buffer.
//!
//! `tokio-rustls` To keep it simple and correct, [TlsStream] will behave like `BufWriter`.
//! For `TlsStream<TcpStream>`, this means that data written by `poll_write` is not guaranteed to be written to `TcpStream`.
//! You must call `poll_flush` to ensure that it is written to `TcpStream`.
//!
//! You should call `poll_flush` at the appropriate time,
//! such as when a period of `poll_write` write is complete and there is no more data to write.
//!
//! ## Why don't we write during `poll_read`?
//!
//! We did this in the early days of `tokio-rustls`, but it caused some bugs.
//! We can solve these bugs through some solutions, but this will cause performance degradation (reverse false wakeup).
//!
//! And reverse write will also prevent us implement full duplex in the future.
//!
//! see <https://github.com/tokio-rs/tls/issues/40>
//!
//! ## Why can't we handle it like `native-tls`?
//!
//! When data channel returns to pending, `native-tls` will falsely report the number of bytes it consumes.
//! This means that if data written by `poll_write` is not actually written to data channel, it will not return `Ready`.
//! Thus avoiding the call of `poll_flush`.
//!
//! but which does not conform to convention of `AsyncWrite` trait.
//! This means that if you give inconsistent data in two `poll_write`, it may cause unexpected behavior.
//!
//! see <https://github.com/tokio-rs/tls/issues/41>
#![warn(unreachable_pub, clippy::use_self)]
use std::io;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::pin::Pin;
use std::task::{Context, Poll};
pub use rustls;
use rustls::CommonState;
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
macro_rules! ready {
( $e:expr ) => {
match $e {
std::task::Poll::Ready(t) => t,
std::task::Poll::Pending => return std::task::Poll::Pending,
}
};
}
pub mod client;
pub use client::{Connect, FallibleConnect, TlsConnector, TlsConnectorWithAlpn};
mod common;
pub mod server;
pub use server::{Accept, FallibleAccept, LazyConfigAcceptor, StartHandshake, TlsAcceptor};
/// Unified TLS stream type
///
/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use
/// a single type to keep both client- and server-initiated TLS-encrypted connections.
#[allow(clippy::large_enum_variant)] // https://github.com/rust-lang/rust-clippy/issues/9798
#[derive(Debug)]
pub enum TlsStream<T> {
Client(client::TlsStream<T>),
Server(server::TlsStream<T>),
}
impl<T> TlsStream<T> {
pub fn get_ref(&self) -> (&T, &CommonState) {
use TlsStream::*;
match self {
Client(io) => {
let (io, session) = io.get_ref();
(io, session)
}
Server(io) => {
let (io, session) = io.get_ref();
(io, session)
}
}
}
pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
use TlsStream::*;
match self {
Client(io) => {
let (io, session) = io.get_mut();
(io, &mut *session)
}
Server(io) => {
let (io, session) = io.get_mut();
(io, &mut *session)
}
}
}
}
impl<T> From<client::TlsStream<T>> for TlsStream<T> {
fn from(s: client::TlsStream<T>) -> Self {
Self::Client(s)
}
}
impl<T> From<server::TlsStream<T>> for TlsStream<T> {
fn from(s: server::TlsStream<T>) -> Self {
Self::Server(s)
}
}
#[cfg(unix)]
impl<S> AsRawFd for TlsStream<S>
where
S: AsRawFd,
{
fn as_raw_fd(&self) -> RawFd {
self.get_ref().0.as_raw_fd()
}
}
#[cfg(windows)]
impl<S> AsRawSocket for TlsStream<S>
where
S: AsRawSocket,
{
fn as_raw_socket(&self) -> RawSocket {
self.get_ref().0.as_raw_socket()
}
}
impl<T> AsyncRead for TlsStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Client(x) => Pin::new(x).poll_read(cx, buf),
Self::Server(x) => Pin::new(x).poll_read(cx, buf),
}
}
}
impl<T> AsyncBufRead for TlsStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
match self.get_mut() {
Self::Client(x) => Pin::new(x).poll_fill_buf(cx),
Self::Server(x) => Pin::new(x).poll_fill_buf(cx),
}
}
#[inline]
fn consume(self: Pin<&mut Self>, amt: usize) {
match self.get_mut() {
Self::Client(x) => Pin::new(x).consume(amt),
Self::Server(x) => Pin::new(x).consume(amt),
}
}
}
impl<T> AsyncWrite for TlsStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Client(x) => Pin::new(x).poll_write(cx, buf),
Self::Server(x) => Pin::new(x).poll_write(cx, buf),
}
}
#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs),
Self::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs),
}
}
#[inline]
fn is_write_vectored(&self) -> bool {
match self {
Self::Client(x) => x.is_write_vectored(),
Self::Server(x) => x.is_write_vectored(),
}
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Client(x) => Pin::new(x).poll_flush(cx),
Self::Server(x) => Pin::new(x).poll_flush(cx),
}
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Client(x) => Pin::new(x).poll_shutdown(cx),
Self::Server(x) => Pin::new(x).poll_shutdown(cx),
}
}
}