From 3cc08506c6c5d6a8d86d4880a9adcc4044c57c3e Mon Sep 17 00:00:00 2001
From: Matt Johnston <matt@ucc.asn.au>
Date: Tue, 30 Aug 2022 00:11:31 +0800
Subject: [PATCH] channel req handling

---
 async/examples/serv1.rs    |   5 +-
 async/src/async_channel.rs |   2 +
 async/src/async_door.rs    |   2 +
 sshproto/src/behaviour.rs  |   4 +-
 sshproto/src/channel.rs    | 114 ++++++++++++++++++++++++++-----------
 sshproto/src/packets.rs    |  10 +++-
 sshproto/src/runner.rs     |   2 +-
 sshproto/src/sshwire.rs    |  18 +++---
 sshproto/src/traffic.rs    |   7 ++-
 sshwire_derive/src/lib.rs  |   2 +-
 10 files changed, 118 insertions(+), 48 deletions(-)

diff --git a/async/examples/serv1.rs b/async/examples/serv1.rs
index 1603792..7139e69 100644
--- a/async/examples/serv1.rs
+++ b/async/examples/serv1.rs
@@ -147,9 +147,10 @@ impl ServBehaviour for DemoServer {
         }
     }
 
-    fn sess_req_shell(&mut self, chan: u32) -> bool {
+    fn sess_shell(&mut self, chan: u32) -> bool {
         let r = !self.want_shell && self.sess == Some(chan);
         self.want_shell = true;
+        trace!("req want shell");
         r
     }
 
@@ -176,6 +177,8 @@ fn run_session<'a, R: Send>(args: &'a Args, scope: &'a moro::Scope<'a, '_, R>, m
                 loop {
                     serv.progress(&mut app).await.context("progress loop")?;
                     if app.want_shell && !app.shell_started {
+                        trace!("make shell");
+                        app.shell_started = true;
 
                         if let Some(ch) = app.sess {
                             let ch = ch.clone();
diff --git a/async/src/async_channel.rs b/async/src/async_channel.rs
index ad745f4..93b5f36 100644
--- a/async/src/async_channel.rs
+++ b/async/src/async_channel.rs
@@ -99,6 +99,7 @@ impl<'a> AsyncRead for ChanInOut<'a> {
         cx: &mut Context<'_>,
         buf: &mut ReadBuf,
     ) -> Poll<Result<(), IoError>> {
+        trace!("poll read {}", self.chan);
         let this = self.deref_mut();
         chan_poll_read(&mut this.door, this.chan, None, cx, buf, &mut this.rlfut)
     }
@@ -162,6 +163,7 @@ impl<'a> AsyncWrite for ChanInOut<'a> {
         cx: &mut Context<'_>,
         buf: &[u8],
     ) -> Poll<Result<usize, IoError>> {
+        trace!("poll write {}", self.chan);
         let this = self.deref_mut();
         chan_poll_write(&mut this.door, this.chan, None, cx, buf, &mut this.wlfut)
     }
diff --git a/async/src/async_door.rs b/async/src/async_door.rs
index 02583f1..9dd764c 100644
--- a/async/src/async_door.rs
+++ b/async/src/async_door.rs
@@ -73,10 +73,12 @@ impl<'a> AsyncDoor<'a> {
             let inner = inner.deref_mut();
             inner.runner.progress(b).await?;
 
+            trace!("pre wakers {:?}", inner.chan_read_wakers);
             if let Some(ce) = inner.runner.ready_channel_input() {
                 inner.chan_read_wakers.remove(&ce)
                 .map(|w| wakers.push(w));
             }
+            trace!("pos wakers {:?}", inner.chan_read_wakers);
 
             // Pending HashMap::drain_filter
             // https://github.com/rust-lang/rust/issues/59618
diff --git a/sshproto/src/behaviour.rs b/sshproto/src/behaviour.rs
index 71e580b..19d48d2 100644
--- a/sshproto/src/behaviour.rs
+++ b/sshproto/src/behaviour.rs
@@ -200,12 +200,12 @@ pub trait ServBehaviour: Sync+Send {
     }
 
     #[allow(unused)]
-    fn sess_req_shell(&mut self, chan: u32) -> bool {
+    fn sess_shell(&mut self, chan: u32) -> bool {
         false
     }
 
     #[allow(unused)]
-    fn sess_req_exec(&mut self, chan: u32, cmd: &str) -> bool {
+    fn sess_exec(&mut self, chan: u32, cmd: TextString) -> bool {
         false
     }
 
diff --git a/sshproto/src/channel.rs b/sshproto/src/channel.rs
index 8fa2bc2..8a37e62 100644
--- a/sshproto/src/channel.rs
+++ b/sshproto/src/channel.rs
@@ -8,7 +8,7 @@ use core::mem;
 
 use heapless::{Deque, String, Vec};
 
-use crate::*;
+use crate::{*, sshwire::SSHEncodeEnum};
 use config::*;
 use conn::Dispatched;
 use packets::{ChannelReqType, ChannelOpenFailure, ChannelRequest, Packet, ChannelOpen, ChannelOpenType, ChannelData, ChannelDataExt};
@@ -16,12 +16,11 @@ use traffic::TrafSend;
 use sshwire::{BinString, TextString};
 use sshnames::*;
 
+/// The result of a channel open request.
 pub enum ChanOpened {
     Success,
-
-    /// A channel open response will be sent later
+    /// A channel open response will be sent later (for eg TCP open)
     Defer,
-
     Failure(ChanFail),
 }
 
@@ -83,7 +82,7 @@ impl Channels {
         Ok((ch.as_ref().unwrap(), p))
     }
 
-    /// Returns a `Channel` for a local number, any state.
+    /// Returns a `Channel` for a local number, any state including `InOpen`.
     pub fn get_any(&self, num: u32) -> Result<&Channel> {
         self.ch
             .get(num as usize)
@@ -246,13 +245,13 @@ impl Channels {
             ChannelOpenType::Session => {
                 // unwrap: earlier test ensures b.server() succeeds
                 let mut bserv = b.server().unwrap();
-                bserv.open_session(ch.recv.num)
+                bserv.open_session(ch.num())
             }
             ChannelOpenType::ForwardedTcpip(t) => {
-                b.open_tcp_forwarded(ch.recv.num, t)
+                b.open_tcp_forwarded(ch.num(), t)
             }
             ChannelOpenType::DirectTcpip(t) => {
-                b.open_tcp_direct(ch.recv.num, t)
+                b.open_tcp_direct(ch.num(), t)
             }
             ChannelOpenType::Unknown(_) => {
                 unreachable!()
@@ -261,11 +260,11 @@ impl Channels {
 
         match r {
             ChanOpened::Success => {
-                s.send(ch.open_done())?;
+                s.send(ch.open_done()?)?;
             },
             ChanOpened::Failure(f) => {
-                let n = ch.recv.num;
-                self.remove(n);
+                let n = ch.num();
+                self.remove(n)?;
                 return Err(f.into())
             }
             ChanOpened::Defer => {
@@ -278,20 +277,32 @@ impl Channels {
 
     pub fn dispatch_request(&mut self,
         p: &packets::ChannelRequest,
-        _s: &mut TrafSend,
-        _b: &mut Behaviour<'_>,
+        s: &mut TrafSend,
+        b: &mut Behaviour<'_>,
         ) -> Result<()> {
-            let ch = match self.get(p.num) {
-                Ok(ch) => ch,
-                Err(Error::BadChannel) => {
-                    debug!("request {p:?} channel is unknown");
-                    return Ok(())
-                },
-                Err(e) => unreachable!(),
+        if let Ok(ch) = self.get(p.num) {
+            // only servers accept requests
+            let success = if let Ok(b) = b.server() {
+                ch.dispatch_server_request(p, s, b).unwrap_or_else(|e| {
+                    debug!("Error in channel req handling for {p:?}, {e:?}");
+                    false
+                })
+            } else {
+                false
             };
 
-
-            Ok(())
+            if p.want_reply {
+                let num = ch.send_num()?;
+                if success {
+                    s.send(packets::ChannelSuccess { num })?;
+                } else {
+                    s.send(packets::ChannelFailure { num })?;
+                }
+            }
+        } else {
+            debug!("Ignoring request to unknown channel: {p:#?}");
+        }
+        Ok(())
     }
 
     // Some returned errors will be caught by caller and returned as SSH messages
@@ -583,29 +594,70 @@ impl Channel {
         }
     }
 
+    /// Local channel number
+    pub(crate) fn num(&self) -> u32 {
+        self.recv.num
+    }
+
+    /// Remote channel number, fails if channel is in progress opening
+    pub(crate) fn send_num(&self) -> Result<u32> {
+        Ok(self.send.as_ref().trap()?.num)
+    }
+
     fn request(&mut self, req: ReqDetails, s: &mut TrafSend) -> Result<()> {
         let num = self.send.as_ref().trap()?.num;
         let r = Req { num, details: req };
         s.send(r.packet()?)
     }
 
-    pub(crate) fn number(&self) -> u32 {
-        self.recv.num
-    }
-
     /// Returns an open confirmation reply packet to send.
     /// Must be called with state of `InOpen`.
-    fn open_done<'p>(&mut self) -> Packet<'p> {
+    fn open_done<'p>(&mut self) -> Result<Packet<'p>> {
         debug_assert!(matches!(self.state, ChanState::InOpen));
 
         self.state = ChanState::Normal;
-        packets::ChannelOpenConfirmation {
-            num: self.recv.num,
+        let p = packets::ChannelOpenConfirmation {
+            num: self.send_num()?,
             // unwrap: state is InOpen
             sender_num: self.send.as_ref().unwrap().num,
             initial_window: self.recv.window as u32,
             max_packet: self.recv.max_packet as u32,
-        }.into()
+        }.into();
+        Ok(p)
+    }
+
+    fn dispatch_server_request(&self,
+        p: &packets::ChannelRequest,
+        s: &mut TrafSend,
+        b: &mut dyn ServBehaviour,
+        ) -> Result<bool> {
+
+        if !matches!(self.ty, ChanType::Session) {
+            return Ok(false)
+        }
+
+        match &p.req {
+            ChannelReqType::Shell => {
+                Ok(b.sess_shell(self.num()))
+            }
+            ChannelReqType::Exec(ex) => {
+                Ok(b.sess_exec(self.num(), ex.command))
+            }
+            // TODO need to convert packet to channel Pty
+            // ChannelReqType::Pty(pty) => {
+            //     let cpty = pty.into();
+            //     Ok(b.sess_pty(self.num(), &cpty))
+            // }
+            _ => {
+                if let ChannelReqType::Unknown(u) = &p.req {
+                    warn!("Unknown channel req type \"{}\"", u)
+                } else {
+                    // OK unwrap: tested for Unknown
+                    warn!("Unhandled channel req \"{}\"", p.req.variant_name().unwrap())
+                };
+                Ok(false)
+            }
+        }
     }
 
     fn finished_input(&mut self, len: usize ) {
@@ -639,8 +691,6 @@ impl Channel {
             Ok(None)
         }
     }
-
-
 }
 
 pub struct ChanMsg {
diff --git a/sshproto/src/packets.rs b/sshproto/src/packets.rs
index d255cfb..600a4f1 100644
--- a/sshproto/src/packets.rs
+++ b/sshproto/src/packets.rs
@@ -591,11 +591,19 @@ pub struct ParseContext {
 
     // Used by sign_encode()
     pub method_pubkey_force_sig_bool: bool,
+
+    // Set to true if an unknown variant is encountered.
+    // Packet length checks should be omitted in that case.
+    pub(crate) seen_unknown: bool,
 }
 
 impl ParseContext {
     pub fn new() -> Self {
-        ParseContext { cli_auth_type: None, method_pubkey_force_sig_bool: false }
+        ParseContext {
+            cli_auth_type: None,
+            method_pubkey_force_sig_bool: false,
+            seen_unknown: false,
+        }
     }
 }
 
diff --git a/sshproto/src/runner.rs b/sshproto/src/runner.rs
index 31602ca..814d8a3 100644
--- a/sshproto/src/runner.rs
+++ b/sshproto/src/runner.rs
@@ -166,7 +166,7 @@ impl<'a> Runner<'a> {
             init_req.push(channel::ReqDetails::Shell).trap()?;
         }
         let (ch, p) = self.conn.channels.open(packets::ChannelOpenType::Session, init_req)?;
-        let chan = ch.number();
+        let chan = ch.num();
         self.traf_out.send_packet(p, &mut self.keys)?;
         self.wake();
         Ok(chan)
diff --git a/sshproto/src/sshwire.rs b/sshproto/src/sshwire.rs
index 357a488..8f80128 100644
--- a/sshproto/src/sshwire.rs
+++ b/sshproto/src/sshwire.rs
@@ -31,7 +31,7 @@ pub trait SSHSink {
 pub trait SSHSource<'de> {
     fn take(&mut self, len: usize) -> WireResult<&'de [u8]>;
     fn pos(&self) -> usize;
-    fn ctx(&self) -> &ParseContext;
+    fn ctx(&mut self) -> &mut ParseContext;
 }
 
 /// Encodes the type in SSH wire format
@@ -101,12 +101,16 @@ pub type WireResult<T> = core::result::Result<T, WireError>;
 
 /// Parses a [`Packet`] from a borrowed `&[u8]` byte buffer.
 pub fn packet_from_bytes<'a>(b: &'a [u8], ctx: &ParseContext) -> Result<Packet<'a>> {
-    let mut s = DecodeBytes { input: b, pos: 0, parse_ctx: ctx.clone() };
+    let ctx = ParseContext { seen_unknown: false, .. ctx.clone()};
+    let mut s = DecodeBytes { input: b, pos: 0, parse_ctx: ctx };
     let p = Packet::dec(&mut s)?;
-    if s.pos() == b.len() {
-        Ok(p)
-    } else {
+
+    if s.pos() != b.len() && !s.ctx().seen_unknown {
+        // No length check if the packet had an unknown variant
+        // - we skipped parsing the rest of the packet.
         Err(Error::WrongPacketLength)
+    } else {
+        Ok(p)
     }
 }
 
@@ -222,8 +226,8 @@ impl<'de> SSHSource<'de> for DecodeBytes<'de> {
         self.pos
     }
 
-    fn ctx(&self) -> &ParseContext {
-        &self.parse_ctx
+    fn ctx(&mut self) -> &mut ParseContext {
+        &mut self.parse_ctx
     }
 }
 
diff --git a/sshproto/src/traffic.rs b/sshproto/src/traffic.rs
index 110a17e..a648911 100644
--- a/sshproto/src/traffic.rs
+++ b/sshproto/src/traffic.rs
@@ -96,10 +96,11 @@ impl<'a> TrafIn<'a> {
     pub fn ready_input(&self) -> bool {
         trace!("ready_input state {:?}", self.state);
         match self.state {
-            RxState::Idle
+            | RxState::Idle
             | RxState::ReadInitial { .. }
-            | RxState::Read { .. } => true,
-            RxState::ReadComplete { .. }
+            | RxState::Read { .. }
+            => true,
+            | RxState::ReadComplete { .. }
             | RxState::InPayload { .. }
             | RxState::InChannelData { .. }
             => false,
diff --git a/sshwire_derive/src/lib.rs b/sshwire_derive/src/lib.rs
index 4b354ff..5a7d95c 100644
--- a/sshwire_derive/src/lib.rs
+++ b/sshwire_derive/src/lib.rs
@@ -508,7 +508,7 @@ fn decode_enum_names(
                     if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) {
                         // create the Unknown fallthrough but it will be at the end of the match list
                         let mut m = StreamBuilder::new();
-                        m.push_parsed(format!("_ => Self::{}(Unknown(variant))", var.name))?;
+                        m.push_parsed(format!("_ => {{ s.ctx().seen_unknown = true; Self::{}(Unknown(variant))}}", var.name))?;
                         if unknown_arm.replace(m).is_some() {
                             return Err(Error::Custom { error: "only one variant can have #[sshwire(unknown)]".into(), span: None})
                         }
-- 
GitLab