diff --git a/smol/examples/con1.rs b/smol/examples/con1.rs index 54dc556ad6fac37621203e91baf0feaefd21c5fb..81ce5c068e8389fe468271e452a31416c4d9829a 100644 --- a/smol/examples/con1.rs +++ b/smol/examples/con1.rs @@ -62,6 +62,7 @@ fn parse_args() -> Result<Args> { // TODO current user args.username = Some("matt".into()); } + Ok(args) } @@ -160,6 +161,7 @@ async fn run(args: &Args) -> Result<()> { pin_mut!(netio); // let mut f = future::try_zip(netwrite, netread).fuse(); // f.await; + let mut main_ch; loop { tokio::select! { @@ -175,8 +177,11 @@ async fn run(args: &Args) -> Result<()> { let ev = ev?; match ev { Some(Event::Authenticated) => { - info!("auth auth") - + info!("auth auth"); + let ch = door.with_runner(|runner| { + runner.open_client_session(Some("cowsay it works"), false) + }).await?; + main_ch = Some(ch); } Some(_) => unreachable!(), None => {}, diff --git a/smol/src/async_door.rs b/smol/src/async_door.rs index 7db796ed6fab052bc77eb745ea046133263433d5..ac476c6e2cc00c326f5fe829346fba63ff8410c6 100644 --- a/smol/src/async_door.rs +++ b/smol/src/async_door.rs @@ -54,15 +54,13 @@ impl<'a> AsyncDoor<'a> { } pub async fn progress<F, R>(&mut self, f: F) - -> Result<Option<R>> where F: FnOnce(door::Event) -> Result<Option<R>> { + -> Result<Option<R>> + where F: FnOnce(door::Event) -> Result<Option<R>> { { - info!("progress top"); let res = { let mut inner = self.inner.lock().await; - info!("progress locked"); let inner = inner.deref_mut(); let ev = inner.runner.progress(&mut inner.behaviour).await.context("progess")?; - info!("progress ev {ev:?}"); if let Some(ev) = ev { let r = f(ev); inner.runner.done_payload()?; @@ -102,6 +100,7 @@ impl<'a> AsyncRead for AsyncDoor<'a> { self.read_waker.register(cx.waker()); let mut inner = self.inner.try_lock(); let runner = if let Ok(ref mut inner) = inner { + self.read_waker.take(); &mut inner.deref_mut().runner } else { return Poll::Pending @@ -112,7 +111,10 @@ impl<'a> AsyncRead for AsyncDoor<'a> { let r = match r { // sz=0 means EOF - Ok(0) => Poll::Pending, + Ok(0) => { + runner.set_output_waker(cx.waker().clone()); + Poll::Pending + } Ok(sz) => { trace!("{:?}", (&b[..sz]).hex_dump()); buf.advance(sz); @@ -120,8 +122,13 @@ impl<'a> AsyncRead for AsyncDoor<'a> { } Err(e) => Poll::Ready(Err(e)), }; + + // drop the mutex guard before waking others drop(inner); - self.write_waker.take().map(|w| w.wake()); + self.write_waker.take().map(|w| { + trace!("wake write_waker"); + w.wake() + }); r } } @@ -138,6 +145,7 @@ impl<'a> AsyncWrite for AsyncDoor<'a> { self.write_waker.register(cx.waker()); let mut inner = self.inner.try_lock(); let runner = if let Ok(ref mut inner) = inner { + self.write_waker.take(); &mut inner.deref_mut().runner } else { return Poll::Pending @@ -152,11 +160,23 @@ impl<'a> AsyncWrite for AsyncDoor<'a> { .map_err(|e| IoError::new(std::io::ErrorKind::Other, e)); Poll::Ready(r) } else { + runner.set_input_waker(cx.waker().clone()); Poll::Pending }; + + // drop the mutex guard before waking others drop(inner); - self.progress_notify.notify_one(); - self.read_waker.take().map(|w| w.wake()); + + if let Poll::Ready(_) = r { + // TODO: only notify if packet traffic.payload().is_some() ? + self.progress_notify.notify_one(); + trace!("notify progress"); + } + // TODO: check output_pending() before waking? + self.read_waker.take().map(|w| { + trace!("wake read_waker"); + w.wake() + }); r } diff --git a/sshproto/src/conn.rs b/sshproto/src/conn.rs index f981ce4c70b288217574afd1a5254f0c2691f4f6..470468334c4bf2e474d2eaaf7aee31378e11c62c 100644 --- a/sshproto/src/conn.rs +++ b/sshproto/src/conn.rs @@ -143,7 +143,7 @@ impl<'a> Conn<'a> { } ConnState::PreAuth => { // TODO. need to figure how we'll do "unbounded" responses - // and backpressure. + // and backpressure. can_output() should have a size check? if traffic.can_output() { if let ClientServer::Client(cli) = &mut self.cliserv { cli.auth.start(&mut resp, b.client()?).await?; diff --git a/sshproto/src/runner.rs b/sshproto/src/runner.rs index c7825d366c47dbb4a18a686c29ecef263fda43d0..8430f2fdc4852a4684c9d351012a59e03498a0d9 100644 --- a/sshproto/src/runner.rs +++ b/sshproto/src/runner.rs @@ -46,30 +46,27 @@ impl<'a> Runner<'a> { } pub fn input(&mut self, buf: &[u8]) -> Result<usize, Error> { - trace!("in size {} {:?}", buf.len(), buf.hex_dump()); - let size = self.traffic.input( + self.traffic.input( &mut self.keys, &mut self.conn.remote_version, buf, - )?; - // payload will be handled when progress() is called - if self.traffic.payload(false).is_some() { - trace!("payload some, waker {:?}", self.output_waker); - if let Some(w) = self.output_waker.take() { - trace!("woke"); - w.wake() - } - } - Ok(size) + ) } + /// Write any pending output to the wire, returning the size written + pub fn output(&mut self, buf: &mut [u8]) -> Result<usize, Error> { + let r = self.traffic.output(buf); + self.wake(); + Ok(r) + } + + /// Drives connection progress, handling received payload and sending /// other packets as required. This must be polled/awaited regularly. /// Optionally returns `Event` which provides channel or session // event to the application. pub async fn progress<'f>(&'f mut self, b: &mut Behaviour<'_>) -> Result<Option<Event<'f>>, Error> { - trace!("prog"); - let em = if let Some(payload) = self.traffic.payload(false) { + let em = if let Some(payload) = self.traffic.payload() { // Lifetimes here are a bit subtle. // `payload` has self.traffic lifetime, used until `handle_payload` // completes. @@ -102,19 +99,20 @@ impl<'a> Runner<'a> { let ev = if let Some(em) = em { match em { EventMaker::Channel(ChanEventMaker::DataIn(di)) => { - self.traffic.set_channel_input(di)?; self.traffic.done_payload()?; + self.traffic.set_channel_input(di)?; + // TODO: channel wakers None } _ => { // Some(payload) is only required for some variants in make_event() - trace!("event "); - let payload = self.traffic.payload(true); + let payload = self.traffic.payload_reborrow(); self.conn.make_event(payload, em)? } } } else { self.conn.progress(&mut self.traffic, &mut self.keys, b).await?; + self.wake(); None }; trace!("prog event {ev:?}"); @@ -126,17 +124,19 @@ impl<'a> Runner<'a> { self.traffic.done_payload() } - /// Write any pending output to the wire, returning the size written - pub fn output(&mut self, buf: &mut [u8]) -> Result<usize, Error> { - let r = self.traffic.output(buf); + pub fn wake(&mut self) { if self.ready_input() { if let Some(w) = self.input_waker.take() { + trace!("wake input waker"); + w.wake() + } + } + if self.output_pending() { + if let Some(w) = self.output_waker.take() { + trace!("wake output waker"); w.wake() } } - Ok(r) - // TODO: need some kind of progress() here which - // will return errors } pub fn open_client_session(&mut self, exec: Option<&str>, pty: bool) -> Result<u32> { @@ -193,14 +193,14 @@ impl<'a> Runner<'a> { self.conn.initial_sent() && self.traffic.ready_input() } - pub fn set_input_waker(&mut self, waker: Waker) { - self.input_waker = Some(waker); - } - pub fn output_pending(&self) -> bool { !self.conn.initial_sent() || self.traffic.output_pending() } + pub fn set_input_waker(&mut self, waker: Waker) { + self.input_waker = Some(waker); + } + pub fn set_output_waker(&mut self, waker: Waker) { self.output_waker = Some(waker); } diff --git a/sshproto/src/traffic.rs b/sshproto/src/traffic.rs index 2db447b3742c62c762e960d4399a4c636b8950b7..b449238fc36d73dab1cdec0e5041205a9e82854c 100644 --- a/sshproto/src/traffic.rs +++ b/sshproto/src/traffic.rs @@ -121,6 +121,7 @@ impl<'a> Traffic<'a> { } pub fn can_output(&self) -> bool { + // TODO: test for full output buffer match self.state { TrafState::Write { .. } | TrafState::Idle => true, @@ -152,15 +153,25 @@ impl<'a> Traffic<'a> { /// For a given payload should be called once initially to pass to handle_payload(), /// with borrow=false. Subsequent calls will only return the payload if borrow=false, /// used for borrowing the payload for Event. - pub(crate) fn payload(&mut self, borrow: bool) -> Option<&[u8]> { - trace!("traf payload {:?} borrow {borrow}", self.state); + // TODO: get rid of the bool and make two functions. need to figure naming. + // "payload_reborrow()"? + pub(crate) fn payload(&mut self) -> Option<&[u8]> { let p = match self.state { | TrafState::InPayload { len, .. } => { let payload = &self.buf[SSH_PAYLOAD_START..SSH_PAYLOAD_START + len]; Some(payload) } - | TrafState::BorrowPayload { len, .. } if borrow + _ => None, + }; + trace!("traf 2 {:?}", self.state); + p + } + + pub(crate) fn payload_reborrow(&mut self) -> Option<&[u8]> { + let p = match self.state { + | TrafState::InPayload { len, .. } + | TrafState::BorrowPayload { len, .. } => { let payload = &self.buf[SSH_PAYLOAD_START..SSH_PAYLOAD_START + len]; Some(payload) @@ -171,6 +182,8 @@ impl<'a> Traffic<'a> { p } + /// Called when `payload()` has been handled once, can still be + /// `payload_reborrow()`ed later. pub(crate) fn handled_payload(&mut self) -> Result<(), Error> { match self.state { | TrafState::InPayload { len } @@ -183,6 +196,7 @@ impl<'a> Traffic<'a> { } } + /// Called when `payload()` and `payload_reborrow()` are complete. pub(crate) fn done_payload(&mut self) -> Result<(), Error> { match self.state { | TrafState::InPayload { .. } @@ -192,7 +206,10 @@ impl<'a> Traffic<'a> { self.state = TrafState::Idle; Ok(()) } - _ => Err(Error::bug()) + _ => { + /* Just ignore it */ + Ok(()) + } } } @@ -332,6 +349,7 @@ impl<'a> Traffic<'a> { } pub fn ready_channel_send(&self) -> bool { + // TODO: this should call can_output() match self.state { TrafState::Idle => true, _ => false, @@ -339,6 +357,7 @@ impl<'a> Traffic<'a> { } pub fn set_channel_input(&mut self, di: channel::DataIn) -> Result<()> { + trace!("traf chan input state {:?}", self.state); match self.state { TrafState::Idle => { let idx = SSH_PAYLOAD_START + di.offset;