diff --git a/embassy/demos/std/src/main.rs b/embassy/demos/std/src/main.rs index ecdde261e422368fe8651b385dde2fd65c374cd3..83a6ca6242a606e744f975513ed96fe38b7b2746 100644 --- a/embassy/demos/std/src/main.rs +++ b/embassy/demos/std/src/main.rs @@ -171,11 +171,11 @@ impl DemoShell { loop { let mut b = [0u8; 100]; - let lr = serv.read_channel(chan, None, &mut b).await?; + let lr = serv.read_channel_stdin(chan, &mut b).await?; let b = &mut b[..lr]; for c in b.iter_mut() { - if *c >= b'0' && *c <= b'9' { - *c = b'0' + (b'9' - *c) + if *c >= b'1' && *c <= b'9' { + *c = b'1' + (b'9' - *c) } } let lw = serv.write_channel(chan, None, b).await?; @@ -215,6 +215,8 @@ static EXECUTOR: StaticCell<Executor> = StaticCell::new(); fn main() { env_logger::builder() .filter_level(log::LevelFilter::Trace) + .filter_module("sunset::runner", log::LevelFilter::Info) + .filter_module("sunset::traffic", log::LevelFilter::Info) .filter_module("async_io", log::LevelFilter::Info) .filter_module("polling", log::LevelFilter::Info) .format_timestamp_nanos() diff --git a/embassy/src/embassy_sunset.rs b/embassy/src/embassy_sunset.rs index 8191dbd07390a3935b2c03b828ddecf747fefd3f..fb6635b2a787d1e3409791ac0d33f1ef5084f428 100644 --- a/embassy/src/embassy_sunset.rs +++ b/embassy/src/embassy_sunset.rs @@ -115,7 +115,7 @@ impl<'a> EmbassySunset<'a> { fn wake_channels(&self, inner: &mut Inner) { - if let Some((chan, _ext)) = inner.runner.ready_channel_input() { + if let Some((chan, _ext, _len)) = inner.runner.ready_channel_input() { // TODO: if there isn't any waker waiting, should we just drop the packet? inner.chan_read_wakers[chan as usize].wake() } @@ -216,17 +216,18 @@ impl<'a> EmbassySunset<'a> { }).await } - // TODO: should there be a variant that polls for either normal/ext, and - // returns it as a flag? - pub async fn read_channel(&self, ch: u32, ext: Option<u32>, buf: &mut [u8]) -> Result<usize> { + /// Reads normal channel data. If extended data is pending it will be discarded. + pub async fn read_channel_stdin(&self, ch: u32, buf: &mut [u8]) -> Result<usize> { if ch as usize > MAX_CHANNELS { return Err(Error::BadChannel) } self.poll_inner(|inner, cx| { - let l = inner.runner.channel_input(ch, ext, buf); + let l = inner.runner.channel_input(ch, None, buf); if let Ok(0) = l { // 0 bytes read, pending inner.chan_read_wakers[ch as usize].register(cx.waker()); + // discard any `ext` input for this channel + inner.runner.discard_channel_input(ch); Poll::Pending } else { Poll::Ready(l) @@ -251,3 +252,4 @@ impl<'a> EmbassySunset<'a> { }).await } } + diff --git a/embassy/src/server.rs b/embassy/src/server.rs index 2017c9792ce34cea90b055c55a848793f7a1006b..b02f2d2855b6a12ea2ce4712d8ba445b8b37e70c 100644 --- a/embassy/src/server.rs +++ b/embassy/src/server.rs @@ -28,8 +28,8 @@ impl<'a> SSHServer<'a> { self.sunset.run(socket, b).await } - pub async fn read_channel(&self, ch: u32, ext: Option<u32>, buf: &mut [u8]) -> Result<usize> { - self.sunset.read_channel(ch, ext, buf).await + pub async fn read_channel_stdin(&self, ch: u32, buf: &mut [u8]) -> Result<usize> { + self.sunset.read_channel_stdin(ch, buf).await } pub async fn write_channel(&self, ch: u32, ext: Option<u32>, buf: &[u8]) -> Result<usize> { diff --git a/src/channel.rs b/src/channel.rs index eafddbe7e0247f974f66748c4b5657e61870caf4..a6663e87634faf8783a5c1088a84a95756dabde4 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -222,6 +222,10 @@ impl Channels { self.get(num).map_or(Some(0), |c| c.send_allowed()) } + pub(crate) fn valid_send(&self, num: u32, ext: Option<u32>) -> bool { + self.get(num).map_or(false, |c| c.valid_send(ext)) + } + fn dispatch_open( &mut self, p: &ChannelOpen<'_>, @@ -738,6 +742,12 @@ impl Channel { self.send.as_ref().map(|s| usize::max(s.window, s.max_packet)) } + pub(crate) fn valid_send(&self, ext: Option<u32>) -> bool { + // TODO: later we should only allow non-pty "session" channels + // to have ext, for stderr only. + true + } + /// Returns a window adjustment packet if required fn check_window_adjust(&mut self) -> Result<Option<Packet>> { let send = self.send.as_mut().trap()?; diff --git a/src/runner.rs b/src/runner.rs index 3b22fa34c34250013cafa153d775892a3e1a2c7a..46d6d89a4d542d54687950cb4c96631947e32692 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -234,11 +234,18 @@ impl<'a> Runner<'a> { Ok(len) } + /// Discards any channel input data pending for `chan`, regardless of whether + /// normal or `ext`. + pub fn discard_channel_input(&mut self, chan: u32) { + self.traf_in.discard_channel_input(chan) + } + /// When channel data is ready, returns a tuple - /// `Some((channel, ext))` where `ext` is `None` for stdout, `Some(exttype)` - /// for extended types (like stderr). - /// Returns `None` if none ready. - pub fn ready_channel_input(&self) -> Option<(u32, Option<u32>)> { + /// `Some((channel, ext, len))` where `ext` is `None` for stdout + /// or `Some(sshnames::SSH_EXTENDED_DATA_STDERR)` for stderr. + /// `len` is the amount of data ready remaining to read, will always be non-zero. + /// Returns `None` if no data ready. + pub fn ready_channel_input(&self) -> Option<(u32, Option<u32>, usize)> { self.traf_in.ready_channel_input() } @@ -260,6 +267,12 @@ impl<'a> Runner<'a> { self.conn.channels.send_allowed(chan).map(|s| s.min(payload_space)) } + /// Returns `true` if the channel and `ext` are currently valid for writing. + /// Note that they may not be ready to send output. + pub fn valid_channel_send(&self, chan: u32, ext: Option<u32>) -> bool { + self.conn.channels.valid_send(chan, ext) + } + /// Must be called when an application has finished with a channel. /// /// Channel numbers will not be re-used without calling this, so diff --git a/src/traffic.rs b/src/traffic.rs index 36a9f22066b342f86ec45360e910772bd19a56bd..f0a4ed8605827861a1eeb9490a81452af9b06632 100644 --- a/src/traffic.rs +++ b/src/traffic.rs @@ -83,7 +83,7 @@ enum RxState { ext: Option<u32>, /// read index of channel data. should transition to Idle once `idx==len` idx: usize, - /// length of buffer, end of channel data + /// length of channel data len: usize, }, } @@ -216,9 +216,14 @@ impl<'a> TrafIn<'a> { Ok(buf.len() - r.len()) } - pub fn ready_channel_input(&self) -> Option<(u32, Option<u32>)> { + /// Returns `(channel, ext, length)` + pub fn ready_channel_input(&self) -> Option<(u32, Option<u32>, usize)> { match self.state { - RxState::InChannelData { chan, ext, .. } => Some((chan, ext)), + RxState::InChannelData { chan, ext, idx, len } => { + let rem = len - idx; + debug_assert!(rem > 0); + Some((chan, ext, rem)) + }, _ => None, } } @@ -271,6 +276,16 @@ impl<'a> TrafIn<'a> { } } + pub fn discard_channel_input(&mut self, chan: u32) { + match self.state { + RxState::InChannelData { chan: c, .. } + if c == chan => { + self.state = RxState::Idle; + } + _ => () + } + } + } impl<'a> TrafOut<'a> {