NumPy 排序与搜索

排序和搜索是数据处理中的常见操作。NumPy提供了高效的多维数组排序、部分排序、以及灵活的元素搜索和索引查找功能,能够快速处理大规模数据。

1. 排序函数

1.1 np.sort() - 返回排序后的副本

np.sort(arr, axis=-1, kind='quicksort') 返回数组的排序副本,默认沿最后一个轴进行快速排序。

import numpy as np

arr = np.array([3, 1, 5, 2, 4])
sorted_arr = np.sort(arr)
print("原数组:", arr)
print("排序后:", sorted_arr)

# 二维数组沿指定轴排序
arr2d = np.array([[3, 2, 1], [6, 5, 4]])
print("沿行排序 (axis=1):\n", np.sort(arr2d, axis=1))
print("沿列排序 (axis=0):\n", np.sort(arr2d, axis=0))

输出:

原数组: [3 1 5 2 4]
排序后: [1 2 3 4 5]
沿行排序 (axis=1):
 [[1 2 3]
 [4 5 6]]
沿列排序 (axis=0):
 [[3 2 1]
 [6 5 4]]

1.2 原地排序:ndarray.sort()

数组对象的 sort() 方法会原地修改数组,不返回新数组。

arr = np.array([3, 1, 5, 2, 4])
arr.sort()
print("原地排序后:", arr)

输出:

原地排序后: [1 2 3 4 5]

1.3 降序排序

NumPy没有直接的降序参数,可以通过切片反转实现:

arr = np.array([3, 1, 5, 2, 4])
desc = np.sort(arr)[::-1]
print("降序:", desc)

输出:

降序: [5 4 3 2 1]

1.4 np.argsort() - 返回排序后的索引

有时候我们需要的是元素排序后的索引位置,而不是排序后的值。argsort 返回这些索引。

arr = np.array([10, 30, 20, 40])
idx = np.argsort(arr)
print("排序索引:", idx)          # [0 2 1 3] 对应值 [10,20,30,40]
print("通过索引获取排序值:", arr[idx])

输出:

排序索引: [0 2 1 3]
通过索引获取排序值: [10 20 30 40]

argsort 也支持 axis 参数,可用于多维数组。

1.5 np.lexsort() - 多级排序

lexsort 使用多个键进行间接排序,类似于 SQL 中的 ORDER BY 多个字段。它返回整数索引,最后一个键是主排序键。

# 数据:名字、年龄、身高
names = np.array(['Alice', 'Bob', 'Charlie', 'Alice'])
ages = np.array([25, 30, 35, 20])
heights = np.array([165, 175, 180, 160])

# 按名字排序,名字相同则按年龄排序(注意:lexsort 的键顺序是反向的)
indices = np.lexsort((ages, names))  # 先按 names,再按 ages
print("排序索引:", indices)
print("排序后数据:")
for i in indices:
    print(names[i], ages[i], heights[i])

输出:

排序索引: [3 0 1 2]
排序后数据:
Alice 20 160
Alice 25 165
Bob 30 175
Charlie 35 180

2. 部分排序

当只需要最小的 k 个元素或最大的 k 个元素时,使用 partition 比完整排序更高效。

2.1 np.partition()np.argpartition()

np.partition(arr, kth) 重新排列数组,使得第 k 个位置的元素处于最终排序后的正确位置,左侧元素都小于等于它,右侧都大于等于它,但两侧内部无序。返回副本。

arr = np.array([7, 2, 9, 1, 5, 8, 3])
# 获取最小的3个元素(kth=2,因为索引从0开始)
partitioned = np.partition(arr, 2)
print("partition kth=2:", partitioned)
# 结果:前3个(索引0,1,2)是任意顺序的三个最小元素

# 获取最大的3个元素(使用负数索引)
largest3 = np.partition(arr, -3)[-3:]
print("最大的3个:", largest3)

# argpartition 返回索引
idx = np.argpartition(arr, 2)
print("最小的3个元素的索引:", idx[:3])

输出:

partition kth=2: [1 2 3 7 5 8 9]
最大的3个: [7 8 9]
最小的3个元素的索引: [3 1 6]

注意:partition 不保证左侧或右侧内部的顺序,只保证分界点正确。

3. 搜索函数

3.1 查找最大值/最小值的位置:np.argmin()np.argmax()

返回数组中最小/最大元素的索引(展平后或沿指定轴)。

arr = np.array([[3, 1, 4], [1, 5, 9]])
print("全局最小值索引 (展平):", np.argmin(arr))          # 1
print("全局最大值索引:", np.argmax(arr))                # 5
print("沿列的最小值索引 (axis=0):", np.argmin(arr, axis=0))   # [1 0 0]
print("沿行的最大值索引 (axis=1):", np.argmax(arr, axis=1))   # [2 2]

