1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Defines the layer abstraction for hybrid models.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.contrib.framework.python.ops import variables as framework_variables 21 22 23 class HybridLayer(object): 24 """Layers are building blocks for hybrid models.""" 25 26 def _define_vars(self, 27 params, 28 **kwargs): 29 """Override to define the TensorFlow variables for the layer.""" 30 raise NotImplementedError 31 32 # pylint: disable=unused-argument 33 def __init__(self, params, layer_num, device_assigner, *args, **kwargs): 34 self.layer_num = layer_num 35 self.device_assigner = ( 36 device_assigner or framework_variables.VariableDeviceChooser()) 37 self.params = params 38 self._define_vars(params, **kwargs) 39 40 def inference_graph(self, data, data_spec=None): 41 raise NotImplementedError 42