1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Attribute, Data, DeriveInput, Path};
4
5#[proc_macro_derive(AstNodeTransform, attributes(convert_to))]
7pub fn ast_node_transform(input: TokenStream) -> TokenStream {
8 let input = parse_macro_input!(input as DeriveInput);
9
10 let name = &input.ident;
11 let data = &input.data;
12
13 let convert_to_attr = input
15 .attrs
16 .iter()
17 .find(|attr| attr.path().is_ident("convert_to"))
18 .expect("Missing #[convert_to(...)] attribute for AstNodeTransform");
19
20 let target_variants = parse_convert_to_attribute(convert_to_attr)
21 .expect("Failed to parse #[convert_to(...)] attribute");
22
23 let mut impls = Vec::new();
25
26 for (to_ty, variant) in target_variants {
27 match data {
28 Data::Struct(_) => {
29 impls.push(quote! {
31 impl From<P<#name>> for P<#to_ty> {
32 fn from(id: P<#name>) -> Self {
33 P::from(#variant(id.into()))
34 }
35 }
36 });
37
38 impls.push(quote! {
40 impl From<P<#name>> for #to_ty {
41 fn from(id: P<#name>) -> Self {
42 #variant(id.into())
43 }
44 }
45 });
46
47 impls.push(quote! {
49 impl From<#name> for P<#to_ty> {
50 fn from(id: #name) -> Self {
51 P::from(#variant(id.into()))
52 }
53 }
54 });
55
56 impls.push(quote! {
58 impl From<#name> for #to_ty {
59 fn from(id: #name) -> Self {
60 #variant(id.into())
61 }
62 }
63 });
64 }
65
66 Data::Enum(_) => {
68 impls.push(quote! {
70 impl From<#name> for #to_ty {
71 fn from(id: #name) -> Self {
72 #variant(id.into())
73 }
74 }
75 });
76 }
77
78 _ => panic!("AstNodeTransform can only be applied to enums or structs"),
79 }
80 }
81
82 let output = quote! {
83 #(#impls)*
84 };
85
86 output.into()
87}
88
89fn parse_convert_to_attribute(attr: &Attribute) -> Result<Vec<(Path, Path)>, syn::Error> {
91 attr.parse_args_with(|input: syn::parse::ParseStream| {
92 let mut target_variants = Vec::new();
93
94 while !input.is_empty() {
95 let variant: Path = input.parse()?;
97 let parent = Path::from(variant.segments.first().unwrap().ident.clone());
98 target_variants.push((parent, variant));
99
100 if input.peek(syn::Token![,]) {
102 input.parse::<syn::Token![,]>()?;
103 }
104 }
105
106 Ok(target_variants)
107 })
108}