From fe6a06a788098e43ab824569dfa229320033ef06 Mon Sep 17 00:00:00 2001
From: Matt Johnston <matt@ucc.asn.au>
Date: Wed, 14 Jun 2023 23:14:41 +0800
Subject: [PATCH] Get rid of ParseContext for SSHEncode

Instead we set a flag in MethodPubkey which is the only use case for it.

Split verify() into ed25519 and rsa specific functions
---
 async/src/agent.rs |  5 +---
 src/cliauth.rs     |  5 ++--
 src/kex.rs         |  7 +++---
 src/packets.rs     | 22 ++++++++--------
 src/servauth.rs    |  2 +-
 src/sign.rs        | 63 +++++++++++++++++++++++++---------------------
 src/sshwire.rs     | 31 +++--------------------
 7 files changed, 57 insertions(+), 78 deletions(-)

diff --git a/async/src/agent.rs b/async/src/agent.rs
index 7d99274..74cbea3 100644
--- a/async/src/agent.rs
+++ b/async/src/agent.rs
@@ -120,10 +120,7 @@ impl AgentClient {
     }
 
     async fn request(&mut self, r: AgentRequest<'_>) -> Result<AgentResponse> {
-        let mut ctx = sunset::packets::ParseContext::new();
-        ctx.method_pubkey_force_sig_bool = true;
-
-        let l = sshwire::write_ssh_ctx(&mut self.buf, &Blob(r), ctx)?;
+        let l = sshwire::write_ssh(&mut self.buf, &Blob(r))?;
         let b = &self.buf[..l];
 
         trace!("agent request {:?}", b.hex_dump());
diff --git a/src/cliauth.rs b/src/cliauth.rs
index 537f5c9..a4c9abb 100644
--- a/src/cliauth.rs
+++ b/src/cliauth.rs
@@ -192,16 +192,15 @@ impl CliAuth {
                     sig_algo,
                     pubkey: pubkey.clone(),
                     sig: None,
+                    force_sig: true,
                 }),
             };
 
             let msg = auth::AuthSigMsg::new(&sig_packet, sess_id);
