diff --git a/async/src/cmdline_client.rs b/async/src/cmdline_client.rs index 66e54819e41cc15d41903d2d7f9d5e26b3ce586f..e2c97f5e8a1fd1f075834c0488eb0e59d2d59719 100644 --- a/async/src/cmdline_client.rs +++ b/async/src/cmdline_client.rs @@ -13,6 +13,7 @@ use sunset_embassy::*; use std::collections::VecDeque; use embassy_sync::channel::{Channel, Sender, Receiver}; +use embassy_sync::signal::Signal; use embedded_io::asynch::{Read as _, Write as _}; use tokio::io::AsyncReadExt; @@ -22,6 +23,7 @@ use tokio::signal::unix::{signal, SignalKind}; use futures::{select_biased, future::Fuse}; use futures::FutureExt; +use pretty_hex::PrettyHex; use crate::*; use crate::AgentClient; @@ -176,8 +178,8 @@ impl<'a> CmdlineRunner<'a> { } async fn chan_run(io: ChanInOut<'a, CmdlineHooks<'a>, UnusedServ>, - io_err: Option<ChanIn<'a, CmdlineHooks<'a>, UnusedServ>>) -> Result<()> { - trace!("chan_run top"); + io_err: Option<ChanIn<'a, CmdlineHooks<'a>, UnusedServ>>, + mut pty_guard: Option<RawPtyGuard>) -> Result<()> { // out let fo = async { let mut io = io.clone(); @@ -221,21 +223,66 @@ impl<'a> CmdlineRunner<'a> { } }; + let terminate = Signal::<SunsetRawMutex, ()>::new(); + // in let fi = async { let mut io = io.clone(); let mut si = crate::stdin().map_err(|_| Error::msg("opening stdin failed"))?; + let mut esc = if pty_guard.is_some() { + Some(Escaper::new()) + } else { + None + }; + loop { // TODO buffers let mut buf = [0u8; 1000]; let l = si.read(&mut buf).await.map_err(|_| Error::ChannelEOF)?; - io.write_all(&buf[..l]).await?; + + let buf = &buf[..l]; + + if let Some(ref mut esc) = esc { + let a = esc.escape(buf); + match a { + EscapeAction::None => (), + EscapeAction::Output { extra } => { + if let Some(e) = extra { + io.write_all(&[e]).await?; + } + io.write_all(buf).await?; + } + EscapeAction::Terminate => { + info!("Terminated"); + terminate.signal(()); + return Ok(()) + } + EscapeAction::Suspend => { + // disabled for the time being, doesn't resume OK. + // perhaps a bad interaction with dup_async(), + // maybe the new guard needs to be on the dup-ed + // FDs? + () + + // pty_guard = None; + // nix::sys::signal::raise(nix::sys::signal::Signal::SIGTSTP) + // .unwrap_or_else(|e| { + // warn!("Failed to stop: {e:?}"); + // }); + // // suspended here until resumed externally + // set_pty_guard(&mut pty_guard); + // continue; + } + } + } else { + io.write_all(buf).await?; + } + } #[allow(unreachable_code)] Ok::<_, sunset::Error>(()) }; - // output needs to complete when the channel is closed let fi = embassy_futures::select::select(fi, io.until_closed()); @@ -252,7 +299,8 @@ impl<'a> CmdlineRunner<'a> { // x // }); - let _ = embassy_futures::join::join3(fe, fi, fo).await; + let io_done = embassy_futures::join::join3(fe, fi, fo); + let _ = embassy_futures::select::select(io_done, terminate.wait()).await; // TODO handle errors from the join? Ok(()) } @@ -294,8 +342,12 @@ impl<'a> CmdlineRunner<'a> { Msg::Opened => { let st = core::mem::replace(&mut self.state, CmdlineState::Authed); if let CmdlineState::Opening { io, extin } = st { - chanio.set(Self::chan_run(io.clone(), extin.clone()).fuse()); + let r = Self::chan_run(io.clone(), extin.clone(), self.pty_guard.take()) + .fuse(); + chanio.set(r); self.state = CmdlineState::Ready { io, extin }; + } else { + warn!("Unexpected Msg::Opened") } } Msg::Exited => { @@ -327,13 +379,7 @@ impl<'a> CmdlineRunner<'a> { debug_assert!(matches!(self.state, CmdlineState::Authed)); let (io, extin) = if self.want_pty { - // switch to raw pty mode - match raw_pty() { - Ok(p) => self.pty_guard = Some(p), - Err(e) => { - warn!("Failed getting raw pty: {e:?}"); - } - } + set_pty_guard(&mut self.pty_guard); let io = cli.open_session_pty().await?; (io, None) } else { @@ -365,6 +411,93 @@ impl<'a> CmdlineRunner<'a> { } } +fn set_pty_guard(pty_guard: &mut Option<RawPtyGuard>) { + match raw_pty() { + Ok(p) => *pty_guard = Some(p), + Err(e) => { + warn!("Failed getting raw pty: {e:?}"); + } + } +} + +#[derive(Debug, PartialEq)] +enum EscapeAction { + None, + // an extra character of output to prepend + Output { extra: Option<u8> }, + Terminate, + Suspend, +} + +#[derive(Debug)] +enum Escaper { + Idle, + Newline, + Escape, +} + +impl Escaper { + fn new() -> Self { + // start as if we had received a '\r' + Self::Newline + } + + /// Handle ~. escape sequences. + fn escape(&mut self, buf: &[u8]) -> EscapeAction { + // Only handle single input keystrokes. Provides some protection against + // pasting escape sequences too. + + let mut newline = false; + if buf.len() == 1 { + let c = buf[0]; + newline = c == b'\r'; + + match self { + Self::Newline if c == b'~' => { + *self = Self::Escape; + return EscapeAction::None + } + Self::Escape => { + // handle the actual escape character + match c { + b'~' => { + // output the single '~' in buf. + *self = Self::Idle; + return EscapeAction::Output { extra: None } + } + b'.' => { + *self = Self::Idle; + return EscapeAction::Terminate + } + // ctrl-z, suspend + 0x1a => { + *self = Self::Idle; + return EscapeAction::Suspend + } + // fall through to reset below + _ => (), + } + } + _ => (), + } + } + + // Reset escaping state + let extra = match self { + // output the '~' that was previously consumed + Self::Escape => Some(b'~'), + _ => None, + }; + if newline { + *self = Self::Newline + } else { + *self = Self::Idle + } + + EscapeAction::Output { extra } + } +} + impl<'a> CmdlineHooks<'a> { /// Notify the `CmdlineClient` that the main SSH session has exited. /// @@ -448,3 +581,55 @@ impl<'a> Debug for CmdlineHooks<'a> { } } +#[cfg(test)] +pub(crate) mod tests { + use crate::cmdline_client::*; + + #[test] + fn escaping() { + // None expect_action is shorthand for ::Output + let seqs = vec![ + ("~.", Some(EscapeAction::Terminate), ""), + ("\r~.", Some(EscapeAction::Terminate), "\r"), + ("~~.", None, "~."), + ("~~~.", None, "~~."), + ("\r\r~.", Some(EscapeAction::Terminate), "\r\r"), + ("a~/~.", None, "a~/~."), + ("a~/\r~.", Some(EscapeAction::Terminate), "a~/\r"), + ("~\r~.", Some(EscapeAction::Terminate), "~\r"), + ("~\r~ ", None, "~\r~ "), + ]; + for (inp, expect_action, expect) in seqs.iter() { + let mut out = vec![]; + let mut esc = Escaper::new(); + let mut last_action = None; + println!("input \"{}\"", inp.escape_default()); + for i in inp.chars() { + let i: u8 = i.try_into().unwrap(); + let e = esc.escape(&[i]); + + if let EscapeAction::Output { ref extra } = e { + if let Some(extra) = extra { + out.push(*extra); + } + out.push(i) + } + + last_action = Some(e); + } + assert_eq!(out.as_slice(), expect.as_bytes()); + + let last_action = last_action.unwrap(); + if let Some(expect_action) = expect_action { + assert_eq!(&last_action, expect_action); + } else { + match last_action { + EscapeAction::Output { .. } => (), + _ => panic!("Unexpected action {last_action:?}"), + } + } + } + } + +} +