From 92a25c84547dff26e4b504dba7187a354beb10b2 Mon Sep 17 00:00:00 2001
From: Matt Johnston <matt@ucc.asn.au>
Date: Sat, 25 Jun 2022 13:39:07 +0800
Subject: [PATCH] fix compute_keys enc vs dec, better tests

---
 sshproto/src/encrypt.rs | 187 +++++++++++++++++++++++++++++++---------
 1 file changed, 144 insertions(+), 43 deletions(-)

diff --git a/sshproto/src/encrypt.rs b/sshproto/src/encrypt.rs
index 5d4f758..5806a22 100644
--- a/sshproto/src/encrypt.rs
+++ b/sshproto/src/encrypt.rs
@@ -1,3 +1,5 @@
+//! Handles encryption/decryption and framing a payload in a SSH packet.
+
 #[allow(unused_imports)]
 use {
     crate::error::{Error, Result, TrapBug},
@@ -79,7 +81,6 @@ impl KeyState {
     pub fn encrypt<'b>(
         &mut self, payload_len: usize, buf: &'b mut [u8],
     ) -> Result<usize, Error> {
-        trace!("encrypt p len {}", payload_len);
         let e = self.keys.encrypt(payload_len, buf, self.seq_encrypt.0);
         self.seq_encrypt += 1;
         e
@@ -88,6 +89,37 @@ impl KeyState {
     pub fn size_block_dec(&self) -> usize {
         self.keys.dec.size_block()
     }
+
+    /// Returns the maximum payload that can fit in an available buffer
+    /// after header, encryption, padding, mac
+    pub fn max_enc_payload(&self, total_avail: usize) -> usize {
+        // mac is independent of the rest
+        let total_avail = total_avail.saturating_sub(self.keys.integ_enc.size_out());
+
+        let overhead = SSH_LENGTH_SIZE + 1 + SSH_MIN_PADLEN;
+        let mut space = total_avail;
+
+        // multiple of block length
+        let enc_len = if self.keys.enc.is_aead() {
+            total_avail.saturating_sub(SSH_LENGTH_SIZE)
+        } else {
+            total_avail
+        };
+
+        // round down to block size
+        let extra_block = enc_len % self.keys.enc.size_block();
+        if extra_block != 0 {
+            space = space.saturating_sub(extra_block);
+        }
+
+        space = space.saturating_sub(overhead);
+
+        if space + overhead < SSH_MIN_PACKET_SIZE {
+            0
+        } else {
+            space
+        }
+    }
 }
 
 pub(crate) struct Keys {
@@ -181,7 +213,7 @@ impl Keys {
         let integ_dec = {
             let k = Self::compute_key(
                 i_d,
-                algos.integ_enc.key_len(),
+                algos.integ_dec.key_len(),
                 &mut key,
                 &mut hash,
                 k,
@@ -632,15 +664,15 @@ mod tests {
     use crate::sshnames::SSH_NAME_CURVE25519;
     #[allow(unused_imports)]
     use pretty_hex::PrettyHex;
+    use sha2::Sha256;
 
     // setting `corrupt` tests that incorrect mac is detected
-    fn do_roundtrips(keys: &mut KeyState, corrupt: bool) {
-        // for i in 0usize..40 {
+    fn do_roundtrips(keys_enc: &mut KeyState, keys_dec: &mut KeyState, corrupt: bool) {
         for i in 0usize..40 {
             let mut v: std::vec::Vec<u8> = (0u8..i as u8 + 60).collect();
             let orig_payload = v[SSH_PAYLOAD_START..SSH_PAYLOAD_START + i].to_vec();
 
-            let written = keys.encrypt(i, v.as_mut_slice()).unwrap();
+            let written = keys_enc.encrypt(i, v.as_mut_slice()).unwrap();
 
             v.truncate(written);
 
@@ -649,10 +681,10 @@ mod tests {
                 v[SSH_PAYLOAD_START] ^= 4;
             }
 
-            let l = keys.decrypt_first_block(v.as_mut_slice()).unwrap() as usize;
+            let l = keys_dec.decrypt_first_block(v.as_mut_slice()).unwrap() as usize;
             assert_eq!(l, v.len());
 
-            let dec = keys.decrypt(v.as_mut_slice());
+            let dec = keys_dec.decrypt(v.as_mut_slice());
 
             if corrupt {
                 assert!(matches!(dec, Err(Error::BadDecrypt)));
@@ -668,61 +700,130 @@ mod tests {
     #[test]
     fn roundtrip_nocipher() {
         // check padding works
-        let mut keys = KeyState::new_cleartext();
-        do_roundtrips(&mut keys, false);
+        let mut ke = KeyState::new_cleartext();
+        let mut kd = KeyState::new_cleartext();
+        do_roundtrips(&mut ke, &mut kd, false);
     }
 
     #[test]
     #[should_panic]
     fn roundtrip_nocipher_corrupt() {
         // test the test, cleartext has no mac
-        let mut keys = KeyState::new_cleartext();
-        do_roundtrips(&mut keys, true);
+        let mut ke = KeyState::new_cleartext();
+        let mut kd = KeyState::new_cleartext();
+        do_roundtrips(&mut ke, &mut kd, true);
     }
 
-    #[test]
-    fn algo_roundtrips() {
-        use sha2::Sha256;
-        init_test_log();
+    // returns combinations of ciphers as Some(), as well as a single
+    // None for no-cipher
+    fn algo_combos() -> impl Iterator<Item = Option<kex::Algos>> {
+        // TODO make this combinatorial
+        // order is enc, dec
+        const COMBOS: [(Cipher, Integ, Cipher, Integ); 4] = [
+            (Cipher::Aes256Ctr, Integ::HmacSha256,
+                Cipher::Aes256Ctr, Integ::HmacSha256),
 
-        let combos = [
-            (Cipher::Aes256Ctr, Integ::HmacSha256),
-            (Cipher::ChaPoly, Integ::ChaPoly),
-        ];
+            (Cipher::ChaPoly, Integ::ChaPoly,
+                Cipher::ChaPoly, Integ::ChaPoly),
+
+            (Cipher::Aes256Ctr, Integ::HmacSha256,
+                Cipher::ChaPoly, Integ::ChaPoly),
 
-        for (c, i) in combos {
-            let mut algos = kex::Algos {
+            (Cipher::ChaPoly, Integ::ChaPoly,
+                Cipher::Aes256Ctr, Integ::HmacSha256),
+        ];
+        COMBOS.iter().map(|(ce, ie, cd, id)| {
+            Some(kex::Algos {
                 kex: kex::SharedSecret::from_name(SSH_NAME_CURVE25519).unwrap(),
                 hostsig: sign::SigType::Ed25519,
-                cipher_enc: c.clone(),
-                cipher_dec: c.clone(),
-                integ_enc: i.clone(),
-                integ_dec: i.clone(),
+                cipher_enc: ce.clone(),
+                cipher_dec: cd.clone(),
+                integ_enc: ie.clone(),
+                integ_dec: id.clone(),
                 discard_next: false,
                 is_client: false,
-            };
+            })
+        })
+        // and plaintext
+        .chain(core::iter::once(None))
+    }
 
-            trace!("Trying cipher {c:?} integ {i:?}");
+    #[test]
+    fn algo_roundtrips() {
+        init_test_log();
 
-            // arbitrary keys
-            let hash = algos.kex.hash();
-            let h = SessId::from_slice(&Sha256::digest("some exchange hash".as_bytes())).unwrap();
-            let sess_id = SessId::from_slice(&Sha256::digest("some sessid".as_bytes())).unwrap();
-            let sharedkey = "hello".as_bytes();
-            let mut newkeys =
-                Keys::new_from(sharedkey, &h, &sess_id, &algos).unwrap();
+        for mut algos in algo_combos() {
+
+            let mut keys_enc = KeyState::new_cleartext();
+            let mut keys_dec = KeyState::new_cleartext();
+            if let Some(ref mut algos) = algos {
+                // arbitrary keys
+                let h = SessId::from_slice(&Sha256::digest("some exchange hash".as_bytes())).unwrap();
+                let sess_id = SessId::from_slice(&Sha256::digest("some sessid".as_bytes())).unwrap();
+                let sharedkey = "hello".as_bytes();
+
+                trace!("algos enc {algos:?}");
+                let newkeys = Keys::new_from(sharedkey, &h, &sess_id, &algos).unwrap();
+                keys_enc.rekey(newkeys);
+
+                // client and server enc/dec keys are derived differently, we need them
+                // to match for this test
+                algos.is_client = !algos.is_client;
+                core::mem::swap(&mut algos.cipher_enc, &mut algos.cipher_dec);
+                core::mem::swap(&mut algos.integ_enc, &mut algos.integ_dec);
+                trace!("algos dec {algos:?}");
+                let newkeys_b = Keys::new_from(sharedkey, &h, &sess_id, &algos).unwrap();
+                keys_dec.rekey(newkeys_b);
+
+            } else {
+                trace!("Trying cleartext");
+            }
+
+            do_roundtrips(&mut keys_enc, &mut keys_dec, false);
+            // corrupt test only for non-plaintext
+            if algos.is_some() {
+                do_roundtrips(&mut keys_enc, &mut keys_dec, true);
+            }
+        }
+    }
 
-            // client and server enc/dec keys are derived differently, we need them
-            // to match for this test
-            algos.is_client = !algos.is_client;
-            let newkeys_b = Keys::new_from(sharedkey, &h, &sess_id, &algos).unwrap();
-            newkeys.dec = newkeys_b.dec;
-            newkeys.integ_dec = newkeys_b.integ_dec;
+    #[test]
+    fn max_enc_payload() {
+        init_test_log();
+        for algos in algo_combos() {
 
             let mut keys = KeyState::new_cleartext();
-            keys.rekey(newkeys);
-            do_roundtrips(&mut keys, false);
-            do_roundtrips(&mut keys, true);
+            if let Some(algos) = algos {
+                // arbitrary keys
+                let h = SessId::from_slice(&Sha256::digest("some exchange hash".as_bytes())).unwrap();
+                let sess_id = SessId::from_slice(&Sha256::digest("some sessid".as_bytes())).unwrap();
+                let sharedkey = "hello".as_bytes();
+                let newkeys =
+                    Keys::new_from(sharedkey, &h, &sess_id, &algos).unwrap();
+
+                keys.rekey(newkeys);
+                trace!("algos {algos:?}");
+                trace!("integ {}", keys.keys.integ_enc.size_out());
+            } else {
+                trace!("cleartext");
+            }
+
+            let mut buf = [0u8; 100];
+
+            for i in 1..80 {
+                let p = keys.max_enc_payload(i);
+                trace!("i {i} p {p}");
+                if p > 0 {
+                    let l = keys.encrypt(p, &mut buf).unwrap();
+                    trace!("i {i} p {p} l {l}");
+                    assert!(l <= i);
+                    assert!(l >= i.saturating_sub(keys.keys.enc.size_block()));
+
+                    // check a larger payload would bump the packet size
+                    let l = keys.encrypt(p+1, &mut buf).unwrap();
+                    assert!(l > i);
+                }
+            }
         }
     }
 }
-- 
GitLab