1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Ops tests.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.learn.python.learn import ops 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import random_seed 27 from tensorflow.python.ops import variables 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.platform import test 30 31 32 class OpsTest(test.TestCase): 33 """Ops tests.""" 34 35 def test_softmax_classifier(self): 36 with self.test_session() as session: 37 features = array_ops.placeholder(dtypes.float32, [None, 3]) 38 labels = array_ops.placeholder(dtypes.float32, [None, 2]) 39 weights = constant_op.constant([[0.1, 0.1], [0.1, 0.1], [0.1, 0.1]]) 40 biases = constant_op.constant([0.2, 0.3]) 41 class_weight = constant_op.constant([0.1, 0.9]) 42 prediction, loss = ops.softmax_classifier(features, labels, weights, 43 biases, class_weight) 44 self.assertEqual(prediction.get_shape()[1], 2) 45 self.assertEqual(loss.get_shape(), []) 46 value = session.run(loss, {features: [[0.2, 0.3, 0.2]], labels: [[0, 1]]}) 47 self.assertAllClose(value, 0.55180627) 48 49 def test_embedding_lookup(self): 50 d_embed = 5 51 n_embed = 10 52 ids_shape = (2, 3, 4) 53 embeds = np.random.randn(n_embed, d_embed) 54 ids = np.random.randint(0, n_embed, ids_shape) 55 with self.test_session(): 56 embed_np = embeds[ids] 57 embed_tf = ops.embedding_lookup(embeds, ids).eval() 58 self.assertEqual(embed_np.shape, embed_tf.shape) 59 self.assertAllClose(embed_np, embed_tf) 60 61 def test_categorical_variable(self): 62 random_seed.set_random_seed(42) 63 with self.test_session() as sess: 64 cat_var_idx = array_ops.placeholder(dtypes.int64, [2, 2]) 65 embeddings = ops.categorical_variable( 66 cat_var_idx, n_classes=5, embedding_size=10, name="my_cat_var") 67 sess.run(variables.global_variables_initializer()) 68 emb1 = sess.run(embeddings, 69 feed_dict={cat_var_idx.name: [[0, 1], [2, 3]]}) 70 emb2 = sess.run(embeddings, 71 feed_dict={cat_var_idx.name: [[0, 2], [1, 3]]}) 72 self.assertEqual(emb1.shape, emb2.shape) 73 self.assertAllEqual(np.transpose(emb2, axes=[1, 0, 2]), emb1) 74 75 76 if __name__ == "__main__": 77 test.main() 78