Source code for tensiometer.synthetic_probability.loss_functions

"""
This file contains the loss functions for the normalizing flow training.

Since we are combining different loss functions we have different options.
"""

###############################################################################
# initial imports and set-up:

import numpy as np

# tensorflow imports:
import tensorflow as tf

###############################################################################
# helpers:

def _broadcast_sample_weight(sample_weight, losses):
    """
    Broadcast sample weights to match loss tensor shape.

    :param sample_weight: weights to apply.
    :param losses: loss tensor.
    :returns: broadcasted weights matching ``losses``.
    """
    if sample_weight is None:
        return None
    sample_weight = tf.cast(sample_weight, losses.dtype)
    if sample_weight.shape.rank == 0:
        return tf.ones_like(losses) * sample_weight
    if losses.shape.rank is None or sample_weight.shape.rank is None:
        losses_rank = tf.rank(losses)
        weight_rank = tf.rank(sample_weight)
        rank_diff = losses_rank - weight_rank
        new_shape = tf.concat([tf.shape(sample_weight), tf.ones(rank_diff, tf.int32)], axis=0)
        sample_weight = tf.reshape(sample_weight, new_shape)
        return tf.broadcast_to(sample_weight, tf.shape(losses))
    if sample_weight.shape.rank < losses.shape.rank:
        new_shape = sample_weight.shape.as_list() + [1] * (losses.shape.rank - sample_weight.shape.rank)
        sample_weight = tf.reshape(sample_weight, new_shape)
    return tf.broadcast_to(sample_weight, tf.shape(losses))


def _reduce_weighted_loss(losses, sample_weight, reduction):
    """
    Reduce weighted losses according to Keras reduction rules.

    :param losses: per-sample losses.
    :param sample_weight: sample weights for each loss.
    :param reduction: Keras reduction enum.
    :returns: reduced loss tensor.
    """
    weights = _broadcast_sample_weight(sample_weight, losses)
    if weights is not None:
        losses = losses * weights
    if reduction == tf.keras.losses.Reduction.NONE:
        return losses
    total_loss = tf.reduce_sum(losses)
    if reduction == tf.keras.losses.Reduction.SUM:
        return total_loss
    if weights is None:
        denom = tf.cast(tf.size(losses), losses.dtype)
    else:
        denom = tf.reduce_sum(weights)
    denom = tf.where(denom == 0, tf.constant(1.0, dtype=losses.dtype), denom)
    return total_loss / denom

###############################################################################
# standard normalizing flow loss function:


