1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """A hybrid model that samples paths when training.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.contrib.tensor_forest.hybrid.python.layers import decisions_to_data 21 from tensorflow.contrib.tensor_forest.hybrid.python.layers import fully_connected 22 from tensorflow.contrib.tensor_forest.hybrid.python.models import hard_decisions_to_data_then_nn 23 from tensorflow.python.training import adagrad 24 25 26 class StochasticHardDecisionsToDataThenNN( 27 hard_decisions_to_data_then_nn.HardDecisionsToDataThenNN): 28 """A hybrid model that samples paths when training.""" 29 30 def __init__(self, 31 params, 32 device_assigner=None, 33 optimizer_class=adagrad.AdagradOptimizer, 34 **kwargs): 35 36 super(StochasticHardDecisionsToDataThenNN, self).__init__( 37 params, 38 device_assigner=device_assigner, 39 optimizer_class=optimizer_class, 40 **kwargs) 41 42 self.layers = [decisions_to_data.StochasticHardDecisionsToDataLayer( 43 params, 0, device_assigner), 44 fully_connected.FullyConnectedLayer( 45 params, 1, device_assigner=device_assigner)] 46