From 96bd27839b57c75001f7bc2cfa090d58a00f9274 Mon Sep 17 00:00:00 2001
From: Matt Johnston <matt@ucc.asn.au>
Date: Mon, 29 Aug 2022 00:28:33 +0800
Subject: [PATCH] server progress

now can do password auth. pubkey hasn't been tested

serv1 copy loop doesn't work yet
---
 Cargo.lock               |   2 +-
 async/examples/serv1.rs  |  66 +++++++++++-----
 async/src/async_door.rs  |   3 -
 sshproto/src/auth.rs     |  30 ++++++--
 sshproto/src/channel.rs  |   3 +-
 sshproto/src/cliauth.rs  |   6 +-
 sshproto/src/conn.rs     |  50 +++++++++++-
 sshproto/src/kex.rs      |   3 +-
 sshproto/src/packets.rs  |  85 ++++++++++++++-------
 sshproto/src/runner.rs   |   4 +
 sshproto/src/servauth.rs | 160 +++++++++++++++++++++++----------------
 sshproto/src/sign.rs     |  11 ++-
 sshproto/src/sshwire.rs  |   5 +-
 13 files changed, 284 insertions(+), 144 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index f129f58..f3a1c77 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1124,7 +1124,7 @@ checksum = "f3f6f92acf49d1b98f7a81226834412ada05458b7364277387724a237f062695"
 [[package]]
 name = "salty"
 version = "0.2.0"
-source = "git+https://github.com/mkj/salty?branch=parts#d886ddcef3159eb5e6632a388dd2f93365598311"
+source = "git+https://github.com/mkj/salty?branch=parts#856c2ca9f491cdb776dcf20350a77da5ec0aaa23"
 dependencies = [
  "digest 0.10.3",
  "ed25519",
diff --git a/async/examples/serv1.rs b/async/examples/serv1.rs
index 663b1bd..1603792 100644
--- a/async/examples/serv1.rs
+++ b/async/examples/serv1.rs
@@ -12,7 +12,7 @@ use std::{net::Ipv6Addr, io::Read};
 use std::path::Path;
 
 use door_sshproto::*;
-use door_async::{SSHServer, raw_pty};
+use door_async::*;
 
 use async_trait::async_trait;
 
@@ -98,8 +98,11 @@ fn read_key(p: &str) -> Result<SignKey> {
 }
 
 struct DemoServer {
+    keys: Vec<SignKey>,
+
     sess: Option<u32>,
-    keys: Vec<SignKey>
+    want_shell: bool,
+    shell_started: bool,
 }
 
 impl DemoServer {
@@ -111,6 +114,8 @@ impl DemoServer {
         Ok(Self {
             sess: None,
             keys,
+            want_shell: false,
+            shell_started: false,
         })
     }
 }
@@ -121,16 +126,16 @@ impl ServBehaviour for DemoServer {
     }
 
 
-    fn have_auth_password(&self, user: &str) -> bool {
+    fn have_auth_password(&self, user: TextString) -> bool {
         true
     }
 
-    fn have_auth_pubkey(&self, user: &str) -> bool {
+    fn have_auth_pubkey(&self, user: TextString) -> bool {
         false
     }
 
-    fn auth_password(&mut self, user: &str, password: &str) -> bool {
-        user == "matt" && password == "pw"
+    fn auth_password(&mut self, user: TextString, password: TextString) -> bool {
+        user.as_str().unwrap_or("") == "matt" && password.as_str().unwrap_or("") == "pw"
     }
 
     fn open_session(&mut self, chan: u32) -> ChanOpened {
@@ -142,12 +147,14 @@ impl ServBehaviour for DemoServer {
         }
     }
 
-    fn sess_req_shell(&mut self, _chan: u32) -> bool {
-        true
+    fn sess_req_shell(&mut self, chan: u32) -> bool {
+        let r = !self.want_shell && self.sess == Some(chan);
+        self.want_shell = true;
+        r
     }
 
-    fn sess_pty(&mut self, _chan: u32, _pty: &Pty) -> bool {
-        true
+    fn sess_pty(&mut self, chan: u32, _pty: &Pty) -> bool {
+        self.sess == Some(chan)
     }
 }
 
@@ -158,22 +165,40 @@ fn run_session<'a, R: Send>(args: &'a Args, scope: &'a moro::Scope<'a, '_, R>, m
         let mut app = DemoServer::new(&args.hostkey)?;
         let mut rxbuf = vec![0; 3000];
         let mut txbuf = vec![0; 3000];
-        let mut serv = SSHServer::new(&mut rxbuf, &mut txbuf)?;
+        let mut serv = SSHServer::new(&mut rxbuf, &mut txbuf, &mut app)?;
         let mut s = serv.socket();
 
-        moro::async_scope!(|scope| {
+        let w = moro::async_scope!(|scope| {
 
             scope.spawn(tokio::io::copy_bidirectional(&mut stream, &mut s));
 
-            scope.spawn(async {
+            let v = scope.spawn(async {
                 loop {
-                    let ev = serv.progress(&mut app).await.context("progress loop")?;
+                    serv.progress(&mut app).await.context("progress loop")?;
+                    if app.want_shell && !app.shell_started {
+
+                        if let Some(ch) = app.sess {
+                            let ch = ch.clone();
+                            let (mut inout, mut _ext) = serv.channel(ch).await?;
+                            let mut o = inout.clone();
+                            scope.spawn(async move {
+                                tokio::io::copy(&mut o, &mut inout).await?;
+                                error!("fell out of stdio loop");
+                                Ok::<_, anyhow::Error>(())
+                            });
+                        }
+
+                    }
                 }
                 #[allow(unreachable_code)]
                 Ok::<_, anyhow::Error>(())
-            });
-            Ok::<_, anyhow::Error>(())
-        }).await
+            }).await;
+
+            let r: () = scope.terminate(v).await;
+            Ok(())
+        }).await;
+        trace!("Finished session {:?}", w);
+        w
     });
     Ok(())
 
@@ -181,7 +206,7 @@ fn run_session<'a, R: Send>(args: &'a Args, scope: &'a moro::Scope<'a, '_, R>, m
 
 async fn run(args: &Args) -> Result<()> {
     // TODO not localhost. also ipv6?
-    let listener = TcpListener::bind(("localhost", args.port)).await.context("Listening")?;
+    let listener = TcpListener::bind(("127.6.6.6", args.port)).await.context("Listening")?;
     moro::async_scope!(|scope| {
         scope.spawn(async {
             loop {
@@ -191,8 +216,7 @@ async fn run(args: &Args) -> Result<()> {
             }
             #[allow(unreachable_code)]
             Ok::<_, anyhow::Error>(())
-        });
+        }).await
 
-    }).await;
-    Ok(())
+    }).await
 }
