Skip to content
Snippets Groups Projects
Commit 71b57696 authored by Matt Johnston's avatar Matt Johnston
Browse files

Bits of async channel, isn't quite working yet

parent 7a05a73f
Branches
Tags
No related merge requests found
...@@ -67,3 +67,7 @@ so `.await` fits well. ...@@ -67,3 +67,7 @@ so `.await` fits well.
The problem is that `async_trait` requires `Box`, won't work on `no_std`. The `Behaviour` struct has a `cfg` to switch The problem is that `async_trait` requires `Box`, won't work on `no_std`. The `Behaviour` struct has a `cfg` to switch
between async and non-async traits, hiding that from the main code. Eventually `async fn` [should work OK](https://github.com/rust-lang/rust/issues/91611) in static traits on `no_std`, and then it can be unified. between async and non-async traits, hiding that from the main code. Eventually `async fn` [should work OK](https://github.com/rust-lang/rust/issues/91611) in static traits on `no_std`, and then it can be unified.
## Async
The majority of packet dispatch handling isn't async, it just returns Ready straight away. Becaues of that we just have a Tokio `Mutex` which occassionally
holds the mutex across the `.await` boundary - it should seldom be contended.
...@@ -142,6 +142,7 @@ async fn run(args: &Args) -> Result<()> { ...@@ -142,6 +142,7 @@ async fn run(args: &Args) -> Result<()> {
// let mut stream = TcpStream::connect("130.95.13.18:22").await?; // let mut stream = TcpStream::connect("130.95.13.18:22").await?;
let mut work = vec![0; 3000]; let mut work = vec![0; 3000];
let work = Box::leak(Box::new(work));
let mut sess = door_smol::SimpleClient::new(args.username.as_ref().unwrap()); let mut sess = door_smol::SimpleClient::new(args.username.as_ref().unwrap());
for i in &args.identityfile { for i in &args.identityfile {
sess.add_authkey(read_key(&i) sess.add_authkey(read_key(&i)
...@@ -161,7 +162,6 @@ async fn run(args: &Args) -> Result<()> { ...@@ -161,7 +162,6 @@ async fn run(args: &Args) -> Result<()> {
pin_mut!(netio); pin_mut!(netio);
// let mut f = future::try_zip(netwrite, netread).fuse(); // let mut f = future::try_zip(netwrite, netread).fuse();
// f.await; // f.await;
let mut main_ch;
loop { loop {
tokio::select! { tokio::select! {
...@@ -178,10 +178,21 @@ async fn run(args: &Args) -> Result<()> { ...@@ -178,10 +178,21 @@ async fn run(args: &Args) -> Result<()> {
match ev { match ev {
Some(Event::Authenticated) => { Some(Event::Authenticated) => {
info!("auth auth"); info!("auth auth");
let ch = door.with_runner(|runner| { let r = door.open_client_session(Some("Cowsay it works"), false).await;
runner.open_client_session(Some("cowsay it works"), false) match r {
}).await?; Ok((mut stdio, mut _stderr)) => {
main_ch = Some(ch); tokio::spawn(async move {
trace!("io copy thread");
let r = tokio::io::copy(&mut stdio, &mut tokio::io::stdout()).await;
if let Err(e) = r {
warn!("IO error: {e}");
}
});
},
Err(e) => {
warn!("Failed opening session: {e}")
}
}
} }
Some(_) => unreachable!(), Some(_) => unreachable!(),
None => {}, None => {},
......
use door::sshnames::SSH_EXTENDED_DATA_STDERR;
#[allow(unused_imports)] #[allow(unused_imports)]
use log::{debug, error, info, log, trace, warn}; use log::{debug, error, info, log, trace, warn};
...@@ -6,15 +7,16 @@ use core::pin::Pin; ...@@ -6,15 +7,16 @@ use core::pin::Pin;
use core::task::{Context, Poll}; use core::task::{Context, Poll};
use pin_utils::pin_mut; use pin_utils::pin_mut;
use futures::lock::Mutex;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::Mutex as TokioMutex;
use tokio::sync::Notify as TokioNotify; use tokio::sync::Notify as TokioNotify;
use std::io::Error as IoError; use std::io::Error as IoError;
use std::io::ErrorKind; use std::io::ErrorKind;
use core::task::Waker; use core::task::Waker;
use std::sync::{Arc, Mutex, MutexGuard}; use std::sync::Arc;
use futures::task::AtomicWaker; use futures::task::AtomicWaker;
// TODO // TODO
...@@ -35,22 +37,24 @@ pub struct Inner<'a> { ...@@ -35,22 +37,24 @@ pub struct Inner<'a> {
behaviour: Behaviour<'a>, behaviour: Behaviour<'a>,
} }
#[derive(Clone)]
pub struct AsyncDoor<'a> { pub struct AsyncDoor<'a> {
inner: Arc<TokioMutex<Inner<'a>>>, // Not contended much since the Runner is inherently single threaded anyway,
// using a single buffer for input/output.
inner: Arc<Mutex<Inner<'a>>>,
read_waker: Arc<AtomicWaker>,
write_waker: Arc<AtomicWaker>,
progress_notify: Arc<TokioNotify>, progress_notify: Arc<TokioNotify>,
} }
impl<'a> AsyncDoor<'a> { impl<'a> AsyncDoor<'a> {
pub fn new(runner: Runner<'a>, behaviour: Behaviour<'a>) -> Self { pub fn new(runner: Runner<'a>, behaviour: Behaviour<'a>) -> Self {
let inner = Arc::new(TokioMutex::new(Inner { runner, behaviour })); let inner = Arc::new(Mutex::new(Inner { runner, behaviour }));
let read_waker = Arc::new(AtomicWaker::new());
let write_waker = Arc::new(AtomicWaker::new());
let progress_notify = Arc::new(TokioNotify::new()); let progress_notify = Arc::new(TokioNotify::new());
Self { inner, read_waker, write_waker, progress_notify } Self { inner, progress_notify }
}
pub fn clone(&'_ self) -> Self {
Self { inner: self.inner.clone(),
progress_notify: self.progress_notify.clone() }
} }
pub async fn progress<F, R>(&mut self, f: F) pub async fn progress<F, R>(&mut self, f: F)
...@@ -69,8 +73,6 @@ impl<'a> AsyncDoor<'a> { ...@@ -69,8 +73,6 @@ impl<'a> AsyncDoor<'a> {
Ok(None) Ok(None)
} }
}; };
self.read_waker.take().map(|w| w.wake());
self.write_waker.take().map(|w| w.wake());
// TODO: currently this is only woken by incoming data, should it // TODO: currently this is only woken by incoming data, should it
// also wake internally from runner or conn? It runs once at start // also wake internally from runner or conn? It runs once at start
// to kick off the outgoing handshake at least. // to kick off the outgoing handshake at least.
...@@ -86,6 +88,25 @@ impl<'a> AsyncDoor<'a> { ...@@ -86,6 +88,25 @@ impl<'a> AsyncDoor<'a> {
let mut inner = self.inner.lock().await; let mut inner = self.inner.lock().await;
f(&mut inner.runner) f(&mut inner.runner)
} }
// fn channel_poll_read(
// self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// buf: &mut ReadBuf,
pub async fn open_client_session(&mut self, exec: Option<&str>, pty: bool)
-> Result<(ChanInOut<'a>, ChanExtIn<'a>)> {
let chan = self.with_runner(|runner| {
runner.open_client_session(exec, pty)
}).await?;
let door = self.clone();
let cstd = ChanInOut { door, chan };
let door = self.clone();
let cerr = ChanExtIn { door, chan, ext: SSH_EXTENDED_DATA_STDERR };
Ok((cstd, cerr))
}
} }
impl<'a> AsyncRead for AsyncDoor<'a> { impl<'a> AsyncRead for AsyncDoor<'a> {
...@@ -96,21 +117,21 @@ impl<'a> AsyncRead for AsyncDoor<'a> { ...@@ -96,21 +117,21 @@ impl<'a> AsyncRead for AsyncDoor<'a> {
) -> Poll<Result<(), IoError>> { ) -> Poll<Result<(), IoError>> {
trace!("poll_read"); trace!("poll_read");
// try to lock, or return pending // TODO: can this go into a common function returning a Poll<MappedMutexGuard<Runner>>?
self.read_waker.register(cx.waker()); // Lifetimes seem tricky.
let mut inner = self.inner.try_lock(); let g = self.inner.lock();
let runner = if let Ok(ref mut inner) = inner { pin_mut!(g);
self.read_waker.take(); let mut g = g.poll(cx);
&mut inner.deref_mut().runner let runner = match g {
} else { Poll::Ready(ref mut i) => &mut i.runner,
return Poll::Pending Poll::Pending => return Poll::Pending,
}; };
let b = buf.initialize_unfilled(); let b = buf.initialize_unfilled();
let r = runner.output(b).map_err(|e| IoError::new(ErrorKind::Other, e)); let r = runner.output(b).map_err(|e| IoError::new(ErrorKind::Other, e));
let r = match r { match r {
// sz=0 means EOF // sz=0 means EOF, we don't want that
Ok(0) => { Ok(0) => {
runner.set_output_waker(cx.waker().clone()); runner.set_output_waker(cx.waker().clone());
Poll::Pending Poll::Pending
...@@ -121,15 +142,7 @@ impl<'a> AsyncRead for AsyncDoor<'a> { ...@@ -121,15 +142,7 @@ impl<'a> AsyncRead for AsyncDoor<'a> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
Err(e) => Poll::Ready(Err(e)), Err(e) => Poll::Ready(Err(e)),
}; }
// drop the mutex guard before waking others
drop(inner);
self.write_waker.take().map(|w| {
trace!("wake write_waker");
w.wake()
});
r
} }
} }
...@@ -141,14 +154,12 @@ impl<'a> AsyncWrite for AsyncDoor<'a> { ...@@ -141,14 +154,12 @@ impl<'a> AsyncWrite for AsyncDoor<'a> {
) -> Poll<Result<usize, IoError>> { ) -> Poll<Result<usize, IoError>> {
trace!("poll_write"); trace!("poll_write");
// try to lock, or return pending let g = self.inner.lock();
self.write_waker.register(cx.waker()); pin_mut!(g);
let mut inner = self.inner.try_lock(); let mut g = g.poll(cx);
let runner = if let Ok(ref mut inner) = inner { let runner = match g {
self.write_waker.take(); Poll::Ready(ref mut i) => &mut i.runner,
&mut inner.deref_mut().runner Poll::Pending => return Poll::Pending,
} else {
return Poll::Pending
}; };
// TODO: should runner just have poll_write/poll_read? // TODO: should runner just have poll_write/poll_read?
...@@ -164,19 +175,14 @@ impl<'a> AsyncWrite for AsyncDoor<'a> { ...@@ -164,19 +175,14 @@ impl<'a> AsyncWrite for AsyncDoor<'a> {
Poll::Pending Poll::Pending
}; };
// drop the mutex guard before waking others // drop before waking others
drop(inner); drop(runner);
if let Poll::Ready(_) = r { if let Poll::Ready(_) = r {
// TODO: only notify if packet traffic.payload().is_some() ? // TODO: only notify if packet traffic.payload().is_some() ?
self.progress_notify.notify_one(); self.progress_notify.notify_one();
trace!("notify progress"); trace!("notify progress");
} }
// TODO: check output_pending() before waking?
self.read_waker.take().map(|w| {
trace!("wake read_waker");
w.wake()
});
r r
} }
...@@ -194,3 +200,137 @@ impl<'a> AsyncWrite for AsyncDoor<'a> { ...@@ -194,3 +200,137 @@ impl<'a> AsyncWrite for AsyncDoor<'a> {
todo!("poll_close") todo!("poll_close")
} }
} }
pub struct ChanInOut<'a> {
chan: u32,
door: AsyncDoor<'a>,
}
pub struct ChanExtIn<'a> {
chan: u32,
ext: u32,
door: AsyncDoor<'a>,
}
pub struct ChanExtOut<'a> {
chan: u32,
ext: u32,
door: AsyncDoor<'a>,
}
impl<'a> ChanInOut<'a> {
fn new(chan: u32, door: &AsyncDoor<'a>) -> Self {
Self {
chan, door: door.clone(),
}
}
}
impl<'a> ChanExtIn<'a> {
fn new(chan: u32, ext: u32, door: &AsyncDoor<'a>) -> Self {
Self {
chan, ext, door: door.clone(),
}
}
}
impl<'a> AsyncRead for ChanInOut<'a> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<Result<(), IoError>> {
let this = self.deref_mut();
chan_poll_read(&mut this.door, this.chan, None, cx, buf)
}
}
impl<'a> AsyncRead for ChanExtIn<'a> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<Result<(), IoError>> {
let this = self.deref_mut();
chan_poll_read(&mut this.door, this.chan, Some(this.ext), cx, buf)
}
}
// Common for `ChanInOut` and `ChanExtIn`
fn chan_poll_read(
door: &mut AsyncDoor,
chan: u32,
ext: Option<u32>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<Result<(), IoError>> {
error!("chan_poll_read {chan} {ext:?}");
let g = door.inner.lock();
pin_mut!(g);
let mut g = g.poll(cx);
let runner = match g {
Poll::Ready(ref mut i) => &mut i.runner,
Poll::Pending => {
trace!("lock pending");
return Poll::Pending
}
};
trace!("chan_poll_read locked");
let b = buf.initialize_unfilled();
let r = runner.channel_input(chan, ext, b)
.map_err(|e| IoError::new(std::io::ErrorKind::Other, e));
match r {
// sz=0 means EOF, we don't want that
Ok(0) => {
runner.set_output_waker(cx.waker().clone());
Poll::Pending
}
Ok(sz) => {
buf.advance(sz);
Poll::Ready(Ok(()))
}
Err(e) => Poll::Ready(Err(e)),
}
}
impl<'a> AsyncWrite for ChanInOut<'a> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, IoError>> {
let this = self.deref_mut();
chan_poll_write(&mut this.door, this.chan, None, cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
// perhaps common between InOut and ExtOut?
todo!("channel poll_shutdown")
}
}
fn chan_poll_write(
door: &mut AsyncDoor,
chan: u32,
ext: Option<u32>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, IoError>> {
let mut g = door.inner.lock();
pin_mut!(g);
let runner = match g.poll(cx) {
Poll::Ready(ref mut i) => &mut i.runner,
Poll::Pending => return Poll::Pending,
};
todo!()
}
...@@ -121,7 +121,9 @@ impl<'a> Runner<'a> { ...@@ -121,7 +121,9 @@ impl<'a> Runner<'a> {
} }
pub fn done_payload(&mut self) -> Result<()> { pub fn done_payload(&mut self) -> Result<()> {
self.traffic.done_payload() self.traffic.done_payload()?;
self.wake();
Ok(())
} }
pub fn wake(&mut self) { pub fn wake(&mut self) {
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment