From e2426d8ccaa6a396df5625f2cefea60ec360eb9d Mon Sep 17 00:00:00 2001
From: Matt Johnston <matt@ucc.asn.au>
Date: Wed, 22 Jun 2022 00:11:10 +0800
Subject: [PATCH] wrappers for stdin/stdout async

---
 Cargo.lock              |  23 +++++++
 async/Cargo.toml        |   9 ++-
 async/examples/con1.rs  |  19 +++---
 async/src/async_door.rs |  52 ++-------------
 async/src/client.rs     |  73 +++++++++++++++++++++
 async/src/fdio.rs       | 138 ++++++++++++++++++++++++++++++++++++++++
 async/src/lib.rs        |   9 ++-
 7 files changed, 260 insertions(+), 63 deletions(-)
 create mode 100644 async/src/client.rs
 create mode 100644 async/src/fdio.rs

diff --git a/Cargo.lock b/Cargo.lock
index d297564..1a8175f 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -319,7 +319,9 @@ dependencies = [
  "async-trait",
  "door-sshproto",
  "futures",
+ "libc",
  "log",
+ "nix",
  "pretty-hex",
  "rpassword",
  "simplelog",
@@ -620,6 +622,15 @@ version = "2.5.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
 
+[[package]]
+name = "memoffset"
+version = "0.6.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce"
+dependencies = [
+ "autocfg",
+]
+
 [[package]]
 name = "mio"
 version = "0.8.3"
@@ -647,6 +658,18 @@ version = "1.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "546c37ac5d9e56f55e73b677106873d9d9f5190605e41a856503623648488cae"
 
+[[package]]
+name = "nix"
+version = "0.24.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8f17df307904acd05aa8e32e97bb20f2a0df1728bbc2d771ae8f9a90463441e9"
+dependencies = [
+ "bitflags",
+ "cfg-if",
+ "libc",
+ "memoffset",
+]
+
 [[package]]
 name = "no-panic"
 version = "0.1.15"
diff --git a/async/Cargo.toml b/async/Cargo.toml
index 50f94e7..5ab08b2 100644
--- a/async/Cargo.toml
+++ b/async/Cargo.toml
@@ -24,12 +24,15 @@ tokio = { version = "1.17", features = ["sync"] }
 futures = { version = "0.4.0-alpha.0", git = "https://github.com/rust-lang/futures-rs", revision = "8b0f812f53ada0d0aeb74abc32be22ab9dafae05" }
 async-trait = "0.1"
 
+libc = "0.2"
+nix = "0.24"
+
 # TODO
-anyhow = { version = "1.0" }
 pretty-hex = "0.3"
+snafu = { version = "0.7", default-features = true }
 
 [dev-dependencies]
-snafu = { default-features = true }
-tokio = { features = ["full"] }
+anyhow = { version = "1.0" }
+tokio = { version = "1.17", features = ["full"] }
 pretty-hex = "0.3"
 simplelog = "0.12"
diff --git a/async/examples/con1.rs b/async/examples/con1.rs
index 05c63c6..e7d36fd 100644
--- a/async/examples/con1.rs
+++ b/async/examples/con1.rs
@@ -159,22 +159,23 @@ async fn run(args: &Args) -> Result<()> {
             match ev {
                 Some(Event::Authenticated) => {
                     info!("auth auth");
-                    let r = door.open_client_session(Some("cowsay it works"), false).await
+                    let r = door.open_client_session_nopty(Some("cowsay it works")).await
                         .context("Opening session")?;
                     let (mut io, mut err) = r;
                     tokio::spawn(async move {
-                        trace!("io copy thread");
-                        // let mut i = tokio::io::stdin();
-                        let mut o = tokio::io::stdout();
+                        trace!("channel copy");
+                        let mut i = door_async::stdin().unwrap();
+                        // let mut o = tokio::io::stdout();
+                        let mut o = door_async::stdout().unwrap();
                         let mut e = tokio::io::stderr();
                         let mut io2 = io.clone();
                         let co = tokio::io::copy(&mut io, &mut o);
-                        // let ci = tokio::io::copy(&mut i, &mut io2);
+                        let ci = tokio::io::copy(&mut i, &mut io2);
                         let ce = tokio::io::copy(&mut err, &mut e);
-                        let r = futures::join!(co, ce);
-                        r.0?;
-                        r.1?;
-                        // r.2?;
+                        let (r1, r2, r3) = futures::join!(co, ci, ce);
+                        r1?;
+                        r2?;
+                        r3?;
                         Ok::<_, anyhow::Error>(())
                     });
                 }
diff --git a/async/src/async_door.rs b/async/src/async_door.rs
index b06dcbd..cee550b 100644
--- a/async/src/async_door.rs
+++ b/async/src/async_door.rs
@@ -1,4 +1,3 @@
-use door::sshnames::SSH_EXTENDED_DATA_STDERR;
 #[allow(unused_imports)]
 use log::{debug, error, info, log, trace, warn};
 
@@ -19,11 +18,8 @@ use core::task::Waker;
 use core::ops::DerefMut;
 use std::sync::Arc;
 
-// TODO
-use anyhow::{anyhow, Context as _, Error, Result};
-
 use door_sshproto as door;
-use door::{Behaviour, AsyncCliBehaviour, Runner, Conn};
+use door::{Behaviour, AsyncCliBehaviour, Runner, Conn, Result};
 // use door_sshproto::client::*;
 
 use pretty_hex::PrettyHex;
@@ -68,7 +64,7 @@ impl<'a> AsyncDoor<'a> {
         let res = {
             let mut inner = self.inner.lock().await;
             let inner = inner.deref_mut();
-            let ev = inner.runner.progress(&mut inner.behaviour).await.context("progess")?;
+            let ev = inner.runner.progress(&mut inner.behaviour).await?;
             let r = if let Some(ev) = ev {
                 let r = f(ev);
                 inner.runner.done_payload()?;
@@ -132,46 +128,6 @@ impl Clone for AsyncDoor<'_> {
 }
 
 
-pub struct SSHClient<'a> {
-    door: AsyncDoor<'a>,
-}
-
-impl<'a> SSHClient<'a> {
-    pub fn new(buf: &'a mut [u8], behaviour: Box<dyn AsyncCliBehaviour+Send>) -> Result<Self> {
-        let conn = Conn::new_client()?;
-        let runner = Runner::new(conn, buf)?;
-        let b = Behaviour::new_async_client(behaviour);
-        let door = AsyncDoor::new(runner, b);
-        Ok(Self {
-            door
-        })
-    }
-
-    pub fn socket(&self) -> AsyncDoorSocket<'a> {
-        self.door.socket()
-    }
-
-    pub async fn progress<F, R>(&mut self, f: F)
-        -> Result<Option<R>>
-        where F: FnOnce(door::Event) -> Result<Option<R>> {
-        self.door.progress(f).await
-    }
-
-    // TODO: return a Channel object that gives events like WinChange or exit status
-    // TODO: move to SimpleClient or something?
-    pub async fn open_client_session(&mut self, exec: Option<&str>, pty: bool)
-    -> Result<(ChanInOut<'a>, ChanExtIn<'a>)> {
-        let chan = self.door.with_runner(|runner| {
-            runner.open_client_session(exec, pty)
-        }).await?;
-
-        let cstd = ChanInOut::new(chan, &self.door);
-        let cerr = ChanExtIn::new(chan, SSH_EXTENDED_DATA_STDERR, &self.door);
-        Ok((cstd, cerr))
-    }
-
-}
-
 /// Tries to lock Inner for a poll_read()/poll_write().
 /// lock_fut from the caller holds the future so that it can
 /// be woken later if the lock was contended
@@ -322,7 +278,7 @@ pub struct ChanExtOut<'a> {
 }
 
 impl<'a> ChanInOut<'a> {
-    fn new(chan: u32, door: &AsyncDoor<'a>) -> Self {
+    pub(crate) fn new(chan: u32, door: &AsyncDoor<'a>) -> Self {
         Self {
             chan, door: door.clone(),
             rlfut: None, wlfut: None,
@@ -340,7 +296,7 @@ impl Clone for ChanInOut<'_> {
 }
 
 impl<'a> ChanExtIn<'a> {
-    fn new(chan: u32, ext: u32, door: &AsyncDoor<'a>) -> Self {
+    pub(crate) fn new(chan: u32, ext: u32, door: &AsyncDoor<'a>) -> Self {
         Self {
             chan, ext, door: door.clone(),
             rlfut: None,
diff --git a/async/src/client.rs b/async/src/client.rs
new file mode 100644
index 0000000..45f57c5
--- /dev/null
+++ b/async/src/client.rs
@@ -0,0 +1,73 @@
+#[allow(unused_imports)]
+use log::{debug, error, info, log, trace, warn};
+
+use snafu::{prelude::*, Whatever};
+
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+use tokio::io::unix::AsyncFd;
+use std::os::unix::io::{RawFd, FromRawFd};
+use std::io::{Read, Write};
+
+use std::io::Error as IoError;
+use std::io::ErrorKind;
+
+use core::pin::Pin;
+use core::task::{Context, Poll};
+
+use nix::fcntl::{fcntl, FcntlArg, OFlag};
+
+use crate::*;
+use crate::async_door::*;
+
+use door_sshproto as door;
+use door::{Behaviour, AsyncCliBehaviour, Runner, Conn, Result};
+use door::sshnames::SSH_EXTENDED_DATA_STDERR;
+
+pub struct SSHClient<'a> {
+    door: AsyncDoor<'a>,
+}
+
+impl<'a> SSHClient<'a> {
+    pub fn new(buf: &'a mut [u8], behaviour: Box<dyn AsyncCliBehaviour+Send>) -> Result<Self> {
+        let conn = Conn::new_client()?;
+        let runner = Runner::new(conn, buf)?;
+        let b = Behaviour::new_async_client(behaviour);
+        let door = AsyncDoor::new(runner, b);
+        Ok(Self {
+            door
+        })
+    }
+
+    pub fn socket(&self) -> AsyncDoorSocket<'a> {
+        self.door.socket()
+    }
+
+    pub async fn progress<F, R>(&mut self, f: F)
+        -> Result<Option<R>>
+        where F: FnOnce(door::Event) -> Result<Option<R>> {
+        self.door.progress(f).await
+    }
+
+    // TODO: return a Channel object that gives events like WinChange or exit status
+    // TODO: move to SimpleClient or something?
+    pub async fn open_client_session_nopty(&mut self, exec: Option<&str>)
+    -> Result<(ChanInOut<'a>, ChanExtIn<'a>)> {
+        let chan = self.door.with_runner(|runner| {
+            runner.open_client_session(exec, false)
+        }).await?;
+
+        let cstd = ChanInOut::new(chan, &self.door);
+        let cerr = ChanExtIn::new(chan, SSH_EXTENDED_DATA_STDERR, &self.door);
+        Ok((cstd, cerr))
+    }
+
+    pub async fn open_client_session_pty(&mut self, exec: Option<&str>)
+    -> Result<ChanInOut<'a>> {
+        let chan = self.door.with_runner(|runner| {
+            runner.open_client_session(exec, false)
+        }).await?;
+
+        let cstd = ChanInOut::new(chan, &self.door);
+        Ok(cstd)
+    }
+}
diff --git a/async/src/fdio.rs b/async/src/fdio.rs
new file mode 100644
index 0000000..deb99a1
--- /dev/null
+++ b/async/src/fdio.rs
@@ -0,0 +1,138 @@
+#[allow(unused_imports)]
+use log::{debug, error, info, log, trace, warn};
+
+use snafu::{prelude::*, Whatever};
+
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+use tokio::io::unix::AsyncFd;
+use std::os::unix::io::{RawFd, FromRawFd};
+use std::io::{Read, Write};
+
+use std::io::Error as IoError;
+
+use core::pin::Pin;
+use core::task::{Context, Poll};
+
+use nix::fcntl::{fcntl, FcntlArg, OFlag};
+use nix::errno::Errno;
+
+fn dup_async(orig_fd: libc::c_int) -> Result<AsyncFd<RawFd>, Whatever> {
+    let fd = nix::unistd::dup(orig_fd).whatever_context("dup() failed")?;
+    fcntl(fd, FcntlArg::F_SETFL(OFlag::O_NONBLOCK)).whatever_context("fcntl failed")?;
+    AsyncFd::new(fd).whatever_context("asyncfd")
+}
+
+pub struct Stdin {
+    f: AsyncFd<RawFd>,
+}
+pub struct Stdout {
+    f: AsyncFd<RawFd>,
+}
+pub struct Stderr {
+    f: AsyncFd<RawFd>,
+}
+
+pub fn stdin() -> Result<Stdin, Whatever> {
+    Ok(Stdin {
+        f: dup_async(libc::STDIN_FILENO)?,
+    })
+}
+pub fn stdout() -> Result<Stdout, Whatever> {
+    Ok(Stdout {
+        f: dup_async(libc::STDOUT_FILENO)?,
+    })
+}
+pub fn stderr() -> Result<Stderr, Whatever> {
+    Ok(Stderr {
+        f: dup_async(libc::STDERR_FILENO)?,
+    })
+}
+
+impl AsRef<AsyncFd<RawFd>> for Stdin {
+    fn as_ref(&self) -> &AsyncFd<RawFd> {
+        &self.f
+    }
+}
+
+impl AsRef<AsyncFd<RawFd>> for Stdout {
+    fn as_ref(&self) -> &AsyncFd<RawFd> {
+        &self.f
+    }
+}
+
+impl AsRef<AsyncFd<RawFd>> for Stderr {
+    fn as_ref(&self) -> &AsyncFd<RawFd> {
+        &self.f
+    }
+}
+
+impl AsyncRead for Stdin {
+    fn poll_read(
+        self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+        buf: &mut ReadBuf,
+    ) -> Poll<Result<(), IoError>> {
+        loop {
+            let mut guard = match self.f.poll_read_ready(cx)? {
+                Poll::Ready(r) => r,
+                Poll::Pending => return Poll::Pending,
+            };
+
+            match guard.try_io(|inner| {
+                let fd = *inner.get_ref();
+                let b = buf.initialize_unfilled();
+
+                let r = nix::unistd::read(fd, b);
+                match r {
+                    Ok(s) => {
+                        buf.advance(s);
+                        Ok(())
+                    }
+                    Err(_) => Err(std::io::Error::last_os_error()),
+                }
+            }) {
+                Ok(result) => return Poll::Ready(result),
+                Err(_would_block) => continue,
+            }
+        }
+    }
+}
+
+impl AsyncWrite for Stdout {
+    fn poll_write(
+        self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+        buf: &[u8]
+    ) -> Poll<std::io::Result<usize>> {
+        loop {
+            let mut guard = match self.f.poll_write_ready(cx)? {
+                Poll::Ready(r) => r,
+                Poll::Pending => return Poll::Pending,
+            };
+
+            match guard.try_io(|inner| {
+                let fd = *inner.get_ref();
+                nix::unistd::write(fd, buf)
+                    .map_err(|_| std::io::Error::last_os_error())
+            }) {
+                Ok(result) => return Poll::Ready(result),
+                Err(_would_block) => continue,
+            }
+        }
+    }
+
+    fn poll_flush(
+        self: Pin<&mut Self>,
+        _cx: &mut Context<'_>,
+    ) -> Poll<std::io::Result<()>> {
+        Poll::Ready(Ok(()))
+        }
+
+    fn poll_shutdown(
+        self: Pin<&mut Self>,
+        _cx: &mut Context<'_>,
+    ) -> Poll<std::io::Result<()>> {
+        nix::sys::socket::shutdown(*self.f.get_ref(), nix::sys::socket::Shutdown::Write)?;
+        Poll::Ready(Ok(()))
+    }
+}
diff --git a/async/src/lib.rs b/async/src/lib.rs
index dd6a837..954d514 100644
--- a/async/src/lib.rs
+++ b/async/src/lib.rs
@@ -1,8 +1,11 @@
 
-mod simple_client;
+mod client;
 mod async_door;
+mod simple_client;
+mod fdio;
 
-pub use simple_client::SimpleClient;
 pub use async_door::AsyncDoor;
-pub use async_door::SSHClient;
+pub use client::SSHClient;
+pub use fdio::{stdin, stdout, stderr};
+pub use simple_client::SimpleClient;
 
-- 
GitLab