Наиболее «красивый» вариант реализации Callback-а для Keras через наследование от класса keras.callbacks.Callback. В этом случае не нужно создавать глобальные переменные для сохранения расчета на предыдущих этапах, поскольку объект этого класса инициализируется однократно при запуске model.fit и в дальнейшем не удаляется. Соотвественно, вызов _init_ происходит 1 раз и далее в переменных класса уже можно хранить значения при вызове методов.
В Keras есть уже ряд написанных callback-ов. Пример написания тестового собственного:
import keras
class MyCustomCallback(keras.callbacks.Callback):
var2_save = ""
var3_save = 0
def __init__(self, training_data):
super().__init__() #вызываем конструктор базового класса
self.var2_save = "On init: " + str(self.var3_save)
print(var2_save)
self.var3_save += 1
self.training_data = training_data
def on_epoch_end(self, epoch, logs=None):
print(logs)
print("The average loss for epoch {} is {:7.2f} and mean absolute error is {:7.2f}.".format(epoch, logs['loss'], logs['mean_absolute_error']))
self.var2_save = "On epoch end: " + str(self.var3_save)
print(self.var2_save)
self.var3_save += 1Запуск тренировки сети:
history = model.fit([xTrain01[:50000], xTrainProf01[:50000], xTrainRez01[:50000]],
yTrain[:50000],
epochs= 150,
validation_data=([xTrain01[50000:], xTrainProf01[50000:], xTrainRez01[50000:]],
yTrain[50000:]),
verbose=0,
callbacks=[MyCustomCallback(([xTrain01[:50000], xTrainProf01[:50000], xTrainRez01[:50000], yTrain[:50000]))Обращаю внимание на передачу при вызове callbaсk-а нового класса MyCustomCallback обучающей выборки. Соотвественно, в конструкторе callback-а эти данные можно принять.
Второй вариант, это перед вызовом model.fit передать модели обучающую выборку через аттрибут training_data:
model.training_data = ([xTrain01[:50000], xTrainProf01[:50000], xTrainRez01[:50000]], yTrain[:50000])
Оба способа рабочие. Вопрос эстетики какой лучше использовать.
Результат работы при запуске тренировки сети:
{'val_loss': 3483.9638964047717, 'val_mean_absolute_error': 23.31727834354287, 'loss': 7476.93500734375, 'mean_absolute_error': 32.649654276123044}
The average loss for epoch 0 is 7476.94 and mean absolute error is 32.65.
On epoch end:0
{'val_loss': 3442.7591321254627, 'val_mean_absolute_error': 23.067948508804335, 'loss': 6907.680557949218, 'mean_absolute_error': 32.54259696472168}
The average loss for epoch 1 is 6907.68 and mean absolute error is 32.54.
On epoch end:1
{'val_loss': 3463.991695649211, 'val_mean_absolute_error': 23.16559474806618, 'loss': 7425.2137194726565, 'mean_absolute_error': 32.62528173950195}
The average loss for epoch 2 is 7425.21 and mean absolute error is 32.63.
On epoch end:2Пример callback для Keras
Для примера разработал небольшой callback, который делает следующее:
- Выводит текущую mae для эпохи.
- Строит график mae по всем эпохам на обучающей и проверочной выборках.
- Выводит текущий scatter для пройденных эпох.
- Строит гистограмму ошибок для пройденных эпох.
- Выводит время обучения на эпохе.
- Выводит среднее время обучения на всех эпохах до теущей.
- Выводит суммарное время обучения.
- Выводит сколько эпох прошло и сколько осталось.
- Выводит сколько примерно времени осталось до конца обучения.
- Сохраняет в .h5 сеть с лучшей MAE.
Чтобы в logs появилось значение для mae нужно указать в аргументе при компиляции сети metrics:
model.compile(optimizer=Adam(lr=1e-5), loss='mse', metrics=['mae'])
Графики на каждой эпохе — скорее баловство. И при вызове PlotScatterAndHist(self) есть проблемы, colab падает.
import keras
import time
import datetime
import sys
class MyCustomCallback(keras.callbacks.Callback):
#@classmethod
def PlotGraph(self, epoch):
if self.model is None:
return
History = self.model.history.history
if epoch > 0:
plt.plot(History['mean_absolute_error'], label='Средняя абсолютная ошибка на обучающем наборе')
plt.plot(History['val_mean_absolute_error'], label='Средняя абсолютная ошибка на проверочном наборе')
plt.xlabel('Эпоха обучения')
plt.ylabel('Средняя абсолютная ошибка')
plt.legend()
plt.show()
#@classmethod
def PlotScatterAndHist(self):
if self.model is None:
return
yy = self.model.training_data[3] # Зарплата
pred = self.model.predict([self.model.training_data[0], self.model.training_data[1], self.model.training_data[2]]) # Предсказанная зарплата
#res = np.mean(abs(yy - pred))
#print("Mean: ", res)
plt.scatter(yy, pred)
plt.xlabel('Правильные значение')
plt.ylabel('Предсказания')
plt.axis('equal')
plt.xlim(plt.xlim())
plt.ylim(plt.ylim())
plt.plot([-1000, 5000], [-1000, 5000])
plt.show()
# Гистограмма ошибок
delta = pred - yy # Вычитаем от предсказания правильную зп
plt.hist(abs(delta).flatten(), bins = 30)
plt.xlabel("Значение ошибки")
plt.ylabel("Количество")
plt.show()
def __init__(self):
super().__init__()
self.start_datetime = datetime.datetime.now()
print("Start time:", self.start_datetime)
self.sum_mae = 0
self.best_mae = sys.maxsize;
def on_epoch_begin(self, epoch, logs={}):
self.epoch_time_start = time.time()
def on_epoch_end(self, epoch, logs=None):
print(logs)
self.sum_mae += logs['mean_absolute_error']
print("Avg loss for epoch {} is {:7.2f} and MAE is {:7.2f}.".format(epoch, logs['loss'], logs['mean_absolute_error']))
print("Total MAE: {:7.2f}.".format(self.sum_mae / (epoch + 1)))
current_datetime = datetime.datetime.now()
diff = current_datetime - self.start_datetime
avg_time_per_epoch = (diff / (epoch + 1))
left_epochs = (self.params['epochs'] - (epoch + 1))
left = avg_time_per_epoch * left_epochs
print("Время запуска", self.start_datetime, "Текущее время:", current_datetime, "Прошло, сек:", diff.total_seconds(), "Время на эпоху: {:.2f}".format(time.time() - self.epoch_time_start), "Среднее время на эпоху, сек: {:.2f}".format(avg_time_per_epoch.total_seconds()), "Осталось эпох:", left_epochs,"Осталось, сек: {:.2f}".format(left.total_seconds()))
self.PlotGraph(epoch)
self.PlotScatterAndHist()
if logs['mean_absolute_error'] < self.best_mae:
print("Найдено лучшее значение MAE. Было", self.best_mae, "Новое:", logs['mean_absolute_error'], "Сохраняю файл весов.")
self.model.save_weights("best_weights.h5")
self.best_mae = logs['mean_absolute_error'] #Сохраняем значение лучшего результатаПолучение аргументов объекта в Python
При работе с callback-ом потребовалось посмотреть какие аргументы есть у объекта класса Callback и какими значениями заполнены. Это можно сделать вот таким способом:
attrs = vars(self)
attributes = attrs.items()
for attr in attributes:
print(attr)Чтобы посмотреть список аргументов объекта без вывода значений:
dir(self)