From 9acf1a516c79ba3995c6ca67ec4fa989d24eac56 Mon Sep 17 00:00:00 2001 From: Matt Johnston <matt@ucc.asn.au> Date: Sun, 7 Aug 2022 22:52:33 +0800 Subject: [PATCH] Pass Behaviour as an argument instead. Get rid of boxing --- async/examples/con1.rs | 5 +++-- async/src/async_door.rs | 8 +++++--- async/src/client.rs | 17 +++++++++++------ sshproto/src/async_behaviour.rs | 14 ++++++-------- sshproto/src/behaviour.rs | 11 ++++------- sshproto/src/block_behaviour.rs | 1 - sshproto/src/runner.rs | 10 +++------- 7 files changed, 32 insertions(+), 34 deletions(-) diff --git a/async/examples/con1.rs b/async/examples/con1.rs index c973e24..b1e3b88 100644 --- a/async/examples/con1.rs +++ b/async/examples/con1.rs @@ -145,12 +145,13 @@ async fn run(args: &Args) -> Result<()> { let tx = vec![0; 3000]; let tx = Box::leak(Box::new(tx)).as_mut_slice(); + // cli is a Behaviour let mut cli = door_async::CmdlineClient::new(args.username.as_ref().unwrap()); for i in &args.identityfile { cli.add_authkey(read_key(&i).with_context(|| format!("loading key {i}"))?); } - let mut door = SSHClient::new(work, tx, Box::new(cli))?; + let mut door = SSHClient::new(work, tx)?; let mut s = door.socket(); moro::async_scope!(|scope| { @@ -158,7 +159,7 @@ async fn run(args: &Args) -> Result<()> { scope.spawn(async { loop { - let ev = door.progress(|ev| { + let ev = door.progress(&mut cli, |ev| { trace!("progress event {ev:?}"); let e = match ev { Event::CliAuthed => Some(Event::CliAuthed), diff --git a/async/src/async_door.rs b/async/src/async_door.rs index 3db916a..df4419f 100644 --- a/async/src/async_door.rs +++ b/async/src/async_door.rs @@ -19,7 +19,7 @@ use core::ops::DerefMut; use std::sync::Arc; use door_sshproto as door; -use door::{Runner, Result, Event, ChanEvent}; +use door::{Runner, Result, Event, ChanEvent, Behaviour}; use pretty_hex::PrettyHex; @@ -59,7 +59,9 @@ impl<'a> AsyncDoor<'a> { AsyncDoorSocket::new(self) } - pub async fn progress<F, R>(&mut self, f: F) + pub async fn progress<F, R>(&mut self, + b: &mut Behaviour<'_>, + f: F) -> Result<Option<R>> where F: FnOnce(door::Event) -> Result<Option<R>> { trace!("progress"); @@ -67,7 +69,7 @@ impl<'a> AsyncDoor<'a> { let res = { let mut inner = self.inner.lock().await; let inner = inner.deref_mut(); - let ev = inner.runner.progress().await?; + let ev = inner.runner.progress(b).await?; let r = if let Some(ev) = ev { let r = match ev { Event::Channel(ChanEvent::Eof { num }) => { diff --git a/async/src/client.rs b/async/src/client.rs index cf63026..6303029 100644 --- a/async/src/client.rs +++ b/async/src/client.rs @@ -31,10 +31,8 @@ pub struct SSHClient<'a> { impl<'a> SSHClient<'a> { pub fn new(inbuf: &'a mut [u8], - outbuf: &'a mut [u8], - behaviour: Box<dyn AsyncCliBehaviour+Send>) -> Result<Self> { - let b = Behaviour::new_async_client(behaviour); - let runner = Runner::new_client(inbuf, outbuf, b)?; + outbuf: &'a mut [u8]) -> Result<Self> { + let runner = Runner::new_client(inbuf, outbuf)?; let door = AsyncDoor::new(runner); Ok(Self { door @@ -45,10 +43,17 @@ impl<'a> SSHClient<'a> { self.door.socket() } - pub async fn progress<F, R>(&mut self, f: F) + /// Takes a closure to run on the "output" of the progress call. + /// (This output can't be returned directly since it refers + /// to contents of `Self` and would hit lifetime issues). + pub async fn progress<F, R>(&mut self, + behaviour: &mut (dyn AsyncCliBehaviour+Send), + f: F) -> Result<Option<R>> where F: FnOnce(door::Event) -> Result<Option<R>> { - self.door.progress(f).await + + let mut b = Behaviour::new_async_client(behaviour); + self.door.progress(&mut b, f).await } // TODO: return a Channel object that gives events like WinChange or exit status diff --git a/sshproto/src/async_behaviour.rs b/sshproto/src/async_behaviour.rs index a180584..2dba472 100644 --- a/sshproto/src/async_behaviour.rs +++ b/sshproto/src/async_behaviour.rs @@ -9,24 +9,23 @@ use { }; use async_trait::async_trait; -use std::boxed::Box; use crate::{*, conn::RespPackets}; use behaviour::*; -pub(crate) enum AsyncCliServ { - Client(Box<dyn AsyncCliBehaviour + Send>), - Server(Box<dyn AsyncServBehaviour + Send>), +pub(crate) enum AsyncCliServ<'a> { + Client(&'a mut (dyn AsyncCliBehaviour + Send)), + Server(&'a mut (dyn AsyncServBehaviour + Send)), } -impl AsyncCliServ { +impl<'a> AsyncCliServ<'a> { pub fn client(&mut self) -> Result<CliBehaviour> { let c = match self { Self::Client(c) => c, _ => Error::bug_msg("Not client")?, }; let c = CliBehaviour { - inner: c.as_mut(), + inner: *c, }; Ok(c) } @@ -37,8 +36,7 @@ impl AsyncCliServ { _ => Error::bug_msg("Not server")?, }; let c = ServBehaviour { - inner: c.as_mut(), - phantom: core::marker::PhantomData::default(), + inner: *c, }; Ok(c) } diff --git a/sshproto/src/behaviour.rs b/sshproto/src/behaviour.rs index 914148e..8fd0e68 100644 --- a/sshproto/src/behaviour.rs +++ b/sshproto/src/behaviour.rs @@ -47,7 +47,7 @@ pub type ReplyChannel = bhtokio::ReplyChannel; pub struct Behaviour<'a> { #[cfg(feature = "std")] - inner: crate::async_behaviour::AsyncCliServ, + inner: crate::async_behaviour::AsyncCliServ<'a>, #[cfg(not(feature = "std"))] inner: crate::block_behaviour::BlockCliServ<'a>, @@ -55,15 +55,15 @@ pub struct Behaviour<'a> { } #[cfg(feature = "std")] -impl Behaviour<'_> { - pub fn new_async_client(b: Box<dyn async_behaviour::AsyncCliBehaviour + Send>) -> Self { +impl<'a> Behaviour<'a> { + pub fn new_async_client(b: &'a mut (dyn AsyncCliBehaviour + Send)) -> Self { Self { inner: async_behaviour::AsyncCliServ::Client(b), phantom: PhantomData::default(), } } - pub fn new_async_server(b: Box<dyn async_behaviour::AsyncServBehaviour + Send>) -> Self { + pub fn new_async_server(b: &'a mut (dyn AsyncServBehaviour + Send)) -> Self { Self { inner: async_behaviour::AsyncCliServ::Server(b), phantom: PhantomData::default(), @@ -86,14 +86,12 @@ impl<'a> Behaviour<'a> pub fn new_blocking_client(b: &'a mut dyn BlockCliBehaviour) -> Self { Self { inner: block_behaviour::BlockCliServ::Client(b), - phantom: PhantomData::default(), } } pub fn new_blocking_server(b: &'a mut dyn BlockServBehaviour) -> Self { Self { inner: block_behaviour::BlockCliServ::Server(b), - phantom: PhantomData::default(), } } @@ -184,7 +182,6 @@ pub struct ServBehaviour<'a> { pub inner: &'a mut dyn async_behaviour::AsyncServBehaviour, #[cfg(not(feature = "std"))] pub inner: &'a mut dyn block_behaviour::BlockServBehaviour, - pub phantom: core::marker::PhantomData<&'a ()>, } #[cfg(feature = "std")] diff --git a/sshproto/src/block_behaviour.rs b/sshproto/src/block_behaviour.rs index 945db7f..86240f1 100644 --- a/sshproto/src/block_behaviour.rs +++ b/sshproto/src/block_behaviour.rs @@ -35,7 +35,6 @@ impl BlockCliServ<'_> }; let c = ServBehaviour { inner: *c, - phantom: core::marker::PhantomData::default(), }; Ok(c) } diff --git a/sshproto/src/runner.rs b/sshproto/src/runner.rs index 4c99e4d..ddae3a0 100644 --- a/sshproto/src/runner.rs +++ b/sshproto/src/runner.rs @@ -18,8 +18,6 @@ use channel::ChanEventMaker; pub struct Runner<'a> { conn: Conn<'a>, - behaviour: Behaviour<'a>, - /// Binary packet handling to and from the network buffer traffic: Traffic<'a>, @@ -35,7 +33,6 @@ impl<'a> Runner<'a> { pub fn new_client( inbuf: &'a mut [u8], outbuf: &'a mut [u8], - behaviour: Behaviour<'a>, ) -> Result<Runner<'a>, Error> { let conn = Conn::new_client()?; let runner = Runner { @@ -44,7 +41,6 @@ impl<'a> Runner<'a> { keys: KeyState::new_cleartext(), output_waker: None, input_waker: None, - behaviour, }; Ok(runner) @@ -74,7 +70,7 @@ impl<'a> Runner<'a> { /// Optionally returns `Event` which provides channel or session /// event to the application. /// [`done_payload()`] must be called after any `Ok` result. - pub async fn progress<'f>(&'f mut self) -> Result<Option<Event<'f>>, Error> { + pub async fn progress<'f>(&'f mut self, behaviour: &mut Behaviour<'_>) -> Result<Option<Event<'f>>, Error> { let em = if let Some((payload, seq)) = self.traffic.payload() { // Lifetimes here are a bit subtle. // `payload` has self.traffic lifetime, used until `handle_payload` @@ -83,7 +79,7 @@ impl<'a> Runner<'a> { // by the send_packet(). // After that progress() can perform more send_packet() itself. - let d = self.conn.handle_payload(payload, seq, &mut self.keys, &mut self.behaviour).await?; + let d = self.conn.handle_payload(payload, seq, &mut self.keys, behaviour).await?; self.traffic.handled_payload()?; if !d.resp.is_empty() || d.event.is_none() { @@ -123,7 +119,7 @@ impl<'a> Runner<'a> { } } else { trace!("no em, conn progress"); - self.conn.progress(&mut self.traffic, &mut self.keys, &mut self.behaviour).await?; + self.conn.progress(&mut self.traffic, &mut self.keys, behaviour).await?; self.wake(); None }; -- GitLab