From 39b07b0cf75a77fb86b3f711f81d33b29b3e99ca Mon Sep 17 00:00:00 2001
From: Matt Johnston <matt@ucc.asn.au>
Date: Wed, 1 Jun 2022 00:19:54 +0800
Subject: [PATCH] use parse_tagged_attribute for container attributes

use SSH_NAME constants
---
 Cargo.lock                |  2 ++
 Cargo.toml                |  3 --
 sshproto/src/packets.rs   | 14 ++++----
 sshproto/src/sshnames.rs  |  5 ---
 sshwire_derive/src/lib.rs | 71 ++++++++++++---------------------------
 5 files changed, 30 insertions(+), 65 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 5a7b68e..1d980fe 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1265,6 +1265,8 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
 [[package]]
 name = "virtue"
 version = "0.0.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7b60dcd6a64dd45abf9bd426970c9843726da7fc08f44cd6fcebf68c21220a63"
 
 [[package]]
 name = "void"
diff --git a/Cargo.toml b/Cargo.toml
index e9a518a..0eb4833 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -19,6 +19,3 @@ debug = 1
 
 [patch."https://github.com/mkj/ed25519-dalek"]
 ed25519-dalek = { path = "/home/matt/3rd/rs/ed25519-dalek" }
-
-[patch.crates-io]
-virtue = { path = "/home/matt/3rd/rs/virtue" }
diff --git a/sshproto/src/packets.rs b/sshproto/src/packets.rs
index 031b167..03cc81c 100644
--- a/sshproto/src/packets.rs
+++ b/sshproto/src/packets.rs
@@ -101,11 +101,11 @@ pub struct UserauthRequest<'a> {
 #[derive(Debug, SSHEncode, SSHDecode)]
 #[sshwire(variant_prefix)]
 pub enum AuthMethod<'a> {
