diff --git a/Cargo.lock b/Cargo.lock index d56c88ea462bdb758e7c35a93cb4987834cc4619..ef71bfe9c6aeb92d161c3feb7531d297b3fa0d0c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -318,6 +318,7 @@ dependencies = [ "argh", "async-trait", "door-sshproto", + "futures", "log", "parking_lot", "pin-utils", @@ -419,6 +420,95 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "futures" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f73fe65f54d1e12b726f517d3e2135ca3125a437b6d998caf1962961f7172d9e" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3083ce4b914124575708913bca19bfe887522d6e2e6d0952943f5eac4a74010" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3" + +[[package]] +name = "futures-executor" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9420b90cfa29e327d0429f19be13e7ddb68fa1cccb09d65e5706b8c7a749b8a6" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc4045962a5a5e935ee2fdedaa4e08284547402885ab326734432bed5d12966b" + +[[package]] +name = "futures-macro" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33c1e13800337f4d4d7a316bf45a567dbcb6ffe087f16424852d97e97a91f512" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21163e139fa306126e6eedaf49ecdb4588f939600f0b1e770f4205ee4b7fa868" + +[[package]] +name = "futures-task" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c66a976bf5909d801bbef33416c41372779507e7a6b3a5e25e4749c58f776a" + +[[package]] +name = "futures-util" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "generic-array" version = "0.14.5" @@ -1034,6 +1124,12 @@ dependencies = [ "time", ] +[[package]] +name = "slab" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb703cfe953bccee95685111adeedb76fabe4e97549a58d16f03ea7b9367bb32" + [[package]] name = "smallvec" version = "1.8.0" diff --git a/smol/Cargo.toml b/smol/Cargo.toml index 28baad706905f49be3192a6d1cf6ce211115610b..ee81e5f1a3acbf3f3f755e112f5a02e3660fd98f 100644 --- a/smol/Cargo.toml +++ b/smol/Cargo.toml @@ -11,13 +11,13 @@ rpassword = "6.0" argh = "0.1" # smol = { version = "1.2" } -# futures = "0.3" # futures-io = "0.3" # futures-micro = "0.5" # async-dup = "1.2" pin-utils = "0.1" tokio = { version = "1.17", features = ["full"] } +futures = "0.3" parking_lot = { version = "0.12", features = ["arc_lock", "send_guard"] } async-trait = "0.1" diff --git a/smol/examples/con1.rs b/smol/examples/con1.rs index 74f3d30cb1275a2bb705010af69117ba13b4a2f9..59d10ddea192ff2028831f74a22dad54569317da 100644 --- a/smol/examples/con1.rs +++ b/smol/examples/con1.rs @@ -164,6 +164,10 @@ async fn run(args: &Args) -> Result<()> { loop { tokio::select! { e = &mut netio => break e.map(|_| ()).context("net loop"), + ev = door.progress(|ev| { + trace!("progress event {ev:?}"); + Ok(()) + }) => {} // q = door.next_request() => { // handle_request(&door, q).await // } diff --git a/smol/src/async_client.rs b/smol/src/async_client.rs index 4c02e1cb2c9e84b8da8b165513b2b028c55be736..0602748acc3cc1e8d591188e7d547dc4226b40c3 100644 --- a/smol/src/async_client.rs +++ b/smol/src/async_client.rs @@ -39,19 +39,16 @@ impl SimpleClient { #[async_trait(?Send)] impl door::AsyncCliBehaviour for SimpleClient { - async fn chan_handler<'f>( + async fn chan_handler( &mut self, resp: &mut RespPackets, - chan_msg: ChanMsg<'f>, + chan_msg: ChanMsg, ) -> Result<()> { if Some(chan_msg.num) != self.main_ch { return Err(Error::SSHProtoError); } match chan_msg.msg { - ChanMsgDetails::Data(buf) => { - let _ = tokio::io::stdout().write_all(buf).await; - } ChanMsgDetails::ExtData { .. } => {} ChanMsgDetails::Req { .. } => {} _ => {} diff --git a/smol/src/async_door.rs b/smol/src/async_door.rs index d38078636943fa44be3e612131263923a975a0dd..b57b19a1db59e9e780a09261a5f4a862958fcb8d 100644 --- a/smol/src/async_door.rs +++ b/smol/src/async_door.rs @@ -7,14 +7,15 @@ use core::task::{Context, Poll}; use pin_utils::pin_mut; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::sync::Mutex as TokioMutex; +use tokio::sync::Notify as TokioNotify; use std::io::Error as IoError; use std::io::ErrorKind; +use core::task::Waker; use std::sync::{Arc, Mutex, MutexGuard}; - -use parking_lot::lock_api::ArcMutexGuard; -use parking_lot::Mutex as ParkingLotMutex; +use futures::task::AtomicWaker; // TODO use anyhow::{anyhow, Context as _, Error, Result}; @@ -34,88 +35,65 @@ pub struct Inner<'a> { behaviour: Behaviour<'a>, } +#[derive(Clone)] pub struct AsyncDoor<'a> { - inner: Arc<ParkingLotMutex<Inner<'a>>>, - out_progress_fut: - Option<Pin<Box<dyn Future<Output = Result<(), DoorError>> + 'a>>>, -} + inner: Arc<TokioMutex<Inner<'a>>>, -impl Clone for AsyncDoor<'_> { - fn clone(&self) -> Self { - Self { inner: self.inner.clone(), out_progress_fut: None } - } + read_waker: Arc<AtomicWaker>, + write_waker: Arc<AtomicWaker>, + progress_notify: Arc<TokioNotify>, } impl<'a> AsyncDoor<'a> { pub fn new(runner: Runner<'a>, behaviour: Behaviour<'a>) -> Self { - let inner = Inner { runner, behaviour }; - Self { inner: Arc::new(ParkingLotMutex::new(inner)), out_progress_fut: None } + let inner = Arc::new(TokioMutex::new(Inner { runner, behaviour })); + let read_waker = Arc::new(AtomicWaker::new()); + let write_waker = Arc::new(AtomicWaker::new()); + let progress_notify = Arc::new(TokioNotify::new()); + Self { inner, read_waker, write_waker, progress_notify } } - // TODO this should go away, or perhaps pass the function down to the Behaviour - pub fn with_behaviour<F, R>(&self, f: F) -> R - where - F: FnOnce(&mut Behaviour<'a>) -> R, - { - f(&mut self.lock().behaviour) - } - - fn lock(&self) -> parking_lot::MutexGuard<Inner<'a>> { - self.inner.lock() + pub async fn progress<F, R>(&mut self, f: F) + -> Result<Option<R>> where F: FnOnce(door::Event) -> Result<R> { + { + self.progress_notify.notified().await; + let res = { + let mut inner = self.inner.lock().await; + let inner = inner.deref_mut(); + let ev = inner.runner.progress(&mut inner.behaviour).await.context("progess")?; + if let Some(ev) = ev { + f(ev).map(|r| Some(r)) + } else { + Ok(None) + } + }; + // self.read_waker.take().map(|w| w.wake()); + // self.write_waker.take().map(|w| w.wake()); + res + } } - - // fn poll_write_channel( - // self: Pin<&mut Self>, - // channel: u32, - // cx: &mut Context<'_>, - // buf: &[u8], - // ) -> Poll<Result<usize, IoError>> { - // } - } - impl<'a> AsyncRead for AsyncDoor<'a> { fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf, ) -> Poll<Result<(), IoError>> { trace!("poll_read"); - let r = if let Some(f) = self.out_progress_fut.as_mut() { - f.as_mut().poll(cx) - .map_err(|e| IoError::new(ErrorKind::Other, e)) + // try to lock, or return pending + self.read_waker.register(cx.waker()); + let mut inner = self.inner.try_lock(); + let runner = if let Ok(ref mut inner) = inner { + &mut inner.deref_mut().runner } else { - // TODO this blocks - let mut inner = ParkingLotMutex::lock_arc(&self.inner); - - // TODO: should this be conditional on the result of the poll? - inner.runner.set_output_waker(cx.waker().clone()); - // async move block to capture `inner` - let mut b = Box::pin(async move { - let inner = inner.deref_mut(); - inner.runner.out_progress(&mut inner.behaviour).await - }); - // let mut b = Box::pin(guard_wait(inner)); - let r = b.as_mut().poll(cx); - if let Poll::Pending = r { - self.out_progress_fut = Some(b); - } - r.map_err(|e| IoError::new(ErrorKind::Other, e)) - }?; - if let Poll::Pending = r { - return Poll::Pending; - } else { - self.out_progress_fut = None - } - - let runner = &mut self.lock().runner; + return Poll::Pending + }; let b = buf.initialize_unfilled(); let r = runner.output(b).map_err(|e| IoError::new(ErrorKind::Other, e)); - trace!("runner output {r:?}"); let r = match r { // sz=0 means EOF Ok(0) => Poll::Pending, @@ -126,7 +104,8 @@ impl<'a> AsyncRead for AsyncDoor<'a> { } Err(e) => Poll::Ready(Err(e)), }; - info!("finish poll_read {r:?}"); + drop(inner); + self.write_waker.take().map(|w| w.wake()); r } } @@ -138,9 +117,15 @@ impl<'a> AsyncWrite for AsyncDoor<'a> { buf: &[u8], ) -> Poll<Result<usize, IoError>> { trace!("poll_write"); - // TODO: this lock is blocking - let runner = &mut self.lock().runner; - runner.set_input_waker(cx.waker().clone()); + + // try to lock, or return pending + self.write_waker.register(cx.waker()); + let mut inner = self.inner.try_lock(); + let runner = if let Ok(ref mut inner) = inner { + &mut inner.deref_mut().runner + } else { + return Poll::Pending + }; // TODO: should runner just have poll_write/poll_read? // TODO: is ready_input necessary? .input() should return size=0 @@ -153,7 +138,9 @@ impl<'a> AsyncWrite for AsyncDoor<'a> { } else { Poll::Pending }; - trace!("poll_write {r:?}"); + drop(inner); + self.progress_notify.notify_one(); + // self.read_waker.take().map(|w| w.wake()); r } diff --git a/sshproto/Cargo.toml b/sshproto/Cargo.toml index 01b8fb6ab76868984b4c6f2f794bee111bb965ba..3e59c52593805f86a8804683e8e46e94c442fcb5 100644 --- a/sshproto/Cargo.toml +++ b/sshproto/Cargo.toml @@ -2,6 +2,8 @@ name = "door-sshproto" version = "0.1.0" edition = "2021" +categories = ["network-programming", "no-std"] +keywords = ["ssh"] [dependencies] sshwire_derive = { path = "../sshwire_derive" } @@ -11,9 +13,10 @@ snafu = { version = "0.7", default-features = false, features = ["rust_1_46"] } log = { version = "0.4" } heapless = "0.7.10" no-panic = "0.1" + +# allows avoiding utf8 for SSH identifier names ascii = { version = "1.0", default-features = false } -# TODO: needs changing for embedded platforms rand = { version = "0.8", default-features = false } rand_core = { version = "0.6", default-features = false } @@ -23,10 +26,14 @@ sha2 = { version = "0.10", default-features = false } hmac = "0.12" digest = "0.10" signature = { version = "1.4", default-features = false } -ssh-key = { version = "0.4", default-features = false, features = ["ed25519", "ecdsa", "sha2"] } chacha20 = "0.9" poly1305 = "0.7" +# ed25519/x25519 +salty = { version = "0.2", path = "/home/matt/3rd/rs/salty" } +# could be optional? though isn't linked if openssh keys aren't loaded +ssh-key = { version = "0.4", default-features = false, features = ["ed25519", "ecdsa", "sha2"] } +# for debug printing pretty-hex = { version = "0.3", default-features = false } pin-utils = "0.1" @@ -34,8 +41,6 @@ pin-utils = "0.1" # tokio = { version = "1.18", features = ["sync"], optional = true } async-trait = { version = "0.1", optional = true } -salty = { version = "0.2", path = "/home/matt/3rd/rs/salty" } - [features] default = [ "getrandom" ] std = ["async-trait", "snafu/std"] diff --git a/sshproto/src/async_behaviour.rs b/sshproto/src/async_behaviour.rs index 5ad2bf32e07cf847c6270a26588bc4abd5653f1c..729f07748581db7f27fe7b282cadf9847623c651 100644 --- a/sshproto/src/async_behaviour.rs +++ b/sshproto/src/async_behaviour.rs @@ -50,7 +50,7 @@ impl AsyncCliServ { } } - pub(crate) async fn chan_handler<'f>(&mut self, resp: &mut RespPackets<'_>, chan_msg: ChanMsg<'f>) -> Result<()> { + pub(crate) async fn chan_handler(&mut self, resp: &mut RespPackets<'_>, chan_msg: ChanMsg) -> Result<()> { match self { Self::Client(i) => i.chan_handler(resp, chan_msg).await, Self::Server(i) => i.chan_handler(resp, chan_msg), @@ -60,7 +60,7 @@ impl AsyncCliServ { #[async_trait(?Send)] pub trait AsyncCliBehaviour { - async fn chan_handler<'f>(&mut self, resp: &mut RespPackets, chan_msg: ChanMsg<'f>) -> Result<()>; + async fn chan_handler(&mut self, resp: &mut RespPackets, chan_msg: ChanMsg) -> Result<()>; /// Should not block fn progress(&mut self, runner: &mut Runner) -> Result<()> { Ok(()) } diff --git a/sshproto/src/behaviour.rs b/sshproto/src/behaviour.rs index ed7bdc99383386a4c6c3695b02af45477878358f..43db98e0c99cb2296dce8c6891b695a45613d1f1 100644 --- a/sshproto/src/behaviour.rs +++ b/sshproto/src/behaviour.rs @@ -74,7 +74,7 @@ impl Behaviour<'_> { self.inner.progress(runner) } - pub(crate) async fn chan_handler<'f>(&mut self, resp: &mut RespPackets<'_>, chan_msg: ChanMsg<'f>) -> Result<()> { + pub(crate) async fn chan_handler(&mut self, resp: &mut RespPackets<'_>, chan_msg: ChanMsg) -> Result<()> { self.inner.chan_handler(resp, chan_msg).await } @@ -108,7 +108,7 @@ impl<'a> Behaviour<'a> self.inner.progress(runner) } - pub(crate) async fn chan_handler<'f>(&mut self, resp: &mut RespPackets<'_>, chan_msg: ChanMsg<'f>) -> Result<()> { + pub(crate) async fn chan_handler(&mut self, resp: &mut RespPackets<'_>, chan_msg: ChanMsg) -> Result<()> { self.inner.chan_handler(resp, chan_msg) } diff --git a/sshproto/src/block_behaviour.rs b/sshproto/src/block_behaviour.rs index e169a213091ed4231e4ba11ba5f649cd968435c2..1612b9f4130fe087ff7154d6b3f912ffd1026f31 100644 --- a/sshproto/src/block_behaviour.rs +++ b/sshproto/src/block_behaviour.rs @@ -50,7 +50,7 @@ impl BlockCliServ<'_> pub(crate) fn chan_handler<'f>( &mut self, resp: &mut RespPackets<'_>, - chan_msg: ChanMsg<'f>, + chan_msg: ChanMsg, ) -> Result<()> { match self { Self::Client(i) => i.chan_handler(resp, chan_msg), @@ -63,7 +63,7 @@ pub trait BlockCliBehaviour { fn chan_handler<'f>( &mut self, resp: &mut RespPackets, - chan_msg: ChanMsg<'f>, + chan_msg: ChanMsg, ) -> Result<()>; /// Should not block diff --git a/sshproto/src/channel.rs b/sshproto/src/channel.rs index cb161a30d31b205c9b5520edabf00441dbe3bbec..6ab37a54abb29f832a0e6e1ed051a7f3460a813d 100644 --- a/sshproto/src/channel.rs +++ b/sshproto/src/channel.rs @@ -10,17 +10,24 @@ use heapless::{Deque, String, Vec}; use crate::{conn::RespPackets, *}; use config::*; -use packets::{ChannelReqType, ChannelRequest, Packet, ChannelOpenType}; +use packets::{ChannelReqType, ChannelRequest, Packet, ChannelOpenType, ChannelData, ChannelDataExt}; +use sshwire::BinString; pub(crate) struct Channels { ch: [Option<Channel>; config::MAX_CHANNELS], + + /// The size of data last set with `ChanEvent::DataIn`. + pending_input: Option<PendInput>, } pub(crate) type InitReqs = Vec<ReqDetails, MAX_INIT_REQS>; impl Channels { pub fn new() -> Self { - Channels { ch: Default::default() } + Channels { + ch: Default::default(), + pending_input: None, + } } pub fn open<'b>( @@ -38,17 +45,7 @@ impl Channels { ) .ok_or(Error::NoChannels)?; - let chan = Channel { - state: ChanState::Opening { init_req }, - ty: (&ty).into(), - last_req: Deque::new(), - recv: ChanDir { - num, - max_packet: config::DEFAULT_MAX_PACKET, - window: config::DEFAULT_WINDOW, - }, - send: None, - }; + let chan = Channel::new(num, (&ty).into(), init_req); let p = packets::ChannelOpen { num, initial_window: chan.recv.window as u32, @@ -78,14 +75,46 @@ impl Channels { // Ok(()) } + /// Returns the channel data packet to send, and the length of data consumed + pub(crate) fn send_data<'b>(&mut self, num: u32, ext: Option<u32>, data: &'b [u8]) + -> Result<(Packet<'b>, usize)> { + let send_ch = self.get_chan(num)?.send.as_ref().trap()?.num; + // TODO: check: channel state, channel window, maxpacket + let len = data.len(); + let data = BinString(data); + let p = if let Some(code) = ext { + // TODO: check code is valid for this channel + packets::ChannelDataExt { num: send_ch, code, data }.into() + } else { + packets::ChannelData { num: send_ch, data }.into() + }; + Ok((p, len)) + } + + /// Informs the channel layer that an incoming packet has been read out, + /// so a window adjustment can be queued. + pub(crate) fn finished_input(&mut self, num: u32) -> Result<()> { + match self.pending_input { + Some(ref p) if p.chan == num => { + // TODO: send window adjustment + let len = p.len; + let ch = self.get_chan(num)?; + ch.finished_input(len); + self.pending_input = None; + Ok(()) + } + _ => Err(Error::bug()), + } + } + // incoming packet handling - pub async fn dispatch<'a>( + pub async fn dispatch( &mut self, - packet: &Packet<'a>, + packet: Packet<'_>, resp: &mut RespPackets<'_>, b: &mut Behaviour<'_>, - ) -> Result<()> { - trace!("chan dispatchh"); + ) -> Result<Option<ChanEventMaker>> { + trace!("chan dispatch"); let r = match packet { Packet::ChannelOpen(_p) => { todo!(); @@ -108,7 +137,7 @@ impl Channels { } ch.state = ChanState::Normal; } - Ok(()) + Ok(None) } _ => Err(Error::SSHProtoError), } @@ -118,21 +147,33 @@ impl Channels { if ch.send.is_some() { Err(Error::SSHProtoError) } else { - self.remove(p.num) + self.remove(p.num); + // TODO event + Ok(None) } } Packet::ChannelWindowAdjust(p) => { todo!(); } Packet::ChannelData(p) => { - b.chan_handler( - resp, - ChanMsg { num: p.num, msg: ChanMsgDetails::Data(p.data.0) }, - ) - .await + let ch = self.get_chan(p.num)?; + // TODO check we are expecting input + if self.pending_input.is_some() { + return Err(Error::bug()) + } + self.pending_input = Some(PendInput { chan: p.num, len: p.data.0.len() }); + let di = DataIn { num: p.num, ext: None, offset: p.data_offset(), len: p.data.0.len() }; + Ok(Some(ChanEventMaker::DataIn(di))) } - Packet::ChannelDataExt(_p) => { - todo!(); + Packet::ChannelDataExt(p) => { + let ch = self.get_chan(p.num)?; + // TODO check we are expecting input and ext is valid. + if self.pending_input.is_some() { + return Err(Error::bug()) + } + self.pending_input = Some(PendInput { chan: p.num, len: p.data.0.len() }); + let di = DataIn { num: p.num, ext: Some(p.code), offset: p.data_offset(), len: p.data.0.len() }; + Ok(Some(ChanEventMaker::DataIn(di))) } Packet::ChannelEof(_p) => { todo!(); @@ -140,12 +181,20 @@ impl Channels { Packet::ChannelClose(_p) => { todo!(); } - Packet::ChannelRequest(_p) => { - todo!(); + Packet::ChannelRequest(p) => { + match self.get_chan(p.num) { + Ok(ch) => Ok(Some(ChanEventMaker::Req)), + Err(ch) => { + if p.want_reply { + // TODO respond with an error + } + Ok(None) + } + } } Packet::ChannelSuccess(_p) => { trace!("channel success, TODO"); - Ok(()) + Ok(None) } Packet::ChannelFailure(_p) => { todo!(); @@ -155,9 +204,9 @@ impl Channels { match r { Err(Error::BadChannel) => { warn!("Ignoring bad channel number"); - Ok(()) + Ok(None) } - Ok(()) => Ok(()), + Ok(ev) => Ok(ev), // TODO: close channel on error? or on SSHProtoError? Err(any) => Err(any), } @@ -243,7 +292,7 @@ impl Req { ReqDetails::WinChange(rt) => ChannelReqType::WinChange(rt.clone()), ReqDetails::Break(rt) => ChannelReqType::Break(rt.clone()), }; - let p = ChannelRequest { num, want_reply, ch: ty }.into(); + let p = ChannelRequest { num, want_reply, req: ty }.into(); Ok(p) } } @@ -274,7 +323,10 @@ pub struct ChanDir { } pub enum ChanState { - // TODO: this is wasting half a kB. where else could we store it? + /// init_req are the request messages to be sent once the ChannelOpenConfirmation + /// is received + // TODO: this is wasting half a kB. where else could we store it? could + // the Behaviour own it? Or we don't store them here, just callback to the Behaviour. Opening { init_req: InitReqs }, Normal, DrainRead, @@ -290,9 +342,26 @@ pub struct Channel { recv: ChanDir, // filled after confirmation send: Option<ChanDir>, + + /// Accumulated bytes for the next window adjustment (inbound data direction) + pending_adjust: usize, } impl Channel { + fn new(num: u32, ty: ChanType, init_req: InitReqs) -> Self { + Channel { + state: ChanState::Opening { init_req }, + ty, + last_req: Deque::new(), + recv: ChanDir { + num, + max_packet: config::DEFAULT_MAX_PACKET, + window: config::DEFAULT_WINDOW, + }, + send: None, + pending_adjust: 0, + } + } fn request(&mut self, req: ReqDetails, resp: &mut RespPackets) -> Result<()> { let num = self.send.as_ref().trap()?.num; let r = Req { num, details: req }; @@ -303,33 +372,101 @@ impl Channel { pub(crate) fn number(&self) -> u32 { self.recv.num } + + fn finished_input(&mut self, len: usize ) { + self.pending_adjust = self.pending_adjust.saturating_add(len) + } } -pub struct ChanMsg<'a> { +pub struct ChanMsg { pub num: u32, - pub msg: ChanMsgDetails<'a>, + pub msg: ChanMsgDetails, } -pub enum ChanMsgDetails<'a> { - Data(&'a [u8]), - ExtData { ext: u32, data: &'a [u8] }, +pub enum ChanMsgDetails { + Data, + ExtData { ext: u32 }, // TODO: perhaps we don't need the storaged ReqDetails, just have the reqtype packet? Req(ReqDetails), // TODO closein/closeout/eof, etc. Should also return the exit status etc Close, } -pub enum ChanOut { - // Size written into [`channel_output()`](runner::Runner::channel_output) - // `buf` argument. - Data(usize), - // Size written into [`channel_output()`](runner::Runner::channel_output) - // `buf` argument. - ExtData { ext: u32, size: usize }, - // TODO: perhaps we don't need the storaged ReqDetails, just have the reqtype packet? - Req(ReqDetails), +#[derive(Debug)] +pub(crate) struct DataIn { + pub num: u32, + pub ext: Option<u32>, + pub offset: usize, + pub len: usize, +} + +/// An event returned from `Channel::dispatch()`. +/// Most are propagated to the application, `DataIn is caught by `runner` +#[derive(Debug)] +pub(crate) enum ChanEventMaker { + /// Channel data is ready with `channel_input()`. This breaks the `Packet` abstraction + /// by returning the offset into the payload buffer, used by `traffic`. + DataIn(DataIn), + + OpenSuccess { num: u32 }, + + // A ChannelRequest. Will be split into separate ChanEvent variants + // for each type. + Req, // TODO closein/closeout/eof, etc. Should also return the exit status etc - // TODO: responses to a previous ChanMsg - Close, + Close { num: u32 }, + // TODO: responses to a previous ChanMsg? +} + +impl ChanEventMaker { + // To be called on the same packet that created the ChanEventMaker. + pub fn make<'p>(&self, packet: Packet<'p>) -> Option<ChanEvent<'p>> { + match self { + // Datain is handled at the traffic level, not propagated as an Event + Self::DataIn(_) => { + debug!("DataIn should not be reached"); + None + } + Self::OpenSuccess { num } => Some(ChanEvent::OpenSuccess { num: *num }), + Self::Req => { + if let Packet::ChannelRequest(ChannelRequest { num, want_reply, req }) = packet { + match req { + ChannelReqType::Pty(pty) => Some(ChanEvent::ReqPty { num, want_reply, pty }), + _ => { + warn!("Unhandled {:?}", self); + None + } + } + } else { + None + } + } + Self::Close { num } => Some(ChanEvent::Close { num: *num }), + } + + } +} + +/// Application API +#[derive(Debug)] +pub enum ChanEvent<'a> { + // TODO: perhaps this one should go a level above since it isn't for existing channels? + OpenSuccess { num: u32 }, + + // TODO details + // OpenRequest { }, + + ReqPty { num: u32, want_reply: bool, pty: packets::Pty<'a> }, + + Req { num: u32, req: ChannelReqType<'a> }, + // TODO closein/closeout/eof, etc. Should also return the exit status etc + + Close { num: u32 }, + // TODO: responses to a previous ChanMsg? +} + +struct PendInput { + chan: u32, + len: usize, } diff --git a/sshproto/src/conn.rs b/sshproto/src/conn.rs index 28726ca0ae9845cd307949aaf29037b136b74b7e..0ff407e7a1a27212997fb21a2d46993f1dd3b11d 100644 --- a/sshproto/src/conn.rs +++ b/sshproto/src/conn.rs @@ -27,6 +27,10 @@ pub(crate) const MAX_RESPONSES: usize = 4; pub type RespPackets<'a> = heapless::Vec<PacketMaker<'a>, MAX_RESPONSES>; +pub(crate) enum Handled<'a> { + Response(RespPackets<'a>), +} + /// The core state of a SSH instance. pub struct Conn<'a> { state: ConnState, @@ -87,6 +91,16 @@ enum ConnState { // Cleanup ?? } +// Application API +#[derive(Debug)] +pub enum Event<'a> { + Channel(channel::ChanEvent<'a>), +} + +pub(crate) enum EventMaker { + Channel(channel::ChanEventMaker), +} + impl<'a> Conn<'a> { pub fn new_client() -> Result<Self> { Self::new(ClientServer::Client(client::Client::new())) @@ -161,21 +175,22 @@ impl<'a> Conn<'a> { /// Consumes an input payload which is a view into [`traffic::Traffic::buf`]. /// We queue response packets that can be sent (written into the same buffer) /// after `handle_payload()` runs. - pub(crate) async fn handle_payload( - &mut self, payload: &[u8], keys: &mut KeyState, b: &mut Behaviour<'_>, - ) -> Result<RespPackets<'_>, Error> { + pub(crate) async fn handle_payload<'p>( + &mut self, payload: &'p [u8], keys: &mut KeyState, b: &mut Behaviour<'_>, + ) -> Result<Dispatched<'_>, Error> { let p = sshwire::packet_from_bytes(payload, &self.parse_ctx)?; - let r = self.dispatch_packet(&p, keys, b).await; + let r = self.dispatch_packet(p, keys, b).await; r } - async fn dispatch_packet( - &mut self, packet: &Packet<'_>, keys: &mut KeyState, b: &mut Behaviour<'_>, - ) -> Result<RespPackets<'_>, Error> { + async fn dispatch_packet<'p>( + &mut self, packet: Packet<'p>, keys: &mut KeyState, b: &mut Behaviour<'_>, + ) -> Result<Dispatched<'_>, Error> { // TODO: perhaps could consolidate packet allowed checks into a separate function // to run first? trace!("Incoming {packet:#?}"); let mut resp = RespPackets::new(); + let mut ev = None; match packet { Packet::KexInit(_) => { if matches!(self.state, ConnState::InKex { .. }) { @@ -189,7 +204,7 @@ impl<'a> Conn<'a> { self.cliserv.is_client(), &self.algo_conf, &self.remote_version, - packet, + &packet, )?; if let Some(r) = r { resp.push(r.into()).trap()?; @@ -207,7 +222,7 @@ impl<'a> Conn<'a> { } else { let kex = core::mem::replace(&mut self.kex, kex::Kex::new()?); - *output = Some(kex.handle_kexdhinit(p, &self.sess_id)?); + *output = Some(kex.handle_kexdhinit(&p, &self.sess_id)?); let reply = output.as_ref().trap()?.make_kexdhreply()?; resp.push(reply.into()).trap()?; } @@ -224,7 +239,7 @@ impl<'a> Conn<'a> { } else { let kex = core::mem::replace(&mut self.kex, kex::Kex::new()?); - *output = Some(kex.handle_kexdhreply(p, &self.sess_id, &mut b.client()?).await?); + *output = Some(kex.handle_kexdhreply(&p, &self.sess_id, &mut b.client()?).await?); resp.push(Packet::NewKeys(packets::NewKeys {}).into()).trap()?; } } else { @@ -279,7 +294,7 @@ impl<'a> Conn<'a> { Packet::UserauthFailure(p) => { // TODO: client only if let ClientServer::Client(cli) = &mut self.cliserv { - cli.auth.failure(p, &mut b.client()?, &mut resp, &mut self.parse_ctx).await?; + cli.auth.failure(&p, &mut b.client()?, &mut resp, &mut self.parse_ctx).await?; } else { debug!("Received UserauthFailure as a server"); return Err(Error::SSHProtoError) @@ -310,7 +325,7 @@ impl<'a> Conn<'a> { Packet::UserauthBanner(p) => { // TODO: client only if let ClientServer::Client(cli) = &mut self.cliserv { - cli.banner(p, &mut b.client()?).await; + cli.banner(&p, &mut b.client()?).await; } else { debug!("Received banner as a server"); return Err(Error::SSHProtoError) @@ -319,7 +334,7 @@ impl<'a> Conn<'a> { Packet::Userauth60(p) => { // TODO: client only if let ClientServer::Client(cli) = &mut self.cliserv { - cli.auth.auth60(p, &mut resp, self.sess_id.as_ref().trap()?, &mut self.parse_ctx).await?; + cli.auth.auth60(&p, &mut resp, self.sess_id.as_ref().trap()?, &mut self.parse_ctx).await?; } else { debug!("Received userauth60 as a server"); return Err(Error::SSHProtoError) @@ -336,9 +351,42 @@ impl<'a> Conn<'a> { | Packet::ChannelRequest(_) | Packet::ChannelSuccess(_) | Packet::ChannelFailure(_) - // TODO: probably needs a conn or cliserv argument. - => self.channels.dispatch(packet, &mut resp, b).await?, + // TODO: maybe needs a conn or cliserv argument. + => { + let chev = self.channels.dispatch(packet, &mut resp, b).await?; + ev = chev.map(|c| EventMaker::Channel(c)) + } }; - Ok(resp) + if let Some(ev) = ev { + if resp.is_empty() { + Ok(Dispatched::Event(ev)) + } else { + Err(Error::bug()) + } + } else { + Ok(Dispatched::Resp(resp)) + } } + + pub(crate) fn make_event<'p>(&mut self, payload: &'p [u8], ev: EventMaker) + -> Result<Option<Event<'p>>> { + let p = sshwire::packet_from_bytes(payload, &self.parse_ctx)?; + match ev { + EventMaker::Channel(cev) => { + let c = cev.make(p); + Ok(c.map(|c| Event::Channel(c))) + } + } + } + +} + +// pub(crate) struct Dispatched<'r, 'e> { +// pub resp: RespPackets<'r>, +// pub event: Option<Event<'e>>, +// } + +pub(crate) enum Dispatched<'r> { + Resp(RespPackets<'r>), + Event(EventMaker), } diff --git a/sshproto/src/encrypt.rs b/sshproto/src/encrypt.rs index cfb018417d46438d7749f33971e470e216fd5398..5d4f7588eae88c5d4da507a87bc63a541e5e9c4c 100644 --- a/sshproto/src/encrypt.rs +++ b/sshproto/src/encrypt.rs @@ -35,7 +35,8 @@ const MAX_IV_LEN: usize = 32; /// Largest is chacha. Also applies to MAC keys const MAX_KEY_LEN: usize = 64; -/// Stateful [`Keys`], stores a sequence number as well +/// Stateful [`Keys`], stores a sequence number as well, a single instance +/// is kept for the entire session. pub(crate) struct KeyState { keys: Keys, // Packet sequence numbers. These must be transferred to subsequent KeyState @@ -83,9 +84,7 @@ impl KeyState { self.seq_encrypt += 1; e } - pub fn size_integ_dec(&self) -> usize { - self.keys.integ_dec.size_out() - } + pub fn size_block_dec(&self) -> usize { self.keys.dec.size_block() } @@ -239,8 +238,8 @@ impl Keys { /// total SSH packet (including length+mac) which is calculated /// from the decrypted first 4 bytes. /// Whether bytes `buf[4..block_size]` are decrypted depends on the cipher, they may be - /// handled later by [`decrypt`]. Bytes `buf[0..4]` may not be modified. - pub fn decrypt_first_block( + /// handled later by [`decrypt`]. Bytes `buf[0..4]` may be left unmodified. + fn decrypt_first_block( &mut self, buf: &mut [u8], seq: u32, ) -> Result<u32, Error> { if buf.len() < self.dec.size_block() { @@ -257,7 +256,6 @@ impl Keys { u32::from_be_bytes(buf[..SSH_LENGTH_SIZE].try_into().unwrap()) } }; - trace!("len {len}"); let total_len = len .checked_add((SSH_LENGTH_SIZE + self.integ_dec.size_out()) as u32) @@ -271,7 +269,7 @@ impl Keys { /// Ensures that the packet meets minimum length. /// The first block_size bytes may have been already decrypted by /// [`decrypt_first_block`] depending on the cipher. - pub fn decrypt(&mut self, buf: &mut [u8], seq: u32) -> Result<usize, Error> { + fn decrypt(&mut self, buf: &mut [u8], seq: u32) -> Result<usize, Error> { let size_block = self.dec.size_block(); let size_integ = self.integ_dec.size_out(); @@ -295,7 +293,7 @@ impl Keys { let (data, mac) = buf.split_at_mut(buf.len() - size_integ); - // TODO: ETM modes would check integrity here. + // ETM modes would check integrity here. match &mut self.dec { DecKey::ChaPoly(k) => { @@ -372,7 +370,7 @@ impl Keys { /// Encrypt a buffer in-place, adding packet size, padding, MAC etc. /// Returns the total length. /// Ensures that the packet meets minimum and other length requirements. - pub fn encrypt( + fn encrypt( &mut self, payload_len: usize, buf: &mut [u8], seq: u32, ) -> Result<usize, Error> { let size_block = self.enc.size_block(); @@ -438,7 +436,7 @@ impl Keys { pub(crate) enum Cipher { ChaPoly, Aes256Ctr, - // TODO Aes gcm etc + // TODO AesGcm etc } impl fmt::Display for Cipher { @@ -461,7 +459,7 @@ impl Cipher { } } - /// length in bytes + /// Length in bytes pub fn key_len(&self) -> usize { match self { Cipher::ChaPoly => SSHChaPoly::key_len(), @@ -469,7 +467,7 @@ impl Cipher { } } - /// length in bytes + /// Length in bytes pub fn iv_len(&self) -> usize { match self { Cipher::ChaPoly => 0, @@ -493,6 +491,10 @@ pub(crate) enum EncKey { NoCipher, } +// TODO: could probably unify EncKey and DecKey as "CipherKey". +// Ring had sealing/opening keys which are separate, but RustCrypto +// uses the same structs in both directions. + impl EncKey { /// Construct a key pub fn from_cipher<'a>( diff --git a/sshproto/src/lib.rs b/sshproto/src/lib.rs index c97210d3e1103578e1d21962826d6be2b676bdd4..70a5a2b86dfcdd59eecff171cc3f0ac429a743f7 100644 --- a/sshproto/src/lib.rs +++ b/sshproto/src/lib.rs @@ -44,6 +44,7 @@ mod block_behaviour; mod ssh_chapoly; pub mod sshwire; +// Application API pub use behaviour::{Behaviour, BhError, BhResult, ResponseString}; #[cfg(feature = "std")] pub use async_behaviour::{AsyncCliBehaviour,AsyncServBehaviour}; @@ -56,3 +57,4 @@ pub use sign::SignKey; pub use packets::PubKey; pub use error::{Error,Result}; pub use channel::{ChanMsg,ChanMsgDetails}; +pub use conn::Event; diff --git a/sshproto/src/packets.rs b/sshproto/src/packets.rs index e34921c9ddd6a2c0eb150e158c5772ea33baa5aa..ff927ad3a9f52a757af5fb4554087ef33fcae470 100644 --- a/sshproto/src/packets.rs +++ b/sshproto/src/packets.rs @@ -410,6 +410,13 @@ pub struct ChannelData<'a> { pub data: BinString<'a>, } +impl ChannelData<'_> { + // offset into a packet payload + pub(crate) fn data_offset(&self) -> usize { + 5 + } +} + #[derive(Debug,SSHEncode, SSHDecode)] pub struct ChannelDataExt<'a> { pub num: u32, @@ -417,6 +424,13 @@ pub struct ChannelDataExt<'a> { pub data: BinString<'a>, } +impl ChannelDataExt<'_> { + // offset into a packet payload + pub(crate) fn data_offset(&self) -> usize { + 9 + } +} + #[derive(Debug,SSHEncode, SSHDecode)] pub struct ChannelEof { pub num: u32, @@ -441,11 +455,11 @@ pub struct ChannelFailure { pub struct ChannelRequest<'a> { pub num: u32, - // channel_type is implicit in ch below - #[sshwire(variant_name = ch)] + // channel_type is implicit in req below + #[sshwire(variant_name = req)] pub want_reply: bool, - pub ch: ChannelReqType<'a>, + pub req: ChannelReqType<'a>, } #[derive(Debug, SSHEncode, SSHDecode)] diff --git a/sshproto/src/random.rs b/sshproto/src/random.rs index 9ae401208fc89ae2876f7ccc9c56ba69ef0ffcc9..c35b931ad0b411861ffd10b4089907e8715e3739 100644 --- a/sshproto/src/random.rs +++ b/sshproto/src/random.rs @@ -10,44 +10,6 @@ use core::num::Wrapping; #[cfg(feature = "getrandom")] pub type DoorRng = rand::rngs::OsRng; -#[cfg(feature = "fakerandom")] -pub type DoorRng = FakeRng; - -#[derive(Clone, Copy, Debug, Default)] -pub struct FakeRng { - state: Wrapping<u32>, -} - -impl CryptoRng for FakeRng {} - -impl RngCore for FakeRng { - fn next_u32(&mut self) -> u32 { - rand_core::impls::next_u32_via_fill(self) - } - - fn next_u64(&mut self) -> u64 { - rand_core::impls::next_u64_via_fill(self) - } - - fn fill_bytes(&mut self, dest: &mut [u8]) { - dest.fill_with(|| { - self.state = Wrapping(14013u32) * self.state + Wrapping(2531011u32); - ((self.state>>16).0 & 0xFF) as u8 - }); - dest.fill(8) - - } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> { - Ok(dest.fill(8)) - // Ok(dest.fill_with(|| { - // self.state = Wrapping(14013u32) * self.state + Wrapping(2531011u32); - // ((self.state>>16).0 & 0xFF) as u8 - // })) - } - -} - pub fn fill_random(buf: &mut [u8]) -> Result<(), Error> { // TODO: can this return an error? let mut rng = DoorRng::default(); diff --git a/sshproto/src/runner.rs b/sshproto/src/runner.rs index ab9b394a22652abde76a11c3ffb575f2da550636..692fe12a25eeeaa8f50337fb2f1ad45010b7626b 100644 --- a/sshproto/src/runner.rs +++ b/sshproto/src/runner.rs @@ -8,10 +8,13 @@ use core::task::{Poll, Waker}; use pretty_hex::PrettyHex; -use crate::*; +use crate::{*, channel::ChanEvent}; use encrypt::KeyState; use traffic::Traffic; +use conn::{Dispatched, EventMaker, Event}; +use channel::ChanEventMaker; + pub struct Runner<'a> { conn: Conn<'a>, @@ -49,7 +52,7 @@ impl<'a> Runner<'a> { &mut self.conn.remote_version, buf, )?; - // payload is dispatched by out_progress() on the output side + // payload will be handled when progress() is called if self.traffic.payload().is_some() { trace!("payload some, waker {:?}", self.output_waker); if let Some(w) = self.output_waker.take() { @@ -60,35 +63,65 @@ impl<'a> Runner<'a> { Ok(size) } - // Drives connection progress, handling received payload and sending - // other packets as required - pub async fn out_progress(&mut self, b: &mut Behaviour<'_>) -> Result<(), Error> { - trace!("out_progress top"); - if let Some(payload) = self.traffic.payload() { - trace!("out_progress payload"); + /// Drives connection progress, handling received payload and sending + /// other packets as required. This must be polled/awaited regularly. + /// Optionally returns `Event` which provides channel or session + // event to the application. + pub async fn progress<'f>(&'f mut self, b: &mut Behaviour<'_>) -> Result<Option<Event<'f>>, Error> { + let em = if let Some(payload) = self.traffic.payload() { // Lifetimes here are a bit subtle. // `payload` has self.traffic lifetime, used until `handle_payload` // completes. - // The `resp` from handle_payload() references self.conn, consumed + // The `resp` from handle_payload() references self.conn, consume // by the send_packet(). // After that progress() can perform more send_packet() itself. - let resp = self.conn.handle_payload(payload, &mut self.keys, b).await?; - debug!("done_payload"); - self.traffic.done_payload()?; - for r in resp { - r.send_packet(&mut self.traffic, &mut self.keys)?; + let r = self.conn.handle_payload(payload, &mut self.keys, b).await?; + match r { + Dispatched::Resp(resp) => { + debug!("done_payload"); + self.traffic.done_payload()?; + for r in resp { + r.send_packet(&mut self.traffic, &mut self.keys)?; + } + + None + } + Dispatched::Event(em) => { + Some(em) + } + } + } else { + None + }; + + // We split return values into Event/EventMaker to work around + // the payload borrow range extending too long. + // Polonius would solve this. We can't use polonius-the-crab + // because we're calling async functions. + // "Borrow checker extends borrow range in code with early return" + // https://github.com/rust-lang/rust/issues/54663 + let ev = if let Some(em) = em { + match em { + EventMaker::Channel(ChanEventMaker::DataIn(di)) => { + self.traffic.set_channel_input(di)?; + None + } + _ => { + let payload = self.traffic.payload().trap()?; + self.conn.make_event(payload, em)? + } } - } - self.conn.progress(&mut self.traffic, &mut self.keys, b).await?; - b.progress(self)?; + } else { + self.conn.progress(&mut self.traffic, &mut self.keys, b).await?; + None + }; - trace!("out_progress done"); - Ok(()) + Ok(ev) } - /// Write any pending output, returning the size written + /// Write any pending output to the wire, returning the size written pub fn output(&mut self, buf: &mut [u8]) -> Result<usize, Error> { let r = self.traffic.output(buf); if self.ready_input() { @@ -119,26 +152,42 @@ impl<'a> Runner<'a> { Ok(ch.number()) } - pub fn channel_input( + /// Send data from this application out the wire. + /// Must have already checked `ready_channel_send()`. + /// Returns the length of `buf` consumed. + pub fn channel_send( &mut self, chan: u32, - msg: channel::ChanMsg, + ext: Option<u32>, + buf: &[u8], ) -> Result<usize> { - todo!() + let (p, len) = self.conn.channels.send_data(chan, ext, buf)?; + self.traffic.send_packet(p, &mut self.keys)?; + Ok(len) } - pub fn channel_output( + /// Receive data coming from the wire into this application + pub fn channel_input( &mut self, chan: u32, + ext: Option<u32>, buf: &mut [u8], - ) -> Result<Poll<channel::ChanOut>> { - todo!() + ) -> Result<usize> { + let (len, complete) = self.traffic.channel_input(chan, ext, buf); + if complete { + self.conn.channels.finished_input(chan)?; + } + Ok(len) } pub fn ready_input(&self) -> bool { self.conn.initial_sent() && self.traffic.ready_input() } + pub fn ready_progress(&self) -> bool { + self.conn.initial_sent() && self.traffic.ready_input() + } + pub fn set_input_waker(&mut self, waker: Waker) { self.input_waker = Some(waker); } @@ -151,6 +200,16 @@ impl<'a> Runner<'a> { self.output_waker = Some(waker); } + pub fn ready_channel_input(&self, chan: u32, ext: Option<u32>) -> bool { + self.traffic.ready_channel_input(chan, ext) + } + + // TODO check the chan/ext are valid + pub fn ready_channel_send(&self, _chan: u32, _ext: Option<u32>) -> bool { + self.traffic.ready_channel_send() + // && self.conn.channels.ready_send_data(chan, ext) + } + // pub fn chan_pending(&self) -> bool { // self.conn.chan_pending() // } diff --git a/sshproto/src/sshwire.rs b/sshproto/src/sshwire.rs index 4aef45e0f2de4184a0685e96ed33b745f2465a76..dbc0a744f69d7bb038e76d0c693718219b367025 100644 --- a/sshproto/src/sshwire.rs +++ b/sshproto/src/sshwire.rs @@ -113,7 +113,7 @@ where T: SSHEncode, { let mut s = EncodeBytes { target, pos: 0 }; - let r = value.enc(&mut s)?; + value.enc(&mut s)?; Ok(s.pos) } @@ -233,6 +233,7 @@ pub fn hash_mpint(hash_ctx: &mut dyn digest::DynDigest, m: &[u8]) { /// A SSH style binary string. Serialized as 32 bit length followed by the bytes /// of the slice. +/// Application API #[derive(Clone,PartialEq)] pub struct BinString<'a>(pub &'a [u8]); @@ -265,14 +266,16 @@ impl<'de> SSHDecode<'de> for BinString<'de> { } -/// A text string that may be presented to a user. +/// A text string that may be presented to a user or used +/// for things such as a password, username, exec command, tcp hostname, etc. /// The SSH protocol defines it to be UTF-8, though -/// in some applications it can be treated as ascii-only. +/// in some applications it could be treated as ascii-only. /// The library treats it as an opaque `&[u8]`, leaving /// decoding to the `Behaviour`. /// Note that SSH protocol identifiers in `Packet` etc /// are `&str` rather than `TextString`, and always defined as ASCII. +/// Application API #[derive(Clone,PartialEq,Copy)] pub struct TextString<'a>(pub &'a [u8]); diff --git a/sshproto/src/traffic.rs b/sshproto/src/traffic.rs index 31136728a98e16c78acc53ec3d3b761a55d4982b..3a245235c268745ea38931fa9aff50c670c8f8bf 100644 --- a/sshproto/src/traffic.rs +++ b/sshproto/src/traffic.rs @@ -42,9 +42,20 @@ enum TrafState { ReadComplete { len: usize }, /// Decrypted complete input payload InPayload { len: usize }, + /// Decrypted incoming channel data + InChannelData { + /// channel number + chan: u32, + /// extended flag. usually None, or `Some(1)` for `SSH_EXTENDED_DATA_STDERR` + ext: Option<u32>, + /// read index of channel data. should transition to Idle once `idx==len` + idx: usize, + /// length of buffer, end of channel data + len: usize, + }, /// Writing to the socket. Buffer is encrypted in-place. - /// Should never be left in idx==len state, + /// Should never be left in `idx==len` state, /// instead should transition to Idle Write { /// Cursor position in the buffer @@ -93,6 +104,7 @@ impl<'a> Traffic<'a> { | TrafState::Read { .. } => true, TrafState::ReadComplete { .. } | TrafState::InPayload { .. } + | TrafState::InChannelData { .. } | TrafState::Write { .. } => false, } } @@ -177,7 +189,10 @@ impl<'a> Traffic<'a> { let (idx, len) = match self.state { TrafState::Idle => (0, 0), TrafState::Write { idx, len } => (idx, len), - _ => Err(Error::bug())?, + _ => { + trace!("bad state {:?}", self.state); + Err(Error::bug())? + } }; // Use the remainder of our buffer to write the packet. Payload starts @@ -276,4 +291,62 @@ impl<'a> Traffic<'a> { Ok(buf.len() - r.len()) } + + pub fn ready_channel_input(&self, chan: u32, ext: Option<u32>) -> bool { + match self.state { + TrafState::InChannelData { chan: c, ext: e, .. } + if (c, e) == (chan, ext) => true, + _ => false, + } + } + + pub fn ready_channel_send(&self) -> bool { + match self.state { + TrafState::Idle => true, + _ => false, + } + } + + pub fn set_channel_input(&mut self, di: channel::DataIn) -> Result<()> { + match self.state { + TrafState::Idle => { + let idx = SSH_PAYLOAD_START + di.offset; + self.state = TrafState::InChannelData { chan: di.num, ext: di.ext, idx, len: di.len }; + Ok(()) + } + _ => Err(Error::bug()), + } + } + + // Returns the length consumed, and a bool indicating whether the whole + // data packet has been completed. + pub fn channel_input( + &mut self, + chan: u32, + ext: Option<u32>, + buf: &mut [u8], + ) -> (usize, bool) { + if !matches!(self.state, TrafState::Idle) { + return (0, false) + } + + match self.state { + TrafState::InChannelData { chan: c, ext: e, ref mut idx, len } + if (c, e) == (chan, ext) => { + let wlen = (len - *idx).min(buf.len()); + buf[..wlen].copy_from_slice(&self.buf[*idx..*idx + wlen]); + *idx += wlen; + + if *idx == len { + // all done. + self.state = TrafState::Idle; + (wlen, true) + } else { + (wlen, false) + } + } + _ => (0, false) + } + } + }