NumPy:【集計関数のaxis】の説明と実行例


サマリ

  • np.sum
    • 合計
  • np.mean
    • 平均
  • np.max
    • 最大
  • np.min
    • 最小
  • axis
    • 集計する方向

はじめに

前提

基本

データ作成

import numpy as np
a = np.arange(5)
print(a)
print(a.shape)
[0 1 2 3 4]
(5,)
  • 形状(5, )のndarray(NumPy配列)を作成
  • 1次元のデータ

集計関数を実行

print(np.sum(a)) #合計
print(np.mean(a)) #平均
print(np.max(a)) #最大
print(np.min(a)) #最小
10
2.0
4
0
  • 合計は10
  • 平均は2.0
  • 最大は4
  • 最小は0

2つの使い方を理解

b = np.max(a)
c = a.max()
print(b)
print(c)
4
4
  • np.max(a)
    • 引数にndarray(NumPy配列)を指定する
  • a.max()
    • ndarray(NumPy配列)のメソッドとして実行
    • 後ろにつける時に()を付け忘れないように
  • いずれでも実行可能
    • 以降で説明するaxisはnp.max(a, axis=0)、a.max(axis=0)のように記述できる

axisの方向

はじめに注意事項

  • このイメージは最も基本的な2次元配列(行列)で考える
  • 集計関数にaxisとして指定する”方向”を理解するためのイメージ
    • 3次元以降の多次元はイメージで理解するのは難しい
    • あくまで初歩的な理解として

方向を理解するためのイメージ

axis=0

axis=0のイメージ
  • 下方向から板を当てて、上に向かって力を加えて、ぐちゃっとつぶす
  • axis=0の方向は下から上方向ってことが分かる
  • (3, )
    • (2,3)の1次元目がなくなり(3, )になる

axis=1

axis=1のイメージ
  • 1つ目の矢印
    • 右方向から板を当てて、左に向かって力を加えて、ぐちゃっとつぶす
    • 方向としてはaxis=1の方向は右から左方向ってことが分かる
  • 2つ目の矢印
    • 勝手にコロッと回転させられる
    • ここが混乱の元
  • (2, )
    • (2,3)の2次元目がなくなり(2, )になる

サンプルコードで理解を深める

2次元データで考える

データ作成

import numpy as np
a = np.random.randint(3, 6, (2,3))
print(a)
print(a.shape)
[[5 5 3]
 [3 4 4]]
(2, 3)
  • 形状(2,3)のndarray(NumPy配列)を作成
  • 行方向にまとめ(sum)た場合の想定
    • 5+3=8
    • 5+4=9
    • 3+4=7
  • 列方向にまとめ(sum)た場合の想定
    • 5+5+3=13
    • 3+4+4=11

行方向の集計(sum)の実行

b = a.sum(axis=0)
print(b)
print(b.shape)
[8 9 7]
(3,)
  • 行方向に足している
    • sum(和)の計算結果は想定の通り
  • (2, 3)から(3, )
    • 1次元目の2がなくなって3になる

列方向の集計(sum)の実行(axis=1)

c = a.sum(axis=1)
print(c)
print(c.shape)
[13 11]
(2,)
  • 列方向に足している
    • sum(和)の計算結果は想定の通り
  • (2, 3)から(2, )
    • 2次元目の3がなくなって2になる

多次元だとどうなるか?

データ作成

import numpy as np
a = np.random.randint(3, 6, (2,3,2,4,3))
print(a)
print(a.shape)
[[[[[4 3 3]
    [4 4 4]
    [5 3 5]
    [4 4 5]]

   [[5 4 3]
    [5 4 5]
    [3 4 5]
    [3 5 3]]]


  [[[3 5 4]
    [5 4 3]
    [4 5 5]
    [4 4 3]]

   [[5 3 4]
    [5 3 4]
    [4 3 5]
    [5 3 5]]]


  [[[5 3 4]
    [3 3 3]
    [4 5 5]
    [5 5 5]]

   [[3 5 4]
    [3 4 5]
    [5 5 5]
    [3 5 5]]]]



 [[[[4 3 5]
    [5 3 4]
    [4 3 3]
    [4 4 5]]

   [[5 4 5]
    [5 3 5]
    [5 5 4]
    [5 3 4]]]


  [[[4 4 4]
    [4 3 3]
    [5 5 4]
    [3 3 3]]

   [[5 3 4]
    [3 5 3]
    [4 5 4]
    [3 3 4]]]


  [[[4 3 4]
    [3 5 3]
    [5 3 4]
    [5 5 5]]

   [[5 5 4]
    [3 4 5]
    [4 3 3]
    [3 4 5]]]]]
(2, 3, 2, 4, 3)
  • 多次元(この例では5次元)を考える
    • 2x3x2x4x3=144個のデータ
    • この程度のデータ数になると集計結果を計算で確かめるのは大変
  • 注目するところはndarray(NumPy配列)の形状(2,3,2,4,3)
    • 1次元目(0)の要素数は2
    • 2次元目(1)の要素数は3
    • 3次元目(2)の要素数は2
    • 4次元目(3)の要素数は4
    • 5次元目(4)の要素数は3
  • 次にaxis=3での集計を試してみる
    • axis=3は0から始まるので4次元目(3)のことで要素数は4

集計実行

b = a.sum(axis=3)
print(b)
print(b.shape)
[[[[17 14 17]
   [16 17 16]]

  [[16 18 15]
   [19 12 18]]

  [[17 16 17]
   [14 19 19]]]


 [[[17 13 17]
   [20 15 18]]

  [[16 15 14]
   [15 16 15]]

  [[17 16 16]
   [15 16 17]]]]
(2, 3, 2, 3)
  • 計算結果の確認は難しいので集計結果のndarrayの形状に注目する
  • (2,3,2,3)
    • axis=3である4次元目(3)の要素数4がなくなっている
  • 考え方
    • 4次元目を方向をイメージはできない
    • 4次元目がなくなって、残ったものに集計されたと考えられる

まとめ

axisについてのまとめ

  • axisはどの方向(何次元目)に向かって集計するかと言う指定
  • 多次元の場合には、何次元目がなくなるのかと考えるようにする
    • 2次元で考えたような方向でイメージを作るのは難しい
  • axisが理解できると、AIで扱う次元の理解が整った頃だと思う


Posted by futa