Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Ops tests."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.learn.python.learn import ops
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import random_seed
     27 from tensorflow.python.ops import variables
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.platform import test
     30 
     31 
     32 class OpsTest(test.TestCase):
     33   """Ops tests."""
     34 
     35   def test_softmax_classifier(self):
     36     with self.test_session() as session:
     37       features = array_ops.placeholder(dtypes.float32, [None, 3])
     38       labels = array_ops.placeholder(dtypes.float32, [None, 2])
     39       weights = constant_op.constant([[0.1, 0.1], [0.1, 0.1], [0.1, 0.1]])
     40       biases = constant_op.constant([0.2, 0.3])
     41       class_weight = constant_op.constant([0.1, 0.9])
     42       prediction, loss = ops.softmax_classifier(features, labels, weights,
     43                                                 biases, class_weight)
     44       self.assertEqual(prediction.get_shape()[1], 2)
     45       self.assertEqual(loss.get_shape(), [])
     46       value = session.run(loss, {features: [[0.2, 0.3, 0.2]], labels: [[0, 1]]})
     47       self.assertAllClose(value, 0.55180627)
     48 
     49   def test_embedding_lookup(self):
     50     d_embed = 5
     51     n_embed = 10
     52     ids_shape = (2, 3, 4)
     53     embeds = np.random.randn(n_embed, d_embed)
     54     ids = np.random.randint(0, n_embed, ids_shape)
     55     with self.test_session():
     56       embed_np = embeds[ids]
     57       embed_tf = ops.embedding_lookup(embeds, ids).eval()
     58     self.assertEqual(embed_np.shape, embed_tf.shape)
     59     self.assertAllClose(embed_np, embed_tf)
     60 
     61   def test_categorical_variable(self):
     62     random_seed.set_random_seed(42)
     63     with self.test_session() as sess:
     64       cat_var_idx = array_ops.placeholder(dtypes.int64, [2, 2])
     65       embeddings = ops.categorical_variable(
     66           cat_var_idx, n_classes=5, embedding_size=10, name="my_cat_var")
     67       sess.run(variables.global_variables_initializer())
     68       emb1 = sess.run(embeddings,
     69                       feed_dict={cat_var_idx.name: [[0, 1], [2, 3]]})
     70       emb2 = sess.run(embeddings,
     71                       feed_dict={cat_var_idx.name: [[0, 2], [1, 3]]})
     72     self.assertEqual(emb1.shape, emb2.shape)
     73     self.assertAllEqual(np.transpose(emb2, axes=[1, 0, 2]), emb1)
     74 
     75 
     76 if __name__ == "__main__":
     77   test.main()
     78