From 04b4d17f77679db4d8bb97ead2212a1268d1ea79 Mon Sep 17 00:00:00 2001 From: Matt Johnston <matt@ucc.asn.au> Date: Thu, 17 Nov 2022 23:25:45 +0800 Subject: [PATCH] Implement embassy channel read/write Added a generic poll_inner() to embassy_sunset Added ChannelEOF error return type instead of None for EOF Added app_done for channel handling code Remove some unneeded trait functions in std demo --- async/src/server.rs | 3 +- embassy/demos/std/src/main.rs | 63 ++++++++-------- embassy/src/embassy_sunset.rs | 136 ++++++++++++++++++++-------------- embassy/src/server.rs | 8 ++ src/behaviour.rs | 11 ++- src/channel.rs | 49 +++++++++--- src/error.rs | 7 ++ src/packets.rs | 13 +--- src/runner.rs | 74 +++++++++++++----- src/traffic.rs | 4 +- 10 files changed, 238 insertions(+), 130 deletions(-) diff --git a/async/src/server.rs b/async/src/server.rs index 850b4e6..7dd7e5b 100644 --- a/async/src/server.rs +++ b/async/src/server.rs @@ -51,7 +51,8 @@ impl<'a> SSHServer<'a> { } pub async fn channel(&mut self, ch: u32) -> Result<(ChanInOut<'a>, Option<ChanExtOut<'a>>)> { - let ty = self.sunset.with_runner(|r| r.channel_type(ch)).await?; + // TODO: what was this for? + // let ty = self.sunset.with_runner(|r| r.channel_type(ch)).await?; let inout = ChanInOut::new(ch, &self.sunset); // TODO ext let ext = None; diff --git a/embassy/demos/std/src/main.rs b/embassy/demos/std/src/main.rs index c7e5c4b..d8231a8 100644 --- a/embassy/demos/std/src/main.rs +++ b/embassy/demos/std/src/main.rs @@ -10,10 +10,11 @@ use core::future::Future; use core::todo; use embassy_executor::{Spawner, Executor}; use embassy_sync::mutex::Mutex; -use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embassy_sync::blocking_mutex::raw::{NoopRawMutex, CriticalSectionRawMutex}; +use embassy_sync::signal::Signal; use embassy_net::tcp::TcpSocket; use embassy_net::{Stack, StackResources, ConfigStrategy}; -use embassy_futures::join::join3; +use embassy_futures::join::join; use embedded_io::asynch::{Read, Write}; use static_cell::StaticCell; @@ -27,6 +28,7 @@ use rand::rngs::OsRng; use rand::RngCore; use sunset::*; +use sunset::error::TrapBug; use sunset_embassy::SSHServer; use crate::tuntap::TunTapDevice; @@ -107,6 +109,8 @@ struct DemoServer { sess: Option<u32>, want_shell: bool, shell_started: bool, + + notify: Signal<CriticalSectionRawMutex, ()>, } impl DemoServer { @@ -119,6 +123,7 @@ impl DemoServer { keys, want_shell: false, shell_started: false, + notify: Signal::new(), }) } } @@ -128,37 +133,11 @@ impl ServBehaviour for DemoServer { Ok(&self.keys) } - - fn have_auth_password(&self, user: TextString) -> bool { - true - } - - fn have_auth_pubkey(&self, user: TextString) -> bool { - false - } - fn auth_unchallenged(&mut self, username: TextString) -> bool { info!("Allowing auth for user {:?}", username.as_str()); true } - fn auth_password(&mut self, user: TextString, password: TextString) -> bool { - user.as_str().unwrap_or("") == "matt" && password.as_str().unwrap_or("") == "pw" - } - - // fn auth_pubkey(&mut self, user: TextString, pubkey: &PubKey) -> bool { - // if user.as_str().unwrap_or("") != "matt" { - // return false - // } - - // // key is tested1 - // pubkey.matches_openssh("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMkNdReJERy1rPGqdfTN73TnayPR+lTNhdZvOgkAOs5x") - // .unwrap_or_else(|e| { - // warn!("Failed loading openssh key: {e}"); - // false - // }) - // } - fn open_session(&mut self, chan: u32) -> ChanOpened { if self.sess.is_some() { ChanOpened::Failure(ChanFail::SSH_OPEN_ADMINISTRATIVELY_PROHIBITED) @@ -171,6 +150,7 @@ impl ServBehaviour for DemoServer { fn sess_shell(&mut self, chan: u32) -> bool { let r = !self.want_shell && self.sess == Some(chan); self.want_shell = true; + self.notify.signal(()); trace!("req want shell"); r } @@ -180,6 +160,25 @@ impl ServBehaviour for DemoServer { } } +async fn shell_fut<'f>(serv: &SSHServer<'f>, app: &Mutex<NoopRawMutex, DemoServer>) -> Result<()> +{ + let session = async { + // self.notify.wait()?; + let chan = app.lock().await.sess.trap()?; + + loop { + let mut b = [0u8; 100]; + let lr = serv.read_channel(chan, None, &mut b).await?; + let lw = serv.write_channel(chan, None, &b[..lr]).await?; + if lr != lw { + trace!("read/write mismatch {} {}", lr, lw); + } + } + Ok(()) + }; + session.await +} + async fn session(socket: &mut TcpSocket<'_>) -> sunset::Result<()> { let mut app = DemoServer::new()?; @@ -189,9 +188,15 @@ async fn session(socket: &mut TcpSocket<'_>) -> sunset::Result<()> { let serv = &serv; let app = Mutex::<NoopRawMutex, _>::new(app); + + let session = shell_fut(serv, &app); + let app = &app as &Mutex::<NoopRawMutex, dyn ServBehaviour>; + let run = serv.run(socket, app); + + join(run, session).await; - serv.run(socket, app).await + Ok(()) } static EXECUTOR: StaticCell<Executor> = StaticCell::new(); diff --git a/embassy/src/embassy_sunset.rs b/embassy/src/embassy_sunset.rs index b28c362..b3ab0fe 100644 --- a/embassy/src/embassy_sunset.rs +++ b/embassy/src/embassy_sunset.rs @@ -4,7 +4,7 @@ use { }; use core::future::{poll_fn, Future}; -use core::task::Poll; +use core::task::{Poll, Context}; use embassy_sync::waitqueue::WakerRegistration; use embassy_sync::mutex::Mutex; @@ -15,14 +15,19 @@ use embassy_net::tcp::TcpSocket; use pin_utils::pin_mut; -use sunset::{Runner, Result, Behaviour, ServBehaviour, CliBehaviour}; +use sunset::{Runner, Result, Error, Behaviour, ServBehaviour, CliBehaviour}; use sunset::config::MAX_CHANNELS; pub(crate) struct Inner<'a> { - pub runner: Runner<'a>, + runner: Runner<'a>, - pub chan_read_wakers: [WakerRegistration; MAX_CHANNELS], - pub chan_write_wakers: [WakerRegistration; MAX_CHANNELS], + chan_read_wakers: [WakerRegistration; MAX_CHANNELS], + + chan_write_wakers: [WakerRegistration; MAX_CHANNELS], + /// this is set `true` when the associated `chan_write_wakers` entry + /// was set for an ext write. This is needed because ext writes + /// require more buffer, so have different wake conditions. + ext_write_waker: [bool; MAX_CHANNELS], } pub struct EmbassySunset<'a> { @@ -37,6 +42,7 @@ impl<'a> EmbassySunset<'a> { runner, chan_read_wakers: Default::default(), chan_write_wakers: Default::default(), + ext_write_waker: Default::default(), }; let inner = Mutex::new(inner); @@ -103,19 +109,19 @@ impl<'a> EmbassySunset<'a> { fn wake_channels(&self, inner: &mut Inner) { + if let Some((chan, _ext)) = inner.runner.ready_channel_input() { + inner.chan_read_wakers[chan as usize].wake() + } - if let Some((chan, _ext)) = inner.runner.ready_channel_input() { - inner.chan_read_wakers[chan as usize].wake() - } - - for chan in 0..MAX_CHANNELS { - if inner.runner.ready_channel_send(chan as u32).unwrap_or(0) > 0 { - inner.chan_write_wakers[chan].wake() - } + for chan in 0..MAX_CHANNELS { + let ext = inner.ext_write_waker[chan]; + if inner.runner.ready_channel_send(chan as u32, ext).unwrap_or(0) > 0 { + inner.chan_write_wakers[chan].wake() } + } } - // XXX could we have a concrete NoopRawMutex instead of M? + // XXX should we have a concrete NoopRawMutex instead of M? pub async fn progress<M, B: ?Sized>(&self, b: &Mutex<M, B>) -> Result<()> @@ -129,8 +135,6 @@ impl<'a> EmbassySunset<'a> { { { let mut b = b.lock().await; - warn!("progress locked"); - // XXX: unsure why we need this explicit type let b: &mut B = &mut b; let mut b: Behaviour = b.into(); inner.runner.progress(&mut b).await?; @@ -141,30 +145,23 @@ impl<'a> EmbassySunset<'a> { } // inner dropped } - warn!("progress unlocked"); // idle until input is received // TODO do we also want to wake in other situations? self.progress_notify.wait().await; + Ok(()) } - pub async fn read(&self, buf: &mut [u8]) -> Result<usize> { + async fn poll_inner<F, T>(&self, mut f: F) -> T + where F: FnMut(&mut Inner, &mut Context) -> Poll<T> { poll_fn(|cx| { // Attempt to lock .inner let i = self.inner.lock(); pin_mut!(i); let r = match i.poll(cx) { Poll::Ready(mut inner) => { - match inner.runner.output(buf) { - // no output ready - Ok(0) => { - inner.runner.set_output_waker(cx.waker()); - Poll::Pending - } - Ok(n) => Poll::Ready(Ok(n)), - Err(e) => Poll::Ready(Err(e)), - } + f(&mut inner, cx) } Poll::Pending => { // .inner lock is busy @@ -176,38 +173,67 @@ impl<'a> EmbassySunset<'a> { .await } + pub async fn read(&self, buf: &mut [u8]) -> Result<usize> { + self.poll_inner(|inner, cx| { + match inner.runner.output(buf) { + // no output ready + Ok(0) => { + inner.runner.set_output_waker(cx.waker()); + Poll::Pending + } + Ok(n) => Poll::Ready(Ok(n)), + Err(e) => Poll::Ready(Err(e)), + } + }).await + } + pub async fn write(&self, buf: &[u8]) -> Result<usize> { - poll_fn(|cx| { - let i = self.inner.lock(); - pin_mut!(i); - let r = match i.poll(cx) { - Poll::Ready(mut inner) => { - if inner.runner.ready_input() { - match inner.runner.input(buf) { - Ok(0) => { - inner.runner.set_input_waker(cx.waker()); - Poll::Pending - }, - Ok(n) => Poll::Ready(Ok(n)), - Err(e) => Poll::Ready(Err(e)), - } - } else { + self.poll_inner(|inner, cx| { + if inner.runner.ready_input() { + match inner.runner.input(buf) { + Ok(0) => { + inner.runner.set_input_waker(cx.waker()); Poll::Pending - } + }, + Ok(n) => Poll::Ready(Ok(n)), + Err(e) => Poll::Ready(Err(e)), } - Poll::Pending => { - // .inner lock is busy - Poll::Pending - } - }; - if r.is_ready() { - // wake up .progress() to handle the input - self.progress_notify.signal(()) + } else { + Poll::Pending } - r - }) - .await + }).await + } + + pub async fn read_channel(&self, ch: u32, ext: Option<u32>, buf: &mut [u8]) -> Result<usize> { + if ch as usize > MAX_CHANNELS { + return Err(Error::BadChannel) + } + self.poll_inner(|inner, cx| { + let l = inner.runner.channel_input(ch, ext, buf); + if let Ok(0) = l { + // 0 bytes read, pending + inner.chan_read_wakers[ch as usize].register(cx.waker()); + Poll::Pending + } else { + Poll::Ready(l) + } + }).await } - // pub async fn read_channel(&self, buf: &mut [u8]) -> Result<usize> { + pub async fn write_channel(&self, ch: u32, ext: Option<u32>, buf: &[u8]) -> Result<usize> { + if ch as usize > MAX_CHANNELS { + return Err(Error::BadChannel) + } + self.poll_inner(|inner, cx| { + let l = inner.runner.channel_send(ch, ext, buf); + if let Ok(0) = l { + // 0 bytes written, pending + inner.chan_write_wakers[ch as usize].register(cx.waker()); + inner.ext_write_waker[ch as usize] = ext.is_some(); + Poll::Pending + } else { + Poll::Ready(l) + } + }).await + } } diff --git a/embassy/src/server.rs b/embassy/src/server.rs index 5ecec6c..b7e59f6 100644 --- a/embassy/src/server.rs +++ b/embassy/src/server.rs @@ -28,4 +28,12 @@ impl<'a> SSHServer<'a> { { self.sunset.run(socket, b).await } + + pub async fn read_channel(&self, ch: u32, ext: Option<u32>, buf: &mut [u8]) -> Result<usize> { + self.sunset.read_channel(ch, ext, buf).await + } + + pub async fn write_channel(&self, ch: u32, ext: Option<u32>, buf: &[u8]) -> Result<usize> { + self.sunset.write_channel(ch, ext, buf).await + } } diff --git a/src/behaviour.rs b/src/behaviour.rs index c2a2768..62adb23 100644 --- a/src/behaviour.rs +++ b/src/behaviour.rs @@ -176,9 +176,16 @@ pub trait ServBehaviour: Sync+Send { // be loaded on the stack rather than kept in memory for the whole lifetime. fn hostkeys(&mut self) -> BhResult<&[sign::SignKey]>; + #[allow(unused)] // TODO: or return a slice of enums - fn have_auth_password(&self, username: TextString) -> bool; - fn have_auth_pubkey(&self, username: TextString) -> bool; + fn have_auth_password(&self, username: TextString) -> bool { + false + } + + #[allow(unused)] + fn have_auth_pubkey(&self, username: TextString) -> bool { + false + } #[allow(unused)] /// Return true to allow the user to log in with no authentication diff --git a/src/channel.rs b/src/channel.rs index ed68aa3..eafddbe 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -83,7 +83,7 @@ impl Channels { } /// Returns a `Channel` for a local number, any state including `InOpen`. - pub fn get_any(&self, num: u32) -> Result<&Channel> { + fn get_any(&self, num: u32) -> Result<&Channel> { self.ch .get(num as usize) // out of range @@ -94,7 +94,7 @@ impl Channels { } /// Returns a `Channel` for a local number. Excludes `InOpen` state. - pub fn get(&self, num: u32) -> Result<&Channel> { + fn get(&self, num: u32) -> Result<&Channel> { let ch = self.get_any(num)?; if matches!(ch.state, ChanState::InOpen) { @@ -104,7 +104,7 @@ impl Channels { } } - pub fn get_mut(&mut self, num: u32) -> Result<&mut Channel> { + fn get_mut(&mut self, num: u32) -> Result<&mut Channel> { let ch = self .ch .get_mut(num as usize) @@ -121,11 +121,27 @@ impl Channels { } } + /// Must be called when an application has finished with a channel. + pub fn done(&mut self, num: u32) -> Result<()> { + self.get_mut(num)?.app_done = true; + Ok(()) + } + fn remove(&mut self, num: u32) -> Result<()> { // TODO any checks? - *self.ch.get_mut(num as usize).ok_or(Error::BadChannel)? = None; - todo!(); - // Ok(()) + let ch = self.ch.get_mut(num as usize).ok_or(Error::BadChannel)?; + if let Some(c) = ch { + if c.app_done { + trace!("removing channel {}", num); + *ch = None; + } else { + c.state = ChanState::PendingDone; + trace!("not removing channel {}, not finished", num); + } + Ok(()) + } else{ + Err(Error::BadChannel) + } } /// Returns the first available channel @@ -155,14 +171,17 @@ impl Channels { Ok(ch.as_mut().unwrap()) } - /// Returns the channel data packet to send, and the length of data consumed. - /// Caller has already checked valid length with send_allowed() + /// Returns the channel data packet to send. + /// Caller has already checked valid length with send_allowed(). + /// Don't call with zero length data. pub(crate) fn send_data<'b>( &mut self, num: u32, ext: Option<u32>, data: &'b [u8], ) -> Result<Packet<'b>> { + debug_assert!(data.len() > 0); + let send = self.get_mut(num)?.send.as_mut().trap()?; if data.len() > send.max_packet || data.len() > send.window { return Err(Error::bug()); @@ -376,7 +395,7 @@ impl Channels { let di = DataIn { num: p.num, ext: None, - offset: p.data_offset(), + offset: ChannelData::DATA_OFFSET, len: p.data.0.len(), }; disp = Dispatched(Some(di)); @@ -392,7 +411,7 @@ impl Channels { let di = DataIn { num: p.num, ext: Some(p.code), - offset: p.data_offset(), + offset: ChannelDataExt::DATA_OFFSET, len: p.data.0.len(), }; trace!("{di:?}"); @@ -588,11 +607,11 @@ enum ChanState { init_req: InitReqs, }, Normal, - RecvEof, - // TODO: recvclose state probably shouldn't be possible, we remove it straight away? RecvClose, + /// The channel is unused and ready to close after a call to `done()` + PendingDone, } pub(crate) struct Channel { @@ -611,6 +630,11 @@ pub(crate) struct Channel { pending_adjust: usize, full_window: usize, + + /// Set once application has called `done()`. The channel + /// will only be removed from the list + /// (allowing channel number re-use) if `app_done` is set + app_done: bool, } impl Channel { @@ -629,6 +653,7 @@ impl Channel { send: None, pending_adjust: 0, full_window: config::DEFAULT_WINDOW, + app_done: false, } } diff --git a/src/error.rs b/src/error.rs index 18431ba..d347250 100644 --- a/src/error.rs +++ b/src/error.rs @@ -59,6 +59,13 @@ pub enum Error { /// SSH packet contents doesn't match length WrongPacketLength, + /// Channel EOF + /// + /// This is an expected error when a SSH channel completes. Can be returned + /// by channel read/write functions. Any further calls in the same direction + /// and with the same `ext`) will fail similarly. + ChannelEOF, + // Used for unknown key types etc. #[snafu(display("{what} is not available"))] NotAvailable { what: &'static str }, diff --git a/src/packets.rs b/src/packets.rs index 0f16f63..cb63731 100644 --- a/src/packets.rs +++ b/src/packets.rs @@ -434,10 +434,8 @@ pub struct ChannelData<'a> { } impl ChannelData<'_> { - // offset into a packet of the raw data - pub(crate) fn data_offset(&self) -> usize { - 9 - } + // offset into a packet payload, includes packet type byte + pub const DATA_OFFSET: usize = 9; } #[derive(Debug,SSHEncode, SSHDecode)] @@ -448,11 +446,8 @@ pub struct ChannelDataExt<'a> { } impl ChannelDataExt<'_> { - // offset into a packet payload - pub(crate) fn data_offset(&self) -> usize { - // offset into a packet of the raw data - 13 - } + // offset into a packet payload, includes packet type byte + pub const DATA_OFFSET: usize = 13; } #[derive(Debug,SSHEncode, SSHDecode)] diff --git a/src/runner.rs b/src/runner.rs index 10d3db7..413aff8 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -9,6 +9,7 @@ use core::task::{Poll, Waker}; use pretty_hex::PrettyHex; use crate::*; +use packets::{ChannelDataExt, ChannelData}; use encrypt::KeyState; use traffic::{TrafIn, TrafOut, TrafSend}; @@ -81,8 +82,14 @@ impl<'a> Runner<'a> { Ok(runner) } - /// Drives connection progress, handling received payload and sending - /// other packets as required. This must be polled/awaited regularly. + /// Drives connection progress, handling received payload and queueing + /// other packets to send as required. + /// + /// This must be polled/awaited regularly, passing in `behaviour`. + /// + /// This method is async but will not block unless the `Behaviour` implementation + /// does so. Note that some computationally intensive operations may be performed + /// during key exchange. pub async fn progress(&mut self, behaviour: &mut Behaviour<'_>) -> Result<()> { let mut s = self.traf_out.sender(&mut self.keys); // Handle incoming packets @@ -176,23 +183,27 @@ impl<'a> Runner<'a> { Ok(chan) } - pub fn channel_type(&self, chan: u32) -> Result<channel::ChanType> { - self.conn.channels.get(chan).map(|c| c.ty) - } + // pub fn channel_type(&self, chan: u32) -> Result<channel::ChanType> { + // self.conn.channels.get(chan).map(|c| c.ty) + // } /// Send data from this application out the wire. - /// Returns `Some` the length of `buf` consumed, or `None` on EOF + /// Returns `Ok(len)` consumed, `Err(Error::ChannelEof)` on EOF, + /// or other errors. pub fn channel_send( &mut self, chan: u32, ext: Option<u32>, buf: &[u8], - ) -> Result<Option<usize>> { - let len = self.ready_channel_send(chan); + ) -> Result<usize> { + if buf.len() == 0 { + return Ok(0) + } + let len = self.ready_channel_send(chan, ext.is_some()); let len = match len { - Some(l) if l == 0 => return Ok(Some(0)), + Some(l) if l == 0 => return Ok(0), Some(l) => l, - None => return Ok(None), + None => return Err(Error::ChannelEOF), }; let len = len.min(buf.len()); @@ -200,10 +211,13 @@ impl<'a> Runner<'a> { let p = self.conn.channels.send_data(chan, ext, &buf[..len])?; self.traf_out.send_packet(p, &mut self.keys)?; self.wake(); - Ok(Some(len)) + Ok(len) } - /// Receive data coming from the wire into this application + /// Receive data coming from the wire into this application. + /// Returns `Ok(len)` received, `Err(Error::ChannelEof)` on EOF, + /// or other errors. + /// TODO: EOF is unimplemented pub fn channel_input( &mut self, chan: u32, @@ -213,9 +227,9 @@ impl<'a> Runner<'a> { trace!("runner chan in"); let (len, complete) = self.traf_in.channel_input(chan, ext, buf); if complete { - let p = self.conn.channels.finished_input(chan)?; - if let Some(p) = p { - self.traf_out.send_packet(p, &mut self.keys)?; + let wind_adjust = self.conn.channels.finished_input(chan)?; + if let Some(wind_adjust) = wind_adjust { + self.traf_out.send_packet(wind_adjust, &mut self.keys)?; } self.wake(); } @@ -234,11 +248,29 @@ impl<'a> Runner<'a> { self.conn.channels.have_recv_eof(chan) } - // Returns None on channel closed - pub fn ready_channel_send(&self, chan: u32) -> Option<usize> { + // Returns the maximum data that may be sent to a channel, or + // `None` on channel closed + pub fn ready_channel_send(&self, chan: u32, is_ext: bool) -> Option<usize> { // minimum of buffer space and channel window available - let buf_space = self.traf_out.send_allowed(&self.keys); - self.conn.channels.send_allowed(chan).map(|s| s.min(buf_space)) + let payload_space = self.traf_out.send_allowed(&self.keys); + let offset = if is_ext { + ChannelDataExt::DATA_OFFSET + } else { + ChannelData::DATA_OFFSET + }; + let payload_space = payload_space.saturating_sub(offset); + self.conn.channels.send_allowed(chan).map(|s| s.min(payload_space)) + } + + /// Must be called when an application has finished with a channel. + /// + /// Channel numbers will not be re-used without calling this, so + /// failing to call this can result in running out of channels. + /// + /// Any further calls using the same channel number may result + /// in data from a different channel re-using the same number. + pub fn channel_done(&mut self, chan: u32) -> Result<()> { + self.conn.channels.done(chan) } pub fn term_window_change(&self, _chan: u32, _wc: packets::WinChange) -> Result<()> { @@ -272,5 +304,9 @@ impl<'a> Runner<'a> { } } } +} +#[cfg(test)] +mod tests { + // TODO: test send_allowed() limits } diff --git a/src/traffic.rs b/src/traffic.rs index f480b80..de8db59 100644 --- a/src/traffic.rs +++ b/src/traffic.rs @@ -253,9 +253,7 @@ impl<'a> TrafIn<'a> { match self.state { RxState::InChannelData { chan: c, ext: e, ref mut idx, len } if (c, e) == (chan, ext) => { - if *idx > len { - error!("bad idx {} len {} e {:?} c {}", *idx, len, e, c); - } + debug_assert!(len >= *idx); let wlen = (len - *idx).min(buf.len()); buf[..wlen].copy_from_slice(&self.buf[*idx..*idx + wlen]); // info!("idx {} += wlen {} = {}", *idx, wlen, *idx+wlen); -- GitLab