Home | History | Annotate | Download | only in P_lstm
      1 #
      2 # Copyright (C) 2017 The Android Open Source Project
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #      http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 #
     16 
     17 # LSTM Test: No Cifg, No Peephole, No Projection, and No Clipping.
     18 
     19 model = Model()
     20 
     21 n_batch = 1
     22 n_input = 2
     23 # n_cell and n_output have the same size when there is no projection.
     24 n_cell = 4
     25 n_output = 4
     26 
     27 input = Input("input", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_input))
     28 
     29 input_to_input_weights = Input("input_to_input_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
     30 input_to_forget_weights = Input("input_to_forget_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
     31 input_to_cell_weights = Input("input_to_cell_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
     32 input_to_output_weights = Input("input_to_output_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
     33 
     34 recurrent_to_input_weights = Input("recurrent_to_intput_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
     35 recurrent_to_forget_weights = Input("recurrent_to_forget_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
     36 recurrent_to_cell_weights = Input("recurrent_to_cell_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
     37 recurrent_to_output_weights = Input("recurrent_to_output_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
     38 
     39 cell_to_input_weights = Input("cell_to_input_weights", "TENSOR_FLOAT32", "{0}")
     40 cell_to_forget_weights = Input("cell_to_forget_weights", "TENSOR_FLOAT32", "{0}")
     41 cell_to_output_weights = Input("cell_to_output_weights", "TENSOR_FLOAT32", "{0}")
     42 
     43 input_gate_bias = Input("input_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
     44 forget_gate_bias = Input("forget_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
     45 cell_gate_bias = Input("cell_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
     46 output_gate_bias = Input("output_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
     47 
     48 projection_weights = Input("projection_weights", "TENSOR_FLOAT32", "{0,0}")
     49 projection_bias = Input("projection_bias", "TENSOR_FLOAT32", "{0}")
     50 
     51 output_state_in = Input("output_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output))
     52 cell_state_in = Input("cell_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell))
     53 
     54 activation_param = Input("activation_param", "TENSOR_INT32", "{1}")
     55 cell_clip_param = Input("cell_clip_param", "TENSOR_FLOAT32", "{1}")
     56 proj_clip_param = Input("proj_clip_param", "TENSOR_FLOAT32", "{1}")
     57 
     58 scratch_buffer = IgnoredOutput("scratch_buffer", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, (n_cell * 4)))
     59 output_state_out = IgnoredOutput("output_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output))
     60 cell_state_out = IgnoredOutput("cell_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell))
     61 output = Output("output", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output))
     62 
     63 model = model.Operation("LSTM",
     64                         input,
     65 
     66                         input_to_input_weights,
     67                         input_to_forget_weights,
     68                         input_to_cell_weights,
     69                         input_to_output_weights,
     70 
     71                         recurrent_to_input_weights,
     72                         recurrent_to_forget_weights,
     73                         recurrent_to_cell_weights,
     74                         recurrent_to_output_weights,
     75 
     76                         cell_to_input_weights,
     77                         cell_to_forget_weights,
     78                         cell_to_output_weights,
     79 
     80                         input_gate_bias,
     81                         forget_gate_bias,
     82                         cell_gate_bias,
     83                         output_gate_bias,
     84 
     85                         projection_weights,
     86                         projection_bias,
     87 
     88                         output_state_in,
     89                         cell_state_in,
     90 
     91                         activation_param,
     92                         cell_clip_param,
     93                         proj_clip_param
     94 ).To([scratch_buffer, output_state_out, cell_state_out, output])
     95 
     96 # Example 1. Input in operand 0,
     97 input0 = {input_to_input_weights:  [-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912, -0.15680569, -0.34856534, 0.43890524],
     98           input_to_forget_weights: [0.09701663, 0.20334584, -0.50592935, -0.31343272, -0.40032279, 0.44781327, 0.01387155, -0.35593212],
     99           input_to_cell_weights:   [-0.50013041, 0.1370284, 0.11810488, 0.2013163, -0.20583314, 0.44344562, 0.22077113, -0.29909778],
    100           input_to_output_weights: [-0.25065863, -0.28290087, 0.04613829, 0.40525138, 0.44272184, 0.03897077, -0.1556896, 0.19487578],
    101 
    102           input_gate_bias:  [0.,0.,0.,0.],
    103           forget_gate_bias: [1.,1.,1.,1.],
    104           cell_gate_bias:   [0.,0.,0.,0.],
    105           output_gate_bias: [0.,0.,0.,0.],
    106 
    107           recurrent_to_input_weights: [
    108               -0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
    109             -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
    110             -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296],
    111 
    112           recurrent_to_cell_weights: [
    113               -0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
    114             -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
    115             -0.46367589, 0.26016325, -0.03894562, -0.16368064],
    116 
    117           recurrent_to_forget_weights: [
    118               -0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
    119             -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
    120             0.28053468, 0.01560611, -0.20127171, -0.01140004],
    121 
    122           recurrent_to_output_weights: [
    123               0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
    124               0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
    125               -0.51818722, -0.15390486, 0.0468148, 0.39922136],
    126 
    127           cell_to_input_weights: [],
    128           cell_to_forget_weights: [],
    129           cell_to_output_weights: [],
    130 
    131           projection_weights: [],
    132           projection_bias: [],
    133 
    134           activation_param: [4],  # Tanh
    135           cell_clip_param: [0.],
    136           proj_clip_param: [0.],
    137 }
    138 
    139 # Instantiate examples
    140 # TODO: Add more examples after fixing the reference issue
    141 test_inputs = [
    142     [2., 3.],
    143 #    [3., 4.],[1., 1.]
    144 ]
    145 golden_outputs = [
    146     [-0.02973187, 0.1229473, 0.20885126, -0.15358765,],
    147 #    [-0.03716109, 0.12507336, 0.41193449,  -0.20860538],
    148 #    [-0.15053082, 0.09120187,  0.24278517,  -0.12222792]
    149 ]
    150 
    151 for (input_tensor, output_tensor) in zip(test_inputs, golden_outputs):
    152   output0 = {
    153       scratch_buffer: [ 0 for x in range(n_batch * n_cell * 4) ],
    154       cell_state_out: [ 0 for x in range(n_batch * n_cell) ],
    155       output_state_out: [ 0 for x in range(n_batch * n_output) ],
    156       output: output_tensor
    157   }
    158   input0[input] = input_tensor
    159   input0[output_state_in] = [ 0 for _ in range(n_batch * n_output) ]
    160   input0[cell_state_in] = [ 0 for _ in range(n_batch * n_cell) ]
    161   Example((input0, output0))
    162