diff --git a/sshproto/src/channel.rs b/sshproto/src/channel.rs index 31ae28a450d6e810e6733b1f726b3ab4f26652a6..30bbc84a11cdd7575a81d24976e04f20c7714573 100644 --- a/sshproto/src/channel.rs +++ b/sshproto/src/channel.rs @@ -197,7 +197,7 @@ impl Channels { } fn dispatch_open(&mut self, p: &ChannelOpen<'_>, - s: &TrafSend, + s: &mut TrafSend, b: &mut Behaviour<'_>, ) -> Result<()> { @@ -218,7 +218,7 @@ impl Channels { // the caller will send failure messages if required fn dispatch_open_inner(&mut self, p: &ChannelOpen<'_>, - s: &TrafSend, + s: &mut TrafSend, b: &mut Behaviour<'_>, ) -> Result<(), DispatchOpenError> { @@ -277,7 +277,7 @@ impl Channels { pub fn dispatch_request(&mut self, p: &packets::ChannelRequest, - _s: &TrafSend, + _s: &mut TrafSend, _b: &mut Behaviour<'_>, ) -> Result<()> { let ch = match self.get(p.num) { @@ -299,7 +299,7 @@ impl Channels { pub async fn dispatch( &mut self, packet: Packet<'_>, - s: &TrafSend<'_>, + s: &mut TrafSend<'_>, b: &mut Behaviour<'_>, ) -> Result<Option<ChanEventMaker>> { trace!("chan dispatch"); @@ -576,7 +576,7 @@ impl Channel { } } - fn request(&mut self, req: ReqDetails, s: &TrafSend) -> Result<()> { + fn request(&mut self, req: ReqDetails, s: &mut TrafSend) -> Result<()> { let num = self.send.as_ref().trap()?.num; let r = Req { num, details: req }; s.send(r.packet()?) diff --git a/sshproto/src/cliauth.rs b/sshproto/src/cliauth.rs index b9a66083778756e761eb600722ca413fbadd0e9e..da61f5bb0a96641f68ac0b05585878575f1ed6e5 100644 --- a/sshproto/src/cliauth.rs +++ b/sshproto/src/cliauth.rs @@ -96,7 +96,7 @@ impl CliAuth { // May be called multiple times pub async fn start<'b>( &'b mut self, - s: &TrafSend<'_>, + s: &mut TrafSend<'_>, b: &mut dyn CliBehaviour, ) -> Result<()> { if let AuthState::Unstarted = self.state { @@ -182,7 +182,7 @@ impl CliAuth { auth60: &packets::Userauth60<'_>, sess_id: &SessId, parse_ctx: &mut ParseContext, - s: &TrafSend<'_>, + s: &mut TrafSend<'_>, ) -> Result<()> { parse_ctx.cli_auth_type = None; @@ -197,7 +197,7 @@ impl CliAuth { pkok: &UserauthPkOk<'_>, sess_id: &SessId, parse_ctx: &mut ParseContext, - s: &TrafSend<'_>, + s: &mut TrafSend<'_>, ) -> Result<()> { // We are only sending keys one at a time so they shouldn't // get out of sync. In future we could change it to send @@ -250,7 +250,7 @@ impl CliAuth { &'b mut self, failure: &packets::UserauthFailure<'_>, parse_ctx: &mut ParseContext, - s: &TrafSend<'_>, + s: &mut TrafSend<'_>, b: &mut dyn CliBehaviour, ) -> Result<()> { parse_ctx.cli_auth_type = None; diff --git a/sshproto/src/client.rs b/sshproto/src/client.rs index eef52c109acebf63008e115d9a93704464c4fa38..2b17b8f6bf1506cf3c7b9677266ba526fbe455a5 100644 --- a/sshproto/src/client.rs +++ b/sshproto/src/client.rs @@ -30,7 +30,7 @@ impl Client { pub(crate) fn auth_success(&mut self, parse_ctx: &mut ParseContext, - s: &TrafSend, + s: &mut TrafSend, b: &mut dyn CliBehaviour) -> Result<()> { parse_ctx.cli_auth_type = None; diff --git a/sshproto/src/conn.rs b/sshproto/src/conn.rs index 6c73b0fdbdc8d8083a9847223b885cd5957068f4..6afc57e853265c26a1d41483c2684a2e2a67957b 100644 --- a/sshproto/src/conn.rs +++ b/sshproto/src/conn.rs @@ -17,7 +17,7 @@ use client::Client; use encrypt::KeyState; use packets::{Packet,ParseContext}; use server::Server; -use traffic::{Traffic, TrafSend}; +use traffic::TrafSend; use channel::{Channels, ChanEvent, ChanEventMaker}; use config::MAX_CHANNELS; use kex::SessId; @@ -120,15 +120,15 @@ impl<'a> Conn<'a> { /// Updates `ConnState` and sends any packets required to progress the connection state. pub(crate) async fn progress<'b>( - &mut self, traffic: &mut Traffic<'b>, keys: &mut KeyState, + &mut self, + s: &mut TrafSend<'_>, b: &mut Behaviour<'_>, ) -> Result<(), Error> { debug!("progress conn state {:?}", self.state); - let s = TrafSend::new(traffic, keys); match self.state { ConnState::SendIdent => { - traffic.send_version(ident::OUR_VERSION)?; - let p = self.kex.send_kexinit(&self.algo_conf, &s)?; + s.send_version(ident::OUR_VERSION)?; + let p = self.kex.send_kexinit(&self.algo_conf, s)?; // TODO: first_follows would have a second packet here self.state = ConnState::ReceiveIdent } @@ -140,9 +140,9 @@ impl<'a> Conn<'a> { ConnState::PreAuth => { // TODO. need to figure how we'll do "unbounded" responses // and backpressure. can_output() should have a size check? - if traffic.can_output() { + if s.can_output() { if let ClientServer::Client(cli) = &mut self.cliserv { - cli.auth.start(&s, b.client()?).await?; + cli.auth.start(s, b.client()?).await?; } } // send userauth request @@ -167,12 +167,12 @@ impl<'a> Conn<'a> { } } - /// Consumes an input payload which is a view into [`traffic::Traffic::buf`]. + /// Consumes an input payload which is a view into [`traffic::Traffic::rxbuf`]. /// We queue response packets that can be sent (written into the same buffer) /// after `handle_payload()` runs. pub(crate) async fn handle_payload<'p>( &mut self, payload: &'p [u8], seq: u32, - s: &TrafSend<'_>, + s: &mut TrafSend<'_>, b: &mut Behaviour<'_>, ) -> Result<Dispatched, Error> { let r = sshwire::packet_from_bytes(payload, &self.parse_ctx); @@ -188,7 +188,7 @@ impl<'a> Conn<'a> { } async fn dispatch_packet<'p>( - &mut self, packet: Packet<'p>, s: &TrafSend<'_>, b: &mut Behaviour<'_>, + &mut self, packet: Packet<'p>, s: &mut TrafSend<'_>, b: &mut Behaviour<'_>, ) -> Result<Dispatched, Error> { // TODO: perhaps could consolidate packet allowed checks into a separate function // to run first? diff --git a/sshproto/src/kex.rs b/sshproto/src/kex.rs index 431f68d63b00614a6b92b2516348cc9b6a80f644..92143eca67063e2ae5365eab338e766dd2301cf8 100644 --- a/sshproto/src/kex.rs +++ b/sshproto/src/kex.rs @@ -202,7 +202,7 @@ impl Kex { pub fn handle_kexinit( &mut self, p: &packets::Packet, is_client: bool, algo_conf: &AlgoConfig, remote_version: &RemoteVersion, - s: &TrafSend, + s: &mut TrafSend, ) -> Result<()> { let remote_kexinit = if let Packet::KexInit(k) = p { k } else { return Err(Error::bug()) }; @@ -229,8 +229,8 @@ impl Kex { } } - pub fn send_kexinit<'a>(&self, conf: &'a AlgoConfig, s: &TrafSend) -> Result<()> { - s.send(packets::KexInit { + fn make_kexinit<'a>(&self, conf: &'a AlgoConfig) -> Packet<'a> { + packets::KexInit { cookie: self.our_cookie, kex: (&conf.kexs).into(), hostkey: (&conf.hostsig).into(), @@ -244,7 +244,11 @@ impl Kex { lang_s2c: (&EMPTY_LOCALNAMES).into(), first_follows: false, reserved: 0, - }) + }.into() + } + + pub fn send_kexinit(&self, conf: &AlgoConfig, s: &mut TrafSend) -> Result<()> { + s.send(self.make_kexinit(conf)) } fn make_kexdhinit(&self) -> Result<Packet> { @@ -256,9 +260,9 @@ impl Kex { } // returns kex output, consumes self. - pub fn handle_kexdhinit<'a>( + pub fn handle_kexdhinit( self, p: &packets::KexDHInit, sess_id: &Option<SessId>, - s: &TrafSend, b: &mut dyn ServBehaviour, + s: &mut TrafSend, b: &mut dyn ServBehaviour, ) -> Result<KexOutput> { if self.algos.as_ref().trap()?.is_client { return Err(Error::bug()); @@ -272,7 +276,7 @@ impl Kex { // consumes self. pub async fn handle_kexdhreply<'f>( self, p: &packets::KexDHReply<'f>, sess_id: &Option<SessId>, - s: &TrafSend<'_>, + s: &mut TrafSend<'_>, b: &mut dyn CliBehaviour, ) -> Result<KexOutput> { if !self.algos.as_ref().trap()?.is_client { @@ -442,7 +446,7 @@ impl SharedSecret { // server only. consumes kex. fn handle_kexdhinit<'a>( mut kex: Kex, p: &packets::KexDHInit, sess_id: &Option<SessId>, - s: &TrafSend, b: &mut dyn ServBehaviour, + s: &mut TrafSend, b: &mut dyn ServBehaviour, ) -> Result<KexOutput> { // let mut algos = kex.algos.take().trap()?; let mut algos = kex.algos.trap()?; @@ -450,29 +454,28 @@ impl SharedSecret { // TODO let fake_hostkey = PubKey::Ed25519(packets::Ed25519PubKey{ key: BinString(&[]) }); kex_hash.prefinish(&fake_hostkey, p.q_c.0, algos.kex.pubkey())?; - let (kex_pub, kex_out) = match algos.kex { - SharedSecret::KexCurve25519(ref k) => { - let pubkey: salty::agreement::PublicKey = k.ours.as_ref().trap()?.into(); + let (kex_out, kex_pub) = match algos.kex { + SharedSecret::KexCurve25519(_) => { let kex_out = KexCurve25519::secret(&mut algos, p.q_c.0, kex_hash, sess_id)?; - (&pubkey.to_bytes(), kex_out) + (kex_out, algos.kex.pubkey()) } }; - kex.send_kexdhreply(kex_pub, s, b)?; + Self::send_kexdhreply(&kex_out, kex_pub, algos.hostsig, s, b)?; Ok(kex_out) } // server only - pub fn send_kexdhreply(&self, kex_pub: &[u8], s: &TrafSend, b: &mut dyn ServBehaviour) -> Result<()> { + pub fn send_kexdhreply(ko: &KexOutput, kex_pub: &[u8], sig_type: SigType, s: &mut TrafSend, b: &mut dyn ServBehaviour) -> Result<()> { let q_s = BinString(kex_pub); // hostkeys list must contain the signature type - let key = b.hostkeys()?.iter().find(|k| k.can_sign(&self.algos.hostsig)).trap()?; + let key = b.hostkeys()?.iter().find(|k| k.can_sign(&sig_type)).trap()?; let k_s = Blob(key.pubkey()); - self.sig = Some(key.sign(&self.h.as_slice(), None)?); - let sig: Signature = self.sig.as_ref().unwrap().into(); + let sig = key.sign(&ko.h.as_slice(), None)?; + let sig: Signature = (&sig).into(); let sig = Blob(sig); - Ok(packets::KexDHReply { k_s, q_s, sig }.into()) + s.send(packets::KexDHReply { k_s, q_s, sig }) } fn pubkey(&self) -> &[u8] { @@ -485,21 +488,11 @@ impl SharedSecret { pub(crate) struct KexOutput { pub h: SessId, pub keys: Keys, - - // storage for kex packet reply content that outlives Kex - // in make_kexdhreply(). - - /// ephemeral public key octet string - kex_pub: Option<[u8; 32]>, - // the negotiated signature type - sig_type: SigType, - sig: Option<sign::OwnedSig>, } impl fmt::Debug for KexOutput { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("KexOutput") - .field("kex_pub", &self.kex_pub.is_some()) .finish_non_exhaustive() } } @@ -519,7 +512,9 @@ impl<'a> KexOutput { } pub(crate) struct KexCurve25519 { + // Initialised in `new()`, cleared after deriving the secret ours: Option<salty::agreement::SecretKey>, + // TODO: it would be nice to avoid having to store this separately, but seems difficult pubkey: [u8; 32], } @@ -537,17 +532,17 @@ impl KexCurve25519 { let mut s = [0u8; 32]; random::fill_random(s.as_mut_slice())?; // TODO: check that pure random bytes are OK - let ours = salty::agreement::SecretKey::from_seed(&s); + let ours = salty::agreement::SecretKey::from_seed(&mut s); let pubkey: salty::agreement::PublicKey = (&ours).into(); let pubkey = pubkey.to_bytes(); Ok(KexCurve25519 { ours: Some(ours), pubkey }) } - fn pubkey<'a>(&'a self) -> &'a [u8] { - self.pubkey.as_slice() + fn pubkey(&self) -> &[u8] { + &self.pubkey } - fn secret<'a>( + fn secret( algos: &mut Algos, theirs: &[u8], kex_hash: KexHash, sess_id: &Option<SessId>, ) -> Result<KexOutput> { @@ -649,6 +644,26 @@ mod tests { } } + // struct BlankTrafSend { + // buf: Vec<u8>, + // keys: encrypt::KeyState, + // } + + // impl BlankTrafSend { + // fn new() -> Self { + // Self { + // buf: vec![0u8, 3000], + // keys: encrypt::KeyState::new_cleartext(), + // } + // } + + // fn sender(&mut self) -> traffic::TrafSend { + // let mut t = traffic::TrafOut::new(&mut self.buf); + // t.sender(&mut self.keys) + // } + // } + + #[test] fn test_agree_kex() { init_test_log(); @@ -679,12 +694,18 @@ mod tests { let ci = cli.make_kexinit(&cli_conf); let ci = reencode(&mut bufc, ci, &ctx); - serv.handle_kexinit(false, &serv_conf, &cli_version, &ci).unwrap(); - cli.handle_kexinit(true, &cli_conf, &serv_version, &si).unwrap(); + // TODO fix this + + // let ts = BlankTrafSend::new(); + // let s = ts.sender(); + // serv.handle_kexinit(&ci, false, &serv_conf, &cli_version, &mut s).unwrap(); + // cli.handle_kexinit(&si, true, &cli_conf, &serv_version, &mut s).unwrap(); + + // let ci = cli.make_kexdhinit().unwrap(); + // let ci = if let Packet::KexDHInit(k) = ci { k } else { panic!() }; + // let sout = serv.handle_kexdhinit(&ci, &None, &mut s, sb).unwrap(); + - let ci = cli.make_kexdhinit().unwrap(); - let ci = if let Packet::KexDHInit(k) = ci { k } else { panic!() }; - let sout = serv.handle_kexdhinit(&ci, &None).unwrap(); // let kexreply = sout.make_kexdhreply(sb); // let kexreply = diff --git a/sshproto/src/runner.rs b/sshproto/src/runner.rs index dc313634e2d580143e25656470c167ea01bf1b9a..d512f9b95063ab2a7efff0e0b59c2fa1d33e1cd0 100644 --- a/sshproto/src/runner.rs +++ b/sshproto/src/runner.rs @@ -10,7 +10,7 @@ use pretty_hex::PrettyHex; use crate::{*, channel::ChanEvent}; use encrypt::KeyState; -use traffic::{Traffic, TrafSend}; +use traffic::{TrafIn, TrafOut, TrafSend}; use conn::{Conn, Dispatched, EventMaker, Event}; use channel::ChanEventMaker; @@ -18,8 +18,10 @@ use channel::ChanEventMaker; pub struct Runner<'a> { conn: Conn<'a>, - /// Binary packet handling to and from the network buffer - traffic: Traffic<'a>, + /// Binary packet handling from the network buffer + traf_in: TrafIn<'a>, + /// Binary packet handling to the network buffer + traf_out: TrafOut<'a>, /// Current encryption/integrity keys keys: KeyState, @@ -37,7 +39,8 @@ impl<'a> Runner<'a> { let conn = Conn::new_client()?; let runner = Runner { conn, - traffic: traffic::Traffic::new(outbuf, inbuf), + traf_in: TrafIn::new(inbuf), + traf_out: TrafOut::new(outbuf), keys: KeyState::new_cleartext(), output_waker: None, input_waker: None, @@ -53,7 +56,8 @@ impl<'a> Runner<'a> { let conn = Conn::new_server()?; let runner = Runner { conn, - traffic: traffic::Traffic::new(outbuf, inbuf), + traf_in: TrafIn::new(inbuf), + traf_out: TrafOut::new(outbuf), keys: KeyState::new_cleartext(), output_waker: None, input_waker: None, @@ -63,7 +67,7 @@ impl<'a> Runner<'a> { } pub fn input(&mut self, buf: &[u8]) -> Result<usize, Error> { - self.traffic.input( + self.traf_in.input( &mut self.keys, &mut self.conn.remote_version, buf, @@ -72,7 +76,7 @@ impl<'a> Runner<'a> { /// Write any pending output to the wire, returning the size written pub fn output(&mut self, buf: &mut [u8]) -> Result<usize, Error> { - let r = self.traffic.output(buf); + let r = self.traf_out.output(buf); if r > 0 { trace!("output() wake"); self.wake(); @@ -87,7 +91,8 @@ impl<'a> Runner<'a> { /// event to the application. /// [`done_payload()`] must be called after any `Ok` result. pub async fn progress<'f>(&'f mut self, behaviour: &mut Behaviour<'_>) -> Result<Option<Event<'f>>, Error> { - let em = if let Some((payload, seq)) = self.traffic.payload() { + let mut s = self.traf_out.sender(&mut self.keys); + let em = if let Some((payload, seq)) = self.traf_in.payload() { // Lifetimes here are a bit subtle. // `payload` has self.traffic lifetime, used until `handle_payload` // completes. @@ -95,14 +100,12 @@ impl<'a> Runner<'a> { // by the send_packet(). // After that progress() can perform more send_packet() itself. - // TODO matt aug: trafsend should be constructed by traffic.split_send() or something. - let s = TrafSend::new(&mut self.traffic, &mut self.keys); - let d = self.conn.handle_payload(payload, seq, &s, behaviour).await?; - self.traffic.handled_payload()?; + let d = self.conn.handle_payload(payload, seq, &mut s, behaviour).await?; + self.traf_in.handled_payload()?; if d.event.is_none() { // switch to using the buffer for output. - self.traffic.done_payload()?; + self.traf_in.done_payload()?; } d.event @@ -121,20 +124,19 @@ impl<'a> Runner<'a> { match em { EventMaker::Channel(ChanEventMaker::DataIn(di)) => { trace!("chanmaaker {di:?}"); - self.traffic.done_payload()?; - self.traffic.set_channel_input(di)?; + self.traf_in.done_payload()?; + self.traf_in.set_channel_input(di)?; // TODO: channel wakers None } _ => { // Some(payload) is only required for some variants in make_event() - let payload = self.traffic.payload_reborrow(); - self.conn.make_event(payload, em)? + panic!("delete this codepath") } } } else { trace!("no em, conn progress"); - self.conn.progress(&mut self.traffic, &mut self.keys, behaviour).await?; + self.conn.progress(&mut s, behaviour).await?; self.wake(); None }; @@ -144,7 +146,7 @@ impl<'a> Runner<'a> { } pub fn done_payload(&mut self) -> Result<()> { - self.traffic.done_payload()?; + self.traf_in.done_payload()?; self.wake(); Ok(()) } @@ -184,7 +186,7 @@ impl<'a> Runner<'a> { } let (ch, p) = self.conn.channels.open(packets::ChannelOpenType::Session, init_req)?; let chan = ch.number(); - self.traffic.send_packet(p, &mut self.keys)?; + self.traf_out.send_packet(p, &mut self.keys)?; self.wake(); Ok(chan) } @@ -207,7 +209,7 @@ impl<'a> Runner<'a> { let len = len.min(buf.len()); let p = self.conn.channels.send_data(chan, ext, &buf[..len])?; - self.traffic.send_packet(p, &mut self.keys)?; + self.traf_out.send_packet(p, &mut self.keys)?; self.wake(); Ok(Some(len)) } @@ -220,11 +222,11 @@ impl<'a> Runner<'a> { buf: &mut [u8], ) -> Result<usize> { trace!("runner chan in"); - let (len, complete) = self.traffic.channel_input(chan, ext, buf); + let (len, complete) = self.traf_in.channel_input(chan, ext, buf); if complete { let p = self.conn.channels.finished_input(chan)?; if let Some(p) = p { - self.traffic.send_packet(p, &mut self.keys)?; + self.traf_out.send_packet(p, &mut self.keys)?; } self.wake(); } @@ -232,11 +234,11 @@ impl<'a> Runner<'a> { } pub fn ready_input(&self) -> bool { - self.conn.initial_sent() && self.traffic.ready_input() + self.conn.initial_sent() && self.traf_in.ready_input() } pub fn output_pending(&self) -> bool { - self.traffic.output_pending() + self.traf_out.output_pending() } pub fn set_input_waker(&mut self, waker: Waker) { @@ -248,7 +250,7 @@ impl<'a> Runner<'a> { } pub fn ready_channel_input(&self) -> Option<(u32, Option<u32>)> { - self.traffic.ready_channel_input() + self.traf_in.ready_channel_input() } pub fn channel_eof(&self, chan: u32) -> bool { @@ -258,7 +260,7 @@ impl<'a> Runner<'a> { // Returns None on channel closed pub fn ready_channel_send(&self, chan: u32) -> Option<usize> { // minimum of buffer space and channel window available - let buf_space = self.traffic.send_allowed(&self.keys); + let buf_space = self.traf_out.send_allowed(&self.keys); self.conn.channels.send_allowed(chan).map(|s| s.min(buf_space)) } diff --git a/sshproto/src/traffic.rs b/sshproto/src/traffic.rs index a8d123f21d59c54019629f28f3cee7e6b69fe947..9381675d170d11f8e99106f813ccff09b3e0c931 100644 --- a/sshproto/src/traffic.rs +++ b/sshproto/src/traffic.rs @@ -11,7 +11,7 @@ use crate::*; use crate::packets::Packet; use pretty_hex::PrettyHex; -pub(crate) struct Traffic<'a> { +pub(crate) struct TrafOut<'a> { // TODO: if smoltcp exposed both ends of a CircularBuffer to recv() // we could perhaps just work directly in smoltcp's provided buffer? // Would need changes to ciphers with block boundaries @@ -23,12 +23,25 @@ pub(crate) struct Traffic<'a> { /// Contains ciphertext or cleartext, encrypted/decrypted in-place. /// Writing may contain multiple SSH packets to write out, encrypted /// in-place as they are written to `buf`. - tx_buf: &'a mut [u8], - /// Only contains a single SSH packet at a time. - rx_buf: &'a mut [u8], + buf: &'a mut [u8], + state: TxState, +} - tx_state: TxState, - rx_state: RxState, +pub(crate) struct TrafIn<'a> { + // TODO: if smoltcp exposed both ends of a CircularBuffer to recv() + // we could perhaps just work directly in smoltcp's provided buffer? + // Would need changes to ciphers with block boundaries + + // TODO: decompression will need another buffer + /// Accumulated input or output buffer. + /// Should be sized to fit the largest packet allowed for input, or + /// sequence of packets to be sent at once (see [`conn::MAX_RESPONSES`]). + /// Contains ciphertext or cleartext, encrypted/decrypted in-place. + /// Writing may contain multiple SSH packets to write out, encrypted + /// in-place as they are written to `buf`. + /// Only contains a single SSH packet at a time. + buf: &'a mut [u8], + state: RxState, } /// State machine for writes @@ -78,16 +91,13 @@ enum RxState { }, } -impl<'a> Traffic<'a> { - pub fn new(rx_buf: &'a mut [u8], tx_buf: &'a mut [u8]) -> Self { - Traffic { tx_buf, rx_buf, - tx_state: TxState::Idle, - rx_state: RxState::Idle, - } +impl<'a> TrafIn<'a> { + pub fn new(buf: &'a mut [u8]) -> Self { + Self { buf, state: RxState::Idle } } pub fn ready_input(&self) -> bool { - match self.rx_state { + match self.state { RxState::Idle | RxState::ReadInitial { .. } | RxState::Read { .. } => true, @@ -99,41 +109,14 @@ impl<'a> Traffic<'a> { } } - pub fn output_pending(&self) -> bool { - match self.tx_state { - TxState::Write { .. } => true, - _ => false - } - } - - /// A simple test if a packet can be sent. `send_allowed` should be used - /// for more general situations - pub fn can_output(&self) -> bool { - // TODO don't use this - true - } - - /// Returns payload space available to send a packet. Returns 0 if not ready or full - pub fn send_allowed(&self, keys: &KeyState) -> usize { - // TODO: test for full output buffer - match self.tx_state { - TxState::Write { len, .. } => { - keys.max_enc_payload(self.tx_buf.len() - len) - } - TxState::Idle => { - keys.max_enc_payload(self.tx_buf.len()) - } - } - } - /// Returns the number of bytes consumed. pub fn input( &mut self, keys: &mut KeyState, remote_version: &mut RemoteVersion, buf: &[u8], ) -> Result<usize, Error> { let mut inlen = 0; - trace!("state {:?} input {}", self.rx_state, buf.len()); - if remote_version.version().is_none() && matches!(self.rx_state, RxState::Idle) { + trace!("state {:?} input {}", self.state, buf.len()); + if remote_version.version().is_none() && matches!(self.state, RxState::Idle) { // Handle initial version string let l; l = remote_version.consume(buf)?; @@ -142,29 +125,16 @@ impl<'a> Traffic<'a> { let buf = &buf[inlen..]; inlen += self.fill_input(keys, buf)?; - trace!("after inlen {inlen} state {:?}", self.rx_state); + trace!("after inlen {inlen} state {:?}", self.state); Ok(inlen) } - /// Returns a reference to the decrypted payload buffer if ready, - /// and the `seq` of that packet. - pub(crate) fn payload(&mut self) -> Option<(&[u8], u32)> { - match self.rx_state { - | RxState::InPayload { len, seq } - => { - let payload = &self.rx_buf[SSH_PAYLOAD_START..SSH_PAYLOAD_START + len]; - Some((payload, seq)) - } - _ => None, - } - } - pub(crate) fn payload_reborrow(&mut self) -> Option<&[u8]> { - match self.rx_state { + match self.state { | RxState::InPayload { len, .. } | RxState::BorrowPayload { len, .. } => { - let payload = &self.rx_buf[SSH_PAYLOAD_START..SSH_PAYLOAD_START + len]; + let payload = &self.buf[SSH_PAYLOAD_START..SSH_PAYLOAD_START + len]; Some(payload) } _ => None, @@ -174,11 +144,11 @@ impl<'a> Traffic<'a> { /// Called when `payload()` has been handled once, can still be /// `payload_reborrow()`ed later. pub(crate) fn handled_payload(&mut self) -> Result<(), Error> { - match self.rx_state { + match self.state { | RxState::InPayload { len, .. } | RxState::BorrowPayload { len } => { - self.rx_state = RxState::BorrowPayload { len }; + self.state = RxState::BorrowPayload { len }; Ok(()) } _ => Err(Error::bug()) @@ -187,11 +157,11 @@ impl<'a> Traffic<'a> { /// Called when `payload()` and `payload_reborrow()` are complete. pub(crate) fn done_payload(&mut self) -> Result<(), Error> { - match self.rx_state { + match self.state { | RxState::InPayload { .. } | RxState::BorrowPayload { .. } => { - self.rx_state = RxState::Idle; + self.state = RxState::Idle; Ok(()) } _ => { @@ -202,71 +172,17 @@ impl<'a> Traffic<'a> { } } - pub fn send_version(&mut self, buf: &[u8]) -> Result<(), Error> { - if !matches!(self.tx_state, TxState::Idle) { - return Err(Error::bug()); - } - - if buf.len() + 2 > self.tx_buf.len() { - return Err(Error::NoRoom); - } - - self.tx_buf[..buf.len()].copy_from_slice(buf); - self.tx_buf[buf.len()] = ident::CR; - self.tx_buf[buf.len()+1] = ident::LF; - self.tx_state = TxState::Write { idx: 0, len: buf.len() + 2 }; - Ok(()) - } - - /// Serializes and and encrypts a packet to send - pub fn send_packet(&mut self, p: packets::Packet, keys: &mut KeyState) -> Result<()> { - trace!("send_packet {:?}", p.message_num()); - - // Either a fresh buffer or appending to write - let (idx, len) = match self.tx_state { - TxState::Idle => (0, 0), - TxState::Write { idx, len } => (idx, len), - _ => { - trace!("bad state {:?}", self.tx_state); - Err(Error::bug())? + /// Returns a reference to the decrypted payload buffer if ready, + /// and the `seq` of that packet. + pub(crate) fn payload(&mut self) -> Option<(&[u8], u32)> { + match self.state { + | RxState::InPayload { len, seq } + => { + let payload = &self.buf[SSH_PAYLOAD_START..SSH_PAYLOAD_START + len]; + Some((payload, seq)) } - }; - - // Use the remainder of our buffer to write the packet. Payload starts - // after the length and padding bytes which get filled by encrypt() - let wbuf = &mut self.tx_buf[len..]; - if wbuf.len() < SSH_PAYLOAD_START { - return Err(Error::NoRoom) + _ => None, } - let plen = sshwire::write_ssh(&mut wbuf[SSH_PAYLOAD_START..], &p)?; - trace!("Sending {p:?}"); - trace!("new {plen} {:?}", (&wbuf[SSH_PAYLOAD_START..SSH_PAYLOAD_START+plen]).hex_dump()); - - // Encrypt in place - let elen = keys.encrypt(plen, wbuf)?; - self.tx_state = TxState::Write { idx, len: len+elen }; - Ok(()) - - } - - /// Write any pending output, returning the size written - pub fn output(&mut self, buf: &mut [u8]) -> usize { - let r = match self.tx_state { - TxState::Write { ref mut idx, len } => { - let wlen = (len - *idx).min(buf.len()); - buf[..wlen].copy_from_slice(&self.tx_buf[*idx..*idx + wlen]); - *idx += wlen; - - if *idx == len { - // all done, read the next packet - self.tx_state = TxState::Idle - } - wlen - } - _ => 0, - }; - trace!("output state now {:?}", self.tx_state); - r } fn fill_input( @@ -279,7 +195,7 @@ impl<'a> Traffic<'a> { // Fill the initial block from either Idle with input, // partial initial block - if let Some(idx) = match self.rx_state { + if let Some(idx) = match self.state { RxState::Idle if r.len() > 0 => Some(0), RxState::ReadInitial { idx } => Some(idx), _ => None, @@ -287,67 +203,67 @@ impl<'a> Traffic<'a> { let need = (size_block - idx).clamp(0, r.len()); let x; (x, r) = r.split_at(need); - let w = &mut self.rx_buf[idx..idx + need]; + let w = &mut self.buf[idx..idx + need]; w.copy_from_slice(x); - self.rx_state = RxState::ReadInitial { idx: idx + need } + self.state = RxState::ReadInitial { idx: idx + need } } // Have enough input now to decrypt the packet length - if let RxState::ReadInitial { idx } = self.rx_state { + if let RxState::ReadInitial { idx } = self.state { if idx >= size_block { - let w = &mut self.rx_buf[..size_block]; + let w = &mut self.buf[..size_block]; let total_len = keys.decrypt_first_block(w)? as usize; - if total_len > self.rx_buf.len() { + if total_len > self.buf.len() { // TODO: Or just BadDecrypt could make more sense if // it were packet corruption/decryption failure return Err(Error::BigPacket { size: total_len }); } - self.rx_state = RxState::Read { idx, expect: total_len } + self.state = RxState::Read { idx, expect: total_len } } } // Know expected length, read until the end of the packet. // We have already validated that expect_len <= buf_size - if let RxState::Read { ref mut idx, expect } = self.rx_state { + if let RxState::Read { ref mut idx, expect } = self.state { let need = (expect - *idx).min(r.len()); let x; (x, r) = r.split_at(need); - let w = &mut self.rx_buf[*idx..*idx + need]; + let w = &mut self.buf[*idx..*idx + need]; w.copy_from_slice(x); *idx += need; if *idx == expect { - self.rx_state = RxState::ReadComplete { len: expect } + self.state = RxState::ReadComplete { len: expect } } } - if let RxState::ReadComplete { len } = self.rx_state { - let w = &mut self.rx_buf[0..len]; + if let RxState::ReadComplete { len } = self.state { + let w = &mut self.buf[..len]; let seq = keys.recv_seq(); let payload_len = keys.decrypt(w)?; - self.rx_state = RxState::InPayload { len: payload_len, seq } + self.state = RxState::InPayload { len: payload_len, seq } } Ok(buf.len() - r.len()) } pub fn ready_channel_input(&self) -> Option<(u32, Option<u32>)> { - match self.rx_state { + match self.state { RxState::InChannelData { chan, ext, .. } => Some((chan, ext)), _ => None, } } pub fn set_channel_input(&mut self, di: channel::DataIn) -> Result<()> { - trace!("traf chan input state {:?}", self.rx_state); - match self.rx_state { + trace!("traf chan input state {:?}", self.state); + match self.state { RxState::Idle => { let idx = SSH_PAYLOAD_START + di.offset; - self.rx_state = RxState::InChannelData { chan: di.num, ext: di.ext, idx, len: idx + di.len }; + self.state = RxState::InChannelData { chan: di.num, ext: di.ext, idx, len: idx + di.len }; // error!("set input {:?}", self.state); - trace!("all buf {:?}", self.rx_buf[..32].hex_dump()); + trace!("all buf {:?}", self.buf[..32].hex_dump()); trace!("set chan input offset {} idx {} {:?}", di.offset, idx, - self.rx_buf[idx..idx + di.len].hex_dump()); + self.buf[idx..idx + di.len].hex_dump()); Ok(()) } _ => Err(Error::bug()), @@ -362,22 +278,22 @@ impl<'a> Traffic<'a> { ext: Option<u32>, buf: &mut [u8], ) -> (usize, bool) { - trace!("channel input {chan} {ext:?} st {:?}", self.rx_state); + trace!("channel input {chan} {ext:?} st {:?}", self.state); - match self.rx_state { + match self.state { RxState::InChannelData { chan: c, ext: e, ref mut idx, len } if (c, e) == (chan, ext) => { if *idx > len { error!("bad idx {} len {} e {:?} c {}", *idx, len, e, c); } let wlen = (len - *idx).min(buf.len()); - buf[..wlen].copy_from_slice(&self.rx_buf[*idx..*idx + wlen]); + buf[..wlen].copy_from_slice(&self.buf[*idx..*idx + wlen]); // info!("idx {} += wlen {} = {}", *idx, wlen, *idx+wlen); *idx += wlen; if *idx == len { // all done. - self.rx_state = RxState::Idle; + self.state = RxState::Idle; (wlen, true) } else { (wlen, false) @@ -389,24 +305,140 @@ impl<'a> Traffic<'a> { } +impl<'a> TrafOut<'a> { + pub fn new(buf: &'a mut [u8]) -> Self { + Self { buf, state: TxState::Idle } + } + + /// Serializes and and encrypts a packet to send + pub(crate) fn send_packet(&mut self, p: packets::Packet, keys: &mut KeyState) -> Result<()> { + trace!("send_packet {:?}", p.message_num()); + + // Either a fresh buffer or appending to write + let (idx, len) = match self.state { + TxState::Idle => (0, 0), + TxState::Write { idx, len } => (idx, len), + _ => { + trace!("bad state {:?}", self.state); + Err(Error::bug())? + } + }; + + // Use the remainder of our buffer to write the packet. Payload starts + // after the length and padding bytes which get filled by encrypt() + let wbuf = &mut self.buf[len..]; + if wbuf.len() < SSH_PAYLOAD_START { + return Err(Error::NoRoom) + } + let plen = sshwire::write_ssh(&mut wbuf[SSH_PAYLOAD_START..], &p)?; + trace!("Sending {p:?}"); + trace!("new {plen} {:?}", (&wbuf[SSH_PAYLOAD_START..SSH_PAYLOAD_START+plen]).hex_dump()); + + // Encrypt in place + let elen = keys.encrypt(plen, wbuf)?; + self.state = TxState::Write { idx, len: len+elen }; + Ok(()) + + } + + pub fn output_pending(&self) -> bool { + match self.state { + TxState::Write { .. } => true, + _ => false + } + } + + /// A simple test if a packet can be sent. `send_allowed` should be used + /// for more general situations + pub fn can_output(&self) -> bool { + // TODO don't use this + true + } + + /// Returns payload space available to send a packet. Returns 0 if not ready or full + pub fn send_allowed(&self, keys: &KeyState) -> usize { + // TODO: test for full output buffer + match self.state { + TxState::Write { len, .. } => { + keys.max_enc_payload(self.buf.len() - len) + } + TxState::Idle => { + keys.max_enc_payload(self.buf.len()) + } + } + } + + pub fn send_version(&mut self, buf: &[u8]) -> Result<(), Error> { + if !matches!(self.state, TxState::Idle) { + return Err(Error::bug()); + } + + if buf.len() + 2 > self.buf.len() { + return Err(Error::NoRoom); + } + + self.buf[..buf.len()].copy_from_slice(buf); + self.buf[buf.len()] = ident::CR; + self.buf[buf.len()+1] = ident::LF; + self.state = TxState::Write { idx: 0, len: buf.len() + 2 }; + Ok(()) + } + + /// Write any pending output, returning the size written + pub fn output(&mut self, buf: &mut [u8]) -> usize { + let r = match self.state { + TxState::Write { ref mut idx, len } => { + let wlen = (len - *idx).min(buf.len()); + buf[..wlen].copy_from_slice(&self.buf[*idx..*idx + wlen]); + *idx += wlen; + + if *idx == len { + // all done, read the next packet + self.state = TxState::Idle + } + wlen + } + _ => 0, + }; + trace!("output state now {:?}", self.state); + r + } + + + pub fn sender(&mut self, keys: &'a mut KeyState) -> TrafSend { + TrafSend::new(self, keys) + } + +} + +/// Convenience to pass TrafOut with keys pub(crate) struct TrafSend<'a> { - traffic: &'a mut Traffic<'a>, keys: &'a mut KeyState, + out: &'a mut TrafOut<'a>, } impl<'a> TrafSend<'a> { - pub fn new(traffic: &mut Traffic, keys: &mut KeyState) -> Self { + fn new(out: &'a mut TrafOut<'a>, keys: &'a mut KeyState) -> Self { Self { - traffic, + out, keys, } } - pub fn send<'p, P: Into<packets::Packet<'p>>>(&self, p: P) -> Result<()> { - self.traffic.send_packet(p.into(), self.keys) + pub fn send<'p, P: Into<packets::Packet<'p>>>(&mut self, p: P) -> Result<()> { + self.out.send_packet(p.into(), self.keys) } - pub fn rekey(&self, keys: encrypt::Keys) { + + pub fn rekey(&mut self, keys: encrypt::Keys) { self.keys.rekey(keys) } + + pub fn send_version(&mut self, buf: &[u8]) -> Result<(), Error> { + self.out.send_version(buf) + } + + pub fn can_output(&self) -> bool { + self.out.can_output() + } }