# Definition for a binary tree node. # class TreeNode: # def __init__(self, val=0, left=None, right=None): # self.val = val # self.left = left # self.right = right class Solution: # Perform reverse in-order traversal of the given BST. # This will give us numbers in descending order as we # touch each node. # # Right subtree -> Current node -> Left subtree # def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]: self.walk(root, 0) return root # `total` will be the state used to keep the sum of # all numbers greater than `node.val` def walk(self, node, total): if node is not None: # Since all nums to the right will be greater # than current, we need to get the total of # the right subtree right_total = self.walk(node.right, total) # Need to store the new total based on the right # subtree's total we obtained next_total = node.val + right_total # Update current node's total to its value plus # total of all numbers greater than it (i.e right # subtree) node.val = next_total # If there's a left subtree, it could mean that the # there might be a new total that we need to return one # level down the call stack next_total = self.walk(node.left, next_total) return next_total # If node isn't present, we can't just return 0. # # 0 would be valid for the very first rightmost subtree. # But as soon as we touch left subtrees, `total` could # be different from initial 0 we pass return total