Home | History | Annotate | Download | only in specs
      1 #
      2 # Copyright (C) 2017 The Android Open Source Project
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #      http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 #
     16 
     17 # LSTM Test, With Cifg, With Peephole, No Projection, No Clipping.
     18 
     19 model = Model()
     20 
     21 n_batch = 1
     22 n_input = 2
     23 # n_cell and n_output have the same size when there is no projection.
     24 n_cell = 4
     25 n_output = 4
     26 
     27 input = Input("input", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_input))
     28 
     29 input_to_input_weights = Input("input_to_input_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
     30 input_to_forget_weights = Input("input_to_forget_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
     31 input_to_cell_weights = Input("input_to_cell_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
     32 input_to_output_weights = Input("input_to_output_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
     33 
     34 recurrent_to_input_weights = Input("recurrent_to_intput_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
     35 recurrent_to_forget_weights = Input("recurrent_to_forget_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
     36 recurrent_to_cell_weights = Input("recurrent_to_cell_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
     37 recurrent_to_output_weights = Input("recurrent_to_output_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
     38 
     39 cell_to_input_weights = Input("cell_to_input_weights", "TENSOR_FLOAT32", "{0}")
     40 cell_to_forget_weights = Input("cell_to_forget_weights", "TENSOR_FLOAT32", "{%d}" % (n_cell))
     41 cell_to_output_weights = Input("cell_to_output_weights", "TENSOR_FLOAT32", "{%d}" % (n_cell))
     42 
     43 input_gate_bias = Input("input_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
     44 forget_gate_bias = Input("forget_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
     45 cell_gate_bias = Input("cell_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
     46 output_gate_bias = Input("output_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
     47 
     48 projection_weights = Input("projection_weights", "TENSOR_FLOAT32", "{0,0}")
     49 projection_bias = Input("projection_bias", "TENSOR_FLOAT32", "{0}")
     50 
     51 output_state_in = Input("output_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output))
     52 cell_state_in = Input("cell_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell))
     53 
     54 activation_param = Input("activation_param", "TENSOR_INT32", "{1}");
     55 cell_clip_param = Input("cell_clip_param", "TENSOR_FLOAT32", "{1}")
     56 proj_clip_param = Input("proj_clip_param", "TENSOR_FLOAT32", "{1}")
     57 
     58 scratch_buffer = IgnoredOutput("scratch_buffer", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell * 3))
     59 output_state_out = Output("output_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output))
     60 cell_state_out = Output("cell_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell))
     61 output = Output("output", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output))
     62 
     63 model = model.Operation("LSTM",
     64                         input,
     65 
     66                         input_to_input_weights,
     67                         input_to_forget_weights,
     68                         input_to_cell_weights,
     69                         input_to_output_weights,
     70 
     71                         recurrent_to_input_weights,
     72                         recurrent_to_forget_weights,
     73                         recurrent_to_cell_weights,
     74                         recurrent_to_output_weights,
     75 
     76                         cell_to_input_weights,
     77                         cell_to_forget_weights,
     78                         cell_to_output_weights,
     79 
     80                         input_gate_bias,
     81                         forget_gate_bias,
     82                         cell_gate_bias,
     83                         output_gate_bias,
     84 
     85                         projection_weights,
     86                         projection_bias,
     87 
     88                         output_state_in,
     89                         cell_state_in,
     90 
     91                         activation_param,
     92                         cell_clip_param,
     93                         proj_clip_param
     94 ).To([scratch_buffer, output_state_out, cell_state_out, output])
     95 
     96 input0 = {input_to_input_weights:[],
     97           input_to_cell_weights: [-0.49770179, -0.27711356, -0.09624726, 0.05100781, 0.04717243, 0.48944736, -0.38535351, -0.17212132],
     98           input_to_forget_weights: [-0.55291498, -0.42866567, 0.13056988, -0.3633365, -0.22755712, 0.28253698, 0.24407166, 0.33826375],
     99           input_to_output_weights: [0.10725588, -0.02335852, -0.55932593, -0.09426838, -0.44257352, 0.54939759, 0.01533556, 0.42751634],
    100 
    101           input_gate_bias:  [],
    102           forget_gate_bias: [1.,1.,1.,1.],
    103           cell_gate_bias:   [0.,0.,0.,0.],
    104           output_gate_bias: [0.,0.,0.,0.],
    105 
    106           recurrent_to_input_weights: [],
    107           recurrent_to_cell_weights: [
    108               0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
    109               0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
    110               0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
    111               0.21193194],
    112 
    113           recurrent_to_forget_weights: [
    114               -0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
    115             0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
    116             -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349],
    117 
    118           recurrent_to_output_weights: [
    119               0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
    120               -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
    121               0.50248802, 0.26114327, -0.43736315, 0.33149987],
    122 
    123           cell_to_input_weights: [],
    124           cell_to_forget_weights: [0.47485286, -0.51955009, -0.24458408, 0.31544167],
    125           cell_to_output_weights: [-0.17135078, 0.82760304, 0.85573703, -0.77109635],
    126 
    127           projection_weights: [],
    128           projection_bias: [],
    129 
    130           activation_param: [4],  # Tanh
    131           cell_clip_param: [0.],
    132           proj_clip_param: [0.],
    133 }
    134 
    135 output0 = {
    136     scratch_buffer: [ 0 for x in range(n_batch * n_cell * 4) ],
    137     cell_state_out: [ -0.760444, -0.0180416, 0.182264, -0.0649371 ],
    138     output_state_out: [ -0.364445, -0.00352185, 0.128866, -0.0516365 ],
    139 }
    140 
    141 input0[input] = [2., 3.]
    142 input0[output_state_in] = [ 0 for _ in range(n_batch * n_output) ]
    143 input0[cell_state_in] = [ 0 for _ in range(n_batch * n_cell) ]
    144 output0[output] = [-0.36444446, -0.00352185, 0.12886585, -0.05163646]
    145 
    146 Example((input0, output0))
    147