1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """A test lib that defines some models.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import contextlib 21 22 from tensorflow.python import pywrap_tensorflow as print_mdl 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.ops import array_ops 25 from tensorflow.python.ops import init_ops 26 from tensorflow.python.ops import math_ops 27 from tensorflow.python.ops import nn_grad # pylint: disable=unused-import 28 from tensorflow.python.ops import nn_ops 29 from tensorflow.python.ops import rnn 30 from tensorflow.python.ops import rnn_cell 31 from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import 32 from tensorflow.python.ops import variable_scope 33 from tensorflow.python.profiler import model_analyzer 34 from tensorflow.python.training import gradient_descent 35 from tensorflow.python.util import compat 36 37 38 def BuildSmallModel(): 39 """Build a small forward conv model.""" 40 image = array_ops.zeros([2, 6, 6, 3]) 41 _ = variable_scope.get_variable( 42 'ScalarW', [], 43 dtypes.float32, 44 initializer=init_ops.random_normal_initializer(stddev=0.001)) 45 kernel = variable_scope.get_variable( 46 'DW', [3, 3, 3, 6], 47 dtypes.float32, 48 initializer=init_ops.random_normal_initializer(stddev=0.001)) 49 x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME') 50 kernel = variable_scope.get_variable( 51 'DW2', [2, 2, 6, 12], 52 dtypes.float32, 53 initializer=init_ops.random_normal_initializer(stddev=0.001)) 54 x = nn_ops.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME') 55 return x 56 57 58 def BuildFullModel(): 59 """Build the full model with conv,rnn,opt.""" 60 seq = [] 61 for i in range(4): 62 with variable_scope.variable_scope('inp_%d' % i): 63 seq.append(array_ops.reshape(BuildSmallModel(), [2, 1, -1])) 64 65 cell = rnn_cell.BasicRNNCell(16) 66 out = rnn.dynamic_rnn( 67 cell, array_ops.concat(seq, axis=1), dtype=dtypes.float32)[0] 68 69 target = array_ops.ones_like(out) 70 loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out)) 71 sgd_op = gradient_descent.GradientDescentOptimizer(1e-2) 72 return sgd_op.minimize(loss) 73 74 75 def BuildSplitableModel(): 76 """Build a small model that can be run partially in each step.""" 77 image = array_ops.zeros([2, 6, 6, 3]) 78 79 kernel1 = variable_scope.get_variable( 80 'DW', [3, 3, 3, 6], 81 dtypes.float32, 82 initializer=init_ops.random_normal_initializer(stddev=0.001)) 83 r1 = nn_ops.conv2d(image, kernel1, [1, 2, 2, 1], padding='SAME') 84 85 kernel2 = variable_scope.get_variable( 86 'DW2', [2, 3, 3, 6], 87 dtypes.float32, 88 initializer=init_ops.random_normal_initializer(stddev=0.001)) 89 r2 = nn_ops.conv2d(image, kernel2, [1, 2, 2, 1], padding='SAME') 90 91 r3 = r1 + r2 92 return r1, r2, r3 93 94 95 def SearchTFProfNode(node, name): 96 """Search a node in the tree.""" 97 if node.name == name: 98 return node 99 for c in node.children: 100 r = SearchTFProfNode(c, name) 101 if r: return r 102 return None 103 104 105 @contextlib.contextmanager 106 def ProfilerFromFile(profile_file): 107 """Initialize a profiler from profile file.""" 108 print_mdl.ProfilerFromFile(compat.as_bytes(profile_file)) 109 profiler = model_analyzer.Profiler.__new__(model_analyzer.Profiler) 110 yield profiler 111 print_mdl.DeleteProfiler() 112 113 114 def CheckAndRemoveDoc(profile): 115 assert 'Doc:' in profile 116 start_pos = profile.find('Profile:') 117 return profile[start_pos + 9:] 118