1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for SparsemaxLossOp.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.sparsemax import sparsemax, sparsemax_loss 24 from tensorflow.python.ops import gradient_checker 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.ops import gradients_impl 27 from tensorflow.python.framework import constant_op 28 from tensorflow.python.platform import test 29 30 test_obs = 10 31 32 33 class SparsemaxLossTest(test.TestCase): 34 35 def _np_sparsemax(self, z): 36 z = z - np.mean(z, axis=1)[:, np.newaxis] 37 38 # sort z 39 z_sorted = np.sort(z, axis=1)[:, ::-1] 40 41 # calculate k(z) 42 z_cumsum = np.cumsum(z_sorted, axis=1) 43 k = np.arange(1, z.shape[1] + 1) 44 z_check = 1 + k * z_sorted > z_cumsum 45 # use argmax to get the index by row as .nonzero() doesn't 46 # take an axis argument. np.argmax return the first index, but the last 47 # index is required here, use np.flip to get the last index and 48 # `z.shape[axis]` to compensate for np.flip afterwards. 49 k_z = z.shape[1] - np.argmax(z_check[:, ::-1], axis=1) 50 51 # calculate tau(z) 52 tau_sum = z_cumsum[np.arange(0, z.shape[0]), k_z - 1] 53 tau_z = ((tau_sum - 1) / k_z).reshape(-1, 1) 54 55 # calculate p 56 return np.maximum(0, z - tau_z) 57 58 def _np_sparsemax_loss(self, z, q): 59 z = z - np.mean(z, axis=1)[:, np.newaxis] 60 61 # Calculate q^T * z 62 z_k = np.sum(q * z, axis=1) 63 64 # calculate sum over S(z) 65 p = self._np_sparsemax(z) 66 s = p > 0 67 # z_i^2 - tau(z)^2 = p_i (2 * z_i - p_i) for i \in S(z) 68 S_sum = np.sum(s * p * (2 * z - p), axis=1) 69 70 # because q is binary, sum([q_1^2, q_2^2, ...]) is just sum(q) 71 q_norm = np.sum(q, axis=1) 72 73 return -z_k + 0.5 * S_sum + 0.5 * q_norm 74 75 def _np_sparsemax_loss_grad(self, z, q): 76 # chain rule 77 grad = 1 78 79 return grad * (-q + self._np_sparsemax(z)) 80 81 def _tf_sparsemax(self, z, dtype, use_gpu): 82 with self.test_session(use_gpu=use_gpu): 83 tf_sparsemax_op = sparsemax(z.astype(dtype)) 84 tf_sparsemax_out = tf_sparsemax_op.eval() 85 86 return tf_sparsemax_op, tf_sparsemax_out 87 88 def _tf_sparsemax_loss(self, z, q, dtype, use_gpu): 89 z = z.astype(dtype) 90 q = q.astype(dtype) 91 92 with self.test_session(use_gpu=use_gpu): 93 tf_sparsemax_op = sparsemax(z) 94 tf_loss_op = sparsemax_loss(z, tf_sparsemax_op, q) 95 tf_loss_out = tf_loss_op.eval() 96 97 return tf_loss_op, tf_loss_out 98 99 def _test_sparsemax_loss_against_numpy(self, dtype, random, use_gpu): 100 """check sparsemax-loss kernel against numpy""" 101 z = random.uniform(low=-3, high=3, size=(test_obs, 10)) 102 q = np.zeros((test_obs, 10)) 103 q[np.arange(0, test_obs), random.randint(0, 10, size=test_obs)] = 1 104 105 tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu) 106 np_loss = self._np_sparsemax_loss(z, q).astype(dtype) 107 108 self.assertAllCloseAccordingToType( 109 np_loss, tf_loss_out, half_atol=1e-2, half_rtol=5e-3) 110 self.assertShapeEqual(np_loss, tf_loss_op) 111 112 def _test_constant_add(self, dtype, random, use_gpu): 113 """check sparsemax-loss proposition 3""" 114 z = random.uniform(low=-3, high=3, size=(test_obs, 10)) 115 c = random.uniform(low=-3, high=3, size=(test_obs, 1)) 116 q = np.zeros((test_obs, 10)) 117 q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1 118 119 _, tf_loss_zpc = self._tf_sparsemax_loss(z + c, q, dtype, use_gpu) 120 121 _, tf_loss_z = self._tf_sparsemax_loss(z, q, dtype, use_gpu) 122 123 self.assertAllCloseAccordingToType( 124 tf_loss_zpc, 125 tf_loss_z, 126 float_atol=5e-6, 127 float_rtol=5e-6, 128 half_atol=1e-2, 129 half_rtol=1e-2) 130 131 def _test_sparsemax_loss_positive(self, dtype, random, use_gpu): 132 """check sparsemax-loss proposition 4""" 133 z = random.uniform(low=-3, high=3, size=(test_obs, 10)) 134 q = np.zeros((test_obs, 10)) 135 q[np.arange(0, test_obs), random.randint(0, 10, size=test_obs)] = 1 136 137 tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu) 138 139 self.assertAllCloseAccordingToType(np.abs(tf_loss_out), tf_loss_out) 140 self.assertShapeEqual(np.zeros(test_obs), tf_loss_op) 141 142 def _test_sparsemax_loss_zero(self, dtype, random, use_gpu): 143 """check sparsemax-loss proposition 5""" 144 # construct z and q, such that z_k >= 1 + max_{j!=k} z_k holds for 145 # delta_0 = 1. 146 z = random.uniform(low=-3, high=3, size=(test_obs, 10)) 147 z[:, 0] = np.max(z, axis=1) + 1.05 148 149 q = np.zeros((test_obs, 10)) 150 q[:, 0] = 1 151 152 tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu) 153 tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu) 154 155 self.assertAllCloseAccordingToType(np.zeros(test_obs), tf_loss_out) 156 self.assertShapeEqual(np.zeros(test_obs), tf_loss_op) 157 158 self.assertAllCloseAccordingToType(q, tf_sparsemax_out) 159 self.assertShapeEqual(q, tf_sparsemax_op) 160 161 def _test_gradient_against_estimate(self, dtype, random, use_gpu): 162 """check sparsemax-loss Rop, against estimated-loss Rop""" 163 z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype) 164 q = np.zeros((test_obs, 10)).astype(dtype) 165 q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1 166 167 logits = array_ops.placeholder(dtype, name='z') 168 sparsemax_op = sparsemax(logits) 169 loss_op = sparsemax_loss(logits, sparsemax_op, q) 170 171 with self.test_session(use_gpu=use_gpu): 172 err = gradient_checker.compute_gradient_error( 173 logits, z.shape, loss_op, (test_obs,), x_init_value=z, delta=1e-9) 174 175 self.assertLess(err, 1e-4) 176 177 def _test_gradient_against_numpy(self, dtype, random, use_gpu): 178 """check sparsemax-loss Rop, against numpy Rop""" 179 z = random.uniform(low=-3, high=3, size=(test_obs, 10)) 180 q = np.zeros((test_obs, 10)) 181 q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1 182 183 logits = constant_op.constant(z.astype(dtype), name='z') 184 sparsemax_op = sparsemax(logits) 185 loss_op = sparsemax_loss(logits, sparsemax_op, q.astype(dtype)) 186 loss_grad_op = gradients_impl.gradients(loss_op, [logits])[0] 187 188 with self.test_session(use_gpu=use_gpu): 189 tf_grad = loss_grad_op.eval() 190 np_grad = self._np_sparsemax_loss_grad(z, q).astype(dtype) 191 192 self.assertAllCloseAccordingToType( 193 np_grad, tf_grad, half_atol=1e-2, half_rtol=5e-3) 194 self.assertShapeEqual(np_grad, loss_grad_op) 195 196 def _test_dtype(self, dtype): 197 random = np.random.RandomState(1) 198 199 self._test_sparsemax_loss_against_numpy(dtype, random, use_gpu=False) 200 201 self._test_constant_add(dtype, random, use_gpu=False) 202 203 self._test_sparsemax_loss_positive(dtype, random, use_gpu=False) 204 205 self._test_sparsemax_loss_zero(dtype, random, use_gpu=False) 206 207 # sparsemax is not a smooth function so gradient estimation is only 208 # possibol for float64. 209 if dtype == 'float64': 210 self._test_gradient_against_estimate(dtype, random, use_gpu=False) 211 212 self._test_gradient_against_numpy(dtype, random, use_gpu=False) 213 214 def testFloat(self): 215 self._test_dtype('float32') 216 217 def testDouble(self): 218 self._test_dtype('float64') 219 220 221 if __name__ == '__main__': 222 test.main() 223