PytochでDataParallelを利用する

A8バナー広告

Pytorchにはtorch.nn.DataParallelという機能がある。

このクラスを使うと、複数GPUデバイスがあるときに、計算を分散し高速化できる。

利用に少し困るところがあったので、書き置きをしておく。

モデルのattributeアクセスにはmodel.module.attributeの順番

DataParallel の利用前には model.weight とアクセスできたのに、DataParalle の利用後にはそれができない。

代わりに DataParallel.module.weight と記載する。

Python

Posted by user