diff --git a/docs/design.md b/docs/design.md index 5bac2d5c3b6e09c5eb2c8581f236cab490c297eb..f5395462ac4e331329da5658f057581e8880c2dc 100644 --- a/docs/design.md +++ b/docs/design.md @@ -67,3 +67,7 @@ so `.await` fits well. The problem is that `async_trait` requires `Box`, won't work on `no_std`. The `Behaviour` struct has a `cfg` to switch between async and non-async traits, hiding that from the main code. Eventually `async fn` [should work OK](https://github.com/rust-lang/rust/issues/91611) in static traits on `no_std`, and then it can be unified. +## Async + +The majority of packet dispatch handling isn't async, it just returns Ready straight away. Becaues of that we just have a Tokio `Mutex` which occassionally +holds the mutex across the `.await` boundary - it should seldom be contended. diff --git a/smol/examples/con1.rs b/smol/examples/con1.rs index 81ce5c068e8389fe468271e452a31416c4d9829a..e15a79655197ec9f65a5353e0017fdd8033db1e1 100644 --- a/smol/examples/con1.rs +++ b/smol/examples/con1.rs @@ -142,6 +142,7 @@ async fn run(args: &Args) -> Result<()> { // let mut stream = TcpStream::connect("130.95.13.18:22").await?; let mut work = vec![0; 3000]; + let work = Box::leak(Box::new(work)); let mut sess = door_smol::SimpleClient::new(args.username.as_ref().unwrap()); for i in &args.identityfile { sess.add_authkey(read_key(&i) @@ -161,7 +162,6 @@ async fn run(args: &Args) -> Result<()> { pin_mut!(netio); // let mut f = future::try_zip(netwrite, netread).fuse(); // f.await; - let mut main_ch; loop { tokio::select! { @@ -178,10 +178,21 @@ async fn run(args: &Args) -> Result<()> { match ev { Some(Event::Authenticated) => { info!("auth auth"); - let ch = door.with_runner(|runner| { - runner.open_client_session(Some("cowsay it works"), false) - }).await?; - main_ch = Some(ch); + let r = door.open_client_session(Some("Cowsay it works"), false).await; + match r { + Ok((mut stdio, mut _stderr)) => { + tokio::spawn(async move { + trace!("io copy thread"); + let r = tokio::io::copy(&mut stdio, &mut tokio::io::stdout()).await; + if let Err(e) = r { + warn!("IO error: {e}"); + } + }); + }, + Err(e) => { + warn!("Failed opening session: {e}") + } + } } Some(_) => unreachable!(), None => {}, diff --git a/smol/src/async_door.rs b/smol/src/async_door.rs index ac476c6e2cc00c326f5fe829346fba63ff8410c6..fb46ff957076f86557f40348292901c93890c4f0 100644 --- a/smol/src/async_door.rs +++ b/smol/src/async_door.rs @@ -1,3 +1,4 @@ +use door::sshnames::SSH_EXTENDED_DATA_STDERR; #[allow(unused_imports)] use log::{debug, error, info, log, trace, warn}; @@ -6,15 +7,16 @@ use core::pin::Pin; use core::task::{Context, Poll}; use pin_utils::pin_mut; +use futures::lock::Mutex; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio::sync::Mutex as TokioMutex; use tokio::sync::Notify as TokioNotify; use std::io::Error as IoError; use std::io::ErrorKind; use core::task::Waker; -use std::sync::{Arc, Mutex, MutexGuard}; +use std::sync::Arc; use futures::task::AtomicWaker; // TODO @@ -35,22 +37,24 @@ pub struct Inner<'a> { behaviour: Behaviour<'a>, } -#[derive(Clone)] pub struct AsyncDoor<'a> { - inner: Arc<TokioMutex<Inner<'a>>>, + // Not contended much since the Runner is inherently single threaded anyway, + // using a single buffer for input/output. + inner: Arc<Mutex<Inner<'a>>>, - read_waker: Arc<AtomicWaker>, - write_waker: Arc<AtomicWaker>, progress_notify: Arc<TokioNotify>, } impl<'a> AsyncDoor<'a> { pub fn new(runner: Runner<'a>, behaviour: Behaviour<'a>) -> Self { - let inner = Arc::new(TokioMutex::new(Inner { runner, behaviour })); - let read_waker = Arc::new(AtomicWaker::new()); - let write_waker = Arc::new(AtomicWaker::new()); + let inner = Arc::new(Mutex::new(Inner { runner, behaviour })); let progress_notify = Arc::new(TokioNotify::new()); - Self { inner, read_waker, write_waker, progress_notify } + Self { inner, progress_notify } + } + + pub fn clone(&'_ self) -> Self { + Self { inner: self.inner.clone(), + progress_notify: self.progress_notify.clone() } } pub async fn progress<F, R>(&mut self, f: F) @@ -69,8 +73,6 @@ impl<'a> AsyncDoor<'a> { Ok(None) } }; - self.read_waker.take().map(|w| w.wake()); - self.write_waker.take().map(|w| w.wake()); // TODO: currently this is only woken by incoming data, should it // also wake internally from runner or conn? It runs once at start // to kick off the outgoing handshake at least. @@ -86,6 +88,25 @@ impl<'a> AsyncDoor<'a> { let mut inner = self.inner.lock().await; f(&mut inner.runner) } + + // fn channel_poll_read( + // self: Pin<&mut Self>, + // cx: &mut Context<'_>, + // buf: &mut ReadBuf, + + pub async fn open_client_session(&mut self, exec: Option<&str>, pty: bool) + -> Result<(ChanInOut<'a>, ChanExtIn<'a>)> { + let chan = self.with_runner(|runner| { + runner.open_client_session(exec, pty) + }).await?; + + let door = self.clone(); + let cstd = ChanInOut { door, chan }; + let door = self.clone(); + let cerr = ChanExtIn { door, chan, ext: SSH_EXTENDED_DATA_STDERR }; + Ok((cstd, cerr)) + } + } impl<'a> AsyncRead for AsyncDoor<'a> { @@ -96,21 +117,21 @@ impl<'a> AsyncRead for AsyncDoor<'a> { ) -> Poll<Result<(), IoError>> { trace!("poll_read"); - // try to lock, or return pending - self.read_waker.register(cx.waker()); - let mut inner = self.inner.try_lock(); - let runner = if let Ok(ref mut inner) = inner { - self.read_waker.take(); - &mut inner.deref_mut().runner - } else { - return Poll::Pending + // TODO: can this go into a common function returning a Poll<MappedMutexGuard<Runner>>? + // Lifetimes seem tricky. + let g = self.inner.lock(); + pin_mut!(g); + let mut g = g.poll(cx); + let runner = match g { + Poll::Ready(ref mut i) => &mut i.runner, + Poll::Pending => return Poll::Pending, }; let b = buf.initialize_unfilled(); let r = runner.output(b).map_err(|e| IoError::new(ErrorKind::Other, e)); - let r = match r { - // sz=0 means EOF + match r { + // sz=0 means EOF, we don't want that Ok(0) => { runner.set_output_waker(cx.waker().clone()); Poll::Pending @@ -121,15 +142,7 @@ impl<'a> AsyncRead for AsyncDoor<'a> { Poll::Ready(Ok(())) } Err(e) => Poll::Ready(Err(e)), - }; - - // drop the mutex guard before waking others - drop(inner); - self.write_waker.take().map(|w| { - trace!("wake write_waker"); - w.wake() - }); - r + } } } @@ -141,14 +154,12 @@ impl<'a> AsyncWrite for AsyncDoor<'a> { ) -> Poll<Result<usize, IoError>> { trace!("poll_write"); - // try to lock, or return pending - self.write_waker.register(cx.waker()); - let mut inner = self.inner.try_lock(); - let runner = if let Ok(ref mut inner) = inner { - self.write_waker.take(); - &mut inner.deref_mut().runner - } else { - return Poll::Pending + let g = self.inner.lock(); + pin_mut!(g); + let mut g = g.poll(cx); + let runner = match g { + Poll::Ready(ref mut i) => &mut i.runner, + Poll::Pending => return Poll::Pending, }; // TODO: should runner just have poll_write/poll_read? @@ -164,19 +175,14 @@ impl<'a> AsyncWrite for AsyncDoor<'a> { Poll::Pending }; - // drop the mutex guard before waking others - drop(inner); + // drop before waking others + drop(runner); if let Poll::Ready(_) = r { // TODO: only notify if packet traffic.payload().is_some() ? self.progress_notify.notify_one(); trace!("notify progress"); } - // TODO: check output_pending() before waking? - self.read_waker.take().map(|w| { - trace!("wake read_waker"); - w.wake() - }); r } @@ -194,3 +200,137 @@ impl<'a> AsyncWrite for AsyncDoor<'a> { todo!("poll_close") } } + +pub struct ChanInOut<'a> { + chan: u32, + door: AsyncDoor<'a>, +} + +pub struct ChanExtIn<'a> { + chan: u32, + ext: u32, + door: AsyncDoor<'a>, +} + +pub struct ChanExtOut<'a> { + chan: u32, + ext: u32, + door: AsyncDoor<'a>, +} + +impl<'a> ChanInOut<'a> { + fn new(chan: u32, door: &AsyncDoor<'a>) -> Self { + Self { + chan, door: door.clone(), + } + } +} + +impl<'a> ChanExtIn<'a> { + fn new(chan: u32, ext: u32, door: &AsyncDoor<'a>) -> Self { + Self { + chan, ext, door: door.clone(), + } + } +} + +impl<'a> AsyncRead for ChanInOut<'a> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf, + ) -> Poll<Result<(), IoError>> { + let this = self.deref_mut(); + chan_poll_read(&mut this.door, this.chan, None, cx, buf) + } +} + +impl<'a> AsyncRead for ChanExtIn<'a> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf, + ) -> Poll<Result<(), IoError>> { + let this = self.deref_mut(); + chan_poll_read(&mut this.door, this.chan, Some(this.ext), cx, buf) + } +} + +// Common for `ChanInOut` and `ChanExtIn` +fn chan_poll_read( + door: &mut AsyncDoor, + chan: u32, + ext: Option<u32>, + cx: &mut Context, + buf: &mut ReadBuf, +) -> Poll<Result<(), IoError>> { + + error!("chan_poll_read {chan} {ext:?}"); + + let g = door.inner.lock(); + pin_mut!(g); + let mut g = g.poll(cx); + let runner = match g { + Poll::Ready(ref mut i) => &mut i.runner, + Poll::Pending => { + trace!("lock pending"); + return Poll::Pending + } + }; + + trace!("chan_poll_read locked"); + + let b = buf.initialize_unfilled(); + let r = runner.channel_input(chan, ext, b) + .map_err(|e| IoError::new(std::io::ErrorKind::Other, e)); + + match r { + // sz=0 means EOF, we don't want that + Ok(0) => { + runner.set_output_waker(cx.waker().clone()); + Poll::Pending + } + Ok(sz) => { + buf.advance(sz); + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e)), + } +} + +impl<'a> AsyncWrite for ChanInOut<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, IoError>> { + let this = self.deref_mut(); + chan_poll_write(&mut this.door, this.chan, None, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { + // perhaps common between InOut and ExtOut? + todo!("channel poll_shutdown") + } +} + +fn chan_poll_write( + door: &mut AsyncDoor, + chan: u32, + ext: Option<u32>, + cx: &mut Context<'_>, + buf: &[u8], +) -> Poll<Result<usize, IoError>> { + + let mut g = door.inner.lock(); + pin_mut!(g); + let runner = match g.poll(cx) { + Poll::Ready(ref mut i) => &mut i.runner, + Poll::Pending => return Poll::Pending, + }; + todo!() +} diff --git a/sshproto/src/runner.rs b/sshproto/src/runner.rs index 8430f2fdc4852a4684c9d351012a59e03498a0d9..de031a6da2d03a6121750de0b7d9d4a931c08cf5 100644 --- a/sshproto/src/runner.rs +++ b/sshproto/src/runner.rs @@ -121,7 +121,9 @@ impl<'a> Runner<'a> { } pub fn done_payload(&mut self) -> Result<()> { - self.traffic.done_payload() + self.traffic.done_payload()?; + self.wake(); + Ok(()) } pub fn wake(&mut self) {