diff --git a/sshproto/src/channel.rs b/sshproto/src/channel.rs index 30bbc84a11cdd7575a81d24976e04f20c7714573..b2c866743b4efce043dad1b700e4bbae3b2ebb0e 100644 --- a/sshproto/src/channel.rs +++ b/sshproto/src/channel.rs @@ -10,6 +10,7 @@ use heapless::{Deque, String, Vec}; use crate::*; use config::*; +use conn::Dispatched; use packets::{ChannelReqType, ChannelOpenFailure, ChannelRequest, Packet, ChannelOpen, ChannelOpenType, ChannelData, ChannelDataExt}; use traffic::TrafSend; use sshwire::{BinString, TextString}; @@ -293,20 +294,18 @@ impl Channels { Ok(()) } - /// Incoming packet handling - // TODO: protocol errors etc should perhaps be less fatal, - // ssh implementations are usually imperfect. - pub async fn dispatch( + // Some returned errors will be caught by caller and returned as SSH messages + async fn dispatch_inner( &mut self, packet: Packet<'_>, s: &mut TrafSend<'_>, b: &mut Behaviour<'_>, - ) -> Result<Option<ChanEventMaker>> { + ) -> Result<Dispatched> { trace!("chan dispatch"); + let mut disp = Dispatched(None); let r = match packet { Packet::ChannelOpen(p) => { self.dispatch_open(&p, s, b)?; - Ok(None) } Packet::ChannelOpenConfirmation(p) => { @@ -327,9 +326,8 @@ impl Channels { } ch.state = ChanState::Normal; } - Ok(None) } - _ => Err(Error::SSHProtoError), + _ => return Err(Error::SSHProtoError), } } @@ -337,17 +335,15 @@ impl Channels { let ch = self.get(p.num)?; if ch.send.is_some() { // TODO: or just warn? - Err(Error::SSHProtoError) + return Err(Error::SSHProtoError) } else { self.remove(p.num)?; // TODO event - Ok(None) } } Packet::ChannelWindowAdjust(p) => { let send = self.get_mut(p.num)?.send.as_mut().trap()?; send.window = send.window.saturating_add(p.adjust as usize); - Ok(None) } Packet::ChannelData(p) => { self.get(p.num)?; @@ -357,7 +353,7 @@ impl Channels { } self.pending_input = Some(PendInput { chan: p.num, len: p.data.0.len() }); let di = DataIn { num: p.num, ext: None, offset: p.data_offset(), len: p.data.0.len() }; - Ok(Some(ChanEventMaker::DataIn(di))) + disp = Dispatched(Some(di)); } Packet::ChannelDataExt(p) => { self.get(p.num)?; @@ -368,38 +364,48 @@ impl Channels { self.pending_input = Some(PendInput { chan: p.num, len: p.data.0.len() }); let di = DataIn { num: p.num, ext: Some(p.code), offset: p.data_offset(), len: p.data.0.len() }; trace!("{di:?}"); - Ok(Some(ChanEventMaker::DataIn(di))) } Packet::ChannelEof(p) => { self.get(p.num)?; - Ok(Some(ChanEventMaker::Eof { num: p.num })) } Packet::ChannelClose(_p) => { // todo!(); error!("ignoring channel close"); - Ok(None) } Packet::ChannelRequest(p) => { self.dispatch_request(&p, s, b)?; - Ok(None) } Packet::ChannelSuccess(_p) => { trace!("channel success, TODO"); - Ok(None) } Packet::ChannelFailure(_p) => { todo!(); } - _ => Error::bug_msg("unreachable") + _ => Error::bug_msg("unreachable")? }; + Ok(disp) + } + + /// Incoming packet handling + // TODO: protocol errors etc should perhaps be less fatal, + // ssh implementations are usually imperfect. + pub async fn dispatch( + &mut self, + packet: Packet<'_>, + s: &mut TrafSend<'_>, + b: &mut Behaviour<'_>, + ) -> Result<Dispatched> { + + let r = self.dispatch_inner(packet, s, b).await; + match r { Err(Error::BadChannel) => { warn!("Ignoring bad channel number"); - Ok(None) + Ok(Dispatched(None)) } - Ok(ev) => Ok(ev), // TODO: close channel on error? or on SSHProtoError? Err(any) => Err(any), + Ok(disp) => Ok(disp), } } } diff --git a/sshproto/src/conn.rs b/sshproto/src/conn.rs index 6afc57e853265c26a1d41483c2684a2e2a67957b..50c92f999abe1564612ceedf549608e27855cadc 100644 --- a/sshproto/src/conn.rs +++ b/sshproto/src/conn.rs @@ -96,6 +96,8 @@ pub(crate) enum EventMaker { CliAuthed, } +pub(crate) struct Dispatched(pub Option<channel::DataIn>); + impl<'a> Conn<'a> { pub fn new_client() -> Result<Self> { Self::new(ClientServer::Client(client::Client::new())) @@ -181,7 +183,7 @@ impl<'a> Conn<'a> { Err(Error::UnknownPacket { number }) => { trace!("Unimplemented packet type {number}"); s.send(packets::Unimplemented { seq })?; - Ok(Dispatched { event: None }) + Ok(Dispatched(None)) } Err(e) => return Err(e), } @@ -193,7 +195,7 @@ impl<'a> Conn<'a> { // TODO: perhaps could consolidate packet allowed checks into a separate function // to run first? trace!("Incoming {packet:#?}"); - let mut event = None; + let mut disp = Dispatched(None); match packet { Packet::KexInit(_) => { if matches!(self.state, ConnState::InKex { .. }) { @@ -304,7 +306,6 @@ impl<'a> Conn<'a> { if matches!(self.state, ConnState::PreAuth) { self.state = ConnState::Authed; cli.auth_success(&mut self.parse_ctx, s, b.client()?)?; - event = Some(EventMaker::CliAuthed); } else { debug!("Received UserauthSuccess unrequested") } @@ -344,38 +345,11 @@ impl<'a> Conn<'a> { | Packet::ChannelFailure(_) // TODO: maybe needs a conn or cliserv argument. => { - let chev = self.channels.dispatch(packet, s, b).await?; - event = chev.map(|c| EventMaker::Channel(c)) + disp = self.channels.dispatch(packet, s, b).await?; } }; - Ok(Dispatched { event }) + Ok(disp) } - - /// creates an `Event` that borrows data from the payload. Some `Event` variants don't - /// require payload data, the payload is not required in that case. - /// Those variants are allowed to return `resp` packets from `dispatch()` - pub(crate) fn make_event<'p>(&mut self, payload: Option<&'p [u8]>, ev: EventMaker) - -> Result<Option<Event<'p>>> { - let p = payload.map(|pl| sshwire::packet_from_bytes(pl, &self.parse_ctx)).transpose()?; - let r = match ev { - EventMaker::Channel(ChanEventMaker::DataIn(_)) => { - // caller should have handled it instead - return Err(Error::bug()) - } - EventMaker::Channel(cev) => { - let c = cev.make(p.trap()?); - c.map(|c| Event::Channel(c)) - } - EventMaker::CliAuthed => Some(Event::CliAuthed), - }; - Ok(r) - } - -} - -// TODO: delete this -pub(crate) struct Dispatched { - pub event: Option<EventMaker>, } #[cfg(test)] diff --git a/sshproto/src/runner.rs b/sshproto/src/runner.rs index d512f9b95063ab2a7efff0e0b59c2fa1d33e1cd0..079c674c3ae74021ef4fca99687ecba039ac5d36 100644 --- a/sshproto/src/runner.rs +++ b/sshproto/src/runner.rs @@ -12,7 +12,7 @@ use crate::{*, channel::ChanEvent}; use encrypt::KeyState; use traffic::{TrafIn, TrafOut, TrafSend}; -use conn::{Conn, Dispatched, EventMaker, Event}; +use conn::{Conn, Dispatched}; use channel::ChanEventMaker; pub struct Runner<'a> { @@ -84,15 +84,23 @@ impl<'a> Runner<'a> { Ok(r) } + pub async fn progress(&mut self, b: &mut Behaviour<'_>) -> Result<()> { + let done = self.progress_inner(b).await?; + + if done { + self.wake(); + } + Ok(()) + } + /// Drives connection progress, handling received payload and sending /// other packets as required. This must be polled/awaited regularly. /// Optionally returns `Event` which provides channel or session /// event to the application. /// [`done_payload()`] must be called after any `Ok` result. - pub async fn progress<'f>(&'f mut self, behaviour: &mut Behaviour<'_>) -> Result<Option<Event<'f>>, Error> { - let mut s = self.traf_out.sender(&mut self.keys); - let em = if let Some((payload, seq)) = self.traf_in.payload() { + pub async fn progress_inner(&'a mut self, behaviour: &mut Behaviour<'_>) -> Result<bool> { + if let Some((payload, seq)) = self.traf_in.payload() { // Lifetimes here are a bit subtle. // `payload` has self.traffic lifetime, used until `handle_payload` // completes. @@ -100,49 +108,21 @@ impl<'a> Runner<'a> { // by the send_packet(). // After that progress() can perform more send_packet() itself. + let mut s = self.traf_out.sender(&mut self.keys); + let d = self.conn.handle_payload(payload, seq, &mut s, behaviour).await?; - self.traf_in.handled_payload()?; - if d.event.is_none() { - // switch to using the buffer for output. + if let Some(d) = d.0 { + self.traf_in.handled_payload()?; + self.traf_in.set_channel_input(d)?; + false + } else { self.traf_in.done_payload()?; + self.conn.progress(&mut s, behaviour).await?; + true } - d.event - } else { - None - }; - - // We split return values into Event/EventMaker to work around - // the payload borrow range extending too long. - // Polonius would solve this. We can't use polonius-the-crab - // because we're calling async functions. - // "Borrow checker extends borrow range in code with early return" - // https://github.com/rust-lang/rust/issues/54663 - let ev = if let Some(em) = em { - trace!("em"); - match em { - EventMaker::Channel(ChanEventMaker::DataIn(di)) => { - trace!("chanmaaker {di:?}"); - self.traf_in.done_payload()?; - self.traf_in.set_channel_input(di)?; - // TODO: channel wakers - None - } - _ => { - // Some(payload) is only required for some variants in make_event() - panic!("delete this codepath") - } - } - } else { - trace!("no em, conn progress"); - self.conn.progress(&mut s, behaviour).await?; - self.wake(); - None - }; - trace!("prog event {ev:?}"); - - Ok(ev) + } } pub fn done_payload(&mut self) -> Result<()> { diff --git a/sshproto/src/traffic.rs b/sshproto/src/traffic.rs index 9381675d170d11f8e99106f813ccff09b3e0c931..bc5bcb84c652dcb7d66beca4e0c7b9854e4f0df6 100644 --- a/sshproto/src/traffic.rs +++ b/sshproto/src/traffic.rs @@ -405,7 +405,7 @@ impl<'a> TrafOut<'a> { } - pub fn sender(&mut self, keys: &'a mut KeyState) -> TrafSend { + pub fn sender(&'a mut self, keys: &'a mut KeyState) -> TrafSend { TrafSend::new(self, keys) }