| mod utils { |
| use std::collections::VecDeque; |
| use std::io::IoSlice; |
| use std::pin::Pin; |
| use std::task::{Context, Poll}; |
| |
| use rustls::{ |
| pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer}, |
| ClientConfig, RootCertStore, ServerConfig, |
| }; |
| use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt}; |
| |
| #[allow(dead_code)] |
| pub(crate) fn make_configs() -> (ServerConfig, ClientConfig) { |
| // A test root certificate that is the trust anchor for the CHAIN. |
| const ROOT: &str = include_str!("certs/root.pem"); |
| // A server certificate chain that includes both an end-entity server certificate |
| // and the intermediate certificate that issued it. The ROOT is configured |
| // out-of-band. |
| const CHAIN: &str = include_str!("certs/chain.pem"); |
| // A private key corresponding to the end-entity server certificate in CHAIN. |
| const EE_KEY: &str = include_str!("certs/end.key"); |
| |
| let cert = CertificateDer::pem_slice_iter(CHAIN.as_bytes()) |
| .collect::<Result<Vec<_>, _>>() |
| .unwrap(); |
| let key = PrivateKeyDer::from_pem_slice(EE_KEY.as_bytes()).unwrap(); |
| let sconfig = ServerConfig::builder() |
| .with_no_client_auth() |
| .with_single_cert(cert, key) |
| .unwrap(); |
| |
| let mut client_root_cert_store = RootCertStore::empty(); |
| for root in CertificateDer::pem_slice_iter(ROOT.as_bytes()) { |
| client_root_cert_store.add(root.unwrap()).unwrap(); |
| } |
| |
| let cconfig = ClientConfig::builder() |
| .with_root_certificates(client_root_cert_store) |
| .with_no_client_auth(); |
| |
| (sconfig, cconfig) |
| } |
| |
| #[allow(dead_code)] |
| pub(crate) async fn write<W: AsyncWrite + Unpin>( |
| w: &mut W, |
| data: &[u8], |
| vectored: bool, |
| ) -> io::Result<()> { |
| if !vectored { |
| return w.write_all(data).await; |
| } |
| |
| let mut data = data; |
| |
| while !data.is_empty() { |
| let chunk_size = (data.len() / 4).max(1); |
| let vectors = data |
| .chunks(chunk_size) |
| .map(IoSlice::new) |
| .collect::<Vec<_>>(); |
| let written = w.write_vectored(&vectors).await?; |
| data = &data[written..]; |
| } |
| |
| Ok(()) |
| } |
| |
| #[allow(dead_code)] |
| pub(crate) const TEST_SERVER_DOMAIN: &str = "foobar.com"; |
| |
| /// An IO wrapper that never flushes when writing, and always returns pending on first flush. |
| /// |
| /// This is used to test that rustls always flushes to completion during handshake. |
| pub(crate) struct FlushWrapper<S> { |
| stream: S, |
| buf: VecDeque<Vec<u8>>, |
| queued: Vec<u8>, |
| } |
| |
| impl<S> FlushWrapper<S> { |
| #[allow(dead_code)] |
| pub(crate) fn new(stream: S) -> Self { |
| Self { |
| stream, |
| buf: VecDeque::new(), |
| queued: Vec::new(), |
| } |
| } |
| } |
| |
| impl<S: AsyncRead + Unpin> AsyncRead for FlushWrapper<S> { |
| fn poll_read( |
| self: Pin<&mut Self>, |
| cx: &mut Context<'_>, |
| buf: &mut tokio::io::ReadBuf<'_>, |
| ) -> Poll<io::Result<()>> { |
| Pin::new(&mut self.get_mut().stream).poll_read(cx, buf) |
| } |
| } |
| |
| impl<S: AsyncWrite + Unpin> FlushWrapper<S> { |
| fn poll_flush_inner<F>( |
| &mut self, |
| cx: &mut Context<'_>, |
| flush_inner: F, |
| ) -> Poll<Result<(), io::Error>> |
| where |
| F: FnOnce(Pin<&mut S>, &mut Context<'_>) -> Poll<Result<(), io::Error>>, |
| { |
| loop { |
| let stream = Pin::new(&mut self.stream); |
| if !self.queued.is_empty() { |
| // write out the queued data |
| let n = std::task::ready!(stream.poll_write(cx, &self.queued))?; |
| self.queued = self.queued[n..].to_vec(); |
| } else if let Some(buf) = self.buf.pop_front() { |
| // queue the flush, but don't trigger the write immediately. |
| self.queued = buf; |
| cx.waker().wake_by_ref(); |
| return Poll::Pending; |
| } else { |
| // nothing more to flush to the inner stream, flush the inner stream instead. |
| return flush_inner(stream, cx); |
| } |
| } |
| } |
| } |
| |
| impl<S: AsyncWrite + Unpin> AsyncWrite for FlushWrapper<S> { |
| fn poll_write( |
| self: Pin<&mut Self>, |
| _: &mut Context<'_>, |
| buf: &[u8], |
| ) -> Poll<Result<usize, io::Error>> { |
| self.get_mut().buf.push_back(buf.to_vec()); |
| Poll::Ready(Ok(buf.len())) |
| } |
| |
| fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
| self.get_mut() |
| .poll_flush_inner(cx, |s, cx| s.poll_flush(cx)) |
| } |
| |
| fn poll_shutdown( |
| self: Pin<&mut Self>, |
| cx: &mut Context<'_>, |
| ) -> Poll<Result<(), io::Error>> { |
| self.get_mut() |
| .poll_flush_inner(cx, |s, cx| s.poll_shutdown(cx)) |
| } |
| } |
| } |