Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.contrib.labeled_tensor.python.ops import core
     21 from tensorflow.contrib.labeled_tensor.python.ops import nn
     22 from tensorflow.contrib.labeled_tensor.python.ops import test_util
     23 from tensorflow.python.ops import nn_impl
     24 from tensorflow.python.ops import nn_ops
     25 
     26 
     27 class NNTests(test_util.Base):
     28 
     29   def setUp(self):
     30     super(NNTests, self).setUp()
     31     self.axes = ['x']
     32     self.original_lt = core.LabeledTensor([0.0, 0.5, 1.0], self.axes)
     33     self.other_lt = 1 - self.original_lt
     34 
     35   def test_unary_ops(self):
     36     ops = [
     37         ('relu', nn_ops.relu, nn.relu),
     38         ('relu6', nn_ops.relu6, nn.relu6),
     39         ('crelu', nn_ops.crelu, nn.crelu),
     40         ('elu', nn_ops.elu, nn.elu),
     41         ('softplus', nn_ops.softplus, nn.softplus),
     42         ('l2_loss', nn_ops.l2_loss, nn.l2_loss),
     43         ('softmax', nn_ops.softmax, nn.softmax),
     44         ('log_softmax', nn_ops.log_softmax, nn.log_softmax),
     45     ]
     46     for op_name, tf_op, lt_op in ops:
     47       golden_tensor = tf_op(self.original_lt.tensor)
     48       golden_lt = core.LabeledTensor(golden_tensor, self.axes)
     49       actual_lt = lt_op(self.original_lt)
     50       self.assertIn(op_name, actual_lt.name)
     51       self.assertLabeledTensorsEqual(golden_lt, actual_lt)
     52 
     53   def test_binary_ops(self):
     54     ops = [
     55         ('sigmoid_cross_entropy_with_logits',
     56          nn_impl.sigmoid_cross_entropy_with_logits,
     57          nn.sigmoid_cross_entropy_with_logits),
     58         ('softmax_cross_entropy_with_logits',
     59          nn_ops.softmax_cross_entropy_with_logits,
     60          nn.softmax_cross_entropy_with_logits),
     61         ('sparse_softmax_cross_entropy_with_logits',
     62          nn_ops.sparse_softmax_cross_entropy_with_logits,
     63          nn.sparse_softmax_cross_entropy_with_logits),
     64     ]
     65     for op_name, tf_op, lt_op in ops:
     66       golden_tensor = tf_op(self.original_lt.tensor, self.other_lt.tensor)
     67       golden_lt = core.LabeledTensor(golden_tensor, self.axes)
     68       actual_lt = lt_op(self.original_lt, self.other_lt)
     69       self.assertIn(op_name, actual_lt.name)
     70       self.assertLabeledTensorsEqual(golden_lt, actual_lt)
     71