-    #[sshwire(variant = "password")]
+    #[sshwire(variant = SSH_AUTHMETHOD_PASSWORD)]
     Password(MethodPassword<'a>),
-    #[sshwire(variant = "publickey")]
+    #[sshwire(variant = SSH_AUTHMETHOD_PUBLICKEY)]
     PubKey(MethodPubKey<'a>),
-    #[sshwire(variant = "none")]
+    #[sshwire(variant = SSH_NAME_NONE)]
     None,
     #[sshwire(unknown)]
     Unknown(Unknown<'a>),
@@ -228,9 +228,9 @@ pub struct UserauthBanner<'a> {
 #[derive(SSHEncode, SSHDecode, Debug, Clone, PartialEq)]
 #[sshwire(variant_prefix)]
 pub enum PubKey<'a> {
-    #[sshwire(variant = "ssh-ed25519")]
+    #[sshwire(variant = SSH_NAME_ED25519)]
     Ed25519(Ed25519PubKey<'a>),
-    #[sshwire(variant = "ssh-rsa")]
+    #[sshwire(variant = SSH_NAME_RSA)]
     RSA(RSAPubKey<'a>),
     #[sshwire(unknown)]
     Unknown(Unknown<'a>),
@@ -262,9 +262,9 @@ pub struct RSAPubKey<'a> {
 #[derive(Debug, SSHEncode,  SSHDecode)]
 #[sshwire(variant_prefix)]
 pub enum Signature<'a> {
-    #[sshwire(variant = "ssh-ed25519")]
+    #[sshwire(variant = SSH_NAME_ED25519)]
     Ed25519(Ed25519Sig<'a>),
-    #[sshwire(variant = "rsa-sha2-256")]
+    #[sshwire(variant = SSH_NAME_RSA_SHA256)]
     RSA256(RSA256Sig<'a>),
     #[sshwire(unknown)]
     Unknown(Unknown<'a>),
diff --git a/sshproto/src/sshnames.rs b/sshproto/src/sshnames.rs
index f97e01a..5a535ba 100644
--- a/sshproto/src/sshnames.rs
+++ b/sshproto/src/sshnames.rs
@@ -1,11 +1,6 @@
 //! Named SSH algorithms, methods and extensions. This module also serves as
 //! an index of SSH specifications.
 
-// Note that some names are listed as string literals in packets.rs instead,
-// for #[serde(rename)].  Those could be moved here if this is resolved
-// https://github.com/serde-rs/serde/issues/1964
-// "Rename With Expressions"
-
 /// [RFC8731](https://tools.ietf.org/html/rfc8731)
 pub const SSH_NAME_CURVE25519: &str = "curve25519-sha256";
 /// An older alias prior to standardisation. Eventually could be removed
diff --git a/sshwire_derive/src/lib.rs b/sshwire_derive/src/lib.rs
index 4ca2651..01709f3 100644
--- a/sshwire_derive/src/lib.rs
+++ b/sshwire_derive/src/lib.rs
@@ -3,6 +3,7 @@ use std::collections::HashSet;
 use proc_macro::Delimiter;
 use virtue::generate::FnSelfArg;
 use virtue::parse::{Attribute, AttributeLocation, EnumBody, StructBody};
+use virtue::utils::{parse_tagged_attribute, ParsedAttribute};
 use virtue::prelude::*;
 
 #[proc_macro_derive(SSHEncode, attributes(sshwire))]
@@ -76,58 +77,28 @@ enum FieldAtt {
 }
 
 fn take_cont_atts(atts: &[Attribute]) -> Result<Vec<ContainerAtt>> {
-    atts.iter()
+    let x = atts.iter()
         .filter_map(|a| {
-            match a.location {
-                AttributeLocation::Container => {
-                    let mut s = a.tokens.stream().into_iter();
-                    if &s.next().expect("missing attribute name").to_string()
-                        != "sshwire"
-                    {
-                        // skip attributes other than "sshwire"
-                        return None;
-                    }
-                    Some(if let Some(TokenTree::Group(g)) = s.next() {
-                        let mut g = g.stream().into_iter();
-                        let f = match g.next() {
-                            Some(TokenTree::Ident(l))
-                                if l.to_string() == "no_variant_names" =>
-                            {
-                                Ok(ContainerAtt::NoNames)
-                            }
-
-                            Some(TokenTree::Ident(l))
-                                if l.to_string() == "variant_prefix" =>
-                            {
-                                Ok(ContainerAtt::VariantPrefix)
-                            }
-
-                            _ => Err(Error::Custom {
-                                error: "Unknown sshwire atttribute".into(),
-                                span: Some(a.tokens.span()),
-                            }),
-                        };
+            parse_tagged_attribute(&a.tokens, "sshwire")
+            .transpose()
+        });
 
-                        if let Some(_) = g.next() {
-                            Err(Error::Custom {
-                                error: "Extra unhandled parts".into(),
-                                span: Some(a.tokens.span()),
-                            })
-                        } else {
-                            f
-                        }
-                    } else {
-                        Err(Error::Custom {
-                            error: "#[sshwire(...)] attribute is missing (...) part"
-                                .into(),
-                            span: Some(a.tokens.span()),
-                        })
-                    })
-                }
-                _ => panic!("Non-field attribute for field: {a:#?}"),
-            }
-        })
-        .collect()
+    let mut ret = vec![];
+    // flatten the lists
+    for a in x {
+        for a in a? {
+            let l = match a {
+                ParsedAttribute::Tag(l) if l.to_string() == "no_variant_names" => Ok(ContainerAtt::NoNames),
+                ParsedAttribute::Tag(l) if l.to_string() == "variant_prefix" => Ok(ContainerAtt::VariantPrefix),
+                _ => Err(Error::Custom {
+                    error: "Unknown sshwire atttribute".into(),
+                    span: None,
+                }),
+            }?;
+            ret.push(l);
+        }
+    }
+    Ok(ret)
 }
 
 // TODO: we could use virtue parse_tagged_attribute() though it doesn't support Literals
-- 
GitLab