1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """A collection of functions to be used as evaluation metrics.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.contrib import losses 23 from tensorflow.contrib.learn.python.learn.estimators import prediction_key 24 from tensorflow.contrib.metrics.python.ops import metric_ops 25 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import math_ops 28 from tensorflow.python.ops import nn 29 30 INFERENCE_PROB_NAME = prediction_key.PredictionKey.PROBABILITIES 31 INFERENCE_PRED_NAME = prediction_key.PredictionKey.CLASSES 32 33 FEATURE_IMPORTANCE_NAME = 'global_feature_importance' 34 35 36 def _top_k_generator(k): 37 def _top_k(probabilities, targets): 38 targets = math_ops.to_int32(targets) 39 if targets.get_shape().ndims > 1: 40 targets = array_ops.squeeze(targets, squeeze_dims=[1]) 41 return metric_ops.streaming_mean(nn.in_top_k(probabilities, targets, k)) 42 return _top_k 43 44 45 def _accuracy(predictions, targets, weights=None): 46 return metric_ops.streaming_accuracy(predictions, targets, weights=weights) 47 48 49 def _r2(probabilities, targets, weights=None): 50 targets = math_ops.to_float(targets) 51 y_mean = math_ops.reduce_mean(targets, 0) 52 squares_total = math_ops.reduce_sum(math_ops.square(targets - y_mean), 0) 53 squares_residuals = math_ops.reduce_sum( 54 math_ops.square(targets - probabilities), 0) 55 score = 1 - math_ops.reduce_sum(squares_residuals / squares_total) 56 return metric_ops.streaming_mean(score, weights=weights) 57 58 59 def _squeeze_and_onehot(targets, depth): 60 targets = array_ops.squeeze(targets, squeeze_dims=[1]) 61 return array_ops.one_hot(math_ops.to_int32(targets), depth) 62 63 64 def _sigmoid_entropy(probabilities, targets, weights=None): 65 return metric_ops.streaming_mean( 66 losses.sigmoid_cross_entropy(probabilities, 67 _squeeze_and_onehot( 68 targets, 69 array_ops.shape(probabilities)[1])), 70 weights=weights) 71 72 73 def _softmax_entropy(probabilities, targets, weights=None): 74 return metric_ops.streaming_mean( 75 losses.sparse_softmax_cross_entropy(probabilities, 76 math_ops.to_int32(targets)), 77 weights=weights) 78 79 80 def _predictions(predictions, unused_targets, **unused_kwargs): 81 return predictions 82 83 84 def _class_log_loss(probabilities, targets, weights=None): 85 return metric_ops.streaming_mean( 86 losses.log_loss(probabilities, 87 _squeeze_and_onehot(targets, 88 array_ops.shape(probabilities)[1])), 89 weights=weights) 90 91 92 def _precision(predictions, targets, weights=None): 93 return metric_ops.streaming_precision(predictions, targets, weights=weights) 94 95 96 def _precision_at_thresholds(predictions, targets, weights=None): 97 return metric_ops.streaming_precision_at_thresholds( 98 array_ops.slice(predictions, [0, 1], [-1, 1]), 99 targets, 100 np.arange( 101 0, 1, 0.01, dtype=np.float32), 102 weights=weights) 103 104 105 def _recall(predictions, targets, weights=None): 106 return metric_ops.streaming_recall(predictions, targets, weights=weights) 107 108 109 def _recall_at_thresholds(predictions, targets, weights=None): 110 return metric_ops.streaming_recall_at_thresholds( 111 array_ops.slice(predictions, [0, 1], [-1, 1]), 112 targets, 113 np.arange( 114 0, 1, 0.01, dtype=np.float32), 115 weights=weights) 116 117 118 def _auc(probs, targets, weights=None): 119 return metric_ops.streaming_auc(array_ops.slice(probs, [0, 1], [-1, 1]), 120 targets, weights=weights) 121 122 123 _EVAL_METRICS = { 124 'auc': _auc, 125 'sigmoid_entropy': _sigmoid_entropy, 126 'softmax_entropy': _softmax_entropy, 127 'accuracy': _accuracy, 128 'r2': _r2, 129 'predictions': _predictions, 130 'classification_log_loss': _class_log_loss, 131 'precision': _precision, 132 'precision_at_thresholds': _precision_at_thresholds, 133 'recall': _recall, 134 'recall_at_thresholds': _recall_at_thresholds, 135 'top_5': _top_k_generator(5) 136 } 137 138 _PREDICTION_KEYS = { 139 'auc': INFERENCE_PROB_NAME, 140 'sigmoid_entropy': INFERENCE_PROB_NAME, 141 'softmax_entropy': INFERENCE_PROB_NAME, 142 'accuracy': INFERENCE_PRED_NAME, 143 'r2': prediction_key.PredictionKey.SCORES, 144 'predictions': INFERENCE_PRED_NAME, 145 'classification_log_loss': INFERENCE_PROB_NAME, 146 'precision': INFERENCE_PRED_NAME, 147 'precision_at_thresholds': INFERENCE_PROB_NAME, 148 'recall': INFERENCE_PRED_NAME, 149 'recall_at_thresholds': INFERENCE_PROB_NAME, 150 'top_5': INFERENCE_PROB_NAME 151 } 152 153 154 def get_metric(metric_name): 155 """Given a metric name, return the corresponding metric function.""" 156 return _EVAL_METRICS[metric_name] 157 158 159 def get_prediction_key(metric_name): 160 return _PREDICTION_KEYS[metric_name] 161