rstml/
visitor.rs

1use std::marker::PhantomData;
2
3use super::node::*;
4use crate::{
5    atoms::{CloseTag, OpenTag},
6    Infallible,
7};
8
9/// Enum that represents the different types with valid Rust code that can be
10/// visited using `syn::Visitor`. Usually `syn::Block` or `syn::Expr`.
11pub enum RustCode<'a> {
12    Block(&'a mut syn::Block),
13    Expr(&'a mut syn::Expr),
14    LitStr(&'a mut syn::LitStr),
15    Pat(&'a mut syn::Pat),
16}
17/// Visitor api provide a way to traverse the node tree and modify its
18/// components. The api allows modification of all types of nodes, and some
19/// atoms like InvalidBlock or NodeName.
20///
21/// Each method returns a bool that indicates if the visitor should continue to
22/// traverse the tree. If the method returns false, the visitor will stop
23/// traversing the tree.
24///
25/// By default Visitor are abstract over CustomNode, but it is possible to
26/// implement a Visitor for concrete CustomNode.
27pub trait Visitor<Custom> {
28    // Visit node types
29    fn visit_node(&mut self, _node: &mut Node<Custom>) -> bool {
30        true
31    }
32    fn visit_block(&mut self, _node: &mut NodeBlock) -> bool {
33        true
34    }
35    fn visit_comment(&mut self, _node: &mut NodeComment) -> bool {
36        true
37    }
38    fn visit_doctype(&mut self, _node: &mut NodeDoctype) -> bool {
39        true
40    }
41    fn visit_raw_node<AnyC: CustomNode>(&mut self, _node: &mut RawText<AnyC>) -> bool {
42        true
43    }
44    fn visit_custom(&mut self, _node: &mut Custom) -> bool {
45        true
46    }
47    fn visit_text_node(&mut self, _node: &mut NodeText) -> bool {
48        true
49    }
50    fn visit_element(&mut self, _node: &mut NodeElement<Custom>) -> bool {
51        true
52    }
53    fn visit_fragment(&mut self, _node: &mut NodeFragment<Custom>) -> bool {
54        true
55    }
56
57    // Visit atoms
58    fn visit_rust_code(&mut self, _code: RustCode) -> bool {
59        true
60    }
61    fn visit_invalid_block(&mut self, _block: &mut InvalidBlock) -> bool {
62        true
63    }
64    fn visit_node_name(&mut self, _name: &mut NodeName) -> bool {
65        true
66    }
67
68    fn visit_open_tag(&mut self, _open_tag: &mut OpenTag) -> bool {
69        true
70    }
71    fn visit_close_tag(&mut self, _closed_tag: &mut CloseTag) -> bool {
72        true
73    }
74    // Visit Attributes
75    fn visit_attribute(&mut self, _attribute: &mut NodeAttribute) -> bool {
76        true
77    }
78    fn visit_keyed_attribute(&mut self, _attribute: &mut KeyedAttribute) -> bool {
79        true
80    }
81    fn visit_attribute_flag(&mut self, _key: &mut NodeName) -> bool {
82        true
83    }
84    fn visit_attribute_binding(&mut self, _key: &mut NodeName, _value: &mut FnBinding) -> bool {
85        true
86    }
87    fn visit_attribute_value(
88        &mut self,
89        _key: &mut NodeName,
90        _value: &mut AttributeValueExpr,
91    ) -> bool {
92        true
93    }
94}
95
96#[derive(Debug, Default, Clone, PartialEq, PartialOrd, Ord, Copy, Eq)]
97pub struct AnyWalker<C>(PhantomData<C>);
98
99/// Define walker for `CustomNode`.
100pub trait CustomNodeWalker {
101    type Custom: CustomNode;
102    fn walk_custom_node_fields<VisitorImpl: Visitor<Self::Custom>>(
103        visitor: &mut VisitorImpl,
104        node: &mut Self::Custom,
105    ) -> bool;
106}
107
108impl<C> CustomNodeWalker for AnyWalker<C>
109where
110    C: CustomNode,
111{
112    type Custom = C;
113    fn walk_custom_node_fields<VisitorImpl: Visitor<Self::Custom>>(
114        _visitor: &mut VisitorImpl,
115        _node: &mut C,
116    ) -> bool {
117        true
118    }
119}
120
121macro_rules! visit_inner {
122    ($self:ident.$visitor:ident.$method:ident($($tokens:tt)*)) => {
123        if !$self.$visitor.$method($($tokens)*) {
124            return false;
125        }
126    };
127}
128
129macro_rules! try_visit {
130    ($self:ident.$method:ident($($tokens:tt)*)) => {
131        if !$self.$method($($tokens)*) {
132            return false;
133        }
134    };
135}
136
137/// Wrapper for visitor that calls inner visitors.
138/// Inner visitor should implement `Visitor` trait and
139/// `syn::visit_mut::VisitMut`.
140///
141/// For regular usecases it is recommended to use `visit_nodes`,
142/// `visit_nodes_with_custom` or `visit_attributes` functions.
143///
144/// But if you need it can be used by calling `visit_*` methods directly.
145///
146/// Example:
147/// ```rust
148/// use quote::quote;
149/// use rstml::{
150///     node::{Node, NodeText},
151///     visitor::{Visitor, Walker},
152///     Infallible,
153/// };
154/// use syn::parse_quote;
155///
156/// struct TestVisitor;
157/// impl<C> Visitor<C> for TestVisitor {
158///     fn visit_text_node(&mut self, node: &mut NodeText) -> bool {
159///         *node = parse_quote!("modified");
160///         true
161///     }
162/// }
163/// impl syn::visit_mut::VisitMut for TestVisitor {}
164///
165/// let mut visitor = Walker::new(TestVisitor);
166///
167/// let tokens = quote! {
168///     <div>
169///         <span>"Some raw text"</span>
170///         <span></span>"And text after span"
171///     </div>
172/// };
173/// let mut nodes = rstml::parse2(tokens).unwrap();
174/// for node in &mut nodes {
175///     visitor.visit_node(node);
176/// }
177/// let result = quote! {
178///     #(#nodes)*
179/// };
180/// assert_eq!(
181///     result.to_string(),
182///     quote! {
183///         <div>
184///             <span>"modified"</span>
185///             <span></span>"modified"
186///         </div>
187///     }
188///     .to_string()
189/// );
190/// ```
191pub struct Walker<V, C = Infallible, CW = AnyWalker<C>>
192where
193    C: CustomNode,
194    V: Visitor<C> + syn::visit_mut::VisitMut,
195    CW: CustomNodeWalker<Custom = C>,
196{
197    visitor: V,
198    // we use callbakc instead of marker for `CustomNodeWalker`
199    // because it will fail to resolve with infinite recursion
200    walker: PhantomData<CW>,
201    _pd: PhantomData<C>,
202}
203
204impl<V, C> Walker<V, C>
205where
206    C: CustomNode,
207    V: Visitor<C> + syn::visit_mut::VisitMut,
208{
209    pub fn new(visitor: V) -> Self {
210        Self {
211            visitor,
212            walker: PhantomData,
213            _pd: PhantomData,
214        }
215    }
216    pub fn with_custom_handler<OtherCW>(visitor: V) -> Walker<V, C, OtherCW>
217    where
218        OtherCW: CustomNodeWalker<Custom = C>,
219    {
220        Walker {
221            visitor,
222            walker: PhantomData,
223            _pd: PhantomData,
224        }
225    }
226}
227impl<V, C, CW> Walker<V, C, CW>
228where
229    C: CustomNode,
230    V: Visitor<C> + syn::visit_mut::VisitMut,
231    CW: CustomNodeWalker<Custom = C>,
232{
233    pub fn destruct(self) -> V {
234        self.visitor
235    }
236}
237
238impl<V, C, CW> Visitor<C> for Walker<V, C, CW>
239where
240    C: CustomNode,
241    V: Visitor<C> + syn::visit_mut::VisitMut,
242    CW: CustomNodeWalker<Custom = C>,
243{
244    fn visit_node(&mut self, node: &mut Node<C>) -> bool {
245        visit_inner!(self.visitor.visit_node(node));
246
247        match node {
248            Node::Block(b) => self.visit_block(b),
249            Node::Comment(c) => self.visit_comment(c),
250            Node::Doctype(d) => self.visit_doctype(d),
251            Node::Element(e) => self.visit_element(e),
252            Node::Fragment(f) => self.visit_fragment(f),
253            Node::Text(t) => self.visit_text_node(t),
254            Node::RawText(r) => self.visit_raw_node(r),
255            Node::Custom(c) => self.visit_custom(c),
256        }
257    }
258    fn visit_block(&mut self, node: &mut NodeBlock) -> bool {
259        visit_inner!(self.visitor.visit_block(node));
260
261        match node {
262            NodeBlock::Invalid(b) => self.visit_invalid_block(b),
263            NodeBlock::ValidBlock(b) => self.visit_rust_code(RustCode::Block(b)),
264        }
265    }
266    fn visit_comment(&mut self, node: &mut NodeComment) -> bool {
267        visit_inner!(self.visitor.visit_comment(node));
268
269        self.visit_rust_code(RustCode::LitStr(&mut node.value))
270    }
271    fn visit_doctype(&mut self, node: &mut NodeDoctype) -> bool {
272        visit_inner!(self.visitor.visit_doctype(node));
273
274        self.visit_raw_node(&mut node.value)
275    }
276    fn visit_raw_node<OtherC: CustomNode>(&mut self, node: &mut RawText<OtherC>) -> bool {
277        visit_inner!(self.visitor.visit_raw_node(node));
278
279        true
280    }
281    fn visit_custom(&mut self, node: &mut C) -> bool {
282        visit_inner!(self.visitor.visit_custom(node));
283
284        CW::walk_custom_node_fields(self, node)
285    }
286    fn visit_text_node(&mut self, node: &mut NodeText) -> bool {
287        visit_inner!(self.visitor.visit_text_node(node));
288
289        self.visit_rust_code(RustCode::LitStr(&mut node.value))
290    }
291    fn visit_element(&mut self, node: &mut NodeElement<C>) -> bool {
292        visit_inner!(self.visitor.visit_element(node));
293
294        try_visit!(self.visit_open_tag(&mut node.open_tag));
295
296        for attribute in node.attributes_mut() {
297            try_visit!(self.visit_attribute(attribute))
298        }
299        for child in node.children_mut() {
300            try_visit!(self.visit_node(child))
301        }
302
303        if let Some(close_tag) = &mut node.close_tag {
304            try_visit!(self.visit_close_tag(close_tag));
305        }
306        true
307    }
308    fn visit_fragment(&mut self, node: &mut NodeFragment<C>) -> bool {
309        visit_inner!(self.visitor.visit_fragment(node));
310
311        for child in node.children_mut() {
312            try_visit!(self.visit_node(child))
313        }
314        true
315    }
316
317    fn visit_open_tag(&mut self, open_tag: &mut OpenTag) -> bool {
318        visit_inner!(self.visitor.visit_open_tag(open_tag));
319
320        try_visit!(self.visit_node_name(&mut open_tag.name));
321
322        true
323    }
324    fn visit_close_tag(&mut self, closed_tag: &mut CloseTag) -> bool {
325        visit_inner!(self.visitor.visit_close_tag(closed_tag));
326
327        try_visit!(self.visit_node_name(&mut closed_tag.name));
328
329        true
330    }
331
332    fn visit_attribute(&mut self, attribute: &mut NodeAttribute) -> bool {
333        visit_inner!(self.visitor.visit_attribute(attribute));
334
335        match attribute {
336            NodeAttribute::Attribute(a) => self.visit_keyed_attribute(a),
337            NodeAttribute::Block(b) => self.visit_block(b),
338        }
339    }
340    fn visit_keyed_attribute(&mut self, attribute: &mut KeyedAttribute) -> bool {
341        visit_inner!(self.visitor.visit_keyed_attribute(attribute));
342
343        match &mut attribute.possible_value {
344            KeyedAttributeValue::None => self.visit_attribute_flag(&mut attribute.key),
345            KeyedAttributeValue::Binding(b) => self.visit_attribute_binding(&mut attribute.key, b),
346            KeyedAttributeValue::Value(v) => self.visit_attribute_value(&mut attribute.key, v),
347        }
348    }
349    fn visit_attribute_flag(&mut self, key: &mut NodeName) -> bool {
350        visit_inner!(self.visitor.visit_attribute_flag(key));
351        true
352    }
353    fn visit_attribute_binding(&mut self, key: &mut NodeName, value: &mut FnBinding) -> bool {
354        visit_inner!(self.visitor.visit_attribute_binding(key, value));
355
356        for input in value.inputs.iter_mut() {
357            try_visit!(self.visit_rust_code(RustCode::Pat(input)))
358        }
359        true
360    }
361    fn visit_attribute_value(
362        &mut self,
363        key: &mut NodeName,
364        value: &mut AttributeValueExpr,
365    ) -> bool {
366        visit_inner!(self.visitor.visit_attribute_value(key, value));
367
368        self.visit_node_name(key);
369        match &mut value.value {
370            KVAttributeValue::Expr(expr) => self.visit_rust_code(RustCode::Expr(expr)),
371            KVAttributeValue::InvalidBraced(braced) => self.visit_invalid_block(braced),
372        }
373    }
374
375    fn visit_invalid_block(&mut self, block: &mut InvalidBlock) -> bool {
376        visit_inner!(self.visitor.visit_invalid_block(block));
377
378        true
379    }
380    fn visit_node_name(&mut self, name: &mut NodeName) -> bool {
381        visit_inner!(self.visitor.visit_node_name(name));
382
383        true
384    }
385    fn visit_rust_code(&mut self, mut code: RustCode) -> bool {
386        {
387            // use rewrap because enum `RustCode` is not Copy
388            let rewrap = match &mut code {
389                RustCode::Block(b) => RustCode::Block(b),
390                RustCode::Expr(e) => RustCode::Expr(e),
391                RustCode::LitStr(l) => RustCode::LitStr(l),
392                RustCode::Pat(p) => RustCode::Pat(p),
393            };
394            visit_inner!(self.visitor.visit_rust_code(rewrap));
395        }
396
397        match code {
398            RustCode::Block(b) => self.visitor.visit_block_mut(b),
399            RustCode::Expr(e) => self.visitor.visit_expr_mut(e),
400            RustCode::LitStr(l) => self.visitor.visit_lit_str_mut(l),
401            RustCode::Pat(p) => self.visitor.visit_pat_mut(p),
402        }
403
404        true
405    }
406}
407/// Visitor entrypoint.
408/// Visit nodes in array calling visitor methods.
409/// Recursively visit nodes in children, and attributes.
410///
411/// Return modified visitor back
412pub fn visit_nodes<V, C>(nodes: &mut [Node<C>], visitor: V) -> V
413where
414    C: CustomNode,
415    V: Visitor<C> + syn::visit_mut::VisitMut,
416{
417    let mut visitor = Walker::<V, C>::new(visitor);
418    for node in nodes {
419        visitor.visit_node(node);
420    }
421    visitor.visitor
422}
423
424/// Visitor entrypoint.
425/// Visit nodes in array calling visitor methods.
426/// Recursively visit nodes in children, and attributes.
427/// Provide custom handler that is used to visit custom nodes.
428/// Custom handler should return true if visitor should continue to traverse,
429/// and call visitor methods for its children.
430///
431/// Return modified visitor back
432pub fn visit_nodes_with_custom<V, C, CW>(nodes: &mut [Node<C>], visitor: V) -> V
433where
434    C: CustomNode,
435    V: Visitor<C> + syn::visit_mut::VisitMut,
436    CW: CustomNodeWalker<Custom = C>,
437{
438    let mut visitor = Walker::with_custom_handler::<CW>(visitor);
439    for node in nodes {
440        visitor.visit_node(node);
441    }
442    visitor.visitor
443}
444
445/// Visit attributes in array calling visitor methods.
446pub fn visit_attributes<V>(attributes: &mut [NodeAttribute], visitor: V) -> V
447where
448    V: Visitor<Infallible> + syn::visit_mut::VisitMut,
449    Walker<V>: Visitor<Infallible>,
450{
451    let mut visitor = Walker::new(visitor);
452    for attribute in attributes {
453        visitor.visit_attribute(attribute);
454    }
455    visitor.visitor
456}
457#[cfg(test)]
458mod tests {
459
460    use quote::{quote, ToTokens};
461    use syn::parse_quote;
462
463    use super::*;
464    use crate::Infallible;
465    #[test]
466    fn collect_node_names() {
467        #[derive(Default)]
468        struct TestVisitor {
469            collected_names: Vec<NodeName>,
470        }
471        impl<C: CustomNode> Visitor<C> for TestVisitor {
472            fn visit_node_name(&mut self, name: &mut NodeName) -> bool {
473                self.collected_names.push(name.clone());
474                true
475            }
476        }
477        // empty impl
478        impl syn::visit_mut::VisitMut for TestVisitor {}
479
480        let stream = quote! {
481            <div>
482                <span></span>
483                <span></span>
484            </div>
485            <!-- "comment" -->
486            <foo attr key=value> </foo>
487        };
488        let mut nodes = crate::parse2(stream).unwrap();
489        let visitor = visit_nodes(&mut nodes, TestVisitor::default());
490        // convert node_names to string;
491        let node_names = visitor
492            .collected_names
493            .iter()
494            .map(|name| name.to_string())
495            .collect::<Vec<_>>();
496
497        assert_eq!(
498            node_names,
499            vec!["div", "span", "span", "span", "span", "div", "foo", "key", "foo"]
500        );
501    }
502
503    #[test]
504    fn collect_node_elements() {
505        #[derive(Default)]
506        struct TestVisitor {
507            collected_names: Vec<NodeName>,
508        }
509        impl<C: CustomNode> Visitor<C> for TestVisitor {
510            fn visit_element(&mut self, node: &mut NodeElement<C>) -> bool {
511                self.collected_names.push(node.open_tag.name.clone());
512                true
513            }
514        }
515        // empty impl
516        impl syn::visit_mut::VisitMut for TestVisitor {}
517
518        let stream = quote! {
519            <div>
520                <span></span>
521                <span></span>
522            </div>
523            <!-- "comment" -->
524            <foo attr key=value> </foo>
525        };
526        let mut nodes = crate::parse2(stream).unwrap();
527        let visitor = visit_nodes(&mut nodes, TestVisitor::default());
528        // convert node_names to string;
529        let node_names = visitor
530            .collected_names
531            .iter()
532            .map(|name| name.to_string())
533            .collect::<Vec<_>>();
534
535        assert_eq!(node_names, vec!["div", "span", "span", "foo"]);
536    }
537
538    #[test]
539    fn collect_rust_blocks() {
540        #[derive(Default)]
541        struct TestVisitor {
542            collected_blocks: Vec<syn::Block>,
543        }
544        // empty impl
545        impl<C: CustomNode> Visitor<C> for TestVisitor {}
546        impl syn::visit_mut::VisitMut for TestVisitor {
547            fn visit_block_mut(&mut self, i: &mut syn::Block) {
548                self.collected_blocks.push(i.clone());
549            }
550        }
551
552        let stream = quote! {
553            <div>
554            { let block = "in node position"; }
555                <span { block_in_attr_position = foo }></span>
556                <span var = {block_in_value}></span>
557            </div>
558            <!-- "comment" -->
559        };
560        let mut nodes = crate::parse2(stream).unwrap();
561        let visitor = visit_nodes(&mut nodes, TestVisitor::default());
562        // convert node_names to string;
563        let blocks = visitor
564            .collected_blocks
565            .iter()
566            .map(|block| block.to_token_stream().to_string())
567            .collect::<Vec<_>>();
568
569        assert_eq!(
570            blocks,
571            vec![
572                "{ let block = \"in node position\" ; }",
573                "{ block_in_attr_position = foo }",
574                "{ block_in_value }",
575            ]
576        );
577    }
578
579    #[test]
580    fn collect_raw_text() {
581        #[derive(Default)]
582        struct TestVisitor {
583            collected_raw_text: Vec<RawText<Infallible>>,
584        }
585        impl<C: CustomNode> Visitor<C> for TestVisitor {
586            fn visit_raw_node<AnyC: CustomNode>(&mut self, node: &mut RawText<AnyC>) -> bool {
587                let raw = node.clone().convert_custom::<Infallible>();
588                self.collected_raw_text.push(raw);
589                true
590            }
591        }
592        // empty impl
593        impl syn::visit_mut::VisitMut for TestVisitor {}
594
595        let stream = quote! {
596            <!Doctype Other raw text >
597            <div>
598                <span>Some raw text</span>
599                <span></span> And text after span
600            </div>
601            <!-- "comment" -->
602            <foo attr key=value> </foo>
603        };
604        let mut nodes = crate::parse2(stream).unwrap();
605        let visitor = visit_nodes(&mut nodes, TestVisitor::default());
606        // convert collected_raw_text to string;
607        let raw_text = visitor
608            .collected_raw_text
609            .iter()
610            .map(|raw| raw.to_string_best())
611            .collect::<Vec<_>>();
612
613        assert_eq!(
614            raw_text,
615            vec!["Other raw text", "Some raw text", "And text after span",]
616        );
617    }
618
619    #[test]
620    fn collect_string_literals() {
621        #[derive(Default)]
622        struct TestVisitor {
623            collected_literals: Vec<syn::LitStr>,
624        }
625        impl<C: CustomNode> Visitor<C> for TestVisitor {}
626        impl syn::visit_mut::VisitMut for TestVisitor {
627            fn visit_lit_str_mut(&mut self, i: &mut syn::LitStr) {
628                self.collected_literals.push(i.clone());
629            }
630        }
631
632        let stream = quote! {
633            <!Doctype Other raw text >
634            <div>
635                <span>"Some raw text"</span>
636                <span></span>"And text after span"
637            </div>
638            <!-- "comment" -->
639            <foo attr key=value> </foo>
640        };
641        let mut nodes = crate::parse2(stream).unwrap();
642        let visitor = visit_nodes(&mut nodes, TestVisitor::default());
643        // convert collected_literals to string;
644        let literals = visitor
645            .collected_literals
646            .iter()
647            .map(|lit| lit.value())
648            .collect::<Vec<_>>();
649
650        assert_eq!(
651            literals,
652            vec!["Some raw text", "And text after span", "comment"]
653        );
654    }
655
656    #[test]
657    fn modify_text_visitor() {
658        struct TestVisitor;
659        impl<C: CustomNode> Visitor<C> for TestVisitor {
660            fn visit_text_node(&mut self, node: &mut NodeText) -> bool {
661                *node = parse_quote!("modified");
662                true
663            }
664        }
665        impl syn::visit_mut::VisitMut for TestVisitor {}
666
667        let mut visitor = Walker::new(TestVisitor);
668
669        let tokens = quote! {
670            <div>
671                <span>"Some raw text"</span>
672                <span></span>"And text after span"
673            </div>
674        };
675        let mut nodes = crate::parse2(tokens).unwrap();
676        for node in &mut nodes {
677            visitor.visit_node(node);
678        }
679        let result = quote! {
680            #(#nodes)*
681        };
682        assert_eq!(
683            result.to_string(),
684            quote! {
685                <div>
686                    <span>"modified"</span>
687                    <span></span>"modified"
688                </div>
689            }
690            .to_string()
691        );
692    }
693}