なぜnp.dotの前に【2次元】に変換するのか?


はじめに

前提

画像処理

  • 扱うデータ
    • 3次元以上の多次元テンソル
    • 畳み込み処理などで重みとdot積を計算
  • 計算の工夫
    • 多次元データをそのまま計算しない
    • 2次元に変換して計算する
  • 形状変換しても大丈夫?
    • 結論は変換しても計算結果は同じ
    • 高速に計算できる2次元(行列)に変換
      • タイトルの疑問に対する回答
  • 以降では、本当に大丈夫なのか以降の実行ログを見ながら説明します

形状変換しても計算結果は同じか?

  • 4次元は数字が多くなるので、3次元で説明

データ作成

変換する方のデータ

import numpy as np
a = np.arange(12).reshape(3,2,2)
b = a.reshape(-1,2)
print(a)
print(a.shape)
print(b)
print(b.shape)
[[[ 0  1]
  [ 2  3]]

 [[ 4  5]
  [ 6  7]]

 [[ 8  9]
  [10 11]]]
(3, 2, 2)
[[ 0  1]
 [ 2  3]
 [ 4  5]
 [ 6  7]
 [ 8  9]
 [10 11]]
(6, 2)
  • 3次元と2次元のndarray(NumPy配列)を作成
  • np.arange(12)の12個の数値データを使って2種類の次元のデータを作成
  • a
    • (3,2,2)
    • 3次元テンソル
  • b
    • (6,2)
    • 2次元テンソルに変換
    • 何となくbの方が計算しやすそうに見えませんか?
  • 行列計算に関係するのは最後の次元なので、どちらも最後の次元の要素数が「2」になるようにreshape
  • 参考

np.dotに使うデータ

w = np.array([[2,3,1],[4,2,2]])
print(w)
print(w.shape)
[[2 3 1]
 [4 2 2]]
(2, 3)
  • dot積を行うデータ
  • 形状(2,3)のndarray(NumPy配列)を作成
    • 後ろ側のデータなので、最初の次元の要素数が「2」にする

計算結果と確認

それぞれでnp.dot計算

c = np.dot(a, w)
d = np.dot(b, w)
print(c)
print(c.shape)
print(d)
print(d.shape)
[[[ 4  2  2]
  [16 12  8]]

 [[28 22 14]
  [40 32 20]]

 [[52 42 26]
  [64 52 32]]]
(3, 2, 3)
[[ 4  2  2]
 [16 12  8]
 [28 22 14]
 [40 32 20]
 [52 42 26]
 [64 52 32]]
(6, 3)
  • 形状の確認
    • c
      • aの(3,2,2)から(3,2,3)になった
    • d
      • (6,2)から(6,3)になった
  • 数値の比較
    • 中身の計算された数字は同じ
  • 形状は異なるが、中身(計算結果)が同じことが分かる

さらに、2次元を3次元に戻して比較

e = d.reshape(a.shape[0], a.shape[1], -1)
print(c)
print(c.shape)
print(e)
print(e.shape)
[[[ 4  2  2]
  [16 12  8]]

 [[28 22 14]
  [40 32 20]]

 [[52 42 26]
  [64 52 32]]]
(3, 2, 3)
[[[ 4  2  2]
  [16 12  8]]

 [[28 22 14]
  [40 32 20]]

 [[52 42 26]
  [64 52 32]]]
(3, 2, 3)
  • まったく同じ計算結果

まとめ

  • タイトルに書いたなぜ?に対する回答
    • 高速化のため
    • 2次元の行列計算にした方が処理が早い(GPUの特性に拠る)
    • しかし、計算結果が変わってしまうと問題
  • 計算結果
    • 3次元から2次元に変換
    • 2次元から3次元に戻す
    • 変換しない場合の計算結果と一致した


Posted by futa