Home | History | Annotate | Download | only in internal
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """A test lib that defines some models."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import contextlib
     21 
     22 from tensorflow.python import pywrap_tensorflow as print_mdl
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.ops import array_ops
     25 from tensorflow.python.ops import init_ops
     26 from tensorflow.python.ops import math_ops
     27 from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
     28 from tensorflow.python.ops import nn_ops
     29 from tensorflow.python.ops import rnn
     30 from tensorflow.python.ops import rnn_cell
     31 from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
     32 from tensorflow.python.ops import variable_scope
     33 from tensorflow.python.profiler import model_analyzer
     34 from tensorflow.python.training import gradient_descent
     35 from tensorflow.python.util import compat
     36 
     37 
     38 def BuildSmallModel():
     39   """Build a small forward conv model."""
     40   image = array_ops.zeros([2, 6, 6, 3])
     41   _ = variable_scope.get_variable(
     42       'ScalarW', [],
     43       dtypes.float32,
     44       initializer=init_ops.random_normal_initializer(stddev=0.001))
     45   kernel = variable_scope.get_variable(
     46       'DW', [3, 3, 3, 6],
     47       dtypes.float32,
     48       initializer=init_ops.random_normal_initializer(stddev=0.001))
     49   x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
     50   kernel = variable_scope.get_variable(
     51       'DW2', [2, 2, 6, 12],
     52       dtypes.float32,
     53       initializer=init_ops.random_normal_initializer(stddev=0.001))
     54   x = nn_ops.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME')
     55   return x
     56 
     57 
     58 def BuildFullModel():
     59   """Build the full model with conv,rnn,opt."""
     60   seq = []
     61   for i in range(4):
     62     with variable_scope.variable_scope('inp_%d' % i):
     63       seq.append(array_ops.reshape(BuildSmallModel(), [2, 1, -1]))
     64 
     65   cell = rnn_cell.BasicRNNCell(16)
     66   out = rnn.dynamic_rnn(
     67       cell, array_ops.concat(seq, axis=1), dtype=dtypes.float32)[0]
     68 
     69   target = array_ops.ones_like(out)
     70   loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out))
     71   sgd_op = gradient_descent.GradientDescentOptimizer(1e-2)
     72   return sgd_op.minimize(loss)
     73 
     74 
     75 def BuildSplitableModel():
     76   """Build a small model that can be run partially in each step."""
     77   image = array_ops.zeros([2, 6, 6, 3])
     78 
     79   kernel1 = variable_scope.get_variable(
     80       'DW', [3, 3, 3, 6],
     81       dtypes.float32,
     82       initializer=init_ops.random_normal_initializer(stddev=0.001))
     83   r1 = nn_ops.conv2d(image, kernel1, [1, 2, 2, 1], padding='SAME')
     84 
     85   kernel2 = variable_scope.get_variable(
     86       'DW2', [2, 3, 3, 6],
     87       dtypes.float32,
     88       initializer=init_ops.random_normal_initializer(stddev=0.001))
     89   r2 = nn_ops.conv2d(image, kernel2, [1, 2, 2, 1], padding='SAME')
     90 
     91   r3 = r1 + r2
     92   return r1, r2, r3
     93 
     94 
     95 def SearchTFProfNode(node, name):
     96   """Search a node in the tree."""
     97   if node.name == name:
     98     return node
     99   for c in node.children:
    100     r = SearchTFProfNode(c, name)
    101     if r: return r
    102   return None
    103 
    104 
    105 @contextlib.contextmanager
    106 def ProfilerFromFile(profile_file):
    107   """Initialize a profiler from profile file."""
    108   print_mdl.ProfilerFromFile(compat.as_bytes(profile_file))
    109   profiler = model_analyzer.Profiler.__new__(model_analyzer.Profiler)
    110   yield profiler
    111   print_mdl.DeleteProfiler()
    112 
    113 
    114 def CheckAndRemoveDoc(profile):
    115   assert 'Doc:' in profile
    116   start_pos = profile.find('Profile:')
    117   return profile[start_pos + 9:]
    118