Home | History | Annotate | Download | only in tests
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Utilities for helping test ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 from six.moves import xrange  # pylint: disable=redefined-builtin
     23 
     24 
     25 def ConvertBetweenDataFormats(x, data_format_src, data_format_dst):
     26   """Converts 4D tensor between data formats."""
     27 
     28   valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"]
     29   if data_format_src not in valid_data_formats:
     30     raise ValueError("data_format_src must be of %s, got %s." %
     31                      (valid_data_formats, data_format_src))
     32   if data_format_dst not in valid_data_formats:
     33     raise ValueError("data_format_dst must be of %s, got %s." %
     34                      (valid_data_formats, data_format_dst))
     35   if len(x.shape) != 4:
     36     raise ValueError("x must be 4D, got shape %s." % x.shape)
     37 
     38   if data_format_src == data_format_dst:
     39     return x
     40 
     41   dim_map = {d: i for i, d in enumerate(data_format_src)}
     42   transpose_dims = [dim_map[d] for d in data_format_dst]
     43   return np.transpose(x, transpose_dims)
     44 
     45 
     46 def PermuteDimsBetweenDataFormats(dims, data_format_src, data_format_dst):
     47   """Get new shape for converting between data formats."""
     48 
     49   valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"]
     50   if data_format_src not in valid_data_formats:
     51     raise ValueError("data_format_src must be of %s, got %s." %
     52                      (valid_data_formats, data_format_src))
     53   if data_format_dst not in valid_data_formats:
     54     raise ValueError("data_format_dst must be of %s, got %s." %
     55                      (valid_data_formats, data_format_dst))
     56   if len(dims) != 4:
     57     raise ValueError("dims must be of length 4, got %s." % dims)
     58 
     59   if data_format_src == data_format_dst:
     60     return dims
     61 
     62   dim_map = {d: i for i, d in enumerate(data_format_src)}
     63   permuted_dims = [dims[dim_map[d]] for d in data_format_dst]
     64   return permuted_dims
     65 
     66 
     67 _JIT_WARMUP_ITERATIONS = 10
     68 
     69 
     70 def RunWithWarmup(sess, op_to_run, feed_dict, options=None, run_metadata=None):
     71   """Runs a graph a few times to ensure that its clusters are compiled."""
     72   for _ in xrange(0, _JIT_WARMUP_ITERATIONS):
     73     sess.run(op_to_run, feed_dict, options=options)
     74   return sess.run(
     75       op_to_run, feed_dict, options=options, run_metadata=run_metadata)
     76