Custom callback в Keras

Наиболее «красивый» вариант реализации Callback-а для Keras через наследование от класса keras.callbacks.Callback. В этом случае не нужно создавать глобальные переменные для сохранения расчета на предыдущих этапах, поскольку объект этого класса инициализируется однократно при запуске model.fit и в дальнейшем не удаляется. Соотвественно, вызов _init_ происходит 1 раз и далее в переменных класса уже можно хранить значения при вызове методов.

В Keras есть уже ряд написанных callback-ов. Пример написания тестового собственного:

import keras

class MyCustomCallback(keras.callbacks.Callback):
  var2_save = ""
  var3_save = 0

  def __init__(self, training_data):
    super().__init__() #вызываем конструктор базового класса
    self.var2_save = "On init: " + str(self.var3_save)
    print(var2_save)
    self.var3_save += 1
    self.training_data = training_data

  def on_epoch_end(self, epoch, logs=None):
    print(logs)
    print("The average loss for epoch {} is {:7.2f} and mean absolute error is {:7.2f}.".format(epoch, logs['loss'], logs['mean_absolute_error']))
    self.var2_save = "On epoch end: " + str(self.var3_save)
    print(self.var2_save)
    self.var3_save += 1

Запуск тренировки сети:

history = model.fit([xTrain01[:50000], xTrainProf01[:50000], xTrainRez01[:50000]], 
                    yTrain[:50000], 
                    epochs= 150, 
                    validation_data=([xTrain01[50000:], xTrainProf01[50000:], xTrainRez01[50000:]], 
                    yTrain[50000:]), 
                    verbose=0, 
                    callbacks=[MyCustomCallback(([xTrain01[:50000], xTrainProf01[:50000], xTrainRez01[:50000], yTrain[:50000]))

Обращаю внимание на передачу при вызове callbaсk-а нового класса MyCustomCallback обучающей выборки. Соотвественно, в конструкторе callback-а эти данные можно принять.

Второй вариант, это перед вызовом model.fit передать модели обучающую выборку через аттрибут training_data:

model.training_data = ([xTrain01[:50000], xTrainProf01[:50000], xTrainRez01[:50000]], yTrain[:50000])

Оба способа рабочие. Вопрос эстетики какой лучше использовать.

Результат работы при запуске тренировки сети:

{'val_loss': 3483.9638964047717, 'val_mean_absolute_error': 23.31727834354287, 'loss': 7476.93500734375, 'mean_absolute_error': 32.649654276123044}
The average loss for epoch 0 is 7476.94 and mean absolute error is   32.65.
On epoch end:0
{'val_loss': 3442.7591321254627, 'val_mean_absolute_error': 23.067948508804335, 'loss': 6907.680557949218, 'mean_absolute_error': 32.54259696472168}
The average loss for epoch 1 is 6907.68 and mean absolute error is   32.54.
On epoch end:1
{'val_loss': 3463.991695649211, 'val_mean_absolute_error': 23.16559474806618, 'loss': 7425.2137194726565, 'mean_absolute_error': 32.62528173950195}
The average loss for epoch 2 is 7425.21 and mean absolute error is   32.63.
On epoch end:2

Пример callback для Keras

Для примера разработал небольшой callback, который делает следующее:

  • Выводит текущую mae для эпохи.
  • Строит график mae по всем эпохам на обучающей и проверочной выборках.
  • Выводит текущий scatter для пройденных эпох.
  • Строит гистограмму ошибок для пройденных эпох.
  • Выводит время обучения на эпохе.
  • Выводит среднее время обучения на всех эпохах до теущей.
  • Выводит суммарное время обучения.
  • Выводит сколько эпох прошло и сколько осталось.
  • Выводит сколько примерно времени осталось до конца обучения.
  • Сохраняет в .h5 сеть с лучшей MAE.

Чтобы в logs появилось значение для mae нужно указать в аргументе при компиляции сети metrics:

model.compile(optimizer=Adam(lr=1e-5), loss='mse', metrics=['mae'])

Графики на каждой эпохе — скорее баловство. И при вызове PlotScatterAndHist(self) есть проблемы, colab падает.

import keras
import time
import datetime
import sys

class MyCustomCallback(keras.callbacks.Callback):
  #@classmethod
  def PlotGraph(self, epoch):
    if self.model is None:
      return
    
    History = self.model.history.history
    if epoch > 0:
      plt.plot(History['mean_absolute_error'], label='Средняя абсолютная ошибка на обучающем наборе')
      plt.plot(History['val_mean_absolute_error'], label='Средняя абсолютная ошибка на проверочном наборе')
      plt.xlabel('Эпоха обучения')
      plt.ylabel('Средняя абсолютная ошибка')
      plt.legend()
      plt.show() 

  #@classmethod
  def PlotScatterAndHist(self):
    if self.model is None:
      return
    
    yy = self.model.training_data[3] # Зарплата

    pred = self.model.predict([self.model.training_data[0], self.model.training_data[1], self.model.training_data[2]]) # Предсказанная зарплата
    #res = np.mean(abs(yy - pred))
    #print("Mean: ", res)
    plt.scatter(yy, pred)
    plt.xlabel('Правильные значение')
    plt.ylabel('Предсказания')
    plt.axis('equal')
    plt.xlim(plt.xlim())
    plt.ylim(plt.ylim())
    plt.plot([-1000, 5000], [-1000, 5000])
    plt.show()

    # Гистограмма ошибок
    delta = pred - yy # Вычитаем от предсказания правильную зп
    plt.hist(abs(delta).flatten(), bins = 30)
    plt.xlabel("Значение ошибки")
    plt.ylabel("Количество")
    plt.show() 

  def __init__(self):
    super().__init__()
    
    self.start_datetime = datetime.datetime.now()
    print("Start time:", self.start_datetime)
    self.sum_mae = 0
    self.best_mae = sys.maxsize;

  def on_epoch_begin(self, epoch, logs={}):
    self.epoch_time_start = time.time()  

  def on_epoch_end(self, epoch, logs=None):
    print(logs)

    self.sum_mae += logs['mean_absolute_error']
    print("Avg loss for epoch {} is {:7.2f} and MAE is {:7.2f}.".format(epoch, logs['loss'], logs['mean_absolute_error']))
    print("Total MAE: {:7.2f}.".format(self.sum_mae / (epoch + 1)))
    
    current_datetime = datetime.datetime.now()
    diff = current_datetime - self.start_datetime
    avg_time_per_epoch = (diff / (epoch + 1))
    left_epochs = (self.params['epochs'] - (epoch + 1))
    left = avg_time_per_epoch * left_epochs
    print("Время запуска", self.start_datetime, "Текущее время:", current_datetime, "Прошло, сек:", diff.total_seconds(), "Время на эпоху: {:.2f}".format(time.time() - self.epoch_time_start), "Среднее время на эпоху, сек: {:.2f}".format(avg_time_per_epoch.total_seconds()), "Осталось эпох:", left_epochs,"Осталось, сек: {:.2f}".format(left.total_seconds()))

    self.PlotGraph(epoch)

    self.PlotScatterAndHist()

    if logs['mean_absolute_error'] < self.best_mae:
      print("Найдено лучшее значение MAE. Было", self.best_mae, "Новое:", logs['mean_absolute_error'], "Сохраняю файл весов.")
      self.model.save_weights("best_weights.h5")
      self.best_mae = logs['mean_absolute_error'] #Сохраняем значение лучшего результата

Получение аргументов объекта в Python

При работе с callback-ом потребовалось посмотреть какие аргументы есть у объекта класса Callback и какими значениями заполнены. Это можно сделать вот таким способом:

    attrs = vars(self)
    attributes = attrs.items()
    for attr in attributes: 
      print(attr)

Чтобы посмотреть список аргументов объекта без вывода значений:

dir(self)
Spread the love
Запись опубликована в рубрике IT рецепты. Добавьте в закладки постоянную ссылку.

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *