diff --git a/sshproto/src/channel.rs b/sshproto/src/channel.rs index 1bca0e8a3cfcc6024190be0ab1184f95b25a3d35..1144b8e65c2695cfa8549b51535fc3ef6895f73c 100644 --- a/sshproto/src/channel.rs +++ b/sshproto/src/channel.rs @@ -89,10 +89,12 @@ impl Channels { /// Caller has already checked valid length with send_allowed() pub(crate) fn send_data<'b>(&mut self, num: u32, ext: Option<u32>, data: &'b [u8]) -> Result<Packet<'b>> { - let send = self.get(num)?.send.as_ref().trap()?; + let send = self.get_mut(num)?.send.as_mut().trap()?; if data.len() > send.max_packet || data.len() > send.window { return Err(Error::bug()) } + send.window -= data.len(); + let data = BinString(data); let p = if let Some(code) = ext { // TODO: check code is valid for this channel @@ -100,20 +102,20 @@ impl Channels { } else { packets::ChannelData { num: send.num, data }.into() }; + Ok(p) } /// Informs the channel layer that an incoming packet has been read out, /// so a window adjustment can be queued. - pub(crate) fn finished_input(&mut self, num: u32) -> Result<()> { + pub(crate) fn finished_input(&mut self, num: u32) -> Result<Option<Packet>> { match self.pending_input { Some(ref p) if p.chan == num => { - // TODO: send window adjustment let len = p.len; - let ch = self.get_mut(num)?; - ch.finished_input(len); + self.get_mut(num)?.finished_input(len); self.pending_input = None; - Ok(()) + + self.get_mut(num)?.check_window_adjust() } _ => Err(Error::bug()), } @@ -132,7 +134,6 @@ impl Channels { &mut self, packet: Packet<'_>, resp: &mut RespPackets<'_>, - b: &mut Behaviour<'_>, ) -> Result<Option<ChanEventMaker>> { trace!("chan dispatch"); let r = match packet { @@ -173,7 +174,9 @@ impl Channels { } } Packet::ChannelWindowAdjust(p) => { - todo!(); + let send = self.get_mut(p.num)?.send.as_mut().trap()?; + send.window = send.window.saturating_add(p.adjust as usize); + Ok(None) } Packet::ChannelData(p) => { let ch = self.get(p.num)?; @@ -193,6 +196,7 @@ impl Channels { } self.pending_input = Some(PendInput { chan: p.num, len: p.data.0.len() }); let di = DataIn { num: p.num, ext: Some(p.code), offset: p.data_offset(), len: p.data.0.len() }; + trace!("{di:?}"); Ok(Some(ChanEventMaker::DataIn(di))) } Packet::ChannelEof(p) => { @@ -371,6 +375,8 @@ pub struct Channel { /// Accumulated bytes for the next window adjustment (inbound data direction) pending_adjust: usize, + + full_window: usize, } impl Channel { @@ -388,6 +394,7 @@ impl Channel { }, send: None, pending_adjust: 0, + full_window: config::DEFAULT_WINDOW, } } fn request(&mut self, req: ReqDetails, resp: &mut RespPackets) -> Result<()> { @@ -419,6 +426,20 @@ impl Channel { fn send_allowed(&self) -> Option<usize> { self.send.as_ref().map(|s| usize::max(s.window, s.max_packet)) } + + /// Returns a window adjustment packet if required + fn check_window_adjust(&mut self) -> Result<Option<Packet>> { + let send = self.send.as_mut().trap()?; + if self.pending_adjust > self.full_window / 2 { + let adjust = self.pending_adjust as u32; + let p = packets::ChannelWindowAdjust { num: send.num, adjust }.into(); + Ok(Some(p)) + } else { + Ok(None) + } + } + + } pub struct ChanMsg { diff --git a/sshproto/src/conn.rs b/sshproto/src/conn.rs index 286075f04f53b4a25e4f4e061ad50ccb050cbaff..69bcaceed97f66906ca94a61bde00e3650bcba4e 100644 --- a/sshproto/src/conn.rs +++ b/sshproto/src/conn.rs @@ -348,7 +348,7 @@ impl<'a> Conn<'a> { | Packet::ChannelFailure(_) // TODO: maybe needs a conn or cliserv argument. => { - let chev = self.channels.dispatch(packet, &mut resp, b).await?; + let chev = self.channels.dispatch(packet, &mut resp).await?; event = chev.map(|c| EventMaker::Channel(c)) } }; diff --git a/sshproto/src/packets.rs b/sshproto/src/packets.rs index ff927ad3a9f52a757af5fb4554087ef33fcae470..6d63ae427b5fbd899502c898181062264f0b20c2 100644 --- a/sshproto/src/packets.rs +++ b/sshproto/src/packets.rs @@ -411,9 +411,9 @@ pub struct ChannelData<'a> { } impl ChannelData<'_> { - // offset into a packet payload + // offset into a packet of the raw data pub(crate) fn data_offset(&self) -> usize { - 5 + 9 } } @@ -427,7 +427,8 @@ pub struct ChannelDataExt<'a> { impl ChannelDataExt<'_> { // offset into a packet payload pub(crate) fn data_offset(&self) -> usize { - 9 + // offset into a packet of the raw data + 13 } } diff --git a/sshproto/src/runner.rs b/sshproto/src/runner.rs index 54da7cb40119fbb8a90e13902387c6b03b192c5a..ae96403c14d5f92e8f399064d600e80d7afc5f53 100644 --- a/sshproto/src/runner.rs +++ b/sshproto/src/runner.rs @@ -204,7 +204,10 @@ impl<'a> Runner<'a> { trace!("runner chan in"); let (len, complete) = self.traffic.channel_input(chan, ext, buf); if complete { - self.conn.channels.finished_input(chan)?; + let p = self.conn.channels.finished_input(chan)?; + if let Some(p) = p { + self.traffic.send_packet(p, &mut self.keys)?; + } self.wake(); } Ok(len) diff --git a/sshproto/src/traffic.rs b/sshproto/src/traffic.rs index 09c60bef91f04b93af77c81da5539b75c6207268..1cab016fafc1cb28777e62c59e6fb8ab060c21c9 100644 --- a/sshproto/src/traffic.rs +++ b/sshproto/src/traffic.rs @@ -222,7 +222,7 @@ impl<'a> Traffic<'a> { } _ => { // Just ignore it - warn!("done_payload called without payload"); + // warn!("done_payload called without payload, st {:?}", self.state); Ok(()) } } @@ -367,7 +367,12 @@ impl<'a> Traffic<'a> { match self.state { TrafState::Idle => { let idx = SSH_PAYLOAD_START + di.offset; - self.state = TrafState::InChannelData { chan: di.num, ext: di.ext, idx, len: di.len }; + self.state = TrafState::InChannelData { chan: di.num, ext: di.ext, idx, len: idx + di.len }; + // error!("set input {:?}", self.state); + trace!("all buf {:?}", self.buf[..32].hex_dump()); + trace!("set chan input offset {} idx {} {:?}", + di.offset, idx, + self.buf[idx..idx + di.len].hex_dump()); Ok(()) } _ => Err(Error::bug()), @@ -386,9 +391,13 @@ impl<'a> Traffic<'a> { match self.state { TrafState::InChannelData { chan: c, ext: e, ref mut idx, len } - if (c, e) == (chan, ext) => { + if (c, e) == (chan, ext) => { + if *idx > len { + error!("bad idx {} len {} e {:?} c {}", *idx, len, e, c); + } let wlen = (len - *idx).min(buf.len()); buf[..wlen].copy_from_slice(&self.buf[*idx..*idx + wlen]); + // info!("idx {} += wlen {} = {}", *idx, wlen, *idx+wlen); *idx += wlen; if *idx == len {