From fe6a06a788098e43ab824569dfa229320033ef06 Mon Sep 17 00:00:00 2001 From: Matt Johnston <matt@ucc.asn.au> Date: Wed, 14 Jun 2023 23:14:41 +0800 Subject: [PATCH] Get rid of ParseContext for SSHEncode Instead we set a flag in MethodPubkey which is the only use case for it. Split verify() into ed25519 and rsa specific functions --- async/src/agent.rs | 5 +--- src/cliauth.rs | 5 ++-- src/kex.rs | 7 +++--- src/packets.rs | 22 ++++++++-------- src/servauth.rs | 2 +- src/sign.rs | 63 +++++++++++++++++++++++++--------------------- src/sshwire.rs | 31 +++-------------------- 7 files changed, 57 insertions(+), 78 deletions(-) diff --git a/async/src/agent.rs b/async/src/agent.rs index 7d99274..74cbea3 100644 --- a/async/src/agent.rs +++ b/async/src/agent.rs @@ -120,10 +120,7 @@ impl AgentClient { } async fn request(&mut self, r: AgentRequest<'_>) -> Result<AgentResponse> { - let mut ctx = sunset::packets::ParseContext::new(); - ctx.method_pubkey_force_sig_bool = true; - - let l = sshwire::write_ssh_ctx(&mut self.buf, &Blob(r), ctx)?; + let l = sshwire::write_ssh(&mut self.buf, &Blob(r))?; let b = &self.buf[..l]; trace!("agent request {:?}", b.hex_dump()); diff --git a/src/cliauth.rs b/src/cliauth.rs index 537f5c9..a4c9abb 100644 --- a/src/cliauth.rs +++ b/src/cliauth.rs @@ -192,16 +192,15 @@ impl CliAuth { sig_algo, pubkey: pubkey.clone(), sig: None, + force_sig: true, }), }; let msg = auth::AuthSigMsg::new(&sig_packet, sess_id); - let mut ctx = ParseContext::default(); - ctx.method_pubkey_force_sig_bool = true; if key.is_agent() { Ok(b.agent_sign(key, &msg).await?) } else { - key.sign(&&msg, Some(&ctx)) + key.sign(&&msg) } } else { Err(Error::bug()) diff --git a/src/kex.rs b/src/kex.rs index b7e3cfb..03e3654 100644 --- a/src/kex.rs +++ b/src/kex.rs @@ -558,7 +558,7 @@ impl SharedSecret { // TODO: error message on signature failure. let h: &[u8] = kex_out.h.as_ref(); trace!("verify h {}", h.hex_dump()); - algos.hostsig.verify(&p.k_s.0, &h, &p.sig.0, None)?; + algos.hostsig.verify(&p.k_s.0, &h, &p.sig.0)?; debug!("Hostkey signature is valid"); if matches!(b.valid_hostkey(&p.k_s.0), Ok(true)) { Ok(kex_out) @@ -596,7 +596,7 @@ impl SharedSecret { let k_s = Blob(hostkey.pubkey()); trace!("sign kexreply h {}", ko.h.as_slice().hex_dump()); - let sig = hostkey.sign(&ko.h.as_slice(), None)?; + let sig = hostkey.sign(&ko.h.as_slice())?; let sig: Signature = (&sig).into(); let sig = Blob(sig); s.send(packets::KexDHReply { k_s, q_s, sig }) @@ -971,10 +971,11 @@ mod tests { assert!(matches!(ts.next().unwrap(), Packet::NewKeys(_))); let s = &mut tc.sender(); - let f = cli.handle_kexdhreply(&serv_dhrep, s, cb); + let f = cli.handle_kexdhreply(&serv_dhrep, s, cb, true); let f = crate::non_async(f).unwrap(); f.unwrap(); assert!(matches!(tc.next().unwrap(), Packet::NewKeys(_))); + assert!(matches!(tc.next(), None)); let (cout, calgos) = if let Kex::NewKeys { output, algos } = cli { (output, algos) diff --git a/src/packets.rs b/src/packets.rs index c83c63b..f39361d 100644 --- a/src/packets.rs +++ b/src/packets.rs @@ -228,6 +228,9 @@ pub struct MethodPubKey<'a> { pub sig_algo: &'a str, pub pubkey: Blob<PubKey<'a>>, pub sig: Option<Blob<Signature<'a>>>, + // Set when serializing to create a signature. Will set the "signature present" + // boolean to TRUE even without a signature (signature is appended later). + pub force_sig: bool, } impl<'a> MethodPubKey<'a> { @@ -239,6 +242,7 @@ impl<'a> MethodPubKey<'a> { sig_algo, pubkey: Blob(pubkey), sig, + force_sig: false, }) } @@ -256,8 +260,7 @@ impl SSHEncode for MethodPubKey<'_> { // string signature // Signature bool will be set when signing - let force_sig_bool = s.ctx().map_or(false, |c| c.method_pubkey_force_sig_bool); - let sig = self.sig.is_some() || force_sig_bool; + let sig = self.sig.is_some() || self.force_sig; sig.enc(s)?; self.sig_algo.enc(s)?; self.pubkey.enc(s)?; @@ -277,7 +280,7 @@ impl<'de: 'a, 'a> SSHDecode<'de> for MethodPubKey<'a> { } else { None }; - Ok(Self { sig_algo, pubkey, sig }) + Ok(Self { sig_algo, pubkey, sig, force_sig: false }) } } @@ -792,14 +795,8 @@ impl Debug for Unknown<'_> { /// Use this so the parser can select the correct enum variant to decode. #[derive(Default, Clone, Debug)] pub struct ParseContext { - // Beware that currently .ctx() is not used by `length_enc()` or `Blob`, - // so if ParseContext needs to modify output length it may not work correctly. - pub cli_auth_type: Option<auth::AuthType>, - // Used by auth_sig_msg() - pub method_pubkey_force_sig_bool: bool, - // Set to true if an unknown variant is encountered. // Packet length checks should be omitted in that case. pub(crate) seen_unknown: bool, @@ -809,7 +806,6 @@ impl ParseContext { pub fn new() -> Self { ParseContext { cli_auth_type: None, - method_pubkey_force_sig_bool: false, seen_unknown: false, } } @@ -1031,7 +1027,7 @@ mod tests { test_roundtrip(&p); // again with a sig - let owned_sig = k.sign(&"hello", None).unwrap(); + let owned_sig = k.sign(&"hello").unwrap(); let sig: Signature = (&owned_sig).into(); let sig_algo = sig.algorithm_name().unwrap(); let sig = Some(Blob(sig)); @@ -1039,6 +1035,7 @@ mod tests { sig_algo, pubkey: Blob(k.pubkey()), sig, + force_sig: false, }); let p = UserauthRequest { username: "matt".into(), @@ -1109,7 +1106,8 @@ mod tests { )), sig: Some(Blob(Signature::Ed25519(Ed25519Sig { sig: BinString(b"sighere") - }))) + }))), + force_sig: false, })}.into(); let mut buf1 = vec![88; 1000]; diff --git a/src/servauth.rs b/src/servauth.rs index 81326f6..838d3c9 100644 --- a/src/servauth.rs +++ b/src/servauth.rs @@ -143,7 +143,7 @@ impl ServAuth { }; let msg = auth::AuthSigMsg::new(&p, sess_id); - match sig_type.verify(&m.pubkey.0, &&msg, sig, None) { + match sig_type.verify(&m.pubkey.0, &&msg, sig) { Ok(()) => true, Err(e) => { trace!("sig failed {e}"); false}, } diff --git a/src/sign.rs b/src/sign.rs index 419812c..da6ffc8 100644 --- a/src/sign.rs +++ b/src/sign.rs @@ -11,8 +11,7 @@ use ed25519_dalek::{Signer, Verifier}; use zeroize::ZeroizeOnDrop; use crate::*; -use packets::ParseContext; -use packets::{Ed25519PubKey, PubKey, Signature}; +use packets::{Ed25519PubKey, Ed25519Sig, PubKey, Signature}; use sshnames::*; use sshwire::{BinString, Blob, SSHEncode}; @@ -66,7 +65,6 @@ impl SigType { pubkey: &PubKey, msg: &dyn SSHEncode, sig: &Signature, - parse_ctx: Option<&ParseContext>, ) -> Result<()> { // Check that the signature type is known let sig_type = sig.sig_type().map_err(|_| Error::BadSig)?; @@ -85,32 +83,12 @@ impl SigType { match (self, pubkey, sig) { (SigType::Ed25519, PubKey::Ed25519(k), Signature::Ed25519(s)) => { - let k: &[u8; 32] = &k.key.0; - let k: salty::PublicKey = k.try_into().map_err(|_| Error::BadKey)?; - let s: &[u8; 64] = s.sig.0.try_into().map_err(|_| Error::BadSig)?; - let s: salty::Signature = s.into(); - k.verify_parts(&s, |h| { - sshwire::hash_ser(h, msg, parse_ctx) - .map_err(|_| salty::Error::ContextTooLong) - }) - .map_err(|_| Error::BadSig) + Self::verify_ed25519(k, msg, s) } #[cfg(feature = "rsa")] (SigType::RSA, PubKey::RSA(k), Signature::RSA(s)) => { - let verifying_key = - rsa::pkcs1v15::VerifyingKey::<sha2::Sha256>::new_with_prefix( - k.key.clone(), - ); - let s: Box<[u8]> = s.sig.0.into(); - let signature = s.into(); - - let mut h = sha2::Sha256::new(); - sshwire::hash_ser(&mut h, msg, parse_ctx)?; - verifying_key.verify_digest(h, &signature).map_err(|e| { - trace!("RSA signature failed: {e}"); - Error::BadSig - }) + Self::verify_rsa(k, msg, s) } _ => { @@ -123,6 +101,36 @@ impl SigType { } } } + + + fn verify_ed25519(k: &Ed25519PubKey, msg: &dyn SSHEncode, s: &Ed25519Sig) -> Result<()> { + let k: &[u8; 32] = &k.key.0; + let k: salty::PublicKey = k.try_into().map_err(|_| Error::BadKey)?; + let s: &[u8; 64] = s.sig.0.try_into().map_err(|_| Error::BadSig)?; + let s: salty::Signature = s.into(); + k.verify_parts(&s, |h| { + sshwire::hash_ser(h, msg) + .map_err(|_| salty::Error::ContextTooLong) + }) + .map_err(|_| Error::BadSig) + } + + #[cfg(feature = "rsa")] + fn verify_rsa(k: &packets::RSAPubKey, msg: &dyn SSHEncode, s: &packets::RSASig) -> Result<()> { + let verifying_key = + rsa::pkcs1v15::VerifyingKey::<sha2::Sha256>::new_with_prefix( + k.key.clone(), + ); + let s: Box<[u8]> = s.sig.0.into(); + let signature = s.into(); + + let mut h = sha2::Sha256::new(); + sshwire::hash_ser(&mut h, msg)?; + verifying_key.verify_digest(h, &signature).map_err(|e| { + trace!("RSA signature failed: {e}"); + Error::BadSig + }) + } } pub enum OwnedSig { @@ -289,7 +297,6 @@ impl SignKey { pub(crate) fn sign( &self, msg: &impl SSHEncode, - parse_ctx: Option<&ParseContext>, ) -> Result<OwnedSig> { let sig: OwnedSig = match self { SignKey::Ed25519(k) => { @@ -297,7 +304,7 @@ impl SignKey { let sig = dalek::hazmat::raw_sign_byupdate::<sha2::Sha512, _>( &exk, |h| { - sshwire::hash_ser(h, msg, parse_ctx) + sshwire::hash_ser(h, msg) .map_err(|_| dalek::SignatureError::new()) }, &k.verifying_key(), @@ -313,7 +320,7 @@ impl SignKey { k.clone(), ); let mut h = sha2::Sha256::new(); - sshwire::hash_ser(&mut h, msg, parse_ctx)?; + sshwire::hash_ser(&mut h, msg)?; let sig = signing_key.try_sign_digest(h).map_err(|e| { trace!("RSA signing failed: {e:?}"); Error::bug() diff --git a/src/sshwire.rs b/src/sshwire.rs index 2f68e08..4c1c6f0 100644 --- a/src/sshwire.rs +++ b/src/sshwire.rs @@ -28,9 +28,6 @@ use packets::{Packet, ParseContext}; /// A generic destination for serializing, used similarly to `serde::Serializer` pub trait SSHSink { fn push(&mut self, v: &[u8]) -> WireResult<()>; - fn ctx(&self) -> Option<&ParseContext> { - None - } } /// A generic source for a packet, used similarly to `serde::Deserializer` @@ -138,27 +135,18 @@ pub fn read_ssh<'a, T: SSHDecode<'a>>(b: &'a [u8], ctx: Option<ParseContext>) -> pub fn write_ssh(target: &mut [u8], value: &dyn SSHEncode) -> Result<usize> { - let mut s = EncodeBytes { target, pos: 0, parse_ctx: None }; + let mut s = EncodeBytes { target, pos: 0 }; value.enc(&mut s)?; Ok(s.pos) } -pub fn write_ssh_ctx<T>(target: &mut [u8], value: &T, ctx: ParseContext) -> Result<usize> -where - T: SSHEncode -{ - let mut s = EncodeBytes { target, pos: 0, parse_ctx: Some(ctx) }; - value.enc(&mut s)?; - Ok(s.pos) -} - /// Hashes the SSH wire format representation of `value`, with a `u32` length prefix. pub fn hash_ser_length(hash_ctx: &mut impl SSHWireDigestUpdate, value: &dyn SSHEncode) -> Result<()> { let len: u32 = length_enc(value)?; hash_ctx.digest_update(&len.to_be_bytes()); - hash_ser(hash_ctx, value, None) + hash_ser(hash_ctx, value) } /// Hashes the SSH wire format representation of `value` @@ -166,10 +154,9 @@ pub fn hash_ser_length(hash_ctx: &mut impl SSHWireDigestUpdate, /// Will only fail if `value.enc()` can return an error. pub fn hash_ser(hash_ctx: &mut impl SSHWireDigestUpdate, value: &dyn SSHEncode, - parse_ctx: Option<&ParseContext>, ) -> Result<()> { - let mut s = EncodeHash { hash_ctx, parse_ctx: parse_ctx.cloned() }; + let mut s = EncodeHash { hash_ctx }; value.enc(&mut s)?; Ok(()) } @@ -185,7 +172,6 @@ fn length_enc(value: &dyn SSHEncode) -> WireResult<u32> struct EncodeBytes<'a> { target: &'a mut [u8], pos: usize, - parse_ctx: Option<ParseContext>, } impl SSHSink for EncodeBytes<'_> { @@ -198,10 +184,6 @@ impl SSHSink for EncodeBytes<'_> { self.pos = end; Ok(()) } - - fn ctx(&self) -> Option<&ParseContext> { - self.parse_ctx.as_ref() - } } struct EncodeLen { @@ -217,7 +199,6 @@ impl SSHSink for EncodeLen { struct EncodeHash<'a> { hash_ctx: &'a mut dyn SSHWireDigestUpdate, - parse_ctx: Option<ParseContext>, } impl SSHSink for EncodeHash<'_> { @@ -225,10 +206,6 @@ impl SSHSink for EncodeHash<'_> { self.hash_ctx.digest_update(v); Ok(()) } - - fn ctx(&self) -> Option<&ParseContext> { - self.parse_ctx.as_ref() - } } struct DecodeBytes<'a> { @@ -729,7 +706,7 @@ pub(crate) mod tests { // hash_ser let mut hash_ctx = Sha256::new(); hash_ctx.update(&(w1 as u32).to_be_bytes()); - hash_ser(&mut hash_ctx, &input, None).unwrap(); + hash_ser(&mut hash_ctx, &input).unwrap(); let digest3 = hash_ctx.finalize(); assert_eq!(digest3, digest2); } -- GitLab