diff --git a/sshproto/src/channel.rs b/sshproto/src/channel.rs index 9cd0859f0f06573169563ccf9fb18ef89a6c6af7..cb161a30d31b205c9b5520edabf00441dbe3bbec 100644 --- a/sshproto/src/channel.rs +++ b/sshproto/src/channel.rs @@ -74,7 +74,8 @@ impl Channels { fn remove(&mut self, num: u32) -> Result<()> { // TODO any checks? *self.ch.get_mut(num as usize).ok_or(Error::BadChannel)? = None; - Ok(()) + Err(Error::otherbug()) + // Ok(()) } // incoming packet handling diff --git a/sshproto/src/error.rs b/sshproto/src/error.rs index 7df141c5bdebd923aa4492df45860f3356b02f5f..04bca5a182495484f31e35305a06837771f8a4e9 100644 --- a/sshproto/src/error.rs +++ b/sshproto/src/error.rs @@ -92,8 +92,12 @@ pub enum Error { // This state should not be reached, previous logic should have prevented it. // Create this using [`Error::bug()`] or [`.trap()`](TrapBug::trap). - #[snafu(display("Program bug {location}"))] - Bug { location: snafu::Location }, + // #[snafu(display("Program bug {location}"))] + // Bug { location: snafu::Location }, + /// Program bug + Bug, + + OtherBug { location: snafu::Location }, } impl Error { @@ -101,18 +105,36 @@ impl Error { Error::Custom { msg: m } } - #[track_caller] #[cold] /// Panics in debug builds, returns [`Error::Bug`] in release. // TODO: this should return a Result since it's always used as Err(Error::bug()) pub fn bug() -> Error { + // Easier to track the source of errors in development, + // but release builds shouldn't panic. + if cfg!(debug_assertions) { + panic!("Hit a bug"); + } else { + // let caller = core::panic::Location::caller(); + Error::Bug + // { + // location: snafu::Location::new( + // caller.file(), + // caller.line(), + // caller.column(), + // ), + // } + } + } + + pub fn otherbug() -> Error { // Easier to track the source of errors in development, // but release builds shouldn't panic. if cfg!(debug_assertions) { panic!("Hit a bug"); } else { let caller = core::panic::Location::caller(); - Error::Bug { + Error::OtherBug + { location: snafu::Location::new( caller.file(), caller.line(), @@ -124,10 +146,8 @@ impl Error { /// Like [`bug()`] but with a message /// The message can be used instead of a code comment, is logged at `debug` level. - #[track_caller] #[cold] - /// TODO: is the generic `T` going to make it bloat? - pub fn bug_args<T>(args: Arguments) -> Result<T, Error> { + pub fn bug_args(args: Arguments) -> Error { // Easier to track the source of errors in development, // but release builds shouldn't panic. if cfg!(debug_assertions) { @@ -136,21 +156,26 @@ impl Error { debug!("Hit a bug: {args}"); // TODO: this bloats binaries with full paths // https://github.com/rust-lang/rust/issues/95529 is having function - let caller = core::panic::Location::caller(); - Err(Error::Bug { - location: snafu::Location::new( - caller.file(), - caller.line(), - caller.column(), - ), - }) + // let caller = core::panic::Location::caller(); + Error::Bug + // { + // location: snafu::Location::new( + // caller.file(), + // caller.line(), + // caller.column(), + // ), + // } } } - #[track_caller] #[cold] /// TODO: is the generic `T` going to make it bloat? pub fn bug_msg<T>(msg: &str) -> Result<T, Error> { + Err(Self::bug_args(format_args!("{}", msg))) + } + + #[cold] + pub fn bug_err_msg(msg: &str) -> Error { Self::bug_args(format_args!("{}", msg)) } @@ -162,12 +187,10 @@ pub trait TrapBug<T> { /// `.trap()` should be used like `.unwrap()`, in situations /// never expected to fail. Instead it calls [`Error::bug()`]. /// (or debug builds may panic) - #[track_caller] fn trap(self) -> Result<T, Error>; /// Like [`trap()`] but with a message, calls [`Error::bug_msg()`] /// The message can be used instead of a comment. - #[track_caller] fn trap_msg(self, args: Arguments) -> Result<T, Error>; } @@ -185,7 +208,7 @@ impl<T, E> TrapBug<T> for Result<T, E> { if let Ok(i) = self { Ok(i) } else { - Error::bug_args(args) + Err(Error::bug_args(args)) } } } @@ -204,7 +227,7 @@ impl<T> TrapBug<T> for Option<T> { if let Some(i) = self { Ok(i) } else { - Error::bug_args(args) + Err(Error::bug_args(args)) } } } diff --git a/sshproto/src/namelist.rs b/sshproto/src/namelist.rs index 3c3d52dbe34b9bac8c95995e593e169851be1c00..f459fb8941d065e7620a4bcd052b9e0c4782a6e3 100644 --- a/sshproto/src/namelist.rs +++ b/sshproto/src/namelist.rs @@ -10,7 +10,7 @@ use ascii::{AsciiStr, AsciiChar::Comma}; use sshwire_derive::{SSHEncode, SSHDecode}; use crate::*; -use sshwire::{SSHEncode, SSHDecode, SSHSource, SSHSink, BinString, try_as_ascii}; +use sshwire::{SSHEncode, SSHDecode, SSHSource, SSHSink, BinString, WireResult}; /// A comma separated string, can be decoded or encoded. /// Used for remote name lists. @@ -33,7 +33,7 @@ pub enum NameList<'a> { } impl<'de: 'a, 'a> SSHDecode<'de> for NameList<'a> { - fn dec<S>(s: &mut S) -> Result<NameList<'a>> + fn dec<S>(s: &mut S) -> WireResult<NameList<'a>> where S: SSHSource<'de>, { @@ -43,7 +43,7 @@ impl<'de: 'a, 'a> SSHDecode<'de> for NameList<'a> { /// Serialize the list of names with comma separators impl SSHEncode for LocalNames<'_> { - fn enc<S>(&self, e: &mut S) -> Result<()> + fn enc<S>(&self, e: &mut S) -> WireResult<()> where S: sshwire::SSHSink { let names = &self.0; // space for names and commas diff --git a/sshproto/src/packets.rs b/sshproto/src/packets.rs index 291f56118a6580d25e18142679b6834689bc5641..e34921c9ddd6a2c0eb150e158c5772ea33baa5aa 100644 --- a/sshproto/src/packets.rs +++ b/sshproto/src/packets.rs @@ -21,7 +21,8 @@ use namelist::NameList; use sshnames::*; use sshwire::{BinString, TextString, Blob}; use sign::{SigType, OwnedSig}; -use sshwire::{SSHEncode, SSHEncodeEnum, SSHDecode, SSHDecodeEnum, SSHSource, SSHSink}; +use sshwire::{SSHEncode, SSHDecode, SSHSource, SSHSink, WireResult, WireError}; +use sshwire::{SSHEncodeEnum, SSHDecodeEnum}; // Any `enum` needs to have special handling to select a variant when deserializing. // This is mostly done with `#[sshwire(...)]` attributes. @@ -135,14 +136,14 @@ pub enum Userauth60<'a> { } impl<'de: 'a, 'a> SSHDecode<'de> for Userauth60<'a> { - fn dec<S>(s: &mut S) -> Result<Self> + fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { match s.ctx().cli_auth_type { Some(cliauth::AuthType::Password) => Ok(Self::PwChangeReq(SSHDecode::dec(s)?)), Some(cliauth::AuthType::PubKey) => Ok(Self::PkOk(SSHDecode::dec(s)?)), _ => { trace!("Wrong packet state for userauth60"); - return Err(Error::PacketWrong) + return Err(WireError::PacketWrong) } } } @@ -184,7 +185,7 @@ pub struct MethodPubKey<'a> { } impl SSHEncode for MethodPubKey<'_> { - fn enc<S>(&self, s: &mut S) -> Result<()> + fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink { let force_sig_bool = s.ctx().map_or(false, |c| c.method_pubkey_force_sig_bool); let sig = self.sig.is_some() || force_sig_bool; @@ -197,7 +198,7 @@ impl SSHEncode for MethodPubKey<'_> { } impl<'de: 'a, 'a> SSHDecode<'de> for MethodPubKey<'a> { - fn dec<S>(s: &mut S) -> Result<Self> + fn dec<S>(s: &mut S) -> WireResult<Self> where S: sshwire::SSHSource<'de> { let sig = bool::dec(s)?; let sig_algo = SSHDecode::dec(s)?; @@ -611,7 +612,7 @@ impl TryFrom<u8> for MessageNumber { } impl SSHEncode for Packet<'_> { - fn enc<S>(&self, s: &mut S) -> Result<()> + fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink { let t = self.message_num() as u8; t.enc(s)?; @@ -630,13 +631,13 @@ impl SSHEncode for Packet<'_> { } impl<'de: 'a, 'a> SSHDecode<'de> for Packet<'a> { - fn dec<S>(s: &mut S) -> Result<Self> + fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { let msg_num = u8::dec(s)?; let ty = MessageNumber::try_from(msg_num); let ty = match ty { Ok(t) => t, - Err(_) => return Err(Error::UnknownPacket { number: msg_num }) + Err(_) => return Err(WireError::UnknownPacket { number: msg_num }) }; // Decode based on the message number diff --git a/sshproto/src/sign.rs b/sshproto/src/sign.rs index cc79cbda309416aed65666c2cb14656d213330b1..819a701f558d585547018888c0f8ff7bc743509d 100644 --- a/sshproto/src/sign.rs +++ b/sshproto/src/sign.rs @@ -62,7 +62,7 @@ impl SigType { match (self, pubkey, sig) { (SigType::Ed25519, PubKey::Ed25519(k), Signature::Ed25519(s)) => { - let k = dalek::PublicKey::from_bytes(k.key.0).map_err(|_| Error::BadKey)?; + let k = dalek::PublicKey::from_bytes(k.key.0).map_err(|_| Error::BadSignature)?; let s = dalek::Signature::from_bytes(s.sig.0).map_err(|_| Error::BadSignature)?; k.verify(message, &s).map_err(|_| Error::BadSignature) } diff --git a/sshproto/src/sshwire.rs b/sshproto/src/sshwire.rs index 080c6cee2fc323372ffb7d2bdc7963e331a0c16e..12c2524f7ba857d2dd85b4543d50bfb670a21bc6 100644 --- a/sshproto/src/sshwire.rs +++ b/sshproto/src/sshwire.rs @@ -8,6 +8,7 @@ use core::str; use core::convert::AsRef; use core::fmt::{self,Debug}; use pretty_hex::PrettyHex; +use snafu::{prelude::*, Location}; use ascii::{AsAsciiStr, AsciiChar, AsciiStr}; @@ -16,50 +17,86 @@ use packets::{Packet, ParseContext}; pub trait SSHSink { - fn push(&mut self, v: &[u8]) -> Result<()>; + fn push(&mut self, v: &[u8]) -> WireResult<()>; fn ctx(&self) -> Option<&ParseContext> { None } } pub trait SSHSource<'de> { - fn take(&mut self, len: usize) -> Result<&'de [u8]>; + fn take(&mut self, len: usize) -> WireResult<&'de [u8]>; fn ctx(&self) -> &ParseContext; } pub trait SSHEncode { - fn enc<S>(&self, s: &mut S) -> Result<()> where S: SSHSink; + fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink; } /// For enums with an externally provided name pub trait SSHEncodeEnum { /// Returns the current variant, used for encoding parent structs. /// Fails if it is Unknown - fn variant_name(&self) -> Result<&'static str>; + fn variant_name(&self) -> WireResult<&'static str>; } /// Decodes `struct` and `enum`s without an externally provided enum name pub trait SSHDecode<'de>: Sized { - fn dec<S>(s: &mut S) -> Result<Self> where S: SSHSource<'de>; + fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de>; } /// Decodes enums with an externally provided name pub trait SSHDecodeEnum<'de>: Sized { /// `var` is the variant name to decode, as raw bytes off the wire. - fn dec_enum<S>(s: &mut S, var: &'de [u8]) -> Result<Self> where S: SSHSource<'de>; + fn dec_enum<S>(s: &mut S, var: &'de [u8]) -> WireResult<Self> where S: SSHSource<'de>; } +/// A subset of [`Error`] for `SSHEncode` and `SSHDecode`. +/// Compiled code size is very sensitive to the size of this +/// enum so we avoid unused elements. +#[derive(Debug)] +pub enum WireError { + NoRoom, + + RanOut, + + BadString, + + BadName, + + UnknownVariant, + + PacketWrong, + + UnknownPacket { number: u8 }, +} + +impl From<WireError> for Error { + fn from(w: WireError) -> Self { + match w { + WireError::NoRoom => Error::NoRoom, + WireError::RanOut => Error::RanOut, + WireError::BadString => Error::BadString, + WireError::BadName => Error::BadName, + WireError::PacketWrong => Error::PacketWrong, + WireError::UnknownVariant => Error::bug_err_msg("Can't encode Unknown"), + WireError::UnknownPacket { number } => Error::UnknownPacket { number }, + } + } +} + +pub type WireResult<T> = core::result::Result<T, WireError>; + /////////////////////////////////////////////// /// Parses a [`Packet`] from a borrowed `&[u8]` byte buffer. pub fn packet_from_bytes<'a>(b: &'a [u8], ctx: &ParseContext) -> Result<Packet<'a>> { let mut s = DecodeBytes { input: b, pos: 0, parse_ctx: ctx.clone() }; - Packet::dec(&mut s) + Ok(Packet::dec(&mut s)?) } pub fn read_ssh<'a, T: SSHDecode<'a>>(b: &'a [u8], ctx: Option<ParseContext>) -> Result<T> { let mut s = DecodeBytes { input: b, pos: 0, parse_ctx: ctx.unwrap_or_default() }; - T::dec(&mut s) + Ok(T::dec(&mut s)?) } pub fn write_ssh<T>(target: &mut [u8], value: &T) -> Result<usize> @@ -67,7 +104,7 @@ where T: SSHEncode, { let mut s = EncodeBytes { target, pos: 0 }; - value.enc(&mut s)?; + let r = value.enc(&mut s)?; Ok(s.pos) } @@ -76,7 +113,7 @@ pub fn hash_ser_length<T>(hash_ctx: &mut impl digest::DynDigest, where T: SSHEncode, { - let len = length_enc(value)? as u32; + let len: u32 = length_enc(value)?; hash_ctx.update(&len.to_be_bytes()); hash_ser(hash_ctx, None, value) } @@ -92,13 +129,14 @@ where Ok(()) } -pub fn length_enc<T>(value: &T) -> Result<usize> +/// Returns `WireError::NoRoom` if larger than `u32` +fn length_enc<T>(value: &T) -> WireResult<u32> where T: SSHEncode, { let mut s = EncodeLen { pos: 0 }; value.enc(&mut s)?; - Ok(s.pos) + s.pos.try_into().map_err(|e| WireError::NoRoom) } struct EncodeBytes<'a> { @@ -107,9 +145,9 @@ struct EncodeBytes<'a> { } impl SSHSink for EncodeBytes<'_> { - fn push(&mut self, v: &[u8]) -> Result<()> { + fn push(&mut self, v: &[u8]) -> WireResult<()> { if self.pos + v.len() > self.target.len() { - return Err(Error::NoRoom); + return Err(WireError::NoRoom); } self.target[self.pos..self.pos + v.len()].copy_from_slice(v); self.pos += v.len(); @@ -122,7 +160,7 @@ struct EncodeLen { } impl SSHSink for EncodeLen { - fn push(&mut self, v: &[u8]) -> Result<()> { + fn push(&mut self, v: &[u8]) -> WireResult<()> { self.pos += v.len(); Ok(()) } @@ -134,7 +172,7 @@ struct EncodeHash<'a> { } impl SSHSink for EncodeHash<'_> { - fn push(&mut self, v: &[u8]) -> Result<()> { + fn push(&mut self, v: &[u8]) -> WireResult<()> { self.hash_ctx.update(v); Ok(()) } @@ -151,9 +189,9 @@ struct DecodeBytes<'a> { } impl<'de> SSHSource<'de> for DecodeBytes<'de> { - fn take(&mut self, len: usize) -> Result<&'de [u8]> { + fn take(&mut self, len: usize) -> WireResult<&'de [u8]> { if len > self.input.len() { - return Err(Error::RanOut); + return Err(WireError::RanOut); } let t; (t, self.input) = self.input.split_at(len); @@ -198,7 +236,7 @@ impl<'a> Debug for BinString<'a> { } impl SSHEncode for BinString<'_> { - fn enc<S>(&self, s: &mut S) -> Result<()> + fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: sshwire::SSHSink { (self.0.len() as u32).enc(s)?; self.0.enc(s) @@ -206,7 +244,7 @@ impl SSHEncode for BinString<'_> { } impl<'de> SSHDecode<'de> for BinString<'de> { - fn dec<S>(s: &mut S) -> Result<Self> + fn dec<S>(s: &mut S) -> WireResult<Self> where S: sshwire::SSHSource<'de> { let len = u32::dec(s)? as usize; Ok(BinString(s.take(len)?)) @@ -262,7 +300,7 @@ impl<'a> Debug for TextString<'a> { } impl SSHEncode for TextString<'_> { - fn enc<S>(&self, s: &mut S) -> Result<()> + fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: sshwire::SSHSink { (self.0.len() as u32).enc(s)?; self.0.enc(s) @@ -270,7 +308,7 @@ impl SSHEncode for TextString<'_> { } impl<'de> SSHDecode<'de> for TextString<'de> { - fn dec<S>(s: &mut S) -> Result<Self> + fn dec<S>(s: &mut S) -> WireResult<Self> where S: sshwire::SSHSource<'de> { let len = u32::dec(s)? as usize; Ok(TextString(s.take(len)?)) @@ -294,23 +332,25 @@ impl<B: Clone> Clone for Blob<B> { impl<B: SSHEncode + Debug> Debug for Blob<B> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let len = sshwire::length_enc(&self.0) - .map_err(|_| core::fmt::Error)?; - write!(f, "Blob(len={len}, {:?})", self.0) + if let Ok(len) = sshwire::length_enc(&self.0) { + write!(f, "Blob(len={len}, {:?})", self.0) + } else { + write!(f, "Blob(len>u32, {:?})", self.0) + } } } impl<B: SSHEncode> SSHEncode for Blob<B> { - fn enc<S>(&self, s: &mut S) -> Result<()> + fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: sshwire::SSHSink { - let len: u32 = sshwire::length_enc(&self.0)?.try_into().trap()?; + let len: u32 = sshwire::length_enc(&self.0)?; len.enc(s)?; self.0.enc(s) } } impl<'de, B: SSHDecode<'de>> SSHDecode<'de> for Blob<B> { - fn dec<S>(s: &mut S) -> Result<Self> + fn dec<S>(s: &mut S) -> WireResult<Self> where S: sshwire::SSHSource<'de> { let len = u32::dec(s)?; let inner = SSHDecode::dec(s)?; @@ -322,21 +362,21 @@ impl<'de, B: SSHDecode<'de>> SSHDecode<'de> for Blob<B> { /////////////////////////////////////////////// impl SSHEncode for u8 { - fn enc<S>(&self, s: &mut S) -> Result<()> + fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink { s.push(&[*self]) } } impl SSHEncode for bool { - fn enc<S>(&self, s: &mut S) -> Result<()> + fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink { (*self as u8).enc(s) } } impl SSHEncode for u32 { - fn enc<S>(&self, s: &mut S) -> Result<()> + fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink { s.push(&self.to_be_bytes()) } @@ -344,7 +384,7 @@ impl SSHEncode for u32 { // no length prefix impl SSHEncode for &[u8] { - fn enc<S>(&self, s: &mut S) -> Result<()> + fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink { // data s.push(&self) @@ -353,14 +393,14 @@ impl SSHEncode for &[u8] { // no length prefix impl<const N: usize> SSHEncode for [u8; N] { - fn enc<S>(&self, s: &mut S) -> Result<()> + fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink { s.push(self) } } impl SSHEncode for &str { - fn enc<S>(&self, s: &mut S) -> Result<()> + fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink { let v = self.as_bytes(); // length prefix @@ -370,7 +410,7 @@ impl SSHEncode for &str { } impl<T: SSHEncode> SSHEncode for Option<T> { - fn enc<S>(&self, s: &mut S) -> Result<()> + fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink { if let Some(t) = self.as_ref() { t.enc(s)?; @@ -380,7 +420,7 @@ impl<T: SSHEncode> SSHEncode for Option<T> { } impl SSHEncode for &AsciiStr{ - fn enc<S>(&self, s: &mut S) -> Result<()> + fn enc<S>(&self, s: &mut S) -> WireResult<()> where S: SSHSink { let v = self.as_bytes(); BinString(v).enc(s) @@ -388,7 +428,7 @@ impl SSHEncode for &AsciiStr{ } impl<'de> SSHDecode<'de> for bool { - fn dec<S>(s: &mut S) -> Result<Self> + fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { Ok(u8::dec(s)? != 0) } @@ -398,7 +438,7 @@ impl<'de> SSHDecode<'de> for bool { impl<'de> SSHDecode<'de> for u8 { #[inline] - fn dec<S>(s: &mut S) -> Result<Self> + fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { let t = s.take(core::mem::size_of::<u8>())?; Ok(u8::from_be_bytes(t.try_into().unwrap())) @@ -407,7 +447,7 @@ impl<'de> SSHDecode<'de> for u8 { impl<'de> SSHDecode<'de> for u32 { #[inline] - fn dec<S>(s: &mut S) -> Result<Self> + fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { let t = s.take(core::mem::size_of::<u32>())?; Ok(u32::from_be_bytes(t.try_into().unwrap())) @@ -416,21 +456,21 @@ impl<'de> SSHDecode<'de> for u32 { /// Decodes a SSH name string. Must be ascii /// without control characters. RFC4251 section 6. -pub fn try_as_ascii<'a>(t: &'a [u8]) -> Result<&'a AsciiStr> { - let n = t.as_ascii_str().map_err(|_| Error::BadName)?; +pub fn try_as_ascii<'a>(t: &'a [u8]) -> WireResult<&'a AsciiStr> { + let n = t.as_ascii_str().map_err(|_| WireError::BadName)?; if n.chars().any(|ch| ch.is_ascii_control() || ch == AsciiChar::DEL) { - return Err(Error::BadName); + return Err(WireError::BadName); } Ok(n) } -pub fn try_as_ascii_str<'a>(t: &'a [u8]) -> Result<&'a str> { +pub fn try_as_ascii_str<'a>(t: &'a [u8]) -> WireResult<&'a str> { try_as_ascii(t).map(AsciiStr::as_str) } impl<'de: 'a, 'a> SSHDecode<'de> for &'a str { #[inline] - fn dec<S>(s: &mut S) -> Result<Self> + fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { let len = u32::dec(s)?; let t = s.take(len as usize)?; @@ -439,7 +479,7 @@ impl<'de: 'a, 'a> SSHDecode<'de> for &'a str { } impl<'de: 'a, 'a> SSHDecode<'de> for &'de AsciiStr { - fn dec<S>(s: &mut S) -> Result<&'de AsciiStr> + fn dec<S>(s: &mut S) -> WireResult<&'de AsciiStr> where S: SSHSource<'de>, { let b: BinString = SSHDecode::dec(s)?; @@ -448,7 +488,7 @@ impl<'de: 'a, 'a> SSHDecode<'de> for &'de AsciiStr { } impl<'de, const N: usize> SSHDecode<'de> for [u8; N] { - fn dec<S>(s: &mut S) -> Result<Self> + fn dec<S>(s: &mut S) -> WireResult<Self> where S: SSHSource<'de> { // TODO is there a better way? Or can we return a slice? let mut l = [0u8; N]; diff --git a/sshwire_derive/src/lib.rs b/sshwire_derive/src/lib.rs index ff39dc446b69ad9cd8d2c9b8e04ce741bcf536b3..e6308bbce880c97b3ca90dea8ee404e8f3f8234e 100644 --- a/sshwire_derive/src/lib.rs +++ b/sshwire_derive/src/lib.rs @@ -201,7 +201,7 @@ fn encode_struct(gen: &mut Generator, body: StructBody) -> Result<()> { .with_generic_deps("E", ["crate::sshwire::SSHSink"]) .with_self_arg(FnSelfArg::RefSelf) .with_arg("s", "&mut E") - .with_return_type("Result<()>") + .with_return_type("crate::sshwire::WireResult<()>") .body(|fn_body| { match &body.fields { Fields::Tuple(v) => { @@ -254,7 +254,7 @@ fn encode_enum( .with_generic_deps("S", ["crate::sshwire::SSHSink"]) .with_self_arg(FnSelfArg::RefSelf) .with_arg("s", "&mut S") - .with_return_type("Result<()>") + .with_return_type("crate::sshwire::WireResult<()>") .body(|fn_body| { if cont_atts.iter().any(|c| matches!(c, ContainerAtt::VariantPrefix)) { fn_body.push_parsed("crate::sshwire::SSHEncode::enc(&self.variant_name()?, s)?;")?; @@ -283,7 +283,7 @@ fn encode_enum( Ok(()) })?; if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) { - rhs.push_parsed("return Error::bug_msg(\"Can't encode Unknown\")")?; + rhs.push_parsed("return Err(crate::sshwire::WireError::UnknownVariant)")?; } else { rhs.push_parsed(format!("crate::sshwire::SSHEncode::enc(i, s)?;"))?; } @@ -331,7 +331,7 @@ fn encode_enum_names( gen.impl_for("crate::sshwire::SSHEncodeEnum") .generate_fn("variant_name") .with_self_arg(FnSelfArg::RefSelf) - .with_return_type("Result<&'static str>") + .with_return_type("crate::sshwire::WireResult<&'static str>") .body(|fn_body| { fn_body.push_parsed("let r = match self")?; fn_body.group(Delimiter::Brace, |match_arm| { @@ -343,7 +343,7 @@ fn encode_enum_names( let mut rhs = StreamBuilder::new(); let atts = take_field_atts(&var.attributes)?; if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) { - rhs.push_parsed("return Error::bug_msg(\"Can't encode Unknown\")")?; + rhs.push_parsed("return Err(crate::sshwire::WireError::UnknownVariant)")?; } else { rhs.push(field_att_var_names(&var.name, atts)?); } @@ -383,7 +383,7 @@ fn decode_struct(gen: &mut Generator, body: StructBody) -> Result<()> { .generate_fn("dec") .with_generic_deps("S", ["crate::sshwire::SSHSource<'de>"]) .with_arg("s", "&mut S") - .with_return_type("Result<Self>") + .with_return_type("crate::sshwire::WireResult<Self>") .body(|fn_body| { let mut named_enums = HashSet::new(); if let Fields::Struct(v) = &body.fields { @@ -474,7 +474,7 @@ fn decode_enum_variant_prefix( .generate_fn("dec") .with_generic_deps("S", ["crate::sshwire::SSHSource<'de>"]) .with_arg("s", "&mut S") - .with_return_type("Result<Self>") + .with_return_type("crate::sshwire::WireResult<Self>") .body(|fn_body| { fn_body .push_parsed("let variant: crate::sshwire::BinString = crate::sshwire::SSHDecode::dec(s)?;")?; @@ -495,7 +495,7 @@ fn decode_enum_names( .with_generic_deps("S", ["crate::sshwire::SSHSource<'de>"]) .with_arg("s", "&mut S") .with_arg("variant", "&'de [u8]") - .with_return_type("Result<Self>") + .with_return_type("crate::sshwire::WireResult<Self>") .body(|fn_body| { // Some(ascii_string), or None fn_body.push_parsed("let var_str = crate::sshwire::try_as_ascii_str(variant).ok();")?;