diff --git a/async/src/async_door.rs b/async/src/async_door.rs
index c1231e9..02583f1 100644
--- a/async/src/async_door.rs
+++ b/async/src/async_door.rs
@@ -32,12 +32,9 @@ pub(crate) struct Inner<'a> {
 }
 
 pub struct AsyncDoor<'a> {
-    // Not contended much since the Runner is inherently single threaded anyway,
-    // using a single buffer for input/output.
     pub(crate) inner: Arc<Mutex<Inner<'a>>>,
 
     progress_notify: Arc<TokioNotify>,
-
 }
 
 impl<'a> AsyncDoor<'a> {
diff --git a/sshproto/src/auth.rs b/sshproto/src/auth.rs
index 5646f3e..5eace8f 100644
--- a/sshproto/src/auth.rs
+++ b/sshproto/src/auth.rs
@@ -14,20 +14,36 @@ use packets::ParseContext;
 use packets::{Packet, Signature, Userauth60};
 use sign::SignKey;
 use sshnames::*;
-use sshwire::BinString;
-use sshwire_derive::SSHEncode;
+use sshwire::{BinString, SSHEncode, WireResult};
+use kex::SessId;
 
 /// The message to be signed in a pubkey authentication message,
 /// RFC4252 Section 7. The packet is a UserauthRequest, with None sig.
-#[derive(SSHEncode)]
 pub(crate) struct AuthSigMsg<'a> {
     pub sess_id: BinString<'a>,
+    pub u: &'a packets::UserauthRequest<'a>,
+}
+
+impl SSHEncode for AuthSigMsg<'_> {
+    fn enc<S>(&self, s: &mut S) -> WireResult<()>
+    where S: sshwire::SSHSink {
+        self.sess_id.enc(s)?;
 
-    // always SSH_MSG_USERAUTH_REQUEST
-    pub msg_num: u8,
+        let m = packets::MessageNumber::SSH_MSG_USERAUTH_REQUEST as u8;
+        m.enc(s)?;
+
+        (*self.u).enc(s)?;
+        Ok(())
+    }
+}
 
