Home | History | Annotate | Download | only in client
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """A collection of functions to be used as evaluation metrics."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.contrib import losses
     23 from tensorflow.contrib.learn.python.learn.estimators import prediction_key
     24 from tensorflow.contrib.metrics.python.ops import metric_ops
     25 
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import math_ops
     28 from tensorflow.python.ops import nn
     29 
     30 INFERENCE_PROB_NAME = prediction_key.PredictionKey.PROBABILITIES
     31 INFERENCE_PRED_NAME = prediction_key.PredictionKey.CLASSES
     32 
     33 FEATURE_IMPORTANCE_NAME = 'global_feature_importance'
     34 
     35 
     36 def _top_k_generator(k):
     37   def _top_k(probabilities, targets):
     38     targets = math_ops.to_int32(targets)
     39     if targets.get_shape().ndims > 1:
     40       targets = array_ops.squeeze(targets, squeeze_dims=[1])
     41     return metric_ops.streaming_mean(nn.in_top_k(probabilities, targets, k))
     42   return _top_k
     43 
     44 
     45 def _accuracy(predictions, targets, weights=None):
     46   return metric_ops.streaming_accuracy(predictions, targets, weights=weights)
     47 
     48 
     49 def _r2(probabilities, targets, weights=None):
     50   targets = math_ops.to_float(targets)
     51   y_mean = math_ops.reduce_mean(targets, 0)
     52   squares_total = math_ops.reduce_sum(math_ops.square(targets - y_mean), 0)
     53   squares_residuals = math_ops.reduce_sum(
     54       math_ops.square(targets - probabilities), 0)
     55   score = 1 - math_ops.reduce_sum(squares_residuals / squares_total)
     56   return metric_ops.streaming_mean(score, weights=weights)
     57 
     58 
     59 def _squeeze_and_onehot(targets, depth):
     60   targets = array_ops.squeeze(targets, squeeze_dims=[1])
     61   return array_ops.one_hot(math_ops.to_int32(targets), depth)
     62 
     63 
     64 def _sigmoid_entropy(probabilities, targets, weights=None):
     65   return metric_ops.streaming_mean(
     66       losses.sigmoid_cross_entropy(probabilities,
     67                                    _squeeze_and_onehot(
     68                                        targets,
     69                                        array_ops.shape(probabilities)[1])),
     70       weights=weights)
     71 
     72 
     73 def _softmax_entropy(probabilities, targets, weights=None):
     74   return metric_ops.streaming_mean(
     75       losses.sparse_softmax_cross_entropy(probabilities,
     76                                           math_ops.to_int32(targets)),
     77       weights=weights)
     78 
     79 
     80 def _predictions(predictions, unused_targets, **unused_kwargs):
     81   return predictions
     82 
     83 
     84 def _class_log_loss(probabilities, targets, weights=None):
     85   return metric_ops.streaming_mean(
     86       losses.log_loss(probabilities,
     87                       _squeeze_and_onehot(targets,
     88                                           array_ops.shape(probabilities)[1])),
     89       weights=weights)
     90 
     91 
     92 def _precision(predictions, targets, weights=None):
     93   return metric_ops.streaming_precision(predictions, targets, weights=weights)
     94 
     95 
     96 def _precision_at_thresholds(predictions, targets, weights=None):
     97   return metric_ops.streaming_precision_at_thresholds(
     98       array_ops.slice(predictions, [0, 1], [-1, 1]),
     99       targets,
    100       np.arange(
    101           0, 1, 0.01, dtype=np.float32),
    102       weights=weights)
    103 
    104 
    105 def _recall(predictions, targets, weights=None):
    106   return metric_ops.streaming_recall(predictions, targets, weights=weights)
    107 
    108 
    109 def _recall_at_thresholds(predictions, targets, weights=None):
    110   return metric_ops.streaming_recall_at_thresholds(
    111       array_ops.slice(predictions, [0, 1], [-1, 1]),
    112       targets,
    113       np.arange(
    114           0, 1, 0.01, dtype=np.float32),
    115       weights=weights)
    116 
    117 
    118 def _auc(probs, targets, weights=None):
    119   return metric_ops.streaming_auc(array_ops.slice(probs, [0, 1], [-1, 1]),
    120                                   targets, weights=weights)
    121 
    122 
    123 _EVAL_METRICS = {
    124     'auc': _auc,
    125     'sigmoid_entropy': _sigmoid_entropy,
    126     'softmax_entropy': _softmax_entropy,
    127     'accuracy': _accuracy,
    128     'r2': _r2,
    129     'predictions': _predictions,
    130     'classification_log_loss': _class_log_loss,
    131     'precision': _precision,
    132     'precision_at_thresholds': _precision_at_thresholds,
    133     'recall': _recall,
    134     'recall_at_thresholds': _recall_at_thresholds,
    135     'top_5': _top_k_generator(5)
    136 }
    137 
    138 _PREDICTION_KEYS = {
    139     'auc': INFERENCE_PROB_NAME,
    140     'sigmoid_entropy': INFERENCE_PROB_NAME,
    141     'softmax_entropy': INFERENCE_PROB_NAME,
    142     'accuracy': INFERENCE_PRED_NAME,
    143     'r2': prediction_key.PredictionKey.SCORES,
    144     'predictions': INFERENCE_PRED_NAME,
    145     'classification_log_loss': INFERENCE_PROB_NAME,
    146     'precision': INFERENCE_PRED_NAME,
    147     'precision_at_thresholds': INFERENCE_PROB_NAME,
    148     'recall': INFERENCE_PRED_NAME,
    149     'recall_at_thresholds': INFERENCE_PROB_NAME,
    150     'top_5': INFERENCE_PROB_NAME
    151 }
    152 
    153 
    154 def get_metric(metric_name):
    155   """Given a metric name, return the corresponding metric function."""
    156   return _EVAL_METRICS[metric_name]
    157 
    158 
    159 def get_prediction_key(metric_name):
    160   return _PREDICTION_KEYS[metric_name]
    161