From aebd94e60c65bce3e08cb570ebd717ec8c739648 Mon Sep 17 00:00:00 2001 From: Matt Johnston <matt@ucc.asn.au> Date: Fri, 7 Apr 2023 23:13:21 +0800 Subject: [PATCH] Use trait object for SSHSink in SSHEncode Saves around 5kB in picow demo, functions can be shared between EncodeBytes/EncodeLen/EncodeHash --- async/src/agent.rs | 3 +- src/auth.rs | 3 +- src/kex.rs | 4 +-- src/namelist.rs | 9 +++-- src/packets.rs | 19 ++++++----- src/sign.rs | 2 +- src/sshwire.rs | 72 +++++++++++++-------------------------- src/test.rs | 2 +- sshwire-derive/src/lib.rs | 6 ++-- 9 files changed, 45 insertions(+), 75 deletions(-) diff --git a/async/src/agent.rs b/async/src/agent.rs index c0a40f9..9fc0868 100644 --- a/async/src/agent.rs +++ b/async/src/agent.rs @@ -48,8 +48,7 @@ enum AgentRequest<'a> { } impl SSHEncode for AgentRequest<'_> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { match self { Self::SignRequest(a) => { let n = AgentMessageNum::SSH_AGENTC_SIGN_REQUEST as u8; diff --git a/src/auth.rs b/src/auth.rs index e07e5c0..66bbed6 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -28,8 +28,7 @@ pub struct AuthSigMsg<'a> { } impl SSHEncode for &AuthSigMsg<'_> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: sshwire::SSHSink { + fn enc(&self, s: &mut dyn sshwire::SSHSink) -> WireResult<()> { self.sess_id.enc(s)?; let m = packets::MessageNumber::SSH_MSG_USERAUTH_REQUEST as u8; diff --git a/src/kex.rs b/src/kex.rs index bc663e8..136fb74 100644 --- a/src/kex.rs +++ b/src/kex.rs @@ -288,9 +288,9 @@ impl Kex { Ok(()) } - fn make_kexinit<'a>(cookie: &KexCookie, conf: &'a AlgoConfig) -> Packet<'a> { + fn make_kexinit<'a>(cookie: &'a KexCookie, conf: &'a AlgoConfig) -> Packet<'a> { packets::KexInit { - cookie: cookie.clone(), + cookie: cookie, kex: (&conf.kexs).into(), hostsig: (&conf.hostsig).into(), cipher_c2s: (&conf.ciphers).into(), diff --git a/src/namelist.rs b/src/namelist.rs index 318eeda..78b9a40 100644 --- a/src/namelist.rs +++ b/src/namelist.rs @@ -55,17 +55,16 @@ impl<'de: 'a, 'a> SSHDecode<'de> for NameList<'a> { /// Serialize the list of names with comma separators impl SSHEncode for &LocalNames { - fn enc<S>(&self, e: &mut S) -> WireResult<()> - where S: sshwire::SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { let names = self.0.as_slice(); // space for names and commas let strlen = names.iter().map(|n| n.len()).sum::<usize>() + names.len().saturating_sub(1); - (strlen as u32).enc(e)?; + (strlen as u32).enc(s)?; for i in 0..names.len() { - names[i].as_bytes().enc(e)?; + names[i].as_bytes().enc(s)?; if i < names.len() - 1 { - b','.enc(e)?; + b','.enc(s)?; } } Ok(()) diff --git a/src/packets.rs b/src/packets.rs index b027891..42fe5d6 100644 --- a/src/packets.rs +++ b/src/packets.rs @@ -30,7 +30,7 @@ use sshwire::{SSHEncodeEnum, SSHDecodeEnum}; #[derive(Debug, SSHEncode, SSHDecode)] pub struct KexInit<'a> { - pub cookie: [u8; 16], + pub cookie: &'a [u8; 16], pub kex: NameList<'a>, /// A list of signature algorithms /// @@ -136,7 +136,7 @@ impl<'de: 'a, 'a> SSHDecode<'de> for ExtInfo<'a> { } impl SSHEncode for ExtInfo<'_> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { if let Some(ref algs) = self.server_sig_algs { 1u32.enc(s)?; SSH_EXT_SERVER_SIG_ALGS.enc(s)?; @@ -239,8 +239,7 @@ pub struct MethodPubKey<'a> { } impl SSHEncode for MethodPubKey<'_> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { // byte SSH_MSG_USERAUTH_REQUEST // string user name // string service name @@ -706,9 +705,12 @@ pub struct DirectTcpip<'a> { } -// Placeholder for unknown method names. These are sometimes non-fatal and -// need to be handled by the relevant code, for example newly invented pubkey types -// This is deliberately not Serializable, we only receive it. +/// Placeholder for unknown method names. +/// +///These are sometimes non-fatal and +/// need to be handled by the relevant code, for example newly invented pubkey types +/// This is deliberately not SSHEncode, we only receive it. sshwire-derive will +/// automatically create instances. #[derive(Clone, PartialEq)] pub struct Unknown<'a>(pub &'a [u8]); @@ -796,8 +798,7 @@ impl TryFrom<u8> for MessageNumber { } impl SSHEncode for Packet<'_> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { let t = self.message_num() as u8; t.enc(s)?; match self { diff --git a/src/sign.rs b/src/sign.rs index ad60f3a..c6f9e71 100644 --- a/src/sign.rs +++ b/src/sign.rs @@ -48,7 +48,7 @@ impl SigType { /// Returns `Ok(())` on success pub fn verify( - &self, pubkey: &PubKey, msg: &impl SSHEncode, sig: &Signature, parse_ctx: Option<&ParseContext>) -> Result<()> { + &self, pubkey: &PubKey, msg: &dyn SSHEncode, sig: &Signature, parse_ctx: Option<&ParseContext>) -> Result<()> { // Check that the signature type is known let sig_type = sig.sig_type().map_err(|_| Error::BadSig)?; diff --git a/src/sshwire.rs b/src/sshwire.rs index f62c5d7..70b89b9 100644 --- a/src/sshwire.rs +++ b/src/sshwire.rs @@ -46,7 +46,7 @@ pub trait SSHEncode { /// /// The state of the `SSHSink` is undefined after an error is returned, data may /// have been partially encoded. - fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink; + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()>; } /// For enums with an externally provided name @@ -133,9 +133,7 @@ pub fn read_ssh<'a, T: SSHDecode<'a>>(b: &'a [u8], ctx: Option<ParseContext>) -> Ok(T::dec(&mut s)?) } -pub fn write_ssh<T>(target: &mut [u8], value: &T) -> Result<usize> -where - T: SSHEncode +pub fn write_ssh(target: &mut [u8], value: &dyn SSHEncode) -> Result<usize> { let mut s = EncodeBytes { target, pos: 0, parse_ctx: None }; value.enc(&mut s)?; @@ -152,10 +150,8 @@ where } /// Hashes the SSH wire format representation of `value`, with a `u32` length prefix. -pub fn hash_ser_length<T>(hash_ctx: &mut impl DigestUpdate, - value: &T) -> Result<()> -where - T: SSHEncode, +pub fn hash_ser_length(hash_ctx: &mut impl DigestUpdate, + value: &dyn SSHEncode) -> Result<()> { let len: u32 = length_enc(value)?; hash_ctx.digest_update(&len.to_be_bytes()); @@ -163,12 +159,10 @@ where } /// Hashes the SSH wire format representation of `value` -pub fn hash_ser<T>(hash_ctx: &mut impl DigestUpdate, - value: &T, +pub fn hash_ser(hash_ctx: &mut impl DigestUpdate, + value: &dyn SSHEncode, parse_ctx: Option<&ParseContext>, ) -> Result<()> -where - T: SSHEncode, { let mut s = EncodeHash { hash_ctx, parse_ctx: parse_ctx.cloned() }; value.enc(&mut s)?; @@ -176,9 +170,7 @@ where } /// Returns `WireError::NoRoom` if larger than `u32` -fn length_enc<T>(value: &T) -> WireResult<u32> -where - T: SSHEncode, +fn length_enc(value: &dyn SSHEncode) -> WireResult<u32> { let mut s = EncodeLen { pos: 0 }; value.enc(&mut s)?; @@ -293,8 +285,7 @@ impl Debug for BinString<'_> { } impl SSHEncode for BinString<'_> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: sshwire::SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { (self.0.len() as u32).enc(s)?; self.0.enc(s) } @@ -374,8 +365,7 @@ impl Display for TextString<'_> { } impl SSHEncode for TextString<'_> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: sshwire::SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { (self.0.len() as u32).enc(s)?; self.0.enc(s) } @@ -415,8 +405,7 @@ impl<B: SSHEncode + Debug> Debug for Blob<B> { } impl<B: SSHEncode> SSHEncode for Blob<B> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: sshwire::SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { let len: u32 = sshwire::length_enc(&self.0)?; len.enc(s)?; self.0.enc(s) @@ -455,46 +444,40 @@ impl<'de, B: SSHDecode<'de>> SSHDecode<'de> for Blob<B> { /////////////////////////////////////////////// impl SSHEncode for u8 { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { s.push(&[*self]) } } impl SSHEncode for bool { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { (*self as u8).enc(s) } } impl SSHEncode for u32 { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { s.push(&self.to_be_bytes()) } } // no length prefix impl SSHEncode for &[u8] { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { // data s.push(self) } } // no length prefix -impl<const N: usize> SSHEncode for [u8; N] { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { - s.push(self) +impl<const N: usize> SSHEncode for &[u8; N] { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { + s.push(self.as_slice()) } } impl SSHEncode for &str { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { let v = self.as_bytes(); // length prefix (v.len() as u32).enc(s)?; @@ -503,8 +486,7 @@ impl SSHEncode for &str { } impl<T: SSHEncode> SSHEncode for Option<T> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { if let Some(t) = self.as_ref() { t.enc(s)?; } @@ -513,8 +495,7 @@ impl<T: SSHEncode> SSHEncode for Option<T> { } impl SSHEncode for &AsciiStr{ - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { let v = self.as_bytes(); BinString(v).enc(s) } @@ -527,10 +508,7 @@ impl<'de> SSHDecode<'de> for bool { } } -// #[inline] seems to decrease code size somehow - impl<'de> SSHDecode<'de> for u8 { - #[inline] fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { let t = s.take(core::mem::size_of::<u8>())?; @@ -539,7 +517,6 @@ impl<'de> SSHDecode<'de> for u8 { } impl<'de> SSHDecode<'de> for u32 { - #[inline] fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { let t = s.take(core::mem::size_of::<u32>())?; @@ -562,7 +539,6 @@ pub fn try_as_ascii_str(t: &[u8]) -> WireResult<&str> { } impl<'de: 'a, 'a> SSHDecode<'de> for &'a str { - #[inline] fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { let len = u32::dec(s)?; @@ -580,13 +556,11 @@ impl<'de: 'a, 'a> SSHDecode<'de> for &'de AsciiStr { } } -impl<'de, const N: usize> SSHDecode<'de> for [u8; N] { +impl<'de, const N: usize> SSHDecode<'de> for &'de [u8; N] { fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { - // TODO is there a better way? Or can we return a slice? - let mut l = [0u8; N]; - l.copy_from_slice(s.take(N)?); - Ok(l) + // OK unwrap: take() fails if the length is short + Ok(s.take(N)?.try_into().unwrap()) } } diff --git a/src/test.rs b/src/test.rs index 0a5f0ce..0fe0a02 100644 --- a/src/test.rs +++ b/src/test.rs @@ -35,7 +35,7 @@ mod tests { #[test] fn roundtrip_kexinit() { let k = KexInit { - cookie: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + cookie: &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], kex: "kex".try_into().unwrap(), hostsig: "hostkey,another".try_into().unwrap(), cipher_c2s: "chacha20-poly1305@openssh.com,aes128-ctr".try_into().unwrap(), diff --git a/sshwire-derive/src/lib.rs b/sshwire-derive/src/lib.rs index a871a43..db8039f 100644 --- a/sshwire-derive/src/lib.rs +++ b/sshwire-derive/src/lib.rs @@ -212,9 +212,8 @@ fn take_field_atts(atts: &[Attribute]) -> Result<Vec<FieldAtt>> { fn encode_struct(gen: &mut Generator, body: StructBody) -> Result<()> { gen.impl_for("crate::sshwire::SSHEncode") .generate_fn("enc") - .with_generic_deps("E", ["crate::sshwire::SSHSink"]) .with_self_arg(FnSelfArg::RefSelf) - .with_arg("s", "&mut E") + .with_arg("s", "&mut dyn crate::sshwire::SSHSink") .with_return_type("crate::sshwire::WireResult<()>") .body(|fn_body| { match &body.fields { @@ -263,9 +262,8 @@ fn encode_enum( gen.impl_for("crate::sshwire::SSHEncode") .generate_fn("enc") - .with_generic_deps("S", ["crate::sshwire::SSHSink"]) .with_self_arg(FnSelfArg::RefSelf) - .with_arg("s", "&mut S") + .with_arg("s", "&mut dyn crate::sshwire::SSHSink") .with_return_type("crate::sshwire::WireResult<()>") .body(|fn_body| { if cont_atts.iter().any(|c| matches!(c, ContainerAtt::VariantPrefix)) { -- GitLab