From dc766f0a6befd810529f85828808f6ababec218c Mon Sep 17 00:00:00 2001
From: Matt Johnston <matt@ucc.asn.au>
Date: Thu, 24 Nov 2022 22:59:44 +0800
Subject: [PATCH] More sanity check for Blob length, add tests

---
 src/packets.rs | 40 +++++++++++++++++++++++
 src/sshwire.rs | 89 ++++++++++++++++++++++++++++++++++++++++++--------
 2 files changed, 116 insertions(+), 13 deletions(-)

diff --git a/src/packets.rs b/src/packets.rs
index cb63731..c469b31 100644
--- a/src/packets.rs
+++ b/src/packets.rs
@@ -893,6 +893,46 @@ mod tests {
         let ctx = ParseContext::default();
         let p2 = packet_from_bytes(&buf1, &ctx).unwrap();
         trace!("broken: {p2:#?}");
+        assert!(matches!(p2,
+            Packet::ChannelOpen(ChannelOpen { ty: ChannelOpenType::Unknown(_), ..})
+        ));
+    }
+
+    #[test]
+    /// Tests recovery from unknown variants in a blob when decoding.
+    fn unknown_variant_in_blob() {
+        init_test_log();
+        let p: Packet = UserauthRequest {
+            username: "matt".into(), service: "connection",
+            method: AuthMethod::PubKey(MethodPubKey {
+                sig_algo: "something",
+                pubkey: Blob(PubKey::Ed25519(
+                    Ed25519PubKey { key: BinString(b"zzzz") }
+                )),
+                sig: Some(Blob(Signature::Ed25519(Ed25519Sig {
+                    sig: BinString(b"sighere")
+                })))
+            })}.into();
+
+        let mut buf1 = vec![88; 1000];
+        let l = write_ssh(&mut buf1, &p).unwrap();
+        buf1.truncate(l);
+        // change a byte in the "ssh-ed25519" variant string
+        buf1[60] = 'F' as u8;
+        trace!("broken: {:?}", buf1.hex_dump());
+        let ctx = ParseContext::default();
+        let p2 = packet_from_bytes(&buf1, &ctx).unwrap();
+        trace!("broken: {p2:#?}");
+        assert!(matches!(p2,
+            Packet::UserauthRequest(UserauthRequest {
+                method: AuthMethod::PubKey(MethodPubKey {
+                    pubkey: Blob(PubKey::Unknown(Unknown(b"ssF-ed25519"))),
+                    sig: Some(Blob(Signature::Ed25519(_))),
+                    ..
+                }),
+                ..
+            })
+        ));
     }
 
     #[test]
diff --git a/src/sshwire.rs b/src/sshwire.rs
index c564bd8..0c5b7fc 100644
--- a/src/sshwire.rs
+++ b/src/sshwire.rs
@@ -391,20 +391,22 @@ impl<'de, B: SSHDecode<'de>> SSHDecode<'de> for Blob<B> {
 
         // Sanity check the length matched
         let used_len = pos2 - pos1;
-        if used_len == len {
-            Ok(Blob(inner))
-        } else {
-            let extra = len.checked_sub(used_len).ok_or_else(|| {
-                trace!("inner consumed past length of SSH Blob. \
-                    Expected {} bytes, got {} bytes {}..{}",
-                    len, pos2-pos1, pos1, pos2);
-                WireError::SSHProtoError
-            })?;
-            // Skip over unconsumed bytes in the blob.
-            // This can occur with Unknown variants
-            s.take(extra)?;
-            Ok(Blob(inner))
+        if used_len != len {
+            trace!("SSH blob length differs. \
+                Expected {} bytes, got {} bytes {}..{}",
+                len, used_len, pos1, pos2);
+            let extra = len.checked_sub(used_len).ok_or(WireError::SSHProtoError)?;
+
+            if s.ctx().seen_unknown {
+                // Skip over unconsumed bytes in the blob.
+                // This can occur with Unknown variants
+                trace!("Difference is OK, seen_unknown");
+                s.take(extra)?;
+            } else {
+                return Err(WireError::SSHProtoError)
+            }
         }
+        Ok(Blob(inner))
     }
 }
 
@@ -640,4 +642,65 @@ pub(crate) mod tests {
         ctx.cli_auth_type = Some(auth::AuthType::PubKey);
         test_roundtrip_context(&p, &ctx);
     }
+
+    // Some other blob decoding tests are in packets module
+
+    #[test]
+    fn wrong_blob_size() {
+        let p1 = Blob(BinString(b"hello"));
+
+        let mut buf1 = vec![88; 1000];
+        let l = write_ssh(&mut buf1, &p1).unwrap();
+        // some leeway
+        buf1.truncate(l+5);
+        // make the length one extra
+        buf1[3] += 1;
+        let r: Result<Blob<BinString>, _> = read_ssh(&buf1, None);
+        assert!(matches!(r.unwrap_err(), Error::SSHProtoError));
+
+        let mut buf1 = vec![88; 1000];
+        let l = write_ssh(&mut buf1, &p1).unwrap();
+        // some leeway
+        buf1.truncate(l+5);
+        // make the length one short
+        buf1[3] -= 1;
+        let r: Result<Blob<BinString>, _> = read_ssh(&buf1, None);
+        assert!(matches!(r.unwrap_err(), Error::SSHProtoError));
+    }
+
+    #[test]
+    fn wrong_packet_size() {
+        let p1 = packets::NewKeys {};
+        let p1: Packet = p1.into();
+        let mut ctx = ParseContext::new();
+
+        let mut buf1 = vec![88; 1000];
+        let l = write_ssh(&mut buf1, &p1).unwrap();
+
+        // too long
+        buf1.truncate(l+1);
+        let r = packet_from_bytes(&buf1, &ctx);
+        assert!(matches!(r.unwrap_err(), Error::WrongPacketLength));
+
+        // success
+        buf1.truncate(l);
+        packet_from_bytes(&buf1, &ctx).unwrap();
+
+        // short
+        buf1.truncate(l-1);
+        let r = packet_from_bytes(&buf1, &ctx);
+        assert!(matches!(r.unwrap_err(), Error::RanOut));
+
+    }
+
+    #[test]
+    fn overflow_encode() {
+        let mut buf1 = vec![22; 7];
+
+        assert_eq!(write_ssh(&mut buf1, &"").unwrap(), 4);
+        assert_eq!(write_ssh(&mut buf1, &"a").unwrap(), 5);
+        assert_eq!(write_ssh(&mut buf1, &"aa").unwrap(), 6);
+        assert_eq!(write_ssh(&mut buf1, &"aaa").unwrap(), 7);
+        assert!(matches!(write_ssh(&mut buf1, &"aaaa").unwrap_err(), Error::NoRoom));
+    }
 }
-- 
GitLab