-    //TODO: does encoding the whole Packet enum bloat binary?
-    pub u: packets::UserauthRequest<'a>,
+impl<'a> AuthSigMsg<'a> {
+    pub fn new(u: &'a packets::UserauthRequest<'a>, sess_id: &'a SessId) -> Self {
+        auth::AuthSigMsg {
+            sess_id: BinString(sess_id.as_ref()),
+            u,
+        }
+    }
 }
 
 #[derive(Clone, Debug)]
diff --git a/sshproto/src/channel.rs b/sshproto/src/channel.rs
index 6203f2c..8fa2bc2 100644
--- a/sshproto/src/channel.rs
+++ b/sshproto/src/channel.rs
@@ -410,6 +410,7 @@ impl Channels {
     }
 }
 
+#[derive(Clone, Copy)]
 pub enum ChanType {
     Session,
     Tcp,
@@ -546,7 +547,7 @@ enum ChanState {
 }
 
 pub(crate) struct Channel {
-    ty: ChanType,
+    pub ty: ChanType,
     state: ChanState,
     sent_eof: bool,
     sent_close: bool,
diff --git a/sshproto/src/cliauth.rs b/sshproto/src/cliauth.rs
index 6acaab8..f9cdc8e 100644
--- a/sshproto/src/cliauth.rs
+++ b/sshproto/src/cliauth.rs
@@ -164,11 +164,7 @@ impl CliAuth {
                 }),
             };
 
-            let msg = auth::AuthSigMsg {
-                sess_id: BinString(sess_id.as_ref()),
-                msg_num: MessageNumber::SSH_MSG_USERAUTH_REQUEST as u8,
-                u: sig_packet,
-            };
+            let msg = auth::AuthSigMsg::new(&sig_packet, sess_id);
             let mut ctx = ParseContext::default();
             ctx.method_pubkey_force_sig_bool = true;
             key.sign(&msg, Some(&ctx))
diff --git a/sshproto/src/conn.rs b/sshproto/src/conn.rs
index e646783..16ec256 100644
--- a/sshproto/src/conn.rs
+++ b/sshproto/src/conn.rs
@@ -192,13 +192,50 @@ impl<'a> Conn<'a> {
         }
     }
 
