From 92975adbca12848f3844f6da942479e66953ad8a Mon Sep 17 00:00:00 2001
From: Matt Johnston <matt@ucc.asn.au>
Date: Mon, 11 Dec 2023 22:37:36 +0800
Subject: [PATCH] Add tests for allowed packet types during KEX

---
 src/conn.rs | 26 +++++++++++++++++++++++---
 1 file changed, 23 insertions(+), 3 deletions(-)

diff --git a/src/conn.rs b/src/conn.rs
index 272ab6d..2e193f7 100644
--- a/src/conn.rs
+++ b/src/conn.rs
@@ -176,7 +176,17 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> {
         let r = sshwire::packet_from_bytes(payload, &self.parse_ctx);
 
         match r {
-            Ok(p) => self.dispatch_packet(p, s, b).await,
+            Ok(p) => {
+                let num = p.message_num() as u8;
+                let a = self.dispatch_packet(p, s, b).await;
+                match a {
+                    | Err(Error::SSHProtoError)
+                    | Err(Error::PacketWrong)
+                    => debug!("Error handling {num} packet"),
+                    _ => (),
+                }
+                a
+            }
             Err(Error::UnknownPacket { number }) => {
                 trace!("Unimplemented packet type {number}");
                 s.send(packets::Unimplemented { seq })?;
@@ -192,6 +202,7 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> {
     /// Check that a packet is received in the correct state
     fn check_packet(&self, p: &Packet) -> Result<()> {
         let r = if self.is_first_kex() && self.kex.is_strict() {
+            // Strict Kex doesn't allow even packets like Ignore or Debug
             match p.category() {
                 packets::Category::Kex => Ok(()),
                 _ => {
@@ -199,7 +210,18 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> {
                     Err(Error::SSHProtoError)
                 },
             }
+        } else if !matches!(self.kex, Kex::Idle) {
+            // Normal KEX only allows certain packets
+            match p.category() {
+                packets::Category::All => Ok(()),
+                packets::Category::Kex => Ok(()),
+                _ => {
+                    debug!("Invalid packet during kex");
+                    Err(Error::SSHProtoError)
+                },
+            }
         } else {
+            // No KEX in progress, check for post-auth packets
             match p.category() {
                 packets::Category::All => Ok(()),
                 packets::Category::Kex => Ok(()),
@@ -221,8 +243,6 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> {
             }
         };
 
-        // TODO: reject other packets while kex is in progress?
-
         if r.is_err() {
             error!("Received unexpected packet {}",
                 p.message_num() as u8);
-- 
GitLab