From 9282ace5afd2b97c621e5fcdec76f4dc3ccb3487 Mon Sep 17 00:00:00 2001 From: Matt Johnston <matt@ucc.asn.au> Date: Fri, 3 Jun 2022 14:36:33 +0800 Subject: [PATCH] Use ascii --- Cargo.lock | 7 +++ smol/examples/con1.rs | 6 ++ sshproto/Cargo.toml | 4 +- sshproto/examples/kex1.rs | 24 +++---- sshproto/src/async_behaviour.rs | 2 +- sshproto/src/behaviour.rs | 17 +++-- sshproto/src/block_behaviour.rs | 2 +- sshproto/src/channel.rs | 2 +- sshproto/src/cliauth.rs | 7 ++- sshproto/src/client.rs | 4 +- sshproto/src/conn.rs | 7 +-- sshproto/src/error.rs | 3 + sshproto/src/namelist.rs | 50 ++++++++------- sshproto/src/packets.rs | 74 +++++++++++----------- sshproto/src/sign.rs | 4 +- sshproto/src/sshwire.rs | 108 ++++++++++++++++++++++++++++++-- sshproto/src/test.rs | 20 +++--- sshwire_derive/src/lib.rs | 19 +++--- 18 files changed, 246 insertions(+), 114 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c978dbd..ada6444 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,6 +57,12 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6f8c380fa28aa1b36107cd97f0196474bb7241bb95a453c5c01a15ac74b2eac" +[[package]] +name = "ascii" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbf56136a5198c7b01a49e3afcbef6cf84597273d298f54432926024107b0109" + [[package]] name = "async-trait" version = "0.1.53" @@ -341,6 +347,7 @@ version = "0.1.0" dependencies = [ "aes", "anyhow", + "ascii", "async-trait", "chacha20", "ctr", diff --git a/smol/examples/con1.rs b/smol/examples/con1.rs index 74f3d30..cd97cc2 100644 --- a/smol/examples/con1.rs +++ b/smol/examples/con1.rs @@ -67,6 +67,12 @@ fn parse_args() -> Result<Args> { fn main() -> Result<()> { let args = parse_args()?; + use std::panic; + + panic::set_hook(Box::new(|_| { + println!("Custom panic hook"); + })); + // time crate won't read TZ if we're threaded, in case someone // tries to mutate shared state with setenv. diff --git a/sshproto/Cargo.toml b/sshproto/Cargo.toml index 9df5f76..ff92e13 100644 --- a/sshproto/Cargo.toml +++ b/sshproto/Cargo.toml @@ -11,6 +11,7 @@ snafu = { version = "0.7", default-features = false, features = ["rust_1_46"] } log = { version = "0.4" } heapless = "0.7.10" no-panic = "0.1" +ascii = { version = "1.0", default-features = false } # TODO: needs changing for embedded platforms rand = { version = "0.8", default-features = false } @@ -25,7 +26,6 @@ ssh-key = { version = "0.4", default-features = false, features = ["ed25519", "e chacha20 = "0.9" poly1305 = "0.7" -# for debugging pretty-hex = { version = "0.3", default-features = false } pin-utils = "0.1" @@ -57,7 +57,7 @@ branch = "mobilecoin" [features] default = [ "getrandom" ] -std = ["async-trait"] +std = ["async-trait", "snafu/std"] # tokio-queue = ["dep:tokio"] getrandom = ["rand/getrandom"] diff --git a/sshproto/examples/kex1.rs b/sshproto/examples/kex1.rs index 69d1a10..a2d3499 100644 --- a/sshproto/examples/kex1.rs +++ b/sshproto/examples/kex1.rs @@ -18,9 +18,9 @@ fn main() -> Result<()> { fn do_userauth() -> Result<()> { let p: Packet = packets::UserauthRequest { - username: "matt", + username: "matt".into(), service: "con", - method: AuthMethod::Password(packets::MethodPassword { change: false, password: "123" }), + method: AuthMethod::Password(packets::MethodPassword { change: false, password: "123".into() }), }.into(); let mut buf = vec![0; 2000]; @@ -42,16 +42,16 @@ fn do_kexinit() -> Result<()> { let k = KexInit { // cookie: &[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16], cookie: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], - kex: "hello,more".into(), - hostkey: "hello,more".into(), - cipher_c2s: "hello,more".into(), - cipher_s2c: "hello,more".into(), - mac_c2s: "hi".into(), - mac_s2c: "hello,more".into(), - comp_c2s: "hello,more".into(), - comp_s2c: "hello,more".into(), - lang_c2s: "hello,more".into(), - lang_s2c: "hello,more".into(), + kex: "hello,more".try_into().unwrap(), + hostkey: "hello,more".try_into().unwrap(), + cipher_c2s: "hello,more".try_into().unwrap(), + cipher_s2c: "hello,more".try_into().unwrap(), + mac_c2s: "hi".try_into().unwrap(), + mac_s2c: "hello,more".try_into().unwrap(), + comp_c2s: "hello,more".try_into().unwrap(), + comp_s2c: "hello,more".try_into().unwrap(), + lang_c2s: "hello,more".try_into().unwrap(), + lang_s2c: "hello,more".try_into().unwrap(), first_follows: false, reserved: 0, }; diff --git a/sshproto/src/async_behaviour.rs b/sshproto/src/async_behaviour.rs index afea4f9..5ad2bf3 100644 --- a/sshproto/src/async_behaviour.rs +++ b/sshproto/src/async_behaviour.rs @@ -102,7 +102,7 @@ pub trait AsyncCliBehaviour { /// Language may be empty, is provided by the server. #[allow(unused)] async fn show_banner(&self, banner: &str, language: &str) { - info!("Got banner:\n{}", banner.escape_default()); + info!("Got banner:\n{:?}", banner.escape_default()); } // TODO: postauth channel callbacks } diff --git a/sshproto/src/behaviour.rs b/sshproto/src/behaviour.rs index f48af86..ed7bdc9 100644 --- a/sshproto/src/behaviour.rs +++ b/sshproto/src/behaviour.rs @@ -18,6 +18,7 @@ use packets::{self,Packet}; use runner::{self,Runner}; use channel::ChanMsg; use conn::RespPackets; +use sshwire::TextString; // TODO: "Bh" is an ugly abbreviation. Naming is hard. @@ -153,11 +154,15 @@ impl<'a> CliBehaviour<'a> { self.inner.authenticated().await } - pub(crate) async fn show_banner(&self, banner: &str, language: &str) { - self.inner.show_banner(banner, language).await + pub(crate) async fn show_banner(&self, banner: TextString<'_>, language: TextString<'_>) -> Result<()> { + let banner = banner.as_str().map_err(|e| { warn!("Bad banner {:?}", banner); e})?; + let language = language.as_str()?; + self.inner.show_banner(banner, language).await; + Ok(()) } } +// no-std blocking variant #[cfg(not(feature = "std"))] impl<'a> CliBehaviour<'a> { pub(crate) async fn username(&mut self) -> BhResult<ResponseString>{ @@ -181,8 +186,12 @@ impl<'a> CliBehaviour<'a> { self.inner.authenticated() } - pub(crate) async fn show_banner(&self, banner: &str, language: &str) { - self.inner.show_banner(banner, language) + // TODO: make ascii/utf8 a feature + pub(crate) async fn show_banner(&self, banner: TextString<'_>, language: TextString<'_>) -> Result<()> { + let banner = banner.as_ascii().map_err(|e| { warn!("Bad banner {:?}", banner); e})?; + let language = language.as_ascii()?; + self.inner.show_banner(banner, language); + Ok(()) } } diff --git a/sshproto/src/block_behaviour.rs b/sshproto/src/block_behaviour.rs index bf642f7..e169a21 100644 --- a/sshproto/src/block_behaviour.rs +++ b/sshproto/src/block_behaviour.rs @@ -109,7 +109,7 @@ pub trait BlockCliBehaviour { /// Language may be empty, is provided by the server. #[allow(unused)] fn show_banner(&self, banner: &str, language: &str) { - info!("Got banner:\n{}", banner.escape_default()); + info!("Got banner:\n{:?}", banner.escape_default()); } // TODO: postauth channel callbacks } diff --git a/sshproto/src/channel.rs b/sshproto/src/channel.rs index 428d5e2..9cd0859 100644 --- a/sshproto/src/channel.rs +++ b/sshproto/src/channel.rs @@ -237,7 +237,7 @@ impl Req { todo!("serialize modes") } ReqDetails::Exec(cmd) => { - ChannelReqType::Exec(packets::Exec { command: &cmd }) + ChannelReqType::Exec(packets::Exec { command: cmd.as_str().into() }) } ReqDetails::WinChange(rt) => ChannelReqType::WinChange(rt.clone()), ReqDetails::Break(rt) => ChannelReqType::Break(rt.clone()), diff --git a/sshproto/src/cliauth.rs b/sshproto/src/cliauth.rs index 089f626..4790088 100644 --- a/sshproto/src/cliauth.rs +++ b/sshproto/src/cliauth.rs @@ -47,6 +47,7 @@ impl Req { username: &'b str, parse_ctx: &mut ParseContext, ) -> Result<Packet<'b>> { + let username = username.into(); let p = match self { Req::PubKey { key, .. } => { // already checked by make_pubkey_req() @@ -64,7 +65,7 @@ impl Req { service: SSH_SERVICE_CONNECTION, method: packets::AuthMethod::Password(packets::MethodPassword { change: false, - password: pw, + password: pw.as_str().into(), }), }.into() } @@ -113,7 +114,7 @@ impl CliAuth { resp.push(p.into()).trap()?; let p: Packet = packets::UserauthRequest { - username: &self.username, + username: self.username.as_str().into(), service: SSH_SERVICE_CONNECTION, method: packets::AuthMethod::None, }.into(); @@ -161,7 +162,7 @@ impl CliAuth { }) = p { let sig_packet = UserauthRequest { - username, + username: *username, service, method: AuthMethod::PubKey(MethodPubKey { sig_algo, diff --git a/sshproto/src/client.rs b/sshproto/src/client.rs index b82ac06..e295a1b 100644 --- a/sshproto/src/client.rs +++ b/sshproto/src/client.rs @@ -39,6 +39,8 @@ impl Client { } pub(crate) async fn banner(&mut self, banner: &packets::UserauthBanner<'_>, b: &mut CliBehaviour<'_>) { - b.show_banner(banner.message, banner.lang).await + if let Err(e) = b.show_banner(banner.message, banner.lang).await { + warn!("Banner not shown: {e}") + } } } diff --git a/sshproto/src/conn.rs b/sshproto/src/conn.rs index aa3d2ce..28726ca 100644 --- a/sshproto/src/conn.rs +++ b/sshproto/src/conn.rs @@ -266,14 +266,11 @@ impl<'a> Conn<'a> { warn!("Received SSH unimplemented message"); } Packet::DebugPacket(p) => { - warn!( - "SSH debug message from remote host: '{}'", - p.message.escape_default() - ); + warn!("SSH debug message from remote host: '{:?}'", p.message); } Packet::Disconnect(p) => { // TODO: SSH2_DISCONNECT_BY_APPLICATION is normal, sent by openssh client. - info!("Received disconnect: {}", p.desc.escape_default()); + info!("Received disconnect: {:?}", p.desc); } Packet::UserauthRequest(_p) => { // TODO: this is server only diff --git a/sshproto/src/error.rs b/sshproto/src/error.rs index 3a7b8b8..7df141c 100644 --- a/sshproto/src/error.rs +++ b/sshproto/src/error.rs @@ -30,6 +30,9 @@ pub enum Error { /// Not a UTF8 string BadString, + /// Not a valid SSH ascii string + BadName, + /// Decryption failure or integrity mismatch BadDecrypt, diff --git a/sshproto/src/namelist.rs b/sshproto/src/namelist.rs index 3180172..3c3d52d 100644 --- a/sshproto/src/namelist.rs +++ b/sshproto/src/namelist.rs @@ -5,15 +5,17 @@ use { log::{debug, error, info, log, trace, warn}, }; +use ascii::{AsciiStr, AsciiChar::Comma}; + use sshwire_derive::{SSHEncode, SSHDecode}; use crate::*; -use sshwire::{SSHEncode, SSHDecode, SSHSource, SSHSink}; +use sshwire::{SSHEncode, SSHDecode, SSHSource, SSHSink, BinString, try_as_ascii}; /// A comma separated string, can be decoded or encoded. /// Used for remote name lists. #[derive(SSHEncode, SSHDecode, Debug)] -pub struct StringNames<'a>(pub &'a str); +pub struct StringNames<'a>(pub &'a AsciiStr); /// A list of names, can only be encoded. Used for local name lists, comes /// from local fixed lists @@ -35,12 +37,7 @@ impl<'de: 'a, 'a> SSHDecode<'de> for NameList<'a> { where S: SSHSource<'de>, { - let i = StringNames::dec(s)?; - if i.0.is_ascii() { - Ok(NameList::String(i)) - } else { - Err(Error::BadString) - } + Ok(NameList::String(StringNames::dec(s)?)) } } @@ -63,21 +60,25 @@ impl SSHEncode for LocalNames<'_> { } } -impl<'a> From<&'a str> for StringNames<'a> { - fn from(s: &'a str) -> Self { - Self(s) +// for tests +impl<'a> TryFrom<&'a str> for StringNames<'a> { + type Error = (); + fn try_from(s: &'a str) -> Result<Self, Self::Error> { + Ok(Self(AsciiStr::from_ascii(s).map_err(|_| ())?)) } } +impl<'a> TryFrom<&'a str> for NameList<'a> { + type Error = (); + fn try_from(s: &'a str) -> Result<Self, Self::Error> { + Ok(NameList::String(s.try_into()?)) + } +} + impl<'a> From<&'a [&'static str]> for LocalNames<'a> { fn from(s: &'a [&'static str]) -> Self { Self(s) } } -impl<'a> From<&'a str> for NameList<'a> { - fn from(s: &'a str) -> Self { - NameList::String(s.into()) - } -} impl<'a> From<&LocalNames<'a>> for NameList<'a> { fn from(s: &LocalNames<'a>) -> Self { NameList::Local(LocalNames(s.0)) @@ -127,7 +128,7 @@ impl<'a> NameList<'a> { impl<'a> StringNames<'a> { /// Returns the first name in this namelist that matches one of the provided options fn first_string_match(&self, options: &LocalNames) -> Option<&'static str> { - for n in self.0.split(',') { + for n in self.0.split(Comma) { for o in options.0.iter() { if n == *o { return Some(*o); @@ -140,7 +141,7 @@ impl<'a> StringNames<'a> { /// Returns the first of "options" that is in this namelist fn first_options_match(&self, options: &LocalNames) -> Option<&'static str> { for o in options.0.iter() { - for n in self.0.split(',') { + for n in self.0.split(Comma) { if n == *o { return Some(*o); } @@ -151,11 +152,11 @@ impl<'a> StringNames<'a> { fn first(&self) -> &str { // unwrap is OK, split() always returns an item - self.0.split(',').next().unwrap() + self.0.split(Comma).next().unwrap().as_str() } fn has_algo(&self, algo: &str) -> bool { - self.0.split(',').any(|a| a == algo) + self.0.split(Comma).any(|a| a == algo) } } @@ -179,8 +180,8 @@ mod tests { #[test] fn test_match() { - let r1 = NameList::String("rho,cog".into()); - let r2 = NameList::String("woe".into()); + let r1 = NameList::String("rho,cog".try_into().unwrap()); + let r2 = NameList::String("woe".try_into().unwrap()); let l1 = LocalNames(&["rho", "cog"]); let l2 = LocalNames(&["cog", "rho"]); let l3 = LocalNames(&["now", "woe"]); @@ -228,7 +229,7 @@ mod tests { for t in tests.iter() { let l = NameList::Local(LocalNames(t)); let x = t.join(","); - let s = NameList::String(StringNames(&x)); + let s: NameList = x.as_str().try_into().unwrap(); assert_eq!(l.first(), s.first()); if t.len() == 0{ assert_eq!(l.first(), ""); @@ -241,7 +242,8 @@ mod tests { #[test] fn test_has_algo() { fn n(list: &str, has: &str) -> bool { - NameList::String(StringNames(list)).has_algo(has).unwrap() + let s: NameList = list.try_into().unwrap(); + s.has_algo(has).unwrap() } assert_eq!(n("", ""), true); assert_eq!(n("", "one"), false); diff --git a/sshproto/src/packets.rs b/sshproto/src/packets.rs index 2d4341a..291f561 100644 --- a/sshproto/src/packets.rs +++ b/sshproto/src/packets.rs @@ -12,13 +12,14 @@ use { }; use heapless::String; +use pretty_hex::PrettyHex; use sshwire_derive::*; use crate::*; use namelist::NameList; use sshnames::*; -use sshwire::{BinString, Blob}; +use sshwire::{BinString, TextString, Blob}; use sign::{SigType, OwnedSig}; use sshwire::{SSHEncode, SSHEncodeEnum, SSHDecode, SSHDecodeEnum, SSHSource, SSHSink}; @@ -52,15 +53,15 @@ pub struct Ignore {} #[derive(Debug, SSHEncode, SSHDecode)] pub struct DebugPacket<'a> { pub always_display: bool, - pub message: &'a str, + pub message: TextString<'a>, pub lang: &'a str, } #[derive(Debug, SSHEncode, SSHDecode)] pub struct Disconnect<'a> { pub reason: u32, - pub desc: &'a str, - pub lang: &'a str, + pub desc: TextString<'a>, + pub lang: TextString<'a>, } #[derive(Debug, SSHEncode, SSHDecode)] @@ -92,7 +93,7 @@ pub struct ServiceAccept<'a> { #[derive(Debug, SSHEncode, SSHDecode)] pub struct UserauthRequest<'a> { - pub username: &'a str, + pub username: TextString<'a>, pub service: &'a str, pub method: AuthMethod<'a>, } @@ -155,14 +156,14 @@ pub struct UserauthPkOk<'a> { #[derive(Debug, SSHEncode, SSHDecode)] pub struct UserauthPwChangeReq<'a> { - pub prompt: &'a str, - pub lang: &'a str, + pub prompt: TextString<'a>, + pub lang: TextString<'a>, } #[derive(SSHEncode, SSHDecode)] pub struct MethodPassword<'a> { pub change: bool, - pub password: &'a str, + pub password: TextString<'a>, } // Don't print password @@ -221,8 +222,8 @@ pub struct UserauthSuccess {} #[derive(Debug, SSHEncode, SSHDecode)] pub struct UserauthBanner<'a> { - pub message: &'a str, - pub lang: &'a str, + pub message: TextString<'a>, + pub lang: TextString<'a>, } #[derive(SSHEncode, SSHDecode, Debug, Clone, PartialEq)] @@ -238,11 +239,11 @@ pub enum PubKey<'a> { impl<'a> PubKey<'a> { /// The algorithm name presented. May be invalid. - pub fn algorithm_name(&self) -> &'a str { + pub fn algorithm_name(&self) -> Result<&'a str, &Unknown<'a>> { match self { - PubKey::Ed25519(_) => SSH_NAME_ED25519, - PubKey::RSA(_) => SSH_NAME_RSA, - PubKey::Unknown(u) => u.0, + PubKey::Ed25519(_) => Ok(SSH_NAME_ED25519), + PubKey::RSA(_) => Ok(SSH_NAME_RSA), + PubKey::Unknown(u) => Err(u), } } } @@ -287,11 +288,11 @@ pub enum Signature<'a> { impl<'a> Signature<'a> { /// The algorithm name presented. May be invalid. - pub fn algorithm_name(&self) -> &'a str { + pub fn algorithm_name(&self) -> Result<&'a str, &Unknown<'a>> { match self { - Signature::Ed25519(_) => SSH_NAME_ED25519, - Signature::RSA256(_) => SSH_NAME_RSA_SHA256, - Signature::Unknown(u) => u.0, + Signature::Ed25519(_) => Ok(SSH_NAME_ED25519), + Signature::RSA256(_) => Ok(SSH_NAME_RSA_SHA256), + Signature::Unknown(u) => Err(u), } } @@ -392,7 +393,7 @@ pub struct ChannelOpenConfirmation { pub struct ChannelOpenFailure<'a> { pub num: u32, pub reason: u32, - pub desc: &'a str, + pub desc: TextString<'a>, pub lang: &'a str, } @@ -477,12 +478,12 @@ pub enum ChannelReqType<'a> { #[derive(Debug, SSHEncode, SSHDecode)] pub struct Exec<'a> { - pub command: &'a str, + pub command: TextString<'a>, } #[derive(Debug, SSHEncode, SSHDecode)] pub struct Pty<'a> { - pub term: &'a str, + pub term: TextString<'a>, pub cols: u32, pub rows: u32, pub width: u32, @@ -517,7 +518,7 @@ pub struct ExitStatus { pub struct ExitSignal<'a> { pub signal: &'a str, pub core: bool, - pub error: &'a str, + pub error: TextString<'a>, pub lang: &'a str, } @@ -528,17 +529,17 @@ pub struct Break { #[derive(Debug, SSHEncode, SSHDecode)] pub struct ForwardedTcpip<'a> { - pub address: &'a str, + pub address: TextString<'a>, pub port: u32, - pub origin: &'a str, + pub origin: TextString<'a>, pub origin_port: u32, } #[derive(Debug, SSHEncode, SSHDecode)] pub struct DirectTcpip<'a> { - pub address: &'a str, + pub address: TextString<'a>, pub port: u32, - pub origin: &'a str, + pub origin: TextString<'a>, pub origin_port: u32, } @@ -547,13 +548,16 @@ pub struct DirectTcpip<'a> { // need to be handled by the relevant code, for example newly invented pubkey types // This is deliberately not Serializable, we only receive it. #[derive(Debug, Clone, PartialEq)] -pub struct Unknown<'a>(pub &'a str); +pub struct Unknown<'a>(pub &'a [u8]); impl core::fmt::Display for Unknown<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.escape_default().fmt(f) + if let Ok(s) = sshwire::try_as_ascii_str(self.0) { + f.write_str(s) + } else { + write!(f, "non-ascii {:?}", self.0.hex_dump()) + } } - } /// State to be passed to decoding. @@ -749,8 +753,8 @@ mod tests { // with None sig let s = sign::tests::make_ed25519_signkey(); let p = UserauthRequest { - username: "matt", - service: "conn", + username: "matt".into(), + service: "conn".into(), method: s.pubkey().try_into().unwrap(), }.into(); test_roundtrip(&p); @@ -761,7 +765,7 @@ mod tests { }); let sig = Some(Blob(sig)); let p = UserauthRequest { - username: "matt", + username: "matt".into(), service: "conn", method: s.pubkey().try_into().unwrap(), }.into(); @@ -776,9 +780,9 @@ mod tests { initial_window: 50000, max_packet: 20000, ch: ChannelOpenType::DirectTcpip(DirectTcpip { - address: "localhost", + address: "localhost".into(), port: 4444, - origin: "somewhere", + origin: "somewhere".into(), origin_port: 0, }), }); @@ -821,7 +825,7 @@ mod tests { num: 0, initial_window: 200000, max_packet: 88200, - ch: ChannelOpenType::Unknown(Unknown("audio-stream")) + ch: ChannelOpenType::Unknown(Unknown(b"audio-stream")) }); let mut buf1 = vec![88; 1000]; write_ssh(&mut buf1, &p).unwrap(); diff --git a/sshproto/src/sign.rs b/sshproto/src/sign.rs index 7c30fcf..cc79cbd 100644 --- a/sshproto/src/sign.rs +++ b/sshproto/src/sign.rs @@ -54,7 +54,7 @@ impl SigType { // This would also get caught by SignatureMismatch below // but that error message is intended for mismatch key vs sig. if discriminant(&sig_type) != discriminant(self) { - warn!("Received {} signature, expecting {}", + warn!("Received {:?} signature, expecting {}", sig.algorithm_name(), self.algorithm_name()); return Err(Error::BadSignature) } @@ -85,7 +85,7 @@ impl SigType { } _ => { - warn!("Signature \"{}\" doesn't match key type \"{}\"", + warn!("Signature \"{:?}\" doesn't match key type \"{:?}\"", sig.algorithm_name(), pubkey.algorithm_name(), ); diff --git a/sshproto/src/sshwire.rs b/sshproto/src/sshwire.rs index 4283030..080c6ce 100644 --- a/sshproto/src/sshwire.rs +++ b/sshproto/src/sshwire.rs @@ -7,6 +7,9 @@ use { use core::str; use core::convert::AsRef; use core::fmt::{self,Debug}; +use pretty_hex::PrettyHex; + +use ascii::{AsAsciiStr, AsciiChar, AsciiStr}; use crate::*; use packets::{Packet, ParseContext}; @@ -42,8 +45,8 @@ pub trait SSHDecode<'de>: Sized { /// Decodes enums with an externally provided name pub trait SSHDecodeEnum<'de>: Sized { - /// `var` is the variant name to decode - fn dec_enum<S>(s: &mut S, var: &'de str) -> Result<Self> where S: SSHSource<'de>; + /// `var` is the variant name to decode, as raw bytes off the wire. + fn dec_enum<S>(s: &mut S, var: &'de [u8]) -> Result<Self> where S: SSHSource<'de>; } /////////////////////////////////////////////// @@ -211,6 +214,69 @@ impl<'de> SSHDecode<'de> for BinString<'de> { } +/// A text string that may be presented to a user. +/// The SSH protocol defines it to be UTF-8, though +/// in some applications it can be treated as ascii-only. +/// The library treats it as an opaque `&[u8]`, leaving +/// decoding to the `Behaviour`. + +/// Note that SSH protocol identifiers in `Packet` etc +/// are `&str` rather than `TextString`, and always defined as ASCII. +#[derive(Clone,PartialEq,Copy)] +pub struct TextString<'a>(pub &'a [u8]); + +impl<'a> TextString<'a> { + /// Returns the utf8 decoded string, using [`core::str::from_utf8`] + /// Don't call this if you are avoiding including utf8 routines in + /// the binary. + pub fn as_str(&self) -> Result<&'a str> { + core::str::from_utf8(self.0).map_err(|_| Error::BadString) + } + + pub fn as_ascii(&self) -> Result<&'a str> { + self.0.as_ascii_str().map_err(|_| Error::BadString).map(|s| s.as_str()) + } +} + +impl<'a> AsRef<[u8]> for TextString<'a> { + fn as_ref(&self) -> &'a [u8] { + self.0 + } +} + +impl<'a> From<&'a str> for TextString<'a> { + fn from(s: &'a str) -> Self { + TextString(s.as_bytes()) + } +} + +impl<'a> Debug for TextString<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let s = core::str::from_utf8(self.0); + if let Ok(s) = s { + write!(f, "TextString(\"{}\")", s.escape_default()) + } else { + write!(f, "TextString(not utf8!, {:#?})", self.0.hex_dump()) + } + } +} + +impl SSHEncode for TextString<'_> { + fn enc<S>(&self, s: &mut S) -> Result<()> + where S: sshwire::SSHSink { + (self.0.len() as u32).enc(s)?; + self.0.enc(s) + } +} + +impl<'de> SSHDecode<'de> for TextString<'de> { + fn dec<S>(s: &mut S) -> Result<Self> + where S: sshwire::SSHSource<'de> { + let len = u32::dec(s)? as usize; + Ok(TextString(s.take(len)?)) + } +} + // A wrapper for a u32 prefixed data structure `B`, such as a public key blob pub struct Blob<B>(pub B); @@ -313,6 +379,14 @@ impl<T: SSHEncode> SSHEncode for Option<T> { } } +impl SSHEncode for &AsciiStr{ + fn enc<S>(&self, s: &mut S) -> Result<()> + where S: SSHSink { + let v = self.as_bytes(); + BinString(v).enc(s) + } +} + impl<'de> SSHDecode<'de> for bool { fn dec<S>(s: &mut S) -> Result<Self> where S: SSHSource<'de> { @@ -340,13 +414,36 @@ impl<'de> SSHDecode<'de> for u32 { } } +/// Decodes a SSH name string. Must be ascii +/// without control characters. RFC4251 section 6. +pub fn try_as_ascii<'a>(t: &'a [u8]) -> Result<&'a AsciiStr> { + let n = t.as_ascii_str().map_err(|_| Error::BadName)?; + if n.chars().any(|ch| ch.is_ascii_control() || ch == AsciiChar::DEL) { + return Err(Error::BadName); + } + Ok(n) +} + +pub fn try_as_ascii_str<'a>(t: &'a [u8]) -> Result<&'a str> { + try_as_ascii(t).map(AsciiStr::as_str) +} + impl<'de: 'a, 'a> SSHDecode<'de> for &'a str { #[inline] fn dec<S>(s: &mut S) -> Result<Self> where S: SSHSource<'de> { let len = u32::dec(s)?; let t = s.take(len as usize)?; - str::from_utf8(t).map_err(|_| Error::BadString) + try_as_ascii_str(t) + } +} + +impl<'de: 'a, 'a> SSHDecode<'de> for &'de AsciiStr { + fn dec<S>(s: &mut S) -> Result<&'de AsciiStr> + where + S: SSHSource<'de>, { + let b: BinString = SSHDecode::dec(s)?; + try_as_ascii(b.0) } } @@ -360,6 +457,7 @@ impl<'de, const N: usize> SSHDecode<'de> for [u8; N] { } } + #[cfg(test)] pub(crate) mod tests { use crate::*; @@ -433,8 +531,8 @@ pub(crate) mod tests { let mut ctx = ParseContext::new(); let p = Userauth60::PwChangeReq(UserauthPwChangeReq { - prompt: "change the password", - lang: "", + prompt: "change the password".into(), + lang: "".into(), }).into(); let mut pw = ResponseString::new(); pw.push_str("123").unwrap(); diff --git a/sshproto/src/test.rs b/sshproto/src/test.rs index e70fb3d..fa51115 100644 --- a/sshproto/src/test.rs +++ b/sshproto/src/test.rs @@ -35,16 +35,16 @@ mod tests { fn roundtrip_kexinit() { let k = KexInit { cookie: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], - kex: "kex".into(), - hostkey: "hostkey,another".into(), - cipher_c2s: "chacha20-poly1305@openssh.com,aes128-ctr".into(), - cipher_s2c: "blowfish".into(), - mac_c2s: "hmac-sha1".into(), - mac_s2c: "hmac-md5".into(), - comp_c2s: "none".into(), - comp_s2c: "".into(), - lang_c2s: "".into(), - lang_s2c: "".into(), + kex: "kex".try_into().unwrap(), + hostkey: "hostkey,another".try_into().unwrap(), + cipher_c2s: "chacha20-poly1305@openssh.com,aes128-ctr".try_into().unwrap(), + cipher_s2c: "blowfish".try_into().unwrap(), + mac_c2s: "hmac-sha1".try_into().unwrap(), + mac_s2c: "hmac-md5".try_into().unwrap(), + comp_c2s: "none".try_into().unwrap(), + comp_s2c: "".try_into().unwrap(), + lang_c2s: "".try_into().unwrap(), + lang_s2c: "".try_into().unwrap(), first_follows: true, reserved: 0x6148291e, }; diff --git a/sshwire_derive/src/lib.rs b/sshwire_derive/src/lib.rs index 8efe55f..ff39dc4 100644 --- a/sshwire_derive/src/lib.rs +++ b/sshwire_derive/src/lib.rs @@ -393,12 +393,12 @@ fn decode_struct(gen: &mut Generator, body: StructBody) -> Result<()> { if let FieldAtt::VariantName(enum_field) = a { // Read the extra field on the wire that isn't directly included in the struct named_enums.insert(enum_field.to_string()); - fn_body.push_parsed(format!("let enum_name_{enum_field} = crate::sshwire::SSHDecode::dec(s)?;"))?; + fn_body.push_parsed(format!("let enum_name_{enum_field}: BinString = crate::sshwire::SSHDecode::dec(s)?;"))?; } } let fname = &f.0; if named_enums.contains(&fname.to_string()) { - fn_body.push_parsed(format!("let field_{fname} = crate::sshwire::SSHDecodeEnum::dec_enum(s, enum_name_{fname})?;"))?; + fn_body.push_parsed(format!("let field_{fname} = crate::sshwire::SSHDecodeEnum::dec_enum(s, enum_name_{fname}.0)?;"))?; } else { fn_body.push_parsed(format!("let field_{fname} = crate::sshwire::SSHDecode::dec(s)?;"))?; } @@ -477,9 +477,9 @@ fn decode_enum_variant_prefix( .with_return_type("Result<Self>") .body(|fn_body| { fn_body - .push_parsed("let variant = crate::sshwire::SSHDecode::dec(s)?;")?; + .push_parsed("let variant: crate::sshwire::BinString = crate::sshwire::SSHDecode::dec(s)?;")?; fn_body.push_parsed( - "crate::sshwire::SSHDecodeEnum::dec_enum(s, variant)", + "crate::sshwire::SSHDecodeEnum::dec_enum(s, variant.0)", )?; Ok(()) }) @@ -494,10 +494,13 @@ fn decode_enum_names( .generate_fn("dec_enum") .with_generic_deps("S", ["crate::sshwire::SSHSource<'de>"]) .with_arg("s", "&mut S") - .with_arg("variant", "&'de str") + .with_arg("variant", "&'de [u8]") .with_return_type("Result<Self>") .body(|fn_body| { - fn_body.push_parsed("let r = match variant")?; + // Some(ascii_string), or None + fn_body.push_parsed("let var_str = crate::sshwire::try_as_ascii_str(variant).ok();")?; + + fn_body.push_parsed("let r = match var_str")?; fn_body.group(Delimiter::Brace, |match_arm| { let mut unknown_arm = None; for var in &body.variants { @@ -505,13 +508,13 @@ fn decode_enum_names( if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) { // create the Unknown fallthrough but it will be at the end of the match list let mut m = StreamBuilder::new(); - m.push_parsed(format!("unk => Self::{}(Unknown(unk))", var.name))?; + m.push_parsed(format!("_ => Self::{}(Unknown(variant))", var.name))?; if unknown_arm.replace(m).is_some() { return Err(Error::Custom { error: "only one variant can have #[sshwire(unknown)]".into(), span: None}) } } else { let var_name = field_att_var_names(&var.name, atts)?; - match_arm.push_parsed(format!("{} => ", var_name))?; + match_arm.push_parsed(format!("Some({}) => ", var_name))?; match_arm.group(Delimiter::Brace, |var_body| { match var.fields { Fields::Unit => { -- GitLab