输出:

全局最小值索引 (展平): 1
全局最大值索引: 5
沿列的最小值索引 (axis=0): [1 0 0]
沿行的最大值索引 (axis=1): [2 2]

3.2 条件搜索:np.where()

np.where(condition, [x, y]) 有两种用法:

  • 如果只给条件,返回满足条件的元素的索引元组(可用于多维索引)。
  • 如果同时给 x 和 y,则根据条件从 x 或 y 中选择元素,类似三元运算符。

arr = np.array([1, 5, 3, 8, 2])
indices = np.where(arr > 3)
print("arr > 3 的索引:", indices)        # (array([1, 3]),)
print("arr > 3 的值:", arr[indices])

# 多维数组
arr2d = np.array([[1, 2], [3, 4]])
rows, cols = np.where(arr2d > 2)
print("大于2的元素坐标:", list(zip(rows, cols)))

# 三元选择
result = np.where(arr > 3, arr, -1)
print("大于3保留原值,否则替换为-1:", result)

输出:

arr > 3 的索引: (array([1, 3]),)
arr > 3 的值: [5 8]
大于2的元素坐标: [(1, 0), (1, 1)]
大于3保留原值,否则替换为-1: [-1  5 -1  8 -1]

3.3 有序数组中的搜索:np.searchsorted()

在有序数组中查找插入位置以保持顺序,返回索引。对于需要插入多个元素的情况非常高效。

sorted_arr = np.array([1, 3, 5, 7])
print("2 应插入的位置:", np.searchsorted(sorted_arr, 2))   # 1(插入到索引1前)
print("多个值:", np.searchsorted(sorted_arr, [2, 4, 6]))  # [1,2,3]

# 可以指定 side='right' 得到右侧插入位置
print("6 插入右侧:", np.searchsorted(sorted_arr, 6, side='right'))  # 3(插入到索引3后)

输出:

2 应插入的位置: 1
多个值: [1 2 3]
6 插入右侧: 3

3.4 非零元素:np.nonzero()

返回非零元素的索引,与 np.where(arr != 0) 相同。

arr = np.array([0, 2, 0, 3, 4])
nonzero_idx = np.nonzero(arr)
print("非零索引:", nonzero_idx)          # (array([1, 3, 4]),)
print("非零值:", arr[nonzero_idx])

输出:

非零索引: (array([1, 3, 4]),)
非零值: [2 3 4]

3.5 提取元素:np.extract()

根据条件提取元素,返回一维数组。

arr = np.array([1, 2, 3, 4, 5])
condition = arr % 2 == 0
result = np.extract(condition, arr)
print("偶数:", result)   # [2 4]

输出:

偶数: [2 4]

4. 应用示例:获取排序后的索引并重新排列其他数组

使用 argsort 可以对一个数组排序,同时按相同顺序重排其他数组。

# 假设我们有学生的分数和姓名
scores = np.array([85, 92, 78, 90])
names = np.array(['Alice', 'Bob', 'Charlie', 'David'])

# 按分数升序排序
sorted_idx = np.argsort(scores)
print("按分数排序后的姓名:", names[sorted_idx])
print("对应的分数:", scores[sorted_idx])

# 降序
desc_idx = np.argsort(scores)[::-1]
print("降序姓名:", names[desc_idx])

输出:

按分数排序后的姓名: ['Charlie' 'Alice' 'David' 'Bob']
对应的分数: [78 85 90 92]
降序姓名: ['Bob' 'David' 'Alice' 'Charlie']

5. 性能考虑

  • 对于大型数组,np.sortnp.argsort 默认使用快速排序(不稳定),也可选择 kind='mergesort'(稳定)或 'heapsort'
  • 如果只需要部分排序,优先使用 np.partition,它比完整排序快得多。
  • searchsorted 对于在有序数组中查找插入点非常高效(O(log n))。
  • wherenonzero 返回的索引元组可以直接用于花式索引。
提示: 对于多维数组,argsortpartition 的 axis 参数可以灵活控制排序方向。熟悉这些函数可以避免编写复杂的 Python 循环,提高代码效率。

总结

本节介绍了 NumPy 中用于排序和搜索的核心函数。通过 sortargsortpartition 可以高效地重排数据,wheresearchsortedargmin 等函数则帮助快速定位元素。结合这些工具,可以轻松完成许多数据预处理任务。下一章我们将学习 NumPy 的文件输入输出操作。