From 92975adbca12848f3844f6da942479e66953ad8a Mon Sep 17 00:00:00 2001 From: Matt Johnston <matt@ucc.asn.au> Date: Mon, 11 Dec 2023 22:37:36 +0800 Subject: [PATCH] Add tests for allowed packet types during KEX --- src/conn.rs | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/conn.rs b/src/conn.rs index 272ab6d..2e193f7 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -176,7 +176,17 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> { let r = sshwire::packet_from_bytes(payload, &self.parse_ctx); match r { - Ok(p) => self.dispatch_packet(p, s, b).await, + Ok(p) => { + let num = p.message_num() as u8; + let a = self.dispatch_packet(p, s, b).await; + match a { + | Err(Error::SSHProtoError) + | Err(Error::PacketWrong) + => debug!("Error handling {num} packet"), + _ => (), + } + a + } Err(Error::UnknownPacket { number }) => { trace!("Unimplemented packet type {number}"); s.send(packets::Unimplemented { seq })?; @@ -192,6 +202,7 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> { /// Check that a packet is received in the correct state fn check_packet(&self, p: &Packet) -> Result<()> { let r = if self.is_first_kex() && self.kex.is_strict() { + // Strict Kex doesn't allow even packets like Ignore or Debug match p.category() { packets::Category::Kex => Ok(()), _ => { @@ -199,7 +210,18 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> { Err(Error::SSHProtoError) }, } + } else if !matches!(self.kex, Kex::Idle) { + // Normal KEX only allows certain packets + match p.category() { + packets::Category::All => Ok(()), + packets::Category::Kex => Ok(()), + _ => { + debug!("Invalid packet during kex"); + Err(Error::SSHProtoError) + }, + } } else { + // No KEX in progress, check for post-auth packets match p.category() { packets::Category::All => Ok(()), packets::Category::Kex => Ok(()), @@ -221,8 +243,6 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> { } }; - // TODO: reject other packets while kex is in progress? - if r.is_err() { error!("Received unexpected packet {}", p.message_num() as u8); -- GitLab