# Time: O(N)
# Space: O(N)

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def goodNodes(self, root: TreeNode) -> int:
        def dfs(node, max_value) -> int:
            '''
            Perform pre-order traversal and keep track of max
            elements in the tree. Any subsequent traversal can
            then compare against the updated max_value to see if
            it's a good node or not
            '''
            
            # If no left/right nodes, then we can just return 0
            if not node: return 0
            
            # Current node is good if it's value is greater than or
            # equal to the `max_value` seen so far
            res = 1 if max_value <= node.val else 0
            
            # Compute the new max value, the current node could be it
            max_value = max(max_value, node.val)
            
            # Do traversal on left and right nodes and add their units
            res += dfs(node.left, max_value) + dfs(node.right, max_value)
            
            # This will indicate the count
            return res
        
        
        return dfs(root, root.val)