1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Utilities for helping test ops.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 from six.moves import xrange # pylint: disable=redefined-builtin 23 24 25 def ConvertBetweenDataFormats(x, data_format_src, data_format_dst): 26 """Converts 4D tensor between data formats.""" 27 28 valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"] 29 if data_format_src not in valid_data_formats: 30 raise ValueError("data_format_src must be of %s, got %s." % 31 (valid_data_formats, data_format_src)) 32 if data_format_dst not in valid_data_formats: 33 raise ValueError("data_format_dst must be of %s, got %s." % 34 (valid_data_formats, data_format_dst)) 35 if len(x.shape) != 4: 36 raise ValueError("x must be 4D, got shape %s." % x.shape) 37 38 if data_format_src == data_format_dst: 39 return x 40 41 dim_map = {d: i for i, d in enumerate(data_format_src)} 42 transpose_dims = [dim_map[d] for d in data_format_dst] 43 return np.transpose(x, transpose_dims) 44 45 46 def PermuteDimsBetweenDataFormats(dims, data_format_src, data_format_dst): 47 """Get new shape for converting between data formats.""" 48 49 valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"] 50 if data_format_src not in valid_data_formats: 51 raise ValueError("data_format_src must be of %s, got %s." % 52 (valid_data_formats, data_format_src)) 53 if data_format_dst not in valid_data_formats: 54 raise ValueError("data_format_dst must be of %s, got %s." % 55 (valid_data_formats, data_format_dst)) 56 if len(dims) != 4: 57 raise ValueError("dims must be of length 4, got %s." % dims) 58 59 if data_format_src == data_format_dst: 60 return dims 61 62 dim_map = {d: i for i, d in enumerate(data_format_src)} 63 permuted_dims = [dims[dim_map[d]] for d in data_format_dst] 64 return permuted_dims 65 66 67 _JIT_WARMUP_ITERATIONS = 10 68 69 70 def RunWithWarmup(sess, op_to_run, feed_dict, options=None, run_metadata=None): 71 """Runs a graph a few times to ensure that its clusters are compiled.""" 72 for _ in xrange(0, _JIT_WARMUP_ITERATIONS): 73 sess.run(op_to_run, feed_dict, options=options) 74 return sess.run( 75 op_to_run, feed_dict, options=options, run_metadata=run_metadata) 76