前言
syn和quote的简单使用——生成结构体-CSDN博客https://blog.csdn.net/qq_63401240/article/details/150609865?spm=1001.2014.3001.5501
前面使用syn和quote,发现挺好玩的,感觉可以干很多事情,不愧是Rust中的宏。
宏分为声明宏和过程宏,过程宏又分为函数宏、派生宏、属性宏。
前面build_struct这个宏是过程宏,或者说更详细点是函数宏
这篇就来使用另一个过程宏——派生宏,
Macro 宏编程 - Rust语言圣经(Rust Course)https://course.rs/advance/macro.html
需求
一个结构体变成mysql的create的语句。
比如
#[derive(Create)]
#[table_name="students"]
struct Student{#[field(pk)]id:i32,#[field(length=40,null=false)]name:String,#[field(null=false)]score: f32,#[field(null=false,default=true)]is_job:bool,
}
生成的sql语句应该是
create table students (id int auto_increment primary key,name varchar(40) not null, score float not null, is_job tinyint(1) not null default true
)
- 如果没有设置表的名字,就用结构体的名字当表名。
- 如果没有主键,就在自定义一个主键=结构体的名字+_id。
- 需要一个实现一个打印sql语句的方法。
还有其他复杂的情况,不考虑,就这样。
正文
一些简单地介绍
定义派生宏——proc_macro_derive,简单来说,定义Create这个派生宏的代码如下
#[proc_macro_derive(Create, attributes(table_name, field))]
pub fn create(input: TokenStream) -> TokenStream {
}
定义派生宏使用proc_macro_derive这个属性(Attribute)。
Create是定义的派生宏的名字。
table_name和field是派生宏Create额外识别的属性名,
#[field(null=false)]score: f32,
这里null=false这个整体可以称为元数据Meta,当然元数据的定义如下
#[cfg_attr(docsrs, doc(cfg(any(feature = "full", feature = "derive"))))]pub enum Meta {Path(Path),/// A structured list within an attribute, like `derive(Copy, Clone)`.List(MetaList),/// A name-value pair within an attribute, like `feature = "nightly"`.NameValue(MetaNameValue),}
关于元数据有三种类型Path、List、NameValue
这三者不是独立的,往往一起出现,比如
对于#[field(pk)],就是List+Path;
#[field(null=false)]是List+NameValue。
#[field]是Path
总之
没有括号,是Path。
有括号,是List。
有等号是NameValue。
思考
整个问题其实还是很复杂的,大致问题如下:
- 如何获取结构体的名字?
- 如何获取结构体字段的名字和类型?
- 如何把Rust类型映射到sql的数据类型?
- 如何解析一个字段中的全部元数据?
- 如何拼接sql语句?
完成这5步感觉就差不多了,整个过程其实也像玩积木游戏,先拆开再拼回去。
获取结构体的名字
首先,传进来的是TokenStream这个类型,需要进行类型装换。转化成能在派生宏中处理的数据,
代码如下
let input = parse_macro_input!(input as DeriveInput);
此时,这个input就变成DeriveInput类型了
看看DeriveInput 的定义
#[cfg_attr(docsrs, doc(cfg(feature = "derive")))]pub struct DeriveInput {pub attrs: Vec<Attribute>,pub vis: Visibility,pub ident: Ident,pub generics: Generics,pub data: Data,}
发现,有5个字段,意思是显然的,
1. pub attrs: Vec<Attribute> :包含所有附加在该类型上的属性。
2. pub vis: Visibility :类型的可见性(visibility)。
3. pub ident: Ident 类型的标识符(名称)。
4. pub generics: Generics 泛型参数信息
5 pub data: Data 类型的具体数据内容(主体部分)
显然,属性就是指前面提到的table_name,field这两个。
因此,获取结构体的名字
let name = input.ident;
获取表的名字
不妨写一个函数
fn get_table_name(input: &DeriveInput) -> String {}
把前面的input传进来。
需要获取table_name这个属性,结合DeriveInput结构体的定义,大致流程如下
- 属性肯定是在attrs中。
- attrs返回一个Vec<Attribute>,获取迭代器,遍历其中的属性。
- 寻找到table_name这个属性。
- 获取属性所对应的值,返回。
- 没找到则使用结构体名字。
因此,部分代码如下
fn get_tablename(input: &DeriveInput) -> String {input.attrs.iter().find(|attr| attr.path().is_ident("table_name")).and_then(|attr| {}).unwrap_or_else(|| input.ident.to_string().to_lowercase())
}
现在找到了table_name这个属性,如何获取对应的值?
考虑到attr的类型是Attribute,定义如下
#[cfg_attr(docsrs, doc(cfg(any(feature = "full", feature = "derive"))))]pub struct Attribute {pub pound_token: Token![#],pub style: AttrStyle,pub bracket_token: token::Bracket,pub meta: Meta,}
显然,要找到table_name对应的值"students"需要在meta中寻找,结合前面Meta的定义,
因此,在and_then闭包中的代码如下
if let Meta::NameValue(meta) = &attr.meta {}
这一步是对元数据的判断。
为什么是Meta::NameValue,因为是table_name="student"。
进一步操作,获取"studnet",同时考虑到前面Meta::NameValue以及其中MetaNameValue的定
义,如下
#[cfg_attr(docsrs, doc(cfg(any(feature = "full", feature = "derive"))))]pub struct MetaNameValue {pub path: Path,pub eq_token: Token![=],pub value: Expr,}
因此,在and_then闭包中的进一步代码如下
if let syn::Expr::Lit(expr_lit) = &meta.value {}
这一步是对字面量表达式的判断,判断"student"是不是字面量表达式。
同理,下一步就是对字符串切片的判断,
#[cfg_attr(docsrs, doc(cfg(any(feature = "full", feature = "derive"))))]pub struct ExprLit {pub attrs: Vec<Attribute>,pub lit: Lit,}
即
if let Lit::Str(lit_str) = &expr_lit.lit {}
三个if,如果都成功,判断是字符串切片。
因此,返回数据。
因为是在and_then的闭包函数中
pub fn and_then<U, F>(self, f: F) -> Option<U>whereF: FnOnce(T) -> Option<U>,
需要返回Option。
因此,返回如下
return Some(lit_str.value());
所以,关于获取表的名字的全部代码如下
fn get_table_name(input: &DeriveInput) -> String {input.attrs.iter().find(|attr| attr.path().is_ident("table_name")).and_then(|attr| {if let Meta::NameValue(meta) = &attr.meta {if let syn::Expr::Lit(expr_lit) = &meta.value {if let Lit::Str(lit_str) = &expr_lit.lit {return Some(lit_str.value());}}}None}).unwrap_or_else(|| input.ident.to_string().to_lowercase())
}
一层一层的深入,获取数据。
获取全部字段
如何获取字段,大致操作如下:
- 是一个结构体
- 有字段的结构体——命名结构体
- 获取字段
就这三步,慢慢来。
判断是否是结构体,需要使用DeriveInput中的data,因为data是Data类型的,Data的定义如下
#[cfg_attr(docsrs, doc(cfg(feature = "derive")))]pub enum Data {Struct(DataStruct),Enum(DataEnum),Union(DataUnion),}
因此,代码如下
let fields = if let Data::Struct(s) = input.data {} else {panic!("不是结构体");};
现在代码中的s的类型是DataStruct ,看看定义
#[cfg_attr(docsrs, doc(cfg(feature = "derive")))]pub struct DataStruct {pub struct_token: Token![struct],pub fields: Fields,pub semi_token: Option<Token![;]>,}
获取field,因为field是Field类型的,其中Field定义
#[cfg_attr(docsrs, doc(cfg(any(feature = "full", feature = "derive"))))]pub enum Fields {/// Named fields of a struct or struct variant such as `Point { x: f64,/// y: f64 }`.Named(FieldsNamed),/// Unnamed fields of a tuple struct or tuple variant such as `Some(T)`.Unnamed(FieldsUnnamed),/// Unit struct or unit variant such as `None`.Unit,}
需要有Named,同时,FieldsNamed的定义如下
#[cfg_attr(docsrs, doc(cfg(any(feature = "full", feature = "derive"))))]pub struct FieldsNamed {pub brace_token: token::Brace,pub named: Punctuated<Field, Token![,]>,}
因此,获取named,代码如下
if let Fields::Named(n) = s.fields {n.named} else {panic!("不是命名结构体");}
因此,获取结构体的全部字段的代码如下
let fields = if let Data::Struct(s) = input.data {if let Fields::Named(n) = s.fields {n.named} else {panic!("不是命名结构体");}} else {panic!("不是结构体");
};
错误处理的不是很好,算了。
此时这个fields就是全部的字段,类型是Punctuated<Field, Token![,]>
Punctuated in syn::punctuated - Rusthttps://docs.rs/syn/latest/syn/punctuated/struct.Punctuated.html
rust类型映射到sql
创建一个函数,用来解析获得的fields。
fn parse_fields(fields: Punctuated<Field, Comma>){
}
返回什么,先不慌。
考虑到Punctuated实现了IntoIterator 这个trait,可以遍历。
因此,先对fields进行循环
for field in fields.into_iter() {}
而field的类型是Field,要获取类型和名字,很简单
let field_name = field.ident.unwrap();
let ty = field.ty;
现在获取字段中的类型,定义一个函数实现类型映射,代码如下
pub fn rust_type_to_sql_type(ty: &Type) -> String {
}
关于Type,这是一个enum,里面有各种类型的定义,
Type in syn - Rusthttps://docs.rs/syn/latest/syn/enum.Type.html因为前面定义的结构体Student全是简单的标识符类型(不是泛型、不是引用、不是数组)。因此,这里就使用 Type::Path。
代码如下
pub fn rust_type_to_sql_type(ty: &Type) -> String {match ty {Type::Path(type_path) if type_path.path.is_ident("i32") => "int".to_string(),Type::Path(type_path) if type_path.path.is_ident("f32") => "float".to_string(),Type::Path(type_path) if type_path.path.is_ident("String") => "varchar".to_string(),Type::Path(type_path) if type_path.path.is_ident("bool") => "tinyint(1)".to_string(),_ => "".to_string(),}
}
解析元数据前的处理——获取元数据
解析元数据之前,先要获取到元数据。
和前面获取table_name类型的,只是更复杂,总体流程如下。
- 从attrs中获取Vec<Attribute>,然后遍历,找到所有属性field,返回Attribute
- 判断和获取
for attr in field.attrs.iter().filter(|attr| attr.path().is_ident("field"))}
前面找到table_name,使用的find,因为只有一个,而这个寻找field,使用了filter
因为
- filter:找出所有符合条件的
- find:只找第一个符合条件的
所以使用filter。
然后,判断元数据类型,因为Student这个结构体中的字段的field属性全是List,因此,第一步代码如下
if let Meta::List(meta_list) = &attr.meta {}
以name字段为例
#[field(length=40,null=false)]name:String,
判断是List,然后要进一步操作,即对length=40,null=false进行操作
因为List这个元数据类型中包含一个字段List(MetaList),考虑到MetaList的定义
#[cfg_attr(docsrs, doc(cfg(any(feature = "full", feature = "derive"))))]pub struct MetaList {pub path: Path,pub delimiter: MacroDelimiter,pub tokens: TokenStream,}
要进一步操作,因此,是获取到tokens,而tokens的类型是TokenStream,这无法操作,需要变成其他类型,因此,需要使用某个类型的解析方法,即parse方法
而且,非常关键的是这个TokenStream是proc_macro2::TokenStream;
因为length=40,null=false是逗号分开的
因此
笔者在这里把tokens解析成前面的Punctuated。
当然,还有其他解析方法,笔者不管这么多了,代码如下
let punctuated = Punctuated::<Meta, Token![,]>::parse_terminated;
if let Ok(nested) = punctuated.parse2(meta_list.tokens.clone()) {}
Punctuated::<Meta, Token![,]> 是定义
parse_terminated 返回解析器,允许最后一个元素后面有分隔符,还有其他方法返回解析器
fn parse2(self, tokens: TokenStream) -> Result<T> ,传一个proc_macro2::TokenStream,返回Result<T>,这里T就是Punctuated::<Meta, Token![,]>
因此,nested就是一个Punctuated,对于length=40,null=false来说
大致是这样的
即Meta,逗号,Meata
当然,不是很准确,没有中括号,意思一下。
总之现在nest是个序列,因此,再次循环一下
for meta in nested {}
序列中的元素的Meta,因此,同理,判断元数据类型。
match meta {Meta::Path(path) => {}Meta::NameValue(nv) => {}_ => {}}
因为走到这一步,只有Path和NameValue两种类型了,
比如对于#[field(pk)],走到这里,就是pk了,类型是Path
对于#[field(null=false)],走到这里,就是`null=false`,类型是NameValue
分别处理就可以了
处理NameValue
不妨定义一个新的函数
pub fn handle_one_nv(nv: MetaNameValue) -> String {}
参数是MetaNameValue,对于这个类型
前面说过,分别获取name和value,对于null=true来说,name就是null,value就是true。
let path = nv.path;let value = nv.value;
可以把这个null变成字符串字面量"null",还有对true的判断
即
match path.get_ident().unwrap().to_string().as_str() {"null" => {}_ => "".to_string(),}
判断也很简单,先判断是不是字面量,然后判断是不是bool。
代码如下
pub fn handle_one_nv(nv: MetaNameValue) -> String {let path = nv.path;let value = nv.value;match path.get_ident().unwrap().to_string().as_str() {"null" => {if let Expr::Lit(expr_lit) = &value {if let Lit::Bool(lit_bool) = &expr_lit.lit {return if lit_bool.value {"".to_string()} else {" not null".to_string()};}}panic!("null 需要设为bool值,例如: #[null = true]");}_ => "".to_string(),}
}
一步一步进去就可以了,对于其他东西,比如default,length,也是如下,就是判断。不细说了
生成sql和主键的处理
在一次循环中,把每一个字段生成一个sql属性,包括sql对应的属性名、类型、约束,变成一个字符串,放进一个vec中
如果没有主键,循环完成后,添加一个主键就可以了
最后,把vec使用逗号,将属性每一个连接起来。
和table_name一起生成create语句
这一步没什么可说,穿插在前面的代码中
宏的全部代码
src/lib.rs文件
mod utils;
use proc_macro::TokenStream;
use quote::quote;
use syn::Lit;
use syn::Meta;
use syn::parse::Parser;
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{Data, Field, Fields, parse_macro_input};
use syn::{DeriveInput, Token};
use utils::{handle_length, handle_one_nv, rust_type_to_sql_type};/// 处理表名
fn get_table_name(input: &DeriveInput) -> String {input.attrs.iter().find(|attr| attr.path().is_ident("table_name")).and_then(|attr| {if let Meta::NameValue(meta) = &attr.meta {if let syn::Expr::Lit(expr_lit) = &meta.value {if let Lit::Str(lit_str) = &expr_lit.lit {return Some(lit_str.value());}}}None}).unwrap_or_else(|| input.ident.to_string().to_lowercase())
}/// 解析字段
fn parse_fields(fields: Punctuated<Field, Comma>, name: String) -> Vec<String> {let mut field_vec = Vec::new();let mut has_pk = false;for field in fields.into_iter() {let field_name = field.ident.unwrap();let ty = field.ty;let mut sql_type = rust_type_to_sql_type(&ty); // 让 sql_type 可变let mut sql_constraint = String::new();for attr in field.attrs.iter().filter(|attr| attr.path().is_ident("field")){if let Meta::List(meta_list) = &attr.meta {let punctuated = Punctuated::<Meta, Token![,]>::parse_terminated;if let Ok(nested) = punctuated.parse2(meta_list.tokens.clone()) {for meta in nested {match meta {Meta::Path(path) => {if path.is_ident("pk") {has_pk = true;sql_constraint += " auto_increment primary key";}}Meta::NameValue(nv) => {let ident = nv.path.get_ident().unwrap().to_string();if ident == "length" {// length 应该修改类型部分,而不是约束部分let length_value = handle_length(&nv.value);sql_type = format!("{}({})", sql_type, length_value);} else {let res = handle_one_nv(nv);sql_constraint += &res;}}_ => {}}}}}}let field_sql = format!("{} {}{}", field_name, sql_type, sql_constraint);field_vec.push(field_sql);}if !has_pk {let pk_sql = format!("{}_id int auto_increment primary key", name);field_vec.push(pk_sql);}field_vec
}
#[proc_macro_derive(Create, attributes(table_name, field))]
pub fn create(input: TokenStream) -> TokenStream {let input = parse_macro_input!(input as DeriveInput);let name = &input.ident;//获取表名let table_name = get_table_name(&input);// 获取结构体的属性let fields = if let Data::Struct(s) = input.data {if let Fields::Named(n) = s.fields {n.named} else {panic!("不是命名结构体");}} else {panic!("不是结构体");};// sql属性和约束let word_vec = parse_fields(fields, table_name.clone());// 列let columns = word_vec.join(",\n ");// 拼接let output = quote! {impl #name {pub fn create_table_sql() -> String {format!("create table {} (\n{}\n);",#table_name,#columns)}}};output.into()
}
src/utils.rs文件
use syn::Lit::{Bool, Int};
use syn::MetaNameValue;
use syn::{Expr, Lit};
use syn::Type;pub fn rust_type_to_sql_type(ty: &Type) -> String {match ty {Type::Path(type_path) if type_path.path.is_ident("i32") => "int".to_string(),Type::Path(type_path) if type_path.path.is_ident("f32") => "float".to_string(),Type::Path(type_path) if type_path.path.is_ident("String") => "varchar".to_string(),Type::Path(type_path) if type_path.path.is_ident("bool") => "tinyint".to_string(),_ => "".to_string(),}
}fn handle_default(value: Expr) -> String {if let Expr::Lit(expr_lit) = &value {match &expr_lit.lit {Bool(lit_bool) => {return if lit_bool.value {" default true".to_string()} else {" default false".to_string()};}Lit::Str(lit_str) => {return format!(" default '{}'", lit_str.value());}Int(lit_int) => {return format!(" default {}", lit_int.base10_digits());}Lit::Float(lit_float) => {return format!(" default {}", lit_float.base10_digits());}_ => panic!("default 不支持该类型"),}}panic!("default 需要设为字面量值");
}
pub fn handle_length(value: &Expr) -> String {if let Expr::Lit(expr_lit) = &value {if let Int(lit_int) = &expr_lit.lit {return lit_int.base10_digits().to_string();}}panic!("需要设为整型");
}
pub fn handle_one_nv(nv: MetaNameValue) -> String {let path = nv.path;let value = nv.value;match path.get_ident().unwrap().to_string().as_str() {"null" => {if let Expr::Lit(expr_lit) = &value {if let Bool(lit_bool) = &expr_lit.lit {return if lit_bool.value {"".to_string()} else {" not null".to_string()};}}panic!("null 需要设为bool值,例如: #[null = true]");}"default" => handle_default(value),_ => "".to_string(),}
}
错误处理不是很好,不管那些。
简单测试一下
新建一个binary crate,导入前面宏所在的crate,其中src/main.rs的内容如下
use macro_crate::{Create};
#[derive(Create)]
struct Student{#[field(pk)]id:i32,#[field(length=40,null=false)]name:String,#[field(null=false,default=60.0)]score: f32,#[field(null=false,default=true)]is_job:bool,
}
fn main() {let create=Student::create_table_sql();println!("{}",create);
}
生成sql语句如下
create table student (
id int auto_increment primary key,name varchar(40) not null,score float not null default 60.0,is_job tinyint not null default true
);
运行结果如下
成功,哈哈哈哈哈哈
其它复杂的东西,就不管了。