From 2abc4627d0d0d7cc1b045b47370c04310c920a8f Mon Sep 17 00:00:00 2001
From: Matt Johnston <matt@ucc.asn.au>
Date: Thu, 24 Nov 2022 22:54:20 +0800
Subject: [PATCH] Avoid some array indexing, be careful of overflow

---
 src/channel.rs | 12 +++++-------
 src/encrypt.rs |  7 +++----
 src/sshwire.rs |  4 ++--
 3 files changed, 10 insertions(+), 13 deletions(-)

diff --git a/src/channel.rs b/src/channel.rs
index a6663e8..3ca27c0 100644
--- a/src/channel.rs
+++ b/src/channel.rs
@@ -67,7 +67,7 @@ impl Channels {
         ty: packets::ChannelOpenType<'b>,
         init_req: InitReqs,
     ) -> Result<(&Channel, Packet<'b>)> {
-        let num = self.unused_chan()?;
+        let (num, ch) = self.unused_chan()?;
 
         let chan = Channel::new(num, (&ty).into(), init_req);
         let p = packets::ChannelOpen {
@@ -77,7 +77,6 @@ impl Channels {
             ty,
         }
         .into();
-        let ch = &mut self.ch[num as usize];
         *ch = Some(chan);
         Ok((ch.as_ref().unwrap(), p))
     }
@@ -145,19 +144,19 @@ impl Channels {
     }
 
     /// Returns the first available channel
-    fn unused_chan(&self) -> Result<u32> {
+    fn unused_chan(&mut self) -> Result<(u32, &mut Option<Channel>)> {
         self.ch
-            .iter()
+            .iter_mut()
             .enumerate()
             .find_map(
-                |(i, ch)| if ch.as_ref().is_none() { Some(i as u32) } else { None },
+                |(i, ch)| if ch.as_mut().is_none() { Some((i as u32, ch)) } else { None },
             )
             .ok_or(Error::NoChannels)
     }
 
     /// Creates a new channel in InOpen state.
     fn reserve_chan(&mut self, co: &ChannelOpen<'_>) -> Result<&mut Channel> {
-        let num = self.unused_chan()?;
+        let (num, ch) = self.unused_chan()?;
         let mut chan = Channel::new(num, (&co.ty).into(), Vec::new());
         chan.send = Some(ChanDir {
             num: co.num,
@@ -166,7 +165,6 @@ impl Channels {
         });
         chan.state = ChanState::InOpen;
 
-        let ch = &mut self.ch[num as usize];
         *ch = Some(chan);
         Ok(ch.as_mut().unwrap())
     }
diff --git a/src/encrypt.rs b/src/encrypt.rs
index eba4f4b..f4ca753 100644
--- a/src/encrypt.rs
+++ b/src/encrypt.rs
@@ -253,8 +253,7 @@ impl Keys {
         debug_assert!(2 * hash_ctx.output_size() >= out.len());
 
         let l = len.min(hash_ctx.output_size());
-        let rest = &mut out[..];
-        let (k1, rest) = rest.split_at_mut(l);
+        let (k1, rest) = out.split_at_mut(l);
         let (k2, _) = rest.split_at_mut(len - l);
 
         hash_ctx.reset();
@@ -434,11 +433,11 @@ impl Keys {
             return Err(Error::NoRoom);
         }
 
-        buf[SSH_LENGTH_SIZE] = padlen as u8;
         // write the length
-        buf[0..SSH_LENGTH_SIZE]
+        buf[..SSH_LENGTH_SIZE]
             .copy_from_slice(&((len - SSH_LENGTH_SIZE) as u32).to_be_bytes());
         // write random padding
+        buf[SSH_LENGTH_SIZE] = padlen as u8;
         let pad_start = SSH_LENGTH_SIZE + 1 + payload_len;
         debug_assert_eq!(pad_start + padlen, len);
         random::fill_random(&mut buf[pad_start..pad_start + padlen])?;
diff --git a/src/sshwire.rs b/src/sshwire.rs
index 1c91946..93e5c18 100644
--- a/src/sshwire.rs
+++ b/src/sshwire.rs
@@ -169,7 +169,7 @@ struct EncodeBytes<'a> {
 
 impl SSHSink for EncodeBytes<'_> {
     fn push(&mut self, v: &[u8]) -> WireResult<()> {
-        if self.pos + v.len() > self.target.len() {
+        if self.target.len() - self.pos < v.len() {
             return Err(WireError::NoRoom);
         }
         self.target[self.pos..self.pos + v.len()].copy_from_slice(v);
@@ -184,7 +184,7 @@ struct EncodeLen {
 
 impl SSHSink for EncodeLen {
     fn push(&mut self, v: &[u8]) -> WireResult<()> {
-        self.pos += v.len();
+        self.pos = self.pos.checked_add(v.len()).ok_or(WireError::NoRoom)?;
         Ok(())
     }
 }
-- 
GitLab