gbf_core/decompiler/
function_decompiler.rs

1#![deny(missing_docs)]
2
3use crate::basic_block::BasicBlockId;
4use crate::function::{Function, FunctionError};
5use crate::instruction::Instruction;
6use crate::opcode::Opcode;
7use crate::operand::OperandError;
8use crate::utils::STRUCTURE_ANALYSIS_MAX_ITERATIONS;
9use serde::Serialize;
10use std::backtrace::Backtrace;
11use std::collections::HashMap;
12use thiserror::Error;
13
14use super::ast::array_kind::ArrayKind;
15use super::ast::expr::ExprKind;
16use super::ast::function::FunctionNode;
17use super::ast::visitors::emit_context::EmitContext;
18use super::ast::visitors::emitter::Gs2Emitter;
19use super::ast::{AstKind, AstVisitable, new_array};
20use super::execution_frame::ExecutionFrame;
21use super::function_decompiler_context::FunctionDecompilerContext;
22use super::structure_analysis::region::{RegionId, RegionType};
23use super::structure_analysis::{ControlFlowEdgeType, StructureAnalysis, StructureAnalysisError};
24
25/// An error when decompiling a function
26#[derive(Debug, Error, Serialize)]
27pub enum FunctionDecompilerError {
28    /// Encountered FunctionError
29    #[error("Encountered FunctionError while decompiling: {source}")]
30    FunctionError {
31        /// The source of the error
32        source: FunctionError,
33        /// The context of the error
34        context: Box<FunctionDecompilerErrorContext>,
35        /// The backtrace of the error
36        #[serde(skip)]
37        backtrace: Backtrace,
38    },
39
40    /// Register not found
41    #[error("Register not found: {register_id}")]
42    RegisterNotFound {
43        /// The register ID that was not found
44        register_id: usize,
45        /// The context of the error
46        context: Box<FunctionDecompilerErrorContext>,
47        /// The backtrace of the error
48        #[serde(skip)]
49        backtrace: Backtrace,
50    },
51
52    /// Encountered an error while processing the operand
53    #[error("Encountered an error while processing the operand: {source}")]
54    OperandError {
55        /// The source of the error
56        source: OperandError,
57        /// The context of the error
58        context: Box<FunctionDecompilerErrorContext>,
59        /// The backtrace of the error
60        #[serde(skip)]
61        backtrace: Backtrace,
62    },
63
64    /// Encountered AstNodeError
65    #[error("Encountered AstNodeError while decompiling: {source}")]
66    AstNodeError {
67        /// The source of the error
68        source: super::ast::AstNodeError,
69        /// The context of the error
70        context: Box<FunctionDecompilerErrorContext>,
71        /// The backtrace of the error
72        #[serde(skip)]
73        backtrace: Backtrace,
74    },
75
76    /// The current instruction must have an operand
77    #[error("The instruction associated with opcode {opcode} must have an operand.")]
78    InstructionMustHaveOperand {
79        /// The opcode associated with the instruction
80        opcode: Opcode,
81        /// The context of the error
82        context: Box<FunctionDecompilerErrorContext>,
83        /// The backtrace of the error
84        #[serde(skip)]
85        backtrace: Backtrace,
86    },
87
88    /// Invalid node type on stack
89    #[error("Unexpected AstNode sub-type on stack. Expected {expected}.")]
90    UnexpectedNodeType {
91        /// The expected node type
92        expected: String,
93        /// The context of the error
94        context: Box<FunctionDecompilerErrorContext>,
95        /// The backtrace of the error
96        #[serde(skip)]
97        backtrace: Backtrace,
98    },
99
100    /// Unimplemented Opcode
101    #[error("Unimplemented Opcode: {opcode}")]
102    UnimplementedOpcode {
103        /// The opcode that is unimplemented
104        opcode: Opcode,
105        /// The context of the error
106        context: Box<FunctionDecompilerErrorContext>,
107        /// The backtrace of the error
108        #[serde(skip)]
109        backtrace: Backtrace,
110    },
111
112    /// Execution state stack is empty
113    #[error("The AST Node stack is empty.")]
114    ExecutionStackEmpty {
115        /// The context of the error
116        context: Box<FunctionDecompilerErrorContext>,
117        /// The backtrace of the error
118        #[serde(skip)]
119        backtrace: Backtrace,
120    },
121
122    /// Unexpected execution state
123    #[error("Unexpected execution state.")]
124    UnexpectedExecutionState {
125        /// The context of the error
126        context: Box<FunctionDecompilerErrorContext>,
127        /// The backtrace of the error
128        #[serde(skip)]
129        backtrace: Backtrace,
130    },
131
132    /// Structure analysis error
133    #[error("A structure analysis error occurred while decompiling the function: {source}")]
134    StructureAnalysisError {
135        /// The source of the error
136        source: Box<StructureAnalysisError>,
137        /// The context of the error
138        context: Box<FunctionDecompilerErrorContext>,
139        /// The backtrace of the error
140        #[serde(skip)]
141        backtrace: Backtrace,
142    },
143
144    /// All other errors
145    #[error("An error occurred while decompiling the function: {message}")]
146    Other {
147        /// Message associated with the error
148        message: String,
149        /// The context of the error
150        context: Box<FunctionDecompilerErrorContext>,
151        /// The backtrace of the error
152        #[serde(skip)]
153        backtrace: Backtrace,
154    },
155}
156
157/// A trait to provide details for a function decompiler error
158pub trait FunctionDecompilerErrorDetails {
159    /// Get the context for the error
160    fn context(&self) -> &FunctionDecompilerErrorContext;
161    /// Get the backtrace for the error
162    fn backtrace(&self) -> &Backtrace;
163    /// Get the type for the error
164    fn error_type(&self) -> String;
165}
166
167/// The context for a function decompiler error
168#[derive(Debug, Serialize, Clone)]
169pub struct FunctionDecompilerErrorContext {
170    /// The current block ID when the error occurred
171    pub current_block_id: BasicBlockId,
172    /// The current instruction when the error occurred
173    pub current_instruction: Instruction,
174    /// The current AST node stack when the error occurred
175    pub current_ast_node_stack: Vec<ExecutionFrame>,
176}
177
178/// The builder for a function decompiler
179pub struct FunctionDecompilerBuilder<'a> {
180    function: &'a Function,
181    emit_context: EmitContext,
182    structure_debug_mode: bool,
183    structure_analysis_max_iterations: usize,
184}
185
186impl<'a> FunctionDecompilerBuilder<'a> {
187    /// Create a new function decompiler builder
188    pub fn new(function: &'a Function) -> Self {
189        FunctionDecompilerBuilder {
190            function,
191            emit_context: EmitContext::default(),
192            structure_debug_mode: false,
193            structure_analysis_max_iterations: STRUCTURE_ANALYSIS_MAX_ITERATIONS,
194        }
195    }
196
197    /// Set the emit context for the function decompiler
198    pub fn emit_context(mut self, emit_context: EmitContext) -> Self {
199        self.emit_context = emit_context;
200        self
201    }
202
203    /// Set the structure debug mode for the function decompiler. These keeps track
204    /// of the structure of the function as it is being analyzed with StructureAnalysis.
205    pub fn structure_debug_mode(mut self, structure_debug_mode: bool) -> Self {
206        self.structure_debug_mode = structure_debug_mode;
207        self
208    }
209
210    /// Sets the maximum number of iterations for the structure analysis
211    pub fn structure_analysis_max_iterations(mut self, max_iterations: usize) -> Self {
212        self.structure_analysis_max_iterations = max_iterations;
213        self
214    }
215
216    /// Build the function decompiler
217    pub fn build(self) -> FunctionDecompiler<'a> {
218        FunctionDecompiler::new(
219            self.function,
220            self.structure_debug_mode,
221            self.structure_analysis_max_iterations,
222        )
223    }
224}
225
226/// A struct to hold the state of a function decompiler
227pub struct FunctionDecompiler<'a> {
228    /// Create a copy of the function to analyze
229    function: &'a Function,
230    /// A conversion from block ids to region ids
231    block_to_region: HashMap<BasicBlockId, RegionId>,
232    /// The current context for the decompiler
233    context: Option<FunctionDecompilerContext>,
234    /// The parameters for the function
235    function_parameters: Option<ArrayKind>,
236    /// The structure analysis
237    struct_analysis: StructureAnalysis,
238    /// Whether the analysis has been run
239    did_run_analysis: bool,
240}
241
242impl<'a> FunctionDecompiler<'a> {
243    /// A new method for the FunctionDecompiler struct.
244    ///
245    /// # Arguments
246    /// - `function`: The function to analyze and decompile.
247    /// - `structure_debug_mode`: Whether to enable debug mode for the structure analysis.
248    /// - `structure_max_iterations`: The maximum number of iterations for the structure analysis.
249    ///
250    /// # Returns
251    /// - A newly constructed `FunctionDecompiler` instance.
252    ///
253    /// # Errors
254    /// - `FunctionDecompilerError` if there is an error while decompiling the function.
255    fn new(
256        function: &'a Function,
257        structure_debug_mode: bool,
258        structure_max_iterations: usize,
259    ) -> Self {
260        FunctionDecompiler {
261            function,
262            block_to_region: HashMap::new(),
263            context: None,
264            function_parameters: None,
265            struct_analysis: StructureAnalysis::new(structure_debug_mode, structure_max_iterations),
266            did_run_analysis: false,
267        }
268    }
269}
270
271// == Private Functions ==
272impl FunctionDecompiler<'_> {
273    /// Decompile the function and emit the AST as a string.
274    ///
275    /// # Arguments
276    /// - `context`: The context for AST emission.
277    ///
278    /// # Returns
279    /// - The emitted AST as a string.
280    ///
281    /// # Errors
282    /// - Returns `FunctionDecompilerError` for any issues encountered during decompilation.
283    pub fn decompile(
284        &mut self,
285        emit_context: EmitContext,
286    ) -> Result<String, FunctionDecompilerError> {
287        self.process_regions()?;
288
289        let entry_block_id = self.function.get_entry_basic_block().id;
290        let entry_region_id = self.block_to_region.get(&entry_block_id).unwrap();
291
292        self.did_run_analysis = true;
293        self.struct_analysis.execute().map_err(|e| {
294            FunctionDecompilerError::StructureAnalysisError {
295                source: Box::new(e),
296                context: self.context.as_ref().unwrap().get_error_context(),
297                backtrace: Backtrace::capture(),
298            }
299        })?;
300        let entry_region = {
301            let region = self
302                .struct_analysis
303                .get_region(*entry_region_id)
304                .expect("[Bug] The entry region should exist.");
305            region.clone()
306        };
307        let entry_region_nodes = entry_region.iter_nodes().cloned().collect::<Vec<_>>();
308
309        let func = AstKind::Function(
310            FunctionNode::new(
311                self.function.id.name.clone(),
312                self.function_parameters
313                    .clone()
314                    .unwrap_or_else(|| new_array::<ExprKind>(vec![]).into()),
315                entry_region_nodes,
316            )
317            .into(),
318        );
319
320        let mut emitter = Gs2Emitter::new(emit_context);
321        let output: String = func.accept(&mut emitter).node;
322
323        Ok(output)
324    }
325
326    /// Get the structure analysis snapshots
327    pub fn get_structure_analysis_snapshots(&self) -> Result<Vec<String>, FunctionDecompilerError> {
328        self.struct_analysis
329            .get_snapshots()
330            .map_err(|e| FunctionDecompilerError::StructureAnalysisError {
331                source: Box::new(e),
332                context: self.context.as_ref().unwrap().get_error_context(),
333                backtrace: Backtrace::capture(),
334            })
335            .cloned()
336    }
337
338    fn generate_regions(&mut self) -> Result<(), FunctionDecompilerError> {
339        for block in self.function.iter() {
340            // If the block is the end of the module, it is a tail region
341            let successors = self.function.get_successors(block.id).map_err(|e| {
342                FunctionDecompilerError::FunctionError {
343                    source: e,
344                    backtrace: Backtrace::capture(),
345                    context: self.context.as_ref().unwrap().get_error_context(),
346                }
347            })?;
348            let region_type = if successors.is_empty() {
349                RegionType::Tail
350            } else {
351                RegionType::Linear
352            };
353
354            let new_region_id = self.struct_analysis.add_region(region_type);
355            self.block_to_region.insert(block.id, new_region_id);
356        }
357        Ok(())
358    }
359
360    fn process_regions(&mut self) -> Result<(), FunctionDecompilerError> {
361        // Generate all the regions before doing anything else
362        self.generate_regions()?;
363
364        let mut ctx = FunctionDecompilerContext::new(self.function.get_entry_basic_block_id());
365
366        // Iterate through all the blocks in reverse post order
367        let reverse_post_order = self
368            .function
369            .get_reverse_post_order(self.function.get_entry_basic_block().id)
370            .map_err(|e| FunctionDecompilerError::FunctionError {
371                source: e,
372                backtrace: Backtrace::capture(),
373                context: ctx.get_error_context(),
374            })?;
375
376        for block_id in &reverse_post_order {
377            // Get the region id for the block
378            let region_id = *self
379                .block_to_region
380                .get(block_id)
381                .expect("[Bug] We just made the regions, so not sure why it doesn't exist.");
382
383            ctx.start_block_processing(*block_id)?;
384
385            // Connect the block's predecessors in the graph
386            self.connect_predecessor_regions(*block_id, region_id)?;
387
388            // Process instructions in the block
389            let instructions: Vec<_> = {
390                let block = self
391                    .function
392                    .get_basic_block_by_id(*block_id)
393                    .map_err(|e| FunctionDecompilerError::FunctionError {
394                        source: e,
395                        backtrace: Backtrace::capture(),
396                        context: ctx.get_error_context(),
397                    })?;
398                block.iter().cloned().collect()
399            };
400
401            for instr in instructions {
402                let processed = ctx.process_instruction(&instr)?;
403                if let Some(node) = processed.node_to_push {
404                    let current_region_id = self
405                        .block_to_region
406                        .get(block_id)
407                        .expect("[Bug] The region should exist.");
408                    self.struct_analysis
409                        .push_to_region(*current_region_id, node);
410                }
411
412                if let Some(params) = processed.function_parameters {
413                    self.function_parameters = Some(params);
414                }
415
416                if let Some(jmp) = &processed.jump_condition {
417                    let current_region_id = self
418                        .block_to_region
419                        .get(block_id)
420                        .expect("[Bug] The region should exist.");
421                    let region = self
422                        .struct_analysis
423                        .get_region_mut(*current_region_id)
424                        .expect("[Bug] The region should exist.");
425
426                    // Get the successor regions -- if all the successors are the same, then we can
427                    // set the region type to linear and remove all duplicate edges.
428                    let successors = self.function.get_successors(*block_id).map_err(|e| {
429                        FunctionDecompilerError::FunctionError {
430                            source: e,
431                            backtrace: Backtrace::capture(),
432                            context: ctx.get_error_context(),
433                        }
434                    })?;
435                    let mut unique_successors = successors.to_vec();
436                    unique_successors.dedup();
437                    // TODO: Double check this logic - this is for cases where a conditional jump
438                    // has the same target for both branches.
439                    if unique_successors.len() == 1 {
440                        region.set_region_type(RegionType::Linear);
441                    } else {
442                        region.set_jump_expr(Some(jmp.clone()));
443                        region.set_region_type(RegionType::ControlFlow);
444                        region.set_branch_opcode(instr.opcode);
445                    }
446                }
447            }
448        }
449
450        self.context = Some(ctx);
451
452        Ok(())
453    }
454
455    /// Get predecessors of a block and return the results as a vector of tuples
456    fn get_predecessors(
457        &self,
458        block_id: BasicBlockId,
459    ) -> Result<Vec<(BasicBlockId, RegionId, ControlFlowEdgeType)>, FunctionDecompilerError> {
460        // Step 1: Get the predecessors of the current block
461        let predecessors = self.function.get_predecessors(block_id).map_err(|e| {
462            FunctionDecompilerError::FunctionError {
463                source: e,
464                backtrace: Backtrace::capture(),
465                context: self.context.as_ref().unwrap().get_error_context(),
466            }
467        })?;
468
469        // Step 2: Map each predecessor to its region ID and determine the edge type
470        let predecessor_regions: Vec<(BasicBlockId, RegionId, ControlFlowEdgeType)> = predecessors
471            .iter()
472            .map(|pred_id| {
473                let pred_region_id = *self.block_to_region.get(pred_id).unwrap();
474
475                // Get the predecessor block
476                let pred_block = self
477                    .function
478                    .get_basic_block_by_id(*pred_id)
479                    .expect("Predecessor block not found");
480
481                // Get the last instruction of the predecessor block
482                // TODO: This is a bug if the block is empty; maybe in this case we should
483                // just get the address of the block?
484                let pred_last_instruction = pred_block.last().expect("Empty block");
485
486                // Get the target block address
487                let target_block = self
488                    .function
489                    .get_basic_block_by_id(block_id)
490                    .expect("Target block not found");
491                let target_address = target_block.id.address;
492
493                // Determine the edge type based on control flow
494                let edge_type = if pred_last_instruction.address + 1 != target_address {
495                    // The target address is NOT the next address, so it's a branch
496                    ControlFlowEdgeType::Branch
497                } else {
498                    ControlFlowEdgeType::Fallthrough
499                };
500
501                (*pred_id, pred_region_id, edge_type)
502            })
503            .collect();
504        Ok(predecessor_regions)
505    }
506
507    fn connect_predecessor_regions(
508        &mut self,
509        block_id: BasicBlockId,
510        region_id: RegionId,
511    ) -> Result<Vec<(BasicBlockId, RegionId, ControlFlowEdgeType)>, FunctionDecompilerError> {
512        // Step 1: Get the predecessors of the current block
513        let predecessor_regions = self.get_predecessors(block_id)?;
514
515        // Step 2: Connect the predecessor regions to the target region in the graph
516        for (_, pred_region_id, edge_type) in &predecessor_regions {
517            self.struct_analysis
518                .connect_regions(*pred_region_id, region_id, *edge_type)
519                .map_err(|e| FunctionDecompilerError::StructureAnalysisError {
520                    source: Box::new(e),
521                    context: self.context.as_ref().unwrap().get_error_context(),
522                    backtrace: Backtrace::capture(),
523                })?;
524        }
525
526        // Step 4: Return the vector of predecessor regions and their edge types
527        Ok(predecessor_regions)
528    }
529}
530
531// == Other Implementations ==
532
533impl FunctionDecompilerErrorDetails for FunctionDecompilerError {
534    fn context(&self) -> &FunctionDecompilerErrorContext {
535        match self {
536            FunctionDecompilerError::FunctionError { context, .. } => context,
537            FunctionDecompilerError::OperandError { context, .. } => context,
538            FunctionDecompilerError::AstNodeError { context, .. } => context,
539            FunctionDecompilerError::InstructionMustHaveOperand { context, .. } => context,
540            FunctionDecompilerError::UnexpectedNodeType { context, .. } => context,
541            FunctionDecompilerError::UnimplementedOpcode { context, .. } => context,
542            FunctionDecompilerError::ExecutionStackEmpty { context, .. } => context,
543            FunctionDecompilerError::UnexpectedExecutionState { context, .. } => context,
544            FunctionDecompilerError::Other { context, .. } => context,
545            FunctionDecompilerError::StructureAnalysisError { context, .. } => context,
546            FunctionDecompilerError::RegisterNotFound { context, .. } => context,
547        }
548    }
549
550    fn backtrace(&self) -> &Backtrace {
551        match self {
552            FunctionDecompilerError::FunctionError { backtrace, .. } => backtrace,
553            FunctionDecompilerError::OperandError { backtrace, .. } => backtrace,
554            FunctionDecompilerError::AstNodeError { backtrace, .. } => backtrace,
555            FunctionDecompilerError::InstructionMustHaveOperand { backtrace, .. } => backtrace,
556            FunctionDecompilerError::UnexpectedNodeType { backtrace, .. } => backtrace,
557            FunctionDecompilerError::UnimplementedOpcode { backtrace, .. } => backtrace,
558            FunctionDecompilerError::ExecutionStackEmpty { backtrace, .. } => backtrace,
559            FunctionDecompilerError::UnexpectedExecutionState { backtrace, .. } => backtrace,
560            FunctionDecompilerError::Other { backtrace, .. } => backtrace,
561            FunctionDecompilerError::StructureAnalysisError { source, .. } => source.backtrace(),
562            FunctionDecompilerError::RegisterNotFound { backtrace, .. } => backtrace,
563        }
564    }
565
566    fn error_type(&self) -> String {
567        match self {
568            FunctionDecompilerError::FunctionError { .. } => "FunctionError".to_string(),
569            FunctionDecompilerError::OperandError { .. } => "OperandError".to_string(),
570            FunctionDecompilerError::AstNodeError { .. } => "AstNodeError".to_string(),
571            FunctionDecompilerError::InstructionMustHaveOperand { .. } => {
572                "InstructionMustHaveOperand".to_string()
573            }
574            FunctionDecompilerError::UnexpectedNodeType { .. } => "UnexpectedNodeType".to_string(),
575            FunctionDecompilerError::UnimplementedOpcode { .. } => {
576                "UnimplementedOpcode".to_string()
577            }
578            FunctionDecompilerError::ExecutionStackEmpty { .. } => {
579                "ExecutionStackEmpty".to_string()
580            }
581            FunctionDecompilerError::UnexpectedExecutionState { .. } => {
582                "UnexpectedExecutionState".to_string()
583            }
584            FunctionDecompilerError::Other { .. } => "Other".to_string(),
585            FunctionDecompilerError::StructureAnalysisError { .. } => {
586                "StructureAnalysisError".to_string()
587            }
588            FunctionDecompilerError::RegisterNotFound { .. } => "RegisterNotFound".to_string(),
589        }
590    }
591}