GPUを使った機械学習のデバッグ

エンジニアのコト

Faster R-CNNを使って物体認識を仕様と勉強中です。環境はGoogle Colab、ライブラリはPytorchを使っています。

Pytorchでは以下のようにdeviceを指定することでGPUで学習や推論処理を行うことができます。

import torchvision

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.to(device)

以下のようにして学習をさせようとしていた時、

torch.manual_seed(1)

dataset = Dataset(BASE_PATH)
n_train = int(len(dataset) * 0.7)
n_val = len(dataset) - n_train

train, val = torch.utils.data.random_split(dataset, [n_train, n_val])

def collate_fn(batch):
    return tuple(zip(*batch))

data_loader = torch.utils.data.DataLoader(train, batch_size=2, shuffle=True, collate_fn=collate_fn)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

for i, batch in enumerate(data_loader):
    images, targets = batch

    imagess = [image.to(device) for image in images]
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

    pred = model(images, targets)
    losses = sum(loss for loss in pred.values())
    loss_value = losses.item()

    optimizer.zero_grad()
    losses.backward()
    optimizer.step()

    if (i+1) % 10 == 0:
        print(f"epoch #{epoch+1} Iteration #{i+1} loss: {loss_value}")  

RuntimeError: CUDA error: device-side assert triggeredが発生してしまいました。より親切なエラーを見るためにはCPUで実行すればよいということを学んだ。

# cpuで実行するように変更
device = torch.device('cpu') 
model.to(device)

for i, batch in enumerate(data_loader):
    images, targets = batch

    imagess = [image.to(device) for image in images]
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

    pred = model(images, targets)
    losses = sum(loss for loss in pred.values())
    loss_value = losses.item()

    optimizer.zero_grad()
    losses.backward()
    optimizer.step()

    if (i+1) % 10 == 0:
        print(f"epoch #{epoch+1} Iteration #{i+1} loss: {loss_value}")  

そうするとIndexError: Target 6 is out of bounds.ということでfeature mapあたりが間違っていることがわかりました。めでたしめでたし。

コメント

タイトルとURLをコピーしました