diff --git a/Cargo.lock b/Cargo.lock index c19f3a0e401a39365c7254024a6101b80a00b687..50830f5204046b574fc8914788ffa6b71974bd88 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -320,8 +320,6 @@ dependencies = [ "door-sshproto", "futures", "log", - "parking_lot", - "pin-utils", "pretty-hex 0.3.0", "rpassword", "simplelog", @@ -421,9 +419,8 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "futures" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f73fe65f54d1e12b726f517d3e2135ca3125a437b6d998caf1962961f7172d9e" +version = "0.4.0-alpha.0" +source = "git+https://github.com/rust-lang/futures-rs#7f2603402a1ffbf6ad3a31f15598b72216bec242" dependencies = [ "futures-channel", "futures-core", @@ -436,9 +433,8 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3083ce4b914124575708913bca19bfe887522d6e2e6d0952943f5eac4a74010" +version = "0.4.0-alpha.0" +source = "git+https://github.com/rust-lang/futures-rs#7f2603402a1ffbf6ad3a31f15598b72216bec242" dependencies = [ "futures-core", "futures-sink", @@ -446,15 +442,13 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3" +version = "1.0.0-alpha.0" +source = "git+https://github.com/rust-lang/futures-rs#7f2603402a1ffbf6ad3a31f15598b72216bec242" [[package]] name = "futures-executor" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9420b90cfa29e327d0429f19be13e7ddb68fa1cccb09d65e5706b8c7a749b8a6" +version = "0.4.0-alpha.0" +source = "git+https://github.com/rust-lang/futures-rs#7f2603402a1ffbf6ad3a31f15598b72216bec242" dependencies = [ "futures-core", "futures-task", @@ -464,14 +458,12 @@ dependencies = [ [[package]] name = "futures-io" version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc4045962a5a5e935ee2fdedaa4e08284547402885ab326734432bed5d12966b" +source = "git+https://github.com/rust-lang/futures-rs#7f2603402a1ffbf6ad3a31f15598b72216bec242" [[package]] name = "futures-macro" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33c1e13800337f4d4d7a316bf45a567dbcb6ffe087f16424852d97e97a91f512" +version = "0.4.0-alpha.0" +source = "git+https://github.com/rust-lang/futures-rs#7f2603402a1ffbf6ad3a31f15598b72216bec242" dependencies = [ "proc-macro2", "quote", @@ -480,21 +472,18 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21163e139fa306126e6eedaf49ecdb4588f939600f0b1e770f4205ee4b7fa868" +version = "0.4.0-alpha.0" +source = "git+https://github.com/rust-lang/futures-rs#7f2603402a1ffbf6ad3a31f15598b72216bec242" [[package]] name = "futures-task" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c66a976bf5909d801bbef33416c41372779507e7a6b3a5e25e4749c58f776a" +version = "0.4.0-alpha.0" +source = "git+https://github.com/rust-lang/futures-rs#7f2603402a1ffbf6ad3a31f15598b72216bec242" [[package]] name = "futures-util" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a" +version = "0.4.0-alpha.0" +source = "git+https://github.com/rust-lang/futures-rs#7f2603402a1ffbf6ad3a31f15598b72216bec242" dependencies = [ "futures-channel", "futures-core", diff --git a/smol/Cargo.toml b/smol/Cargo.toml index ee81e5f1a3acbf3f3f755e112f5a02e3660fd98f..8a5e337a503a5bcb628fb3a628aa8f811f733e03 100644 --- a/smol/Cargo.toml +++ b/smol/Cargo.toml @@ -15,11 +15,13 @@ argh = "0.1" # futures-micro = "0.5" # async-dup = "1.2" -pin-utils = "0.1" -tokio = { version = "1.17", features = ["full"] } -futures = "0.3" -parking_lot = { version = "0.12", features = ["arc_lock", "send_guard"] } +# pin-utils = "0.1" +# pin-project = "1.0" +# parking_lot = { version = "0.12", features = ["arc_lock", "send_guard"] } +tokio = { version = "1.17", features = ["sync"] } +# require alpha for https://github.com/rust-lang/futures-rs/pull/2571 +futures = { version = "0.4.0-alpha.0", git = "https://github.com/rust-lang/futures-rs", revision = "8b0f812f53ada0d0aeb74abc32be22ab9dafae05" } async-trait = "0.1" # TODO @@ -27,6 +29,7 @@ anyhow = { version = "1.0" } pretty-hex = "0.3" [dev-dependencies] -snafu = { version = "0.7", default-features = true } +snafu = { default-features = true } +tokio = { features = ["full"] } pretty-hex = "0.3" simplelog = "0.12" diff --git a/smol/examples/con1.rs b/smol/examples/con1.rs index e15a79655197ec9f65a5353e0017fdd8033db1e1..7d0bc2618b3135e8229ddc771e9ff34f278a1011 100644 --- a/smol/examples/con1.rs +++ b/smol/examples/con1.rs @@ -3,10 +3,9 @@ use { // crate::error::Error, log::{debug, error, info, log, trace, warn}, }; -use anyhow::{Context, Result, Error}; +use anyhow::{Context, Result, Error, bail}; use pretty_hex::PrettyHex; -use pin_utils::*; use tokio::net::TcpStream; use std::{net::Ipv6Addr, io::Read}; @@ -157,50 +156,61 @@ async fn run(args: &Args) -> Result<()> { // let door = async_dup::Mutex::new(door_smol::AsyncDoor { runner }); - let mut d = door.clone(); - let netio = tokio::io::copy_bidirectional(&mut stream, &mut d); - pin_mut!(netio); // let mut f = future::try_zip(netwrite, netread).fuse(); // f.await; - loop { - tokio::select! { - e = &mut netio => break e.map(|_| ()).context("net loop"), - ev = door.progress(|ev| { + let mut d = door.clone(); + let netloop = tokio::io::copy_bidirectional(&mut stream, &mut d); + + let prog = tokio::spawn(async move { + + loop { + let ev = door.progress(|ev| { trace!("progress event {ev:?}"); let e = match ev { Event::Authenticated => Some(Event::Authenticated), _ => None, }; Ok(e) - }) => { - let ev = ev?; - match ev { - Some(Event::Authenticated) => { - info!("auth auth"); - let r = door.open_client_session(Some("Cowsay it works"), false).await; - match r { - Ok((mut stdio, mut _stderr)) => { - tokio::spawn(async move { - trace!("io copy thread"); - let r = tokio::io::copy(&mut stdio, &mut tokio::io::stdout()).await; - if let Err(e) = r { - warn!("IO error: {e}"); - } - }); - }, - Err(e) => { - warn!("Failed opening session: {e}") - } - } - } - Some(_) => unreachable!(), - None => {}, + }).await.context("progress loop")?; + + match ev { + Some(Event::Authenticated) => { + info!("auth auth"); + let r = door.open_client_session(Some("cowsay it works"), false).await + .context("Opening session")?; + let (mut io, mut err) = r; + tokio::spawn(async move { + trace!("io copy thread"); + // let mut i = tokio::io::stdin(); + let mut o = tokio::io::stdout(); + let mut e = tokio::io::stderr(); + let mut io2 = io.clone(); + let co = tokio::io::copy(&mut io, &mut o); + // let ci = tokio::io::copy(&mut i, &mut io2); + let ce = tokio::io::copy(&mut err, &mut e); + let r = futures::join!(co, ce); + r.0?; + r.1?; + // r.2?; + Ok::<_, anyhow::Error>(()) + }); } + Some(_) => unreachable!(), + None => {}, + } + } + Ok::<_, anyhow::Error>(()) + }); + + loop { + tokio::select! { + e = prog => { + bail!("progress loop {e:?}"); + } + e = netloop => { + bail!("net loop {e:?}"); } - // q = door.next_request() => { - // handle_request(&door, q).await - // } } } diff --git a/smol/src/async_client.rs b/smol/src/async_client.rs index 0602748acc3cc1e8d591188e7d547dc4226b40c3..8de183714c89bd4ff078a294196e24824ead28b1 100644 --- a/smol/src/async_client.rs +++ b/smol/src/async_client.rs @@ -9,8 +9,6 @@ use door_sshproto::{BhError, BhResult}; use door_sshproto::{ChanMsg, ChanMsgDetails, Error, RespPackets, Result, Runner}; use std::collections::VecDeque; -use std::io::Write; -use tokio::io::AsyncWriteExt; use async_trait::async_trait; @@ -34,10 +32,10 @@ impl SimpleClient { pub fn add_authkey(&mut self, k: SignKey) { self.authkeys.push_back(k) } - } -#[async_trait(?Send)] +// #[async_trait(?Send)] +#[async_trait] impl door::AsyncCliBehaviour for SimpleClient { async fn chan_handler( &mut self, diff --git a/smol/src/async_door.rs b/smol/src/async_door.rs index fb46ff957076f86557f40348292901c93890c4f0..d522a9ae954470d512da4b0135788c4161244fa6 100644 --- a/smol/src/async_door.rs +++ b/smol/src/async_door.rs @@ -5,29 +5,27 @@ use log::{debug, error, info, log, trace, warn}; use core::future::Future; use core::pin::Pin; use core::task::{Context, Poll}; -use pin_utils::pin_mut; -use futures::lock::Mutex; +use futures::lock::{Mutex, OwnedMutexLockFuture, OwnedMutexGuard}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::sync::Notify as TokioNotify; use std::io::Error as IoError; use std::io::ErrorKind; +use std::collections::HashMap; use core::task::Waker; +use core::ops::DerefMut; use std::sync::Arc; -use futures::task::AtomicWaker; // TODO use anyhow::{anyhow, Context as _, Error, Result}; -use core::ops::DerefMut; use door::{Behaviour, Runner}; use door_sshproto as door; use door_sshproto::error::Error as DoorError; // use door_sshproto::client::*; -use async_trait::async_trait; use pretty_hex::PrettyHex; @@ -35,6 +33,9 @@ pub struct Inner<'a> { runner: Runner<'a>, // TODO: perhaps behaviour can move to runner? unsure of lifetimes. behaviour: Behaviour<'a>, + + chan_read_wakers: HashMap<(u32, Option<u32>), Waker>, + chan_write_wakers: HashMap<(u32, Option<u32>), Waker>, } pub struct AsyncDoor<'a> { @@ -43,44 +44,75 @@ pub struct AsyncDoor<'a> { inner: Arc<Mutex<Inner<'a>>>, progress_notify: Arc<TokioNotify>, + + read_lock_fut: Option<OwnedMutexLockFuture<Inner<'a>>>, + write_lock_fut: Option<OwnedMutexLockFuture<Inner<'a>>>, } impl<'a> AsyncDoor<'a> { pub fn new(runner: Runner<'a>, behaviour: Behaviour<'a>) -> Self { - let inner = Arc::new(Mutex::new(Inner { runner, behaviour })); + let chan_read_wakers = HashMap::new(); + let chan_write_wakers = HashMap::new(); + let inner = Arc::new(Mutex::new(Inner { runner, behaviour, + chan_read_wakers, chan_write_wakers })); let progress_notify = Arc::new(TokioNotify::new()); - Self { inner, progress_notify } - } - - pub fn clone(&'_ self) -> Self { - Self { inner: self.inner.clone(), - progress_notify: self.progress_notify.clone() } + Self { inner, progress_notify, read_lock_fut: None, write_lock_fut: None } } pub async fn progress<F, R>(&mut self, f: F) -> Result<Option<R>> where F: FnOnce(door::Event) -> Result<Option<R>> { - { - let res = { - let mut inner = self.inner.lock().await; - let inner = inner.deref_mut(); - let ev = inner.runner.progress(&mut inner.behaviour).await.context("progess")?; - if let Some(ev) = ev { - let r = f(ev); - inner.runner.done_payload()?; - r - } else { - Ok(None) - } + trace!("progress"); + let mut wakers = Vec::new(); + let res = { + let mut inner = self.inner.lock().await; + let inner = inner.deref_mut(); + let ev = inner.runner.progress(&mut inner.behaviour).await.context("progess")?; + let r = if let Some(ev) = ev { + let r = f(ev); + inner.runner.done_payload()?; + r + } else { + Ok(None) }; - // TODO: currently this is only woken by incoming data, should it - // also wake internally from runner or conn? It runs once at start - // to kick off the outgoing handshake at least. - if let Ok(None) = res { - self.progress_notify.notified().await; + + if let Some(ce) = inner.runner.ready_channel_input() { + inner.chan_read_wakers.remove(&ce) + .map(|w| wakers.push(w)); } - res + + // Pending https://github.com/rust-lang/rust/issues/59618 + // HashMap::drain_filter + // TODO: untested. + // TODO: fairness? Also it's not clear whether progress notify + // will always get woken by runner.wake() to update this... + inner.chan_write_wakers.retain(|(ch, ext), w| { + if inner.runner.ready_channel_send(*ch, *ext) { + wakers.push(w.clone()); + false + } else { + true + } + }); + + r + }; + // lock is dropped before waker or notify + + for w in wakers { + trace!("woken {w:?}"); + w.wake() + } + + // TODO: currently this is only woken by incoming data, should it + // also wake internally from runner or conn? It runs once at start + // to kick off the outgoing handshake at least. + if let Ok(None) = res { + trace!("progress wait"); + self.progress_notify.notified().await; + trace!("progress awaited"); } + res } pub async fn with_runner<F, R>(&mut self, f: F) -> R @@ -89,51 +121,73 @@ impl<'a> AsyncDoor<'a> { f(&mut inner.runner) } - // fn channel_poll_read( - // self: Pin<&mut Self>, - // cx: &mut Context<'_>, - // buf: &mut ReadBuf, - + // TODO: return a Channel object that gives events like WinChange or exit status + // TODO: move to SimpleClient or something? pub async fn open_client_session(&mut self, exec: Option<&str>, pty: bool) -> Result<(ChanInOut<'a>, ChanExtIn<'a>)> { let chan = self.with_runner(|runner| { runner.open_client_session(exec, pty) }).await?; - let door = self.clone(); - let cstd = ChanInOut { door, chan }; - let door = self.clone(); - let cerr = ChanExtIn { door, chan, ext: SSH_EXTENDED_DATA_STDERR }; + let cstd = ChanInOut::new(chan, &self); + let cerr = ChanExtIn::new(chan, SSH_EXTENDED_DATA_STDERR, &self); Ok((cstd, cerr)) } } +impl Clone for AsyncDoor<'_> { + fn clone(&self) -> Self { + Self { inner: self.inner.clone(), + progress_notify: self.progress_notify.clone(), + read_lock_fut: None, + write_lock_fut: None, + } + } +} + + +/// Tries to locks Inner for a poll_read()/poll_write(). +/// lock_fut from the caller holds the future so that it can +/// be woken later if the lock was contended +fn poll_lock<'a>(inner: Arc<Mutex<Inner<'a>>>, cx: &mut Context<'_>, + lock_fut: &mut Option<OwnedMutexLockFuture<Inner<'a>>>) + -> Poll<OwnedMutexGuard<Inner<'a>>> { + let mut g = inner.lock_owned(); + let p = Pin::new(&mut g).poll(cx); + *lock_fut = match p { + Poll::Ready(_) => None, + Poll::Pending => Some(g), + }; + p +} + impl<'a> AsyncRead for AsyncDoor<'a> { fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf, ) -> Poll<Result<(), IoError>> { trace!("poll_read"); - // TODO: can this go into a common function returning a Poll<MappedMutexGuard<Runner>>? - // Lifetimes seem tricky. - let g = self.inner.lock(); - pin_mut!(g); - let mut g = g.poll(cx); - let runner = match g { + let mut p = poll_lock(self.inner.clone(), cx, &mut self.read_lock_fut); + + let runner = match p { Poll::Ready(ref mut i) => &mut i.runner, - Poll::Pending => return Poll::Pending, + Poll::Pending => { + trace!("poll_read pending lock"); + return Poll::Pending + } }; + runner.set_output_waker(cx.waker().clone()); let b = buf.initialize_unfilled(); let r = runner.output(b).map_err(|e| IoError::new(ErrorKind::Other, e)); match r { - // sz=0 means EOF, we don't want that + // poll_read() returning 0 means EOF, we don't want that Ok(0) => { - runner.set_output_waker(cx.waker().clone()); + trace!("set output waker"); Poll::Pending } Ok(sz) => { @@ -148,20 +202,23 @@ impl<'a> AsyncRead for AsyncDoor<'a> { impl<'a> AsyncWrite for AsyncDoor<'a> { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, IoError>> { trace!("poll_write"); - let g = self.inner.lock(); - pin_mut!(g); - let mut g = g.poll(cx); - let runner = match g { + let mut p = poll_lock(self.inner.clone(), cx, &mut self.write_lock_fut); + + let runner = match p { Poll::Ready(ref mut i) => &mut i.runner, - Poll::Pending => return Poll::Pending, + Poll::Pending => { + trace!("poll_write pending lock"); + return Poll::Pending; + } }; + runner.set_input_waker(cx.waker().clone()); // TODO: should runner just have poll_write/poll_read? // TODO: is ready_input necessary? .input() should return size=0 // if nothing is consumed. Or .input() could return a Poll<Result<usize>> @@ -171,7 +228,6 @@ impl<'a> AsyncWrite for AsyncDoor<'a> { .map_err(|e| IoError::new(std::io::ErrorKind::Other, e)); Poll::Ready(r) } else { - runner.set_input_waker(cx.waker().clone()); Poll::Pending }; @@ -180,6 +236,7 @@ impl<'a> AsyncWrite for AsyncDoor<'a> { if let Poll::Ready(_) = r { // TODO: only notify if packet traffic.payload().is_some() ? + // Though we also are using progress() for other events. self.progress_notify.notify_one(); trace!("notify progress"); } @@ -204,24 +261,41 @@ impl<'a> AsyncWrite for AsyncDoor<'a> { pub struct ChanInOut<'a> { chan: u32, door: AsyncDoor<'a>, + + rlfut: Option<OwnedMutexLockFuture<Inner<'a>>>, + wlfut: Option<OwnedMutexLockFuture<Inner<'a>>>, } pub struct ChanExtIn<'a> { chan: u32, ext: u32, door: AsyncDoor<'a>, + + rlfut: Option<OwnedMutexLockFuture<Inner<'a>>>, } pub struct ChanExtOut<'a> { chan: u32, ext: u32, door: AsyncDoor<'a>, + + wlfut: Option<OwnedMutexLockFuture<Inner<'a>>>, } impl<'a> ChanInOut<'a> { fn new(chan: u32, door: &AsyncDoor<'a>) -> Self { Self { chan, door: door.clone(), + rlfut: None, wlfut: None, + } + } +} + +impl Clone for ChanInOut<'_> { + fn clone(&self) -> Self { + Self { + chan: self.chan, door: self.door.clone(), + rlfut: None, wlfut: None, } } } @@ -230,6 +304,7 @@ impl<'a> ChanExtIn<'a> { fn new(chan: u32, ext: u32, door: &AsyncDoor<'a>) -> Self { Self { chan, ext, door: door.clone(), + rlfut: None, } } } @@ -241,7 +316,7 @@ impl<'a> AsyncRead for ChanInOut<'a> { buf: &mut ReadBuf, ) -> Poll<Result<(), IoError>> { let this = self.deref_mut(); - chan_poll_read(&mut this.door, this.chan, None, cx, buf) + chan_poll_read(&mut this.door, this.chan, None, cx, buf, &mut this.rlfut) } } @@ -252,33 +327,29 @@ impl<'a> AsyncRead for ChanExtIn<'a> { buf: &mut ReadBuf, ) -> Poll<Result<(), IoError>> { let this = self.deref_mut(); - chan_poll_read(&mut this.door, this.chan, Some(this.ext), cx, buf) + chan_poll_read(&mut this.door, this.chan, Some(this.ext), cx, buf, &mut this.rlfut) } } // Common for `ChanInOut` and `ChanExtIn` -fn chan_poll_read( - door: &mut AsyncDoor, +fn chan_poll_read<'a>( + door: &mut AsyncDoor<'a>, chan: u32, ext: Option<u32>, cx: &mut Context, buf: &mut ReadBuf, + lock_fut: &mut Option<OwnedMutexLockFuture<Inner<'a>>>, ) -> Poll<Result<(), IoError>> { - error!("chan_poll_read {chan} {ext:?}"); - - let g = door.inner.lock(); - pin_mut!(g); - let mut g = g.poll(cx); - let runner = match g { - Poll::Ready(ref mut i) => &mut i.runner, + let mut p = poll_lock(door.inner.clone(), cx, lock_fut); + let inner = match p { + Poll::Ready(ref mut i) => i, Poll::Pending => { - trace!("lock pending"); return Poll::Pending } }; - trace!("chan_poll_read locked"); + let runner = &mut inner.runner; let b = buf.initialize_unfilled(); let r = runner.channel_input(chan, ext, b) @@ -287,7 +358,8 @@ fn chan_poll_read( match r { // sz=0 means EOF, we don't want that Ok(0) => { - runner.set_output_waker(cx.waker().clone()); + let w = cx.waker().clone(); + inner.chan_read_wakers.insert((chan, ext), w); Poll::Pending } Ok(sz) => { @@ -305,7 +377,7 @@ impl<'a> AsyncWrite for ChanInOut<'a> { buf: &[u8], ) -> Poll<Result<usize, IoError>> { let this = self.deref_mut(); - chan_poll_write(&mut this.door, this.chan, None, cx, buf) + chan_poll_write(&mut this.door, this.chan, None, cx, buf, &mut this.wlfut) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { @@ -318,17 +390,17 @@ impl<'a> AsyncWrite for ChanInOut<'a> { } } -fn chan_poll_write( - door: &mut AsyncDoor, +fn chan_poll_write<'a>( + door: &mut AsyncDoor<'a>, chan: u32, ext: Option<u32>, cx: &mut Context<'_>, buf: &[u8], + lock_fut: &mut Option<OwnedMutexLockFuture<Inner<'a>>>, ) -> Poll<Result<usize, IoError>> { - let mut g = door.inner.lock(); - pin_mut!(g); - let runner = match g.poll(cx) { + let mut p = poll_lock(door.inner.clone(), cx, lock_fut); + let runner = match p { Poll::Ready(ref mut i) => &mut i.runner, Poll::Pending => return Poll::Pending, }; diff --git a/sshproto/src/async_behaviour.rs b/sshproto/src/async_behaviour.rs index 7fe1a31c09653fc3c7fd369bc35a99528d062e07..df4c59f2e2daca79cc4634b01f5374e1d4891b4a 100644 --- a/sshproto/src/async_behaviour.rs +++ b/sshproto/src/async_behaviour.rs @@ -58,8 +58,12 @@ impl AsyncCliServ { } } -#[async_trait(?Send)] -pub trait AsyncCliBehaviour { +// Send+Sync bound here is required for trait objects since there are +// default implementations of some methods. +// https://docs.rs/async-trait/latest/async_trait/index.html#dyn-traits +// #[async_trait(?Send)] +#[async_trait] +pub trait AsyncCliBehaviour: Sync+Send { async fn chan_handler(&mut self, resp: &mut RespPackets, chan_msg: ChanMsg) -> Result<()>; /// Should not block @@ -112,8 +116,9 @@ pub trait AsyncCliBehaviour { // TODO: postauth channel callbacks } -#[async_trait(?Send)] -pub trait AsyncServBehaviour { +// #[async_trait(?Send)] +#[async_trait] +pub trait AsyncServBehaviour: Sync+Send { fn progress(&mut self, runner: &mut Runner) -> Result<()> { Ok(()) } fn chan_handler(&mut self, resp: &mut RespPackets, chan_msg: ChanMsg) -> Result<()>; diff --git a/sshproto/src/behaviour.rs b/sshproto/src/behaviour.rs index 43db98e0c99cb2296dce8c6891b695a45613d1f1..a0f21624b2e60b9977919646e982ada2cb436168 100644 --- a/sshproto/src/behaviour.rs +++ b/sshproto/src/behaviour.rs @@ -124,7 +124,7 @@ impl<'a> Behaviour<'a> pub struct CliBehaviour<'a> { #[cfg(feature = "std")] - pub inner: &'a mut dyn async_behaviour::AsyncCliBehaviour, + pub inner: &'a mut (dyn async_behaviour::AsyncCliBehaviour + Send), #[cfg(not(feature = "std"))] pub inner: &'a mut dyn block_behaviour::BlockCliBehaviour, // pub phantom: core::marker::PhantomData<&'a ()>, diff --git a/sshproto/src/runner.rs b/sshproto/src/runner.rs index de031a6da2d03a6121750de0b7d9d4a931c08cf5..bc566fc714225d0a7690dcaf483b11093fc2f002 100644 --- a/sshproto/src/runner.rs +++ b/sshproto/src/runner.rs @@ -56,7 +56,10 @@ impl<'a> Runner<'a> { /// Write any pending output to the wire, returning the size written pub fn output(&mut self, buf: &mut [u8]) -> Result<usize, Error> { let r = self.traffic.output(buf); - self.wake(); + if r > 0 { + trace!("output() wake"); + self.wake(); + } Ok(r) } @@ -99,6 +102,7 @@ impl<'a> Runner<'a> { let ev = if let Some(em) = em { match em { EventMaker::Channel(ChanEventMaker::DataIn(di)) => { + trace!("chanmaaker {di:?}"); self.traffic.done_payload()?; self.traffic.set_channel_input(di)?; // TODO: channel wakers @@ -127,16 +131,21 @@ impl<'a> Runner<'a> { } pub fn wake(&mut self) { + error!("wake"); if self.ready_input() { + trace!("wake ready_input, waker {:?}", self.input_waker); if let Some(w) = self.input_waker.take() { trace!("wake input waker"); w.wake() } } + if self.output_pending() { if let Some(w) = self.output_waker.take() { trace!("wake output waker"); w.wake() + } else { + trace!("no waker"); } } } @@ -180,9 +189,11 @@ impl<'a> Runner<'a> { ext: Option<u32>, buf: &mut [u8], ) -> Result<usize> { + trace!("runner chan in"); let (len, complete) = self.traffic.channel_input(chan, ext, buf); if complete { self.conn.channels.finished_input(chan)?; + self.wake(); } Ok(len) } @@ -191,12 +202,8 @@ impl<'a> Runner<'a> { self.conn.initial_sent() && self.traffic.ready_input() } - pub fn ready_progress(&self) -> bool { - self.conn.initial_sent() && self.traffic.ready_input() - } - pub fn output_pending(&self) -> bool { - !self.conn.initial_sent() || self.traffic.output_pending() + self.traffic.output_pending() } pub fn set_input_waker(&mut self, waker: Waker) { @@ -207,13 +214,13 @@ impl<'a> Runner<'a> { self.output_waker = Some(waker); } - pub fn ready_channel_input(&self, chan: u32, ext: Option<u32>) -> bool { - self.traffic.ready_channel_input(chan, ext) + pub fn ready_channel_input(&self) -> Option<(u32, Option<u32>)> { + self.traffic.ready_channel_input() } - // TODO check the chan/ext are valid + // TODO check the chan/ext are valid, SSH window pub fn ready_channel_send(&self, _chan: u32, _ext: Option<u32>) -> bool { - self.traffic.ready_channel_send() + self.traffic.can_output() // && self.conn.channels.ready_send_data(chan, ext) } diff --git a/sshproto/src/traffic.rs b/sshproto/src/traffic.rs index b449238fc36d73dab1cdec0e5041205a9e82854c..59ebccdbf6c4588adc204aa1d8b846370ebf8c70 100644 --- a/sshproto/src/traffic.rs +++ b/sshproto/src/traffic.rs @@ -262,9 +262,7 @@ impl<'a> Traffic<'a> { /// Write any pending output, returning the size written pub fn output(&mut self, buf: &mut [u8]) -> usize { - trace!("output state {:?}", self.state); - - match self.state { + let r = match self.state { TrafState::Write { ref mut idx, len } => { let wlen = (len - *idx).min(buf.len()); buf[..wlen].copy_from_slice(&self.buf[*idx..*idx + wlen]); @@ -277,7 +275,9 @@ impl<'a> Traffic<'a> { wlen } _ => 0, - } + }; + trace!("output state now {:?}", self.state); + r } fn fill_input( @@ -340,19 +340,10 @@ impl<'a> Traffic<'a> { Ok(buf.len() - r.len()) } - pub fn ready_channel_input(&self, chan: u32, ext: Option<u32>) -> bool { + pub fn ready_channel_input(&self) -> Option<(u32, Option<u32>)> { match self.state { - TrafState::InChannelData { chan: c, ext: e, .. } - if (c, e) == (chan, ext) => true, - _ => false, - } - } - - pub fn ready_channel_send(&self) -> bool { - // TODO: this should call can_output() - match self.state { - TrafState::Idle => true, - _ => false, + TrafState::InChannelData { chan, ext, .. } => Some((chan, ext)), + _ => None, } } @@ -376,9 +367,7 @@ impl<'a> Traffic<'a> { ext: Option<u32>, buf: &mut [u8], ) -> (usize, bool) { - if !matches!(self.state, TrafState::Idle) { - return (0, false) - } + trace!("channel input {chan} {ext:?} st {:?}", self.state); match self.state { TrafState::InChannelData { chan: c, ext: e, ref mut idx, len }