NumPy 广播机制

广播是NumPy中一种强大的机制,它允许不同形状的数组之间进行算术运算。当两个数组的形状不完全相同时,NumPy会自动扩展较小数组的维度,使其与较大数组兼容,而无需实际复制数据。理解广播是高效使用NumPy的关键。

1. 什么是广播?

广播描述了NumPy在算术运算期间如何处理不同形状的数组。在一定的约束条件下,较小的数组会“广播”到较大数组的形状,使得它们具有兼容的形状。广播不会导致内存复制,因此非常高效。

import numpy as np

# 最简单的广播:标量与数组
arr = np.array([1, 2, 3])
print("arr + 5 =", arr + 5)   # 5被广播到整个数组

# 一维数组与二维数组
A = np.array([[1, 2, 3], [4, 5, 6]])
b = np.array([10, 20, 30])
print("A + b:\n", A + b)      # b被广播到每一行

输出:

arr + 5 = [6 7 8]
A + b:
 [[11 22 33]
 [14 25 36]]

2. 广播的规则

NumPy的广播遵循一组严格的规则,用于确定两个数组是否兼容以及如何进行广播。

  1. 从后往前对齐维度:比较两个数组的形状,从最后一个维度开始向前对齐。
  2. 维度兼容条件:如果两个维度相等,或者其中一个维度为1,或者其中一个数组在该维度缺失(视为1),则这两个维度是兼容的。
  3. 扩展维度:将维度为1的数组扩展到与另一个数组相同的长度(逻辑上,不复制数据)。
  4. 如果不兼容:如果上述条件都不满足,则抛出 ValueError: operands could not be broadcast together

2.1 示例:维度对齐

# 数组形状
A = np.ones((3, 4))      # shape (3, 4)
B = np.ones((4,))        # shape (4,)
# 从后往前对齐:A的最后一个维度是4,B的最后一个维度也是4,兼容
# 然后A的前一个维度是3,B在此维度缺失(视为1),所以B被广播为(1,4),再复制为(3,4)
print("A + B shape:", (A + B).shape)   # (3, 4)

# 另一个例子
C = np.ones((3, 1))      # shape (3, 1)
D = np.ones((4,))        # shape (4,)
# 对齐:C的最后维度1,D的最后维度4,不相等,但C的该维度为1,所以兼容
# 广播后C变为(3,4),D变为(1,4)然后复制为(3,4)?实际上D也会广播,最终结果(3,4)
print("C + D shape:", (C + D).shape)   # (3, 4)

输出:

A + B shape: (3, 4)
C + D shape: (3, 4)

2.2 不兼容的情况

E = np.ones((3, 4))
F = np.ones((3, 5))
# 尝试相加会报错,因为最后一个维度4和5不相等且没有一个为1
# E + F  # ValueError: operands could not be broadcast together

3. 广播的常见模式

3.1 标量与数组

标量可以看作形状为 () 的数组,它会广播到整个数组。

3.2 一维数组与二维数组(行广播)

matrix = np.array([[1, 2, 3],
                   [4, 5, 6],
                   [7, 8, 9]])
row = np.array([10, 20, 30])
result = matrix + row   # row 广播到每一行
print("matrix + row:\n", result)

输出:

matrix + row:
 [[11 22 33]
 [14 25 36]
 [17 28 39]]

3.3 列广播(利用维度1)

column = np.array([[10], [20], [30]])   # shape (3,1)
result = matrix + column   # column 广播到每一列
print("matrix + column:\n", result)

输出:

matrix + column:
 [[11 12 13]
 [24 25 26]
 [37 38 39]]

3.4 三维广播

# 创建一个三维数组,形状 (2,3,4)
arr3d = np.arange(24).reshape(2,3,4)
# 一个二维数组,形状 (3,4)
arr2d = np.array([[10,20,30,40],
                  [50,60,70,80],
                  [90,100,110,120]])
# 广播:arr2d 会与 arr3d 的每个“块”相加
result = arr3d + arr2d
print("结果 shape:", result.shape)   # (2,3,4)

4. 手动扩展维度:np.newaxis

有时我们需要手动添加维度以满足广播规则。可以使用 np.newaxis(或 None)在指定位置插入新轴。

a = np.array([1, 2, 3])          # shape (3,)
b = np.array([4, 5])              # shape (2,)

# 我们希望将 a 作为列向量,b 作为行向量,得到外积
# a[:, None] 变为 (3,1),b[None, :] 变为 (1,2),广播得到 (3,2)
outer = a[:, np.newaxis] * b[np.newaxis, :]
print("外积:\n", outer)

# 更简洁的写法:a[:, None] * b
print("简洁写法:\n", a[:, None] * b)

输出:

外积:
 [[ 4  5]
 [ 8 10]
 [12 15]]
简洁写法:
 [[ 4  5]
 [ 8 10]
 [12 15]]

5. 广播的实际应用

  • 数据标准化:减去均值,除以标准差(均值和标准差是一维数组,与二维数据广播)。
  • 计算外积:如上例。
  • 网格生成:结合 ogridmgrid
  • 不同维度的特征组合:在机器学习中常用。
# 数据标准化示例
data = np.random.rand(5, 3)      # 5个样本,3个特征
mean = data.mean(axis=0)         # 每个特征的均值,shape (3,)
std = data.std(axis=0)           # 每个特征的标准差,shape (3,)
# 广播:mean 和 std 扩展到每一行
data_normalized = (data - mean) / std
print("标准化后每列均值应为0:", data_normalized.mean(axis=0))

6. 广播与性能

广播不会在内存中实际扩展数组,它通过巧妙的迭代方式实现。因此,广播非常高效,应优先于显式的 tilerepeat 操作。

# 不推荐:显式复制
big_arr = np.tile(b, (3, 1))   # 创建副本,浪费内存

# 推荐:广播
result = a[:, None] + b
注意: 尽管广播高效,但在某些情况下,广播可能导致中间结果的内存占用变大(例如,如果广播涉及非常大的数组,迭代开销仍然存在)。但通常广播是首选。

7. 常见的广播陷阱

  • 误解形状:例如,试图对形状 (3,) 和 (3,1) 进行运算,结果可能不是预期的 (3,3) 而是 (3,3)。
  • 忘记括号:在切片时使用 Nonenp.newaxis 容易出错。
  • 不兼容的形状:记住维度从后往前匹配的规则,有时需要 reshape。
# 错误示例:试图将 (3,) 与 (4,) 相加,但希望得到 (3,4)
a = np.array([1,2,3])
b = np.array([4,5,6,7])
# a + b   # 会报错,因为形状不兼容
# 正确做法:手动扩展
a_col = a[:, None]   # (3,1)
print("广播结果 shape:", (a_col + b).shape)   # (3,4)

总结

广播是NumPy中一个强大且优雅的特性,它使得数组运算更加灵活和高效。掌握广播规则和常见模式,可以写出更简洁、更快速的代码。下一章我们将介绍NumPy的线性代数模块,进一步探索数组的高级运算。