From 3cc08506c6c5d6a8d86d4880a9adcc4044c57c3e Mon Sep 17 00:00:00 2001 From: Matt Johnston <matt@ucc.asn.au> Date: Tue, 30 Aug 2022 00:11:31 +0800 Subject: [PATCH] channel req handling --- async/examples/serv1.rs | 5 +- async/src/async_channel.rs | 2 + async/src/async_door.rs | 2 + sshproto/src/behaviour.rs | 4 +- sshproto/src/channel.rs | 114 ++++++++++++++++++++++++++----------- sshproto/src/packets.rs | 10 +++- sshproto/src/runner.rs | 2 +- sshproto/src/sshwire.rs | 18 +++--- sshproto/src/traffic.rs | 7 ++- sshwire_derive/src/lib.rs | 2 +- 10 files changed, 118 insertions(+), 48 deletions(-) diff --git a/async/examples/serv1.rs b/async/examples/serv1.rs index 1603792..7139e69 100644 --- a/async/examples/serv1.rs +++ b/async/examples/serv1.rs @@ -147,9 +147,10 @@ impl ServBehaviour for DemoServer { } } - fn sess_req_shell(&mut self, chan: u32) -> bool { + fn sess_shell(&mut self, chan: u32) -> bool { let r = !self.want_shell && self.sess == Some(chan); self.want_shell = true; + trace!("req want shell"); r } @@ -176,6 +177,8 @@ fn run_session<'a, R: Send>(args: &'a Args, scope: &'a moro::Scope<'a, '_, R>, m loop { serv.progress(&mut app).await.context("progress loop")?; if app.want_shell && !app.shell_started { + trace!("make shell"); + app.shell_started = true; if let Some(ch) = app.sess { let ch = ch.clone(); diff --git a/async/src/async_channel.rs b/async/src/async_channel.rs index ad745f4..93b5f36 100644 --- a/async/src/async_channel.rs +++ b/async/src/async_channel.rs @@ -99,6 +99,7 @@ impl<'a> AsyncRead for ChanInOut<'a> { cx: &mut Context<'_>, buf: &mut ReadBuf, ) -> Poll<Result<(), IoError>> { + trace!("poll read {}", self.chan); let this = self.deref_mut(); chan_poll_read(&mut this.door, this.chan, None, cx, buf, &mut this.rlfut) } @@ -162,6 +163,7 @@ impl<'a> AsyncWrite for ChanInOut<'a> { cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, IoError>> { + trace!("poll write {}", self.chan); let this = self.deref_mut(); chan_poll_write(&mut this.door, this.chan, None, cx, buf, &mut this.wlfut) } diff --git a/async/src/async_door.rs b/async/src/async_door.rs index 02583f1..9dd764c 100644 --- a/async/src/async_door.rs +++ b/async/src/async_door.rs @@ -73,10 +73,12 @@ impl<'a> AsyncDoor<'a> { let inner = inner.deref_mut(); inner.runner.progress(b).await?; + trace!("pre wakers {:?}", inner.chan_read_wakers); if let Some(ce) = inner.runner.ready_channel_input() { inner.chan_read_wakers.remove(&ce) .map(|w| wakers.push(w)); } + trace!("pos wakers {:?}", inner.chan_read_wakers); // Pending HashMap::drain_filter // https://github.com/rust-lang/rust/issues/59618 diff --git a/sshproto/src/behaviour.rs b/sshproto/src/behaviour.rs index 71e580b..19d48d2 100644 --- a/sshproto/src/behaviour.rs +++ b/sshproto/src/behaviour.rs @@ -200,12 +200,12 @@ pub trait ServBehaviour: Sync+Send { } #[allow(unused)] - fn sess_req_shell(&mut self, chan: u32) -> bool { + fn sess_shell(&mut self, chan: u32) -> bool { false } #[allow(unused)] - fn sess_req_exec(&mut self, chan: u32, cmd: &str) -> bool { + fn sess_exec(&mut self, chan: u32, cmd: TextString) -> bool { false } diff --git a/sshproto/src/channel.rs b/sshproto/src/channel.rs index 8fa2bc2..8a37e62 100644 --- a/sshproto/src/channel.rs +++ b/sshproto/src/channel.rs @@ -8,7 +8,7 @@ use core::mem; use heapless::{Deque, String, Vec}; -use crate::*; +use crate::{*, sshwire::SSHEncodeEnum}; use config::*; use conn::Dispatched; use packets::{ChannelReqType, ChannelOpenFailure, ChannelRequest, Packet, ChannelOpen, ChannelOpenType, ChannelData, ChannelDataExt}; @@ -16,12 +16,11 @@ use traffic::TrafSend; use sshwire::{BinString, TextString}; use sshnames::*; +/// The result of a channel open request. pub enum ChanOpened { Success, - - /// A channel open response will be sent later + /// A channel open response will be sent later (for eg TCP open) Defer, - Failure(ChanFail), } @@ -83,7 +82,7 @@ impl Channels { Ok((ch.as_ref().unwrap(), p)) } - /// Returns a `Channel` for a local number, any state. + /// Returns a `Channel` for a local number, any state including `InOpen`. pub fn get_any(&self, num: u32) -> Result<&Channel> { self.ch .get(num as usize) @@ -246,13 +245,13 @@ impl Channels { ChannelOpenType::Session => { // unwrap: earlier test ensures b.server() succeeds let mut bserv = b.server().unwrap(); - bserv.open_session(ch.recv.num) + bserv.open_session(ch.num()) } ChannelOpenType::ForwardedTcpip(t) => { - b.open_tcp_forwarded(ch.recv.num, t) + b.open_tcp_forwarded(ch.num(), t) } ChannelOpenType::DirectTcpip(t) => { - b.open_tcp_direct(ch.recv.num, t) + b.open_tcp_direct(ch.num(), t) } ChannelOpenType::Unknown(_) => { unreachable!() @@ -261,11 +260,11 @@ impl Channels { match r { ChanOpened::Success => { - s.send(ch.open_done())?; + s.send(ch.open_done()?)?; }, ChanOpened::Failure(f) => { - let n = ch.recv.num; - self.remove(n); + let n = ch.num(); + self.remove(n)?; return Err(f.into()) } ChanOpened::Defer => { @@ -278,20 +277,32 @@ impl Channels { pub fn dispatch_request(&mut self, p: &packets::ChannelRequest, - _s: &mut TrafSend, - _b: &mut Behaviour<'_>, + s: &mut TrafSend, + b: &mut Behaviour<'_>, ) -> Result<()> { - let ch = match self.get(p.num) { - Ok(ch) => ch, - Err(Error::BadChannel) => { - debug!("request {p:?} channel is unknown"); - return Ok(()) - }, - Err(e) => unreachable!(), + if let Ok(ch) = self.get(p.num) { + // only servers accept requests + let success = if let Ok(b) = b.server() { + ch.dispatch_server_request(p, s, b).unwrap_or_else(|e| { + debug!("Error in channel req handling for {p:?}, {e:?}"); + false + }) + } else { + false }; - - Ok(()) + if p.want_reply { + let num = ch.send_num()?; + if success { + s.send(packets::ChannelSuccess { num })?; + } else { + s.send(packets::ChannelFailure { num })?; + } + } + } else { + debug!("Ignoring request to unknown channel: {p:#?}"); + } + Ok(()) } // Some returned errors will be caught by caller and returned as SSH messages @@ -583,29 +594,70 @@ impl Channel { } } + /// Local channel number + pub(crate) fn num(&self) -> u32 { + self.recv.num + } + + /// Remote channel number, fails if channel is in progress opening + pub(crate) fn send_num(&self) -> Result<u32> { + Ok(self.send.as_ref().trap()?.num) + } + fn request(&mut self, req: ReqDetails, s: &mut TrafSend) -> Result<()> { let num = self.send.as_ref().trap()?.num; let r = Req { num, details: req }; s.send(r.packet()?) } - pub(crate) fn number(&self) -> u32 { - self.recv.num - } - /// Returns an open confirmation reply packet to send. /// Must be called with state of `InOpen`. - fn open_done<'p>(&mut self) -> Packet<'p> { + fn open_done<'p>(&mut self) -> Result<Packet<'p>> { debug_assert!(matches!(self.state, ChanState::InOpen)); self.state = ChanState::Normal; - packets::ChannelOpenConfirmation { - num: self.recv.num, + let p = packets::ChannelOpenConfirmation { + num: self.send_num()?, // unwrap: state is InOpen sender_num: self.send.as_ref().unwrap().num, initial_window: self.recv.window as u32, max_packet: self.recv.max_packet as u32, - }.into() + }.into(); + Ok(p) + } + + fn dispatch_server_request(&self, + p: &packets::ChannelRequest, + s: &mut TrafSend, + b: &mut dyn ServBehaviour, + ) -> Result<bool> { + + if !matches!(self.ty, ChanType::Session) { + return Ok(false) + } + + match &p.req { + ChannelReqType::Shell => { + Ok(b.sess_shell(self.num())) + } + ChannelReqType::Exec(ex) => { + Ok(b.sess_exec(self.num(), ex.command)) + } + // TODO need to convert packet to channel Pty + // ChannelReqType::Pty(pty) => { + // let cpty = pty.into(); + // Ok(b.sess_pty(self.num(), &cpty)) + // } + _ => { + if let ChannelReqType::Unknown(u) = &p.req { + warn!("Unknown channel req type \"{}\"", u) + } else { + // OK unwrap: tested for Unknown + warn!("Unhandled channel req \"{}\"", p.req.variant_name().unwrap()) + }; + Ok(false) + } + } } fn finished_input(&mut self, len: usize ) { @@ -639,8 +691,6 @@ impl Channel { Ok(None) } } - - } pub struct ChanMsg { diff --git a/sshproto/src/packets.rs b/sshproto/src/packets.rs index d255cfb..600a4f1 100644 --- a/sshproto/src/packets.rs +++ b/sshproto/src/packets.rs @@ -591,11 +591,19 @@ pub struct ParseContext { // Used by sign_encode() pub method_pubkey_force_sig_bool: bool, + + // Set to true if an unknown variant is encountered. + // Packet length checks should be omitted in that case. + pub(crate) seen_unknown: bool, } impl ParseContext { pub fn new() -> Self { - ParseContext { cli_auth_type: None, method_pubkey_force_sig_bool: false } + ParseContext { + cli_auth_type: None, + method_pubkey_force_sig_bool: false, + seen_unknown: false, + } } } diff --git a/sshproto/src/runner.rs b/sshproto/src/runner.rs index 31602ca..814d8a3 100644 --- a/sshproto/src/runner.rs +++ b/sshproto/src/runner.rs @@ -166,7 +166,7 @@ impl<'a> Runner<'a> { init_req.push(channel::ReqDetails::Shell).trap()?; } let (ch, p) = self.conn.channels.open(packets::ChannelOpenType::Session, init_req)?; - let chan = ch.number(); + let chan = ch.num(); self.traf_out.send_packet(p, &mut self.keys)?; self.wake(); Ok(chan) diff --git a/sshproto/src/sshwire.rs b/sshproto/src/sshwire.rs index 357a488..8f80128 100644 --- a/sshproto/src/sshwire.rs +++ b/sshproto/src/sshwire.rs @@ -31,7 +31,7 @@ pub trait SSHSink { pub trait SSHSource<'de> { fn take(&mut self, len: usize) -> WireResult<&'de [u8]>; fn pos(&self) -> usize; - fn ctx(&self) -> &ParseContext; + fn ctx(&mut self) -> &mut ParseContext; } /// Encodes the type in SSH wire format @@ -101,12 +101,16 @@ pub type WireResult<T> = core::result::Result<T, WireError>; /// Parses a [`Packet`] from a borrowed `&[u8]` byte buffer. pub fn packet_from_bytes<'a>(b: &'a [u8], ctx: &ParseContext) -> Result<Packet<'a>> { - let mut s = DecodeBytes { input: b, pos: 0, parse_ctx: ctx.clone() }; + let ctx = ParseContext { seen_unknown: false, .. ctx.clone()}; + let mut s = DecodeBytes { input: b, pos: 0, parse_ctx: ctx }; let p = Packet::dec(&mut s)?; - if s.pos() == b.len() { - Ok(p) - } else { + + if s.pos() != b.len() && !s.ctx().seen_unknown { + // No length check if the packet had an unknown variant + // - we skipped parsing the rest of the packet. Err(Error::WrongPacketLength) + } else { + Ok(p) } } @@ -222,8 +226,8 @@ impl<'de> SSHSource<'de> for DecodeBytes<'de> { self.pos } - fn ctx(&self) -> &ParseContext { - &self.parse_ctx + fn ctx(&mut self) -> &mut ParseContext { + &mut self.parse_ctx } } diff --git a/sshproto/src/traffic.rs b/sshproto/src/traffic.rs index 110a17e..a648911 100644 --- a/sshproto/src/traffic.rs +++ b/sshproto/src/traffic.rs @@ -96,10 +96,11 @@ impl<'a> TrafIn<'a> { pub fn ready_input(&self) -> bool { trace!("ready_input state {:?}", self.state); match self.state { - RxState::Idle + | RxState::Idle | RxState::ReadInitial { .. } - | RxState::Read { .. } => true, - RxState::ReadComplete { .. } + | RxState::Read { .. } + => true, + | RxState::ReadComplete { .. } | RxState::InPayload { .. } | RxState::InChannelData { .. } => false, diff --git a/sshwire_derive/src/lib.rs b/sshwire_derive/src/lib.rs index 4b354ff..5a7d95c 100644 --- a/sshwire_derive/src/lib.rs +++ b/sshwire_derive/src/lib.rs @@ -508,7 +508,7 @@ fn decode_enum_names( if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) { // create the Unknown fallthrough but it will be at the end of the match list let mut m = StreamBuilder::new(); - m.push_parsed(format!("_ => Self::{}(Unknown(variant))", var.name))?; + m.push_parsed(format!("_ => {{ s.ctx().seen_unknown = true; Self::{}(Unknown(variant))}}", var.name))?; if unknown_arm.replace(m).is_some() { return Err(Error::Custom { error: "only one variant can have #[sshwire(unknown)]".into(), span: None}) } -- GitLab