+    /// Check that a packet is received in the correct state
+    fn check_packet(&self, p: &Packet) -> Result<()> {
+        let r = match p.category() {
+            packets::Category::All => Ok(()),
+            packets::Category::Kex => {
+                match self.state {
+                    ConnState::InKex {..} => Ok(()),
+                    _ => Err(Error::SSHProtoError),
+                }
+            }
+            packets::Category::Auth => {
+                match self.state {
+                    | ConnState::PreAuth
+                    | ConnState::Authed
+                    => Ok(()),
+                    _ => Err(Error::SSHProtoError),
+                }
+            }
+            packets::Category::Sess => {
+                match self.state {
+                    ConnState::Authed
+                    => Ok(()),
+                    _ => Err(Error::SSHProtoError),
+                }
+            }
+        };
+
+        if r.is_err() {
+            error!("Received unexpected packet {}",
+                p.message_num() as u8);
+            debug!("state is {:?}", self.state);
+        }
+        r
+    }
+
     async fn dispatch_packet<'p>(
         &mut self, packet: Packet<'p>, s: &mut TrafSend<'_, '_>, b: &mut Behaviour<'_>,
     ) -> Result<Dispatched, Error> {
-        // TODO: perhaps could consolidate packet allowed checks into a separate function
-        // to run first?
+        // TODO: perhaps could consolidate packet client vs server checks
         trace!("Incoming {packet:#?}");
         let mut disp = Dispatched(None);
+
+        self.check_packet(&packet)?;
+
         match packet {
             Packet::KexInit(_) => {
                 if matches!(self.state, ConnState::InKex { .. }) {
@@ -296,7 +333,11 @@ impl<'a> Conn<'a> {
             }
             Packet::UserauthRequest(p) => {
                 if let ClientServer::Server(serv) = &mut self.cliserv {
-                    serv.auth.request(p, s, b.server()?)?;
+                    let sess_id = self.sess_id.as_ref().trap()?;
+                    let success = serv.auth.request(p, sess_id, s, b.server()?)?;
+                    if success {
+                        self.state = ConnState::Authed;
+                    }
                 } else {
                     debug!("Server sent an auth request");
                     return Err(Error::SSHProtoError)
@@ -335,7 +376,8 @@ impl<'a> Conn<'a> {
             Packet::Userauth60(p) => {
                 // TODO: client only
                 if let ClientServer::Client(cli) = &mut self.cliserv {
-                    cli.auth.auth60(&p, self.sess_id.as_ref().trap()?, &mut self.parse_ctx, s).await?;
+                    let sess_id = self.sess_id.as_ref().trap()?;
+                    cli.auth.auth60(&p, sess_id, &mut self.parse_ctx, s).await?;
                 } else {
                     debug!("Received userauth60 as a server");
                     return Err(Error::SSHProtoError)
diff --git a/sshproto/src/kex.rs b/sshproto/src/kex.rs
index 6e6f65d..27453c5 100644
--- a/sshproto/src/kex.rs
+++ b/sshproto/src/kex.rs
@@ -435,7 +435,8 @@ impl SharedSecret {
         };
 
         // TODO: error message on signature failure.
-        algos.hostsig.verify(&p.k_s.0, kex_out.h.as_ref(), &p.sig.0)?;
+        let h: &[u8] = kex_out.h.as_ref();
+        algos.hostsig.verify(&p.k_s.0, &h, &p.sig.0)?;
         debug!("Hostkey signature is valid");
         if matches!(b.valid_hostkey(&p.k_s.0), Ok(true)) {
             Ok(kex_out)
diff --git a/sshproto/src/packets.rs b/sshproto/src/packets.rs
index e4bd7c8..d255cfb 100644
--- a/sshproto/src/packets.rs
+++ b/sshproto/src/packets.rs
@@ -602,7 +602,13 @@ impl ParseContext {
 /// We have repeated `match` statements for the various packet types, use a macro
 macro_rules! messagetypes {
     (
-        $( ( $message_num:literal, $SpecificPacketVariant:ident, $SpecificPacketType:ty, $SSH_MESSAGE_NAME:ident ), )*
+        $( ( $message_num:literal,
+            $SpecificPacketVariant:ident,
+            $SpecificPacketType:ty,
+            $SSH_MESSAGE_NAME:ident,
+            $category:ident
+            ),
+             )*
     ) => {
 
 
@@ -697,6 +703,16 @@ impl<'a> Packet<'a> {
             )*
         }
     }
+
+    pub fn category(&self) -> Category {
+        match self {
+            // eg
+            // Packet::KexInit() => Category::Kex,
+            $(
+            Packet::$SpecificPacketVariant(_) => Category::$category,
+            )*
+        }
+    }
 }
 
 $(
@@ -709,43 +725,56 @@ impl<'a> From<$SpecificPacketType> for Packet<'a> {
 
 } } // macro
 
+pub enum Category {
+    /// Allowed at any time.
+    /// TODO: may need to limit some of these during KEX.
+    All,
+    /// After kexinit, before newkeys complete (other packets are not allowed during
+    /// that time).
+    Kex,
+    /// Post-kex
+    Auth,
+    /// Post-auth
+    Sess,
+}
+
 messagetypes![
-(1, Disconnect, Disconnect<'a>, SSH_MSG_DISCONNECT),
-(2, Ignore, Ignore, SSH_MSG_IGNORE),
-(3, Unimplemented, Unimplemented, SSH_MSG_UNIMPLEMENTED),
-(4, DebugPacket, DebugPacket<'a>, SSH_MSG_DEBUG),
-(5, ServiceRequest, ServiceRequest<'a>, SSH_MSG_SERVICE_REQUEST),
-(6, ServiceAccept, ServiceAccept<'a>, SSH_MSG_SERVICE_ACCEPT),
-(20, KexInit, KexInit<'a>, SSH_MSG_KEXINIT),
-(21, NewKeys, NewKeys, SSH_MSG_NEWKEYS),
-(30, KexDHInit, KexDHInit<'a>, SSH_MSG_KEXDH_INIT),
-(31, KexDHReply, KexDHReply<'a>, SSH_MSG_KEXDH_REPLY),
-
-(50, UserauthRequest, UserauthRequest<'a>, SSH_MSG_USERAUTH_REQUEST),
-(51, UserauthFailure, UserauthFailure<'a>, SSH_MSG_USERAUTH_FAILURE),
-(52, UserauthSuccess, UserauthSuccess, SSH_MSG_USERAUTH_SUCCESS),
-(53, UserauthBanner, UserauthBanner<'a>, SSH_MSG_USERAUTH_BANNER),
+(1, Disconnect, Disconnect<'a>, SSH_MSG_DISCONNECT, All),
+(2, Ignore, Ignore, SSH_MSG_IGNORE, All),
+(3, Unimplemented, Unimplemented, SSH_MSG_UNIMPLEMENTED, All),
+(4, DebugPacket, DebugPacket<'a>, SSH_MSG_DEBUG, All),
+(5, ServiceRequest, ServiceRequest<'a>, SSH_MSG_SERVICE_REQUEST, Auth),
+(6, ServiceAccept, ServiceAccept<'a>, SSH_MSG_SERVICE_ACCEPT, Auth),
+(20, KexInit, KexInit<'a>, SSH_MSG_KEXINIT, All),
+(21, NewKeys, NewKeys, SSH_MSG_NEWKEYS, Kex),
+(30, KexDHInit, KexDHInit<'a>, SSH_MSG_KEXDH_INIT, Kex),
+(31, KexDHReply, KexDHReply<'a>, SSH_MSG_KEXDH_REPLY, Kex),
+
+(50, UserauthRequest, UserauthRequest<'a>, SSH_MSG_USERAUTH_REQUEST, Auth),
+(51, UserauthFailure, UserauthFailure<'a>, SSH_MSG_USERAUTH_FAILURE, Auth),
+(52, UserauthSuccess, UserauthSuccess, SSH_MSG_USERAUTH_SUCCESS, Auth),
+(53, UserauthBanner, UserauthBanner<'a>, SSH_MSG_USERAUTH_BANNER, Auth),
 // One of
 // SSH_MSG_USERAUTH_PASSWD_CHANGEREQ
 // SSH_MSG_USERAUTH_PK_OK
 // SSH_MSG_USERAUTH_INFO_REQUEST
-(60, Userauth60, Userauth60<'a>, SSH_MSG_USERAUTH_60),
+(60, Userauth60, Userauth60<'a>, SSH_MSG_USERAUTH_60, Auth),
 // (61, SSH_MSG_USERAUTH_INFO_RESPONSE),
 
 // (80            SSH_MSG_GLOBAL_REQUEST),
 // (81            SSH_MSG_REQUEST_SUCCESS),
 // (82            SSH_MSG_REQUEST_FAILURE),
-(90, ChannelOpen, ChannelOpen<'a>, SSH_MSG_CHANNEL_OPEN),
-(91, ChannelOpenConfirmation, ChannelOpenConfirmation, SSH_MSG_CHANNEL_OPEN_CONFIRMATION),
-(92, ChannelOpenFailure, ChannelOpenFailure<'a>, SSH_MSG_CHANNEL_OPEN_FAILURE),
-(93, ChannelWindowAdjust, ChannelWindowAdjust, SSH_MSG_CHANNEL_WINDOW_ADJUST),
-(94, ChannelData, ChannelData<'a>, SSH_MSG_CHANNEL_DATA),
-(95, ChannelDataExt, ChannelDataExt<'a>, SSH_MSG_CHANNEL_EXTENDED_DATA),
-(96, ChannelEof, ChannelEof, SSH_MSG_CHANNEL_EOF),
-(97, ChannelClose, ChannelClose, SSH_MSG_CHANNEL_CLOSE),
-(98, ChannelRequest, ChannelRequest<'a>, SSH_MSG_CHANNEL_REQUEST),
-(99, ChannelSuccess, ChannelSuccess, SSH_MSG_CHANNEL_SUCCESS),
-(100, ChannelFailure, ChannelFailure, SSH_MSG_CHANNEL_FAILURE),
+(90, ChannelOpen, ChannelOpen<'a>, SSH_MSG_CHANNEL_OPEN, Sess),
+(91, ChannelOpenConfirmation, ChannelOpenConfirmation, SSH_MSG_CHANNEL_OPEN_CONFIRMATION, Sess),
+(92, ChannelOpenFailure, ChannelOpenFailure<'a>, SSH_MSG_CHANNEL_OPEN_FAILURE, Sess),
+(93, ChannelWindowAdjust, ChannelWindowAdjust, SSH_MSG_CHANNEL_WINDOW_ADJUST, Sess),
+(94, ChannelData, ChannelData<'a>, SSH_MSG_CHANNEL_DATA, Sess),
+(95, ChannelDataExt, ChannelDataExt<'a>, SSH_MSG_CHANNEL_EXTENDED_DATA, Sess),
+(96, ChannelEof, ChannelEof, SSH_MSG_CHANNEL_EOF, Sess),
+(97, ChannelClose, ChannelClose, SSH_MSG_CHANNEL_CLOSE, Sess),
+(98, ChannelRequest, ChannelRequest<'a>, SSH_MSG_CHANNEL_REQUEST, Sess),
+(99, ChannelSuccess, ChannelSuccess, SSH_MSG_CHANNEL_SUCCESS, Sess),
+(100, ChannelFailure, ChannelFailure, SSH_MSG_CHANNEL_FAILURE, Sess),
 ];
 
 #[cfg(test)]
diff --git a/sshproto/src/runner.rs b/sshproto/src/runner.rs
index 3c26752..31602ca 100644
--- a/sshproto/src/runner.rs
+++ b/sshproto/src/runner.rs
@@ -172,6 +172,10 @@ impl<'a> Runner<'a> {
         Ok(chan)
     }
 
+    pub fn channel_type(&self, chan: u32) -> Result<channel::ChanType> {
+        self.conn.channels.get(chan).map(|c| c.ty)
+    }
+
     /// Send data from this application out the wire.
     /// Returns `Some` the length of `buf` consumed, or `None` on EOF
     pub fn channel_send(
diff --git a/sshproto/src/servauth.rs b/sshproto/src/servauth.rs
index 325660c..5764cc4 100644
--- a/sshproto/src/servauth.rs
+++ b/sshproto/src/servauth.rs
@@ -9,22 +9,14 @@ use heapless::Vec;
 use crate::sshnames::*;
 use crate::*;
 use packets::{AuthMethod, Userauth60, UserauthPkOk};
+use sshwire::{BinString, Blob};
 use traffic::TrafSend;
+use kex::SessId;
 
 pub(crate) struct ServAuth {
     pub authed: bool,
 }
 
-// for auth_inner()
-enum AuthResp {
-    // success
-    Success,
-    // failed, send a response
-    Failure,
-    // failure, a response has already been send
-    FailNoReply,
-}
-
 impl ServAuth {
     pub fn new(b: &mut dyn ServBehaviour) -> Self {
         Self { authed: false }
@@ -33,20 +25,71 @@ impl ServAuth {
     /// Returns `true` if auth succeeds
     pub fn request(
         &self,
-        p: packets::UserauthRequest,
+        mut p: packets::UserauthRequest,
+        sess_id: &SessId,
         s: &mut TrafSend,
         b: &mut dyn ServBehaviour,
     ) -> Result<bool> {
-        let r = self.auth_inner(p, s, b)?;
 
-        match r {
+        enum AuthResp {
+            Success,
+            Failure,
+            // failure, a response has already been send
+            FailNoReply,
+        }
+
+        let username = p.username.clone();
+
+        let inner = || {
+            // even allows "none" auth
+            if b.auth_unchallenged(p.username) {
+                return Ok(AuthResp::Success) as Result<_>
+            }
+
+            let success = match p.method {
+                AuthMethod::Password(m) => b.auth_password(p.username, m.password),
+                AuthMethod::PubKey(ref m) => {
+                    let allowed_key = b.auth_pubkey(p.username, &m.pubkey.0);
+                    if allowed_key {
+                        if m.sig.is_some() {
+                            self.verify_sig(&mut p, sess_id)
+                        } else {
+                            s.send(Userauth60::PkOk(UserauthPkOk {
+                                algo: m.sig_algo,
+                                key: m.pubkey.clone(),
+                            }))?;
+                            return Ok(AuthResp::FailNoReply);
+                        }
+                    } else {
+                        false
+                    }
+                }
+                AuthMethod::None => {
+                    // nothing to do
+                    false
+                }
+                AuthMethod::Unknown(u) => {
+                    debug!("Request for unknown auth method {}", u);
+                    false
+                }
+            };
+
+            if success {
+                Ok(AuthResp::Success)
+            } else {
+                Ok(AuthResp::Failure)
+            }
+        };
+
+        // failure sends a list of available methods
+        match inner()? {
             AuthResp::Success => {
                 s.send(packets::UserauthSuccess {})?;
                 Ok(true)
             }
             AuthResp::Failure => {
-                let mut n: Vec<&str, NUM_AUTHMETHOD> = Vec::new();
-                let methods = self.avail_methods(&mut n);
+                let mut n: Vec<&'static str, NUM_AUTHMETHOD> = Vec::new();
+                let methods = self.avail_methods(username, &mut n, b);
                 let methods = (&methods).into();
 
                 s.send(packets::UserauthFailure { methods, partial: false })?;
@@ -56,57 +99,31 @@ impl ServAuth {
         }
     }
 
-    pub fn auth_inner(
-        &self,
-        p: packets::UserauthRequest,
-        s: &mut TrafSend,
-        b: &mut dyn ServBehaviour,
-    ) -> Result<AuthResp> {
-        // even allows "none" auth
-        if b.auth_unchallenged(p.username) {
-            return Ok(AuthResp::Success);
-        }
+    /// Must be passed a MethodPubkey packet with a signature part
+    fn verify_sig(&self, p: &mut packets::UserauthRequest, sess_id: &SessId) -> bool {
+        // Remove the signature from the packet - the signature message includes
+        // packet without that signature part.
 
-        let success = match p.method {
-            AuthMethod::Password(m) => b.auth_password(p.username, m.password),
-            AuthMethod::PubKey(m) => {
-                let allowed_key = b.auth_pubkey(p.username, &m.pubkey.0);
-                if allowed_key {
-                    if m.sig.is_none() {
-                        s.send(Userauth60::PkOk(UserauthPkOk {
-                            algo: m.sig_algo,
-                            key: m.pubkey,
-                        }))?;
-                        return Ok(AuthResp::FailNoReply);
-                    } else {
-                        self.verify_pubkey(&m)
-                    }
-                } else {
-                    false
-                }
-            }
-            AuthMethod::None => {
-                // nothing to do
-                false
-            }
-            AuthMethod::Unknown(u) => {
-                debug!("Request for unknown auth method {}", u);
-                false
+        let sig = match &mut p.method {
+            AuthMethod::PubKey(m) => m.sig.take(),
+            _ => {
+                debug_assert!(false, "must be passed MethodPubkey");
+                return false;
             }
         };
 
-        if success {
-            Ok(AuthResp::Success)
-        } else {
-            Ok(AuthResp::Failure)
-        }
-    }
+        // clumsy splitting m and p
+        let m = match &p.method {
+            AuthMethod::PubKey(m) => m,
+            _ => return false,
+        };
 
-    /// Returns `true` on successful signature verifcation. `false` on bad signature.
-    fn verify_pubkey(&self, m: &packets::MethodPubKey) -> bool {
-        let sig = match m.sig.as_ref() {
+        let sig = match sig.as_ref() {
             Some(s) => &s.0,
-            None => return false,
+            None => {
+                debug_assert!(false, "missing signature");
+                return false;
+            }
         };
 
         let sig_type = match sig.sig_type() {
@@ -114,19 +131,28 @@ impl ServAuth {
             Err(_) => return false,
         };
 
-        false
-        // XXX
-        // sig_type.verify(&m.pubkey.0, sess_id.as)
-
-        // m.pubkey.
+        let msg = auth::AuthSigMsg::new(&p, sess_id);
+        match sig_type.verify(&m.pubkey.0, &msg, sig) {
+            Ok(()) => true,
+            Err(_) => false,
+        }
     }
 
     fn avail_methods<'f>(
         &self,
-        buf: &'f mut Vec<&str, NUM_AUTHMETHOD>,
+        user: TextString,
+        buf: &'f mut Vec<&'static str, NUM_AUTHMETHOD>,
+        b: &mut dyn ServBehaviour,
     ) -> namelist::LocalNames<'f> {
         buf.clear();
-        // for
+
+        // OK unwrap: buf is large enough
+        if b.have_auth_password(user) {
+            buf.push(SSH_AUTHMETHOD_PASSWORD).unwrap()
+        }
+        if b.have_auth_pubkey(user) {
+            buf.push(SSH_AUTHMETHOD_PUBLICKEY).unwrap()
+        }
         buf.as_slice().into()
     }
 }
diff --git a/sshproto/src/sign.rs b/sshproto/src/sign.rs
index d156893..902f1fe 100644
--- a/sshproto/src/sign.rs
+++ b/sshproto/src/sign.rs
@@ -47,7 +47,7 @@ impl SigType {
 
     /// Returns `Ok(())` on success
     pub fn verify(
-        &self, pubkey: &PubKey, message: &[u8], sig: &Signature) -> Result<()> {
+        &self, pubkey: &PubKey, msg: &impl SSHEncode, sig: &Signature) -> Result<()> {
 
         // Check that the signature type is known
         let sig_type = sig.sig_type().map_err(|_| Error::BadSig)?;
@@ -68,7 +68,10 @@ impl SigType {
                 let k: salty::PublicKey = k.try_into().map_err(|_| Error::BadKey)?;
                 let s: &[u8; 64] = s.sig.0.try_into().map_err(|_| Error::BadSig)?;
                 let s: salty::Signature = s.into();
-                k.verify(message, &s).map_err(|_| Error::BadSig)
+                k.verify_parts(&s, |h| {
+                    sshwire::hash_ser(h, msg, None).map_err(|_| salty::Error::ContextTooLong)
+                })
+                .map_err(|_| Error::BadSig)
             }
 
             (SigType::RSA256, PubKey::RSA(_k), Signature::RSA256(_s)) => {
@@ -145,11 +148,11 @@ impl SignKey {
         }
     }
 
-    pub(crate) fn sign<'s>(&self, msg: &'s impl SSHEncode, parse_ctx: Option<&ParseContext>) -> Result<OwnedSig> {
+    pub(crate) fn sign(&self, msg: &impl SSHEncode, parse_ctx: Option<&ParseContext>) -> Result<OwnedSig> {
         match self {
             SignKey::Ed25519(k) => {
                 k.sign_parts(|h| {
-                    sshwire::hash_ser(h, parse_ctx, msg).map_err(|_| salty::Error::ContextTooLong)
+                    sshwire::hash_ser(h, msg, parse_ctx).map_err(|_| salty::Error::ContextTooLong)
                 })
                 .trap()
                 .map(|s| s.into())
diff --git a/sshproto/src/sshwire.rs b/sshproto/src/sshwire.rs
index 421f885..357a488 100644
--- a/sshproto/src/sshwire.rs
+++ b/sshproto/src/sshwire.rs
@@ -132,13 +132,14 @@ where
 {
     let len: u32 = length_enc(value)?;
     hash_ctx.update(&len.to_be_bytes());
-    hash_ser(hash_ctx, None, value)
+    hash_ser(hash_ctx, value, None)
 }
 
 /// Hashes the SSH wire format representation of `value`
 pub fn hash_ser<T>(hash_ctx: &mut impl digest::DynDigest,
+    value: &T,
     parse_ctx: Option<&ParseContext>,
-    value: &T) -> Result<()>
+    ) -> Result<()>
 where
     T: SSHEncode,
 {
-- 
GitLab