diff --git a/sshproto/src/error.rs b/sshproto/src/error.rs index b311a9a9b7b02ddc9a73a911673abe0d1a2cd825..51abae2bc2e67ed34d61abaa9a08e1468110a355 100644 --- a/sshproto/src/error.rs +++ b/sshproto/src/error.rs @@ -56,6 +56,9 @@ pub enum Error { /// Bad channel number BadChannel, + /// SSH packet contents doesn't match length + WrongPacketLength, + // Used for unknown key types etc. #[snafu(display("{what} is not available"))] NotAvailable { what: &'static str }, diff --git a/sshproto/src/kex.rs b/sshproto/src/kex.rs index 95fd08cee5ca681550a08428bc1874fedbecfba6..da4146b3c1005fb485b5524107429fbc5dce4851 100644 --- a/sshproto/src/kex.rs +++ b/sshproto/src/kex.rs @@ -603,8 +603,8 @@ mod tests { /// Round trip a `Packet` fn reencode<'a>(out_buf: &'a mut [u8], p: Packet, ctx: &ParseContext) -> Packet<'a> { - sshwire::write_ssh(out_buf, &p).unwrap(); - sshwire::packet_from_bytes(out_buf, &ctx).unwrap() + let l = sshwire::write_ssh(out_buf, &p).unwrap(); + sshwire::packet_from_bytes(&out_buf[..l], &ctx).unwrap() } #[test] diff --git a/sshproto/src/sshwire.rs b/sshproto/src/sshwire.rs index 12c2524f7ba857d2dd85b4543d50bfb670a21bc6..4aef45e0f2de4184a0685e96ed33b745f2465a76 100644 --- a/sshproto/src/sshwire.rs +++ b/sshproto/src/sshwire.rs @@ -25,6 +25,7 @@ pub trait SSHSink { pub trait SSHSource<'de> { fn take(&mut self, len: usize) -> WireResult<&'de [u8]>; + fn pos(&self) -> usize; fn ctx(&self) -> &ParseContext; } @@ -67,6 +68,8 @@ pub enum WireError { PacketWrong, + SSHProtoError, + UnknownPacket { number: u8 }, } @@ -77,6 +80,7 @@ impl From<WireError> for Error { WireError::RanOut => Error::RanOut, WireError::BadString => Error::BadString, WireError::BadName => Error::BadName, + WireError::SSHProtoError => Error::SSHProtoError, WireError::PacketWrong => Error::PacketWrong, WireError::UnknownVariant => Error::bug_err_msg("Can't encode Unknown"), WireError::UnknownPacket { number } => Error::UnknownPacket { number }, @@ -91,7 +95,12 @@ pub type WireResult<T> = core::result::Result<T, WireError>; /// Parses a [`Packet`] from a borrowed `&[u8]` byte buffer. pub fn packet_from_bytes<'a>(b: &'a [u8], ctx: &ParseContext) -> Result<Packet<'a>> { let mut s = DecodeBytes { input: b, pos: 0, parse_ctx: ctx.clone() }; - Ok(Packet::dec(&mut s)?) + let p = Packet::dec(&mut s)?; + if s.pos() == b.len() { + Ok(p) + } else { + Err(Error::WrongPacketLength) + } } pub fn read_ssh<'a, T: SSHDecode<'a>>(b: &'a [u8], ctx: Option<ParseContext>) -> Result<T> { @@ -199,6 +208,10 @@ impl<'de> SSHSource<'de> for DecodeBytes<'de> { Ok(t) } + fn pos(&self) -> usize { + self.pos + } + fn ctx(&self) -> &ParseContext { &self.parse_ctx } @@ -315,7 +328,7 @@ impl<'de> SSHDecode<'de> for TextString<'de> { } } -// A wrapper for a u32 prefixed data structure `B`, such as a public key blob +/// A wrapper for a u32 prefixed data structure `B`, such as a public key blob pub struct Blob<B>(pub B); impl<B> AsRef<B> for Blob<B> { @@ -353,9 +366,14 @@ impl<'de, B: SSHDecode<'de>> SSHDecode<'de> for Blob<B> { fn dec<S>(s: &mut S) -> WireResult<Self> where S: sshwire::SSHSource<'de> { let len = u32::dec(s)?; + let pos1 = s.pos(); let inner = SSHDecode::dec(s)?; - // TODO verify length matches - Ok(Blob(inner)) + let pos2 = s.pos(); + if (pos2 - pos1) == len as usize { + Ok(Blob(inner)) + } else { + Err(WireError::SSHProtoError) + } } } diff --git a/sshproto/src/test.rs b/sshproto/src/test.rs index fa511157fed32b40147ce8822693009643a1873a..b0ebe7e9568d6dbd29f41a4c96d98f0e6dd6bc36 100644 --- a/sshproto/src/test.rs +++ b/sshproto/src/test.rs @@ -14,11 +14,11 @@ mod tests { fn test_roundtrip_packet(p: &Packet) -> Result<(), Error> { let mut buf1 = vec![99; 500]; - let _w1 = sshwire::write_ssh(&mut buf1, p)?; + let w1 = sshwire::write_ssh(&mut buf1, p)?; let ctx = ParseContext::new(); - let p2 = sshwire::packet_from_bytes(&buf1, &ctx)?; + let p2 = sshwire::packet_from_bytes(&buf1[..w1], &ctx)?; let mut buf2 = vec![99; 500]; let _w2 = sshwire::write_ssh(&mut buf2, &p2)?;