[docs] class standard_loss(tf.keras.losses.Loss): """KL-based density loss for normalizing flow training.""" def __init__(self): """ Standard density loss function for the normalizing flow. """ # initialize: super(standard_loss, self).__init__() # return None
[docs] def compute_loss_components(self, y_true, y_pred, sample_weight): """ Compute the signed log-probability contribution. :param y_true: target log density (unused). :param y_pred: predicted log density. :param sample_weight: sample weights. :returns: negative predicted log density. """ return -y_pred
[docs] def call(self, y_true, y_pred): """ Standard normalizing flow loss function is KL divergence of two abstract distributions. """ return -y_pred
[docs] def print_feedback(self, padding=''): """Print the configured loss details.""" print(padding+'using standard loss function')
[docs] def reset(self): """Reset loss hyperparameters (no-op for standard loss).""" pass
############################################################################### # density and evidence loss with constant weights:
[docs] class constant_weight_loss(tf.keras.losses.Loss): """Combined density and evidence-error loss with fixed weights.""" def __init__(self, alpha=1.0, beta=0.0): """ Initialize the fixed-weight loss. :param alpha: weight for the density component. :param beta: additive offset applied to predicted log density. """ # initialize: super(constant_weight_loss, self).__init__() # set parameters: self.alpha = alpha self.beta = beta # return None
[docs] def compute_loss_components(self, y_true, y_pred, sample_weight): """ Compute density and evidence-error components. :param y_true: target log density. :param y_pred: predicted log density. :param sample_weight: weights for each sample. :returns: tuple of density loss and variance of residuals. """ # compute difference between true and predicted posterior values: diffs = (y_true - y_pred) # sum weights: tot_weights = tf.reduce_sum(sample_weight) # compute overall offset: mean_diff = tf.reduce_sum(diffs*sample_weight) / tot_weights # compute its variance: var_diff = (diffs - mean_diff)**2 # compute density loss function: loss_orig = -(y_pred + self.beta) # return loss_orig, var_diff
[docs] def compute_loss(self, y_true, y_pred, sample_weight): """ Combine density and evidence-error loss. :param y_true: target log density. :param y_pred: predicted log density. :param sample_weight: weights for each sample. :returns: weighted sum of loss components. """ # get components: loss_1, loss_2 = self.compute_loss_components(y_true, y_pred, sample_weight) # return +self.alpha*loss_1 + (1. - self.alpha)*loss_2
def __call__(self, y_true, y_pred, sample_weight=None): """ Override ``Loss.__call__`` to support explicit ``sample_weight``. :param y_true: target log density. :param y_pred: predicted log density. :param sample_weight: weights for each sample. :returns: weighted loss tensor respecting reduction mode. """ if sample_weight is None: sample_weight = tf.ones_like(y_pred, dtype=y_pred.dtype) name_scope = getattr(self, "_name_scope", None) if not name_scope: name_scope = getattr(self, "name", self.__class__.__name__) with tf.name_scope(name_scope): losses = self.compute_loss(y_true, y_pred, sample_weight) return _reduce_weighted_loss(losses, sample_weight, self.reduction)
[docs] def print_feedback(self, padding=''): """Print the configured fixed-weight loss settings.""" print(padding+'using combined density and evidence-error loss function') print(padding+'weight of density loss: %.3g, weight of evidence-error-loss: %.3g' % (self.alpha, 1.-self.alpha))
[docs] def reset(self): """Reset loss hyperparameters (no-op for constant weights).""" pass
############################################################################### # density and evidence loss with variable weights:
[docs] class variable_weight_loss(tf.keras.losses.Loss): """Combined loss with trainable weights updated during training.""" def __init__(self, lambda_1=1.0, lambda_2=0.0, beta=0.0): """ Initialize the variable-weight loss. :param lambda_1: initial weight for the density term. :param lambda_2: initial weight for the evidence-error term. :param beta: additive offset applied to predicted log density. """ # initialize: super(variable_weight_loss, self).__init__() # set parameters: self.lambda_1 = tf.Variable(lambda_1, trainable=False, name='loss_lambda_1', dtype=type(lambda_1)) self.lambda_2 = tf.Variable(lambda_2, trainable=False, name='loss_lambda_2', dtype=type(lambda_2)) self.beta = tf.Variable(beta, trainable=False, name='loss_beta', dtype=type(beta)) # save initial parameters: self.initial_lambda_1 = lambda_1 self.initial_lambda_2 = lambda_2 self.initial_beta = beta # return None
[docs] def update_lambda_values_on_epoch_begin(self, epoch, **kwargs): """ Update values of lambda at epoch start. Takes in every kwargs to not crowd the interface... :param epoch: current epoch index. :param kwargs: unused passthrough arguments for compatibility. :raises NotImplementedError: expected to be overridden in subclasses. """ # base class is empty... # use the following sintax: # tf.keras.backend.set_value(self.lambda_1, tf.constant(0.5) * epoch) raise NotImplementedError
[docs] def compute_loss_components(self, y_true, y_pred, sample_weight, lambda_1=None, lambda_2=None): """ Compute density and evidence-error components with configurable weights. :param y_true: target log density. :param y_pred: predicted log density. :param sample_weight: weights for each sample. :param lambda_1: optional override for density weight. :param lambda_2: optional override for evidence-error weight. :returns: tuple of density loss, variance of residuals, and active weights. """ # compute difference between true and predicted posterior values: diffs = (y_true - y_pred) # sum weights: tot_weights = tf.reduce_sum(sample_weight) # compute overall offset: mean_diff = tf.reduce_sum(diffs*sample_weight) / tot_weights # compute its variance: var_diff = (diffs - mean_diff)**2 # compute density loss function: loss_orig = -(y_pred + self.beta) # get weights if not passed: if lambda_1 is None: lambda_1 = tf.keras.backend.get_value(self.lambda_1) if lambda_2 is None: lambda_2 = tf.keras.backend.get_value(self.lambda_2) # return loss_orig, var_diff, lambda_1, lambda_2
[docs] def compute_loss(self, y_true, y_pred, sample_weight): """ Combine density and evidence-error loss using current weights. :param y_true: target log density. :param y_pred: predicted log density. :param sample_weight: weights for each sample. :returns: weighted loss value. """ # get components: loss_1, loss_2, lambda_1,lambda_2 = self.compute_loss_components(y_true, y_pred, sample_weight, self.lambda_1, self.lambda_2) # return lambda_1*loss_1 + lambda_2*loss_2
def __call__(self, y_true, y_pred, sample_weight=None): """ Override ``Loss.__call__`` to support explicit ``sample_weight``. :param y_true: target log density. :param y_pred: predicted log density. :param sample_weight: weights for each sample. :returns: weighted loss tensor respecting reduction mode. """ if sample_weight is None: sample_weight = tf.ones_like(y_pred, dtype=y_pred.dtype) name_scope = getattr(self, "_name_scope", None) if not name_scope: name_scope = getattr(self, "name", self.__class__.__name__) with tf.name_scope(name_scope): losses = self.compute_loss(y_true, y_pred, sample_weight) return _reduce_weighted_loss(losses, sample_weight, self.reduction)
[docs] def print_feedback(self, padding=''): """ Print feedback to screen """ raise NotImplementedError
[docs] def reset(self): """ Reset loss functions hyper parameters """ self.__init__(lambda_1=self.initial_lambda_1, lambda_2=self.initial_lambda_2, beta=self.initial_beta)
[docs] class random_weight_loss(variable_weight_loss): """ Random weighting of the two loss functions. """ def __init__(self, initial_random_epoch=0, lambda_1=1.0, beta=0.0, **kwargs): """ Initialize loss function """ # initialize: super(random_weight_loss, self).__init__(lambda_1, 1.-lambda_1, beta) # set parameters: self.initial_random_epoch = initial_random_epoch # return None
[docs] def update_lambda_values_on_epoch_begin(self, epoch, **kwargs): """ Update values of lambda at epoch start. Takes in every kwargs to not crowd the interface... """ if epoch > self.initial_random_epoch: _temp_rand = np.random.randint(2) tf.keras.backend.set_value(self.lambda_1, _temp_rand) tf.keras.backend.set_value(self.lambda_2, 1.-_temp_rand) # return None
[docs] def print_feedback(self, padding=''): """ Print feedback to screen """ print(padding+'using randomized loss function')
[docs] class annealed_weight_loss(variable_weight_loss): """ Slowly go from density to evidence-error loss. """ def __init__(self, anneal_epoch=125, lambda_1=1.0, beta=0.0, roll_off_nepoch=10, **kwargs): """ Initialize loss function """ # initialize: super(annealed_weight_loss, self).__init__(lambda_1, 1.-lambda_1, beta) # set parameters: self.anneal_epoch = anneal_epoch self.roll_off_nepoch = roll_off_nepoch # return None
[docs] def update_lambda_values_on_epoch_begin(self, epoch, **kwargs): """ Update values of lambda at epoch start. Takes in every kwargs to not crowd the interface... """ if epoch > self.anneal_epoch: _lambda_1 = tf.keras.backend.get_value(self.lambda_1) _lambda_1 *= np.exp(-1.*(epoch - self.anneal_epoch)/self.roll_off_nepoch) tf.keras.backend.set_value(self.lambda_1, _lambda_1) tf.keras.backend.set_value(self.lambda_2, 1.-_lambda_1) # return None
[docs] def print_feedback(self, padding=''): """ Print feedback to screen """ print(padding+'using annealed loss function')
[docs] class SoftAdapt_weight_loss(variable_weight_loss): """ Implement SoftAdapt as in arXiv:1912.12355, with optional smoothing """ def __init__(self, tau=1.0, beta=0.0, smoothing=True, smoothing_tau=20, quantity_1='val_rho_loss', quantity_2='val_ee_loss', **kwargs): """ Initialize loss function """ # initialize: super(SoftAdapt_weight_loss, self).__init__() # set parameters: self.tau = tau self.beta = beta self.smoothing = smoothing self.smoothing_alpha = 1. / smoothing_tau # quantrities to be monitored: self.quantity_1 = quantity_1 self.quantity_2 = quantity_2 self.rate_1_buffer = 0.0 self.rate_2_buffer = 0.0 self.lambda_1_buffer = 1.0 self.lambda_2_buffer = 0.0 # return None
[docs] def update_lambda_values_on_epoch_begin(self, epoch, **kwargs): """ Update values of lambda at epoch start. Takes in every kwargs to not crowd the interface... """ # get logs: logs = kwargs.get('logs') quantity_1 = logs[self.quantity_1] quantity_2 = logs[self.quantity_2] # get the two rates: if len(quantity_1) < 2: rate_1 = 0.0 else: rate_1 = quantity_1[-1] - quantity_1[-2] if len(quantity_2) < 2: rate_2 = 0.0 else: rate_2 = quantity_2[-1] - quantity_2[-2] # smooth the two rates: if self.smoothing: rate_1 = self.smoothing_alpha * rate_1 + (1. - self.smoothing_alpha) * self.rate_1_buffer rate_2 = self.smoothing_alpha * rate_2 + (1. - self.smoothing_alpha) * self.rate_2_buffer self.rate_1_buffer = rate_1 self.rate_2_buffer = rate_2 # protect for initial phase: if rate_1 == 0.0: _lambda_1 = 1.0 else: lambda_1 = np.exp(self.tau * rate_1) lambda_2 = np.exp(self.tau * rate_2) _tot = lambda_1 + lambda_2 if self.smoothing: _lambda_1 = self.smoothing_alpha * lambda_1 / _tot + (1. - self.smoothing_alpha) * self.lambda_1_buffer _lambda_2 = self.smoothing_alpha * lambda_2 / _tot + (1. - self.smoothing_alpha) * self.lambda_2_buffer _lambda_1 = _lambda_1 / (_lambda_1 + _lambda_2) _lambda_2 = _lambda_2 / (_lambda_1 + _lambda_2) self.lambda_1_buffer = _lambda_1 self.lambda_2_buffer = _lambda_2 else: _lambda_1 = lambda_1 / _tot # set second by enforcing sum to one: tf.keras.backend.set_value(self.lambda_1, _lambda_1) tf.keras.backend.set_value(self.lambda_2, 1.-_lambda_1) # return None
[docs] def print_feedback(self, padding=''): """ Print feedback to screen """ print(padding+'using SoftAdapt loss function') if self.smoothing: print(padding+' with smoothing') print(padding+' with smoothing_tau = ', 1./self.smoothing_alpha) print(padding+' with tau = ', self.tau)
[docs] class SharpStep(variable_weight_loss): """ Implement sharp stepping between two values """ def __init__(self, step_epoch=50, value_1=1.0, value_2=0.1, beta=0., **kwargs): """ Initialize loss function """ # initialize: super(SharpStep, self).__init__() # set parameters: self.step_epoch = step_epoch self.value_1 = value_1 self.value_2 = value_2 # initialize: self.beta = beta # return None
[docs] def update_lambda_values_on_epoch_begin(self, epoch, **kwargs): """ Update values of lambda at epoch start. Takes in every kwargs to not crowd the interface... """ if epoch < self.step_epoch: lambda_1 = self.value_1 else: lambda_1 = self.value_2 # set second by enforcing sum to one: lambda_2 = 1. - lambda_1 # tf.keras.backend.set_value(self.lambda_1, lambda_1) tf.keras.backend.set_value(self.lambda_2, lambda_2) # return None
[docs] def print_feedback(self, padding=''): """ Print feedback to screen """ print(padding+'using sharp step loss function')