"""
This file contains the loss functions for the normalizing flow training.
Since we are combining different loss functions we have different options.
"""
###############################################################################
# initial imports and set-up:
import numpy as np
# tensorflow imports:
import tensorflow as tf
###############################################################################
# helpers:
def _broadcast_sample_weight(sample_weight, losses):
"""
Broadcast sample weights to match loss tensor shape.
:param sample_weight: weights to apply.
:param losses: loss tensor.
:returns: broadcasted weights matching ``losses``.
"""
if sample_weight is None:
return None
sample_weight = tf.cast(sample_weight, losses.dtype)
if sample_weight.shape.rank == 0:
return tf.ones_like(losses) * sample_weight
if losses.shape.rank is None or sample_weight.shape.rank is None:
losses_rank = tf.rank(losses)
weight_rank = tf.rank(sample_weight)
rank_diff = losses_rank - weight_rank
new_shape = tf.concat([tf.shape(sample_weight), tf.ones(rank_diff, tf.int32)], axis=0)
sample_weight = tf.reshape(sample_weight, new_shape)
return tf.broadcast_to(sample_weight, tf.shape(losses))
if sample_weight.shape.rank < losses.shape.rank:
new_shape = sample_weight.shape.as_list() + [1] * (losses.shape.rank - sample_weight.shape.rank)
sample_weight = tf.reshape(sample_weight, new_shape)
return tf.broadcast_to(sample_weight, tf.shape(losses))
def _reduce_weighted_loss(losses, sample_weight, reduction):
"""
Reduce weighted losses according to Keras reduction rules.
:param losses: per-sample losses.
:param sample_weight: sample weights for each loss.
:param reduction: Keras reduction enum.
:returns: reduced loss tensor.
"""
weights = _broadcast_sample_weight(sample_weight, losses)
if weights is not None:
losses = losses * weights
if reduction == tf.keras.losses.Reduction.NONE:
return losses
total_loss = tf.reduce_sum(losses)
if reduction == tf.keras.losses.Reduction.SUM:
return total_loss
if weights is None:
denom = tf.cast(tf.size(losses), losses.dtype)
else:
denom = tf.reduce_sum(weights)
denom = tf.where(denom == 0, tf.constant(1.0, dtype=losses.dtype), denom)
return total_loss / denom
###############################################################################
# standard normalizing flow loss function:
[docs]
class standard_loss(tf.keras.losses.Loss):
"""KL-based density loss for normalizing flow training."""
def __init__(self):
"""
Standard density loss function for the normalizing flow.
"""
# initialize:
super(standard_loss, self).__init__()
#
return None
[docs]
def compute_loss_components(self, y_true, y_pred, sample_weight):
"""
Compute the signed log-probability contribution.
:param y_true: target log density (unused).
:param y_pred: predicted log density.
:param sample_weight: sample weights.
:returns: negative predicted log density.
"""
return -y_pred
[docs]
def call(self, y_true, y_pred):
"""
Standard normalizing flow loss function is KL divergence of two abstract
distributions.
"""
return -y_pred
[docs]
def print_feedback(self, padding=''):
"""Print the configured loss details."""
print(padding+'using standard loss function')
[docs]
def reset(self):
"""Reset loss hyperparameters (no-op for standard loss)."""
pass
###############################################################################
# density and evidence loss with constant weights:
[docs]
class constant_weight_loss(tf.keras.losses.Loss):
"""Combined density and evidence-error loss with fixed weights."""
def __init__(self, alpha=1.0, beta=0.0):
"""
Initialize the fixed-weight loss.
:param alpha: weight for the density component.
:param beta: additive offset applied to predicted log density.
"""
# initialize:
super(constant_weight_loss, self).__init__()
# set parameters:
self.alpha = alpha
self.beta = beta
#
return None
[docs]
def compute_loss_components(self, y_true, y_pred, sample_weight):
"""
Compute density and evidence-error components.
:param y_true: target log density.
:param y_pred: predicted log density.
:param sample_weight: weights for each sample.
:returns: tuple of density loss and variance of residuals.
"""
# compute difference between true and predicted posterior values:
diffs = (y_true - y_pred)
# sum weights:
tot_weights = tf.reduce_sum(sample_weight)
# compute overall offset:
mean_diff = tf.reduce_sum(diffs*sample_weight) / tot_weights
# compute its variance:
var_diff = (diffs - mean_diff)**2
# compute density loss function:
loss_orig = -(y_pred + self.beta)
#
return loss_orig, var_diff
[docs]
def compute_loss(self, y_true, y_pred, sample_weight):
"""
Combine density and evidence-error loss.
:param y_true: target log density.
:param y_pred: predicted log density.
:param sample_weight: weights for each sample.
:returns: weighted sum of loss components.
"""
# get components:
loss_1, loss_2 = self.compute_loss_components(y_true, y_pred, sample_weight)
#
return +self.alpha*loss_1 + (1. - self.alpha)*loss_2
def __call__(self, y_true, y_pred, sample_weight=None):
"""
Override ``Loss.__call__`` to support explicit ``sample_weight``.
:param y_true: target log density.
:param y_pred: predicted log density.
:param sample_weight: weights for each sample.
:returns: weighted loss tensor respecting reduction mode.
"""
if sample_weight is None:
sample_weight = tf.ones_like(y_pred, dtype=y_pred.dtype)
name_scope = getattr(self, "_name_scope", None)
if not name_scope:
name_scope = getattr(self, "name", self.__class__.__name__)
with tf.name_scope(name_scope):
losses = self.compute_loss(y_true, y_pred, sample_weight)
return _reduce_weighted_loss(losses, sample_weight, self.reduction)
[docs]
def print_feedback(self, padding=''):
"""Print the configured fixed-weight loss settings."""
print(padding+'using combined density and evidence-error loss function')
print(padding+'weight of density loss: %.3g, weight of evidence-error-loss: %.3g' % (self.alpha, 1.-self.alpha))
[docs]
def reset(self):
"""Reset loss hyperparameters (no-op for constant weights)."""
pass
###############################################################################
# density and evidence loss with variable weights:
[docs]
class variable_weight_loss(tf.keras.losses.Loss):
"""Combined loss with trainable weights updated during training."""
def __init__(self, lambda_1=1.0, lambda_2=0.0, beta=0.0):
"""
Initialize the variable-weight loss.
:param lambda_1: initial weight for the density term.
:param lambda_2: initial weight for the evidence-error term.
:param beta: additive offset applied to predicted log density.
"""
# initialize:
super(variable_weight_loss, self).__init__()
# set parameters:
self.lambda_1 = tf.Variable(lambda_1, trainable=False, name='loss_lambda_1', dtype=type(lambda_1))
self.lambda_2 = tf.Variable(lambda_2, trainable=False, name='loss_lambda_2', dtype=type(lambda_2))
self.beta = tf.Variable(beta, trainable=False, name='loss_beta', dtype=type(beta))
# save initial parameters:
self.initial_lambda_1 = lambda_1
self.initial_lambda_2 = lambda_2
self.initial_beta = beta
#
return None
[docs]
def update_lambda_values_on_epoch_begin(self, epoch, **kwargs):
"""
Update values of lambda at epoch start. Takes in every kwargs to not
crowd the interface...
:param epoch: current epoch index.
:param kwargs: unused passthrough arguments for compatibility.
:raises NotImplementedError: expected to be overridden in subclasses.
"""
# base class is empty...
# use the following sintax:
# tf.keras.backend.set_value(self.lambda_1, tf.constant(0.5) * epoch)
raise NotImplementedError
[docs]
def compute_loss_components(self, y_true, y_pred, sample_weight, lambda_1=None, lambda_2=None):
"""
Compute density and evidence-error components with configurable weights.
:param y_true: target log density.
:param y_pred: predicted log density.
:param sample_weight: weights for each sample.
:param lambda_1: optional override for density weight.
:param lambda_2: optional override for evidence-error weight.
:returns: tuple of density loss, variance of residuals, and active weights.
"""
# compute difference between true and predicted posterior values:
diffs = (y_true - y_pred)
# sum weights:
tot_weights = tf.reduce_sum(sample_weight)
# compute overall offset:
mean_diff = tf.reduce_sum(diffs*sample_weight) / tot_weights
# compute its variance:
var_diff = (diffs - mean_diff)**2
# compute density loss function:
loss_orig = -(y_pred + self.beta)
# get weights if not passed:
if lambda_1 is None:
lambda_1 = tf.keras.backend.get_value(self.lambda_1)
if lambda_2 is None:
lambda_2 = tf.keras.backend.get_value(self.lambda_2)
#
return loss_orig, var_diff, lambda_1, lambda_2
[docs]
def compute_loss(self, y_true, y_pred, sample_weight):
"""
Combine density and evidence-error loss using current weights.
:param y_true: target log density.
:param y_pred: predicted log density.
:param sample_weight: weights for each sample.
:returns: weighted loss value.
"""
# get components:
loss_1, loss_2, lambda_1,lambda_2 = self.compute_loss_components(y_true, y_pred, sample_weight, self.lambda_1, self.lambda_2)
#
return lambda_1*loss_1 + lambda_2*loss_2
def __call__(self, y_true, y_pred, sample_weight=None):
"""
Override ``Loss.__call__`` to support explicit ``sample_weight``.
:param y_true: target log density.
:param y_pred: predicted log density.
:param sample_weight: weights for each sample.
:returns: weighted loss tensor respecting reduction mode.
"""
if sample_weight is None:
sample_weight = tf.ones_like(y_pred, dtype=y_pred.dtype)
name_scope = getattr(self, "_name_scope", None)
if not name_scope:
name_scope = getattr(self, "name", self.__class__.__name__)
with tf.name_scope(name_scope):
losses = self.compute_loss(y_true, y_pred, sample_weight)
return _reduce_weighted_loss(losses, sample_weight, self.reduction)
[docs]
def print_feedback(self, padding=''):
"""
Print feedback to screen
"""
raise NotImplementedError
[docs]
def reset(self):
"""
Reset loss functions hyper parameters
"""
self.__init__(lambda_1=self.initial_lambda_1, lambda_2=self.initial_lambda_2, beta=self.initial_beta)
[docs]
class random_weight_loss(variable_weight_loss):
"""
Random weighting of the two loss functions.
"""
def __init__(self, initial_random_epoch=0, lambda_1=1.0, beta=0.0, **kwargs):
"""
Initialize loss function
"""
# initialize:
super(random_weight_loss, self).__init__(lambda_1, 1.-lambda_1, beta)
# set parameters:
self.initial_random_epoch = initial_random_epoch
#
return None
[docs]
def update_lambda_values_on_epoch_begin(self, epoch, **kwargs):
"""
Update values of lambda at epoch start. Takes in every kwargs to not
crowd the interface...
"""
if epoch > self.initial_random_epoch:
_temp_rand = np.random.randint(2)
tf.keras.backend.set_value(self.lambda_1, _temp_rand)
tf.keras.backend.set_value(self.lambda_2, 1.-_temp_rand)
#
return None
[docs]
def print_feedback(self, padding=''):
"""
Print feedback to screen
"""
print(padding+'using randomized loss function')
[docs]
class annealed_weight_loss(variable_weight_loss):
"""
Slowly go from density to evidence-error loss.
"""
def __init__(self, anneal_epoch=125, lambda_1=1.0, beta=0.0, roll_off_nepoch=10, **kwargs):
"""
Initialize loss function
"""
# initialize:
super(annealed_weight_loss, self).__init__(lambda_1, 1.-lambda_1, beta)
# set parameters:
self.anneal_epoch = anneal_epoch
self.roll_off_nepoch = roll_off_nepoch
#
return None
[docs]
def update_lambda_values_on_epoch_begin(self, epoch, **kwargs):
"""
Update values of lambda at epoch start. Takes in every kwargs to not
crowd the interface...
"""
if epoch > self.anneal_epoch:
_lambda_1 = tf.keras.backend.get_value(self.lambda_1)
_lambda_1 *= np.exp(-1.*(epoch - self.anneal_epoch)/self.roll_off_nepoch)
tf.keras.backend.set_value(self.lambda_1, _lambda_1)
tf.keras.backend.set_value(self.lambda_2, 1.-_lambda_1)
#
return None
[docs]
def print_feedback(self, padding=''):
"""
Print feedback to screen
"""
print(padding+'using annealed loss function')
[docs]
class SoftAdapt_weight_loss(variable_weight_loss):
"""
Implement SoftAdapt as in arXiv:1912.12355, with optional smoothing
"""
def __init__(self, tau=1.0, beta=0.0, smoothing=True, smoothing_tau=20, quantity_1='val_rho_loss', quantity_2='val_ee_loss', **kwargs):
"""
Initialize loss function
"""
# initialize:
super(SoftAdapt_weight_loss, self).__init__()
# set parameters:
self.tau = tau
self.beta = beta
self.smoothing = smoothing
self.smoothing_alpha = 1. / smoothing_tau
# quantrities to be monitored:
self.quantity_1 = quantity_1
self.quantity_2 = quantity_2
self.rate_1_buffer = 0.0
self.rate_2_buffer = 0.0
self.lambda_1_buffer = 1.0
self.lambda_2_buffer = 0.0
#
return None
[docs]
def update_lambda_values_on_epoch_begin(self, epoch, **kwargs):
"""
Update values of lambda at epoch start. Takes in every kwargs to not
crowd the interface...
"""
# get logs:
logs = kwargs.get('logs')
quantity_1 = logs[self.quantity_1]
quantity_2 = logs[self.quantity_2]
# get the two rates:
if len(quantity_1) < 2:
rate_1 = 0.0
else:
rate_1 = quantity_1[-1] - quantity_1[-2]
if len(quantity_2) < 2:
rate_2 = 0.0
else:
rate_2 = quantity_2[-1] - quantity_2[-2]
# smooth the two rates:
if self.smoothing:
rate_1 = self.smoothing_alpha * rate_1 + (1. - self.smoothing_alpha) * self.rate_1_buffer
rate_2 = self.smoothing_alpha * rate_2 + (1. - self.smoothing_alpha) * self.rate_2_buffer
self.rate_1_buffer = rate_1
self.rate_2_buffer = rate_2
# protect for initial phase:
if rate_1 == 0.0:
_lambda_1 = 1.0
else:
lambda_1 = np.exp(self.tau * rate_1)
lambda_2 = np.exp(self.tau * rate_2)
_tot = lambda_1 + lambda_2
if self.smoothing:
_lambda_1 = self.smoothing_alpha * lambda_1 / _tot + (1. - self.smoothing_alpha) * self.lambda_1_buffer
_lambda_2 = self.smoothing_alpha * lambda_2 / _tot + (1. - self.smoothing_alpha) * self.lambda_2_buffer
_lambda_1 = _lambda_1 / (_lambda_1 + _lambda_2)
_lambda_2 = _lambda_2 / (_lambda_1 + _lambda_2)
self.lambda_1_buffer = _lambda_1
self.lambda_2_buffer = _lambda_2
else:
_lambda_1 = lambda_1 / _tot
# set second by enforcing sum to one:
tf.keras.backend.set_value(self.lambda_1, _lambda_1)
tf.keras.backend.set_value(self.lambda_2, 1.-_lambda_1)
#
return None
[docs]
def print_feedback(self, padding=''):
"""
Print feedback to screen
"""
print(padding+'using SoftAdapt loss function')
if self.smoothing:
print(padding+' with smoothing')
print(padding+' with smoothing_tau = ', 1./self.smoothing_alpha)
print(padding+' with tau = ', self.tau)
[docs]
class SharpStep(variable_weight_loss):
"""
Implement sharp stepping between two values
"""
def __init__(self, step_epoch=50, value_1=1.0, value_2=0.1, beta=0., **kwargs):
"""
Initialize loss function
"""
# initialize:
super(SharpStep, self).__init__()
# set parameters:
self.step_epoch = step_epoch
self.value_1 = value_1
self.value_2 = value_2
# initialize:
self.beta = beta
#
return None
[docs]
def update_lambda_values_on_epoch_begin(self, epoch, **kwargs):
"""
Update values of lambda at epoch start. Takes in every kwargs to not
crowd the interface...
"""
if epoch < self.step_epoch:
lambda_1 = self.value_1
else:
lambda_1 = self.value_2
# set second by enforcing sum to one:
lambda_2 = 1. - lambda_1
#
tf.keras.backend.set_value(self.lambda_1, lambda_1)
tf.keras.backend.set_value(self.lambda_2, lambda_2)
#
return None
[docs]
def print_feedback(self, padding=''):
"""
Print feedback to screen
"""
print(padding+'using sharp step loss function')