-            let mut ctx = ParseContext::default();
-            ctx.method_pubkey_force_sig_bool = true;
             if key.is_agent() {
                 Ok(b.agent_sign(key, &msg).await?)
             } else {
-                key.sign(&&msg, Some(&ctx))
+                key.sign(&&msg)
             }
         } else {
             Err(Error::bug())
diff --git a/src/kex.rs b/src/kex.rs
index b7e3cfb..03e3654 100644
--- a/src/kex.rs
+++ b/src/kex.rs
@@ -558,7 +558,7 @@ impl SharedSecret {
         // TODO: error message on signature failure.
         let h: &[u8] = kex_out.h.as_ref();
         trace!("verify  h {}", h.hex_dump());
-        algos.hostsig.verify(&p.k_s.0, &h, &p.sig.0, None)?;
+        algos.hostsig.verify(&p.k_s.0, &h, &p.sig.0)?;
         debug!("Hostkey signature is valid");
         if matches!(b.valid_hostkey(&p.k_s.0), Ok(true)) {
             Ok(kex_out)
@@ -596,7 +596,7 @@ impl SharedSecret {
 
         let k_s = Blob(hostkey.pubkey());
         trace!("sign kexreply h {}", ko.h.as_slice().hex_dump());
-        let sig = hostkey.sign(&ko.h.as_slice(), None)?;
+        let sig = hostkey.sign(&ko.h.as_slice())?;
         let sig: Signature = (&sig).into();
         let sig = Blob(sig);
         s.send(packets::KexDHReply { k_s, q_s, sig })
@@ -971,10 +971,11 @@ mod tests {
         assert!(matches!(ts.next().unwrap(), Packet::NewKeys(_)));
 
         let s = &mut tc.sender();
-        let f = cli.handle_kexdhreply(&serv_dhrep, s, cb);
+        let f = cli.handle_kexdhreply(&serv_dhrep, s, cb, true);
         let f = crate::non_async(f).unwrap();
         f.unwrap();
         assert!(matches!(tc.next().unwrap(), Packet::NewKeys(_)));
+        assert!(matches!(tc.next(), None));
 
         let (cout, calgos) = if let Kex::NewKeys { output, algos } = cli {
             (output, algos)
diff --git a/src/packets.rs b/src/packets.rs
index c83c63b..f39361d 100644
--- a/src/packets.rs
+++ b/src/packets.rs
@@ -228,6 +228,9 @@ pub struct MethodPubKey<'a> {
     pub sig_algo: &'a str,
     pub pubkey: Blob<PubKey<'a>>,
     pub sig: Option<Blob<Signature<'a>>>,
+    // Set when serializing to create a signature. Will set the "signature present"
+    // boolean to TRUE even without a signature (signature is appended later).
+    pub force_sig: bool,
 }
 
 impl<'a> MethodPubKey<'a> {
@@ -239,6 +242,7 @@ impl<'a> MethodPubKey<'a> {
             sig_algo,
             pubkey: Blob(pubkey),
             sig,
+            force_sig: false,
         })
 
     }
@@ -256,8 +260,7 @@ impl SSHEncode for MethodPubKey<'_> {
         // string    signature
 
         // Signature bool will be set when signing
-        let force_sig_bool = s.ctx().map_or(false, |c| c.method_pubkey_force_sig_bool);
-        let sig = self.sig.is_some() || force_sig_bool;
+        let sig = self.sig.is_some() || self.force_sig;
         sig.enc(s)?;
         self.sig_algo.enc(s)?;
         self.pubkey.enc(s)?;
@@ -277,7 +280,7 @@ impl<'de: 'a, 'a> SSHDecode<'de> for MethodPubKey<'a> {
         } else {
             None
         };
-        Ok(Self { sig_algo, pubkey, sig })
+        Ok(Self { sig_algo, pubkey, sig, force_sig: false })
     }
 }
 
@@ -792,14 +795,8 @@ impl Debug for Unknown<'_> {
 /// Use this so the parser can select the correct enum variant to decode.
 #[derive(Default, Clone, Debug)]
 pub struct ParseContext {
-    // Beware that currently .ctx() is not used by `length_enc()` or `Blob`,
-    // so if ParseContext needs to modify output length it may not work correctly.
-
     pub cli_auth_type: Option<auth::AuthType>,
 
-    // Used by auth_sig_msg()
-    pub method_pubkey_force_sig_bool: bool,
-
     // Set to true if an unknown variant is encountered.
     // Packet length checks should be omitted in that case.
     pub(crate) seen_unknown: bool,
@@ -809,7 +806,6 @@ impl ParseContext {
     pub fn new() -> Self {
         ParseContext {
             cli_auth_type: None,
-            method_pubkey_force_sig_bool: false,
             seen_unknown: false,
         }
     }
@@ -1031,7 +1027,7 @@ mod tests {
         test_roundtrip(&p);
 
         // again with a sig
-        let owned_sig = k.sign(&"hello", None).unwrap();
+        let owned_sig = k.sign(&"hello").unwrap();
         let sig: Signature = (&owned_sig).into();
         let sig_algo = sig.algorithm_name().unwrap();
         let sig = Some(Blob(sig));
@@ -1039,6 +1035,7 @@ mod tests {
             sig_algo,
             pubkey: Blob(k.pubkey()),
             sig,
+            force_sig: false,
         });
         let p = UserauthRequest {
             username: "matt".into(),
@@ -1109,7 +1106,8 @@ mod tests {
                 )),
                 sig: Some(Blob(Signature::Ed25519(Ed25519Sig {
                     sig: BinString(b"sighere")
-                })))
+                }))),
+                force_sig: false,
             })}.into();
 
         let mut buf1 = vec![88; 1000];
diff --git a/src/servauth.rs b/src/servauth.rs
index 81326f6..838d3c9 100644
--- a/src/servauth.rs
+++ b/src/servauth.rs
@@ -143,7 +143,7 @@ impl ServAuth {
         };
 
         let msg = auth::AuthSigMsg::new(&p, sess_id);
-        match sig_type.verify(&m.pubkey.0, &&msg, sig, None) {
+        match sig_type.verify(&m.pubkey.0, &&msg, sig) {
             Ok(()) => true,
             Err(e) => { trace!("sig failed  {e}"); false},
         }
diff --git a/src/sign.rs b/src/sign.rs
index 419812c..da6ffc8 100644
--- a/src/sign.rs
+++ b/src/sign.rs
@@ -11,8 +11,7 @@ use ed25519_dalek::{Signer, Verifier};
 use zeroize::ZeroizeOnDrop;
 
 use crate::*;
-use packets::ParseContext;
-use packets::{Ed25519PubKey, PubKey, Signature};
+use packets::{Ed25519PubKey, Ed25519Sig, PubKey, Signature};
 use sshnames::*;
 use sshwire::{BinString, Blob, SSHEncode};
 
@@ -66,7 +65,6 @@ impl SigType {
         pubkey: &PubKey,
         msg: &dyn SSHEncode,
         sig: &Signature,
-        parse_ctx: Option<&ParseContext>,
     ) -> Result<()> {
         // Check that the signature type is known
         let sig_type = sig.sig_type().map_err(|_| Error::BadSig)?;
@@ -85,32 +83,12 @@ impl SigType {
 
         match (self, pubkey, sig) {
             (SigType::Ed25519, PubKey::Ed25519(k), Signature::Ed25519(s)) => {
-                let k: &[u8; 32] = &k.key.0;
-                let k: salty::PublicKey = k.try_into().map_err(|_| Error::BadKey)?;
-                let s: &[u8; 64] = s.sig.0.try_into().map_err(|_| Error::BadSig)?;
-                let s: salty::Signature = s.into();
-                k.verify_parts(&s, |h| {
-                    sshwire::hash_ser(h, msg, parse_ctx)
-                        .map_err(|_| salty::Error::ContextTooLong)
-                })
-                .map_err(|_| Error::BadSig)
+                Self::verify_ed25519(k, msg, s)
             }
 
             #[cfg(feature = "rsa")]
             (SigType::RSA, PubKey::RSA(k), Signature::RSA(s)) => {
-                let verifying_key =
-                    rsa::pkcs1v15::VerifyingKey::<sha2::Sha256>::new_with_prefix(
-                        k.key.clone(),
-                    );
-                let s: Box<[u8]> = s.sig.0.into();
-                let signature = s.into();
-
-                let mut h = sha2::Sha256::new();
-                sshwire::hash_ser(&mut h, msg, parse_ctx)?;
-                verifying_key.verify_digest(h, &signature).map_err(|e| {
-                    trace!("RSA signature failed: {e}");
-                    Error::BadSig
-                })
+                Self::verify_rsa(k, msg, s)
             }
 
             _ => {
@@ -123,6 +101,36 @@ impl SigType {
             }
         }
     }
+
+        
+    fn verify_ed25519(k: &Ed25519PubKey, msg: &dyn SSHEncode, s: &Ed25519Sig) -> Result<()> {
+        let k: &[u8; 32] = &k.key.0;
+        let k: salty::PublicKey = k.try_into().map_err(|_| Error::BadKey)?;
+        let s: &[u8; 64] = s.sig.0.try_into().map_err(|_| Error::BadSig)?;
+        let s: salty::Signature = s.into();
+        k.verify_parts(&s, |h| {
+            sshwire::hash_ser(h, msg)
+                .map_err(|_| salty::Error::ContextTooLong)
+        })
+        .map_err(|_| Error::BadSig)
+    }
+
+    #[cfg(feature = "rsa")]
+    fn verify_rsa(k: &packets::RSAPubKey, msg: &dyn SSHEncode, s: &packets::RSASig) -> Result<()> {
+        let verifying_key =
+            rsa::pkcs1v15::VerifyingKey::<sha2::Sha256>::new_with_prefix(
+                k.key.clone(),
+            );
+        let s: Box<[u8]> = s.sig.0.into();
+        let signature = s.into();
+
+        let mut h = sha2::Sha256::new();
+        sshwire::hash_ser(&mut h, msg)?;
+        verifying_key.verify_digest(h, &signature).map_err(|e| {
+            trace!("RSA signature failed: {e}");
+            Error::BadSig
+        })
+    }
 }
 
 pub enum OwnedSig {
@@ -289,7 +297,6 @@ impl SignKey {
     pub(crate) fn sign(
         &self,
         msg: &impl SSHEncode,
-        parse_ctx: Option<&ParseContext>,
     ) -> Result<OwnedSig> {
         let sig: OwnedSig = match self {
             SignKey::Ed25519(k) => {
@@ -297,7 +304,7 @@ impl SignKey {
                 let sig = dalek::hazmat::raw_sign_byupdate::<sha2::Sha512, _>(
                     &exk,
                     |h| {
-                        sshwire::hash_ser(h, msg, parse_ctx)
+                        sshwire::hash_ser(h, msg)
                             .map_err(|_| dalek::SignatureError::new())
                     },
                     &k.verifying_key(),
@@ -313,7 +320,7 @@ impl SignKey {
                         k.clone(),
                     );
                 let mut h = sha2::Sha256::new();
-                sshwire::hash_ser(&mut h, msg, parse_ctx)?;
+                sshwire::hash_ser(&mut h, msg)?;
                 let sig = signing_key.try_sign_digest(h).map_err(|e| {
                     trace!("RSA signing failed: {e:?}");
                     Error::bug()
diff --git a/src/sshwire.rs b/src/sshwire.rs
index 2f68e08..4c1c6f0 100644
--- a/src/sshwire.rs
+++ b/src/sshwire.rs
@@ -28,9 +28,6 @@ use packets::{Packet, ParseContext};
 /// A generic destination for serializing, used similarly to `serde::Serializer`
 pub trait SSHSink {
     fn push(&mut self, v: &[u8]) -> WireResult<()>;
-    fn ctx(&self) -> Option<&ParseContext> {
-        None
-    }
 }
 
 /// A generic source for a packet, used similarly to `serde::Deserializer`
@@ -138,27 +135,18 @@ pub fn read_ssh<'a, T: SSHDecode<'a>>(b: &'a [u8], ctx: Option<ParseContext>) ->
 
 pub fn write_ssh(target: &mut [u8], value: &dyn SSHEncode) -> Result<usize>
 {
-    let mut s = EncodeBytes { target, pos: 0, parse_ctx: None };
+    let mut s = EncodeBytes { target, pos: 0 };
     value.enc(&mut s)?;
     Ok(s.pos)
 }
 
-pub fn write_ssh_ctx<T>(target: &mut [u8], value: &T, ctx: ParseContext) -> Result<usize>
-where
-    T: SSHEncode
-{
-    let mut s = EncodeBytes { target, pos: 0, parse_ctx: Some(ctx) };
-        value.enc(&mut s)?;
-    Ok(s.pos)
-}
-
 /// Hashes the SSH wire format representation of `value`, with a `u32` length prefix.
 pub fn hash_ser_length(hash_ctx: &mut impl SSHWireDigestUpdate,
     value: &dyn SSHEncode) -> Result<()>
 {
     let len: u32 = length_enc(value)?;
     hash_ctx.digest_update(&len.to_be_bytes());
-    hash_ser(hash_ctx, value, None)
+    hash_ser(hash_ctx, value)
 }
 
 /// Hashes the SSH wire format representation of `value`
@@ -166,10 +154,9 @@ pub fn hash_ser_length(hash_ctx: &mut impl SSHWireDigestUpdate,
 /// Will only fail if `value.enc()` can return an error.
 pub fn hash_ser(hash_ctx: &mut impl SSHWireDigestUpdate,
     value: &dyn SSHEncode,
-    parse_ctx: Option<&ParseContext>,
     ) -> Result<()>
 {
-    let mut s = EncodeHash { hash_ctx, parse_ctx: parse_ctx.cloned() };
+    let mut s = EncodeHash { hash_ctx };
     value.enc(&mut s)?;
     Ok(())
 }
@@ -185,7 +172,6 @@ fn length_enc(value: &dyn SSHEncode) -> WireResult<u32>
 struct EncodeBytes<'a> {
     target: &'a mut [u8],
     pos: usize,
-    parse_ctx: Option<ParseContext>,
 }
 
 impl SSHSink for EncodeBytes<'_> {
@@ -198,10 +184,6 @@ impl SSHSink for EncodeBytes<'_> {
         self.pos = end;
         Ok(())
     }
-
-    fn ctx(&self) -> Option<&ParseContext> {
-        self.parse_ctx.as_ref()
-    }
 }
 
 struct EncodeLen {
@@ -217,7 +199,6 @@ impl SSHSink for EncodeLen {
 
 struct EncodeHash<'a> {
     hash_ctx: &'a mut dyn SSHWireDigestUpdate,
-    parse_ctx: Option<ParseContext>,
 }
 
 impl SSHSink for EncodeHash<'_> {
@@ -225,10 +206,6 @@ impl SSHSink for EncodeHash<'_> {
         self.hash_ctx.digest_update(v);
         Ok(())
     }
-
-    fn ctx(&self) -> Option<&ParseContext> {
-        self.parse_ctx.as_ref()
-    }
 }
 
 struct DecodeBytes<'a> {
@@ -729,7 +706,7 @@ pub(crate) mod tests {
         // hash_ser
         let mut hash_ctx = Sha256::new();
         hash_ctx.update(&(w1 as u32).to_be_bytes());
-        hash_ser(&mut hash_ctx, &input, None).unwrap();
+        hash_ser(&mut hash_ctx, &input).unwrap();
         let digest3 = hash_ctx.finalize();
         assert_eq!(digest3, digest2);
     }
-- 
GitLab