반응형
파이토치는 모델을 .pt 또는 .pth 확장자를 사용하여 모델을 저장한다.
모델을 저장하는 방법은 두 가지가 있다.
전체 모델을 저장하거나 파라미터만 저장한다.
전체 모델을 저장하면 모델 파일의 사이즈가 커지지만 다른 파일에서 불러올 때 모델 구조에 대해서 미리 선언해주지 않아도 통째로 불러올 수가 있다.
파라미터마 저장하면 모델 파일의 사이즈가 작아지지만 파라미터를 불러오기 전에 동일하게 모델을 선언해주고 그 모델에 파라미터를 대입해주어야한다.
1. 전체 모델 저장 및 로드
# 저장
torch.save(model, 'save/to/path/model.pt')
# 로드
model = torch.load('load/from/path/model.pt')
2. 파라미터 저장 및 로드
# 저장
torch.save(model.state_dict(), 'save/to/path/model.pth')
# 로드
model = MyModelDefinition(args)
model.load_state_dict(torch.load('load/from/path/model.pth'))
- state_dict
각 계층을 매개변수 텐서로 매핑되는 Python 딕셔너리
그리고 로드한 모델을 가지고 추론을 할 때는
model.eval()
이를 호출하여 모델을 추론 모드로 바꾸어주어야 파라미터 업데이트가 일어나지 않는다.
반대는 model.train()
더보기
이를 트레이닝에 사용하려면 model.train()를 콜(Call)해야 합니다.
이를 추론(Inference)에 사용하려면 model.eval()를 콜(Call)해야 합니다.
추가 : 체크포인트 저장
# 저장
torch.save({
'epoch': EPOCH,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, 'save/to/path/model.pth')
# 로드
model = MyModelDefinition(args)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load('load/from/path/model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
반응형
'Programming > Pytorch' 카테고리의 다른 글
Resnet50 finetuning - CIFAR10 DATASET (0) | 2023.02.28 |
---|---|
torch.onnx.export() option (1) | 2023.02.28 |
[yolov7] hyperparameter evolve : Index Error: index 30 is out of bounds for axis 0 with size 30 (0) | 2023.02.22 |