diff --git a/async/src/agent.rs b/async/src/agent.rs index c0a40f960185699a4abbb0794f0147656d6340e5..9fc0868520e371ee755b9ed950563658565e199f 100644 --- a/async/src/agent.rs +++ b/async/src/agent.rs @@ -48,8 +48,7 @@ enum AgentRequest<'a> { } impl SSHEncode for AgentRequest<'_> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { match self { Self::SignRequest(a) => { let n = AgentMessageNum::SSH_AGENTC_SIGN_REQUEST as u8; diff --git a/src/auth.rs b/src/auth.rs index e07e5c0bdadeba6297a2de42e88caa52101753d9..66bbed6a91a5218b955e3c31b9716abe0bd415a4 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -28,8 +28,7 @@ pub struct AuthSigMsg<'a> { } impl SSHEncode for &AuthSigMsg<'_> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: sshwire::SSHSink { + fn enc(&self, s: &mut dyn sshwire::SSHSink) -> WireResult<()> { self.sess_id.enc(s)?; let m = packets::MessageNumber::SSH_MSG_USERAUTH_REQUEST as u8; diff --git a/src/kex.rs b/src/kex.rs index bc663e89e71aa129a824a7aca81da3a35a671e2f..136fb743f7751c30b1b0f971eb65b3851426417f 100644 --- a/src/kex.rs +++ b/src/kex.rs @@ -288,9 +288,9 @@ impl Kex { Ok(()) } - fn make_kexinit<'a>(cookie: &KexCookie, conf: &'a AlgoConfig) -> Packet<'a> { + fn make_kexinit<'a>(cookie: &'a KexCookie, conf: &'a AlgoConfig) -> Packet<'a> { packets::KexInit { - cookie: cookie.clone(), + cookie: cookie, kex: (&conf.kexs).into(), hostsig: (&conf.hostsig).into(), cipher_c2s: (&conf.ciphers).into(), diff --git a/src/namelist.rs b/src/namelist.rs index 318eedae35cf90703ff81afdc68fbf0f2f931942..78b9a4011afc6d0274b479f0916c6231e84e3815 100644 --- a/src/namelist.rs +++ b/src/namelist.rs @@ -55,17 +55,16 @@ impl<'de: 'a, 'a> SSHDecode<'de> for NameList<'a> { /// Serialize the list of names with comma separators impl SSHEncode for &LocalNames { - fn enc<S>(&self, e: &mut S) -> WireResult<()> - where S: sshwire::SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { let names = self.0.as_slice(); // space for names and commas let strlen = names.iter().map(|n| n.len()).sum::<usize>() + names.len().saturating_sub(1); - (strlen as u32).enc(e)?; + (strlen as u32).enc(s)?; for i in 0..names.len() { - names[i].as_bytes().enc(e)?; + names[i].as_bytes().enc(s)?; if i < names.len() - 1 { - b','.enc(e)?; + b','.enc(s)?; } } Ok(()) diff --git a/src/packets.rs b/src/packets.rs index b0278913092f20bbeb512a7f7490a8048ce39876..42fe5d6d3d3fd587084402d3be2d814799024e72 100644 --- a/src/packets.rs +++ b/src/packets.rs @@ -30,7 +30,7 @@ use sshwire::{SSHEncodeEnum, SSHDecodeEnum}; #[derive(Debug, SSHEncode, SSHDecode)] pub struct KexInit<'a> { - pub cookie: [u8; 16], + pub cookie: &'a [u8; 16], pub kex: NameList<'a>, /// A list of signature algorithms /// @@ -136,7 +136,7 @@ impl<'de: 'a, 'a> SSHDecode<'de> for ExtInfo<'a> { } impl SSHEncode for ExtInfo<'_> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { if let Some(ref algs) = self.server_sig_algs { 1u32.enc(s)?; SSH_EXT_SERVER_SIG_ALGS.enc(s)?; @@ -239,8 +239,7 @@ pub struct MethodPubKey<'a> { } impl SSHEncode for MethodPubKey<'_> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { // byte SSH_MSG_USERAUTH_REQUEST // string user name // string service name @@ -706,9 +705,12 @@ pub struct DirectTcpip<'a> { } -// Placeholder for unknown method names. These are sometimes non-fatal and -// need to be handled by the relevant code, for example newly invented pubkey types -// This is deliberately not Serializable, we only receive it. +/// Placeholder for unknown method names. +/// +///These are sometimes non-fatal and +/// need to be handled by the relevant code, for example newly invented pubkey types +/// This is deliberately not SSHEncode, we only receive it. sshwire-derive will +/// automatically create instances. #[derive(Clone, PartialEq)] pub struct Unknown<'a>(pub &'a [u8]); @@ -796,8 +798,7 @@ impl TryFrom<u8> for MessageNumber { } impl SSHEncode for Packet<'_> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { let t = self.message_num() as u8; t.enc(s)?; match self { diff --git a/src/sign.rs b/src/sign.rs index ad60f3ac81d34a9ad7a70cbce1e09671286f944e..c6f9e717ae1138d82059a4fdef78ad5fb54a4e58 100644 --- a/src/sign.rs +++ b/src/sign.rs @@ -48,7 +48,7 @@ impl SigType { /// Returns `Ok(())` on success pub fn verify( - &self, pubkey: &PubKey, msg: &impl SSHEncode, sig: &Signature, parse_ctx: Option<&ParseContext>) -> Result<()> { + &self, pubkey: &PubKey, msg: &dyn SSHEncode, sig: &Signature, parse_ctx: Option<&ParseContext>) -> Result<()> { // Check that the signature type is known let sig_type = sig.sig_type().map_err(|_| Error::BadSig)?; diff --git a/src/sshwire.rs b/src/sshwire.rs index f62c5d775c62690f81df91fe0034a041f999dec9..70b89b98c127fbb1b5d9c60f76d0953023a9f11c 100644 --- a/src/sshwire.rs +++ b/src/sshwire.rs @@ -46,7 +46,7 @@ pub trait SSHEncode { /// /// The state of the `SSHSink` is undefined after an error is returned, data may /// have been partially encoded. - fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink; + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()>; } /// For enums with an externally provided name @@ -133,9 +133,7 @@ pub fn read_ssh<'a, T: SSHDecode<'a>>(b: &'a [u8], ctx: Option<ParseContext>) -> Ok(T::dec(&mut s)?) } -pub fn write_ssh<T>(target: &mut [u8], value: &T) -> Result<usize> -where - T: SSHEncode +pub fn write_ssh(target: &mut [u8], value: &dyn SSHEncode) -> Result<usize> { let mut s = EncodeBytes { target, pos: 0, parse_ctx: None }; value.enc(&mut s)?; @@ -152,10 +150,8 @@ where } /// Hashes the SSH wire format representation of `value`, with a `u32` length prefix. -pub fn hash_ser_length<T>(hash_ctx: &mut impl DigestUpdate, - value: &T) -> Result<()> -where - T: SSHEncode, +pub fn hash_ser_length(hash_ctx: &mut impl DigestUpdate, + value: &dyn SSHEncode) -> Result<()> { let len: u32 = length_enc(value)?; hash_ctx.digest_update(&len.to_be_bytes()); @@ -163,12 +159,10 @@ where } /// Hashes the SSH wire format representation of `value` -pub fn hash_ser<T>(hash_ctx: &mut impl DigestUpdate, - value: &T, +pub fn hash_ser(hash_ctx: &mut impl DigestUpdate, + value: &dyn SSHEncode, parse_ctx: Option<&ParseContext>, ) -> Result<()> -where - T: SSHEncode, { let mut s = EncodeHash { hash_ctx, parse_ctx: parse_ctx.cloned() }; value.enc(&mut s)?; @@ -176,9 +170,7 @@ where } /// Returns `WireError::NoRoom` if larger than `u32` -fn length_enc<T>(value: &T) -> WireResult<u32> -where - T: SSHEncode, +fn length_enc(value: &dyn SSHEncode) -> WireResult<u32> { let mut s = EncodeLen { pos: 0 }; value.enc(&mut s)?; @@ -293,8 +285,7 @@ impl Debug for BinString<'_> { } impl SSHEncode for BinString<'_> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: sshwire::SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { (self.0.len() as u32).enc(s)?; self.0.enc(s) } @@ -374,8 +365,7 @@ impl Display for TextString<'_> { } impl SSHEncode for TextString<'_> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: sshwire::SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { (self.0.len() as u32).enc(s)?; self.0.enc(s) } @@ -415,8 +405,7 @@ impl<B: SSHEncode + Debug> Debug for Blob<B> { } impl<B: SSHEncode> SSHEncode for Blob<B> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: sshwire::SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { let len: u32 = sshwire::length_enc(&self.0)?; len.enc(s)?; self.0.enc(s) @@ -455,46 +444,40 @@ impl<'de, B: SSHDecode<'de>> SSHDecode<'de> for Blob<B> { /////////////////////////////////////////////// impl SSHEncode for u8 { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { s.push(&[*self]) } } impl SSHEncode for bool { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { (*self as u8).enc(s) } } impl SSHEncode for u32 { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { s.push(&self.to_be_bytes()) } } // no length prefix impl SSHEncode for &[u8] { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { // data s.push(self) } } // no length prefix -impl<const N: usize> SSHEncode for [u8; N] { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { - s.push(self) +impl<const N: usize> SSHEncode for &[u8; N] { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { + s.push(self.as_slice()) } } impl SSHEncode for &str { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { let v = self.as_bytes(); // length prefix (v.len() as u32).enc(s)?; @@ -503,8 +486,7 @@ impl SSHEncode for &str { } impl<T: SSHEncode> SSHEncode for Option<T> { - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { if let Some(t) = self.as_ref() { t.enc(s)?; } @@ -513,8 +495,7 @@ impl<T: SSHEncode> SSHEncode for Option<T> { } impl SSHEncode for &AsciiStr{ - fn enc<S>(&self, s: &mut S) -> WireResult<()> - where S: SSHSink { + fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> { let v = self.as_bytes(); BinString(v).enc(s) } @@ -527,10 +508,7 @@ impl<'de> SSHDecode<'de> for bool { } } -// #[inline] seems to decrease code size somehow - impl<'de> SSHDecode<'de> for u8 { - #[inline] fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { let t = s.take(core::mem::size_of::<u8>())?; @@ -539,7 +517,6 @@ impl<'de> SSHDecode<'de> for u8 { } impl<'de> SSHDecode<'de> for u32 { - #[inline] fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { let t = s.take(core::mem::size_of::<u32>())?; @@ -562,7 +539,6 @@ pub fn try_as_ascii_str(t: &[u8]) -> WireResult<&str> { } impl<'de: 'a, 'a> SSHDecode<'de> for &'a str { - #[inline] fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { let len = u32::dec(s)?; @@ -580,13 +556,11 @@ impl<'de: 'a, 'a> SSHDecode<'de> for &'de AsciiStr { } } -impl<'de, const N: usize> SSHDecode<'de> for [u8; N] { +impl<'de, const N: usize> SSHDecode<'de> for &'de [u8; N] { fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { - // TODO is there a better way? Or can we return a slice? - let mut l = [0u8; N]; - l.copy_from_slice(s.take(N)?); - Ok(l) + // OK unwrap: take() fails if the length is short + Ok(s.take(N)?.try_into().unwrap()) } } diff --git a/src/test.rs b/src/test.rs index 0a5f0ce42034ba72e106df94d80c6ddcffa5036c..0fe0a02e8b91a8fb5f0681b91e758330aea09215 100644 --- a/src/test.rs +++ b/src/test.rs @@ -35,7 +35,7 @@ mod tests { #[test] fn roundtrip_kexinit() { let k = KexInit { - cookie: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + cookie: &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], kex: "kex".try_into().unwrap(), hostsig: "hostkey,another".try_into().unwrap(), cipher_c2s: "chacha20-poly1305@openssh.com,aes128-ctr".try_into().unwrap(), diff --git a/sshwire-derive/src/lib.rs b/sshwire-derive/src/lib.rs index a871a43e3795b27c41c2f60f9a8ee7b44f649f46..db8039fd79b44fe071b53a17f8c87ab0362efa43 100644 --- a/sshwire-derive/src/lib.rs +++ b/sshwire-derive/src/lib.rs @@ -212,9 +212,8 @@ fn take_field_atts(atts: &[Attribute]) -> Result<Vec<FieldAtt>> { fn encode_struct(gen: &mut Generator, body: StructBody) -> Result<()> { gen.impl_for("crate::sshwire::SSHEncode") .generate_fn("enc") - .with_generic_deps("E", ["crate::sshwire::SSHSink"]) .with_self_arg(FnSelfArg::RefSelf) - .with_arg("s", "&mut E") + .with_arg("s", "&mut dyn crate::sshwire::SSHSink") .with_return_type("crate::sshwire::WireResult<()>") .body(|fn_body| { match &body.fields { @@ -263,9 +262,8 @@ fn encode_enum( gen.impl_for("crate::sshwire::SSHEncode") .generate_fn("enc") - .with_generic_deps("S", ["crate::sshwire::SSHSink"]) .with_self_arg(FnSelfArg::RefSelf) - .with_arg("s", "&mut S") + .with_arg("s", "&mut dyn crate::sshwire::SSHSink") .with_return_type("crate::sshwire::WireResult<()>") .body(|fn_body| { if cont_atts.iter().any(|c| matches!(c, ContainerAtt::VariantPrefix)) {