# Time: O(N)
# Space: O(N) due to call stack

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def trimBST(self, root: Optional[TreeNode], low: int, high: int) -> Optional[TreeNode]:
        # Recursion approach
        
        def trim(node):
            # Base case
            if node is None:
                return None
            
            if node.val < low:
                # Cause everything to the left will be
                # less than `low` and hence trimmed
                return trim(node.right)
            
            elif node.val > high:
                # Cause everything to the right will be
                # greater than `high` and hence trimmed
                return trim(node.left)
            
            else:
                # `node.val` is between [low, high] so
                # just need to trim their left and right
                # subtrees
                node.left = trim(node.left)
                node.right = trim(node.right)
                return node
            
        return trim(root)