diff --git a/async/examples/con1.rs b/async/examples/con1.rs index b4dba8e13c9825fc887684cd9f69babec7bb9fe5..264eb1d34bf1698c65cd67e20c2ebdd2942988e8 100644 --- a/async/examples/con1.rs +++ b/async/examples/con1.rs @@ -141,14 +141,16 @@ async fn run(args: &Args) -> Result<()> { let work = vec![0; 3000]; // TODO: better lifetime rather than leaking - let work = Box::leak(Box::new(work)); + let work = Box::leak(Box::new(work)).as_mut_slice(); + let tx = vec![0; 3000]; + let tx = Box::leak(Box::new(tx)).as_mut_slice(); let mut cli = door_async::CmdlineClient::new(args.username.as_ref().unwrap()); for i in &args.identityfile { cli.add_authkey(read_key(&i).with_context(|| format!("loading key {i}"))?); } - let mut door = SSHClient::new(work.as_mut_slice(), Box::new(cli))?; + let mut door = SSHClient::new(work, tx, Box::new(cli))?; let mut s = door.socket(); moro::async_scope!(|scope| { diff --git a/async/src/client.rs b/async/src/client.rs index 463441c5ae0118cd77d15cbfabc871a947267dab..cf63026d73cb8aea66e63d813fd3012a558ca5b6 100644 --- a/async/src/client.rs +++ b/async/src/client.rs @@ -30,9 +30,11 @@ pub struct SSHClient<'a> { } impl<'a> SSHClient<'a> { - pub fn new(buf: &'a mut [u8], behaviour: Box<dyn AsyncCliBehaviour+Send>) -> Result<Self> { + pub fn new(inbuf: &'a mut [u8], + outbuf: &'a mut [u8], + behaviour: Box<dyn AsyncCliBehaviour+Send>) -> Result<Self> { let b = Behaviour::new_async_client(behaviour); - let runner = Runner::new_client(buf, b)?; + let runner = Runner::new_client(inbuf, outbuf, b)?; let door = AsyncDoor::new(runner); Ok(Self { door diff --git a/sshproto/src/runner.rs b/sshproto/src/runner.rs index 1d952fc3ab330a5fd5781f29894e571a21c11f0c..4c99e4db97e55c89b2eb6993cc41b5d85ece53ae 100644 --- a/sshproto/src/runner.rs +++ b/sshproto/src/runner.rs @@ -31,15 +31,16 @@ pub struct Runner<'a> { } impl<'a> Runner<'a> { - /// `iobuf` must be sized to fit the largest SSH packet allowed. + /// `inbuf` must be sized to fit the largest SSH packet allowed. pub fn new_client( - iobuf: &'a mut [u8], + inbuf: &'a mut [u8], + outbuf: &'a mut [u8], behaviour: Behaviour<'a>, ) -> Result<Runner<'a>, Error> { let conn = Conn::new_client()?; let runner = Runner { conn, - traffic: traffic::Traffic::new(iobuf), + traffic: traffic::Traffic::new(outbuf, inbuf), keys: KeyState::new_cleartext(), output_waker: None, input_waker: None, diff --git a/sshproto/src/traffic.rs b/sshproto/src/traffic.rs index 73e05808fa6bba780f47e7eaa67a825fdc9f79ad..551e3bcf6b7e897879b3f292cf5bbf2803326527 100644 --- a/sshproto/src/traffic.rs +++ b/sshproto/src/traffic.rs @@ -24,15 +24,35 @@ pub(crate) struct Traffic<'a> { /// When reading only contains a single SSH packet at a time. /// Writing may contain multiple SSH packets to write out, encrypted /// in-place as they are written to `buf`. - buf: &'a mut [u8], - state: TrafState, + tx_buf: &'a mut [u8], + rx_buf: &'a mut [u8], + + tx_state: TxState, + rx_state: RxState, } -/// State machine for reads/writes sharing [`Traffic::buf`] +/// State machine for writes #[derive(Debug)] -enum TrafState { +enum TxState { - /// Awaiting read or write, buffer is unused + /// Awaiting write, buffer is unused + Idle, + + /// Writing to the socket. Buffer is encrypted in-place. + /// Should never be left in `idx==len` state, + /// instead should transition to Idle + Write { + /// Cursor position in the buffer + idx: usize, + /// Buffer available to write + len: usize, + }, +} + +#[derive(Debug)] +enum RxState { + + /// Awaiting read, buffer is unused Idle, /// Reading initial encrypted block for packet length. idx > 0. ReadInitial { idx: usize }, @@ -56,16 +76,6 @@ enum TrafState { /// length of buffer, end of channel data len: usize, }, - - /// Writing to the socket. Buffer is encrypted in-place. - /// Should never be left in `idx==len` state, - /// instead should transition to Idle - Write { - /// Cursor position in the buffer - idx: usize, - /// Buffer available to write - len: usize, - }, } #[derive(Debug)] @@ -96,26 +106,29 @@ impl<'a> PacketMaker<'a> { } impl<'a> Traffic<'a> { - pub fn new(buf: &'a mut [u8]) -> Self { - Traffic { buf, state: TrafState::Idle } + pub fn new(rx_buf: &'a mut [u8], tx_buf: &'a mut [u8]) -> Self { + Traffic { tx_buf, rx_buf, + tx_state: TxState::Idle, + rx_state: RxState::Idle, + } } pub fn ready_input(&self) -> bool { - match self.state { - TrafState::Idle - | TrafState::ReadInitial { .. } - | TrafState::Read { .. } => true, - TrafState::ReadComplete { .. } - | TrafState::InPayload { .. } - | TrafState::BorrowPayload { .. } - | TrafState::InChannelData { .. } - | TrafState::Write { .. } => false, + match self.rx_state { + RxState::Idle + | RxState::ReadInitial { .. } + | RxState::Read { .. } => true, + RxState::ReadComplete { .. } + | RxState::InPayload { .. } + | RxState::BorrowPayload { .. } + | RxState::InChannelData { .. } + => false, } } pub fn output_pending(&self) -> bool { - match self.state { - TrafState::Write { .. } => true, + match self.tx_state { + TxState::Write { .. } => true, _ => false } } @@ -123,24 +136,20 @@ impl<'a> Traffic<'a> { /// A simple test if a packet can be sent. `send_allowed` should be used /// for more general situations pub fn can_output(&self) -> bool { - match self.state { - TrafState::Write { .. } - | TrafState::Idle => true, - _ => false - } + // TODO don't use this + true } /// Returns payload space available to send a packet. Returns 0 if not ready or full pub fn send_allowed(&self, keys: &KeyState) -> usize { // TODO: test for full output buffer - match self.state { - TrafState::Write { len, .. } => { - keys.max_enc_payload(self.buf.len() - len) + match self.tx_state { + TxState::Write { len, .. } => { + keys.max_enc_payload(self.tx_buf.len() - len) } - TrafState::Idle => { - keys.max_enc_payload(self.buf.len()) + TxState::Idle => { + keys.max_enc_payload(self.tx_buf.len()) } - _ => 0 } } @@ -150,8 +159,8 @@ impl<'a> Traffic<'a> { buf: &[u8], ) -> Result<usize, Error> { let mut inlen = 0; - trace!("state {:?} input {}", self.state, buf.len()); - if remote_version.version().is_none() && matches!(self.state, TrafState::Idle) { + trace!("state {:?} input {}", self.rx_state, buf.len()); + if remote_version.version().is_none() && matches!(self.rx_state, RxState::Idle) { // Handle initial version string let l; l = remote_version.consume(buf)?; @@ -160,17 +169,17 @@ impl<'a> Traffic<'a> { let buf = &buf[inlen..]; inlen += self.fill_input(keys, buf)?; - trace!("after inlen {inlen} state {:?}", self.state); + trace!("after inlen {inlen} state {:?}", self.rx_state); Ok(inlen) } /// Returns a reference to the decrypted payload buffer if ready, /// and the `seq` of that packet. pub(crate) fn payload(&mut self) -> Option<(&[u8], u32)> { - match self.state { - | TrafState::InPayload { len, seq } + match self.rx_state { + | RxState::InPayload { len, seq } => { - let payload = &self.buf[SSH_PAYLOAD_START..SSH_PAYLOAD_START + len]; + let payload = &self.rx_buf[SSH_PAYLOAD_START..SSH_PAYLOAD_START + len]; Some((payload, seq)) } _ => None, @@ -178,11 +187,11 @@ impl<'a> Traffic<'a> { } pub(crate) fn payload_reborrow(&mut self) -> Option<&[u8]> { - match self.state { - | TrafState::InPayload { len, .. } - | TrafState::BorrowPayload { len, .. } + match self.rx_state { + | RxState::InPayload { len, .. } + | RxState::BorrowPayload { len, .. } => { - let payload = &self.buf[SSH_PAYLOAD_START..SSH_PAYLOAD_START + len]; + let payload = &self.rx_buf[SSH_PAYLOAD_START..SSH_PAYLOAD_START + len]; Some(payload) } _ => None, @@ -192,11 +201,11 @@ impl<'a> Traffic<'a> { /// Called when `payload()` has been handled once, can still be /// `payload_reborrow()`ed later. pub(crate) fn handled_payload(&mut self) -> Result<(), Error> { - match self.state { - | TrafState::InPayload { len, .. } - | TrafState::BorrowPayload { len } + match self.rx_state { + | RxState::InPayload { len, .. } + | RxState::BorrowPayload { len } => { - self.state = TrafState::BorrowPayload { len }; + self.rx_state = RxState::BorrowPayload { len }; Ok(()) } _ => Err(Error::bug()) @@ -205,11 +214,11 @@ impl<'a> Traffic<'a> { /// Called when `payload()` and `payload_reborrow()` are complete. pub(crate) fn done_payload(&mut self) -> Result<(), Error> { - match self.state { - | TrafState::InPayload { .. } - | TrafState::BorrowPayload { .. } + match self.rx_state { + | RxState::InPayload { .. } + | RxState::BorrowPayload { .. } => { - self.state = TrafState::Idle; + self.rx_state = RxState::Idle; Ok(()) } _ => { @@ -221,18 +230,18 @@ impl<'a> Traffic<'a> { } pub fn send_version(&mut self, buf: &[u8]) -> Result<(), Error> { - if !matches!(self.state, TrafState::Idle) { + if !matches!(self.tx_state, TxState::Idle) { return Err(Error::bug()); } - if buf.len() + 2 > self.buf.len() { + if buf.len() + 2 > self.tx_buf.len() { return Err(Error::NoRoom); } - self.buf[..buf.len()].copy_from_slice(buf); - self.buf[buf.len()] = ident::CR; - self.buf[buf.len()+1] = ident::LF; - self.state = TrafState::Write { idx: 0, len: buf.len() + 2 }; + self.tx_buf[..buf.len()].copy_from_slice(buf); + self.tx_buf[buf.len()] = ident::CR; + self.tx_buf[buf.len()+1] = ident::LF; + self.tx_state = TxState::Write { idx: 0, len: buf.len() + 2 }; Ok(()) } @@ -241,18 +250,18 @@ impl<'a> Traffic<'a> { trace!("send_packet {:?}", p.message_num()); // Either a fresh buffer or appending to write - let (idx, len) = match self.state { - TrafState::Idle => (0, 0), - TrafState::Write { idx, len } => (idx, len), + let (idx, len) = match self.tx_state { + TxState::Idle => (0, 0), + TxState::Write { idx, len } => (idx, len), _ => { - trace!("bad state {:?}", self.state); + trace!("bad state {:?}", self.tx_state); Err(Error::bug())? } }; // Use the remainder of our buffer to write the packet. Payload starts // after the length and padding bytes which get filled by encrypt() - let wbuf = &mut self.buf[len..]; + let wbuf = &mut self.tx_buf[len..]; if wbuf.len() < SSH_PAYLOAD_START { return Err(Error::NoRoom) } @@ -262,28 +271,28 @@ impl<'a> Traffic<'a> { // Encrypt in place let elen = keys.encrypt(plen, wbuf)?; - self.state = TrafState::Write { idx, len: len+elen }; + self.tx_state = TxState::Write { idx, len: len+elen }; Ok(()) } /// Write any pending output, returning the size written pub fn output(&mut self, buf: &mut [u8]) -> usize { - let r = match self.state { - TrafState::Write { ref mut idx, len } => { + let r = match self.tx_state { + TxState::Write { ref mut idx, len } => { let wlen = (len - *idx).min(buf.len()); - buf[..wlen].copy_from_slice(&self.buf[*idx..*idx + wlen]); + buf[..wlen].copy_from_slice(&self.tx_buf[*idx..*idx + wlen]); *idx += wlen; if *idx == len { // all done, read the next packet - self.state = TrafState::Idle + self.tx_state = TxState::Idle } wlen } _ => 0, }; - trace!("output state now {:?}", self.state); + trace!("output state now {:?}", self.tx_state); r } @@ -297,75 +306,75 @@ impl<'a> Traffic<'a> { // Fill the initial block from either Idle with input, // partial initial block - if let Some(idx) = match self.state { - TrafState::Idle if r.len() > 0 => Some(0), - TrafState::ReadInitial { idx } => Some(idx), + if let Some(idx) = match self.rx_state { + RxState::Idle if r.len() > 0 => Some(0), + RxState::ReadInitial { idx } => Some(idx), _ => None, } { let need = (size_block - idx).clamp(0, r.len()); let x; (x, r) = r.split_at(need); - let w = &mut self.buf[idx..idx + need]; + let w = &mut self.rx_buf[idx..idx + need]; w.copy_from_slice(x); - self.state = TrafState::ReadInitial { idx: idx + need } + self.rx_state = RxState::ReadInitial { idx: idx + need } } // Have enough input now to decrypt the packet length - if let TrafState::ReadInitial { idx } = self.state { + if let RxState::ReadInitial { idx } = self.rx_state { if idx >= size_block { - let w = &mut self.buf[..size_block]; + let w = &mut self.rx_buf[..size_block]; let total_len = keys.decrypt_first_block(w)? as usize; - if total_len > self.buf.len() { + if total_len > self.rx_buf.len() { // TODO: Or just BadDecrypt could make more sense if // it were packet corruption/decryption failure return Err(Error::BigPacket { size: total_len }); } - self.state = TrafState::Read { idx, expect: total_len } + self.rx_state = RxState::Read { idx, expect: total_len } } } // Know expected length, read until the end of the packet. // We have already validated that expect_len <= buf_size - if let TrafState::Read { ref mut idx, expect } = self.state { + if let RxState::Read { ref mut idx, expect } = self.rx_state { let need = (expect - *idx).min(r.len()); let x; (x, r) = r.split_at(need); - let w = &mut self.buf[*idx..*idx + need]; + let w = &mut self.rx_buf[*idx..*idx + need]; w.copy_from_slice(x); *idx += need; if *idx == expect { - self.state = TrafState::ReadComplete { len: expect } + self.rx_state = RxState::ReadComplete { len: expect } } } - if let TrafState::ReadComplete { len } = self.state { - let w = &mut self.buf[0..len]; + if let RxState::ReadComplete { len } = self.rx_state { + let w = &mut self.rx_buf[0..len]; let seq = keys.recv_seq(); let payload_len = keys.decrypt(w)?; - self.state = TrafState::InPayload { len: payload_len, seq } + self.rx_state = RxState::InPayload { len: payload_len, seq } } Ok(buf.len() - r.len()) } pub fn ready_channel_input(&self) -> Option<(u32, Option<u32>)> { - match self.state { - TrafState::InChannelData { chan, ext, .. } => Some((chan, ext)), + match self.rx_state { + RxState::InChannelData { chan, ext, .. } => Some((chan, ext)), _ => None, } } pub fn set_channel_input(&mut self, di: channel::DataIn) -> Result<()> { - trace!("traf chan input state {:?}", self.state); - match self.state { - TrafState::Idle => { + trace!("traf chan input state {:?}", self.rx_state); + match self.rx_state { + RxState::Idle => { let idx = SSH_PAYLOAD_START + di.offset; - self.state = TrafState::InChannelData { chan: di.num, ext: di.ext, idx, len: idx + di.len }; + self.rx_state = RxState::InChannelData { chan: di.num, ext: di.ext, idx, len: idx + di.len }; // error!("set input {:?}", self.state); - trace!("all buf {:?}", self.buf[..32].hex_dump()); + trace!("all buf {:?}", self.rx_buf[..32].hex_dump()); trace!("set chan input offset {} idx {} {:?}", di.offset, idx, - self.buf[idx..idx + di.len].hex_dump()); + self.rx_buf[idx..idx + di.len].hex_dump()); Ok(()) } _ => Err(Error::bug()), @@ -380,22 +389,22 @@ impl<'a> Traffic<'a> { ext: Option<u32>, buf: &mut [u8], ) -> (usize, bool) { - trace!("channel input {chan} {ext:?} st {:?}", self.state); + trace!("channel input {chan} {ext:?} st {:?}", self.rx_state); - match self.state { - TrafState::InChannelData { chan: c, ext: e, ref mut idx, len } + match self.rx_state { + RxState::InChannelData { chan: c, ext: e, ref mut idx, len } if (c, e) == (chan, ext) => { if *idx > len { error!("bad idx {} len {} e {:?} c {}", *idx, len, e, c); } let wlen = (len - *idx).min(buf.len()); - buf[..wlen].copy_from_slice(&self.buf[*idx..*idx + wlen]); + buf[..wlen].copy_from_slice(&self.rx_buf[*idx..*idx + wlen]); // info!("idx {} += wlen {} = {}", *idx, wlen, *idx+wlen); *idx += wlen; if *idx == len { // all done. - self.state = TrafState::Idle; + self.rx_state = RxState::Idle; (wlen, true) } else { (wlen, false)