ModelCheckpoint

2024. 9. 30. 09:17AI/OpenCV

반응형

`ModelCheckpoint`는 모델 학습 중에 특정 조건을 만족할 때 모델의 가중치(weight)나 전체 모델 상태를 저장하는 Keras 콜백 함수이고 학습이 중단되거나 중간에 멈췄을 때도 이전에 저장된 모델을 불러와 이어서 학습할 수 있다.

 


`ModelCheckpoint`에 저장되는 내용

1. 모델 가중치 (Weights):
   - 모델의 훈련된 가중치(파라미터)를 저장한다. 가중치는 모델의 학습 중에 업데이트되는 중요한 값이다.
   - 기본적으로 `save_weights_only=False`로 설정되면, 모델의 전체 구조와 가중치를 모두 저장한다. 만약 `save_weights_only=True`로 설정하면, 가중치만 저장된다.

2. 모델의 전체 상태:
   - 모델의 가중치뿐만 아니라 모델의 구조(아키텍처), 옵티마이저 상태, 컴파일 정보 등도 함께 저장할 수 있다. 
   - 전체 상태를 저장하면 모델을 재로드할 때 옵티마이저 상태와 함께 불러오므로, 학습이 중단된 지점부터 이어서 학습할 수 있다.

3. 에포크와 손실/평가지표 정보:
   - 특정 조건 (예: validation loss가 가장 낮을 때 등)을 만족할 때마다 체크포인트가 저장되며, 이를 통해 어떤 에포크에서 모델이 저장되었는지 추적할 수 있다.
이 코드를 사용하면 `val_loss`가 가장 낮은 에포크에서 모델이 자동으로 저장된다.

 

 

 

ModelCheckpoint의 주요 파라미터

  • filepath: 모델이 저장될 파일 경로. 체크포인트 파일의 저장 경로를 지정한다.
  • monitor: 모니터링할 값 (예: 'val_loss', 'val_accuracy' 등).
  • save_best_only: True로 설정하면, 성능이 향상된 경우에만 모델을 저장한다.
  • save_weights_only: True로 설정하면, 가중치만 저장한다.
  • mode: 모니터링 값이 min 또는 max일 때 저장한다. (예: val_loss는 최소값이 목표이므로 'min', val_accuracy는 최대값이 목표이므로 'max')

 

from tensorflow.keras.callbacks import ModelCheckpoint

# ModelCheckpoint 콜백 정의
checkpoint = ModelCheckpoint(
    filepath='best_model.h5',    # 모델이 저장될 경로
    monitor='val_loss',          # 모니터링할 메트릭
    save_best_only=True,         # 가장 성능이 좋은 모델만 저장
    save_weights_only=False,     # 가중치와 모델 구조 모두 저장
    mode='min',                  # 'val_loss'는 최소값이 목표
    verbose=1
)

# 모델 학습 시 콜백으로 추가
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=10,
    callbacks=[checkpoint]
)
 
 

 

 

반응형