rstml_component_macro/
func.rs

1use proc_macro2::{Span, TokenStream};
2use proc_macro2_diagnostics::{Diagnostic, Level};
3use quote::quote;
4use syn::{
5	parse::Parse, punctuated::Punctuated, spanned::Spanned, token::Comma, Field, FnArg, GenericParam,
6	Generics, Ident, Pat, Path, PathArguments, PathSegment, Token, Type, TypeImplTrait, TypeParam,
7	TypePath, Visibility,
8};
9
10trait IdentPath {
11	fn write_to(&self, target: &mut dyn FnMut(&str));
12}
13
14impl IdentPath for &str {
15	fn write_to(&self, target: &mut dyn FnMut(&str)) {
16		target(self);
17	}
18}
19
20impl IdentPath for usize {
21	fn write_to(&self, target: &mut dyn FnMut(&str)) {
22		target(&self.to_string());
23	}
24}
25
26impl<L, R> IdentPath for (L, R)
27where
28	L: IdentPath,
29	R: IdentPath,
30{
31	fn write_to(&self, target: &mut dyn FnMut(&str)) {
32		self.0.write_to(target);
33		self.1.write_to(target);
34	}
35}
36
37impl IdentPath for &dyn IdentPath {
38	fn write_to(&self, target: &mut dyn FnMut(&str)) {
39		(*self).write_to(target);
40	}
41}
42
43#[derive(Clone)]
44struct IdentParts<I>(I);
45
46impl<'a, I> IdentPath for IdentParts<I>
47where
48	I: Iterator<Item = &'a str> + Clone,
49{
50	fn write_to(&self, target: &mut dyn FnMut(&str)) {
51		for part in self.0.clone() {
52			target(part);
53		}
54	}
55}
56
57struct ComponentAttrs {
58	vis: syn::Visibility,
59	name: syn::Ident,
60}
61
62impl Parse for ComponentAttrs {
63	fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
64		Ok(Self {
65			vis: input.parse()?,
66			name: input.parse()?,
67		})
68	}
69}
70
71struct ComponentStructBuilder {
72	vis: syn::Visibility,
73	ident: syn::Ident,
74	generics: Generics,
75	fields: Punctuated<Field, Comma>,
76}
77
78impl ComponentStructBuilder {
79	fn new(attr: ComponentAttrs, generics: Generics) -> Self {
80		Self {
81			vis: attr.vis,
82			ident: attr.name,
83			generics,
84			fields: Punctuated::new(),
85		}
86	}
87
88	fn resolve_type(&mut self, ty: &Type, ident_hint: &dyn IdentPath) -> Type {
89		match ty {
90			Type::ImplTrait(inner) => self.add_generic(inner.clone(), ident_hint),
91
92			Type::Array(inner) => {
93				let elem = self.resolve_type(&inner.elem, ident_hint);
94				Type::Array(syn::TypeArray {
95					bracket_token: inner.bracket_token,
96					elem: Box::new(elem),
97					semi_token: inner.semi_token,
98					len: inner.len.clone(),
99				})
100			}
101
102			Type::Paren(inner) => {
103				let elem = self.resolve_type(&inner.elem, ident_hint);
104				Type::Paren(syn::TypeParen {
105					paren_token: inner.paren_token,
106					elem: Box::new(elem),
107				})
108			}
109
110			Type::Ptr(inner) => {
111				let elem = self.resolve_type(&inner.elem, ident_hint);
112				Type::Ptr(syn::TypePtr {
113					star_token: inner.star_token,
114					const_token: inner.const_token,
115					mutability: inner.mutability,
116					elem: Box::new(elem),
117				})
118			}
119
120			Type::Reference(inner) => {
121				let elem = self.resolve_type(&inner.elem, ident_hint);
122				Type::Reference(syn::TypeReference {
123					and_token: inner.and_token,
124					lifetime: inner.lifetime.clone(),
125					mutability: inner.mutability,
126					elem: Box::new(elem),
127				})
128			}
129
130			Type::Slice(inner) => {
131				let elem = self.resolve_type(&inner.elem, ident_hint);
132				Type::Slice(syn::TypeSlice {
133					bracket_token: inner.bracket_token,
134					elem: Box::new(elem),
135				})
136			}
137
138			Type::Tuple(inner) => {
139				let elems = inner
140					.elems
141					.iter()
142					.enumerate()
143					.map(|(idx, elem)| self.resolve_type(elem, &(ident_hint, idx)))
144					.collect();
145
146				Type::Tuple(syn::TypeTuple {
147					paren_token: inner.paren_token,
148					elems,
149				})
150			}
151
152			_ => ty.clone(),
153		}
154	}
155
156	fn push_field(&mut self, ident: Ident, ty: Type) {
157		let ty = self.resolve_type(&ty, &IdentParts(ident.to_string().split('_')));
158
159		let field = Field {
160			attrs: vec![],
161			vis: Visibility::Public(Token![pub](Span::call_site())),
162			mutability: syn::FieldMutability::None,
163			ident: Some(ident),
164			ty,
165			colon_token: Some(Token![:](Span::call_site())),
166		};
167
168		self.fields.push(field);
169	}
170
171	fn add_generic(&mut self, impl_type: TypeImplTrait, ident_hint: &dyn IdentPath) -> Type {
172		let mut type_ident_str = String::new();
173		type_ident_str.push('T');
174		ident_hint.write_to(&mut |part| {
175			if part.is_empty() {
176				return;
177			}
178
179			let first_char = part.chars().next().unwrap();
180			type_ident_str.push(first_char.to_ascii_uppercase());
181			type_ident_str.push_str(&part[first_char.len_utf8()..]);
182		});
183
184		let existing_set = self
185			.generics
186			.type_params()
187			.map(|par| par.ident.to_string())
188			.collect::<std::collections::HashSet<_>>();
189
190		while existing_set.contains(&type_ident_str) {
191			type_ident_str.push('_');
192		}
193
194		let type_ident = Ident::new(&type_ident_str, Span::call_site());
195		let type_param = TypeParam {
196			attrs: vec![],
197			ident: type_ident.clone(),
198			colon_token: Some(Token![:](Span::call_site())),
199			bounds: impl_type.bounds,
200			eq_token: None,
201			default: None,
202		};
203
204		self.generics.params.push(GenericParam::Type(type_param));
205
206		Type::Path(TypePath {
207			qself: None,
208			path: Path {
209				leading_colon: None,
210				segments: Punctuated::from_iter(vec![PathSegment {
211					ident: type_ident,
212					arguments: PathArguments::None,
213				}]),
214			},
215		})
216	}
217
218	fn build(self) -> (TokenStream, Ident, Generics, Punctuated<Field, Comma>) {
219		let Self {
220			vis,
221			ident,
222			generics,
223			fields,
224		} = self;
225		let (impl_generics, _, where_clause) = generics.split_for_impl();
226
227		let generated_struct = quote! {
228			#[derive(::rstml_component::HtmlComponent)]
229			#[allow(non_snake_case)]
230			#vis struct #ident #impl_generics #where_clause {#fields}
231		};
232
233		(generated_struct, ident, generics, fields)
234	}
235}
236
237pub fn component(attr: TokenStream, input: TokenStream) -> TokenStream {
238	let mut diagnostics: Vec<Diagnostic> = vec![];
239
240	// parse input
241	let input: syn::ItemFn = match syn::parse2(input) {
242		Ok(input) => input,
243		Err(err) => return err.to_compile_error(),
244	};
245
246	let attr: ComponentAttrs = match syn::parse2(attr) {
247		Ok(attr) => attr,
248		Err(err) => return err.to_compile_error(),
249	};
250
251	// check if input is valid
252	if let Some(constness) = input.sig.constness {
253		diagnostics.push(Diagnostic::spanned(
254			constness.span(),
255			Level::Error,
256			"component function must not be const",
257		))
258	} else if let Some(asyncness) = input.sig.asyncness {
259		diagnostics.push(Diagnostic::spanned(
260			asyncness.span(),
261			Level::Error,
262			"component function must not be async",
263		));
264	} else if let Some(unsafety) = input.sig.unsafety {
265		diagnostics.push(Diagnostic::spanned(
266			unsafety.span(),
267			Level::Error,
268			"component function must not be unsafe",
269		));
270	} else if let Some(ref abi) = input.sig.abi {
271		diagnostics.push(Diagnostic::spanned(
272			abi.span(),
273			Level::Error,
274			"component function must not have an abi",
275		));
276	}
277
278	let mut struct_builder = ComponentStructBuilder::new(attr, input.sig.generics.clone());
279
280	for arg in input.sig.inputs.iter() {
281		match arg {
282			FnArg::Receiver(_) => {
283				diagnostics.push(Diagnostic::spanned(
284					arg.span(),
285					Level::Error,
286					"component function must not have self argument",
287				));
288			}
289
290			FnArg::Typed(pat_type) => {
291				let pat = *pat_type.pat.clone();
292				let ty = *pat_type.ty.clone();
293
294				match pat {
295					Pat::Ident(pat) => struct_builder.push_field(pat.ident, ty),
296
297					Pat::TupleStruct(pat) => {
298						diagnostics.push(Diagnostic::spanned(
299							pat.span(),
300							Level::Error,
301							"tuple struct pattern not supported",
302						));
303					}
304
305					Pat::Struct(pat) => {
306						diagnostics.push(Diagnostic::spanned(
307							pat.span(),
308							Level::Error,
309							"struct pattern not supported",
310						));
311					}
312
313					Pat::Tuple(pat) => {
314						diagnostics.push(Diagnostic::spanned(
315							pat.span(),
316							Level::Error,
317							"tuple pattern not supported",
318						));
319					}
320
321					Pat::Slice(pat) => {
322						diagnostics.push(Diagnostic::spanned(
323							pat.span(),
324							Level::Error,
325							"slice pattern not supported",
326						));
327					}
328
329					_ => {
330						diagnostics.push(Diagnostic::spanned(
331							pat.span(),
332							Level::Error,
333							"couldn't parse function argument",
334						));
335					}
336				}
337			}
338		}
339	}
340
341	let (generated_struct, ident, generics, fields) = struct_builder.build();
342
343	let input_ident = input.sig.ident.clone();
344	let mut fn_args = Vec::new();
345	for field in fields.iter() {
346		// unwrap shouldn't panic since each field is generated with an ident
347		let ident = field.ident.clone().unwrap();
348		fn_args.push(quote!(self.#ident,));
349	}
350
351	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
352
353	let impl_block = quote! {
354		impl #impl_generics ::rstml_component::HtmlContent for #ident #ty_generics #where_clause {
355			fn fmt(self, formatter: &mut ::rstml_component::HtmlFormatter) -> std::fmt::Result {
356				formatter.write_content(#input_ident (#(#fn_args)*))
357			}
358		}
359	};
360
361	let diagnostics = diagnostics.iter().map(|d| d.clone().emit_as_item_tokens());
362	quote! {
363		#(#diagnostics)*
364		#input
365		#generated_struct
366		#impl_block
367	}
368}