Prune tree
XXX. Binary Tree Pruning
Recursion
Binary Tree
Problem Statement:
Given a binary tree, return the same tree where every subtree (of the given tree) not containing a 1
has been removed.
A subtree of a node is the node plus every node that is a descendant of that node.
Algorithm:
- Define a recursive function to check if a subtree contains a node with value
1
. - For each node:
- Recursively check the left and right subtrees.
- If a subtree does not contain a
1
, prune it by setting the corresponding child pointer tonull
.
- Return the root of the modified tree if it contains a
1
; otherwise, returnnull
.
Complexity:
Time: O(n), where n
is the number of nodes in the tree. Each node is visited once.
Space: O(h), where h
is the height of the tree. This is the space used by the recursion stack.
Java Implementation:
public class Solution {
// Main function to prune the tree
public TreeNode pruneTree(TreeNode root) {
// Return the root only if the tree contains a '1'; otherwise, return null.
return containsOne(root) ? root : null;
}
// Helper function to check if a subtree contains at least one '1'
public boolean containsOne(TreeNode node) {
if (node == null) return false;
// Check if the left subtree contains a '1'
boolean leftContainsOne = containsOne(node.left);
// Check if the right subtree contains a '1'
boolean rightContainsOne = containsOne(node.right);
// Prune the left subtree if it does not contain a '1'
if (!leftContainsOne) node.left = null;
// Prune the right subtree if it does not contain a '1'
if (!rightContainsOne) node.right = null;
// Return true if the current node or any of its subtrees contain a '1'
return node.val == 1 || leftContainsOne || rightContainsOne;
}
// A simpler alternative solution with less branching
public TreeNode pruneTreeFancy(TreeNode root) {
if (root == null) return null;
// Recursively prune the left and right subtrees
root.left = pruneTreeFancy(root.left);
root.right = pruneTreeFancy(root.right);
// If the current node is '0' and both subtrees are pruned, prune this node
return (root.val == 0 && root.left == null && root.right == null) ? null : root;
}
}