From 7c36f43bad11aa2bf3919bfadf206b56a1ebf02e Mon Sep 17 00:00:00 2001 From: Matt Johnston <matt@ucc.asn.au> Date: Sat, 25 Jun 2022 13:39:24 +0800 Subject: [PATCH] work on send packet. not really tested yet --- Cargo.lock | 8 ++++---- async/Cargo.toml | 2 +- async/examples/con1.rs | 12 +++++------- async/src/async_door.rs | 28 +++++++++++++++++++++------- async/src/fdio.rs | 15 +++++++++------ sshproto/src/channel.rs | 41 ++++++++++++++++++++++++++++------------- sshproto/src/conn.rs | 4 ++-- sshproto/src/runner.rs | 34 ++++++++++++++++++++++------------ sshproto/src/traffic.rs | 21 ++++++++++++++++++--- 9 files changed, 110 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8d47f97..94d231a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1402,9 +1402,9 @@ checksum = "42657b1a6f4d817cda8e7a0ace261fe0cc946cf3a80314390b22cc61ae080792" [[package]] name = "tokio" -version = "1.18.2" +version = "1.19.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4903bf0427cf68dddd5aa6a93220756f8be0c34fcfa9f5e6191e103e15a31395" +checksum = "c51a52ed6686dd62c320f9b89299e9dfb46f730c7a48e635c19f21d116cb1439" dependencies = [ "bytes", "libc", @@ -1422,9 +1422,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b557f72f448c511a979e2564e55d74e6c4432fc96ff4f6241bc6bded342643b7" +checksum = "9724f9a975fb987ef7a3cd9be0350edcbe130698af5b8f7a631e23d42d052484" dependencies = [ "proc-macro2", "quote", diff --git a/async/Cargo.toml b/async/Cargo.toml index c0880fc..94de37f 100644 --- a/async/Cargo.toml +++ b/async/Cargo.toml @@ -20,7 +20,7 @@ argh = "0.1" # parking_lot = { version = "0.12", features = ["arc_lock", "send_guard"] } # "net" for AsyncFd on unix -tokio = { version = "1.17", features = ["sync", "net"] } +tokio = { version = "1.19", features = ["sync", "net"] } # require alpha for https://github.com/rust-lang/futures-rs/pull/2571 futures = { version = "0.4.0-alpha.0", git = "https://github.com/rust-lang/futures-rs", revision = "8b0f812f53ada0d0aeb74abc32be22ab9dafae05" } async-trait = "0.1" diff --git a/async/examples/con1.rs b/async/examples/con1.rs index 67ead30..2f54ddc 100644 --- a/async/examples/con1.rs +++ b/async/examples/con1.rs @@ -168,13 +168,11 @@ async fn run(args: &Args) -> Result<()> { let mut o = door_async::stdout()?; let mut e = door_async::stderr()?; let mut io2 = io.clone(); - let co = tokio::io::copy(&mut io, &mut o); - let ci = tokio::io::copy(&mut i, &mut io2); - let ce = tokio::io::copy(&mut err, &mut e); - let (r1, r2, r3) = futures::join!(co, ci, ce); - r1?; - r2?; - r3?; + moro::async_scope!(|scope| { + scope.spawn(tokio::io::copy(&mut io, &mut o)); + scope.spawn(tokio::io::copy(&mut i, &mut io2)); + scope.spawn(tokio::io::copy(&mut err, &mut e)); + }).await; Ok::<_, anyhow::Error>(()) }); // TODO: handle channel completion diff --git a/async/src/async_door.rs b/async/src/async_door.rs index 67c8ce4..4432fe1 100644 --- a/async/src/async_door.rs +++ b/async/src/async_door.rs @@ -74,15 +74,17 @@ impl<'a> AsyncDoor<'a> { let r = if let Some(ev) = ev { let r = match ev { Event::Channel(ChanEvent::Eof { num }) => { + // TODO Ok(None) }, _ => f(ev), }; - inner.runner.done_payload()?; + trace!("async prog done payload"); r } else { Ok(None) }; + inner.runner.done_payload()?; if let Some(ce) = inner.runner.ready_channel_input() { inner.chan_read_wakers.remove(&ce) @@ -95,11 +97,12 @@ impl<'a> AsyncDoor<'a> { // TODO: fairness? Also it's not clear whether progress notify // will always get woken by runner.wake() to update this... inner.chan_write_wakers.retain(|(ch, ext), w| { - if inner.runner.ready_channel_send(*ch, *ext) { - wakers.push(w.clone()); - false - } else { - true + match inner.runner.ready_channel_send(*ch) { + Some(n) if n > 0 => { + wakers.push(w.clone()); + false + } + _ => true } }); @@ -336,6 +339,7 @@ fn chan_poll_read<'a>( buf: &mut ReadBuf, lock_fut: &mut Option<OwnedMutexLockFuture<Inner<'a>>>, ) -> Poll<Result<(), IoError>> { + trace!("chan read"); let mut p = poll_lock(door.inner.clone(), cx, lock_fut); let inner = match p { @@ -352,6 +356,8 @@ fn chan_poll_read<'a>( .map_err(|e| IoError::new(std::io::ErrorKind::Other, e)); match r { + // poll_read() returns 0 on EOF, if the channel isn't eof yet + // we want to return pending Ok(0) if !runner.channel_eof(chan) => { let w = cx.waker().clone(); inner.chan_read_wakers.insert((chan, ext), w); @@ -393,11 +399,19 @@ fn chan_poll_write<'a>( buf: &[u8], lock_fut: &mut Option<OwnedMutexLockFuture<Inner<'a>>>, ) -> Poll<Result<usize, IoError>> { + trace!("chan write"); let mut p = poll_lock(door.inner.clone(), cx, lock_fut); let runner = match p { Poll::Ready(ref mut i) => &mut i.runner, Poll::Pending => return Poll::Pending, }; - todo!() + + match runner.channel_send(chan, ext, buf) { + Ok(Some(l)) if l == 0 => Poll::Pending, + Ok(Some(l)) => Poll::Ready(Ok(l)), + // return 0 for EOF + Ok(None) => Poll::Ready(Ok(0)), + Err(e) => Poll::Ready(Err(IoError::new(ErrorKind::Other, e))), + } } diff --git a/async/src/fdio.rs b/async/src/fdio.rs index 0fb9aef..814b60c 100644 --- a/async/src/fdio.rs +++ b/async/src/fdio.rs @@ -3,7 +3,7 @@ use log::{debug, error, info, log, trace, warn}; use snafu::{prelude::*, Whatever}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, Interest}; use tokio::io::unix::AsyncFd; use std::os::unix::io::RawFd; @@ -14,10 +14,11 @@ use core::task::{Context, Poll}; use nix::fcntl::{fcntl, FcntlArg, OFlag}; -fn dup_async(orig_fd: libc::c_int) -> Result<AsyncFd<RawFd>, IoError> { +fn dup_async(orig_fd: libc::c_int, interest: Interest) -> Result<AsyncFd<RawFd>, IoError> { let fd = nix::unistd::dup(orig_fd)?; fcntl(fd, FcntlArg::F_SETFL(OFlag::O_NONBLOCK))?; - AsyncFd::new(fd) + // TODO: is with_interest necessary? + AsyncFd::with_interest(fd, interest) } pub struct InFd { @@ -28,17 +29,17 @@ pub struct OutFd { } pub fn stdin() -> Result<InFd, IoError> { Ok(InFd { - f: dup_async(libc::STDIN_FILENO)?, + f: dup_async(libc::STDIN_FILENO, Interest::READABLE)?, }) } pub fn stdout() -> Result<OutFd, IoError> { Ok(OutFd { - f: dup_async(libc::STDOUT_FILENO)?, + f: dup_async(libc::STDOUT_FILENO, Interest::WRITABLE)?, }) } pub fn stderr() -> Result<OutFd, IoError> { Ok(OutFd { - f: dup_async(libc::STDERR_FILENO)?, + f: dup_async(libc::STDERR_FILENO, Interest::WRITABLE)?, }) } @@ -48,6 +49,7 @@ impl AsyncRead for InFd { cx: &mut Context<'_>, buf: &mut ReadBuf, ) -> Poll<Result<(), IoError>> { + trace!("infd rd {:?}", self.f); // XXX loop was copy pasted from docs, perhaps it could be simpler loop { let mut guard = match self.f.poll_read_ready(cx)? { @@ -81,6 +83,7 @@ impl AsyncWrite for OutFd { cx: &mut Context<'_>, buf: &[u8] ) -> Poll<std::io::Result<usize>> { + trace!("outfd wr {:?}", self.f); loop { let mut guard = match self.f.poll_write_ready(cx)? { Poll::Ready(r) => r, diff --git a/sshproto/src/channel.rs b/sshproto/src/channel.rs index 417cc06..1bca0e8 100644 --- a/sshproto/src/channel.rs +++ b/sshproto/src/channel.rs @@ -1,4 +1,4 @@ -#[allow(unused_imports)] + #[allow(unused_imports)] use { crate::error::{Error, Result, TrapBug}, log::{debug, error, info, log, trace, warn}, @@ -85,20 +85,22 @@ impl Channels { // Ok(()) } - /// Returns the channel data packet to send, and the length of data consumed + /// Returns the channel data packet to send, and the length of data consumed. + /// Caller has already checked valid length with send_allowed() pub(crate) fn send_data<'b>(&mut self, num: u32, ext: Option<u32>, data: &'b [u8]) - -> Result<(Packet<'b>, usize)> { - let send_ch = self.get(num)?.send.as_ref().trap()?.num; - // TODO: check: channel state, channel window, maxpacket - let len = data.len(); + -> Result<Packet<'b>> { + let send = self.get(num)?.send.as_ref().trap()?; + if data.len() > send.max_packet || data.len() > send.window { + return Err(Error::bug()) + } let data = BinString(data); let p = if let Some(code) = ext { // TODO: check code is valid for this channel - packets::ChannelDataExt { num: send_ch, code, data }.into() + packets::ChannelDataExt { num: send.num, code, data }.into() } else { - packets::ChannelData { num: send_ch, data }.into() + packets::ChannelData { num: send.num, data }.into() }; - Ok((p, len)) + Ok(p) } /// Informs the channel layer that an incoming packet has been read out, @@ -117,8 +119,12 @@ impl Channels { } } - pub(crate) fn recv_eof(&self, num: u32) -> bool { - self.get(num).map_or(false, |c| c.recv_eof()) + pub(crate) fn have_recv_eof(&self, num: u32) -> bool { + self.get(num).map_or(false, |c| c.have_recv_eof()) + } + + pub(crate) fn send_allowed(&self, num: u32) -> Option<usize> { + self.get(num).map_or(Some(0), |c| c.send_allowed()) } // incoming packet handling @@ -194,7 +200,9 @@ impl Channels { Ok(Some(ChanEventMaker::Eof { num: p.num })) } Packet::ChannelClose(_p) => { - todo!(); + // todo!(); + error!("ignoring channel close"); + Ok(None) } Packet::ChannelRequest(p) => { match self.get(p.num) { @@ -397,7 +405,7 @@ impl Channel { self.pending_adjust = self.pending_adjust.saturating_add(len) } - pub fn recv_eof(&self) -> bool { + fn have_recv_eof(&self) -> bool { match self.state { |ChanState::RecvEof |ChanState::RecvClose @@ -406,6 +414,11 @@ impl Channel { } } + + // None on close + fn send_allowed(&self) -> Option<usize> { + self.send.as_ref().map(|s| usize::max(s.window, s.max_packet)) + } } pub struct ChanMsg { @@ -491,6 +504,8 @@ impl ChanEventMaker { } } } else { + // TODO: return a bug result? + warn!("Req event maker but not request packet"); None } } diff --git a/sshproto/src/conn.rs b/sshproto/src/conn.rs index 4704683..286075f 100644 --- a/sshproto/src/conn.rs +++ b/sshproto/src/conn.rs @@ -363,8 +363,8 @@ impl<'a> Conn<'a> { let p = payload.map(|pl| sshwire::packet_from_bytes(pl, &self.parse_ctx)).transpose()?; let r = match ev { EventMaker::Channel(ChanEventMaker::DataIn(_)) => { - // no event returned, handled specially by caller - None + // caller should have handled it instead + return Err(Error::bug()) } EventMaker::Channel(cev) => { let c = cev.make(p.trap()?); diff --git a/sshproto/src/runner.rs b/sshproto/src/runner.rs index 871e163..7238528 100644 --- a/sshproto/src/runner.rs +++ b/sshproto/src/runner.rs @@ -67,7 +67,8 @@ impl<'a> Runner<'a> { /// Drives connection progress, handling received payload and sending /// other packets as required. This must be polled/awaited regularly. /// Optionally returns `Event` which provides channel or session - // event to the application. + /// event to the application. + /// [`done_payload()`] must be called after any `Ok` result. pub async fn progress<'f>(&'f mut self, b: &mut Behaviour<'_>) -> Result<Option<Event<'f>>, Error> { let em = if let Some(payload) = self.traffic.payload() { // Lifetimes here are a bit subtle. @@ -100,6 +101,7 @@ impl<'a> Runner<'a> { // "Borrow checker extends borrow range in code with early return" // https://github.com/rust-lang/rust/issues/54663 let ev = if let Some(em) = em { + trace!("em"); match em { EventMaker::Channel(ChanEventMaker::DataIn(di)) => { trace!("chanmaaker {di:?}"); @@ -115,6 +117,7 @@ impl<'a> Runner<'a> { } } } else { + trace!("no em, conn progress"); self.conn.progress(&mut self.traffic, &mut self.keys, b).await?; self.wake(); None @@ -131,7 +134,6 @@ impl<'a> Runner<'a> { } pub fn wake(&mut self) { - error!("wake"); if self.ready_input() { trace!("wake ready_input, waker {:?}", self.input_waker); if let Some(w) = self.input_waker.take() { @@ -169,17 +171,24 @@ impl<'a> Runner<'a> { } /// Send data from this application out the wire. - /// Must have already checked `ready_channel_send()`. - /// Returns the length of `buf` consumed. + /// Returns `Some` the length of `buf` consumed, or `None` on EOF pub fn channel_send( &mut self, chan: u32, ext: Option<u32>, buf: &[u8], - ) -> Result<usize> { - let (p, len) = self.conn.channels.send_data(chan, ext, buf)?; + ) -> Result<Option<usize>> { + let len = self.ready_channel_send(chan); + let len = match len { + Some(l) => l, + None => return Ok(None), + }; + + let len = len.min(buf.len()); + + let p = self.conn.channels.send_data(chan, ext, &buf[..len])?; self.traffic.send_packet(p, &mut self.keys)?; - Ok(len) + Ok(Some(len)) } /// Receive data coming from the wire into this application @@ -219,13 +228,14 @@ impl<'a> Runner<'a> { } pub fn channel_eof(&self, chan: u32) -> bool { - self.conn.channels.recv_eof(chan) + self.conn.channels.have_recv_eof(chan) } - // TODO check the chan/ext are valid, SSH window - pub fn ready_channel_send(&self, _chan: u32, _ext: Option<u32>) -> bool { - self.traffic.can_output() - // && self.conn.channels.ready_send_data(chan, ext) + // Returns None on channel closed + pub fn ready_channel_send(&self, chan: u32) -> Option<usize> { + // minimum of buffer space and channel window available + let buf_space = self.traffic.send_allowed(&self.keys); + self.conn.channels.send_allowed(chan).map(|s| s.min(buf_space)) } // pub fn chan_pending(&self) -> bool { diff --git a/sshproto/src/traffic.rs b/sshproto/src/traffic.rs index 59ebccd..09c60be 100644 --- a/sshproto/src/traffic.rs +++ b/sshproto/src/traffic.rs @@ -120,8 +120,9 @@ impl<'a> Traffic<'a> { } } + /// A simple test if a packet can be sent. `send_allowed` should be used + /// for more general situations pub fn can_output(&self) -> bool { - // TODO: test for full output buffer match self.state { TrafState::Write { .. } | TrafState::Idle => true, @@ -129,6 +130,20 @@ impl<'a> Traffic<'a> { } } + /// Returns payload space available to send a packet. Returns 0 if not ready or full + pub fn send_allowed(&self, keys: &KeyState) -> usize { + // TODO: test for full output buffer + match self.state { + TrafState::Write { len, .. } => { + keys.max_enc_payload(self.buf.len() - len) + } + TrafState::Idle => { + keys.max_enc_payload(self.buf.len()) + } + _ => 0 + } + } + /// Returns the number of bytes consumed. pub fn input( &mut self, keys: &mut KeyState, remote_version: &mut RemoteVersion, @@ -201,13 +216,13 @@ impl<'a> Traffic<'a> { match self.state { | TrafState::InPayload { .. } | TrafState::BorrowPayload { .. } - | TrafState::Idle // TODO, is this wise? => { self.state = TrafState::Idle; Ok(()) } _ => { - /* Just ignore it */ + // Just ignore it + warn!("done_payload called without payload"); Ok(()) } } -- GitLab