gbf_core/function.rs
1#![deny(missing_docs)]
2
3use petgraph::Direction;
4use petgraph::graph::{DiGraph, NodeIndex};
5use petgraph::visit::{DfsPostOrder, Walker};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fmt::{self, Display, Formatter};
9use std::hash::Hash;
10use std::ops::{Deref, Index};
11use thiserror::Error;
12
13use crate::basic_block::{BasicBlock, BasicBlockId, BasicBlockType};
14use crate::cfg_dot::{CfgDot, CfgDotConfig, DotRenderableGraph, NodeResolver};
15use crate::utils::{GBF_BLUE, GBF_GREEN, GBF_RED, Gs2BytecodeAddress};
16
17/// Represents an error that can occur when working with functions.
18#[derive(Error, Debug, Clone, Serialize, Deserialize)]
19pub enum FunctionError {
20 /// The requested `BasicBlock` was not found by its block id.
21 #[error("BasicBlock not found by its block id: {0}")]
22 BasicBlockNotFoundById(BasicBlockId),
23
24 /// The requested `BasicBlock` was not found by its address.
25 #[error("BasicBlock not found by its address: {0}")]
26 BasicBlockNotFoundByAddress(Gs2BytecodeAddress),
27
28 /// The requested `BasicBlock` does not have a `NodeIndex`.
29 #[error("BasicBlock with id {0} does not have a NodeIndex")]
30 BasicBlockNodeIndexNotFound(BasicBlockId),
31
32 /// The function already has an entry block.
33 #[error("Function already has an entry block")]
34 EntryBlockAlreadyExists,
35}
36
37/// Represents the identifier of a function.
38#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
39pub struct FunctionId {
40 index: usize,
41 /// The name of the function, if it is not the entry point.
42 pub name: Option<String>,
43 /// The address of the function in the module.
44 pub address: Gs2BytecodeAddress,
45}
46
47impl FunctionId {
48 /// Create a new `FunctionId`.
49 ///
50 /// # Arguments
51 /// - `index`: The index of the function in the module.
52 /// - `name`: The name of the function, if it is not the entry point.
53 /// - `address`: The address of the function in the module.
54 ///
55 /// # Returns
56 /// - A new `FunctionId` instance.
57 ///
58 /// # Example
59 /// ```
60 /// use gbf_core::function::FunctionId;
61 ///
62 /// let entry = FunctionId::new_without_name(0, 0);
63 /// let add = FunctionId::new(1, Some("add"), 0x100);
64 /// ```
65 pub fn new<S>(index: usize, name: Option<S>, address: Gs2BytecodeAddress) -> Self
66 where
67 S: Into<String>,
68 {
69 Self {
70 index,
71 name: name.map(|n| n.into()),
72 address,
73 }
74 }
75
76 /// Helper method for creating a `FunctionId` without a name.
77 pub fn new_without_name(index: usize, address: Gs2BytecodeAddress) -> Self {
78 Self::new(index, None::<String>, address)
79 }
80
81 /// If the function has a name.
82 ///
83 /// # Returns
84 /// - `true` if the function has a name.
85 /// - `false` if the function does not have a name.
86 ///
87 /// # Example
88 /// ```
89 /// use gbf_core::function::FunctionId;
90 ///
91 /// let entry = FunctionId::new_without_name(0, 0);
92 ///
93 /// assert!(entry.is_named());
94 /// ```
95 pub fn is_named(&self) -> bool {
96 self.name.is_none()
97 }
98}
99
100/// Represents a function in a module.
101#[derive(Debug, Serialize, Deserialize)]
102pub struct Function {
103 /// The identifier of the function.
104 pub id: FunctionId,
105 /// A vector of all the `BasicBlock`s in the function.
106 blocks: Vec<BasicBlock>,
107 /// Maps `BasicBlockId` to their index in the `blocks` vector.
108 block_map: HashMap<BasicBlockId, usize>,
109 /// The control-flow graph of the function.
110 cfg: DiGraph<(), ()>,
111 /// Used to convert `NodeIndex` to `BasicBlockId`.
112 graph_node_to_block: HashMap<NodeIndex, BasicBlockId>,
113 /// Used to convert `BasicBlockId` to `NodeIndex`.
114 block_to_graph_node: HashMap<BasicBlockId, NodeIndex>,
115 /// A map of function addresses to their IDs.
116 address_to_id: HashMap<Gs2BytecodeAddress, FunctionId>,
117}
118
119impl Function {
120 /// Create a new `Function`. Automatically creates an entry block.
121 ///
122 /// # Arguments
123 /// - `id`: The `FunctionId` of the function.
124 ///
125 /// # Returns
126 /// - A new `Function` instance.
127 pub fn new(id: FunctionId) -> Self {
128 let mut blocks = Vec::new();
129 let mut block_map = HashMap::new();
130 let mut graph_node_to_block = HashMap::new();
131 let mut block_to_graph_node = HashMap::new();
132 let address_to_id = HashMap::new();
133 let mut cfg = DiGraph::new();
134
135 // Initialize entry block
136 let entry_block = BasicBlockId::new(blocks.len(), BasicBlockType::Entry, id.address);
137 blocks.push(BasicBlock::new(entry_block));
138 block_map.insert(entry_block, 0);
139
140 // Add an empty node in the graph to represent this BasicBlock
141 let entry_node_id = cfg.add_node(());
142 graph_node_to_block.insert(entry_node_id, entry_block);
143 block_to_graph_node.insert(entry_block, entry_node_id);
144
145 Self {
146 id,
147 blocks,
148 block_map,
149 cfg,
150 graph_node_to_block,
151 block_to_graph_node,
152 address_to_id,
153 }
154 }
155
156 /// Create a new `BasicBlock` and add it to the function.
157 ///
158 /// # Arguments
159 /// - `block_type`: The type of the block.
160 ///
161 /// # Returns
162 /// - A `BasicBlockId` for the new block.
163 ///
164 /// # Example
165 /// ```
166 /// use gbf_core::function::{Function, FunctionId};
167 /// use gbf_core::basic_block::BasicBlockType;
168 ///
169 /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
170 /// let block = function.create_block(BasicBlockType::Normal, 0);
171 /// ```
172 pub fn create_block(
173 &mut self,
174 block_type: BasicBlockType,
175 address: Gs2BytecodeAddress,
176 ) -> Result<BasicBlockId, FunctionError> {
177 // do not allow entry block to be created more than once
178 if block_type == BasicBlockType::Entry {
179 return Err(FunctionError::EntryBlockAlreadyExists);
180 }
181
182 let id = BasicBlockId::new(self.blocks.len(), block_type, address);
183 self.blocks.push(BasicBlock::new(id));
184 self.block_map.insert(id, self.blocks.len() - 1);
185
186 // Insert a node in the petgraph to represent this BasicBlock
187 let node_id = self.cfg.add_node(());
188 self.block_to_graph_node.insert(id, node_id);
189 self.graph_node_to_block.insert(node_id, id);
190
191 Ok(id)
192 }
193
194 /// Get a reference to a `BasicBlock` by its `BasicBlockId`.
195 ///
196 /// # Arguments
197 /// - `id`: The `BasicBlockId` of the block.
198 ///
199 /// # Returns
200 /// - A reference to the `BasicBlock`.
201 ///
202 /// # Errors
203 /// - `FunctionError::BasicBlockNotFound` if the block does not exist.
204 ///
205 /// # Example
206 /// ```
207 /// use gbf_core::function::{Function, FunctionId};
208 /// use gbf_core::basic_block::BasicBlockType;
209 ///
210 /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
211 /// let block_id = function.create_block(BasicBlockType::Normal, 0).unwrap();
212 /// let block_ref = function.get_basic_block_by_id(block_id).unwrap();
213 /// ```
214 pub fn get_basic_block_by_id(&self, id: BasicBlockId) -> Result<&BasicBlock, FunctionError> {
215 let index = self
216 .block_map
217 .get(&id)
218 .ok_or(FunctionError::BasicBlockNotFoundById(id))?;
219 Ok(&self.blocks[*index])
220 }
221
222 /// Get a reference to a `BasicBlock` by its address. The block address
223 /// -must- be the start address of the block.
224 ///
225 /// # Arguments
226 /// - `id`: The `BasicBlockId` of the block.
227 ///
228 /// # Returns
229 /// - A mutable reference to the `BasicBlock`.
230 ///
231 /// # Errors
232 /// - `FunctionError::BasicBlockNotFound` if the block does not exist.
233 ///
234 /// # Example
235 /// ```
236 /// use gbf_core::function::{Function, FunctionId};
237 /// use gbf_core::basic_block::BasicBlockType;
238 ///
239 /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
240 /// let block_id = function.create_block(BasicBlockType::Normal, 0).unwrap();
241 /// let block_ref = function.get_basic_block_by_id_mut(block_id).unwrap();
242 /// ```
243 pub fn get_basic_block_by_id_mut(
244 &mut self,
245 id: BasicBlockId,
246 ) -> Result<&mut BasicBlock, FunctionError> {
247 let index = self
248 .block_map
249 .get(&id)
250 .ok_or(FunctionError::BasicBlockNotFoundById(id))?;
251 Ok(&mut self.blocks[*index])
252 }
253
254 /// Get a reference to a `BasicBlock` by its address.
255 ///
256 /// # Arguments
257 /// - `address`: The address of the block.
258 ///
259 /// # Returns
260 /// - A reference to the `BasicBlock`.
261 ///
262 /// # Errors
263 /// - `FunctionError::BasicBlockNotFoundByAddress` if the block does not exist.
264 ///
265 /// # Example
266 /// ```
267 /// use gbf_core::function::{Function, FunctionId};
268 /// use gbf_core::basic_block::BasicBlockType;
269 ///
270 /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
271 /// let block_id = function.create_block(BasicBlockType::Normal, 0x100).unwrap();
272 /// let block_ref = function.get_basic_block_by_start_address(0x100).unwrap();
273 /// ```
274 pub fn get_basic_block_by_start_address(
275 &self,
276 address: Gs2BytecodeAddress,
277 ) -> Result<&BasicBlock, FunctionError> {
278 let id = self.get_basic_block_id_by_start_address(address)?;
279 self.get_basic_block_by_id(id)
280 }
281
282 /// Get a reference to a `BasicBlock` by its address (mutable). The block address
283 /// -must- be the start address of the block.
284 ///
285 /// # Arguments
286 /// - `address`: The address of the block.
287 ///
288 /// # Returns
289 /// - A reference to the `BasicBlock`.
290 ///
291 /// # Errors
292 /// - `FunctionError::BasicBlockNotFoundByAddress` if the block does not exist.
293 ///
294 /// # Example
295 /// ```
296 /// use gbf_core::function::{Function, FunctionId};
297 /// use gbf_core::basic_block::BasicBlockType;
298 ///
299 /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
300 /// let block_id = function.create_block(BasicBlockType::Normal, 0x100).unwrap();
301 /// let block_ref = function.get_basic_block_by_start_address_mut(0x100).unwrap();
302 /// ```
303 pub fn get_basic_block_by_start_address_mut(
304 &mut self,
305 address: Gs2BytecodeAddress,
306 ) -> Result<&mut BasicBlock, FunctionError> {
307 let id = self.get_basic_block_id_by_start_address(address)?;
308 self.get_basic_block_by_id_mut(id)
309 }
310
311 /// Check if a block exists by its address.
312 ///
313 /// # Arguments
314 /// - `address`: The address of the block.
315 ///
316 /// # Returns
317 /// - `true` if the block exists.
318 /// - `false` if the block does not exist.
319 ///
320 /// # Example
321 /// ```
322 /// use gbf_core::function::{Function, FunctionId};
323 /// use gbf_core::basic_block::BasicBlockType;
324 ///
325 /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
326 /// let block_id = function.create_block(BasicBlockType::Normal, 0x100).unwrap();
327 /// assert!(function.basic_block_exists_by_address(0x100));
328 /// ```
329 pub fn basic_block_exists_by_address(&self, address: Gs2BytecodeAddress) -> bool {
330 self.blocks.iter().any(|block| block.id.address == address)
331 }
332
333 /// Gets the entry basic block id of the function.
334 ///
335 /// # Returns
336 /// - The `BasicBlockId` of the entry block.
337 ///
338 /// # Example
339 /// ```
340 /// use gbf_core::function::{Function, FunctionId};
341 ///
342 /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
343 /// let entry = function.get_entry_basic_block_id();
344 /// ```
345 pub fn get_entry_basic_block_id(&self) -> BasicBlockId {
346 self.blocks[0].id
347 }
348
349 /// Get the entry basic block of the function.
350 ///
351 /// # Returns
352 /// - A reference to the entry block.
353 ///
354 /// # Example
355 /// ```
356 /// use gbf_core::function::{Function, FunctionId};
357 ///
358 /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
359 /// let entry = function.get_entry_basic_block();
360 /// ```
361 pub fn get_entry_basic_block(&self) -> &BasicBlock {
362 self.blocks.first().unwrap()
363 }
364
365 /// Get the entry block of the function.
366 ///
367 /// # Returns
368 /// - A mutable reference to the entry block.
369 ///
370 /// # Example
371 /// ```
372 /// use gbf_core::function::{Function, FunctionId};
373 ///
374 /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
375 /// let entry = function.get_entry_basic_block_mut();
376 /// ```
377 pub fn get_entry_basic_block_mut(&mut self) -> &mut BasicBlock {
378 self.blocks.first_mut().unwrap()
379 }
380
381 /// Add an edge between two `BasicBlock`s.
382 ///
383 /// # Arguments
384 /// - `source`: The `BasicBlockId` of the source block.
385 /// - `target`: The `BasicBlockId` of the target block.
386 ///
387 /// # Errors
388 /// - `FunctionError::BasicBlockNodeIndexNotFound` if either block does not have a `NodeIndex`.
389 /// - `FunctionError::GraphError` if the edge could not be added to the graph.
390 ///
391 /// # Example
392 /// ```
393 /// use gbf_core::function::{Function, FunctionId};
394 /// use gbf_core::basic_block::BasicBlockType;
395 ///
396 /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
397 /// let block1 = function.create_block(BasicBlockType::Normal, 0).unwrap();
398 /// let block2 = function.create_block(BasicBlockType::Normal, 0).unwrap();
399 /// function.add_edge(block1, block2);
400 /// ```
401 pub fn add_edge(
402 &mut self,
403 source: BasicBlockId,
404 target: BasicBlockId,
405 ) -> Result<(), FunctionError> {
406 let source_node_id = self
407 .block_id_to_node_id(source)
408 .ok_or(FunctionError::BasicBlockNodeIndexNotFound(source))?;
409 let target_node_id = self
410 .block_id_to_node_id(target)
411 .ok_or(FunctionError::BasicBlockNodeIndexNotFound(target))?;
412
413 // With petgraph, this does not fail, so we simply do it:
414 // It can panic if the node does not exist, but we have already checked that.
415 self.cfg.add_edge(source_node_id, target_node_id, ());
416 Ok(())
417 }
418
419 /// Get the number of `BasicBlock`s in the function.
420 ///
421 /// # Returns
422 /// - The number of `BasicBlock`s in the function.
423 ///
424 /// # Example
425 /// ```
426 /// use gbf_core::function::{Function, FunctionId};
427 /// use gbf_core::basic_block::BasicBlockType;
428 ///
429 /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
430 /// let block1 = function.create_block(BasicBlockType::Normal, 0).unwrap();
431 /// let block2 = function.create_block(BasicBlockType::Normal, 0).unwrap();
432 /// let block3 = function.create_block(BasicBlockType::Normal, 0).unwrap();
433 ///
434 /// assert_eq!(function.len(), 4);
435 /// ```
436 pub fn len(&self) -> usize {
437 self.blocks.len()
438 }
439
440 /// Check if the function is empty.
441 ///
442 /// # Returns
443 /// - `true` if the function is empty.
444 ///
445 /// # Example
446 /// ```
447 /// use gbf_core::function::{Function, FunctionId};
448 ///
449 /// let function = Function::new(FunctionId::new_without_name(0, 0));
450 /// assert!(!function.is_empty());
451 /// ```
452 pub fn is_empty(&self) -> bool {
453 // This will always be false since we always create an entry block
454 self.blocks.is_empty()
455 }
456
457 /// Get the predecessors of a `BasicBlock`.
458 ///
459 /// # Arguments
460 /// - `id`: The `BasicBlockId` of the block.
461 ///
462 /// # Returns
463 /// - A vector of `BasicBlockId`s that are predecessors of the block.
464 ///
465 /// # Errors
466 /// - `FunctionError::BasicBlockNodeIndexNotFound` if the block does not exist.
467 /// - `FunctionError::GraphError` if the predecessors could not be retrieved from the graph.
468 ///
469 /// # Example
470 /// ```
471 /// use gbf_core::function::{Function, FunctionId};
472 /// use gbf_core::basic_block::BasicBlockType;
473 ///
474 /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
475 /// let block1 = function.create_block(BasicBlockType::Normal, 0).unwrap();
476 /// let block2 = function.create_block(BasicBlockType::Normal, 0).unwrap();
477 ///
478 /// function.add_edge(block1, block2);
479 /// let preds = function.get_predecessors(block2).unwrap();
480 /// ```
481 pub fn get_predecessors(&self, id: BasicBlockId) -> Result<Vec<BasicBlockId>, FunctionError> {
482 let node_id = self
483 .block_id_to_node_id(id)
484 .ok_or(FunctionError::BasicBlockNodeIndexNotFound(id))?;
485
486 // Collect all incoming neighbors
487 let preds = self
488 .cfg
489 .neighbors_directed(node_id, Direction::Incoming)
490 .collect::<Vec<_>>();
491
492 Ok(preds
493 .into_iter()
494 .filter_map(|pred| self.node_id_to_block_id(pred))
495 .collect())
496 }
497
498 /// Get the successors of a `BasicBlock`.
499 ///
500 /// # Arguments
501 /// - `id`: The `BasicBlockId` of the block.
502 ///
503 /// # Returns
504 /// - A vector of `BasicBlockId`s that are successors of the block.
505 ///
506 /// # Errors
507 /// - `FunctionError::BasicBlockNodeIndexNotFound` if the block does not exist.
508 /// - `FunctionError::GraphError` if the successors could not be retrieved from the graph.
509 ///
510 /// # Example
511 /// ```
512 /// use gbf_core::function::{Function, FunctionId};
513 /// use gbf_core::basic_block::BasicBlockType;
514 ///
515 /// let mut function = Function::new(FunctionId::new_without_name(0, 0));
516 /// let block1 = function.create_block(BasicBlockType::Normal, 0).unwrap();
517 /// let block2 = function.create_block(BasicBlockType::Normal, 0).unwrap();
518 ///
519 /// function.add_edge(block1, block2);
520 /// let succs = function.get_successors(block1).unwrap();
521 /// ```
522 pub fn get_successors(&self, id: BasicBlockId) -> Result<Vec<BasicBlockId>, FunctionError> {
523 let node_id = self
524 .block_id_to_node_id(id)
525 .ok_or(FunctionError::BasicBlockNodeIndexNotFound(id))?;
526
527 // Collect all outgoing neighbors
528 let succs = self
529 .cfg
530 .neighbors_directed(node_id, Direction::Outgoing)
531 .collect::<Vec<_>>();
532
533 Ok(succs
534 .into_iter()
535 .filter_map(|succ| self.node_id_to_block_id(succ))
536 .collect())
537 }
538
539 /// Get the blocks in reverse post order
540 ///
541 /// # Arguments
542 /// - `id`: The `BasicBlockId` of the starting block
543 ///
544 /// # Returns
545 /// - A vector of `BasicBlockId`s that sort the graph in reverse post order
546 ///
547 /// # Errors
548 /// - `FunctionError::BasicBlockNodeIndexNotFound` if the block does not exist.
549 /// - `FunctionError::GraphError` if the successors could not be retrieved from the graph.
550 pub fn get_reverse_post_order(
551 &self,
552 id: BasicBlockId,
553 ) -> Result<Vec<BasicBlockId>, FunctionError> {
554 let node_id = self
555 .block_id_to_node_id(id)
556 .ok_or(FunctionError::BasicBlockNodeIndexNotFound(id))?;
557
558 let dfs = DfsPostOrder::new(&self.cfg, node_id)
559 .iter(&self.cfg)
560 .collect::<Vec<_>>();
561
562 Ok(dfs
563 .into_iter()
564 .rev()
565 .filter_map(|node_id| self.node_id_to_block_id(node_id))
566 .collect())
567 }
568}
569
570/// Internal API for `Function`.
571impl Function {
572 /// Gets a block id based on its address.
573 ///
574 /// # Arguments
575 /// - `address`: The address of the block.
576 ///
577 /// # Returns
578 /// - The `BasicBlockId` of the block with the corresponding address.
579 ///
580 /// # Errors
581 /// - `FunctionError::BasicBlockNotFoundByAddress` if the block does not exist.
582 pub fn get_basic_block_id_by_start_address(
583 &self,
584 address: Gs2BytecodeAddress,
585 ) -> Result<BasicBlockId, FunctionError> {
586 self.blocks
587 .iter()
588 .find(|block| block.id.address == address)
589 .map(|block| block.id)
590 .ok_or(FunctionError::BasicBlockNotFoundByAddress(address))
591 }
592
593 /// Convert a `NodeIndex` to a `BasicBlockId`.
594 ///
595 /// # Arguments
596 /// - `node_id`: The `NodeIndex` to convert.
597 ///
598 /// # Returns
599 /// - The `BasicBlockId` of the block with the corresponding `NodeIndex`.
600 fn node_id_to_block_id(&self, node_id: NodeIndex) -> Option<BasicBlockId> {
601 self.graph_node_to_block.get(&node_id).cloned()
602 }
603
604 /// Convert a `BasicBlockId` to a `NodeIndex`.
605 ///
606 /// # Arguments
607 /// - `block_id`: The `BasicBlockId` to convert.
608 ///
609 /// # Returns
610 /// - The `NodeIndex` of the block with the corresponding `BasicBlockId`.
611 fn block_id_to_node_id(&self, block_id: BasicBlockId) -> Option<NodeIndex> {
612 self.block_to_graph_node.get(&block_id).cloned()
613 }
614}
615
616// === Implementations ===
617
618/// Display implementation for `FunctionId`.
619impl Display for FunctionId {
620 /// Display the `Function` as its name.
621 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
622 if let Some(name) = &self.name {
623 write!(f, "{}", name)
624 } else {
625 write!(f, "Unnamed Function (Entry)")
626 }
627 }
628}
629
630/// Clone implementation for Function
631impl Clone for Function {
632 fn clone(&self) -> Self {
633 let mut blocks = Vec::new();
634 let mut block_map = HashMap::new();
635 let mut graph_node_to_block = HashMap::new();
636 let mut block_to_graph_node = HashMap::new();
637 let address_to_id = HashMap::new();
638 let mut cfg = DiGraph::new();
639
640 // Clone blocks
641 for block in &self.blocks {
642 let new_block = block.clone();
643 let new_block_id = new_block.id;
644 blocks.push(new_block);
645 block_map.insert(new_block_id, blocks.len() - 1);
646
647 // Insert a node in the petgraph to represent this BasicBlock
648 let node_id = cfg.add_node(());
649 block_to_graph_node.insert(new_block_id, node_id);
650 graph_node_to_block.insert(node_id, new_block_id);
651 }
652
653 // Clone edges
654 for edge in self.cfg.raw_edges() {
655 let source = self.graph_node_to_block[&edge.source()];
656 let target = self.graph_node_to_block[&edge.target()];
657 let source_node_id = block_to_graph_node[&source];
658 let target_node_id = block_to_graph_node[&target];
659 cfg.add_edge(source_node_id, target_node_id, ());
660 }
661
662 Self {
663 id: self.id.clone(),
664 blocks,
665 block_map,
666 cfg,
667 graph_node_to_block,
668 block_to_graph_node,
669 address_to_id,
670 }
671 }
672}
673
674/// Deref implementation for Function
675impl Deref for Function {
676 type Target = [BasicBlock];
677
678 fn deref(&self) -> &Self::Target {
679 &self.blocks
680 }
681}
682
683/// Index implementation for function, with usize
684impl Index<usize> for Function {
685 type Output = BasicBlock;
686
687 fn index(&self, index: usize) -> &Self::Output {
688 &self.blocks[index]
689 }
690}
691
692/// IntoIterator implementation immutable reference
693impl<'a> IntoIterator for &'a Function {
694 type Item = &'a BasicBlock;
695 type IntoIter = std::slice::Iter<'a, BasicBlock>;
696
697 fn into_iter(self) -> Self::IntoIter {
698 self.blocks.iter()
699 }
700}
701
702/// IntoIterator implementation mutable reference
703impl<'a> IntoIterator for &'a mut Function {
704 type Item = &'a mut BasicBlock;
705 type IntoIter = std::slice::IterMut<'a, BasicBlock>;
706
707 fn into_iter(self) -> Self::IntoIter {
708 self.blocks.iter_mut()
709 }
710}
711
712impl NodeResolver for Function {
713 type NodeData = BasicBlock;
714
715 fn resolve(&self, node_index: NodeIndex) -> Option<&Self::NodeData> {
716 self.graph_node_to_block
717 .get(&node_index)
718 .and_then(|block_id| {
719 self.block_map
720 .get(block_id)
721 .and_then(|index| self.blocks.get(*index))
722 })
723 }
724
725 fn resolve_edge_color(&self, source: NodeIndex, target: NodeIndex) -> String {
726 // Get the last instruction of the source block
727 let source_block_id = self
728 .graph_node_to_block
729 .get(&source)
730 .expect("Source block not found");
731 let source_block = self
732 .get_basic_block_by_id(*source_block_id)
733 .expect("Source block not found");
734 let source_last_instruction = source_block.last().unwrap();
735
736 let target_block_id = self
737 .graph_node_to_block
738 .get(&target)
739 .expect("Target block not found");
740 let target_block = self
741 .get_basic_block_by_id(*target_block_id)
742 .expect("Target block not found");
743
744 // Figure out if the edge represents a branch by seeing if the target
745 // block address is NOT the next address after the source instruction.
746 let source_last_address = source_last_instruction.address;
747 let target_address = target_block.id.address;
748 if source_last_address + 1 != target_address {
749 // This represents a branch. Color the edge green.
750 return GBF_GREEN.to_string();
751 }
752
753 // If the opcode of the last instruction is a fall through, color the edge red since
754 // the target block's address is the next address
755 if source_last_instruction.opcode.has_fall_through() {
756 return GBF_RED.to_string();
757 }
758
759 // Otherwise, color the edge cyan (e.g. normal control flow)
760 GBF_BLUE.to_string()
761 }
762}
763
764impl DotRenderableGraph for Function {
765 /// Convert the Graph to `dot` format.
766 ///
767 /// # Returns
768 /// - A `String` containing the `dot` representation of the graph.
769 fn render_dot(&self, config: CfgDotConfig) -> String {
770 let cfg = CfgDot { config };
771 cfg.render(&self.cfg, self)
772 }
773}
774
775#[cfg(test)]
776mod tests {
777 use super::*;
778
779 #[test]
780 fn create_function() {
781 let id = FunctionId::new_without_name(0, 0);
782 let function = Function::new(id.clone());
783
784 assert_eq!(function.id, id);
785 assert_eq!(function.blocks.len(), 1);
786 }
787
788 #[test]
789 fn create_block() {
790 let id = FunctionId::new_without_name(0, 0);
791 let mut function = Function::new(id.clone());
792 let block_id = function.create_block(BasicBlockType::Normal, 32).unwrap();
793
794 assert_eq!(function.len(), 2);
795
796 // check block id & node id mappings
797 let node_id = function.block_to_graph_node.get(&block_id).unwrap();
798 let new_block_id = function.graph_node_to_block.get(node_id).unwrap();
799 assert_eq!(*new_block_id, block_id);
800
801 // test EntryBlockAlreadyExists error
802 let result = function.create_block(BasicBlockType::Entry, 0);
803 assert!(result.is_err());
804 }
805
806 #[test]
807 fn get_block() {
808 let id = FunctionId::new_without_name(0, 0);
809 let mut function = Function::new(id.clone());
810 let block_id = function.create_block(BasicBlockType::Normal, 32).unwrap();
811 let block = function.get_basic_block_by_id(block_id).unwrap();
812
813 assert_eq!(block.id, block_id);
814 }
815
816 #[test]
817 fn get_block_mut() {
818 let id = FunctionId::new_without_name(0, 0);
819 let mut function = Function::new(id);
820 let block_id = function.create_block(BasicBlockType::Normal, 43).unwrap();
821 let block = function.get_basic_block_by_id_mut(block_id).unwrap();
822
823 block.id = BasicBlockId::new(0, BasicBlockType::Exit, 43);
824 assert_eq!(block.id, BasicBlockId::new(0, BasicBlockType::Exit, 43));
825 }
826
827 #[test]
828 fn test_get_block_not_found() {
829 let id = FunctionId::new_without_name(0, 0);
830 let function = Function::new(id.clone());
831 let result =
832 function.get_basic_block_by_id(BasicBlockId::new(1234, BasicBlockType::Normal, 0));
833
834 assert!(result.is_err());
835
836 // test mut version
837 let mut function = Function::new(id.clone());
838 let result =
839 function.get_basic_block_by_id_mut(BasicBlockId::new(1234, BasicBlockType::Normal, 0));
840
841 assert!(result.is_err());
842
843 // get by start address
844 let result = function.get_basic_block_by_start_address(0x100);
845 assert!(result.is_err());
846
847 // get by start address mut
848 let result = function.get_basic_block_by_start_address_mut(0x100);
849 assert!(result.is_err());
850 }
851
852 #[test]
853 fn test_get_block_by_address() {
854 let id = FunctionId::new_without_name(0, 0);
855 let mut function = Function::new(id.clone());
856 let block_id = function
857 .create_block(BasicBlockType::Normal, 0x100)
858 .unwrap();
859 let block = function.get_basic_block_by_start_address(0x100).unwrap();
860
861 assert_eq!(block.id, block_id);
862
863 // test mut version
864 let block = function
865 .get_basic_block_by_start_address_mut(0x100)
866 .unwrap();
867 block.id = BasicBlockId::new(0, BasicBlockType::Exit, 0x100);
868 assert_eq!(block.id, BasicBlockId::new(0, BasicBlockType::Exit, 0x100));
869 }
870
871 #[test]
872 fn test_display_function_id() {
873 let id = FunctionId::new_without_name(0, 0);
874 assert_eq!(id.to_string(), "Unnamed Function (Entry)");
875
876 let id = FunctionId::new(0, Some("test".to_string()), 0);
877 assert_eq!(id.to_string(), "test");
878 }
879
880 #[test]
881 fn test_into_iter_mut() {
882 let id = FunctionId::new_without_name(0, 0);
883 let mut function = Function::new(id.clone());
884 let block_id = function.create_block(BasicBlockType::Normal, 32).unwrap();
885
886 for block in &mut function {
887 if block.id == block_id {
888 block.id = BasicBlockId::new(0, BasicBlockType::Exit, 32);
889 }
890 }
891
892 let block = function.get_basic_block_by_id(block_id).unwrap();
893 assert_eq!(block.id, BasicBlockId::new(0, BasicBlockType::Exit, 32));
894 }
895
896 #[test]
897 fn test_is_named() {
898 let id = FunctionId::new_without_name(0, 0);
899 assert!(id.is_named());
900
901 let id = FunctionId::new(0, Some("test".to_string()), 0);
902 assert!(!id.is_named());
903 }
904
905 #[test]
906 fn test_get_entry_block() {
907 let id = FunctionId::new_without_name(0, 0);
908 let function = Function::new(id.clone());
909 let entry = function.get_entry_basic_block();
910
911 assert_eq!(entry.id, function.get_entry_basic_block().id);
912 }
913
914 #[test]
915 fn test_get_entry_block_mut() {
916 let id = FunctionId::new_without_name(0, 0);
917 let mut function = Function::new(id.clone());
918 let entry_id = function.get_entry_basic_block().id;
919 let entry = function.get_entry_basic_block_mut();
920
921 assert_eq!(entry.id, entry_id);
922 }
923
924 #[test]
925 fn test_add_edge() {
926 let id = FunctionId::new_without_name(0, 0);
927 let mut function = Function::new(id.clone());
928 let block1 = function.create_block(BasicBlockType::Normal, 32).unwrap();
929 let block2 = function.create_block(BasicBlockType::Normal, 32).unwrap();
930
931 let result = function.add_edge(block1, block2);
932 assert!(result.is_ok());
933
934 let preds = function.get_predecessors(block2).unwrap();
935 assert_eq!(preds.len(), 1);
936 assert_eq!(preds[0], block1);
937
938 let succs = function.get_successors(block1).unwrap();
939 assert_eq!(succs.len(), 1);
940 assert_eq!(succs[0], block2);
941
942 // test source not found
943 let result = function.add_edge(BasicBlockId::new(1234, BasicBlockType::Normal, 0), block2);
944 assert!(result.is_err());
945
946 // test target not found
947 let result = function.add_edge(block1, BasicBlockId::new(1234, BasicBlockType::Normal, 0));
948 assert!(result.is_err());
949 }
950
951 #[test]
952 fn test_basic_block_is_empty() {
953 // will always be false since we always create an entry block
954 let id = FunctionId::new_without_name(0, 0);
955 let function = Function::new(id.clone());
956 assert!(!function.is_empty());
957 }
958
959 #[test]
960 fn test_get_predecessors() {
961 let id = FunctionId::new_without_name(0, 0);
962 let mut function = Function::new(id.clone());
963 let block1 = function.create_block(BasicBlockType::Normal, 32).unwrap();
964 let block2 = function.create_block(BasicBlockType::Normal, 32).unwrap();
965
966 function.add_edge(block1, block2).unwrap();
967 let preds = function.get_predecessors(block2).unwrap();
968
969 assert_eq!(preds.len(), 1);
970 assert_eq!(preds[0], block1);
971
972 // test error where block not found
973 let result = function.get_predecessors(BasicBlockId::new(1234, BasicBlockType::Normal, 0));
974 assert!(result.is_err());
975 }
976
977 #[test]
978 fn test_get_successors() {
979 let id = FunctionId::new_without_name(0, 0);
980 let mut function = Function::new(id.clone());
981 let block1 = function.create_block(BasicBlockType::Normal, 32).unwrap();
982 let block2 = function.create_block(BasicBlockType::Normal, 32).unwrap();
983
984 function.add_edge(block1, block2).unwrap();
985 let succs = function.get_successors(block1).unwrap();
986
987 assert_eq!(succs.len(), 1);
988 assert_eq!(succs[0], block2);
989
990 // test error where block not found
991 let result = function.get_successors(BasicBlockId::new(1234, BasicBlockType::Normal, 0));
992 assert!(result.is_err());
993 }
994
995 #[test]
996 fn test_get_entry_basic_block_id() {
997 let id = FunctionId::new_without_name(0, 0);
998 let function = Function::new(id.clone());
999 let entry = function.get_entry_basic_block_id();
1000
1001 assert_eq!(entry, function.get_entry_basic_block().id);
1002 }
1003}