Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for cross entropy related functionality in tensorflow.ops.nn."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import math
     22 
     23 import numpy as np
     24 
     25 from tensorflow.python.framework import constant_op
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.ops import gradient_checker
     28 from tensorflow.python.ops import gradients_impl
     29 from tensorflow.python.ops import nn_impl
     30 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
     31 from tensorflow.python.platform import test
     32 
     33 exp = math.exp
     34 log = math.log
     35 
     36 
     37 class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
     38 
     39   def _SigmoidCrossEntropyWithLogits(self, logits, targets):
     40     assert len(logits) == len(targets)
     41     pred = [1 / (1 + exp(-x)) for x in logits]
     42     eps = 0.0001
     43     pred = [min(max(p, eps), 1 - eps) for p in pred]
     44     return [-z * log(y) - (1 - z) * log(1 - y) for y, z in zip(pred, targets)]
     45 
     46   def _Inputs(self, x=None, y=None, dtype=dtypes.float64, sizes=None):
     47     x = [-100, -2, -2, 0, 2, 2, 2, 100] if x is None else x
     48     y = [0, 0, 1, 0, 0, 1, 0.5, 1] if y is None else y
     49     assert len(x) == len(y)
     50     sizes = sizes if sizes else [len(x)]
     51     logits = constant_op.constant(x, shape=sizes, dtype=dtype, name="logits")
     52     targets = constant_op.constant(y, shape=sizes, dtype=dtype, name="targets")
     53     losses = np.array(self._SigmoidCrossEntropyWithLogits(x, y)).reshape(*sizes)
     54     return logits, targets, losses
     55 
     56   def testConstructionNamed(self):
     57     with self.test_session():
     58       logits, targets, _ = self._Inputs()
     59       loss = nn_impl.sigmoid_cross_entropy_with_logits(
     60           labels=targets, logits=logits, name="mylogistic")
     61     self.assertEqual("mylogistic", loss.op.name)
     62 
     63   def testLogisticOutput(self):
     64     for use_gpu in [True, False]:
     65       for dtype in [dtypes.float32, dtypes.float16]:
     66         with self.test_session(use_gpu=use_gpu):
     67           logits, targets, losses = self._Inputs(dtype=dtype)
     68           loss = nn_impl.sigmoid_cross_entropy_with_logits(
     69               labels=targets, logits=logits)
     70           np_loss = np.array(losses).astype(np.float32)
     71           tf_loss = loss.eval()
     72         self.assertAllClose(np_loss, tf_loss, atol=0.001)
     73 
     74   def testLogisticOutputMultiDim(self):
     75     for use_gpu in [True, False]:
     76       for dtype in [dtypes.float32, dtypes.float16]:
     77         with self.test_session(use_gpu=use_gpu):
     78           logits, targets, losses = self._Inputs(dtype=dtype, sizes=[2, 2, 2])
     79           loss = nn_impl.sigmoid_cross_entropy_with_logits(
     80               labels=targets, logits=logits)
     81           np_loss = np.array(losses).astype(np.float32)
     82           tf_loss = loss.eval()
     83         self.assertAllClose(np_loss, tf_loss, atol=0.001)
     84 
     85   def testGradient(self):
     86     sizes = [4, 2]
     87     with self.test_session():
     88       logits, targets, _ = self._Inputs(sizes=sizes)
     89       loss = nn_impl.sigmoid_cross_entropy_with_logits(
     90           labels=targets, logits=logits)
     91       err = gradient_checker.compute_gradient_error(logits, sizes, loss, sizes)
     92     print("logistic loss gradient err = ", err)
     93     self.assertLess(err, 1e-7)
     94 
     95   def testGradientAtZero(self):
     96     with self.test_session():
     97       logits = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
     98       targets = constant_op.constant([0.0, 1.0], dtype=dtypes.float64)
     99       loss = nn_impl.sigmoid_cross_entropy_with_logits(
    100           labels=targets, logits=logits)
    101       grads = gradients_impl.gradients(loss, logits)[0].eval()
    102     self.assertAllClose(grads, [0.5, -0.5])
    103 
    104   def testShapeError(self):
    105     with self.assertRaisesRegexp(ValueError, "must have the same shape"):
    106       nn_impl.sigmoid_cross_entropy_with_logits(labels=[1, 2, 3],
    107                                                 logits=[[2, 1]])
    108 
    109 
    110 class WeightedCrossEntropyTest(test.TestCase):
    111 
    112   def _WeightedCrossEntropy(self, logits, targets, pos_coeff):
    113     assert len(logits) == len(targets)
    114     pred = [1 / (1 + exp(-x)) for x in logits]
    115     eps = 0.0001
    116     pred = [min(max(p, eps), 1 - eps) for p in pred]
    117     return [
    118         -z * pos_coeff * log(y) - (1 - z) * log(1 - y)
    119         for y, z in zip(pred, targets)
    120     ]
    121 
    122   def _Inputs(self, x=None, y=None, q=3.0, dtype=dtypes.float64, sizes=None):
    123     x = [-100, -2, -2, 0, 2, 2, 2, 100] if x is None else x
    124     y = [0, 0, 1, 0, 0, 1, 0.5, 1] if y is None else y
    125     assert len(x) == len(y)
    126     sizes = sizes if sizes else [len(x)]
    127     logits = constant_op.constant(x, shape=sizes, dtype=dtype, name="logits")
    128     targets = constant_op.constant(y, shape=sizes, dtype=dtype, name="targets")
    129     losses = np.array(self._WeightedCrossEntropy(x, y, q)).reshape(*sizes)
    130     return logits, targets, q, losses
    131 
    132   def testConstructionNamed(self):
    133     with self.test_session():
    134       logits, targets, pos_weight, _ = self._Inputs()
    135       loss = nn_impl.weighted_cross_entropy_with_logits(
    136           targets=targets, logits=logits, pos_weight=pos_weight, name="mybce")
    137     self.assertEqual("mybce", loss.op.name)
    138 
    139   def testOutput(self):
    140     for use_gpu in [True, False]:
    141       with self.test_session(use_gpu=use_gpu):
    142         logits, targets, pos_weight, losses = self._Inputs(dtype=dtypes.float32)
    143         loss = nn_impl.weighted_cross_entropy_with_logits(
    144             targets=targets, logits=logits, pos_weight=pos_weight)
    145         np_loss = np.array(losses).astype(np.float32)
    146         tf_loss = loss.eval()
    147       self.assertAllClose(np_loss, tf_loss, atol=0.001)
    148 
    149   def testOutputMultiDim(self):
    150     for use_gpu in [True, False]:
    151       with self.test_session(use_gpu=use_gpu):
    152         logits, targets, pos_weight, losses = self._Inputs(
    153             dtype=dtypes.float32, sizes=[2, 2, 2])
    154         loss = nn_impl.weighted_cross_entropy_with_logits(
    155             targets=targets, logits=logits, pos_weight=pos_weight)
    156         np_loss = np.array(losses).astype(np.float32)
    157         tf_loss = loss.eval()
    158       self.assertAllClose(np_loss, tf_loss, atol=0.001)
    159 
    160   def testGradient(self):
    161     sizes = [4, 2]
    162     with self.test_session():
    163       logits, targets, pos_weight, _ = self._Inputs(sizes=sizes)
    164       loss = nn_impl.weighted_cross_entropy_with_logits(
    165           targets=targets, logits=logits, pos_weight=pos_weight)
    166       err = gradient_checker.compute_gradient_error(logits, sizes, loss, sizes)
    167     print("logistic loss gradient err = ", err)
    168     self.assertLess(err, 1e-7)
    169 
    170   def testShapeError(self):
    171     with self.assertRaisesRegexp(ValueError, "must have the same shape"):
    172       nn_impl.weighted_cross_entropy_with_logits(
    173           targets=[1, 2, 3], logits=[[2, 1]], pos_weight=2.0)
    174 
    175 
    176 if __name__ == "__main__":
    177   test.main()
    178