From d2fc9d40eb3adf1973187630b4c9478016a27aa9 Mon Sep 17 00:00:00 2001
From: Matt Johnston <matt@ucc.asn.au>
Date: Tue, 23 May 2023 00:28:22 +0800
Subject: [PATCH] Add async::Read/Write for cdc usb serial

Ugly generic parameters and longer code, but it should be reusable.
---
 embassy/demos/picow/Cargo.lock       |   1 +
 embassy/demos/picow/Cargo.toml       |   2 +
 embassy/demos/picow/src/usbserial.rs | 147 +++++++++++++++++++--------
 embassy/src/embassy_sunset.rs        |  19 +++-
 embassy/src/lib.rs                   |   2 +-
 5 files changed, 128 insertions(+), 43 deletions(-)

diff --git a/embassy/demos/picow/Cargo.lock b/embassy/demos/picow/Cargo.lock
index 363ef47..7252ddb 100644
--- a/embassy/demos/picow/Cargo.lock
+++ b/embassy/demos/picow/Cargo.lock
@@ -1772,6 +1772,7 @@ dependencies = [
  "embassy-sync 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
  "embassy-time 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
  "embassy-usb",
+ "embassy-usb-driver",
  "embedded-hal 1.0.0-alpha.10",
  "embedded-hal-async",
  "embedded-io 0.4.0",
diff --git a/embassy/demos/picow/Cargo.toml b/embassy/demos/picow/Cargo.toml
index 0612258..18ad37f 100644
--- a/embassy/demos/picow/Cargo.toml
+++ b/embassy/demos/picow/Cargo.toml
@@ -23,6 +23,7 @@ embassy-rp = { version = "0.1.0",  features = ["defmt", "unstable-traits", "nigh
 # embassy-net/nightly is required for asynch::Read/Write on TcpReader/TcpWriter
 embassy-net = { version = "0.1.0", features = ["tcp", "dhcpv4", "medium-ethernet", "nightly"] }
 embassy-net-driver = { version = "0.1.0" }
+embassy-usb-driver = { version = "0.1.0" }
 embassy-sync = { version = "0.2.0" }
 embassy-futures = { version = "0.1.0" }
 embassy-usb = { version = "0.1.0" }
@@ -70,6 +71,7 @@ embassy-futures = { git = "https://github.com/embassy-rs/embassy", rev = "3e730a
 embassy-rp = { git = "https://github.com/embassy-rs/embassy", rev = "3e730aa8b06401003202bf9e21a9c83ec6b21b0e" }
 embassy-net = { git = "https://github.com/embassy-rs/embassy", rev = "3e730aa8b06401003202bf9e21a9c83ec6b21b0e" }
 embassy-usb = { git = "https://github.com/embassy-rs/embassy", rev = "3e730aa8b06401003202bf9e21a9c83ec6b21b0e" }
+embassy-usb-driver = { git = "https://github.com/embassy-rs/embassy", rev = "3e730aa8b06401003202bf9e21a9c83ec6b21b0e" }
 # for cyw43
 embassy-net-driver-channel = { git = "https://github.com/embassy-rs/embassy", rev = "3e730aa8b06401003202bf9e21a9c83ec6b21b0e" }
 embassy-net-driver = { git = "https://github.com/embassy-rs/embassy", rev = "3e730aa8b06401003202bf9e21a9c83ec6b21b0e" }
diff --git a/embassy/demos/picow/src/usbserial.rs b/embassy/demos/picow/src/usbserial.rs
index 7a90a46..4bd7f73 100644
--- a/embassy/demos/picow/src/usbserial.rs
+++ b/embassy/demos/picow/src/usbserial.rs
@@ -1,29 +1,32 @@
-
 #[allow(unused_imports)]
 #[cfg(not(feature = "defmt"))]
-pub use {
-    log::{debug, error, info, log, trace, warn},
-};
+pub use log::{debug, error, info, log, trace, warn};
 
 #[allow(unused_imports)]
 #[cfg(feature = "defmt")]
-pub use defmt::{debug, info, warn, panic, error, trace};
+pub use defmt::{debug, error, info, panic, trace, warn};
 
-use embassy_usb::{Builder};
-use embassy_rp::usb::Instance;
-use embassy_usb::class::cdc_acm::{CdcAcmClass, State};
 use embassy_futures::join::join;
+use embassy_rp::usb::Instance;
+use embassy_usb::class::cdc_acm::{self, CdcAcmClass, State};
+use embassy_usb::Builder;
+use embassy_usb_driver::Driver;
 
-use embedded_io::asynch;
+use embedded_io::{asynch, Io};
+use heapless::Vec;
 
 use sunset::*;
+use sunset_embassy::io_copy;
 
-pub async fn usb_serial(usb: embassy_rp::peripherals::USB,
+pub async fn usb_serial<R, W>(
+    usb: embassy_rp::peripherals::USB,
     irq: embassy_rp::interrupt::USBCTRL_IRQ,
-    tx: &mut impl asynch::Write,
-    rx: &mut impl asynch::Read,
-    ) {
-
+    tx: &mut W,
+    rx: &mut R,
+)
+    where R: asynch::Read+Io<Error=sunset::Error>,
+        W: asynch::Write+Io<Error=sunset::Error>
+{
     info!("usb_serial top");
 
     let driver = embassy_rp::usb::Driver::new(usb, irq);
@@ -62,45 +65,24 @@ pub async fn usb_serial(usb: embassy_rp::peripherals::USB,
 
     let cdc = CdcAcmClass::new(&mut builder, &mut state, 64);
     let (mut cdc_tx, mut cdc_rx) = cdc.split();
+    // let cdc_tx = &mut cdc_tx;
+    // let cdc_rx = &mut cdc_rx;
 
     let mut usb = builder.build();
 
     // Run the USB device.
     let usb_fut = usb.run();
 
-    struct IoDone;
-
     let io = async {
         loop {
             info!("usb waiting");
             cdc_rx.wait_connection().await;
             info!("Connected");
+            let mut cdc_tx = CDCWrite::new(&mut cdc_tx);
+            let mut cdc_rx = CDCRead::new(&mut cdc_rx);
 
-            let io_tx = async {
-                let mut b = [0u8; 64];
-                loop {
-                    let n = cdc_rx.read_packet(&mut b).await .map_err(|_| IoDone)?;
-                    let b = &b[..n];
-                    tx.write_all(b).await.map_err(|_| IoDone)?;
-                }
-                #[allow(unreachable_code)]
-                Ok::<_, IoDone>(())
-            };
-
-            let io_rx = async {
-                // limit to 63 so we can ignore dealing with ZLPs for now
-                let mut b = [0u8; 63];
-                loop {
-                    let n = rx.read(&mut b).await.map_err(|_| IoDone)?;
-                    if n == 0 {
-                        return Err(IoDone);
-                    }
-                    let b = &b[..n];
-                    cdc_tx.write_packet(b).await.map_err(|_| IoDone)?;
-                }
-                #[allow(unreachable_code)]
-                Ok::<_, IoDone>(())
-            };
+            let io_tx = io_copy::<64, _, _>(&mut cdc_rx, tx);
+            let io_rx = io_copy::<64, _, _>(rx, &mut cdc_tx);
 
             join(io_rx, io_tx).await;
             info!("Disconnected");
@@ -109,5 +91,88 @@ pub async fn usb_serial(usb: embassy_rp::peripherals::USB,
 
     info!("usb join");
     join(usb_fut, io).await;
+}
+
+pub struct CDCRead<'a, 'p, D: Driver<'a>> {
+    cdc: &'p mut cdc_acm::Receiver<'a, D>,
+    // sufficient for max packet
+    buf: [u8; 64],
+    // when start reaches end, we set both to 0.
+    start: usize,
+    end: usize,
+}
+
+impl<'a, 'p, D: Driver<'a>> CDCRead<'a, 'p, D> {
+    pub fn new(cdc: &'p mut cdc_acm::Receiver<'a, D>) -> Self {
+        Self { cdc, buf: [0u8; 64], start: 0, end: 0 }
+    }
+}
+
+impl<'a, D: Driver<'a>> asynch::Read for CDCRead<'a, '_, D> {
+    async fn read(&mut self, ret: &mut [u8]) -> sunset::Result<usize> {
+        debug_assert!(self.start < self.end || self.end == 0);
+
+        // return existing content first
+        if self.end > 0 {
+            let l = ret.len().min(self.end - self.start);
+            ret.copy_from_slice(&self.buf[self.start..self.start + l]);
+            self.start += l;
+            if self.start == self.end {
+                self.start = 0;
+                self.end = 0;
+            }
+            return Ok(l);
+        }
+
+        debug_assert!(self.start == 0);
+        debug_assert!(self.end == 0);
+
+        let (r, buffered) = if ret.len() >= self.buf.len() {
+            // read in-place
+            (self.cdc.read_packet(ret).await, false)
+        } else {
+            // ret is too small, use buffer
+            (self.cdc.read_packet(self.buf.as_mut()).await, true)
+        };
+
+
+        let n = r.map_err(|_| sunset::Error::ChannelEOF)?;
+
+        if n == 0 {
+            return Err(sunset::Error::ChannelEOF);
+        }
+
+        if buffered {
+            ret.copy_from_slice(&self.buf[..n]);
+        }
+        Ok(n)
+    }
+}
+
+impl<'a, D: Driver<'a>> Io for CDCRead<'a, '_, D> {
+    type Error = sunset::Error;
+}
+
+pub struct CDCWrite<'a, 'p, D: Driver<'a>>(&'p mut cdc_acm::Sender<'a, D>);
+
+impl<'a, 'p, D: Driver<'a>> CDCWrite<'a, 'p, D> {
+    pub fn new(cdc: &'p mut cdc_acm::Sender<'a, D>) -> Self {
+        Self(cdc)
+    }
+}
+
+impl<'a, D: Driver<'a>> asynch::Write for CDCWrite<'a, '_, D> {
+    async fn write(&mut self, buf: &[u8]) -> sunset::Result<usize> {
+        // limit to 63 so we can ignore dealing with ZLPs for now
+        let b = &buf[..buf.len().min(63)];
+        self.0
+            .write_packet(b)
+            .await
+            .map_err(|_| sunset::Error::ChannelEOF)?;
+        Ok(b.len())
+    }
+}
 
+impl<'a, D: Driver<'a>> Io for CDCWrite<'a, '_, D> {
+    type Error = sunset::Error;
 }
diff --git a/embassy/src/embassy_sunset.rs b/embassy/src/embassy_sunset.rs
index 5dab90f..6dd9be8 100644
--- a/embassy/src/embassy_sunset.rs
+++ b/embassy/src/embassy_sunset.rs
@@ -15,7 +15,7 @@ use embassy_sync::mutex::Mutex;
 use embassy_sync::signal::Signal;
 use embassy_futures::select::select;
 use embassy_futures::join;
-use embedded_io::asynch;
+use embedded_io::{asynch, asynch::Write, Io};
 
 // thumbv6m has no atomic usize add/sub
 use atomic_polyfill::AtomicUsize;
@@ -513,3 +513,20 @@ impl<'a, C: CliBehaviour, S: ServBehaviour> EmbassySunset<'a, C, S> {
     }
 }
 
+
+pub async fn io_copy<const B: usize, R, W>(r: &mut R, w: &mut W) -> Result<()>
+    where R: asynch::Read+Io<Error=sunset::Error>,
+        W: asynch::Write+Io<Error=sunset::Error>
+{
+    let mut b = [0u8; B];
+    loop {
+        let n = r.read(&mut b).await?;
+        if n == 0 {
+            return sunset::error::ChannelEOF.fail();
+        }
+        let b = &b[..n];
+        w.write_all(b).await?;
+    }
+    #[allow(unreachable_code)]
+    Ok::<_, Error>(())
+}
diff --git a/embassy/src/lib.rs b/embassy/src/lib.rs
index abe0037..d5db6eb 100644
--- a/embassy/src/lib.rs
+++ b/embassy/src/lib.rs
@@ -19,4 +19,4 @@ pub use client::SSHClient;
 
 pub use embassy_channel::{ChanInOut, ChanIn, ChanOut};
 
-pub use embassy_sunset::{SunsetMutex, SunsetRawMutex};
+pub use embassy_sunset::{SunsetMutex, SunsetRawMutex, io_copy};
-- 
GitLab