From 70ed7501d9fbfa23aea7aba31cda64dbaafce7db Mon Sep 17 00:00:00 2001 From: Matt Johnston <matt@ucc.asn.au> Date: Tue, 31 May 2022 22:28:56 +0800 Subject: [PATCH] Remove serde --- Cargo.lock | 1 - sshproto/Cargo.toml | 1 - sshproto/examples/kex1.rs | 10 +- sshproto/src/auth.rs | 5 +- sshproto/src/cliauth.rs | 7 +- sshproto/src/encrypt.rs | 2 +- sshproto/src/error.rs | 59 -- sshproto/src/kex.rs | 14 +- sshproto/src/lib.rs | 3 +- sshproto/src/namelist.rs | 58 +- sshproto/src/packets.rs | 556 ++----------------- sshproto/src/sign.rs | 14 +- sshproto/src/sshwire.rs | 208 +++++++- sshproto/src/test.rs | 12 +- sshproto/src/wireformat.rs | 1037 ------------------------------------ sshwire_derive/src/lib.rs | 2 +- 16 files changed, 287 insertions(+), 1702 deletions(-) delete mode 100644 sshproto/src/wireformat.rs diff --git a/Cargo.lock b/Cargo.lock index 596f31b..5a7b68e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -355,7 +355,6 @@ dependencies = [ "pretty-hex 0.3.0", "proptest", "rand", - "serde", "serde_json", "sha2 0.10.2", "simplelog", diff --git a/sshproto/Cargo.toml b/sshproto/Cargo.toml index d107753..71e5c34 100644 --- a/sshproto/Cargo.toml +++ b/sshproto/Cargo.toml @@ -6,7 +6,6 @@ edition = "2021" [dependencies] sshwire_derive = { path = "../sshwire_derive" } -serde = { version = "1.0", default-features = false, features = ["derive"]} snafu = { version = "0.7", default-features = false, features = ["rust_1_46"] } # TODO: check that log macro calls disappear in no_std builds log = { version = "0.4" } diff --git a/sshproto/examples/kex1.rs b/sshproto/examples/kex1.rs index 22a797b..69d1a10 100644 --- a/sshproto/examples/kex1.rs +++ b/sshproto/examples/kex1.rs @@ -5,7 +5,7 @@ use pretty_hex::PrettyHex; use door_sshproto::*; use door_sshproto::packets::*; -use door_sshproto::wireformat::BinString; +use door_sshproto::sshwire::BinString; use simplelog::{TestLogger,self,LevelFilter}; @@ -24,12 +24,12 @@ fn do_userauth() -> Result<()> { }.into(); let mut buf = vec![0; 2000]; - let written = door_sshproto::wireformat::write_ssh(&mut buf, &p)?; + let written = door_sshproto::sshwire::write_ssh(&mut buf, &p)?; buf.truncate(written); println!("buf {:?}", buf.hex_dump()); let ctx = ParseContext::new(); - let x: Packet = door_sshproto::wireformat::packet_from_bytes(&buf, &ctx)?; + let x: Packet = door_sshproto::sshwire::packet_from_bytes(&buf, &ctx)?; println!("{x:?}"); Ok(()) @@ -72,11 +72,11 @@ fn do_kexinit() -> Result<()> { let mut buf = vec![0; 2000]; - let written = door_sshproto::wireformat::write_ssh(&mut buf, &p)?; + let written = door_sshproto::sshwire::write_ssh(&mut buf, &p)?; buf.truncate(written); println!("{:?}", buf.hex_dump()); let ctx = ParseContext::new(); - let x: Packet = door_sshproto::wireformat::packet_from_bytes(&buf, &ctx)?; + let x: Packet = door_sshproto::sshwire::packet_from_bytes(&buf, &ctx)?; println!("fetched {x:?}"); // let cli= Client::new(); diff --git a/sshproto/src/auth.rs b/sshproto/src/auth.rs index 4924821..b48a907 100644 --- a/sshproto/src/auth.rs +++ b/sshproto/src/auth.rs @@ -6,7 +6,6 @@ use { use core::task::{Poll, Waker}; use heapless::{String, Vec}; -use serde::Serialize; use crate::*; use behaviour::CliBehaviour; @@ -16,12 +15,12 @@ use packets::ParseContext; use packets::{Packet, Signature, Userauth60}; use sign::SignKey; use sshnames::*; -use wireformat::BinString; +use sshwire::BinString; use sshwire_derive::SSHEncode; /// The message to be signed in a pubkey authentication message, /// RFC4252 Section 7. The packet is a UserauthRequest, with None sig. -#[derive(Serialize, SSHEncode)] +#[derive(SSHEncode)] pub(crate) struct AuthSigMsg<'a> { pub sess_id: BinString<'a>, diff --git a/sshproto/src/cliauth.rs b/sshproto/src/cliauth.rs index 4211052..089f626 100644 --- a/sshproto/src/cliauth.rs +++ b/sshproto/src/cliauth.rs @@ -18,7 +18,7 @@ use packets::{MessageNumber, AuthMethod, MethodPubKey, ParseContext, UserauthReq use packets::{Packet, Signature, Userauth60}; use sign::{SignKey, OwnedSig}; use sshnames::*; -use wireformat::{BinString, Blob}; +use sshwire::{BinString, Blob}; use kex::SessId; // pub for packets::ParseContext @@ -167,7 +167,6 @@ impl CliAuth { sig_algo, pubkey: pubkey.clone(), sig: None, - signing_now: true, }), }; @@ -176,7 +175,9 @@ impl CliAuth { msg_num: MessageNumber::SSH_MSG_USERAUTH_REQUEST as u8, u: sig_packet, }; - key.sign_serialize(&msg) + let mut ctx = ParseContext::default(); + ctx.method_pubkey_force_sig_bool = true; + key.sign_encode(&msg, Some(&ctx)) } else { Err(Error::bug()) } diff --git a/sshproto/src/encrypt.rs b/sshproto/src/encrypt.rs index bb822e4..8700bd5 100644 --- a/sshproto/src/encrypt.rs +++ b/sshproto/src/encrypt.rs @@ -15,7 +15,7 @@ use sha2::Digest as Sha2DigestForTrait; use crate::*; use kex::{self, SessId}; use sshnames::*; -use wireformat::hash_mpint; +use sshwire::hash_mpint; use ssh_chapoly::SSHChaPoly; // TODO: check that Ctr32 is sufficient. Should be OK with SSH rekeying. diff --git a/sshproto/src/error.rs b/sshproto/src/error.rs index 8c6c986..78d2760 100644 --- a/sshproto/src/error.rs +++ b/sshproto/src/error.rs @@ -5,7 +5,6 @@ use log::{debug, error, info, log, trace, warn}; use core::fmt::Arguments; use core::fmt; -use serde::de::{Expected, Unexpected}; use snafu::{prelude::*, Location}; use heapless::String; @@ -115,10 +114,6 @@ pub enum Error { #[snafu(display("Unknown {kind} method {name}"))] UnknownMethod { kind: &'static str, name: UnknownName }, - /// Serde invalid value - // internal - InvalidDeserializeU8 { value: u8 }, - /// Implementation behaviour error #[snafu(display("Failure from application: {msg}"))] BehaviourError { msg: &'static str }, @@ -261,60 +256,6 @@ impl From<BhError> for Error { } } - -// needed for docs. TODO cfg for doc? -// impl serde::de::StdError for Error {} - -// TODO: need to figure how to return our own Error variants from serde -// rather than using serde Error::custom(). -impl serde::ser::Error for Error { - fn custom<T>(msg: T) -> Self - where - T: core::fmt::Display, - { - trace!("custom ser error: {}", msg); - - Error::msg("ser error") - } -} - -impl serde::de::Error for Error { - fn custom<T>(msg: T) -> Self - where - T: core::fmt::Display, - { - trace!("custom de error: {}", msg); - - Error::msg("de error") - } - - fn invalid_value(unexp: Unexpected<'_>, exp: &dyn Expected) -> Self { - if let Unexpected::Unsigned(val) = unexp { - if val <= 255 { - return Error::InvalidDeserializeU8 { value: val as u8 }; - } - } - info!("Invalid input. Expected {} got {:?}", exp, unexp); - if let Unexpected::Str(_) = unexp { - return Error::BadString - } - Error::bug() - } - - fn unknown_variant(variant: &str, _expected: &'static [&'static str]) -> Self { - debug!("Unknown variant '{variant}' wasn't caught"); - Error::bug() - } -} - -pub struct ExpectedMessageNumber; - -impl Expected for ExpectedMessageNumber { - fn fmt(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(formatter, "a known SSH message number") - } -} - #[cfg(test)] pub(crate) mod tests { use crate::error::*; diff --git a/sshproto/src/kex.rs b/sshproto/src/kex.rs index a5f17e3..c8cefd2 100644 --- a/sshproto/src/kex.rs +++ b/sshproto/src/kex.rs @@ -20,7 +20,7 @@ use namelist::LocalNames; use packets::{Packet, PubKey, Signature}; use sign::SigType; use sshnames::*; -use wireformat::{hash_mpint, BinString, Blob}; +use sshwire::{hash_mpint, BinString, Blob}; use sshwire::{hash_ser, hash_ser_length}; use behaviour::{CliBehaviour, Behaviour, ServBehaviour}; @@ -597,9 +597,9 @@ mod tests { } /// Round trip a `Packet` - fn reserialize<'a>(out_buf: &'a mut [u8], p: Packet, ctx: &ParseContext) -> Packet<'a> { - wireformat::write_ssh(out_buf, &p).unwrap(); - wireformat::packet_from_bytes(out_buf, &ctx).unwrap() + fn reencode<'a>(out_buf: &'a mut [u8], p: Packet, ctx: &ParseContext) -> Packet<'a> { + sshwire::write_ssh(out_buf, &p).unwrap(); + sshwire::packet_from_bytes(out_buf, &ctx).unwrap() } #[test] @@ -618,12 +618,12 @@ mod tests { let mut cli = kex::Kex::new().unwrap(); let mut serv = kex::Kex::new().unwrap(); - // reserialize so we end up with NameList::String not Local + // reencode so we end up with NameList::String not Local let ctx = ParseContext::new(); let si = serv.make_kexinit(&serv_conf); - let si = reserialize(&mut bufs, si, &ctx); + let si = reencode(&mut bufs, si, &ctx); let ci = cli.make_kexinit(&cli_conf); - let ci = reserialize(&mut bufc, ci, &ctx); + let ci = reencode(&mut bufc, ci, &ctx); serv.handle_kexinit(false, &serv_conf, &cli_version, &ci).unwrap(); cli.handle_kexinit(true, &cli_conf, &serv_version, &si).unwrap(); diff --git a/sshproto/src/lib.rs b/sshproto/src/lib.rs index b593b0a..c97210d 100644 --- a/sshproto/src/lib.rs +++ b/sshproto/src/lib.rs @@ -18,7 +18,6 @@ pub mod ident; pub mod kex; pub mod test; pub mod traffic; -pub mod wireformat; pub mod namelist; pub mod random; pub mod sshnames; @@ -43,7 +42,7 @@ mod termmodes; mod async_behaviour; mod block_behaviour; mod ssh_chapoly; -mod sshwire; +pub mod sshwire; pub use behaviour::{Behaviour, BhError, BhResult, ResponseString}; #[cfg(feature = "std")] diff --git a/sshproto/src/namelist.rs b/sshproto/src/namelist.rs index f24629b..79bb17a 100644 --- a/sshproto/src/namelist.rs +++ b/sshproto/src/namelist.rs @@ -5,24 +5,17 @@ use { log::{debug, error, info, log, trace, warn}, }; - -use serde::de; -use serde::de::{DeserializeSeed, SeqAccess, Visitor, Unexpected, Expected}; -use serde::ser::{SerializeSeq, SerializeTuple, Serializer}; -use serde::Deserializer; - -use serde::{Deserialize, Serialize}; use sshwire_derive::{SSHEncode, SSHDecode}; use crate::*; use sshwire::{SSHEncode, SSHDecode, SSHSource, SSHSink}; -/// A comma separated string, can be deserialized or serialized. +/// A comma separated string, can be decoded or encoded. /// Used for remote name lists. -#[derive(Serialize, Deserialize, SSHEncode, SSHDecode, Debug)] +#[derive(SSHEncode, SSHDecode, Debug)] pub struct StringNames<'a>(pub &'a str); -/// A list of names, can only be serialized. Used for local name lists, comes +/// A list of names, can only be encoded. Used for local name lists, comes /// from local fixed lists /// Deliberately `'static` since it should only come from hardcoded local strings /// `SSH_NAME_*` in [`crate::sshnames`]. We don't validate string contents. @@ -30,28 +23,13 @@ pub struct StringNames<'a>(pub &'a str); pub struct LocalNames<'a>(pub &'a [&'static str]); /// The general form that can store either representation -#[derive(Serialize, SSHEncode, Debug)] -#[serde(untagged)] +#[derive(SSHEncode, Debug)] #[sshwire(no_variant_names)] pub enum NameList<'a> { String(StringNames<'a>), Local(LocalNames<'a>), } -impl<'de: 'a, 'a> Deserialize<'de> for NameList<'a> { - fn deserialize<D>(deserializer: D) -> Result<NameList<'a>, D::Error> - where - D: Deserializer<'de>, - { - let s = StringNames::deserialize(deserializer)?; - if s.0.is_ascii() { - Ok(NameList::String(s)) - } else { - Err(de::Error::invalid_value(Unexpected::Str(s.0), &"ASCII")) - } - } -} - impl<'de: 'a, 'a> SSHDecode<'de> for NameList<'a> { fn dec<S>(s: &mut S) -> Result<NameList<'a>> where @@ -66,31 +44,6 @@ impl<'de: 'a, 'a> SSHDecode<'de> for NameList<'a> { } } -/// Serialize the list of names with comma separators -impl<'a> Serialize for LocalNames<'a> { - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> - where - S: Serializer, - { - // This isn't quite right for a generic serde serializer - // but it's OK for our SSH serializer. Serde doesn't have - // an API to incrementally serialize string parts. - // See packets::tests::json() for a test. - let mut seq = serializer.serialize_seq(None)?; - let names = &self.0; - let strlen = names.iter().map(|n| n.len()).sum::<usize>() - + names.len().saturating_sub(1); - seq.serialize_element(&(strlen as u32))?; - for i in 0..names.len() { - seq.serialize_element(names[i].as_bytes())?; - if i < names.len() - 1 { - seq.serialize_element(&(',' as u8))?; - } - } - seq.end() - } -} - /// Serialize the list of names with comma separators impl SSHEncode for LocalNames<'_> { fn enc<S>(&self, e: &mut S) -> Result<()> @@ -219,7 +172,6 @@ impl<'a> LocalNames<'a> { #[cfg(test)] mod tests { use crate::namelist::*; - use crate::wireformat; use pretty_hex::PrettyHex; use std::vec::Vec; use crate::doorlog::init_test_log; @@ -256,7 +208,7 @@ mod tests { for t in tests.iter() { let n = NameList::Local(LocalNames(t)); let mut buf = vec![99; 30]; - let l = wireformat::write_ssh(&mut buf, &n).unwrap(); + let l = sshwire::write_ssh(&mut buf, &n).unwrap(); buf.truncate(l); let out1 = core::str::from_utf8(&buf).unwrap(); // check that a join with std gives the same result. diff --git a/sshproto/src/packets.rs b/sshproto/src/packets.rs index 0debbdf..031b167 100644 --- a/sshproto/src/packets.rs +++ b/sshproto/src/packets.rs @@ -1,8 +1,6 @@ -//! SSH protocol packets. A [`Packet`] can be serialized/deserialized to the -//! SSH Binary Packet Protocol using [`serde`] with [`crate::wireformat`]. +//! SSH protocol packets. A [`Packet`] can be encoded/decoded to the +//! SSH Binary Packet Protocol using [`crate::sshwire`]. //! -//! These are mostly container formats though there is some logic to determine -//! which enum variant needs deserializing for certain packet types. use core::borrow::BorrowMut; use core::cell::Cell; use core::fmt; @@ -14,37 +12,22 @@ use { }; use heapless::String; -use serde::de; -use serde::de::{ - DeserializeSeed, Error as DeError, Expected, MapAccess, SeqAccess, Visitor, -}; -use serde::ser::{ - Error as SerError, SerializeSeq, SerializeStruct, SerializeTuple, Serializer, -}; -use serde::Deserializer; - -use serde::{Deserialize, Serialize}; use sshwire_derive::*; use crate::*; use namelist::NameList; use sshnames::*; -use wireformat::{BinString, Blob}; +use sshwire::{BinString, Blob}; use sign::{SigType, OwnedSig}; use sshwire::{SSHEncode, SSHEncodeEnum, SSHDecode, SSHDecodeEnum, SSHSource, SSHSink}; -// Each struct needs one #[borrow] tag before one of the struct fields with a lifetime -// (eg `blob: BinString<'a>`). That avoids the cryptic error in derive: -// error[E0495]: cannot infer an appropriate lifetime for lifetime parameter `'de` due to conflicting requirements - // Any `enum` needs to have special handling to select a variant when deserializing. -// This is done in conjunction with [`wireformat::deserialize_enum`]. +// This is mostly done with `#[sshwire(...)]` attributes. -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct KexInit<'a> { pub cookie: [u8; 16], - #[serde(borrow)] pub kex: NameList<'a>, pub hostkey: NameList<'a>, // is actually a signature type, not a key type pub cipher_c2s: NameList<'a>, @@ -59,79 +42,71 @@ pub struct KexInit<'a> { pub reserved: u32, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct NewKeys {} -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct Ignore {} /// Named to avoid clashing with [`fmt::Debug`] -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct DebugPacket<'a> { pub always_display: bool, pub message: &'a str, pub lang: &'a str, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct Disconnect<'a> { pub reason: u32, pub desc: &'a str, pub lang: &'a str, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct Unimplemented { pub seq: u32, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct KexDHInit<'a> { - #[serde(borrow)] pub q_c: BinString<'a>, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct KexDHReply<'a> { - #[serde(borrow)] pub k_s: Blob<PubKey<'a>>, pub q_s: BinString<'a>, pub sig: Blob<Signature<'a>>, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct ServiceRequest<'a> { pub name: &'a str, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct ServiceAccept<'a> { pub name: &'a str, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct UserauthRequest<'a> { pub username: &'a str, pub service: &'a str, - // #[serde(deserialize_with = "wrap_unknown")] pub method: AuthMethod<'a>, } /// The method-specific part of a [`UserauthRequest`]. -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] #[sshwire(variant_prefix)] pub enum AuthMethod<'a> { - #[serde(borrow)] - #[serde(rename = "password")] #[sshwire(variant = "password")] Password(MethodPassword<'a>), - #[serde(rename = "publickey")] #[sshwire(variant = "publickey")] PubKey(MethodPubKey<'a>), - #[serde(rename = "none")] #[sshwire(variant = "none")] None, - #[serde(skip_serializing)] #[sshwire(unknown)] Unknown(Unknown<'a>), } @@ -145,16 +120,14 @@ impl<'a> TryFrom<PubKey<'a>> for AuthMethod<'a> { sig_algo, pubkey: Blob(pubkey), sig: None, - signing_now: false, })) } } -#[derive(Serialize, Deserialize, Debug, SSHEncode)] +#[derive(Debug, SSHEncode)] #[sshwire(no_variant_names)] pub enum Userauth60<'a> { - #[serde(borrow)] PkOk(UserauthPkOk<'a>), PwChangeReq(UserauthPwChangeReq<'a>), // TODO keyboard interactive @@ -174,34 +147,19 @@ impl<'de: 'a, 'a> SSHDecode<'de> for Userauth60<'a> { } } -impl<'a> Userauth60<'a> { - /// Special handling in [`wireformat`] - pub(crate) fn variant(ctx: &ParseContext) -> Result<&'static str> { - match ctx.cli_auth_type { - Some(cliauth::AuthType::Password) => Ok("PwChangeReq"), - Some(cliauth::AuthType::PubKey) => Ok("PkOk"), - _ => { - trace!("Wrong packet state for userauth60"); - return Err(Error::PacketWrong) - } - } - } -} - -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct UserauthPkOk<'a> { pub algo: &'a str, - #[serde(borrow)] pub key: Blob<PubKey<'a>>, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct UserauthPwChangeReq<'a> { pub prompt: &'a str, pub lang: &'a str, } -#[derive(Serialize, Deserialize, SSHEncode, SSHDecode)] +#[derive(SSHEncode, SSHDecode)] pub struct MethodPassword<'a> { pub change: bool, pub password: &'a str, @@ -222,29 +180,6 @@ pub struct MethodPubKey<'a> { pub sig_algo: &'a str, pub pubkey: Blob<PubKey<'a>>, pub sig: Option<Blob<Signature<'a>>>, - - // Set internally when serializing for to create a signature, - // the wire format has true for the have_signature bool - // but no actual signature. - // Should only be used by cli_auth::auth_sig_msg(). - pub(crate) signing_now: bool, - -} - -impl<'a> Serialize for MethodPubKey<'a> { - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> - where - S: Serializer, - { - let mut seq = serializer.serialize_seq(None)?; - seq.serialize_element(&(self.sig.is_some() || self.signing_now))?; - seq.serialize_element(&self.sig_algo)?; - seq.serialize_element(&self.pubkey)?; - if let Some(s) = &self.sig { - seq.serialize_element(&s)?; - } - seq.end() - } } impl SSHEncode for MethodPubKey<'_> { @@ -271,85 +206,32 @@ impl<'de: 'a, 'a> SSHDecode<'de> for MethodPubKey<'a> { } else { None }; - Ok(Self { sig_algo, pubkey, sig, signing_now: false }) + Ok(Self { sig_algo, pubkey, sig }) } } -impl<'de: 'a, 'a> Deserialize<'de> for MethodPubKey<'a> { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: Deserializer<'de>, - { - struct Vis; - - impl<'de> Visitor<'de> for Vis { - type Value = MethodPubKey<'de>; - - fn expecting( - &self, formatter: &mut core::fmt::Formatter, - ) -> core::fmt::Result { - formatter.write_str("MethodPubKey") - } - fn visit_seq<V>(self, mut seq: V) -> Result<MethodPubKey<'de>, V::Error> - where - V: SeqAccess<'de>, - { - let actual_sig = seq - .next_element()? - .ok_or_else(|| de::Error::missing_field("actual_sig flag"))?; - - let sig_algo = seq - .next_element()? - .ok_or_else(|| de::Error::missing_field("sig_algo"))?; - - let pubkey = seq - .next_element()? - .ok_or_else(|| de::Error::missing_field("pubkey"))?; - - let sig = if actual_sig { - Some( - seq.next_element()? - .ok_or_else(|| de::Error::missing_field("sig"))?, - ) - } else { - None - }; - - Ok(MethodPubKey { sig_algo, pubkey, sig, signing_now: false }) - } - } - deserializer.deserialize_seq(Vis) - } -} - -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct UserauthFailure<'a> { - #[serde(borrow)] pub methods: NameList<'a>, pub partial: bool, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct UserauthSuccess {} -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct UserauthBanner<'a> { - #[serde(borrow)] pub message: &'a str, pub lang: &'a str, } -#[derive(SSHEncode, SSHDecode, Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(SSHEncode, SSHDecode, Debug, Clone, PartialEq)] #[sshwire(variant_prefix)] pub enum PubKey<'a> { - #[serde(borrow)] - #[serde(rename = "ssh-ed25519")] #[sshwire(variant = "ssh-ed25519")] Ed25519(Ed25519PubKey<'a>), - #[serde(rename = "ssh-rsa")] #[sshwire(variant = "ssh-rsa")] RSA(RSAPubKey<'a>), - #[serde(skip_serializing)] #[sshwire(unknown)] Unknown(Unknown<'a>), } @@ -364,51 +246,26 @@ impl<'a> PubKey<'a> { } } } -// impl<'de: 'a, 'a> crate::sshwire::SSHDecode<'de> for PubKey<'a> { -// fn dec<S: crate::sshwire::SSHSource<'de>>(s: &mut S) -> Result<Self> { -// let variant = crate::sshwire::SSHDecode::dec(s)?; -// crate::sshwire::SSHDecodeEnum::dec_enum(s, variant) -// } -// } -// impl<'de: 'a, 'a> crate::sshwire::SSHDecodeEnum<'de> for PubKey<'a> { -// fn dec_enum<S: crate::sshwire::SSHSource<'de>>( -// s: &mut S, -// variant: &'de str, -// ) -> Result<Self> { -// let r = match variant { -// "ssh-ed25519" => Self::Ed25519(crate::sshwire::SSHDecode::dec(s)?), -// "ssh-rsa" => Self::RSA(crate::sshwire::SSHDecode::dec(s)?), -// unk => Self::Unknown(Unknown(unk)), -// }; -// Ok(r) -// } -// } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, SSHEncode, SSHDecode)] +#[derive(Debug, Clone, PartialEq, SSHEncode, SSHDecode)] pub struct Ed25519PubKey<'a> { - #[serde(borrow)] pub key: BinString<'a>, } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, SSHEncode, SSHDecode)] +#[derive(Debug, Clone, PartialEq, SSHEncode, SSHDecode)] pub struct RSAPubKey<'a> { - #[serde(borrow)] pub e: BinString<'a>, pub n: BinString<'a>, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] #[sshwire(variant_prefix)] pub enum Signature<'a> { - #[serde(borrow)] - #[serde(rename = "ssh-ed25519")] #[sshwire(variant = "ssh-ed25519")] Ed25519(Ed25519Sig<'a>), - #[serde(rename = "rsa-sha2-256")] #[sshwire(variant = "rsa-sha2-256")] RSA256(RSA256Sig<'a>), - #[serde(skip_serializing)] #[sshwire(unknown)] Unknown(Unknown<'a>), } @@ -458,19 +315,17 @@ impl <'a> From<&'a OwnedSig> for Signature<'a> { } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct Ed25519Sig<'a> { - #[serde(borrow)] pub sig: BinString<'a>, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct RSA256Sig<'a> { - #[serde(borrow)] pub sig: BinString<'a>, } -// #[derive(Serialize, Deserialize, Debug)] +// #[derive(Debug)] // pub struct GlobalRequest<'a> { // name: &'a str, // want_reply: bool, @@ -492,107 +347,23 @@ pub struct ChannelOpen<'a> { pub ch: ChannelOpenType<'a>, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub enum ChannelOpenType<'a> { - #[serde(rename = "session")] #[sshwire(variant = "session")] Session, - #[serde(rename = "forwarded-tcpip")] #[sshwire(variant = "forwarded-tcpip")] - #[serde(borrow)] ForwardedTcpip(ForwardedTcpip<'a>), - #[serde(rename = "direct-tcpip")] #[sshwire(variant = "direct-tcpip")] DirectTcpip(DirectTcpip<'a>), - // #[serde(rename = "x11")] // #[sshwire(variant = "x11")] // Session(X11<'a>), - // #[serde(rename = "auth-agent@openssh.com")] // #[sshwire(variant = "auth-agent@openssh.com")] // Session(Agent<'a>), - #[serde(skip_serializing)] #[sshwire(unknown)] Unknown(Unknown<'a>), } -impl<'a> Serialize for ChannelOpen<'a> { - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> - where - S: Serializer, - { - let mut seq = serializer.serialize_struct("ChannelOpen", 5)?; - let channel_type = match self.ch { - ChannelOpenType::Session => "session", - ChannelOpenType::ForwardedTcpip(_) => "forwarded-tcpip", - ChannelOpenType::DirectTcpip(_) => "direct-tcpip", - ChannelOpenType::Unknown(_) => return Err(S::Error::custom("unknown")), - }; - seq.serialize_field("channel_type", channel_type)?; - seq.serialize_field("num", &self.num)?; - seq.serialize_field("initial_window", &self.initial_window)?; - seq.serialize_field("max_packet", &self.initial_window)?; - seq.serialize_field("ch", &self.ch)?; - seq.end() - } -} - -impl<'de: 'a, 'a> Deserialize<'de> for ChannelOpen<'a> { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: Deserializer<'de>, - { - struct Vis; - - impl<'de> Visitor<'de> for Vis { - type Value = ChannelOpen<'de>; - - fn expecting( - &self, formatter: &mut core::fmt::Formatter, - ) -> core::fmt::Result { - formatter.write_str("ChannelOpen") - } - - fn visit_map<V>(self, mut map: V) -> Result<ChannelOpen<'de>, V::Error> - where - V: MapAccess<'de>, - { - // a bit horrible - let mut _k: &'de str; - let _channel_type: &'de str; - let num; - let initial_window; - let max_packet; - let ch; - (_k, _channel_type) = map - .next_entry()? - .ok_or_else(|| de::Error::missing_field("channel_type"))?; - (_k, num) = map - .next_entry()? - .ok_or_else(|| de::Error::missing_field("num"))?; - (_k, initial_window) = map - .next_entry()? - .ok_or_else(|| de::Error::missing_field("initial_window"))?; - (_k, max_packet) = map - .next_entry()? - .ok_or_else(|| de::Error::missing_field("max_packet"))?; - (_k, ch) = map - .next_entry()? - .ok_or_else(|| de::Error::missing_field("ch"))?; - - Ok(ChannelOpen { num, initial_window, max_packet, ch }) - } - } - // deserialize as a struct so wireformat can get the channel_type - // used to decode the ch enum. - deserializer.deserialize_struct( - "ChannelOpen", - &["channel_type", "num", "initial_window", "max_packet", "ch"], - Vis, - ) - } -} - -#[derive(Debug,Serialize,Deserialize, SSHEncode, SSHDecode)] +#[derive(Debug,SSHEncode, SSHDecode)] pub struct ChannelOpenConfirmation { pub num: u32, pub sender_num: u32, @@ -600,7 +371,7 @@ pub struct ChannelOpenConfirmation { pub max_packet: u32, } -#[derive(Debug,Serialize,Deserialize, SSHEncode, SSHDecode)] +#[derive(Debug,SSHEncode, SSHDecode)] pub struct ChannelOpenFailure<'a> { pub num: u32, pub reason: u32, @@ -608,43 +379,41 @@ pub struct ChannelOpenFailure<'a> { pub lang: &'a str, } -#[derive(Debug,Serialize,Deserialize, SSHEncode, SSHDecode)] +#[derive(Debug,SSHEncode, SSHDecode)] pub struct ChannelWindowAdjust { pub num: u32, pub adjust: u32, } -#[derive(Debug,Serialize,Deserialize, SSHEncode, SSHDecode)] +#[derive(Debug,SSHEncode, SSHDecode)] pub struct ChannelData<'a> { pub num: u32, - #[serde(borrow)] pub data: BinString<'a>, } -#[derive(Debug,Serialize,Deserialize, SSHEncode, SSHDecode)] +#[derive(Debug,SSHEncode, SSHDecode)] pub struct ChannelDataExt<'a> { pub num: u32, pub code: u32, - #[serde(borrow)] pub data: BinString<'a>, } -#[derive(Debug,Serialize,Deserialize, SSHEncode, SSHDecode)] +#[derive(Debug,SSHEncode, SSHDecode)] pub struct ChannelEof { pub num: u32, } -#[derive(Debug,Serialize,Deserialize, SSHEncode, SSHDecode)] +#[derive(Debug,SSHEncode, SSHDecode)] pub struct ChannelClose { pub num: u32, } -#[derive(Debug,Serialize,Deserialize, SSHEncode, SSHDecode)] +#[derive(Debug,SSHEncode, SSHDecode)] pub struct ChannelSuccess { pub num: u32, } -#[derive(Debug,Serialize,Deserialize, SSHEncode, SSHDecode)] +#[derive(Debug,SSHEncode, SSHDecode)] pub struct ChannelFailure { pub num: u32, } @@ -660,34 +429,24 @@ pub struct ChannelRequest<'a> { pub ch: ChannelReqType<'a>, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub enum ChannelReqType<'a> { - #[serde(rename = "shell")] #[sshwire(variant = "shell")] Shell, - #[serde(rename = "exec")] #[sshwire(variant = "exec")] - #[serde(borrow)] Exec(Exec<'a>), - #[serde(rename = "pty-req")] #[sshwire(variant = "pty-req")] Pty(Pty<'a>), - #[serde(rename = "subsystem")] #[sshwire(variant = "subsystem")] Subsystem(Subsystem<'a>), - #[serde(rename = "window-change")] #[sshwire(variant = "window-change")] WinChange(WinChange), - #[serde(rename = "signal")] #[sshwire(variant = "signal")] Signal(Signal<'a>), - #[serde(rename = "exit-status")] #[sshwire(variant = "exit-status")] ExitStatus(ExitStatus), - #[serde(rename = "exit-signal")] #[sshwire(variant = "exit-signal")] ExitSignal(ExitSignal<'a>), - #[serde(rename = "break")] #[sshwire(variant = "break")] Break(Break), // Other requests that aren't implemented at present: @@ -695,96 +454,16 @@ pub enum ChannelReqType<'a> { // x11-req // env // xon-xoff - #[serde(skip_serializing)] #[sshwire(unknown)] Unknown(Unknown<'a>), } -impl<'a> Serialize for ChannelRequest<'a> { - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> - where - S: Serializer, - { - let mut seq = serializer.serialize_struct("ChannelRequest", 5)?; - let channel_type = match self.ch { - ChannelReqType::Shell => "shell", - ChannelReqType::Exec(_) => "exec", - ChannelReqType::Pty(_) => "pty-req", - ChannelReqType::Subsystem(_) => "subsystem", - ChannelReqType::WinChange(_) => "window-change", - ChannelReqType::Signal(_) => "signal", - ChannelReqType::ExitStatus(_) => "exit-status", - ChannelReqType::ExitSignal(_) => "exit-signal", - ChannelReqType::Break(_) => "break", - ChannelReqType::Unknown(_) => return Err(S::Error::custom("unknown")), - }; - seq.serialize_field("num", &self.num)?; - seq.serialize_field("channel_type", channel_type)?; - seq.serialize_field("want_reply", &self.want_reply)?; - seq.serialize_field("ch", &self.ch)?; - seq.end() - } -} - -impl<'de: 'a, 'a> Deserialize<'de> for ChannelRequest<'a> { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: Deserializer<'de>, - { - struct Vis; - - impl<'de> Visitor<'de> for Vis { - type Value = ChannelRequest<'de>; - - fn expecting( - &self, formatter: &mut core::fmt::Formatter, - ) -> core::fmt::Result { - formatter.write_str("ChannelRequest") - } - - fn visit_map<V>(self, mut map: V) -> Result<ChannelRequest<'de>, V::Error> - where - V: MapAccess<'de>, - { - // a bit horrible - let mut _k: &'de str; - let _channel_type: &'de str; - let num; - let want_reply; - let ch; - (_k, num) = map - .next_entry()? - .ok_or_else(|| de::Error::missing_field("num"))?; - (_k, _channel_type) = map - .next_entry()? - .ok_or_else(|| de::Error::missing_field("channel_type"))?; - (_k, want_reply) = map - .next_entry()? - .ok_or_else(|| de::Error::missing_field("want_reply"))?; - (_k, ch) = map - .next_entry()? - .ok_or_else(|| de::Error::missing_field("ch"))?; - - Ok(ChannelRequest { num, want_reply, ch }) - } - } - // deserialize as a struct so wireformat can get the channel_type - // used to decode the ch enum. - deserializer.deserialize_struct( - "ChannelRequest", - &["num", "channel_type", "want_reply", "ch"], - Vis, - ) - } -} - - -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct Exec<'a> { pub command: &'a str, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct Pty<'a> { pub term: &'a str, pub cols: u32, @@ -794,12 +473,12 @@ pub struct Pty<'a> { pub modes: BinString<'a>, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct Subsystem<'a> { pub subsystem: &'a str, } -#[derive(Serialize, Deserialize, Debug, Clone, SSHEncode, SSHDecode)] +#[derive(Debug, Clone, SSHEncode, SSHDecode)] pub struct WinChange { pub cols: u32, pub rows: u32, @@ -807,17 +486,17 @@ pub struct WinChange { pub height: u32, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct Signal<'a> { pub sig: &'a str, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct ExitStatus { pub status: u32, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct ExitSignal<'a> { pub signal: &'a str, pub core: bool, @@ -825,12 +504,12 @@ pub struct ExitSignal<'a> { pub lang: &'a str, } -#[derive(Serialize, Deserialize, Debug, Clone, SSHEncode, SSHDecode)] +#[derive(Debug, Clone, SSHEncode, SSHDecode)] pub struct Break { pub length: u32, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct ForwardedTcpip<'a> { pub address: &'a str, pub port: u32, @@ -838,7 +517,7 @@ pub struct ForwardedTcpip<'a> { pub origin_port: u32, } -#[derive(Serialize, Deserialize, Debug, SSHEncode, SSHDecode)] +#[derive(Debug, SSHEncode, SSHDecode)] pub struct DirectTcpip<'a> { pub address: &'a str, pub port: u32, @@ -850,15 +529,16 @@ pub struct DirectTcpip<'a> { // Placeholder for unknown method names. These are sometimes non-fatal and // need to be handled by the relevant code, for example newly invented pubkey types // This is deliberately not Serializable, we only receive it. -#[derive(Debug, Deserialize, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub struct Unknown<'a>(pub &'a str); -/// State to be passed to deserialisation. -/// Use this so the parser can select the correct enum variant to deserialize. +/// State to be passed to decoding. +/// Use this so the parser can select the correct enum variant to decode. #[derive(Default, Clone, Debug)] pub struct ParseContext { pub cli_auth_type: Option<cliauth::AuthType>, + // Used by sign_encode() pub method_pubkey_force_sig_bool: bool, } @@ -868,14 +548,6 @@ impl ParseContext { } } -/// State passed as the Deserializer seed. -pub(crate) struct PacketState { - pub ctx: ParseContext, - // Private fields that keep state during parsing. - // TODO Perhaps not actually necessary, could be removed and just pass ParseContext? - // pub(crate) ty: Cell<Option<MessageNumber>>, -} - // we have repeated `match` statements for the various packet types, use a macro macro_rules! messagetypes { ( @@ -910,56 +582,6 @@ impl TryFrom<u8> for MessageNumber { } } -impl<'de: 'a, 'a> Deserialize<'de> for Packet<'a> { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: Deserializer<'de>, - { - struct Vis; - - impl<'de> Visitor<'de> for Vis { - type Value = Packet<'de>; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("struct Packet") - } - fn visit_seq<V>(self, mut seq: V) -> Result<Packet<'de>, V::Error> - where - V: SeqAccess<'de> - { - // First byte is always message number - let msg_num: u8 = seq - .next_element()? - .ok_or_else(|| de::Error::missing_field("message number"))?; - let ty = MessageNumber::try_from(msg_num); - let ty = match ty { - Ok(t) => t, - Err(_) => { - return Err(de::Error::invalid_value(de::Unexpected::Unsigned(msg_num as u64), - &self)); - } - }; - - // Decode based on the message number - let p = match ty { - // eg - // MessageNumber::SSH_MSG_KEXINIT => Packet::KexInit( - // ... - $( - MessageNumber::$SSH_MESSAGE_NAME => Packet::$SpecificPacketVariant( - seq.next_element()? - .ok_or_else(|| de::Error::missing_field("rest of packet"))? - ), - )* - }; - - Ok(p) - } - } - deserializer.deserialize_seq(Vis { }) - } -} - impl SSHEncode for Packet<'_> { fn enc<S>(&self, s: &mut S) -> Result<()> where S: SSHSink { @@ -1002,31 +624,6 @@ impl<'de: 'a, 'a> SSHDecode<'de> for Packet<'a> { } } -impl<'a> Serialize for Packet<'a> { - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> - where - S: Serializer, - { - let mut seq = serializer.serialize_seq(Some(2))?; - - let t = self.message_num() as u8; - seq.serialize_element(&t)?; - - match self { - // eg - // Packet::KexInit(p) => { - // ... - $( - Packet::$SpecificPacketVariant(p) => { - seq.serialize_element(p)?; - } - )* - }; - - seq.end() - } -} - /// Top level SSH packet enum #[derive(Debug)] pub enum Packet<'a> { @@ -1105,11 +702,10 @@ mod tests { use crate::doorlog::init_test_log; use crate::packets::*; use crate::sshnames::*; - use crate::wireformat::tests::{assert_serialize_equal, test_roundtrip}; - use crate::wireformat::{packet_from_bytes, write_ssh}; - use crate::{packets, wireformat}; + use crate::sshwire::tests::{assert_serialize_equal, test_roundtrip}; + use crate::sshwire::{packet_from_bytes, write_ssh}; + use crate::{packets, sshwire}; use pretty_hex::PrettyHex; - use serde::{Deserialize, Deserializer, Serialize, Serializer}; #[test] /// check round trip of packet enums is right @@ -1122,14 +718,6 @@ mod tests { } } - fn json_roundtrip(p: &Packet) { - // let t = serde_json::to_string_pretty(p).unwrap(); - // trace!("json {t}"); - // let p2 = serde_json::from_str(&t).unwrap(); - - // assert_serialize_equal(p, &p2); - } - #[test] /// Tests MethodPubKey custom serde fn roundtrip_authpubkey() { @@ -1171,7 +759,6 @@ mod tests { }), }); test_roundtrip(&p); - json_roundtrip(&p); let p = Packet::ChannelOpen(ChannelOpen { num: 0, @@ -1180,7 +767,6 @@ mod tests { ch: ChannelOpenType::Session, }); test_roundtrip(&p); - json_roundtrip(&p); } #[test] @@ -1216,28 +802,4 @@ mod tests { let mut buf1 = vec![88; 1000]; write_ssh(&mut buf1, &p).unwrap(); } - - #[test] - /// See whether we work with another `Serializer`/`Deserializer`. - /// Not required, but might make `packets` more reusable without `wireformat`. - fn json() { - init_test_log(); - let p = Packet::Userauth60(Userauth60::PwChangeReq(UserauthPwChangeReq { - prompt: "change the password", - lang: "", - })); - json_roundtrip(&p); - - // Fails, namelist string sections are serialized piecewise, serde - // doesn't have any API to write strings in parts. It's fine for - // SSH format since we have no sequence delimiters. - - // let cli_conf = kex::AlgoConfig::new(true); - // let cli = kex::Kex::new().unwrap(); - // let p = cli.make_kexinit(&cli_conf); - // json_roundtrip(&p); - - // It seems BinString also has problems, haven't figured where the - // problem is. - } } diff --git a/sshproto/src/sign.rs b/sshproto/src/sign.rs index e6e2288..11daae0 100644 --- a/sshproto/src/sign.rs +++ b/sshproto/src/sign.rs @@ -12,8 +12,7 @@ use ed25519_dalek::{Verifier, Signer}; use crate::{*, packets::ParseContext}; use sshnames::*; use packets::{PubKey, Signature, Ed25519PubKey}; -use wireformat::BinString; -use sshwire::SSHEncode; +use sshwire::{BinString, SSHEncode}; use pretty_hex::PrettyHex; @@ -124,15 +123,12 @@ impl SignKey { k.try_into() } - // pub(crate) fn sign_serialize<'s>(&self, msg: &'s impl serde::Serialize) -> Result<OwnedSig> { - pub(crate) fn sign_serialize<'s>(&self, msg: &'s impl SSHEncode) -> Result<OwnedSig> { + pub(crate) fn sign_encode<'s>(&self, msg: &'s impl SSHEncode, parse_ctx: Option<&ParseContext>) -> Result<OwnedSig> { match self { SignKey::Ed25519(k) => { let exk: dalek::ExpandedSecretKey = (&k.secret).into(); exk.sign_parts(|h| { - let mut ctx = ParseContext::default(); - ctx.method_pubkey_force_sig_bool = true; - sshwire::hash_ser(h, Some(&ctx), msg).map_err(|_| dalek::SignatureError::new()) + sshwire::hash_ser(h, parse_ctx, msg).map_err(|_| dalek::SignatureError::new()) }, &k.public) .trap() .map(|s| s.into()) @@ -166,10 +162,8 @@ pub(crate) mod tests { use ed25519_dalek::Signer; use crate::sshnames::SSH_NAME_ED25519; - use crate::{packets, wireformat}; + use crate::packets; use crate::sign::*; - use crate::wireformat::tests::assert_serialize_equal; - use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::doorlog::init_test_log; pub(crate) fn make_ed25519_signkey() -> SignKey { diff --git a/sshproto/src/sshwire.rs b/sshproto/src/sshwire.rs index 8c30c7f..63d3cc9 100644 --- a/sshproto/src/sshwire.rs +++ b/sshproto/src/sshwire.rs @@ -5,6 +5,8 @@ use { }; use core::str; +use core::convert::AsRef; +use core::fmt::{self,Debug}; use crate::*; use packets::{Packet, ParseContext}; @@ -49,17 +51,7 @@ pub trait SSHDecodeEnum<'de>: Sized { /// Parses a [`Packet`] from a borrowed `&[u8]` byte buffer. pub fn packet_from_bytes<'a>(b: &'a [u8], ctx: &ParseContext) -> Result<Packet<'a>> { let mut s = DecodeBytes { input: b, pos: 0, parse_ctx: ctx.clone() }; - Packet::dec(&mut s).map_err(|e| { - // TODO better handling of this. Stuff it in PacketState. - // Also should return which MessageNumber failed in later parsing - if let Error::InvalidDeserializeU8 { value } = e { - // This assumes that the only deserialize that can hit - // invalid_value() is an unknown packet type. Seems safe at present. - Error::UnknownPacket { number: value } - } else { - e - } - }) + Packet::dec(&mut s) } pub fn write_ssh<T>(target: &mut [u8], value: &T) -> Result<usize> @@ -166,6 +158,96 @@ impl<'de> SSHSource<'de> for DecodeBytes<'de> { } } +// Hashes a slice to be treated as a mpint. Has u32 length prefix +// and an extra 0x00 byte if the MSB is set. +pub fn hash_mpint(hash_ctx: &mut dyn digest::DynDigest, m: &[u8]) { + let pad = m.len() > 0 && (m[0] & 0x80) != 0; + let l = m.len() as u32 + pad as u32; + hash_ctx.update(&l.to_be_bytes()); + if pad { + hash_ctx.update(&[0x00]); + } + hash_ctx.update(m); +} + +/////////////////////////////////////////////// + +/// A SSH style binary string. Serialized as 32 bit length followed by the bytes +/// of the slice. +#[derive(Clone,PartialEq)] +pub struct BinString<'a>(pub &'a [u8]); + +impl<'a> AsRef<[u8]> for BinString<'a> { + fn as_ref(&self) -> &'a [u8] { + self.0 + } +} + +impl<'a> Debug for BinString<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "BinString(len={})", self.0.len()) + } +} + +impl SSHEncode for BinString<'_> { + fn enc<S>(&self, s: &mut S) -> Result<()> + where S: sshwire::SSHSink { + (self.0.len() as u32).enc(s)?; + self.0.enc(s) + } +} + +impl<'de> SSHDecode<'de> for BinString<'de> { + fn dec<S>(s: &mut S) -> Result<Self> + where S: sshwire::SSHSource<'de> { + let len = u32::dec(s)? as usize; + Ok(BinString(s.take(len)?)) + } + +} + +// A wrapper for a u32 prefixed data structure `B`, such as a public key blob +pub struct Blob<B>(pub B); + +impl<B> AsRef<B> for Blob<B> { + fn as_ref(&self) -> &B { + &self.0 + } +} + +impl<B: Clone> Clone for Blob<B> { + fn clone(&self) -> Self { + Blob(self.0.clone()) + } +} + +impl<B: SSHEncode + Debug> Debug for Blob<B> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let len = sshwire::length_enc(&self.0) + .map_err(|_| core::fmt::Error)?; + write!(f, "Blob(len={len}, {:?})", self.0) + } +} + +impl<B: SSHEncode> SSHEncode for Blob<B> { + fn enc<S>(&self, s: &mut S) -> Result<()> + where S: sshwire::SSHSink { + let len: u32 = sshwire::length_enc(&self.0)?.try_into().trap()?; + len.enc(s)?; + self.0.enc(s) + } +} + +impl<'de, B: SSHDecode<'de>> SSHDecode<'de> for Blob<B> { + fn dec<S>(s: &mut S) -> Result<Self> + where S: sshwire::SSHSource<'de> { + let len = u32::dec(s)?; + let inner = SSHDecode::dec(s)?; + // TODO verify length matches + Ok(Blob(inner)) + } +} + /////////////////////////////////////////////// impl SSHEncode for u8 { @@ -233,9 +315,10 @@ impl<'de> SSHDecode<'de> for bool { } } -// TODO: inline seemed to help code size in wireformat? +// #[inline] seems to decrease code size somehow + impl<'de> SSHDecode<'de> for u8 { - // #[inline] + #[inline] fn dec<S>(s: &mut S) -> Result<Self> where S: SSHSource<'de> { let t = s.take(core::mem::size_of::<u8>())?; @@ -244,7 +327,7 @@ impl<'de> SSHDecode<'de> for u8 { } impl<'de> SSHDecode<'de> for u32 { - // #[inline] + #[inline] fn dec<S>(s: &mut S) -> Result<Self> where S: SSHSource<'de> { let t = s.take(core::mem::size_of::<u32>())?; @@ -253,7 +336,7 @@ impl<'de> SSHDecode<'de> for u32 { } impl<'de: 'a, 'a> SSHDecode<'de> for &'a str { - // #[inline] + #[inline] fn dec<S>(s: &mut S) -> Result<Self> where S: SSHSource<'de> { let len = u32::dec(s)?; @@ -271,3 +354,98 @@ impl<'de, const N: usize> SSHDecode<'de> for [u8; N] { Ok(l) } } + +#[cfg(test)] +pub(crate) mod tests { + use crate::*; + use doorlog::init_test_log; + use error::Error; + use packets::*; + use sshwire::*; + use pretty_hex::PrettyHex; + + /// Checks that two items serialize the same + pub fn assert_serialize_equal<'de, T: SSHEncode>(p1: &T, p2: &T) { + let mut buf1 = vec![99; 2000]; + let mut buf2 = vec![88; 1000]; + let l1 = write_ssh(&mut buf1, p1).unwrap(); + let l2 = write_ssh(&mut buf2, p2).unwrap(); + buf1.truncate(l1); + buf2.truncate(l2); + assert_eq!(buf1, buf2); + } + + #[test] + /// check that hash_ser_length() matches hashing a serialized message + fn test_hash_packet() { + use sha2::Sha256; + use digest::Digest; + let input = "hello"; + let mut buf = vec![99; 20]; + let w1 = write_ssh(&mut buf, &input).unwrap(); + buf.truncate(w1); + + // hash_ser_length + let mut hash_ctx = Sha256::new(); + hash_ser_length(&mut hash_ctx, &input).unwrap(); + let digest1 = hash_ctx.finalize(); + + let mut hash_ctx = Sha256::new(); + hash_ctx.update(&(w1 as u32).to_be_bytes()); + hash_ctx.update(&buf); + let digest2 = hash_ctx.finalize(); + + assert_eq!(digest1, digest2); + + // hash_ser + let mut hash_ctx = Sha256::new(); + hash_ctx.update(&(w1 as u32).to_be_bytes()); + hash_ser(&mut hash_ctx, None, &input).unwrap(); + let digest3 = hash_ctx.finalize(); + assert_eq!(digest3, digest2); + } + + pub fn test_roundtrip_context(p: &Packet, ctx: &ParseContext) { + let mut buf = vec![99; 200]; + let l = write_ssh(&mut buf, p).unwrap(); + buf.truncate(l); + trace!("wrote packet {:?}", buf.hex_dump()); + + let p2 = packet_from_bytes(&buf, &ctx).unwrap(); + trace!("returned packet {:#?}", p2); + assert_serialize_equal(p, &p2); + } + + /// With default context + pub fn test_roundtrip(p: &Packet) { + test_roundtrip_context(&p, &ParseContext::default()); + } + + /// Tests parsing a packet with a ParseContext. + #[test] + fn test_parse_context() { + init_test_log(); + let mut ctx = ParseContext::new(); + + let p = Userauth60::PwChangeReq(UserauthPwChangeReq { + prompt: "change the password", + lang: "", + }).into(); + let mut pw = ResponseString::new(); + pw.push_str("123").unwrap(); + ctx.cli_auth_type = Some(cliauth::AuthType::Password); + test_roundtrip_context(&p, &ctx); + + // PkOk is a more interesting case because the PubKey inside it is also + // an enum but that can identify its own enum variant. + let p = Userauth60::PkOk(UserauthPkOk { + algo: "ed25519", + key: Blob(PubKey::Ed25519(Ed25519PubKey { + key: BinString(&[0x11, 0x22, 0x33]), + })), + }).into(); + let s = sign::tests::make_ed25519_signkey(); + ctx.cli_auth_type = Some(cliauth::AuthType::PubKey); + test_roundtrip_context(&p, &ctx); + } +} diff --git a/sshproto/src/test.rs b/sshproto/src/test.rs index 14f184a..e70fb3d 100644 --- a/sshproto/src/test.rs +++ b/sshproto/src/test.rs @@ -2,12 +2,10 @@ mod tests { use crate::error::Error; use crate::packets::*; - use crate::wireformat::BinString; + use crate::sshwire::BinString; use crate::packets::{Packet,ParseContext}; - use crate::{packets, wireformat}; + use crate::{packets, sshwire}; use pretty_hex::PrettyHex; - use serde::de::Unexpected; - use serde::{Deserialize, Serialize}; use simplelog::{TestLogger,self,LevelFilter}; pub fn init_log() { @@ -16,14 +14,14 @@ mod tests { fn test_roundtrip_packet(p: &Packet) -> Result<(), Error> { let mut buf1 = vec![99; 500]; - let _w1 = wireformat::write_ssh(&mut buf1, &p)?; + let _w1 = sshwire::write_ssh(&mut buf1, p)?; let ctx = ParseContext::new(); - let p2 = wireformat::packet_from_bytes(&buf1, &ctx)?; + let p2 = sshwire::packet_from_bytes(&buf1, &ctx)?; let mut buf2 = vec![99; 500]; - let _w2 = wireformat::write_ssh(&mut buf2, &p2)?; + let _w2 = sshwire::write_ssh(&mut buf2, &p2)?; // println!("{p:?}"); // println!("{p2:?}"); // println!("{:?}", buf1.hex_dump()); diff --git a/sshproto/src/wireformat.rs b/sshproto/src/wireformat.rs deleted file mode 100644 index 5516973..0000000 --- a/sshproto/src/wireformat.rs +++ /dev/null @@ -1,1037 +0,0 @@ -//! SSH protocol serialization. -//! Implements enough of serde to handle the formats defined in [`crate::packets`] - -//! See [RFC4251](https://datatracker.ietf.org/doc/html/rfc4251) for encodings, -//! [RFC4253](https://datatracker.ietf.org/doc/html/rfc4253) and others for packet structure -#[allow(unused_imports)] -use { - crate::error::{Error, Result, TrapBug}, - log::{debug, error, info, log, trace, warn}, -}; - -use serde::de::value::{BorrowedStrDeserializer, SeqAccessDeserializer}; -use serde::{ - de::{self, value::MapAccessDeserializer, IntoDeserializer, MapAccess}, - de::{DeserializeSeed, EnumAccess, SeqAccess, VariantAccess, Visitor}, - ser, - ser::SerializeSeq, - Deserialize, Deserializer, Serialize, Serializer, -}; - -use pretty_hex::PrettyHex; - -use crate::packets::{Packet, PacketState, ParseContext}; -use crate::{packets::UserauthPkOk, *}; -use core::{cell::Cell, fmt::Binary}; -use core::convert::AsRef; -use core::fmt::{self,Debug}; -use core::slice; -use core::marker::PhantomData; - -use crate::sshwire::{SSHEncode, SSHDecode}; - -/// Parses a [`Packet`] from a borrowed `&[u8]` byte buffer. -pub fn packet_from_bytes<'a>( - b: &'a [u8], ctx: &ParseContext, -) -> Result<Packet<'a>> { - let mut ds = DeSSHBytes::from_bytes(b, ctx.clone()); - Packet::deserialize(&mut ds).map_err(|e| { - // TODO better handling of this. Stuff it in PacketState. - // Also should return which MessageNumber failed in later parsing - if let Error::InvalidDeserializeU8 { value } = e { - // This assumes that the only deserialize that can hit - // invalid_value() is an unknown packet type. Seems safe at present. - Error::UnknownPacket { number: value } - } else { - e - } - }) - // TODO check for trailing bytes, pos != b.len() -} - -// Hashes a slice to be treated as a mpint. Has u32 length prefix -// and an extra 0x00 byte if the MSB is set. -pub fn hash_mpint(hash_ctx: &mut dyn digest::DynDigest, m: &[u8]) { - let pad = m.len() > 0 && (m[0] & 0x80) != 0; - let l = m.len() as u32 + pad as u32; - hash_ctx.update(&l.to_be_bytes()); - if pad { - hash_ctx.update(&[0x00]); - } - hash_ctx.update(m); -} - -/// Writes a SSH packet to a buffer. Returns the length written. -pub fn write_ssh<T>(target: &mut [u8], value: &T) -> Result<usize> -where - T: Serialize, -{ - let mut serializer = SeSSHBytes::WriteBytes { target, pos: 0 }; - value.serialize(&mut serializer)?; - Ok(match serializer { - SeSSHBytes::WriteBytes { target: _, pos } => pos, - _ => 0, // TODO is there a better syntax here? we know it's always WriteBytes - }) -} - -/// Hashes the contents of a SSH packet, updating the provided context. -/// Adds a `u32` length prefix. -pub fn hash_ser_length<T>( - hash_ctx: &mut impl digest::DynDigest, value: &T, -) -> Result<()> -where - T: Serialize, -{ - // calculate the u32 length prefix - let len = SeSSHBytes::get_length(value)? as u32; - hash_ctx.update(&len.to_be_bytes()); - let mut serializer = SeSSHBytes::WriteHash { hash_ctx }; - // the rest of the packet - value.serialize(&mut serializer)?; - Ok(()) -} - -/// Hashes the contents of a `Serialize` item such as a public key. -/// No length prefix is added. -pub fn hash_ser<T>(hash_ctx: &mut impl digest::DynDigest, value: &T) -> Result<()> -where - T: Serialize, -{ - let mut serializer = SeSSHBytes::WriteHash { hash_ctx }; - value.serialize(&mut serializer)?; - Ok(()) -} - -type Res = Result<()>; - -/// A SSH style binary string. Serialized as 32 bit length followed by the bytes -/// of the slice. -#[derive(Deserialize,Clone,PartialEq)] -pub struct BinString<'a>(pub &'a [u8]); - -impl<'a> AsRef<[u8]> for BinString<'a> { - fn as_ref(&self) -> &'a [u8] { - self.0 - } -} - -impl<'a> Debug for BinString<'a> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "BinString(len={})", self.0.len()) - } -} - -impl<'a> Serialize for BinString<'a> { - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> - where - S: Serializer, - { - serializer.serialize_bytes(self.0) - } -} - -impl SSHEncode for BinString<'_> { - fn enc<S>(&self, s: &mut S) -> Result<()> - where S: sshwire::SSHSink { - (self.0.len() as u32).enc(s)?; - self.0.enc(s) - } -} - -impl<'de> SSHDecode<'de> for BinString<'de> { - fn dec<S>(s: &mut S) -> Result<Self> - where S: sshwire::SSHSource<'de> { - let len = u32::dec(s)? as usize; - Ok(BinString(s.take(len)?)) - } - -} - -// a wrapper for a u32 prefixed data structure `B`, such as a public key blob -pub struct Blob<B>(pub B); - -impl<B> AsRef<B> for Blob<B> { - fn as_ref(&self) -> &B { - &self.0 - } -} - -impl<B: Clone> Clone for Blob<B> { - fn clone(&self) -> Self { - Blob(self.0.clone()) - } -} - -impl<B: SSHEncode + Serialize + Debug> Debug for Blob<B> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let len = sshwire::length_enc(&self.0) - .map_err(|_| ser::Error::custom(Error::bug()))?; - // let len = SeSSHBytes::get_length(&self.0) - // .map_err(|_| ser::Error::custom(Error::bug()))?; - write!(f, "Blob(len={len}, {:?})", self.0) - } -} - -impl<B: SSHEncode> SSHEncode for Blob<B> { - fn enc<S>(&self, s: &mut S) -> Result<()> - where S: sshwire::SSHSink { - let len: u32 = sshwire::length_enc(&self.0)?.try_into().trap()?; - len.enc(s)?; - self.0.enc(s) - } -} - -impl<'de, B: SSHDecode<'de>> SSHDecode<'de> for Blob<B> { - fn dec<S>(s: &mut S) -> Result<Self> - where S: sshwire::SSHSource<'de> { - let len = u32::dec(s)?; - let inner = SSHDecode::dec(s)?; - // TODO verify length matches - Ok(Blob(inner)) - } -} - -impl<B: Serialize> Serialize for Blob<B> { - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> - where - S: Serializer, - { - let mut seq = serializer.serialize_seq(None)?; - let len = SeSSHBytes::get_length(&self.0) - .map_err(|_| ser::Error::custom(Error::bug()))? as u32; - trace!("blob len {} {:x}", len, len); - seq.serialize_element(&len)?; - seq.serialize_element(&self.0)?; - seq.end() - } -} - -impl<'de, B: Deserialize<'de> + Serialize> Deserialize<'de> for Blob<B> { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: Deserializer<'de>, - { - struct Vis<'de, B> { - ph: PhantomData<B>, - lifetime: PhantomData<&'de ()>, - } - - impl<'de, B: Deserialize<'de> + Serialize> Visitor<'de> for Vis<'de, B> { - type Value = Blob<B>; - - fn expecting( - &self, formatter: &mut core::fmt::Formatter, - ) -> core::fmt::Result { - formatter.write_str("length prefixed blob") - } - fn visit_seq<V>(self, mut seq: V) -> Result<Blob<B>, V::Error> - where - V: SeqAccess<'de>, - { - let bloblen: u32 = seq - .next_element()? - .ok_or_else(|| de::Error::missing_field("length"))?; - - let inner: B = seq - .next_element()? - .ok_or_else(|| de::Error::missing_field("rest of packet"))?; - - // TODO: is there a better way to find the length consumed? - // If we could enforce that D is a DeSSHBytes we can look - // at the length... - // let gotlen = SeSSHBytes::get_length(&inner) - // .map_err(|_| de::Error::custom(Error::bug()))?; - // if bloblen as usize != gotlen { - // return Err(de::Error::custom(format_args!( - // "Expected {} of length {}, got {}", - // core::any::type_name::<B>(), - // bloblen, - // gotlen - // ))); - // } - Ok(Blob(inner)) - } - } - deserializer.deserialize_seq(Vis { ph: PhantomData, lifetime: PhantomData }) - } -} - -/// Serializer for the SSH wire protocol. Writes into a borrowed `&mut [u8]` buffer. -/// Optionally compute the hash of the packet or the length required. -enum SeSSHBytes<'a> { - WriteBytes { target: &'a mut [u8], pos: usize }, - WriteHash { hash_ctx: &'a mut dyn digest::DynDigest }, - Length { pos: usize }, -} - -impl SeSSHBytes<'_> { - /// Returns the length required to serialize `value` - pub fn get_length<S>(value: S) -> Result<usize> - where - S: Serialize, - { - let mut serializer = SeSSHBytes::Length { pos: 0 }; - value.serialize(&mut serializer)?; - let len = match serializer { - SeSSHBytes::Length { pos } => pos, - _ => 0, // TODO is there a better syntax here? we know it's always WriteBytes - }; - Ok(len) - } - - /// Appends serialized data - fn push(&mut self, v: &[u8]) -> Res { - panic!("push"); - match self { - SeSSHBytes::WriteBytes { target, ref mut pos } => { - if *pos + v.len() > target.len() { - return Err(Error::NoRoom); - } - target[*pos..*pos + v.len()].copy_from_slice(v); - *pos += v.len(); - } - SeSSHBytes::Length { ref mut pos } => { - *pos += v.len(); - } - SeSSHBytes::WriteHash { hash_ctx } => { - hash_ctx.update(v); - } - } - Ok(()) - } -} - -impl Serializer for &mut SeSSHBytes<'_> { - type Ok = (); - type Error = crate::error::Error; - - type SerializeSeq = Self; - type SerializeStruct = Self; - type SerializeTuple = Self; - type SerializeStructVariant = ser::Impossible<(), Error>; - type SerializeTupleStruct = ser::Impossible<(), Error>; - type SerializeTupleVariant = ser::Impossible<(), Error>; - type SerializeMap = ser::Impossible<(), Error>; - - fn serialize_bool(self, v: bool) -> Res { - self.serialize_u8(v as u8) - } - fn serialize_u8(self, v: u8) -> Res { - self.push(&[v]) - } - fn serialize_u32(self, v: u32) -> Res { - self.push(&v.to_be_bytes()) - } - /// Not actually used in any SSH packets, mentioned in the arch doc - fn serialize_u64(self, v: u64) -> Res { - self.push(&v.to_be_bytes()) - } - /// Serialize raw bytes with no prefix - fn serialize_bytes(self, v: &[u8]) -> Res { - self.serialize_u32(v.len() as u32)?; - self.push(v) - // todo!( - // "This is asymmetric with deserialize_bytes, but isn't currently being used." - // ) - } - fn serialize_str(self, v: &str) -> Res { - let b = v.as_bytes(); - self.serialize_u32(b.len() as u32)?; - self.push(b) - } - - fn serialize_some<T>(self, v: &T) -> Res - where - T: ?Sized + Serialize, - { - v.serialize(self) - } - // for truncated last option for publickey - fn serialize_none(self) -> Res { - Ok(()) - } - fn serialize_newtype_struct<T>(self, _name: &'static str, v: &T) -> Res - where - T: ?Sized + Serialize, - { - v.serialize(self) - } - fn serialize_newtype_variant<T>( - self, name: &'static str, _variant_index: u32, variant: &'static str, v: &T, - ) -> Res - where - T: ?Sized + Serialize, - { - match name { - "Userauth60" | "ChannelOpenType" | "ChannelReqType" => { - // Name is elsewhere, part of the parent struct or - // from ParseContext - } - "PubKey" | "Signature" | "AuthMethod" => { - // Name is immediately before the enum - self.serialize_str(variant)?; - } - _ => { - return Error::bug_args(format_args!("Mystery enum")) - } - }; - v.serialize(self) - } - - // for "none" variant - fn serialize_unit_variant( - self, name: &'static str, _variant_index: u32, variant: &'static str, - ) -> Res { - match name { - "ChannelType" => Ok(()), // eg "session" unit variant - _ => self.serialize_str(variant), - } - } - - fn serialize_seq(self, _len: Option<usize>) -> Result<Self> { - Ok(self) - } - fn serialize_struct(self, _name: &'static str, _len: usize) -> Result<Self> { - Ok(self) - } - fn serialize_tuple(self, _len: usize) -> Result<Self> { - Ok(self) - } - - fn collect_str<T: ?Sized>(self, _: &T) -> Res { - Err(Error::NoSerializer) - } - fn serialize_i8(self, _: i8) -> Res { - Err(Error::NoSerializer) - } - fn serialize_i16(self, _: i16) -> Res { - Err(Error::NoSerializer) - } - fn serialize_i32(self, _: i32) -> Res { - Err(Error::NoSerializer) - } - fn serialize_i64(self, _: i64) -> Res { - Err(Error::NoSerializer) - } - fn serialize_u16(self, _: u16) -> Res { - Err(Error::NoSerializer) - } - fn serialize_f32(self, _: f32) -> Res { - Err(Error::NoSerializer) - } - fn serialize_f64(self, _: f64) -> Res { - Err(Error::NoSerializer) - } - // TODO: perhaps useful? - fn serialize_char(self, _: char) -> Res { - Err(Error::NoSerializer) - } - fn serialize_unit(self) -> Res { - Err(Error::NoSerializer) - } - fn serialize_unit_struct(self, _name: &'static str) -> Res { - Err(Error::NoSerializer) - } - fn serialize_tuple_struct( - self, _name: &'static str, _len: usize, - ) -> Result<Self::SerializeTupleStruct> { - Err(Error::NoSerializer) - } - fn serialize_tuple_variant( - self, _name: &'static str, _variant_index: u32, _variant: &'static str, - _len: usize, - ) -> Result<Self::SerializeTupleVariant> { - Err(Error::NoSerializer) - } - fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> { - Err(Error::NoSerializer) - } - fn serialize_struct_variant( - self, _name: &'static str, _variant_index: u32, _variant: &'static str, - _len: usize, - ) -> Result<Self::SerializeStructVariant> { - Err(Error::NoSerializer) - } -} - -impl ser::SerializeSeq for &mut SeSSHBytes<'_> { - type Ok = (); - type Error = crate::error::Error; - - fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error> - where - T: ?Sized + Serialize, - { - value.serialize(&mut **self) - } - - fn end(self) -> Result<(), Self::Error> { - Ok(()) - } -} - -impl ser::SerializeStruct for &mut SeSSHBytes<'_> { - type Ok = (); - type Error = crate::error::Error; - - fn serialize_field<T>( - &mut self, _key: &'static str, value: &T, - ) -> Result<(), Self::Error> - where - T: ?Sized + Serialize, - { - value.serialize(&mut **self) - } - - fn end(self) -> Result<(), Self::Error> { - Ok(()) - } -} - -impl ser::SerializeTuple for &mut SeSSHBytes<'_> { - type Ok = (); - type Error = crate::error::Error; - - fn serialize_element<T: ?Sized>(&mut self, value: &T) -> Result<(), Self::Error> - where - T: Serialize, - { - value.serialize(&mut **self) - } - - fn end(self) -> Result<(), Self::Error> { - Ok(()) - } -} - -/// Deserializer for the SSH wire protocol, from borrowed `&[u8]` -/// Implements enough of serde to handle the formats defined in [`crate::packets`] -pub(crate) struct DeSSHBytes<'a> { - input: &'a [u8], - pos: usize, - - parse_ctx: ParseContext, - - /// MapAccessDeSSH can capture the string value of a selected field - /// for use as an enum variant selector - capture_next_str: bool, - capture_str: Option<&'a str>, - - /// A variant name to be used for the next enum deserialization - next_variant: Option<&'a str>, -} - -impl<'de> DeSSHBytes<'de> { - // XXX: rename to new() ? - pub fn from_bytes(input: &'de [u8], ctx: ParseContext) -> Self { - DeSSHBytes { - input, - pos: 0, - parse_ctx: ctx, - capture_next_str: false, - capture_str: None, - next_variant: None, - } - } - - fn take(&mut self, len: usize) -> Result<&'de [u8]> { - panic!("take"); - if len > self.input.len() { - return Err(Error::RanOut); - } - let (t, rest) = self.input.split_at(len); - self.input = rest; - self.pos += len; - Ok(t) - } - - #[inline] - fn parse_u8(&mut self) -> Result<u8> { - let t = self.take(core::mem::size_of::<u8>())?; - let u = u8::from_be_bytes(t.try_into().unwrap()); - // println!("deser u8 {u}"); - Ok(u) - } - - #[inline] - fn parse_u32(&mut self) -> Result<u32> { - let t = self.take(core::mem::size_of::<u32>())?; - let u = u32::from_be_bytes(t.try_into().unwrap()); - // println!("deser u32 {u}"); - Ok(u) - } - - fn parse_u64(&mut self) -> Result<u64> { - let t = self.take(core::mem::size_of::<u64>())?; - Ok(u64::from_be_bytes(t.try_into().unwrap())) - } - - #[inline] - fn parse_str(&mut self) -> Result<&'de str> { - let len = self.parse_u32()?; - let t = self.take(len as usize)?; - let s = core::str::from_utf8(t).map_err(|_| Error::BadString)?; - Ok(s) - } -} - -impl<'de, 'a> Deserializer<'de> for &'a mut DeSSHBytes<'de> { - type Error = Error; - - fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value> - where - V: Visitor<'de>, - { - visitor.visit_bool(self.parse_u8()? != 0) - } - - fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value> - where - V: Visitor<'de>, - { - visitor.visit_u8(self.parse_u8()?) - } - - fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value> - where - V: Visitor<'de>, - { - visitor.visit_u32(self.parse_u32()?) - } - - fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value> - where - V: Visitor<'de>, - { - visitor.visit_u64(self.parse_u64()?) - } - - fn deserialize_str<V>(self, visitor: V) -> Result<V::Value> - where - V: Visitor<'de>, - { - let s = self.parse_str()?; - if self.capture_next_str { - debug_assert!(self.capture_str.is_none()); - self.capture_str = Some(s); - self.capture_next_str = false - } - visitor.visit_borrowed_str(s) - } - - fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value> - where - V: Visitor<'de>, - { - self.deserialize_str(visitor) - } - - /* deserialize_bytes() is like a string but with binary data. it has - a u32 prefix of the length. Fixed length byte arrays use _tuple() */ - fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value> - where - V: Visitor<'de>, - { - let len = self.parse_u32()?; - let t = self.take(len as usize)?; - visitor.visit_borrowed_bytes(t) - } - - fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value> - where - V: Visitor<'de>, - { - visitor.visit_seq(SeqAccessDeSSH { ds: self, len: Some(len) }) - } - - fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - visitor.visit_seq(SeqAccessDeSSH { ds: self, len: None }) - } - - fn deserialize_struct<V>( - self, name: &'static str, fields: &'static [&'static str], visitor: V, - ) -> Result<V::Value> - where - V: Visitor<'de>, - { - match name { - |"ChannelOpen" - |"ChannelRequest" - => { - // We need a struct deserializer to extract specific fields - let ma = MapAccessDeSSH::new(self, fields, "channel_type"); - let v = visitor.visit_map(ma)?; - debug_assert!(self.next_variant.is_none()); - Ok(v) - } - _ => { - // A simple deserialize_tuple is smaller - self.deserialize_tuple(fields.len(), visitor) - } - } - } - - fn deserialize_enum<V>( - self, name: &'static str, variants: &'static [&'static str], visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - let variant_name = match name { - "Userauth60" => packets::Userauth60::variant(&self.parse_ctx)?, - |"ChannelOpenType" - |"ChannelReqType" - => self.next_variant.take().trap()?, - "PubKey" | "Signature" | "AuthMethod" => { - // The variant is selected by the method name in the packet, - // using `#[serde(rename)]` in `packets` enum definition. - self.parse_str()? - } - _ => { - // A mystery enum has been added to packets.rs - return Error::bug_args(format_args!("Mystery enum")) - } - }; - - let unknown_variant = !variants.contains(&variant_name) || variant_name == "Unknown"; - - let stringenum = SSHStringEnum { - ds: self, variant_name, unknown_variant }; - - visitor.visit_enum(stringenum) - } - - fn deserialize_newtype_struct<V>( - self, _name: &'static str, visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - visitor.visit_newtype_struct(self) - } - - fn deserialize_tuple_struct<V>( - self, _name: &'static str, _len: usize, _visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(Error::NoSerializer) - } - - fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(Error::NoSerializer) - } - // The remainder will fail. - serde::forward_to_deserialize_any! { - i8 i16 i32 i64 i128 u16 u128 f32 f64 char string - byte_buf unit_struct - map ignored_any - option - } - fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value> - where - V: Visitor<'de>, - { - Err(Error::NoSerializer) - } -} - -struct SeqAccessDeSSH<'a, 'b: 'a> { - ds: &'a mut DeSSHBytes<'b>, - len: Option<usize>, -} - -impl<'a, 'b: 'a> SeqAccess<'b> for SeqAccessDeSSH<'a, 'b> { - type Error = Error; - #[inline] - fn next_element_seed<V: DeserializeSeed<'b>>( - &mut self, seed: V, - ) -> Result<Option<V::Value>> { - if let Some(ref mut len) = self.len { - if *len > 0 { - *len -= 1; - Ok(Some(DeserializeSeed::deserialize(seed, &mut *self.ds)?)) - } else { - Ok(None) - } - } else { - Ok(Some(DeserializeSeed::deserialize(seed, &mut *self.ds)?)) - } - } - - fn size_hint(&self) -> Option<usize> { - self.len - } -} - -struct MapAccessDeSSH<'de, 'a> { - ds: &'a mut DeSSHBytes<'de>, - fields: &'static [&'static str], - pos: usize, - - // We want to use a field in a parent struct to choose the - // variant of a child enum. We record the field here - // and use it in deserialize_enum(). - // This assumes that no intervening enums are decoded before - // the desired one. - // Perhaps in future #[serde(flatten)] etc could be used instead. - variant_field: &'a str, -} - -impl<'de: 'a, 'a> MapAccessDeSSH<'de, 'a> { - fn new( - ds: &'a mut DeSSHBytes<'de>, fields: &'static [&'static str], - variant_field: &'a str, - ) -> Self { - debug_assert!(ds.next_variant.is_none()); - MapAccessDeSSH { ds, fields, pos: 0, variant_field } - } -} - -impl<'de: 'a, 'a> MapAccess<'de> for MapAccessDeSSH<'de, 'a> { - type Error = Error; - - // inline reduces code size - #[inline] - fn next_key_seed<S: DeserializeSeed<'de>>( - &mut self, seed: S, - ) -> Result<Option<S::Value>> { - if self.pos < self.fields.len() { - debug_assert!(self.ds.capture_str.is_none()); - debug_assert!(!self.ds.capture_next_str); - // The subsequent next_value_seed() should - // capture the string value if this is our - // capture_field. - self.ds.capture_next_str = self.variant_field == self.fields[self.pos]; - - // Return the field name as the key - let dsfield = - BorrowedStrDeserializer::<Error>::new(self.fields[self.pos]); - self.pos += 1; - Ok(Some(DeserializeSeed::deserialize(seed, dsfield)?)) - } else { - Ok(None) - } - } - - #[inline] - fn next_value_seed<S: DeserializeSeed<'de>>( - &mut self, seed: S, - ) -> Result<S::Value> { - let v = DeserializeSeed::deserialize(seed, &mut *self.ds)?; - - // Stash any captured value - let cap = self.ds.capture_str.take(); - debug_assert!(!self.ds.capture_next_str); - if cap.is_some() { - debug_assert!(self.ds.next_variant.is_none()); - self.ds.next_variant = cap; - } - Ok(v) - } - - fn size_hint(&self) -> Option<usize> { - Some(self.fields.len()) - } -} - -struct SSHStringEnum<'a, 'de: 'a> { - ds: &'a mut DeSSHBytes<'de>, - - /// Set to the variant name to choose from this enum. - variant_name: &'de str, - /// Set when the variant_name doesn't match any known. - /// Rather than failing, a "Unknown" variant is returned. - /// A Unknown variant is included in all the enum types that - /// could potentially receive known content. - unknown_variant: bool, -} - -// Figures which SSH string (eg "password") identifies the enum -impl<'de, 'a> EnumAccess<'de> for SSHStringEnum<'a, 'de> { - type Error = crate::error::Error; - type Variant = Self; - - fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)> - where - V: DeserializeSeed<'de>, - { - let variant = if self.unknown_variant { - "Unknown" - } else { - self.variant_name - }; - // mystery: why doesn't variant.into_deserializer() work? - let n = BorrowedStrDeserializer::<Error>::new(variant); - let n = seed.deserialize(n)?; - Ok((n, self)) - } -} - -// Creates a struct out of thin air with the given content -struct SyntheticNewtypeSeqAccess<'de> { - content: Option<&'de str> -} - -impl<'de: 'a, 'a> SeqAccess<'de> for SyntheticNewtypeSeqAccess<'de> { - type Error = Error; - fn next_element_seed<V: DeserializeSeed<'de>>( - &mut self, seed: V, - ) -> Result<Option<V::Value>> { - let content = self.content.take(); - let c = content.map(|c| { - let b = BorrowedStrDeserializer::<Error>::new(c); - seed.deserialize(b) - }) - .transpose()?; - Ok(c) - } - - fn size_hint(&self) -> Option<usize> { - Some(1) - } -} - -// Decodes a variant from an enum. -// We only use newtype and unit variants -impl<'de, 'a> VariantAccess<'de> for SSHStringEnum<'a, 'de> { - type Error = Error; - - fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error> - where - T: DeserializeSeed<'de>, - { - if self.unknown_variant { - // Put the unknown variant name in an Unknown newtype - let u = SyntheticNewtypeSeqAccess { content: Some(self.variant_name) }; - let b = SeqAccessDeserializer::new(u); - seed.deserialize(b).into() - } else { - seed.deserialize(self.ds) - } - } - - fn unit_variant(self) -> Result<(), Self::Error> { - Ok(()) - } - - fn tuple_variant<V>( - self, _len: usize, _visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(Error::NoSerializer) - } - - fn struct_variant<V>( - self, _fields: &'static [&'static str], _visitor: V, - ) -> Result<V::Value, Self::Error> - where - V: Visitor<'de>, - { - Err(Error::NoSerializer) - } -} - -#[cfg(test)] -pub(crate) mod tests { - use crate::doorlog::init_test_log; - use crate::error::Error; - use crate::packets::*; - use crate::wireformat::*; - use crate::*; - // use pretty_hex::PrettyHex; - - /// Checks that two items serialize the same - pub fn assert_serialize_equal<'de, T: Serialize+SSHEncode>(p1: &T, p2: &T) { - let mut buf1 = vec![99; 2000]; - let mut buf2 = vec![88; 1000]; - let l1 = write_ssh(&mut buf1, p1).unwrap(); - let l2 = write_ssh(&mut buf2, p2).unwrap(); - buf1.truncate(l1); - buf2.truncate(l2); - assert_eq!(buf1, buf2); - } - - #[test] - /// check that hash_ser_length() matches hashing a serialized message - fn test_hash_packet() { - use sha2::Sha256; - use digest::Digest; - let input = "hello"; - let mut buf = vec![99; 20]; - let w1 = wireformat::write_ssh(&mut buf, &input).unwrap(); - buf.truncate(w1); - - // hash_ser_length - let mut hash_ctx = Sha256::new(); - wireformat::hash_ser_length(&mut hash_ctx, &input).unwrap(); - let digest1 = hash_ctx.finalize(); - - let mut hash_ctx = Sha256::new(); - hash_ctx.update(&(w1 as u32).to_be_bytes()); - hash_ctx.update(&buf); - let digest2 = hash_ctx.finalize(); - - assert_eq!(digest1, digest2); - - // hash_ser - let mut hash_ctx = Sha256::new(); - hash_ctx.update(&(w1 as u32).to_be_bytes()); - wireformat::hash_ser(&mut hash_ctx, &input).unwrap(); - let digest3 = hash_ctx.finalize(); - assert_eq!(digest3, digest2); - } - - pub fn test_roundtrip_context(p: &Packet, ctx: &ParseContext) { - let mut buf = vec![99; 200]; - let l = write_ssh(&mut buf, p).unwrap(); - buf.truncate(l); - trace!("wrote packet {:?}", buf.hex_dump()); - - let p2 = packet_from_bytes(&buf, &ctx).unwrap(); - trace!("returned packet {:#?}", p2); - assert_serialize_equal(p, &p2); - } - - /// With default context - pub fn test_roundtrip(p: &Packet) { - test_roundtrip_context(&p, &ParseContext::default()); - } - - /// Tests parsing a packet with a ParseContext. - #[test] - fn test_parse_context() { - init_test_log(); - let mut ctx = ParseContext::new(); - - let p = Userauth60::PwChangeReq(UserauthPwChangeReq { - prompt: "change the password", - lang: "", - }).into(); - let mut pw = ResponseString::new(); - pw.push_str("123").unwrap(); - ctx.cli_auth_type = Some(cliauth::AuthType::Password); - test_roundtrip_context(&p, &ctx); - - // PkOk is a more interesting case because the PubKey inside it is also - // an enum but that can identify its own enum variant. - let p = Userauth60::PkOk(UserauthPkOk { - algo: "ed25519", - key: Blob(PubKey::Ed25519(Ed25519PubKey { - key: BinString(&[0x11, 0x22, 0x33]), - })), - }).into(); - let s = sign::tests::make_ed25519_signkey(); - ctx.cli_auth_type = Some(cliauth::AuthType::PubKey); - test_roundtrip_context(&p, &ctx); - } -} diff --git a/sshwire_derive/src/lib.rs b/sshwire_derive/src/lib.rs index bf52d58..4ca2651 100644 --- a/sshwire_derive/src/lib.rs +++ b/sshwire_derive/src/lib.rs @@ -65,7 +65,7 @@ enum FieldAtt { /// eg `#[sshwire(variant_name = ch)]` for `ChannelRequest` VariantName(Ident), /// Any unknown variant name should be recorded here. - /// This variant can't be serialized + /// This variant can't be written out. /// `#[sshwire(unknown))]` CaptureUnknown, /// The name of a variant, used by the parent struct -- GitLab