Наиболее «красивый» вариант реализации Callback-а для Keras через наследование от класса keras.callbacks.Callback. В этом случае не нужно создавать глобальные переменные для сохранения расчета на предыдущих этапах, поскольку объект этого класса инициализируется однократно при запуске model.fit и в дальнейшем не удаляется. Соотвественно, вызов _init_ происходит 1 раз и далее в переменных класса уже можно хранить значения при вызове методов.
В Keras есть уже ряд написанных callback-ов. Пример написания тестового собственного:
import keras class MyCustomCallback(keras.callbacks.Callback): var2_save = "" var3_save = 0 def __init__(self, training_data): super().__init__() #вызываем конструктор базового класса self.var2_save = "On init: " + str(self.var3_save) print(var2_save) self.var3_save += 1 self.training_data = training_data def on_epoch_end(self, epoch, logs=None): print(logs) print("The average loss for epoch {} is {:7.2f} and mean absolute error is {:7.2f}.".format(epoch, logs['loss'], logs['mean_absolute_error'])) self.var2_save = "On epoch end: " + str(self.var3_save) print(self.var2_save) self.var3_save += 1
Запуск тренировки сети:
history = model.fit([xTrain01[:50000], xTrainProf01[:50000], xTrainRez01[:50000]], yTrain[:50000], epochs= 150, validation_data=([xTrain01[50000:], xTrainProf01[50000:], xTrainRez01[50000:]], yTrain[50000:]), verbose=0, callbacks=[MyCustomCallback(([xTrain01[:50000], xTrainProf01[:50000], xTrainRez01[:50000], yTrain[:50000]))
Обращаю внимание на передачу при вызове callbaсk-а нового класса MyCustomCallback обучающей выборки. Соотвественно, в конструкторе callback-а эти данные можно принять.
Второй вариант, это перед вызовом model.fit передать модели обучающую выборку через аттрибут training_data:
model.training_data = ([xTrain01[:50000], xTrainProf01[:50000], xTrainRez01[:50000]], yTrain[:50000])
Оба способа рабочие. Вопрос эстетики какой лучше использовать.
Результат работы при запуске тренировки сети:
{'val_loss': 3483.9638964047717, 'val_mean_absolute_error': 23.31727834354287, 'loss': 7476.93500734375, 'mean_absolute_error': 32.649654276123044} The average loss for epoch 0 is 7476.94 and mean absolute error is 32.65. On epoch end:0 {'val_loss': 3442.7591321254627, 'val_mean_absolute_error': 23.067948508804335, 'loss': 6907.680557949218, 'mean_absolute_error': 32.54259696472168} The average loss for epoch 1 is 6907.68 and mean absolute error is 32.54. On epoch end:1 {'val_loss': 3463.991695649211, 'val_mean_absolute_error': 23.16559474806618, 'loss': 7425.2137194726565, 'mean_absolute_error': 32.62528173950195} The average loss for epoch 2 is 7425.21 and mean absolute error is 32.63. On epoch end:2
Пример callback для Keras
Для примера разработал небольшой callback, который делает следующее:
- Выводит текущую mae для эпохи.
- Строит график mae по всем эпохам на обучающей и проверочной выборках.
- Выводит текущий scatter для пройденных эпох.
- Строит гистограмму ошибок для пройденных эпох.
- Выводит время обучения на эпохе.
- Выводит среднее время обучения на всех эпохах до теущей.
- Выводит суммарное время обучения.
- Выводит сколько эпох прошло и сколько осталось.
- Выводит сколько примерно времени осталось до конца обучения.
- Сохраняет в .h5 сеть с лучшей MAE.
Чтобы в logs появилось значение для mae нужно указать в аргументе при компиляции сети metrics:
model.compile(optimizer=Adam(lr=1e-5), loss='mse', metrics=['mae'])
Графики на каждой эпохе — скорее баловство. И при вызове PlotScatterAndHist(self) есть проблемы, colab падает.
import keras import time import datetime import sys class MyCustomCallback(keras.callbacks.Callback): #@classmethod def PlotGraph(self, epoch): if self.model is None: return History = self.model.history.history if epoch > 0: plt.plot(History['mean_absolute_error'], label='Средняя абсолютная ошибка на обучающем наборе') plt.plot(History['val_mean_absolute_error'], label='Средняя абсолютная ошибка на проверочном наборе') plt.xlabel('Эпоха обучения') plt.ylabel('Средняя абсолютная ошибка') plt.legend() plt.show() #@classmethod def PlotScatterAndHist(self): if self.model is None: return yy = self.model.training_data[3] # Зарплата pred = self.model.predict([self.model.training_data[0], self.model.training_data[1], self.model.training_data[2]]) # Предсказанная зарплата #res = np.mean(abs(yy - pred)) #print("Mean: ", res) plt.scatter(yy, pred) plt.xlabel('Правильные значение') plt.ylabel('Предсказания') plt.axis('equal') plt.xlim(plt.xlim()) plt.ylim(plt.ylim()) plt.plot([-1000, 5000], [-1000, 5000]) plt.show() # Гистограмма ошибок delta = pred - yy # Вычитаем от предсказания правильную зп plt.hist(abs(delta).flatten(), bins = 30) plt.xlabel("Значение ошибки") plt.ylabel("Количество") plt.show() def __init__(self): super().__init__() self.start_datetime = datetime.datetime.now() print("Start time:", self.start_datetime) self.sum_mae = 0 self.best_mae = sys.maxsize; def on_epoch_begin(self, epoch, logs={}): self.epoch_time_start = time.time() def on_epoch_end(self, epoch, logs=None): print(logs) self.sum_mae += logs['mean_absolute_error'] print("Avg loss for epoch {} is {:7.2f} and MAE is {:7.2f}.".format(epoch, logs['loss'], logs['mean_absolute_error'])) print("Total MAE: {:7.2f}.".format(self.sum_mae / (epoch + 1))) current_datetime = datetime.datetime.now() diff = current_datetime - self.start_datetime avg_time_per_epoch = (diff / (epoch + 1)) left_epochs = (self.params['epochs'] - (epoch + 1)) left = avg_time_per_epoch * left_epochs print("Время запуска", self.start_datetime, "Текущее время:", current_datetime, "Прошло, сек:", diff.total_seconds(), "Время на эпоху: {:.2f}".format(time.time() - self.epoch_time_start), "Среднее время на эпоху, сек: {:.2f}".format(avg_time_per_epoch.total_seconds()), "Осталось эпох:", left_epochs,"Осталось, сек: {:.2f}".format(left.total_seconds())) self.PlotGraph(epoch) self.PlotScatterAndHist() if logs['mean_absolute_error'] < self.best_mae: print("Найдено лучшее значение MAE. Было", self.best_mae, "Новое:", logs['mean_absolute_error'], "Сохраняю файл весов.") self.model.save_weights("best_weights.h5") self.best_mae = logs['mean_absolute_error'] #Сохраняем значение лучшего результата
Получение аргументов объекта в Python
При работе с callback-ом потребовалось посмотреть какие аргументы есть у объекта класса Callback и какими значениями заполнены. Это можно сделать вот таким способом:
attrs = vars(self) attributes = attrs.items() for attr in attributes: print(attr)
Чтобы посмотреть список аргументов объекта без вывода значений:
dir(self)