2019-08-16 02:07:43 +08:00
|
|
|
"""
|
|
|
|
Given an array-like data structure A[1..n], how many pairs
|
2019-08-19 21:37:49 +08:00
|
|
|
(i, j) for all 1 <= i < j <= n such that A[i] > A[j]? These pairs are
|
|
|
|
called inversions. Counting the number of such inversions in an array-like
|
2021-10-12 00:33:06 +08:00
|
|
|
object is the important. Among other things, counting inversions can help
|
|
|
|
us determine how close a given array is to being sorted.
|
2019-08-16 02:07:43 +08:00
|
|
|
In this implementation, I provide two algorithms, a divide-and-conquer
|
2019-08-19 21:37:49 +08:00
|
|
|
algorithm which runs in nlogn and the brute-force n^2 algorithm.
|
2019-08-16 02:07:43 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def count_inversions_bf(arr):
|
|
|
|
"""
|
2022-05-01 18:44:23 +08:00
|
|
|
Counts the number of inversions using a naive brute-force algorithm
|
2019-08-16 02:07:43 +08:00
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
arr: arr: array-like, the list containing the items for which the number
|
|
|
|
of inversions is desired. The elements of `arr` must be comparable.
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
num_inversions: The total number of inversions in `arr`
|
|
|
|
Examples
|
|
|
|
---------
|
|
|
|
>>> count_inversions_bf([1, 4, 2, 4, 1])
|
|
|
|
4
|
|
|
|
>>> count_inversions_bf([1, 1, 2, 4, 4])
|
|
|
|
0
|
|
|
|
>>> count_inversions_bf([])
|
|
|
|
0
|
|
|
|
"""
|
|
|
|
|
|
|
|
num_inversions = 0
|
|
|
|
n = len(arr)
|
|
|
|
|
2019-10-05 13:14:13 +08:00
|
|
|
for i in range(n - 1):
|
2019-08-16 02:07:43 +08:00
|
|
|
for j in range(i + 1, n):
|
|
|
|
if arr[i] > arr[j]:
|
|
|
|
num_inversions += 1
|
|
|
|
|
|
|
|
return num_inversions
|
|
|
|
|
|
|
|
|
|
|
|
def count_inversions_recursive(arr):
|
|
|
|
"""
|
|
|
|
Counts the number of inversions using a divide-and-conquer algorithm
|
|
|
|
Parameters
|
|
|
|
-----------
|
|
|
|
arr: array-like, the list containing the items for which the number
|
|
|
|
of inversions is desired. The elements of `arr` must be comparable.
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
C: a sorted copy of `arr`.
|
|
|
|
num_inversions: int, the total number of inversions in 'arr'
|
|
|
|
Examples
|
|
|
|
--------
|
|
|
|
>>> count_inversions_recursive([1, 4, 2, 4, 1])
|
|
|
|
([1, 1, 2, 4, 4], 4)
|
|
|
|
>>> count_inversions_recursive([1, 1, 2, 4, 4])
|
|
|
|
([1, 1, 2, 4, 4], 0)
|
|
|
|
>>> count_inversions_recursive([])
|
|
|
|
([], 0)
|
|
|
|
"""
|
|
|
|
if len(arr) <= 1:
|
|
|
|
return arr, 0
|
2021-10-12 00:33:06 +08:00
|
|
|
mid = len(arr) // 2
|
2022-10-13 06:54:20 +08:00
|
|
|
p = arr[0:mid]
|
|
|
|
q = arr[mid:]
|
2019-08-16 02:07:43 +08:00
|
|
|
|
2022-10-13 06:54:20 +08:00
|
|
|
a, inversion_p = count_inversions_recursive(p)
|
|
|
|
b, inversions_q = count_inversions_recursive(q)
|
|
|
|
c, cross_inversions = _count_cross_inversions(a, b)
|
2019-08-16 02:07:43 +08:00
|
|
|
|
2021-10-12 00:33:06 +08:00
|
|
|
num_inversions = inversion_p + inversions_q + cross_inversions
|
2022-10-13 06:54:20 +08:00
|
|
|
return c, num_inversions
|
2019-08-16 02:07:43 +08:00
|
|
|
|
|
|
|
|
2022-10-13 06:54:20 +08:00
|
|
|
def _count_cross_inversions(p, q):
|
2019-08-16 02:07:43 +08:00
|
|
|
"""
|
|
|
|
Counts the inversions across two sorted arrays.
|
|
|
|
And combine the two arrays into one sorted array
|
|
|
|
For all 1<= i<=len(P) and for all 1 <= j <= len(Q),
|
|
|
|
if P[i] > Q[j], then (i, j) is a cross inversion
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
P: array-like, sorted in non-decreasing order
|
|
|
|
Q: array-like, sorted in non-decreasing order
|
|
|
|
Returns
|
|
|
|
------
|
|
|
|
R: array-like, a sorted array of the elements of `P` and `Q`
|
|
|
|
num_inversion: int, the number of inversions across `P` and `Q`
|
|
|
|
Examples
|
|
|
|
--------
|
|
|
|
>>> _count_cross_inversions([1, 2, 3], [0, 2, 5])
|
|
|
|
([0, 1, 2, 2, 3, 5], 4)
|
|
|
|
>>> _count_cross_inversions([1, 2, 3], [3, 4, 5])
|
|
|
|
([1, 2, 3, 3, 4, 5], 0)
|
|
|
|
"""
|
|
|
|
|
2022-10-13 06:54:20 +08:00
|
|
|
r = []
|
2019-08-16 02:07:43 +08:00
|
|
|
i = j = num_inversion = 0
|
2022-10-13 06:54:20 +08:00
|
|
|
while i < len(p) and j < len(q):
|
|
|
|
if p[i] > q[j]:
|
2019-08-16 02:07:43 +08:00
|
|
|
# if P[1] > Q[j], then P[k] > Q[k] for all i < k <= len(P)
|
|
|
|
# These are all inversions. The claim emerges from the
|
|
|
|
# property that P is sorted.
|
2022-10-13 06:54:20 +08:00
|
|
|
num_inversion += len(p) - i
|
|
|
|
r.append(q[j])
|
2019-08-16 02:07:43 +08:00
|
|
|
j += 1
|
|
|
|
else:
|
2022-10-13 06:54:20 +08:00
|
|
|
r.append(p[i])
|
2019-08-16 02:07:43 +08:00
|
|
|
i += 1
|
|
|
|
|
2022-10-13 06:54:20 +08:00
|
|
|
if i < len(p):
|
|
|
|
r.extend(p[i:])
|
2019-08-16 02:07:43 +08:00
|
|
|
else:
|
2022-10-13 06:54:20 +08:00
|
|
|
r.extend(q[j:])
|
2019-08-16 02:07:43 +08:00
|
|
|
|
2022-10-13 06:54:20 +08:00
|
|
|
return r, num_inversion
|
2019-08-16 02:07:43 +08:00
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
arr_1 = [10, 2, 1, 5, 5, 2, 11]
|
|
|
|
|
|
|
|
# this arr has 8 inversions:
|
|
|
|
# (10, 2), (10, 1), (10, 5), (10, 5), (10, 2), (2, 1), (5, 2), (5, 2)
|
|
|
|
|
|
|
|
num_inversions_bf = count_inversions_bf(arr_1)
|
|
|
|
_, num_inversions_recursive = count_inversions_recursive(arr_1)
|
|
|
|
|
|
|
|
assert num_inversions_bf == num_inversions_recursive == 8
|
|
|
|
|
|
|
|
print("number of inversions = ", num_inversions_bf)
|
|
|
|
|
|
|
|
# testing an array with zero inversion (a sorted arr_1)
|
|
|
|
|
|
|
|
arr_1.sort()
|
|
|
|
num_inversions_bf = count_inversions_bf(arr_1)
|
|
|
|
_, num_inversions_recursive = count_inversions_recursive(arr_1)
|
|
|
|
|
|
|
|
assert num_inversions_bf == num_inversions_recursive == 0
|
|
|
|
print("number of inversions = ", num_inversions_bf)
|
|
|
|
|
|
|
|
# an empty list should also have zero inversions
|
|
|
|
arr_1 = []
|
|
|
|
num_inversions_bf = count_inversions_bf(arr_1)
|
|
|
|
_, num_inversions_recursive = count_inversions_recursive(arr_1)
|
|
|
|
|
|
|
|
assert num_inversions_bf == num_inversions_recursive == 0
|
|
|
|
print("number of inversions = ", num_inversions_bf)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2019-10-05 13:14:13 +08:00
|
|
|
main()
|