gbf_core/
function.rs

1#![deny(missing_docs)]
2
3use petgraph::Direction;
4use petgraph::graph::{DiGraph, NodeIndex};
5use petgraph::visit::{DfsPostOrder, Walker};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fmt::{self, Display, Formatter};
9use std::hash::Hash;
10use std::ops::{Deref, Index};
11use thiserror::Error;
12
13use crate::basic_block::{BasicBlock, BasicBlockId, BasicBlockType};
14use crate::cfg_dot::{CfgDot, CfgDotConfig, DotRenderableGraph, NodeResolver};
15use crate::utils::{GBF_BLUE, GBF_GREEN, GBF_RED, Gs2BytecodeAddress};
16
17/// Represents an error that can occur when working with functions.
18#[derive(Error, Debug, Clone, Serialize, Deserialize)]
19pub enum FunctionError {
20    /// The requested `BasicBlock` was not found by its block id.
21    #[error("BasicBlock not found by its block id: {0}")]
22    BasicBlockNotFoundById(BasicBlockId),
23
24    /// The requested `BasicBlock` was not found by its address.
25    #[error("BasicBlock not found by its address: {0}")]
26    BasicBlockNotFoundByAddress(Gs2BytecodeAddress),
27
28    /// The requested `BasicBlock` does not have a `NodeIndex`.
29    #[error("BasicBlock with id {0} does not have a NodeIndex")]
30    BasicBlockNodeIndexNotFound(BasicBlockId),
31
32    /// The function already has an entry block.
33    #[error("Function already has an entry block")]
34    EntryBlockAlreadyExists,
35}
36
37/// Represents the identifier of a function.
38#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
39pub struct FunctionId {
40    index: usize,
41    /// The name of the function, if it is not the entry point.
42    pub name: Option<String>,
43    /// The address of the function in the module.
44    pub address: Gs2BytecodeAddress,
45}
46
47impl FunctionId {
48    /// Create a new `FunctionId`.
49    ///
50    /// # Arguments
51    /// - `index`: The index of the function in the module.
52    /// - `name`: The name of the function, if it is not the entry point.
53    /// - `address`: The address of the function in the module.
54    ///
55    /// # Returns
56    /// - A new `FunctionId` instance.
57    ///
58    /// # Example
59    /// ```
60    /// use gbf_core::function::FunctionId;
61    ///
62    /// let entry = FunctionId::new_without_name(0, 0);
63    /// let add = FunctionId::new(1, Some("add"), 0x100);
64    /// ```
65    pub fn new<S>(index: usize, name: Option<S>, address: Gs2BytecodeAddress) -> Self
66    where
67        S: Into<String>,
68    {
69        Self {
70            index,
71            name: name.map(|n| n.into()),
72            address,
73        }
74    }
75
76    /// Helper method for creating a `FunctionId` without a name.
77    pub fn new_without_name(index: usize, address: Gs2BytecodeAddress) -> Self {
78        Self::new(index, None::<String>, address)
79    }
80
81    /// If the function has a name.
82    ///
83    /// # Returns
84    /// - `true` if the function has a name.
85    /// - `false` if the function does not have a name.
86    ///
87    /// # Example
88    /// ```
89    /// use gbf_core::function::FunctionId;
90    ///
91    /// let entry = FunctionId::new_without_name(0, 0);
92    ///
93    /// assert!(entry.is_named());
94    /// ```
95    pub fn is_named(&self) -> bool {
96        self.name.is_none()
97    }
98}
99
100/// Represents a function in a module.
101#[derive(Debug, Serialize, Deserialize)]
102pub struct Function {
103    /// The identifier of the function.
104    pub id: FunctionId,
105    /// A vector of all the `BasicBlock`s in the function.
106    blocks: Vec<BasicBlock>,
107    /// Maps `BasicBlockId` to their index in the `blocks` vector.
108    block_map: HashMap<BasicBlockId, usize>,
109    /// The control-flow graph of the function.
110    cfg: DiGraph<(), ()>,
111    /// Used to convert `NodeIndex` to `BasicBlockId`.
112    graph_node_to_block: HashMap<NodeIndex, BasicBlockId>,
113    /// Used to convert `BasicBlockId` to `NodeIndex`.
114    block_to_graph_node: HashMap<BasicBlockId, NodeIndex>,
115    /// A map of function addresses to their IDs.
116    address_to_id: HashMap<Gs2BytecodeAddress, FunctionId>,
117}
118
119impl Function {
120    /// Create a new `Function`. Automatically creates an entry block.
121    ///
122    /// # Arguments
123    /// - `id`: The `FunctionId` of the function.
124    ///
125    /// # Returns
126    /// - A new `Function` instance.
127    pub fn new(id: FunctionId) -> Self {
128        let mut blocks = Vec::new();
129        let mut block_map = HashMap::new();
130        let mut graph_node_to_block = HashMap::new();
131        let mut block_to_graph_node = HashMap::new();
132        let address_to_id = HashMap::new();
133        let mut cfg = DiGraph::new();
134
135        // Initialize entry block
136        let entry_block = BasicBlockId::new(blocks.len(), BasicBlockType::Entry, id.address);
137        blocks.push(BasicBlock::new(entry_block));
138        block_map.insert(entry_block, 0);
139
140        // Add an empty node in the graph to represent this BasicBlock
141        let entry_node_id = cfg.add_node(());
142        graph_node_to_block.insert(entry_node_id, entry_block);
143        block_to_graph_node.insert(entry_block, entry_node_id);
144
145        Self {
146            id,
147            blocks,
148            block_map,
149            cfg,
150            graph_node_to_block,
151            block_to_graph_node,
152            address_to_id,
153        }
154    }
155
156    /// Create a new `BasicBlock` and add it to the function.
157    ///
158    /// # Arguments
159    /// - `block_type`: The type of the block.
160    ///
161    /// # Returns
162    /// - A `BasicBlockId` for the new block.
163    ///
164    /// # Example
165    /// ```
166    /// use gbf_core::function::{Function, FunctionId};
167    /// use gbf_core::basic_block::BasicBlockType;
168    ///
169    /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
170    /// let block = function.create_block(BasicBlockType::Normal, 0);
171    /// ```
172    pub fn create_block(
173        &mut self,
174        block_type: BasicBlockType,
175        address: Gs2BytecodeAddress,
176    ) -> Result<BasicBlockId, FunctionError> {
177        // do not allow entry block to be created more than once
178        if block_type == BasicBlockType::Entry {
179            return Err(FunctionError::EntryBlockAlreadyExists);
180        }
181
182        let id = BasicBlockId::new(self.blocks.len(), block_type, address);
183        self.blocks.push(BasicBlock::new(id));
184        self.block_map.insert(id, self.blocks.len() - 1);
185
186        // Insert a node in the petgraph to represent this BasicBlock
187        let node_id = self.cfg.add_node(());
188        self.block_to_graph_node.insert(id, node_id);
189        self.graph_node_to_block.insert(node_id, id);
190
191        Ok(id)
192    }
193
194    /// Get a reference to a `BasicBlock` by its `BasicBlockId`.
195    ///
196    /// # Arguments
197    /// - `id`: The `BasicBlockId` of the block.
198    ///
199    /// # Returns
200    /// - A reference to the `BasicBlock`.
201    ///
202    /// # Errors
203    /// - `FunctionError::BasicBlockNotFound` if the block does not exist.
204    ///
205    /// # Example
206    /// ```
207    /// use gbf_core::function::{Function, FunctionId};
208    /// use gbf_core::basic_block::BasicBlockType;
209    ///
210    /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
211    /// let block_id = function.create_block(BasicBlockType::Normal, 0).unwrap();
212    /// let block_ref = function.get_basic_block_by_id(block_id).unwrap();
213    /// ```
214    pub fn get_basic_block_by_id(&self, id: BasicBlockId) -> Result<&BasicBlock, FunctionError> {
215        let index = self
216            .block_map
217            .get(&id)
218            .ok_or(FunctionError::BasicBlockNotFoundById(id))?;
219        Ok(&self.blocks[*index])
220    }
221
222    /// Get a reference to a `BasicBlock` by its address. The block address
223    /// -must- be the start address of the block.
224    ///
225    /// # Arguments
226    /// - `id`: The `BasicBlockId` of the block.
227    ///
228    /// # Returns
229    /// - A mutable reference to the `BasicBlock`.
230    ///
231    /// # Errors
232    /// - `FunctionError::BasicBlockNotFound` if the block does not exist.
233    ///
234    /// # Example
235    /// ```
236    /// use gbf_core::function::{Function, FunctionId};
237    /// use gbf_core::basic_block::BasicBlockType;
238    ///
239    /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
240    /// let block_id = function.create_block(BasicBlockType::Normal, 0).unwrap();
241    /// let block_ref = function.get_basic_block_by_id_mut(block_id).unwrap();
242    /// ```
243    pub fn get_basic_block_by_id_mut(
244        &mut self,
245        id: BasicBlockId,
246    ) -> Result<&mut BasicBlock, FunctionError> {
247        let index = self
248            .block_map
249            .get(&id)
250            .ok_or(FunctionError::BasicBlockNotFoundById(id))?;
251        Ok(&mut self.blocks[*index])
252    }
253
254    /// Get a reference to a `BasicBlock` by its address.
255    ///
256    /// # Arguments
257    /// - `address`: The address of the block.
258    ///
259    /// # Returns
260    /// - A reference to the `BasicBlock`.
261    ///
262    /// # Errors
263    /// - `FunctionError::BasicBlockNotFoundByAddress` if the block does not exist.
264    ///
265    /// # Example
266    /// ```
267    /// use gbf_core::function::{Function, FunctionId};
268    /// use gbf_core::basic_block::BasicBlockType;
269    ///
270    /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
271    /// let block_id = function.create_block(BasicBlockType::Normal, 0x100).unwrap();
272    /// let block_ref = function.get_basic_block_by_start_address(0x100).unwrap();
273    /// ```
274    pub fn get_basic_block_by_start_address(
275        &self,
276        address: Gs2BytecodeAddress,
277    ) -> Result<&BasicBlock, FunctionError> {
278        let id = self.get_basic_block_id_by_start_address(address)?;
279        self.get_basic_block_by_id(id)
280    }
281
282    /// Get a reference to a `BasicBlock` by its address (mutable). The block address
283    /// -must- be the start address of the block.
284    ///
285    /// # Arguments
286    /// - `address`: The address of the block.
287    ///
288    /// # Returns
289    /// - A reference to the `BasicBlock`.
290    ///
291    /// # Errors
292    /// - `FunctionError::BasicBlockNotFoundByAddress` if the block does not exist.
293    ///
294    /// # Example
295    /// ```
296    /// use gbf_core::function::{Function, FunctionId};
297    /// use gbf_core::basic_block::BasicBlockType;
298    ///
299    /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
300    /// let block_id = function.create_block(BasicBlockType::Normal, 0x100).unwrap();
301    /// let block_ref = function.get_basic_block_by_start_address_mut(0x100).unwrap();
302    /// ```
303    pub fn get_basic_block_by_start_address_mut(
304        &mut self,
305        address: Gs2BytecodeAddress,
306    ) -> Result<&mut BasicBlock, FunctionError> {
307        let id = self.get_basic_block_id_by_start_address(address)?;
308        self.get_basic_block_by_id_mut(id)
309    }
310
311    /// Check if a block exists by its address.
312    ///
313    /// # Arguments
314    /// - `address`: The address of the block.
315    ///
316    /// # Returns
317    /// - `true` if the block exists.
318    /// - `false` if the block does not exist.
319    ///
320    /// # Example
321    /// ```
322    /// use gbf_core::function::{Function, FunctionId};
323    /// use gbf_core::basic_block::BasicBlockType;
324    ///
325    /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
326    /// let block_id = function.create_block(BasicBlockType::Normal, 0x100).unwrap();
327    /// assert!(function.basic_block_exists_by_address(0x100));
328    /// ```
329    pub fn basic_block_exists_by_address(&self, address: Gs2BytecodeAddress) -> bool {
330        self.blocks.iter().any(|block| block.id.address == address)
331    }
332
333    /// Gets the entry basic block id of the function.
334    ///
335    /// # Returns
336    /// - The `BasicBlockId` of the entry block.
337    ///
338    /// # Example
339    /// ```
340    /// use gbf_core::function::{Function, FunctionId};
341    ///
342    /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
343    /// let entry = function.get_entry_basic_block_id();
344    /// ```
345    pub fn get_entry_basic_block_id(&self) -> BasicBlockId {
346        self.blocks[0].id
347    }
348
349    /// Get the entry basic block of the function.
350    ///
351    /// # Returns
352    /// - A reference to the entry block.
353    ///
354    /// # Example
355    /// ```
356    /// use gbf_core::function::{Function, FunctionId};
357    ///
358    /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
359    /// let entry = function.get_entry_basic_block();
360    /// ```
361    pub fn get_entry_basic_block(&self) -> &BasicBlock {
362        self.blocks.first().unwrap()
363    }
364
365    /// Get the entry block of the function.
366    ///
367    /// # Returns
368    /// - A mutable reference to the entry block.
369    ///
370    /// # Example
371    /// ```
372    /// use gbf_core::function::{Function, FunctionId};
373    ///
374    /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
375    /// let entry = function.get_entry_basic_block_mut();
376    /// ```
377    pub fn get_entry_basic_block_mut(&mut self) -> &mut BasicBlock {
378        self.blocks.first_mut().unwrap()
379    }
380
381    /// Add an edge between two `BasicBlock`s.
382    ///
383    /// # Arguments
384    /// - `source`: The `BasicBlockId` of the source block.
385    /// - `target`: The `BasicBlockId` of the target block.
386    ///
387    /// # Errors
388    /// - `FunctionError::BasicBlockNodeIndexNotFound` if either block does not have a `NodeIndex`.
389    /// - `FunctionError::GraphError` if the edge could not be added to the graph.
390    ///
391    /// # Example
392    /// ```
393    /// use gbf_core::function::{Function, FunctionId};
394    /// use gbf_core::basic_block::BasicBlockType;
395    ///
396    /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
397    /// let block1 = function.create_block(BasicBlockType::Normal, 0).unwrap();
398    /// let block2 = function.create_block(BasicBlockType::Normal, 0).unwrap();
399    /// function.add_edge(block1, block2);
400    /// ```
401    pub fn add_edge(
402        &mut self,
403        source: BasicBlockId,
404        target: BasicBlockId,
405    ) -> Result<(), FunctionError> {
406        let source_node_id = self
407            .block_id_to_node_id(source)
408            .ok_or(FunctionError::BasicBlockNodeIndexNotFound(source))?;
409        let target_node_id = self
410            .block_id_to_node_id(target)
411            .ok_or(FunctionError::BasicBlockNodeIndexNotFound(target))?;
412
413        // With petgraph, this does not fail, so we simply do it:
414        // It can panic if the node does not exist, but we have already checked that.
415        self.cfg.add_edge(source_node_id, target_node_id, ());
416        Ok(())
417    }
418
419    /// Get the number of `BasicBlock`s in the function.
420    ///
421    /// # Returns
422    /// - The number of `BasicBlock`s in the function.
423    ///
424    /// # Example
425    /// ```
426    /// use gbf_core::function::{Function, FunctionId};
427    /// use gbf_core::basic_block::BasicBlockType;
428    ///
429    /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
430    /// let block1 = function.create_block(BasicBlockType::Normal, 0).unwrap();
431    /// let block2 = function.create_block(BasicBlockType::Normal, 0).unwrap();
432    /// let block3 = function.create_block(BasicBlockType::Normal, 0).unwrap();
433    ///
434    /// assert_eq!(function.len(), 4);
435    /// ```
436    pub fn len(&self) -> usize {
437        self.blocks.len()
438    }
439
440    /// Check if the function is empty.
441    ///
442    /// # Returns
443    /// - `true` if the function is empty.
444    ///
445    /// # Example
446    /// ```
447    /// use gbf_core::function::{Function, FunctionId};
448    ///
449    /// let function = Function::new(FunctionId::new_without_name(0, 0));
450    /// assert!(!function.is_empty());
451    /// ```
452    pub fn is_empty(&self) -> bool {
453        // This will always be false since we always create an entry block
454        self.blocks.is_empty()
455    }
456
457    /// Get the predecessors of a `BasicBlock`.
458    ///
459    /// # Arguments
460    /// - `id`: The `BasicBlockId` of the block.
461    ///
462    /// # Returns
463    /// - A vector of `BasicBlockId`s that are predecessors of the block.
464    ///
465    /// # Errors
466    /// - `FunctionError::BasicBlockNodeIndexNotFound` if the block does not exist.
467    /// - `FunctionError::GraphError` if the predecessors could not be retrieved from the graph.
468    ///
469    /// # Example
470    /// ```
471    /// use gbf_core::function::{Function, FunctionId};
472    /// use gbf_core::basic_block::BasicBlockType;
473    ///
474    /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
475    /// let block1 = function.create_block(BasicBlockType::Normal, 0).unwrap();
476    /// let block2 = function.create_block(BasicBlockType::Normal, 0).unwrap();
477    ///
478    /// function.add_edge(block1, block2);
479    /// let preds = function.get_predecessors(block2).unwrap();
480    /// ```
481    pub fn get_predecessors(&self, id: BasicBlockId) -> Result<Vec<BasicBlockId>, FunctionError> {
482        let node_id = self
483            .block_id_to_node_id(id)
484            .ok_or(FunctionError::BasicBlockNodeIndexNotFound(id))?;
485
486        // Collect all incoming neighbors
487        let preds = self
488            .cfg
489            .neighbors_directed(node_id, Direction::Incoming)
490            .collect::<Vec<_>>();
491
492        Ok(preds
493            .into_iter()
494            .filter_map(|pred| self.node_id_to_block_id(pred))
495            .collect())
496    }
497
498    /// Get the successors of a `BasicBlock`.
499    ///
500    /// # Arguments
501    /// - `id`: The `BasicBlockId` of the block.
502    ///
503    /// # Returns
504    /// - A vector of `BasicBlockId`s that are successors of the block.
505    ///
506    /// # Errors
507    /// - `FunctionError::BasicBlockNodeIndexNotFound` if the block does not exist.
508    /// - `FunctionError::GraphError` if the successors could not be retrieved from the graph.
509    ///
510    /// # Example
511    /// ```
512    /// use gbf_core::function::{Function, FunctionId};
513    /// use gbf_core::basic_block::BasicBlockType;
514    ///
515    /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
516    /// let block1 = function.create_block(BasicBlockType::Normal, 0).unwrap();
517    /// let block2 = function.create_block(BasicBlockType::Normal, 0).unwrap();
518    ///
519    /// function.add_edge(block1, block2);
520    /// let succs = function.get_successors(block1).unwrap();
521    /// ```
522    pub fn get_successors(&self, id: BasicBlockId) -> Result<Vec<BasicBlockId>, FunctionError> {
523        let node_id = self
524            .block_id_to_node_id(id)
525            .ok_or(FunctionError::BasicBlockNodeIndexNotFound(id))?;
526
527        // Collect all outgoing neighbors
528        let succs = self
529            .cfg
530            .neighbors_directed(node_id, Direction::Outgoing)
531            .collect::<Vec<_>>();
532
533        Ok(succs
534            .into_iter()
535            .filter_map(|succ| self.node_id_to_block_id(succ))
536            .collect())
537    }
538
539    /// Get the blocks in reverse post order
540    ///
541    /// # Arguments
542    /// - `id`: The `BasicBlockId` of the starting block
543    ///
544    /// # Returns
545    /// - A vector of `BasicBlockId`s that sort the graph in reverse post order
546    ///
547    /// # Errors
548    /// - `FunctionError::BasicBlockNodeIndexNotFound` if the block does not exist.
549    /// - `FunctionError::GraphError` if the successors could not be retrieved from the graph.
550    pub fn get_reverse_post_order(
551        &self,
552        id: BasicBlockId,
553    ) -> Result<Vec<BasicBlockId>, FunctionError> {
554        let node_id = self
555            .block_id_to_node_id(id)
556            .ok_or(FunctionError::BasicBlockNodeIndexNotFound(id))?;
557
558        let dfs = DfsPostOrder::new(&self.cfg, node_id)
559            .iter(&self.cfg)
560            .collect::<Vec<_>>();
561
562        Ok(dfs
563            .into_iter()
564            .rev()
565            .filter_map(|node_id| self.node_id_to_block_id(node_id))
566            .collect())
567    }
568}
569
570/// Internal API for `Function`.
571impl Function {
572    /// Gets a block id based on its address.
573    ///
574    /// # Arguments
575    /// - `address`: The address of the block.
576    ///
577    /// # Returns
578    /// - The `BasicBlockId` of the block with the corresponding address.
579    ///
580    /// # Errors
581    /// - `FunctionError::BasicBlockNotFoundByAddress` if the block does not exist.
582    pub fn get_basic_block_id_by_start_address(
583        &self,
584        address: Gs2BytecodeAddress,
585    ) -> Result<BasicBlockId, FunctionError> {
586        self.blocks
587            .iter()
588            .find(|block| block.id.address == address)
589            .map(|block| block.id)
590            .ok_or(FunctionError::BasicBlockNotFoundByAddress(address))
591    }
592
593    /// Convert a `NodeIndex` to a `BasicBlockId`.
594    ///
595    /// # Arguments
596    /// - `node_id`: The `NodeIndex` to convert.
597    ///
598    /// # Returns
599    /// - The `BasicBlockId` of the block with the corresponding `NodeIndex`.
600    fn node_id_to_block_id(&self, node_id: NodeIndex) -> Option<BasicBlockId> {
601        self.graph_node_to_block.get(&node_id).cloned()
602    }
603
604    /// Convert a `BasicBlockId` to a `NodeIndex`.
605    ///
606    /// # Arguments
607    /// - `block_id`: The `BasicBlockId` to convert.
608    ///
609    /// # Returns
610    /// - The `NodeIndex` of the block with the corresponding `BasicBlockId`.
611    fn block_id_to_node_id(&self, block_id: BasicBlockId) -> Option<NodeIndex> {
612        self.block_to_graph_node.get(&block_id).cloned()
613    }
614}
615
616// === Implementations ===
617
618/// Display implementation for `FunctionId`.
619impl Display for FunctionId {
620    /// Display the `Function` as its name.
621    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
622        if let Some(name) = &self.name {
623            write!(f, "{}", name)
624        } else {
625            write!(f, "Unnamed Function (Entry)")
626        }
627    }
628}
629
630/// Clone implementation for Function
631impl Clone for Function {
632    fn clone(&self) -> Self {
633        let mut blocks = Vec::new();
634        let mut block_map = HashMap::new();
635        let mut graph_node_to_block = HashMap::new();
636        let mut block_to_graph_node = HashMap::new();
637        let address_to_id = HashMap::new();
638        let mut cfg = DiGraph::new();
639
640        // Clone blocks
641        for block in &self.blocks {
642            let new_block = block.clone();
643            let new_block_id = new_block.id;
644            blocks.push(new_block);
645            block_map.insert(new_block_id, blocks.len() - 1);
646
647            // Insert a node in the petgraph to represent this BasicBlock
648            let node_id = cfg.add_node(());
649            block_to_graph_node.insert(new_block_id, node_id);
650            graph_node_to_block.insert(node_id, new_block_id);
651        }
652
653        // Clone edges
654        for edge in self.cfg.raw_edges() {
655            let source = self.graph_node_to_block[&edge.source()];
656            let target = self.graph_node_to_block[&edge.target()];
657            let source_node_id = block_to_graph_node[&source];
658            let target_node_id = block_to_graph_node[&target];
659            cfg.add_edge(source_node_id, target_node_id, ());
660        }
661
662        Self {
663            id: self.id.clone(),
664            blocks,
665            block_map,
666            cfg,
667            graph_node_to_block,
668            block_to_graph_node,
669            address_to_id,
670        }
671    }
672}
673
674/// Deref implementation for Function
675impl Deref for Function {
676    type Target = [BasicBlock];
677
678    fn deref(&self) -> &Self::Target {
679        &self.blocks
680    }
681}
682
683/// Index implementation for function, with usize
684impl Index<usize> for Function {
685    type Output = BasicBlock;
686
687    fn index(&self, index: usize) -> &Self::Output {
688        &self.blocks[index]
689    }
690}
691
692/// IntoIterator implementation immutable reference
693impl<'a> IntoIterator for &'a Function {
694    type Item = &'a BasicBlock;
695    type IntoIter = std::slice::Iter<'a, BasicBlock>;
696
697    fn into_iter(self) -> Self::IntoIter {
698        self.blocks.iter()
699    }
700}
701
702/// IntoIterator implementation mutable reference
703impl<'a> IntoIterator for &'a mut Function {
704    type Item = &'a mut BasicBlock;
705    type IntoIter = std::slice::IterMut<'a, BasicBlock>;
706
707    fn into_iter(self) -> Self::IntoIter {
708        self.blocks.iter_mut()
709    }
710}
711
712impl NodeResolver for Function {
713    type NodeData = BasicBlock;
714
715    fn resolve(&self, node_index: NodeIndex) -> Option<&Self::NodeData> {
716        self.graph_node_to_block
717            .get(&node_index)
718            .and_then(|block_id| {
719                self.block_map
720                    .get(block_id)
721                    .and_then(|index| self.blocks.get(*index))
722            })
723    }
724
725    fn resolve_edge_color(&self, source: NodeIndex, target: NodeIndex) -> String {
726        // Get the last instruction of the source block
727        let source_block_id = self
728            .graph_node_to_block
729            .get(&source)
730            .expect("Source block not found");
731        let source_block = self
732            .get_basic_block_by_id(*source_block_id)
733            .expect("Source block not found");
734        let source_last_instruction = source_block.last().unwrap();
735
736        let target_block_id = self
737            .graph_node_to_block
738            .get(&target)
739            .expect("Target block not found");
740        let target_block = self
741            .get_basic_block_by_id(*target_block_id)
742            .expect("Target block not found");
743
744        // Figure out if the edge represents a branch by seeing if the target
745        // block address is NOT the next address after the source instruction.
746        let source_last_address = source_last_instruction.address;
747        let target_address = target_block.id.address;
748        if source_last_address + 1 != target_address {
749            // This represents a branch. Color the edge green.
750            return GBF_GREEN.to_string();
751        }
752
753        // If the opcode of the last instruction is a fall through, color the edge red since
754        // the target block's address is the next address
755        if source_last_instruction.opcode.has_fall_through() {
756            return GBF_RED.to_string();
757        }
758
759        // Otherwise, color the edge cyan (e.g. normal control flow)
760        GBF_BLUE.to_string()
761    }
762}
763
764impl DotRenderableGraph for Function {
765    /// Convert the Graph to `dot` format.
766    ///
767    /// # Returns
768    /// - A `String` containing the `dot` representation of the graph.
769    fn render_dot(&self, config: CfgDotConfig) -> String {
770        let cfg = CfgDot { config };
771        cfg.render(&self.cfg, self)
772    }
773}
774
775#[cfg(test)]
776mod tests {
777    use super::*;
778
779    #[test]
780    fn create_function() {
781        let id = FunctionId::new_without_name(0, 0);
782        let function = Function::new(id.clone());
783
784        assert_eq!(function.id, id);
785        assert_eq!(function.blocks.len(), 1);
786    }
787
788    #[test]
789    fn create_block() {
790        let id = FunctionId::new_without_name(0, 0);
791        let mut function = Function::new(id.clone());
792        let block_id = function.create_block(BasicBlockType::Normal, 32).unwrap();
793
794        assert_eq!(function.len(), 2);
795
796        // check block id & node id mappings
797        let node_id = function.block_to_graph_node.get(&block_id).unwrap();
798        let new_block_id = function.graph_node_to_block.get(node_id).unwrap();
799        assert_eq!(*new_block_id, block_id);
800
801        // test EntryBlockAlreadyExists error
802        let result = function.create_block(BasicBlockType::Entry, 0);
803        assert!(result.is_err());
804    }
805
806    #[test]
807    fn get_block() {
808        let id = FunctionId::new_without_name(0, 0);
809        let mut function = Function::new(id.clone());
810        let block_id = function.create_block(BasicBlockType::Normal, 32).unwrap();
811        let block = function.get_basic_block_by_id(block_id).unwrap();
812
813        assert_eq!(block.id, block_id);
814    }
815
816    #[test]
817    fn get_block_mut() {
818        let id = FunctionId::new_without_name(0, 0);
819        let mut function = Function::new(id);
820        let block_id = function.create_block(BasicBlockType::Normal, 43).unwrap();
821        let block = function.get_basic_block_by_id_mut(block_id).unwrap();
822
823        block.id = BasicBlockId::new(0, BasicBlockType::Exit, 43);
824        assert_eq!(block.id, BasicBlockId::new(0, BasicBlockType::Exit, 43));
825    }
826
827    #[test]
828    fn test_get_block_not_found() {
829        let id = FunctionId::new_without_name(0, 0);
830        let function = Function::new(id.clone());
831        let result =
832            function.get_basic_block_by_id(BasicBlockId::new(1234, BasicBlockType::Normal, 0));
833
834        assert!(result.is_err());
835
836        // test mut version
837        let mut function = Function::new(id.clone());
838        let result =
839            function.get_basic_block_by_id_mut(BasicBlockId::new(1234, BasicBlockType::Normal, 0));
840
841        assert!(result.is_err());
842
843        // get by start address
844        let result = function.get_basic_block_by_start_address(0x100);
845        assert!(result.is_err());
846
847        // get by start address mut
848        let result = function.get_basic_block_by_start_address_mut(0x100);
849        assert!(result.is_err());
850    }
851
852    #[test]
853    fn test_get_block_by_address() {
854        let id = FunctionId::new_without_name(0, 0);
855        let mut function = Function::new(id.clone());
856        let block_id = function
857            .create_block(BasicBlockType::Normal, 0x100)
858            .unwrap();
859        let block = function.get_basic_block_by_start_address(0x100).unwrap();
860
861        assert_eq!(block.id, block_id);
862
863        // test mut version
864        let block = function
865            .get_basic_block_by_start_address_mut(0x100)
866            .unwrap();
867        block.id = BasicBlockId::new(0, BasicBlockType::Exit, 0x100);
868        assert_eq!(block.id, BasicBlockId::new(0, BasicBlockType::Exit, 0x100));
869    }
870
871    #[test]
872    fn test_display_function_id() {
873        let id = FunctionId::new_without_name(0, 0);
874        assert_eq!(id.to_string(), "Unnamed Function (Entry)");
875
876        let id = FunctionId::new(0, Some("test".to_string()), 0);
877        assert_eq!(id.to_string(), "test");
878    }
879
880    #[test]
881    fn test_into_iter_mut() {
882        let id = FunctionId::new_without_name(0, 0);
883        let mut function = Function::new(id.clone());
884        let block_id = function.create_block(BasicBlockType::Normal, 32).unwrap();
885
886        for block in &mut function {
887            if block.id == block_id {
888                block.id = BasicBlockId::new(0, BasicBlockType::Exit, 32);
889            }
890        }
891
892        let block = function.get_basic_block_by_id(block_id).unwrap();
893        assert_eq!(block.id, BasicBlockId::new(0, BasicBlockType::Exit, 32));
894    }
895
896    #[test]
897    fn test_is_named() {
898        let id = FunctionId::new_without_name(0, 0);
899        assert!(id.is_named());
900
901        let id = FunctionId::new(0, Some("test".to_string()), 0);
902        assert!(!id.is_named());
903    }
904
905    #[test]
906    fn test_get_entry_block() {
907        let id = FunctionId::new_without_name(0, 0);
908        let function = Function::new(id.clone());
909        let entry = function.get_entry_basic_block();
910
911        assert_eq!(entry.id, function.get_entry_basic_block().id);
912    }
913
914    #[test]
915    fn test_get_entry_block_mut() {
916        let id = FunctionId::new_without_name(0, 0);
917        let mut function = Function::new(id.clone());
918        let entry_id = function.get_entry_basic_block().id;
919        let entry = function.get_entry_basic_block_mut();
920
921        assert_eq!(entry.id, entry_id);
922    }
923
924    #[test]
925    fn test_add_edge() {
926        let id = FunctionId::new_without_name(0, 0);
927        let mut function = Function::new(id.clone());
928        let block1 = function.create_block(BasicBlockType::Normal, 32).unwrap();
929        let block2 = function.create_block(BasicBlockType::Normal, 32).unwrap();
930
931        let result = function.add_edge(block1, block2);
932        assert!(result.is_ok());
933
934        let preds = function.get_predecessors(block2).unwrap();
935        assert_eq!(preds.len(), 1);
936        assert_eq!(preds[0], block1);
937
938        let succs = function.get_successors(block1).unwrap();
939        assert_eq!(succs.len(), 1);
940        assert_eq!(succs[0], block2);
941
942        // test source not found
943        let result = function.add_edge(BasicBlockId::new(1234, BasicBlockType::Normal, 0), block2);
944        assert!(result.is_err());
945
946        // test target not found
947        let result = function.add_edge(block1, BasicBlockId::new(1234, BasicBlockType::Normal, 0));
948        assert!(result.is_err());
949    }
950
951    #[test]
952    fn test_basic_block_is_empty() {
953        // will always be false since we always create an entry block
954        let id = FunctionId::new_without_name(0, 0);
955        let function = Function::new(id.clone());
956        assert!(!function.is_empty());
957    }
958
959    #[test]
960    fn test_get_predecessors() {
961        let id = FunctionId::new_without_name(0, 0);
962        let mut function = Function::new(id.clone());
963        let block1 = function.create_block(BasicBlockType::Normal, 32).unwrap();
964        let block2 = function.create_block(BasicBlockType::Normal, 32).unwrap();
965
966        function.add_edge(block1, block2).unwrap();
967        let preds = function.get_predecessors(block2).unwrap();
968
969        assert_eq!(preds.len(), 1);
970        assert_eq!(preds[0], block1);
971
972        // test error where block not found
973        let result = function.get_predecessors(BasicBlockId::new(1234, BasicBlockType::Normal, 0));
974        assert!(result.is_err());
975    }
976
977    #[test]
978    fn test_get_successors() {
979        let id = FunctionId::new_without_name(0, 0);
980        let mut function = Function::new(id.clone());
981        let block1 = function.create_block(BasicBlockType::Normal, 32).unwrap();
982        let block2 = function.create_block(BasicBlockType::Normal, 32).unwrap();
983
984        function.add_edge(block1, block2).unwrap();
985        let succs = function.get_successors(block1).unwrap();
986
987        assert_eq!(succs.len(), 1);
988        assert_eq!(succs[0], block2);
989
990        // test error where block not found
991        let result = function.get_successors(BasicBlockId::new(1234, BasicBlockType::Normal, 0));
992        assert!(result.is_err());
993    }
994
995    #[test]
996    fn test_get_entry_basic_block_id() {
997        let id = FunctionId::new_without_name(0, 0);
998        let function = Function::new(id.clone());
999        let entry = function.get_entry_basic_block_id();
1000
1001        assert_eq!(entry, function.get_entry_basic_block().id);
1002    }
1003}