Numpy 函数中axis的含义
Numpy中很多的基本操作函数,都有一个参数axis
,比如:
- argmax 返回最大元素的索引
- argmin 返回最大元素的索引
- sum
- max
- min
- mean
- average
- median
官方解释
我们从numpy doc
里面的argmax
函数可以看到下面的解释(有删减)
1 | def argmax(a, axis=None, out=None): |
官方 Examples
1 | 6).reshape(2,3) a = np.arange( |
怎么样更好的理解axis
呢?
以argmax
为例,功能是返回最大元素的索引。可以这么来理解。
当数组是一维的,里面的元素就是一维数组里面的单个值,此时
axis
是没有作用的,只能取值为0
,比如1
2
3
4import numpy as np
1, 2, 3, 4, 5, 6, 7]) a = np.array([
np.argmax(a)
6当数组是二维的,就分了几种情况:
- 不指定axis时,把整个数组当作一维数组来处理,假定数组是3x3,二维数组
1
2
3
4
5
6
7
8
91, 3, 5],[2, 4, 6],[3, 5, 8]]) arr = np.array([[
print(arr.shape)
(3, 3)
print(arr)
[[1 3 5]
[2 4 6]
[3 5 8]]
# 不指定axis时,把整个数组展开成一维数组来处理 np.argmax(arr)
8 - 当
axis=0
时,假定数组是2x3,二维数组,输出的shape应该是有3个元素的索引的一维数组,按列统计,共有3列,给出每列最大值在列方向上的索引1
2
3
4
5
6
7
8
9
10
111, 3, 5],[6, 4, 2]]) arr = np.array([[
print(arr)
[[1 3 5]
[6 4 2]]
# 第一列最大值时6,在列方向上的索引为1
# 第二列最大值为4,在列方向上的索引为1
# 第二列最大值为5,在列方向上的索引为0
0) np.argmax(arr, axis=
array([1, 1, 0], dtype=int64)
2) np.argmax(arr, axis=-
array([1, 1, 0], dtype=int64) - 当
axis=1
时,假定数组是2x3,二维数组,输出的shape应该是有2个元素的索引的一维数组,按行统计,共有3行,给出每行最大值在行方向上的索引1
2
3
4
5
6
7
8
9
101, 3, 5],[6, 4, 2]]) arr = np.array([[
print(arr)
[[1 3 5]
[6 4 2]]
# 第一行最大值时5,在行方向上的索引为2
# 第二行最大值为6,在行方向上的索引为0
1) np.argmax(arr, axis=
array([2, 0], dtype=int64)
1) np.argmax(arr, axis=-
array([2, 0], dtype=int64)
- 不指定axis时,把整个数组当作一维数组来处理,假定数组是3x3,二维数组
当数组是三维数组时。
- 不指定
axis
时,将数组展开成一维数组,很好理解 - 当
axis=0
时,假定数组时2x3x2的三维数组,输出的shape应该是3x2个元素的索引的二维数组。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
4212, size=(2, 3, 2)) arr = np.random.randint(
print(arr)
[[[11 8]
[11 5]
[ 7 0]]
[[ 8 10]
[ 9 5]
[ 4 10]]]
#
print(arr[0,:,:])
[[11 8]
[11 5]
[ 7 0]]
print(arr[1,:,:])
[[ 8 10]
[ 9 5]
[ 4 10]]
# 分别对比 arr[0,:,:], arr[1,:,:] 对应位置,取其中最大值的索引,索引取值范围[0-1]
# out[0][0] = index(max(11, 8)) = 0
# out[0][1] = index(max(8, 10)) = 1
# out[1][0] = index(max(11, 9)) = 0
# out[1][1] = index(max(5, 5)) = 0
# out[2][0] = index(max(7, 4)) = 0
# out[2][1] = index(max(0, 10)) = 1
# out = [[0, 1],
# [0, 0],
# [0, 1]]
print(np.argmax(arr, axis=0).shape)
(3, 2)
0) np.argmax(arr, axis=
array([[0, 1],
[0, 0],
[0, 1]], dtype=int64)
3) np.argmax(arr, axis=-
array([[0, 1],
[0, 0],
[0, 1]], dtype=int64)
0-arr.ndim) np.argmax(arr, axis=
array([[0, 1],
[0, 0],
[0, 1]], dtype=int64) - 当
axis=1
时,假定数组时2x3x2的三维数组,输出的shape应该是2x2个元素的索引的二维数组。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
3612, size=(2, 3, 2)) arr = np.random.randint(
print(arr)
[[[11 8]
[11 5]
[ 7 0]]
[[ 8 10]
[ 9 5]
[ 4 10]]]
print(arr[:,0,:])
[[11 8]
[ 8 10]]
print(arr[:,1,:])
[[11 5]
[ 9 5]]
print(arr[:,2,:])
[[ 7 0]
[ 4 10]]
# 分别对比 arr[:,0,:], arr[:,1,:], arr[:,2,:] 对应位置,取其中最大值的索引,索引取值范围[0-2]
# out[0][0] = index(max(11, 11, 7 )) = 0
# out[0][1] = index(max(8 , 5, 0 )) = 0
# out[1][0] = index(max(8 , 9, 4 )) = 1
# out[1][1] = index(max(10, 5, 10)) = 0
# out = [[0, 0],
# [1, 0]]
print(np.argmax(arr, axis=1).shape)
(2, 2)
1) np.argmax(arr, axis=
array([[0, 0],
[1, 0]], dtype=int64)
2) np.argmax(arr, axis=-
array([[0, 0],
[1, 0]], dtype=int64)
1-arr.ndim) np.argmax(arr, axis=
array([[0, 0],
[1, 0]], dtype=int64) - 当
axis=2
时,假定数组时2x3x2的三维数组,输出的shape应该是2x3个元素的索引的二维数组。我们将数组的三个维度依次称为行,列,高1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
3512, size=(2, 3, 2)) arr = np.random.randint(
print(arr)
[[[11 8]
[11 5]
[ 7 0]]
[[ 8 10]
[ 9 5]
[ 4 10]]]
print(arr[:,:,0])
[[11 11 7]
[ 8 9 4]]
print(arr[:,:,1])
[[ 8 5 0]
[10 5 10]]
print(np.argmax(arr, axis=2).shape)
(2, 3)
# 分别对比 arr[:,:,0], arr[:,:,1] 对应位置,取其中最大值的索引,索引取值范围[0-1]
# out[0][0] = index(max(11, 8)) = 0
# out[0][1] = index(max(11, 5)) = 0
# out[0][2] = index(max(7 , 0)) = 0
# out[1][0] = index(max(8 , 10)) = 1
# out[1][1] = index(max(9 , 5)) = 0
# out[1][2] = index(max(4 , 10)) = 1
# out = [[0, 0, 0],
# [1, 0, 1]]
2) np.argmax(arr, axis=
array([[0, 0, 0],
[1, 0, 1]], dtype=int64)
1) np.argmax(arr, axis=-
array([[0, 0, 0],
[1, 0, 1]], dtype=int64)
2-arr.ndim) np.argmax(arr, axis=
array([[0, 0, 0],
[1, 0, 1]], dtype=int64)
- 不指定