1use proc_macro2::{Span, TokenStream};
2use proc_macro2_diagnostics::{Diagnostic, Level};
3use quote::quote;
4use syn::{
5 parse::Parse, punctuated::Punctuated, spanned::Spanned, token::Comma, Field, FnArg, GenericParam,
6 Generics, Ident, Pat, Path, PathArguments, PathSegment, Token, Type, TypeImplTrait, TypeParam,
7 TypePath, Visibility,
8};
9
10trait IdentPath {
11 fn write_to(&self, target: &mut dyn FnMut(&str));
12}
13
14impl IdentPath for &str {
15 fn write_to(&self, target: &mut dyn FnMut(&str)) {
16 target(self);
17 }
18}
19
20impl IdentPath for usize {
21 fn write_to(&self, target: &mut dyn FnMut(&str)) {
22 target(&self.to_string());
23 }
24}
25
26impl<L, R> IdentPath for (L, R)
27where
28 L: IdentPath,
29 R: IdentPath,
30{
31 fn write_to(&self, target: &mut dyn FnMut(&str)) {
32 self.0.write_to(target);
33 self.1.write_to(target);
34 }
35}
36
37impl IdentPath for &dyn IdentPath {
38 fn write_to(&self, target: &mut dyn FnMut(&str)) {
39 (*self).write_to(target);
40 }
41}
42
43#[derive(Clone)]
44struct IdentParts<I>(I);
45
46impl<'a, I> IdentPath for IdentParts<I>
47where
48 I: Iterator<Item = &'a str> + Clone,
49{
50 fn write_to(&self, target: &mut dyn FnMut(&str)) {
51 for part in self.0.clone() {
52 target(part);
53 }
54 }
55}
56
57struct ComponentAttrs {
58 vis: syn::Visibility,
59 name: syn::Ident,
60}
61
62impl Parse for ComponentAttrs {
63 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
64 Ok(Self {
65 vis: input.parse()?,
66 name: input.parse()?,
67 })
68 }
69}
70
71struct ComponentStructBuilder {
72 vis: syn::Visibility,
73 ident: syn::Ident,
74 generics: Generics,
75 fields: Punctuated<Field, Comma>,
76}
77
78impl ComponentStructBuilder {
79 fn new(attr: ComponentAttrs, generics: Generics) -> Self {
80 Self {
81 vis: attr.vis,
82 ident: attr.name,
83 generics,
84 fields: Punctuated::new(),
85 }
86 }
87
88 fn resolve_type(&mut self, ty: &Type, ident_hint: &dyn IdentPath) -> Type {
89 match ty {
90 Type::ImplTrait(inner) => self.add_generic(inner.clone(), ident_hint),
91
92 Type::Array(inner) => {
93 let elem = self.resolve_type(&inner.elem, ident_hint);
94 Type::Array(syn::TypeArray {
95 bracket_token: inner.bracket_token,
96 elem: Box::new(elem),
97 semi_token: inner.semi_token,
98 len: inner.len.clone(),
99 })
100 }
101
102 Type::Paren(inner) => {
103 let elem = self.resolve_type(&inner.elem, ident_hint);
104 Type::Paren(syn::TypeParen {
105 paren_token: inner.paren_token,
106 elem: Box::new(elem),
107 })
108 }
109
110 Type::Ptr(inner) => {
111 let elem = self.resolve_type(&inner.elem, ident_hint);
112 Type::Ptr(syn::TypePtr {
113 star_token: inner.star_token,
114 const_token: inner.const_token,
115 mutability: inner.mutability,
116 elem: Box::new(elem),
117 })
118 }
119
120 Type::Reference(inner) => {
121 let elem = self.resolve_type(&inner.elem, ident_hint);
122 Type::Reference(syn::TypeReference {
123 and_token: inner.and_token,
124 lifetime: inner.lifetime.clone(),
125 mutability: inner.mutability,
126 elem: Box::new(elem),
127 })
128 }
129
130 Type::Slice(inner) => {
131 let elem = self.resolve_type(&inner.elem, ident_hint);
132 Type::Slice(syn::TypeSlice {
133 bracket_token: inner.bracket_token,
134 elem: Box::new(elem),
135 })
136 }
137
138 Type::Tuple(inner) => {
139 let elems = inner
140 .elems
141 .iter()
142 .enumerate()
143 .map(|(idx, elem)| self.resolve_type(elem, &(ident_hint, idx)))
144 .collect();
145
146 Type::Tuple(syn::TypeTuple {
147 paren_token: inner.paren_token,
148 elems,
149 })
150 }
151
152 _ => ty.clone(),
153 }
154 }
155
156 fn push_field(&mut self, ident: Ident, ty: Type) {
157 let ty = self.resolve_type(&ty, &IdentParts(ident.to_string().split('_')));
158
159 let field = Field {
160 attrs: vec![],
161 vis: Visibility::Public(Token)),
162 mutability: syn::FieldMutability::None,
163 ident: Some(ident),
164 ty,
165 colon_token: Some(Token)),
166 };
167
168 self.fields.push(field);
169 }
170
171 fn add_generic(&mut self, impl_type: TypeImplTrait, ident_hint: &dyn IdentPath) -> Type {
172 let mut type_ident_str = String::new();
173 type_ident_str.push('T');
174 ident_hint.write_to(&mut |part| {
175 if part.is_empty() {
176 return;
177 }
178
179 let first_char = part.chars().next().unwrap();
180 type_ident_str.push(first_char.to_ascii_uppercase());
181 type_ident_str.push_str(&part[first_char.len_utf8()..]);
182 });
183
184 let existing_set = self
185 .generics
186 .type_params()
187 .map(|par| par.ident.to_string())
188 .collect::<std::collections::HashSet<_>>();
189
190 while existing_set.contains(&type_ident_str) {
191 type_ident_str.push('_');
192 }
193
194 let type_ident = Ident::new(&type_ident_str, Span::call_site());
195 let type_param = TypeParam {
196 attrs: vec![],
197 ident: type_ident.clone(),
198 colon_token: Some(Token)),
199 bounds: impl_type.bounds,
200 eq_token: None,
201 default: None,
202 };
203
204 self.generics.params.push(GenericParam::Type(type_param));
205
206 Type::Path(TypePath {
207 qself: None,
208 path: Path {
209 leading_colon: None,
210 segments: Punctuated::from_iter(vec![PathSegment {
211 ident: type_ident,
212 arguments: PathArguments::None,
213 }]),
214 },
215 })
216 }
217
218 fn build(self) -> (TokenStream, Ident, Generics, Punctuated<Field, Comma>) {
219 let Self {
220 vis,
221 ident,
222 generics,
223 fields,
224 } = self;
225 let (impl_generics, _, where_clause) = generics.split_for_impl();
226
227 let generated_struct = quote! {
228 #[derive(::rstml_component::HtmlComponent)]
229 #[allow(non_snake_case)]
230 #vis struct #ident #impl_generics #where_clause {#fields}
231 };
232
233 (generated_struct, ident, generics, fields)
234 }
235}
236
237pub fn component(attr: TokenStream, input: TokenStream) -> TokenStream {
238 let mut diagnostics: Vec<Diagnostic> = vec![];
239
240 let input: syn::ItemFn = match syn::parse2(input) {
242 Ok(input) => input,
243 Err(err) => return err.to_compile_error(),
244 };
245
246 let attr: ComponentAttrs = match syn::parse2(attr) {
247 Ok(attr) => attr,
248 Err(err) => return err.to_compile_error(),
249 };
250
251 if let Some(constness) = input.sig.constness {
253 diagnostics.push(Diagnostic::spanned(
254 constness.span(),
255 Level::Error,
256 "component function must not be const",
257 ))
258 } else if let Some(asyncness) = input.sig.asyncness {
259 diagnostics.push(Diagnostic::spanned(
260 asyncness.span(),
261 Level::Error,
262 "component function must not be async",
263 ));
264 } else if let Some(unsafety) = input.sig.unsafety {
265 diagnostics.push(Diagnostic::spanned(
266 unsafety.span(),
267 Level::Error,
268 "component function must not be unsafe",
269 ));
270 } else if let Some(ref abi) = input.sig.abi {
271 diagnostics.push(Diagnostic::spanned(
272 abi.span(),
273 Level::Error,
274 "component function must not have an abi",
275 ));
276 }
277
278 let mut struct_builder = ComponentStructBuilder::new(attr, input.sig.generics.clone());
279
280 for arg in input.sig.inputs.iter() {
281 match arg {
282 FnArg::Receiver(_) => {
283 diagnostics.push(Diagnostic::spanned(
284 arg.span(),
285 Level::Error,
286 "component function must not have self argument",
287 ));
288 }
289
290 FnArg::Typed(pat_type) => {
291 let pat = *pat_type.pat.clone();
292 let ty = *pat_type.ty.clone();
293
294 match pat {
295 Pat::Ident(pat) => struct_builder.push_field(pat.ident, ty),
296
297 Pat::TupleStruct(pat) => {
298 diagnostics.push(Diagnostic::spanned(
299 pat.span(),
300 Level::Error,
301 "tuple struct pattern not supported",
302 ));
303 }
304
305 Pat::Struct(pat) => {
306 diagnostics.push(Diagnostic::spanned(
307 pat.span(),
308 Level::Error,
309 "struct pattern not supported",
310 ));
311 }
312
313 Pat::Tuple(pat) => {
314 diagnostics.push(Diagnostic::spanned(
315 pat.span(),
316 Level::Error,
317 "tuple pattern not supported",
318 ));
319 }
320
321 Pat::Slice(pat) => {
322 diagnostics.push(Diagnostic::spanned(
323 pat.span(),
324 Level::Error,
325 "slice pattern not supported",
326 ));
327 }
328
329 _ => {
330 diagnostics.push(Diagnostic::spanned(
331 pat.span(),
332 Level::Error,
333 "couldn't parse function argument",
334 ));
335 }
336 }
337 }
338 }
339 }
340
341 let (generated_struct, ident, generics, fields) = struct_builder.build();
342
343 let input_ident = input.sig.ident.clone();
344 let mut fn_args = Vec::new();
345 for field in fields.iter() {
346 let ident = field.ident.clone().unwrap();
348 fn_args.push(quote!(self.#ident,));
349 }
350
351 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
352
353 let impl_block = quote! {
354 impl #impl_generics ::rstml_component::HtmlContent for #ident #ty_generics #where_clause {
355 fn fmt(self, formatter: &mut ::rstml_component::HtmlFormatter) -> std::fmt::Result {
356 formatter.write_content(#input_ident (#(#fn_args)*))
357 }
358 }
359 };
360
361 let diagnostics = diagnostics.iter().map(|d| d.clone().emit_as_item_tokens());
362 quote! {
363 #(#diagnostics)*
364 #input
365 #generated_struct
366 #impl_block
367 }
368}