# 004 median of two sorted arrays

### 4. Median of Two Sorted Arrays

Hard

Let X[1..n] and Y [1..n] be two arrays, each containing n numbers already in sorted order. Give an O(lg n)-time algorithn to find the median of all 2n elements in arrays X and Y .

The median can be obtained recursively as follows. Pick the median of the sorted array A. This is just O(1) time as median is the n/2th element in the sorted array. Now compare the median of A, call is a∗ with median of B, b∗. We have two cases.

• a∗ < b∗ : In this case, the elements in B[n/2 ···n] are also greater than a . So the median cannot lie in either A[1 · · · n/2 ] or B[n/2 · · · n]. So we can just throw these away and recursively

• a∗ > b∗ : In this case, we can still throw away B[1··· n/2] and also A[ n/ · · · n] and solve a smaller subproblem recursively.

In either case, our subproblem size reduces by a factor of half and we spend only constant time to compare the medians of A and B. So the recurrence relation would be T (n) = T (n/2) + O(1) which has a solution T (n) = O(log n).

divide and conquer

• 如果X[n/2] == Y[n/2]，则找到，return
• 如果X[n/2] < Y[n/2],找X[n/2+1….n]和Y[1,2…n/2]之间
• 否则找X[1..n/2]和Y[n/2…n]

• 两个数组长度不一样
• 并不是只找一个median，如果median有两个，需要算平均

B3至少大于5个数, 所以第6小的数有可能是B1 (A1 < A2 < A3 < A4 < A5 < B1), 有可能是B2 (A1 < A2 < A3 < B1 < A4 < B2), 有可能是B3 (A1 < A2 < A3 < B1 < B2 < B3)。那就可以排除掉A1, A2, A3, 转成求A4, A5, ... B1, B2, B3, ...这些数中第3小的数的问题, k就被减半了。每次都假设A的元素个数少, pa = min(k/2, lenA)的结果可能导致k == 1或A空, 这两种情况都是终止条件。

class Solution(object):
def findMedianSortedArrays(self, nums1, nums2):
"""
:type nums1: List[int]
:type nums2: List[int]
:rtype: float
"""
n = len(nums1) + len(nums2)
if n % 2 == 1:
return self.findKth(nums1, nums2, n / 2 + 1)
else:
smaller = self.findKth(nums1, nums2, n / 2)
bigger = self.findKth(nums1, nums2, n / 2 + 1)
return (smaller + bigger) / 2.0

def findKth(self, A, B, k):
if len(A) == 0:
return B[k-1]
if len(B) == 0:
return A[k-1]
if k == 1 :
return min(A[0],B[0])

a = A[ k / 2 - 1 ] if len(A) >= k / 2 else None
b = B[ k / 2 - 1 ] if len(B) >= k / 2 else None

if b is None or (a is not None and a < b):
return self.findKth(A[k/2:], B, k - k/2)
return self.findKth(A, B[k/2:],k - k/2)  #这里要注意：因为 k/2 不一定 等于 (k - k/2),


#python3里面要用向下取整函数才可以AC，否则报错，TypeError: list indices must be integers or slices, not float

from math import floor
class Solution:
def findMedianSortedArrays(self, nums1, nums2):
"""
:type nums1: List[int]
:type nums2: List[int]
:rtype: float
"""
n = len(nums1) + len(nums2)
if n % 2 == 1:
return self.findKth(nums1, nums2, floor(n/2)+1)
else:
smaller = self.findKth(nums1, nums2, floor(n/2))
bigger = self.findKth(nums1, nums2, floor(n/2)+1)
return (smaller + bigger) / 2.0
def findKth(self, A, B, k):

if len(A) == 0:
return B[k-1]
if len(B) == 0:
return A[k-1]
if k == 1:
return min(A[0], B[0])
a = A[floor(k/2)-1] if len(A) >= k/2 else None
b = B[floor(k/2)-1] if len(B) >= k/2 else None
if b is None or (a is not None and a < b):
return self.findKth(A[floor(k/2):], B, k - floor(k/2))
else:
return self.findKth(A, B[floor(k/2):], k - floor(k/2))


## 解法三

//QuickSelect 将第k小的元素放在 a[k-1]
void QuickSelect( int a[], int k, int left, int right )
{
int i, j;
int pivot;

if( left + cutoff <= right )
{
pivot = median3( a, left, right );
//取三数中值作为枢纽元，可以很大程度上避免最坏情况
i = left; j = right - 1;
for( ; ; )
{
while( a[ ++i ] < pivot ){ }
while( a[ --j ] > pivot ){ }
if( i < j )
swap( &a[ i ], &a[ j ] );
else
break;
}
//重置枢纽元
swap( &a[ i ], &a[ right - 1 ] );

if( k <= i )
QuickSelect( a, k, left, i - 1 );
else if( k > i + 1 )
QuickSelect( a, k, i + 1, right );
}
else
InsertSort( a + left, right - left + 1 );
}


class Solution(object):
def findMedianSortedArrays(self, nums1, nums2):
"""
:type nums1: List[int]
:type nums2: List[int]
:rtype: float
"""
def findKth(A, pa, B, pb, k):
res = 0
m = 0
while pa < len(A) and pb < len(B) and m < k:
if A[pa] < B[pb]:
res = A[pa]
m += 1
pa += 1
else:
res = B[pb]
m += 1
pb += 1

while pa < len(A) and m < k:
res = A[pa]
pa += 1
m += 1

while pb < len(B) and m < k:
res = B[pb]
pb += 1
m += 1
return res

n = len(nums1) + len(nums2)
if n % 2 == 1:
return findKth(nums1,0, nums2,0, n / 2 + 1)
else:
smaller = findKth(nums1,0, nums2,0, n / 2)
bigger = findKth(nums1,0, nums2,0, n / 2 + 1)
return (smaller + bigger) / 2.0