diff --git a/async/examples/serv1.rs b/async/examples/serv1.rs index 1603792f9136a7dce5b081343caaf3da0ee2a606..7139e69f3379bf1540fadea329a522e762262ce5 100644 --- a/async/examples/serv1.rs +++ b/async/examples/serv1.rs @@ -147,9 +147,10 @@ impl ServBehaviour for DemoServer { } } - fn sess_req_shell(&mut self, chan: u32) -> bool { + fn sess_shell(&mut self, chan: u32) -> bool { let r = !self.want_shell && self.sess == Some(chan); self.want_shell = true; + trace!("req want shell"); r } @@ -176,6 +177,8 @@ fn run_session<'a, R: Send>(args: &'a Args, scope: &'a moro::Scope<'a, '_, R>, m loop { serv.progress(&mut app).await.context("progress loop")?; if app.want_shell && !app.shell_started { + trace!("make shell"); + app.shell_started = true; if let Some(ch) = app.sess { let ch = ch.clone(); diff --git a/async/src/async_channel.rs b/async/src/async_channel.rs index ad745f4743732f51549a8b52338bf199057e408a..93b5f366d730a0cc3f5cb3ca57febaf3983b611b 100644 --- a/async/src/async_channel.rs +++ b/async/src/async_channel.rs @@ -99,6 +99,7 @@ impl<'a> AsyncRead for ChanInOut<'a> { cx: &mut Context<'_>, buf: &mut ReadBuf, ) -> Poll<Result<(), IoError>> { + trace!("poll read {}", self.chan); let this = self.deref_mut(); chan_poll_read(&mut this.door, this.chan, None, cx, buf, &mut this.rlfut) } @@ -162,6 +163,7 @@ impl<'a> AsyncWrite for ChanInOut<'a> { cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, IoError>> { + trace!("poll write {}", self.chan); let this = self.deref_mut(); chan_poll_write(&mut this.door, this.chan, None, cx, buf, &mut this.wlfut) } diff --git a/async/src/async_door.rs b/async/src/async_door.rs index 02583f13e09dbcaa3f96b690bd62df69381fe581..9dd764cc375e9ddf9fe462936cee16d94dc8c8fa 100644 --- a/async/src/async_door.rs +++ b/async/src/async_door.rs @@ -73,10 +73,12 @@ impl<'a> AsyncDoor<'a> { let inner = inner.deref_mut(); inner.runner.progress(b).await?; + trace!("pre wakers {:?}", inner.chan_read_wakers); if let Some(ce) = inner.runner.ready_channel_input() { inner.chan_read_wakers.remove(&ce) .map(|w| wakers.push(w)); } + trace!("pos wakers {:?}", inner.chan_read_wakers); // Pending HashMap::drain_filter // https://github.com/rust-lang/rust/issues/59618 diff --git a/sshproto/src/behaviour.rs b/sshproto/src/behaviour.rs index 71e580b926ba8d69bbad8f2bb18237f20b8e275f..19d48d2ad03e249a96c5165caa17aebfb2cc5bd3 100644 --- a/sshproto/src/behaviour.rs +++ b/sshproto/src/behaviour.rs @@ -200,12 +200,12 @@ pub trait ServBehaviour: Sync+Send { } #[allow(unused)] - fn sess_req_shell(&mut self, chan: u32) -> bool { + fn sess_shell(&mut self, chan: u32) -> bool { false } #[allow(unused)] - fn sess_req_exec(&mut self, chan: u32, cmd: &str) -> bool { + fn sess_exec(&mut self, chan: u32, cmd: TextString) -> bool { false } diff --git a/sshproto/src/channel.rs b/sshproto/src/channel.rs index 8fa2bc25130a811e87f3773bd856dc80d5638f5b..8a37e626b87e15284ce48f544e2bedb9d141e7b6 100644 --- a/sshproto/src/channel.rs +++ b/sshproto/src/channel.rs @@ -8,7 +8,7 @@ use core::mem; use heapless::{Deque, String, Vec}; -use crate::*; +use crate::{*, sshwire::SSHEncodeEnum}; use config::*; use conn::Dispatched; use packets::{ChannelReqType, ChannelOpenFailure, ChannelRequest, Packet, ChannelOpen, ChannelOpenType, ChannelData, ChannelDataExt}; @@ -16,12 +16,11 @@ use traffic::TrafSend; use sshwire::{BinString, TextString}; use sshnames::*; +/// The result of a channel open request. pub enum ChanOpened { Success, - - /// A channel open response will be sent later + /// A channel open response will be sent later (for eg TCP open) Defer, - Failure(ChanFail), } @@ -83,7 +82,7 @@ impl Channels { Ok((ch.as_ref().unwrap(), p)) } - /// Returns a `Channel` for a local number, any state. + /// Returns a `Channel` for a local number, any state including `InOpen`. pub fn get_any(&self, num: u32) -> Result<&Channel> { self.ch .get(num as usize) @@ -246,13 +245,13 @@ impl Channels { ChannelOpenType::Session => { // unwrap: earlier test ensures b.server() succeeds let mut bserv = b.server().unwrap(); - bserv.open_session(ch.recv.num) + bserv.open_session(ch.num()) } ChannelOpenType::ForwardedTcpip(t) => { - b.open_tcp_forwarded(ch.recv.num, t) + b.open_tcp_forwarded(ch.num(), t) } ChannelOpenType::DirectTcpip(t) => { - b.open_tcp_direct(ch.recv.num, t) + b.open_tcp_direct(ch.num(), t) } ChannelOpenType::Unknown(_) => { unreachable!() @@ -261,11 +260,11 @@ impl Channels { match r { ChanOpened::Success => { - s.send(ch.open_done())?; + s.send(ch.open_done()?)?; }, ChanOpened::Failure(f) => { - let n = ch.recv.num; - self.remove(n); + let n = ch.num(); + self.remove(n)?; return Err(f.into()) } ChanOpened::Defer => { @@ -278,20 +277,32 @@ impl Channels { pub fn dispatch_request(&mut self, p: &packets::ChannelRequest, - _s: &mut TrafSend, - _b: &mut Behaviour<'_>, + s: &mut TrafSend, + b: &mut Behaviour<'_>, ) -> Result<()> { - let ch = match self.get(p.num) { - Ok(ch) => ch, - Err(Error::BadChannel) => { - debug!("request {p:?} channel is unknown"); - return Ok(()) - }, - Err(e) => unreachable!(), + if let Ok(ch) = self.get(p.num) { + // only servers accept requests + let success = if let Ok(b) = b.server() { + ch.dispatch_server_request(p, s, b).unwrap_or_else(|e| { + debug!("Error in channel req handling for {p:?}, {e:?}"); + false + }) + } else { + false }; - - Ok(()) + if p.want_reply { + let num = ch.send_num()?; + if success { + s.send(packets::ChannelSuccess { num })?; + } else { + s.send(packets::ChannelFailure { num })?; + } + } + } else { + debug!("Ignoring request to unknown channel: {p:#?}"); + } + Ok(()) } // Some returned errors will be caught by caller and returned as SSH messages @@ -583,29 +594,70 @@ impl Channel { } } + /// Local channel number + pub(crate) fn num(&self) -> u32 { + self.recv.num + } + + /// Remote channel number, fails if channel is in progress opening + pub(crate) fn send_num(&self) -> Result<u32> { + Ok(self.send.as_ref().trap()?.num) + } + fn request(&mut self, req: ReqDetails, s: &mut TrafSend) -> Result<()> { let num = self.send.as_ref().trap()?.num; let r = Req { num, details: req }; s.send(r.packet()?) } - pub(crate) fn number(&self) -> u32 { - self.recv.num - } - /// Returns an open confirmation reply packet to send. /// Must be called with state of `InOpen`. - fn open_done<'p>(&mut self) -> Packet<'p> { + fn open_done<'p>(&mut self) -> Result<Packet<'p>> { debug_assert!(matches!(self.state, ChanState::InOpen)); self.state = ChanState::Normal; - packets::ChannelOpenConfirmation { - num: self.recv.num, + let p = packets::ChannelOpenConfirmation { + num: self.send_num()?, // unwrap: state is InOpen sender_num: self.send.as_ref().unwrap().num, initial_window: self.recv.window as u32, max_packet: self.recv.max_packet as u32, - }.into() + }.into(); + Ok(p) + } + + fn dispatch_server_request(&self, + p: &packets::ChannelRequest, + s: &mut TrafSend, + b: &mut dyn ServBehaviour, + ) -> Result<bool> { + + if !matches!(self.ty, ChanType::Session) { + return Ok(false) + } + + match &p.req { + ChannelReqType::Shell => { + Ok(b.sess_shell(self.num())) + } + ChannelReqType::Exec(ex) => { + Ok(b.sess_exec(self.num(), ex.command)) + } + // TODO need to convert packet to channel Pty + // ChannelReqType::Pty(pty) => { + // let cpty = pty.into(); + // Ok(b.sess_pty(self.num(), &cpty)) + // } + _ => { + if let ChannelReqType::Unknown(u) = &p.req { + warn!("Unknown channel req type \"{}\"", u) + } else { + // OK unwrap: tested for Unknown + warn!("Unhandled channel req \"{}\"", p.req.variant_name().unwrap()) + }; + Ok(false) + } + } } fn finished_input(&mut self, len: usize ) { @@ -639,8 +691,6 @@ impl Channel { Ok(None) } } - - } pub struct ChanMsg { diff --git a/sshproto/src/packets.rs b/sshproto/src/packets.rs index d255cfbf6e068c44c289fac66b88c8aba40dd036..600a4f15a36db36dd942508c25ea5dc0df2a23f0 100644 --- a/sshproto/src/packets.rs +++ b/sshproto/src/packets.rs @@ -591,11 +591,19 @@ pub struct ParseContext { // Used by sign_encode() pub method_pubkey_force_sig_bool: bool, + + // Set to true if an unknown variant is encountered. + // Packet length checks should be omitted in that case. + pub(crate) seen_unknown: bool, } impl ParseContext { pub fn new() -> Self { - ParseContext { cli_auth_type: None, method_pubkey_force_sig_bool: false } + ParseContext { + cli_auth_type: None, + method_pubkey_force_sig_bool: false, + seen_unknown: false, + } } } diff --git a/sshproto/src/runner.rs b/sshproto/src/runner.rs index 31602ca065b6b38dffb7e88378a4e36a0c3dd15c..814d8a36dc72758afad683d1b26c0b6df32d0358 100644 --- a/sshproto/src/runner.rs +++ b/sshproto/src/runner.rs @@ -166,7 +166,7 @@ impl<'a> Runner<'a> { init_req.push(channel::ReqDetails::Shell).trap()?; } let (ch, p) = self.conn.channels.open(packets::ChannelOpenType::Session, init_req)?; - let chan = ch.number(); + let chan = ch.num(); self.traf_out.send_packet(p, &mut self.keys)?; self.wake(); Ok(chan) diff --git a/sshproto/src/sshwire.rs b/sshproto/src/sshwire.rs index 357a488118b36c6d42e9b23b162cb5bed1740600..8f80128a3fc8b1174f6983506c30a89d9ccb1adf 100644 --- a/sshproto/src/sshwire.rs +++ b/sshproto/src/sshwire.rs @@ -31,7 +31,7 @@ pub trait SSHSink { pub trait SSHSource<'de> { fn take(&mut self, len: usize) -> WireResult<&'de [u8]>; fn pos(&self) -> usize; - fn ctx(&self) -> &ParseContext; + fn ctx(&mut self) -> &mut ParseContext; } /// Encodes the type in SSH wire format @@ -101,12 +101,16 @@ pub type WireResult<T> = core::result::Result<T, WireError>; /// Parses a [`Packet`] from a borrowed `&[u8]` byte buffer. pub fn packet_from_bytes<'a>(b: &'a [u8], ctx: &ParseContext) -> Result<Packet<'a>> { - let mut s = DecodeBytes { input: b, pos: 0, parse_ctx: ctx.clone() }; + let ctx = ParseContext { seen_unknown: false, .. ctx.clone()}; + let mut s = DecodeBytes { input: b, pos: 0, parse_ctx: ctx }; let p = Packet::dec(&mut s)?; - if s.pos() == b.len() { - Ok(p) - } else { + + if s.pos() != b.len() && !s.ctx().seen_unknown { + // No length check if the packet had an unknown variant + // - we skipped parsing the rest of the packet. Err(Error::WrongPacketLength) + } else { + Ok(p) } } @@ -222,8 +226,8 @@ impl<'de> SSHSource<'de> for DecodeBytes<'de> { self.pos } - fn ctx(&self) -> &ParseContext { - &self.parse_ctx + fn ctx(&mut self) -> &mut ParseContext { + &mut self.parse_ctx } } diff --git a/sshproto/src/traffic.rs b/sshproto/src/traffic.rs index 110a17ec325f5d7fe21e027a7cbf075b6a6f3a71..a648911aae9f54f00138b0ed5633163333d70476 100644 --- a/sshproto/src/traffic.rs +++ b/sshproto/src/traffic.rs @@ -96,10 +96,11 @@ impl<'a> TrafIn<'a> { pub fn ready_input(&self) -> bool { trace!("ready_input state {:?}", self.state); match self.state { - RxState::Idle + | RxState::Idle | RxState::ReadInitial { .. } - | RxState::Read { .. } => true, - RxState::ReadComplete { .. } + | RxState::Read { .. } + => true, + | RxState::ReadComplete { .. } | RxState::InPayload { .. } | RxState::InChannelData { .. } => false, diff --git a/sshwire_derive/src/lib.rs b/sshwire_derive/src/lib.rs index 4b354ff6b27c4a3074da9ebc93485bf3129dbf93..5a7d95c240a56cd4d789dd3599580362e6e9e88a 100644 --- a/sshwire_derive/src/lib.rs +++ b/sshwire_derive/src/lib.rs @@ -508,7 +508,7 @@ fn decode_enum_names( if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) { // create the Unknown fallthrough but it will be at the end of the match list let mut m = StreamBuilder::new(); - m.push_parsed(format!("_ => Self::{}(Unknown(variant))", var.name))?; + m.push_parsed(format!("_ => {{ s.ctx().seen_unknown = true; Self::{}(Unknown(variant))}}", var.name))?; if unknown_arm.replace(m).is_some() { return Err(Error::Custom { error: "only one variant can have #[sshwire(unknown)]".into(), span: None}) }