numpy axisについて 【Python】

numpyのaxisの操作について解説します。

numpyのaxisはよく次のように説明されます。

  • 2次元配列の場合、
    • axis=0は行方向
    • axis=1は列方向
  • 3次元配列の場合、
    • axis=0は奥行き方向
    • axis=1は行方向
    • axis=2は列方向

上記のような説明だと、3次元まで直感的に理解できるのですが、4次元以降だと理解が苦しくなります。ディープラーニングの場合、4次元以降の配列を扱うことはよくあります。(バッチサイズ, チャネル, 高さ, 幅)のような配列です。このようなとき、先ほどのような理解だと難しくなります。

結論から言うと、axis=行方向・列方向ではなく、配列の何番目の次元を潰すか、と考えると一番混乱しません。指定したaxisが何か演算されて消えるというイメージも有効です。

目次

基本の考え方

例えば、numpy配列が以下のようにアクセスされるとき

a[i][j][k]

これは

  • axis=0 iの方向
  • axis=1 jの方向
  • axis=2 kの方向

を意味します。つまり、配列の添字が左からaxis=0, 1, 2と決まっていきます。

例:2次元配列

a[i][j]

np.sum(a, axis=0) iを消す

a = np.array([
    [1, 2, 3],
    [4, 5, 6]
]) # shape = (2, 3)
np.sum(a, axis=0)
>> array([5, 7, 9]) # [1+4, 2+5, 3+6] → [5, 7, 9] shape (3,)

np.sum(a, axis=1) jを消す

a = np.array([
    [1, 2, 3],
    [4, 5, 6]
]) # shape = (2, 3)
np.sum(a, axis=1)
>> array([6, 15]) # [1+2+3, 4+5+6] → [6, 15] shape = (2,)

上記の結果から分かるように、(2, 3)というサイズに対してnp.sumという演算をしています。

  • axis=0のとき、(2, 3) → (3, )
  • axis=1の時、(2, 3) → (2, )

(2, 3)というサイズを左から0次元目, 1次元目と考えるとaxis=0のときは, 0次元目の次元が消えています。axis=1のときは1次元目の次元が消えて(2, )となっています。

例:3次元配列

a[i][j][k]

axis=0(iを消す)

i=0とi=1を足すイメージです

a = np.array([
    [
        [1, 2, 3],
        [4, 5, 6]
    ],
    [
        [7, 8, 9],
        [10, 11, 12]
    ]
]) # shape = (2, 2, 3)
np.sum(a, axis=0)
>> array([[ 8, 10, 12],
       [14, 16, 18]])
# [
#  [1+7,  2+8,  3+9 ],
#  [4+10, 5+11, 6+12]
# ]

axis=1(jを消す)

j=0 と j=1 を足すイメージです

a = np.array([
    [
        [1, 2, 3],
        [4, 5, 6]
    ],
    [
        [7, 8, 9],
        [10, 11, 12]
    ]
]) # shape = (2, 2, 3)
np.sum(a, axis=1)
>> array([[ 5,  7,  9],
       [17, 19, 21]])
# [
#  [1+4,  2+5,  3+6 ],
#  [7+10, 8+11, 9+12]
# ]
a = np.array([
    [
        [1, 2, 3],
        [4, 5, 6]
    ],
    [
        [7, 8, 9],
        [10, 11, 12]
    ]
]) # shape = (2, 2, 3)
np.sum(a, axis=1)
>> array([[ 6, 15],
       [24, 33]])
# [
#  [1+2+3,  4+5+6],
#  [7+8+9, 10+11+12]
# ]

覚え方

axisは消える次元

入力 shape: (d0, d1, d2)
axis=0 → (   , d1, d2)
axis=1 → (d0,    , d2)
axis=2 → (d0, d1,    )

消えたところでsum/mean/maxなどが実行される。

次元を残す

keepdims=Trueをオプションとしてつけると消した次元を長さ1で残すことができます。

2次元配列の例のところでも解説しましたが

axis=0の場合(2, 3) → (3, )

のようにオプションを付けないで実行すると次元が消えます。

線形代数的にはベクトルも [N x 1]の行列と捉えることができるので、(3, 1)のように明示的に次元を残した方が可読性やデバッグなどの関係でいいことがあります。

a = np.array([
    [1, 2, 3],
    [4, 5, 6]
]) # shape = (2, 3)
np.sum(a, axis=0, )
>> array([[5, 7, 9]])  # shape (1, 3)

まとめ

numpyのaxisの考え方を消す次元を指定するというようなやり方をするとわかりやすいということを解説しました。

よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!

コメント

コメントする

目次