From 61e1d06562b7da46945a00d1d2f70733989119fe Mon Sep 17 00:00:00 2001
From: Matt Johnston <matt@ucc.asn.au>
Date: Mon, 7 Nov 2022 20:15:57 +0800
Subject: [PATCH] Move run() into embassy_sunset

Now is common between CliBehaviour and ServBehaviour

Thanks to bruh![moment] on Discord for details how to make it work
---
 embassy/demos/picow/src/main.rs |  3 +-
 embassy/demos/std/src/main.rs   |  3 +-
 embassy/src/embassy_sunset.rs   | 64 ++++++++++++++++++++++++--
 embassy/src/server.rs           | 79 +++------------------------------
 src/behaviour.rs                | 14 +++++-
 5 files changed, 84 insertions(+), 79 deletions(-)

diff --git a/embassy/demos/picow/src/main.rs b/embassy/demos/picow/src/main.rs
index cff5d3c..95aa07a 100644
--- a/embassy/demos/picow/src/main.rs
+++ b/embassy/demos/picow/src/main.rs
@@ -233,8 +233,9 @@ async fn session(socket: &mut TcpSocket<'_>) -> sunset::Result<()> {
     let serv = &serv;
 
     let app = Mutex::<NoopRawMutex, _>::new(app);
+    let app = &app as &Mutex::<NoopRawMutex, dyn ServBehaviour>;
 
-    let run = serv.run(socket, &app);
+    let run = serv.run(socket, app);
 
     let session = async {
         loop {
diff --git a/embassy/demos/std/src/main.rs b/embassy/demos/std/src/main.rs
index abd740b..c7e5c4b 100644
--- a/embassy/demos/std/src/main.rs
+++ b/embassy/demos/std/src/main.rs
@@ -189,8 +189,9 @@ async fn session(socket: &mut TcpSocket<'_>) -> sunset::Result<()> {
     let serv = &serv;
 
     let app = Mutex::<NoopRawMutex, _>::new(app);
+    let app = &app as &Mutex::<NoopRawMutex, dyn ServBehaviour>;
 
-    serv.run(socket, &app).await
+    serv.run(socket, app).await
 }
 
 static EXECUTOR: StaticCell<Executor> = StaticCell::new();
diff --git a/embassy/src/embassy_sunset.rs b/embassy/src/embassy_sunset.rs
index 31f4b12..b28c362 100644
--- a/embassy/src/embassy_sunset.rs
+++ b/embassy/src/embassy_sunset.rs
@@ -10,6 +10,8 @@ use embassy_sync::waitqueue::WakerRegistration;
 use embassy_sync::mutex::Mutex;
 use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex};
 use embassy_sync::signal::Signal;
+use embassy_futures::join::join3;
+use embassy_net::tcp::TcpSocket;
 
 use pin_utils::pin_mut;
 
@@ -46,6 +48,60 @@ impl<'a> EmbassySunset<'a> {
          }
     }
 
+    pub async fn run<M, B: ?Sized>(&self, socket: &mut TcpSocket<'_>,
+        b: &Mutex<M, B>) -> Result<()>
+        where
+            M: RawMutex,
+            for<'f> Behaviour<'f>: From<&'f mut B>
+    {
+        let (mut rsock, mut wsock) = socket.split();
+
+        let tx = async {
+            loop {
+                // TODO: make sunset read directly from socket, no intermediate buffer.
+                let mut buf = [0; 1024];
+                let l = self.read(&mut buf).await?;
+                let mut buf = &buf[..l];
+                while buf.len() > 0 {
+                    let n = wsock.write(buf).await.expect("TODO handle write error");
+                    buf = &buf[n..];
+                }
+            }
+            #[allow(unreachable_code)]
+            Ok::<_, sunset::Error>(())
+        };
+
+        let rx = async {
+            loop {
+                // TODO: make sunset read directly from socket, no intermediate buffer.
+                let mut buf = [0; 1024];
+                let l = rsock.read(&mut buf).await.expect("TODO handle read error");
+                let mut buf = &buf[..l];
+                while buf.len() > 0 {
+                    let n = self.write(&buf).await?;
+                    buf = &buf[n..];
+                }
+            }
+            #[allow(unreachable_code)]
+            Ok::<_, sunset::Error>(())
+        };
+
+        let prog = async {
+            loop {
+                self.progress(b).await?;
+            }
+            #[allow(unreachable_code)]
+            Ok::<_, sunset::Error>(())
+        };
+
+
+        // TODO: handle results
+        join3(rx, tx, prog).await;
+
+        Ok(())
+    }
+
+
     fn wake_channels(&self, inner: &mut Inner) {
 
             if let Some((chan, _ext)) = inner.runner.ready_channel_input() {
@@ -60,10 +116,12 @@ impl<'a> EmbassySunset<'a> {
     }
 
     // XXX could we have a concrete NoopRawMutex instead of M?
-    pub async fn progress_server<M, B>(&self,
+    pub async fn progress<M, B: ?Sized>(&self,
         b: &Mutex<M, B>)
         -> Result<()>
-        where M: RawMutex, B: ServBehaviour
+        where
+            M: RawMutex,
+            for<'f> Behaviour<'f>: From<&'f mut B>
         {
 
         {
@@ -74,7 +132,7 @@ impl<'a> EmbassySunset<'a> {
                     warn!("progress locked");
                     // XXX: unsure why we need this explicit type
                     let b: &mut B = &mut b;
-                    let mut b = Behaviour::new_server(b);
+                    let mut b: Behaviour = b.into();
                     inner.runner.progress(&mut b).await?;
                     // b is dropped, allowing other users
                 }
diff --git a/embassy/src/server.rs b/embassy/src/server.rs
index b5958cb..5ecec6c 100644
--- a/embassy/src/server.rs
+++ b/embassy/src/server.rs
@@ -1,6 +1,5 @@
 use embassy_sync::mutex::Mutex;
 use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex};
-use embassy_futures::join::join3;
 use embassy_net::tcp::TcpSocket;
 
 use sunset::*;
@@ -21,78 +20,12 @@ impl<'a> SSHServer<'a> {
         Ok(Self { sunset })
     }
 
-    pub async fn progress<M>(&self,
-        b: &Mutex<M, impl ServBehaviour>)
-        -> Result<()>
-        where M: RawMutex
+    pub async fn run<M, B: ?Sized>(&self, socket: &mut TcpSocket<'_>,
+        b: &Mutex<M, B>) -> Result<()>
+        where
+            M: RawMutex,
+            for<'f> Behaviour<'f>: From<&'f mut B>
     {
-        // let mut b = Behaviour::new_server(b);
-        self.sunset.progress_server(b).await
-    }
-
-    pub async fn run<M>(&self, socket: &mut TcpSocket<'_>, b: &Mutex<M, impl ServBehaviour>) -> Result<()>
-        where M: RawMutex
-    {
-        let (mut rsock, mut wsock) = socket.split();
-
-        let tx = async {
-            loop {
-                // TODO: make sunset read directly from socket, no intermediate buffer.
-                let mut buf = [0; 1024];
-                let l = self.read(&mut buf).await?;
-                let mut buf = &buf[..l];
-                while buf.len() > 0 {
-                    let n = wsock.write(buf).await.expect("TODO handle write error");
-                    buf = &buf[n..];
-                }
-            }
-            #[allow(unreachable_code)]
-            Ok::<_, sunset::Error>(())
-        };
-
-        let rx = async {
-            loop {
-                // TODO: make sunset read directly from socket, no intermediate buffer.
-                let mut buf = [0; 1024];
-                let l = rsock.read(&mut buf).await.expect("TODO handle read error");
-                let mut buf = &buf[..l];
-                while buf.len() > 0 {
-                    let n = self.write(&buf).await?;
-                    buf = &buf[n..];
-                }
-            }
-            #[allow(unreachable_code)]
-            Ok::<_, sunset::Error>(())
-        };
-
-        let prog = async {
-            loop {
-                self.progress(b).await?;
-            }
-            #[allow(unreachable_code)]
-            Ok::<_, sunset::Error>(())
-        };
-
-
-        // TODO: handle results
-        join3(rx, tx, prog).await;
-
-        Ok(())
-    }
-
-    // pub async fn channel(&mut self, ch: u32) -> Result<(ChanInOut<'a>, Option<ChanExtOut<'a>>)> {
-    //     let ty = self.sunset.with_runner(|r| r.channel_type(ch)).await?;
-    //     let inout = ChanInOut::new(ch, &self.sunset);
-    //     // TODO ext
-    //     let ext = None;
-    //     Ok((inout, ext))
-    // }
-
-    pub async fn read(&self, buf: &mut [u8]) -> Result<usize> {
-        self.sunset.read(buf).await
-    }
-
-    pub async fn write(&self, buf: &[u8]) -> Result<usize> {
-        self.sunset.write(buf).await
+        self.sunset.run(socket, b).await
     }
 }
diff --git a/src/behaviour.rs b/src/behaviour.rs
index f0f59dd..c2a2768 100644
--- a/src/behaviour.rs
+++ b/src/behaviour.rs
@@ -38,10 +38,22 @@ pub type ResponseString = heapless::String<100>;
 // into a separate trait (which impls the non-async trait)
 
 pub enum Behaviour<'a> {
-    Client(&'a mut (dyn CliBehaviour + Send)),
+    Client(&'a mut (dyn CliBehaviour)),
     Server(&'a mut (dyn ServBehaviour)),
 }
 
+impl<'a, 'b> From<&'a mut (dyn ServBehaviour + 'b)> for Behaviour<'a> {
+    fn from(b: &'a mut (dyn ServBehaviour + 'b)) -> Self {
+        Self::Server(b)
+    }
+}
+
+impl<'a, 'b> From<&'a mut (dyn CliBehaviour + 'b)> for Behaviour<'a> {
+    fn from(b: &'a mut (dyn CliBehaviour + 'b)) -> Self {
+        Self::Client(b)
+    }
+}
+
 impl<'a> Behaviour<'a> {
     // TODO: make these From<>
     pub fn new_client(b: &'a mut (dyn CliBehaviour + Send)) -> Self {
-- 
GitLab