# Time: O(N^2) # Space: O(1) # NOTE: i, j, k mentioned in the problem refers to indices, not # the values in the array! So, duplicates are A-OK! class Solution: def threeSumMulti(self, arr: List[int], target: int) -> int: # To use 2-pointer technique arr.sort() n = len(arr) result = 0 for i, _ in enumerate(arr): diff = target - arr[i] j, k = i + 1, n - 1 while j < k: # Inside this, it's pretty much two-sum and # we use the two pointers to move from left # and right if they don't equal `diff` if arr[j] + arr[k] < diff: j += 1 elif arr[j] + arr[k] > diff: k -= 1 # Following condition but we know arr[j] + arr[k] == diff elif arr[j] != arr[k]: # Count all of arr[j] and arr[k] instances count_j = count_k = 1 while j + 1 < k and arr[j + 1] == arr[j]: count_j += 1 j += 1 while k - 1 > j and arr[k - 1] == arr[k]: count_k += 1 k -= 1 result += count_j * count_k j += 1 k -= 1 else: # If both are equal, then we have a range similar to the following: # # [4, 4, 4, 4] # ^ ^ # | | # j k # # So, they are the same! We have (k - j + 1) elements to pick from and # we can use combination to solve this cause we shouldn't consider (3i, 4i) # and (4i, 3i) to be the same num_dups = k - j + 1 # num_dups num_dups! # C = ---------------------- # 2 2! ยท (num_dups - 2)! # # = (num_dups * (num_dups - 1)) / 2 result += (num_dups * (num_dups - 1)) // 2 # Since it's all duplicates from j to k, we break break return result % (10 ** 9 + 7)