gbf_core/decompiler/ast/
bin_op.rs

1#![deny(missing_docs)]
2
3use gbf_macros::AstNodeTransform;
4use serde::{Deserialize, Serialize};
5
6use super::AstKind;
7use super::ptr::P;
8use super::visitors::AstVisitor;
9use super::{AstNodeError, expr::ExprKind};
10use crate::decompiler::ast::AstVisitable;
11use crate::define_ast_enum_type;
12
13define_ast_enum_type! {
14    BinOpType {
15        Add => "+",
16        Sub => "-",
17        Mul => "*",
18        Div => "/",
19        Mod => "%",
20        And => "&",
21        Or => "|",
22        Xor => "xor",
23        LogicalAnd => "&&",
24        LogicalOr => "||",
25        Equal => "==",
26        NotEqual => "!=",
27        Greater => ">",
28        Less => "<",
29        GreaterOrEqual => ">=",
30        LessOrEqual => "<=",
31        ShiftLeft => "<<",
32        ShiftRight => ">>",
33        In => "in",
34        Join => "@",
35        Power => "^",
36        Foreach => ":"
37    }
38}
39
40/// Represents a binary operation node in the AST, such as `a + b`.
41#[derive(Debug, Clone, Serialize, Deserialize, Eq, AstNodeTransform)]
42#[convert_to(ExprKind::BinOp, AstKind::Expression)]
43pub struct BinaryOperationNode {
44    /// The left-hand side of the binary operation.
45    pub lhs: ExprKind,
46    /// The right-hand side of the binary operation.
47    pub rhs: ExprKind,
48    /// The binary operation type.
49    pub op_type: BinOpType,
50}
51
52impl BinaryOperationNode {
53    /// Creates a new `BinaryOperationNode` after validating `lhs` and `rhs`.
54    ///
55    /// # Arguments
56    /// - `lhs` - The left-hand side expression.
57    /// - `rhs` - The right-hand side expression.
58    /// - `op_type` - The binary operation type.
59    ///
60    /// # Returns
61    /// A new `BinaryOperationNode`.
62    ///
63    /// # Errors
64    /// Returns an `AstNodeError` if `lhs` or `rhs` is of an unsupported type.
65    pub fn new(lhs: ExprKind, rhs: ExprKind, op_type: BinOpType) -> Result<Self, AstNodeError> {
66        Ok(Self { lhs, rhs, op_type })
67    }
68}
69
70// == Other implementations for binary operations ==
71impl PartialEq for BinaryOperationNode {
72    fn eq(&self, other: &Self) -> bool {
73        self.lhs == other.lhs && self.rhs == other.rhs && self.op_type == other.op_type
74    }
75}
76
77impl AstVisitable for P<BinaryOperationNode> {
78    fn accept<V: AstVisitor>(&self, visitor: &mut V) -> V::Output {
79        visitor.visit_bin_op(self)
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use crate::decompiler::ast::{emit, new_bin_op, new_id};
86
87    use super::*;
88
89    #[test]
90    fn test_bin_op_emit() -> Result<(), AstNodeError> {
91        for op_type in BinOpType::all_variants() {
92            let expr = new_bin_op(new_id("a"), new_id("b"), op_type.clone())?;
93            assert_eq!(emit(expr), format!("a {} b", op_type.as_str()));
94        }
95        Ok(())
96    }
97
98    #[test]
99    fn test_nested_bin_op_emit() -> Result<(), AstNodeError> {
100        let expr = new_bin_op(
101            new_bin_op(new_id("a"), new_id("b"), BinOpType::Add)?,
102            new_id("c"),
103            BinOpType::Mul,
104        )?;
105        assert_eq!(emit(expr), "(a + b) * c");
106        Ok(())
107    }
108
109    #[test]
110    fn test_bin_op_eq() -> Result<(), AstNodeError> {
111        let a = new_bin_op(new_id("a"), new_id("b"), BinOpType::Add)?;
112        let b = new_bin_op(new_id("a"), new_id("b"), BinOpType::Add)?;
113        let c = new_bin_op(new_id("a"), new_id("b"), BinOpType::Sub)?;
114        let d = new_bin_op(new_id("a"), new_id("c"), BinOpType::Add)?;
115
116        assert_eq!(a, b);
117        assert_ne!(a, c);
118        assert_ne!(a, d);
119        Ok(())
120    }
121}