Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for class Evaluator."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import tempfile
     22 
     23 from tensorflow.contrib.eager.python import evaluator
     24 
     25 from tensorflow.contrib.eager.python import metrics
     26 from tensorflow.contrib.summary import summary_test_util
     27 from tensorflow.python.data.ops import dataset_ops
     28 from tensorflow.python.eager import context
     29 from tensorflow.python.eager import test
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.ops import variables
     32 from tensorflow.python.training import training_util
     33 
     34 
     35 class IdentityModel(object):
     36 
     37   def eval_data(self, d):
     38     return d
     39 
     40 
     41 class PrefixLModel(object):
     42 
     43   def eval_data(self, d):
     44     return {"l_" + key: d[key] for key in d}
     45 
     46 
     47 class SimpleEvaluator(evaluator.Evaluator):
     48 
     49   def __init__(self, model):
     50     super(SimpleEvaluator, self).__init__(model)
     51     self.mean = self.track_metric(metrics.Mean("mean"))
     52 
     53   def call(self, eval_data):
     54     self.mean(eval_data)
     55 
     56 
     57 class DelegatingEvaluator(evaluator.Evaluator):
     58 
     59   def __init__(self, model):
     60     super(DelegatingEvaluator, self).__init__(model)
     61     self.sub = self.track_evaluator("inner", SimpleEvaluator(model))
     62     self.mean = self.track_metric(metrics.Mean("outer-mean"))
     63 
     64   def call(self, eval_data):
     65     # Keys here come from PrefixLModel, which adds "l_".
     66     self.mean(eval_data["l_outer"])
     67     self.sub.call(eval_data["l_inner"])
     68 
     69 
     70 # pylint: disable=not-callable
     71 class EvaluatorTest(test.TestCase):
     72 
     73   def testSimple(self):
     74     e = SimpleEvaluator(IdentityModel())
     75     e(3.0)
     76     e([5.0, 7.0, 9.0])
     77     results = e.all_metric_results()
     78     self.assertEqual(set(["mean"]), set(results.keys()))
     79     self.assertEqual(6.0, results["mean"].numpy())
     80 
     81   def testWriteSummaries(self):
     82     e = SimpleEvaluator(IdentityModel())
     83     e(3.0)
     84     e([5.0, 7.0, 9.0])
     85     training_util.get_or_create_global_step()
     86     logdir = tempfile.mkdtemp()
     87 
     88     e.all_metric_results(logdir)
     89 
     90     events = summary_test_util.events_from_logdir(logdir)
     91     self.assertEqual(len(events), 2)
     92     self.assertEqual(events[1].summary.value[0].simple_value, 6.0)
     93 
     94   def testComposition(self):
     95     e = DelegatingEvaluator(PrefixLModel())
     96     e({"inner": 2.0, "outer": 100.0})
     97     e({"inner": 4.0, "outer": 1000.0})
     98     results = e.all_metric_results()
     99     self.assertEqual(set(["inner/mean", "outer-mean"]), set(results.keys()))
    100     self.assertEqual(3.0, results["inner/mean"].numpy())
    101     self.assertEqual(550.0, results["outer-mean"].numpy())
    102 
    103   def testMetricVariables(self):
    104     e = DelegatingEvaluator(PrefixLModel())
    105     e({"inner": 2.0, "outer": 100.0})
    106     prefix_count = {}
    107     for v in e.metric_variables:
    108       p = v.name.split("/")[0]
    109       prefix_count[p] = prefix_count.get(p, 0) + 1
    110     self.assertEqual({"outer_mean": 2, "mean": 2}, prefix_count)
    111 
    112   def testDatasetEager(self):
    113     e = SimpleEvaluator(IdentityModel())
    114     ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
    115     results = e.evaluate_on_dataset(ds)
    116     self.assertEqual(set(["mean"]), set(results.keys()))
    117     self.assertEqual(6.0, results["mean"].numpy())
    118 
    119   def testDatasetGraph(self):
    120     with context.graph_mode(), ops.Graph().as_default(), self.cached_session():
    121       e = SimpleEvaluator(IdentityModel())
    122       ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
    123       init_op, call_op, results_op = e.evaluate_on_dataset(ds)
    124       results = e.run_evaluation(init_op, call_op, results_op)
    125       self.assertEqual(set(["mean"]), set(results.keys()))
    126       self.assertEqual(6.0, results["mean"])
    127 
    128   def testWriteSummariesGraph(self):
    129     with context.graph_mode(), ops.Graph().as_default(), self.cached_session():
    130       e = SimpleEvaluator(IdentityModel())
    131       ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
    132       training_util.get_or_create_global_step()
    133       logdir = tempfile.mkdtemp()
    134       init_op, call_op, results_op = e.evaluate_on_dataset(
    135           ds, summary_logdir=logdir)
    136       variables.global_variables_initializer().run()
    137       e.run_evaluation(init_op, call_op, results_op)
    138 
    139     events = summary_test_util.events_from_logdir(logdir)
    140     self.assertEqual(len(events), 2)
    141     self.assertEqual(events[1].summary.value[0].simple_value, 6.0)
    142 
    143   def testModelProperty(self):
    144     m = IdentityModel()
    145     e = SimpleEvaluator(m)
    146     self.assertIs(m, e.model)
    147 
    148   def testMetricsProperty(self):
    149     e = DelegatingEvaluator(PrefixLModel())
    150     names = set([(p, m.name) for p, m in e.metrics])
    151     self.assertEqual(set([("", "outer-mean"), ("inner/", "mean")]), names)
    152 
    153   def testSharedMetric(self):
    154 
    155     class MetricArgEvaluator(evaluator.Evaluator):
    156 
    157       def __init__(self, model, m):
    158         super(MetricArgEvaluator, self).__init__(model)
    159         self.m = self.track_metric(m)
    160 
    161     metric = metrics.Mean("mean")
    162     model = IdentityModel()
    163     e = MetricArgEvaluator(model, metric)
    164     with self.assertRaisesRegexp(ValueError, "already added"):
    165       MetricArgEvaluator(model, metric)
    166     del e
    167 
    168   def testMetricTrackedTwice(self):
    169 
    170     class MetricTwiceEvaluator(evaluator.Evaluator):
    171 
    172       def __init__(self, model):
    173         super(MetricTwiceEvaluator, self).__init__(model)
    174         self.m = self.track_metric(metrics.Mean("mean"))
    175         self.track_metric(self.m)  # okay to track same metric again
    176 
    177     MetricTwiceEvaluator(IdentityModel())
    178 
    179 
    180 class SparseSoftmaxEvaluatorTest(test.TestCase):
    181 
    182   def testSimple(self):
    183     e = evaluator.SparseSoftmaxEvaluator(IdentityModel())
    184     e({e.loss_key: 1.0, e.label_key: 5, e.predicted_class_key: 5})
    185     e({e.loss_key: [0.0, 3.0, 4.0],
    186        e.label_key: [1, 2, 3],
    187        e.predicted_class_key: [1, 1, 3]})
    188     results = e.all_metric_results()
    189     self.assertEqual(set(["Avg Loss", "Accuracy"]), set(results.keys()))
    190     self.assertEqual(2.0, results["Avg Loss"].numpy())
    191     self.assertEqual(0.75, results["Accuracy"].numpy())
    192 
    193 
    194 if __name__ == "__main__":
    195   test.main()
    196