diff --git a/src/conn.rs b/src/conn.rs index 7f52e6e8178352aec99b946e695218ca9b2067c3..272ab6dad49ee9163eb028b3d959830375133d0c 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -191,22 +191,32 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> { /// Check that a packet is received in the correct state fn check_packet(&self, p: &Packet) -> Result<()> { - let r = match p.category() { - packets::Category::All => Ok(()), - packets::Category::Kex => Ok(()), - packets::Category::Auth => { - match self.state { - | ConnState::PreAuth - | ConnState::Authed - => Ok(()), - _ => Err(Error::SSHProtoError), - } + let r = if self.is_first_kex() && self.kex.is_strict() { + match p.category() { + packets::Category::Kex => Ok(()), + _ => { + debug!("Non-kex packet during strict kex"); + Err(Error::SSHProtoError) + }, } - packets::Category::Sess => { - match self.state { - ConnState::Authed - => Ok(()), - _ => Err(Error::SSHProtoError), + } else { + match p.category() { + packets::Category::All => Ok(()), + packets::Category::Kex => Ok(()), + packets::Category::Auth => { + match self.state { + | ConnState::PreAuth + | ConnState::Authed + => Ok(()), + _ => Err(Error::SSHProtoError), + } + } + packets::Category::Sess => { + match self.state { + ConnState::Authed + => Ok(()), + _ => Err(Error::SSHProtoError), + } } } }; @@ -221,6 +231,10 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> { r } + fn is_first_kex(&self) -> bool { + self.sess_id.is_none() + } + async fn dispatch_packet( &mut self, packet: Packet<'_>, s: &mut TrafSend<'_, '_>, b: &mut Behaviour<'_, C, S>, ) -> Result<Dispatched, Error> { @@ -237,6 +251,7 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> { self.cliserv.is_client(), &self.algo_conf, &self.remote_version, + self.is_first_kex(), s, )?; } @@ -256,7 +271,7 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> { return Err(Error::SSHProtoError); } - self.kex.handle_kexdhreply(&p, s, b.client()?, self.sess_id.is_none()).await?; + self.kex.handle_kexdhreply(&p, s, b.client()?, self.is_first_kex()).await?; } Packet::NewKeys(_) => { self.kex.handle_newkeys(&mut self.sess_id, s)?; diff --git a/src/kex.rs b/src/kex.rs index 877c99e5870e1cb2242c18444d1c68a89d8659f1..69d5d318e86c3d32562bf338c6a55678ee5c2a94 100644 --- a/src/kex.rs +++ b/src/kex.rs @@ -269,6 +269,7 @@ impl Kex { pub fn handle_kexinit( &mut self, remote_kexinit: packets::KexInit, is_client: bool, algo_conf: &AlgoConfig, remote_version: &RemoteVersion, + first_kex: bool, s: &mut TrafSend, ) -> Result<()> { // Reply if we haven't already received one. This will bump the state to Kex::KexInit @@ -285,6 +286,13 @@ impl Kex { let algos = Self::algo_negotiation(is_client, &remote_kexinit, algo_conf)?; debug!("{algos}"); + + if first_kex && algos.strict_kex { + if s.recv_seq() != 1 { + debug!("kexinit has strict kex but wasn't first packet"); + return error::PacketWrong.fail(); + } + } if is_client { let p = algos.kex.make_kexdhinit()?; s.send(p)?; @@ -521,6 +529,14 @@ impl Kex { strict_kex, }) } + + pub fn is_strict(&self) -> bool { + match self { + Kex::KexDH { algos: Algos { strict_kex: true, ..}, .. } => true, + Kex::NewKeys { algos: Algos { strict_kex: true, ..}, .. } => true, + _ => false, + } + } } #[derive(Debug, ZeroizeOnDrop)] diff --git a/src/traffic.rs b/src/traffic.rs index d691ac5615b25b585b3ee47138fa5f1d960dabf6..e6c6b5a1ca92b890756fb6b5ae22db43fea915a5 100644 --- a/src/traffic.rs +++ b/src/traffic.rs @@ -452,5 +452,10 @@ impl<'s, 'a> TrafSend<'s, 'a> { pub fn can_output(&self) -> bool { self.out.can_output() } + + /// Returns the current receive sequence number + pub fn recv_seq(&self) -> u32 { + self.keys.seq_decrypt.0 + } }