diff --git a/src/conn.rs b/src/conn.rs
index 7f52e6e8178352aec99b946e695218ca9b2067c3..272ab6dad49ee9163eb028b3d959830375133d0c 100644
--- a/src/conn.rs
+++ b/src/conn.rs
@@ -191,22 +191,32 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> {
 
     /// Check that a packet is received in the correct state
     fn check_packet(&self, p: &Packet) -> Result<()> {
-        let r = match p.category() {
-            packets::Category::All => Ok(()),
-            packets::Category::Kex => Ok(()),
-            packets::Category::Auth => {
-                match self.state {
-                    | ConnState::PreAuth
-                    | ConnState::Authed
-                    => Ok(()),
-                    _ => Err(Error::SSHProtoError),
-                }
+        let r = if self.is_first_kex() && self.kex.is_strict() {
+            match p.category() {
+                packets::Category::Kex => Ok(()),
+                _ => {
+                    debug!("Non-kex packet during strict kex");
+                    Err(Error::SSHProtoError)
+                },
             }
-            packets::Category::Sess => {
-                match self.state {
-                    ConnState::Authed
-                    => Ok(()),
-                    _ => Err(Error::SSHProtoError),
+        } else {
+            match p.category() {
+                packets::Category::All => Ok(()),
+                packets::Category::Kex => Ok(()),
+                packets::Category::Auth => {
+                    match self.state {
+                        | ConnState::PreAuth
+                        | ConnState::Authed
+                        => Ok(()),
+                        _ => Err(Error::SSHProtoError),
+                    }
+                }
+                packets::Category::Sess => {
+                    match self.state {
+                        ConnState::Authed
+                        => Ok(()),
+                        _ => Err(Error::SSHProtoError),
+                    }
                 }
             }
         };
@@ -221,6 +231,10 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> {
         r
     }
 
+    fn is_first_kex(&self) -> bool {
+        self.sess_id.is_none()
+    }
+
     async fn dispatch_packet(
         &mut self, packet: Packet<'_>, s: &mut TrafSend<'_, '_>, b: &mut Behaviour<'_, C, S>,
     ) -> Result<Dispatched, Error> {
@@ -237,6 +251,7 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> {
                     self.cliserv.is_client(),
                     &self.algo_conf,
                     &self.remote_version,
+                    self.is_first_kex(),
                     s,
                 )?;
             }
@@ -256,7 +271,7 @@ impl<C: CliBehaviour, S: ServBehaviour> Conn<C, S> {
                     return Err(Error::SSHProtoError);
                 }
 
-                self.kex.handle_kexdhreply(&p, s, b.client()?, self.sess_id.is_none()).await?;
+                self.kex.handle_kexdhreply(&p, s, b.client()?, self.is_first_kex()).await?;
             }
             Packet::NewKeys(_) => {
                 self.kex.handle_newkeys(&mut self.sess_id, s)?;
diff --git a/src/kex.rs b/src/kex.rs
index 877c99e5870e1cb2242c18444d1c68a89d8659f1..69d5d318e86c3d32562bf338c6a55678ee5c2a94 100644
--- a/src/kex.rs
+++ b/src/kex.rs
@@ -269,6 +269,7 @@ impl Kex {
     pub fn handle_kexinit(
         &mut self, remote_kexinit: packets::KexInit, is_client: bool, algo_conf: &AlgoConfig,
         remote_version: &RemoteVersion,
+        first_kex: bool,
         s: &mut TrafSend,
     ) -> Result<()> {
         // Reply if we haven't already received one. This will bump the state to Kex::KexInit
@@ -285,6 +286,13 @@ impl Kex {
 
         let algos = Self::algo_negotiation(is_client, &remote_kexinit, algo_conf)?;
         debug!("{algos}");
+
+        if first_kex && algos.strict_kex {
+            if s.recv_seq() != 1 {
+                debug!("kexinit has strict kex but wasn't first packet");
+                return error::PacketWrong.fail();
+            }
+        }
         if is_client {
             let p = algos.kex.make_kexdhinit()?;
             s.send(p)?;
@@ -521,6 +529,14 @@ impl Kex {
             strict_kex,
         })
     }
+
+    pub fn is_strict(&self) -> bool {
+        match self {
+            Kex::KexDH { algos: Algos { strict_kex: true, ..}, .. } => true,
+            Kex::NewKeys { algos: Algos { strict_kex: true, ..}, .. } => true,
+            _ => false,
+        }
+    }
 }
 
 #[derive(Debug, ZeroizeOnDrop)]
diff --git a/src/traffic.rs b/src/traffic.rs
index d691ac5615b25b585b3ee47138fa5f1d960dabf6..e6c6b5a1ca92b890756fb6b5ae22db43fea915a5 100644
--- a/src/traffic.rs
+++ b/src/traffic.rs
@@ -452,5 +452,10 @@ impl<'s, 'a> TrafSend<'s, 'a> {
     pub fn can_output(&self) -> bool {
         self.out.can_output()
     }
+
+    /// Returns the current receive sequence number
+    pub fn recv_seq(&self) -> u32 {
+        self.keys.seq_decrypt.0
+    }
 }