Home | History | Annotate | Download | only in V1_0
      1 #
      2 # Copyright (C) 2017 The Android Open Source Project
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #      http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 #
     16 
     17 batches = 2
     18 units = 16
     19 input_size = 8
     20 
     21 model = Model()
     22 
     23 input = Input("input", "TENSOR_FLOAT32", "{%d, %d}" % (batches, input_size))
     24 weights = Input("weights", "TENSOR_FLOAT32", "{%d, %d}" % (units, input_size))
     25 recurrent_weights = Input("recurrent_weights", "TENSOR_FLOAT32", "{%d, %d}" % (units, units))
     26 bias = Input("bias", "TENSOR_FLOAT32", "{%d}" % (units))
     27 hidden_state_in = Input("hidden_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (batches, units))
     28 
     29 activation_param = Int32Scalar("activation_param", 1)  # Relu
     30 
     31 hidden_state_out = IgnoredOutput("hidden_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (batches, units))
     32 output = Output("output", "TENSOR_FLOAT32", "{%d, %d}" % (batches, units))
     33 
     34 model = model.Operation("RNN", input, weights, recurrent_weights, bias, hidden_state_in,
     35                         activation_param).To([hidden_state_out, output])
     36 
     37 input0 = {
     38     weights: [
     39         0.461459,    0.153381,   0.529743,    -0.00371218, 0.676267,   -0.211346,
     40        0.317493,    0.969689,   -0.343251,   0.186423,    0.398151,   0.152399,
     41        0.448504,    0.317662,   0.523556,    -0.323514,   0.480877,   0.333113,
     42        -0.757714,   -0.674487,  -0.643585,   0.217766,    -0.0251462, 0.79512,
     43        -0.595574,   -0.422444,  0.371572,    -0.452178,   -0.556069,  -0.482188,
     44        -0.685456,   -0.727851,  0.841829,    0.551535,    -0.232336,  0.729158,
     45        -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,   -0.423241,
     46        0.548547,    -0.0152023, -0.757482,   -0.85491,    0.251331,   -0.989183,
     47        0.306261,    -0.340716,  0.886103,    -0.0726757,  -0.723523,  -0.784303,
     48        0.0354295,   0.566564,   -0.485469,   -0.620498,   0.832546,   0.697884,
     49        -0.279115,   0.294415,   -0.584313,   0.548772,    0.0648819,  0.968726,
     50        0.723834,    -0.0080452, -0.350386,   -0.272803,   0.115121,   -0.412644,
     51        -0.824713,   -0.992843,  -0.592904,   -0.417893,   0.863791,   -0.423461,
     52        -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,   -0.639158,
     53        0.816969,    -0.337228,  0.659878,    0.73107,     0.754768,   -0.337042,
     54        0.0960841,   0.368357,   0.244191,    -0.817703,   -0.211223,  0.442012,
     55        0.37225,     -0.623598,  -0.405423,   0.455101,    0.673656,   -0.145345,
     56        -0.511346,   -0.901675,  -0.81252,    -0.127006,   0.809865,   -0.721884,
     57        0.636255,    0.868989,   -0.347973,   -0.10179,    -0.777449,  0.917274,
     58        0.819286,    0.206218,   -0.00785118, 0.167141,    0.45872,    0.972934,
     59        -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057,  -0.469077,
     60        0.277308,    0.415818
     61     ],
     62     recurrent_weights: [
     63         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     64         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     65         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     66         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     67         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     68         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     69         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     70         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     71         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     72         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     73         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     74         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     75         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     76         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     77         0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     78         0.1
     79     ],
     80     bias: [
     81         0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
     82         -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
     83         0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
     84         -0.37609905
     85     ],
     86 }
     87 
     88 
     89 test_inputs = [
     90     0.23689353,   0.285385,     0.037029743, -0.19858193,  -0.27569133,
     91     0.43773448,   0.60379338,   0.35562468,  -0.69424844,  -0.93421471,
     92     -0.87287879,  0.37144363,   -0.62476718, 0.23791671,   0.40060222,
     93     0.1356622,    -0.99774903,  -0.98858172, -0.38952237,  -0.47685933,
     94     0.31073618,   0.71511042,   -0.63767755, -0.31729108,  0.33468103,
     95     0.75801885,   0.30660987,   -0.37354088, 0.77002847,   -0.62747043,
     96     -0.68572164,  0.0069220066, 0.65791464,  0.35130811,   0.80834007,
     97     -0.61777675,  -0.21095741,  0.41213346,  0.73784804,   0.094794154,
     98     0.47791874,   0.86496925,   -0.53376222, 0.85315156,   0.10288584,
     99     0.86684,      -0.011186242, 0.10513687,  0.87825835,   0.59929144,
    100     0.62827742,   0.18899453,   0.31440187,  0.99059987,   0.87170351,
    101     -0.35091716,  0.74861872,   0.17831337,  0.2755419,    0.51864719,
    102     0.55084288,   0.58982027,   -0.47443086, 0.20875752,   -0.058871567,
    103     -0.66609079,  0.59098077,   0.73017097,  0.74604273,   0.32882881,
    104     -0.17503482,  0.22396147,   0.19379807,  0.29120302,   0.077113032,
    105     -0.70331609,  0.15804303,   -0.93407321, 0.40182066,   0.036301374,
    106     0.66521823,   0.0300982,    -0.7747041,  -0.02038002,  0.020698071,
    107     -0.90300065,  0.62870288,   -0.23068321, 0.27531278,   -0.095755219,
    108     -0.712036,    -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,
    109     0.43519354,   0.14744234,   0.62589407,  0.1653645,    -0.10651493,
    110     -0.045277178, 0.99032974,   -0.88255352, -0.85147917,  0.28153265,
    111     0.19455957,   -0.55479527,  -0.56042433, 0.26048636,   0.84702539,
    112     0.47587705,   -0.074295521, -0.12287641, 0.70117295,   0.90532446,
    113     0.89782166,   0.79817224,   0.53402734,  -0.33286154,  0.073485017,
    114     -0.56172788,  -0.044897556, 0.89964068,  -0.067662835, 0.76863563,
    115     0.93455386,   -0.6324693,   -0.083922029
    116 ]
    117 
    118 golden_outputs = [
    119     0.496726,   0,          0.965996,  0,         0.0584254, 0,
    120     0,          0.12315,    0,         0,         0.612266,  0.456601,
    121     0,          0.52286,    1.16099,   0.0291232,
    122 
    123     0,          0,          0.524901,  0,         0,         0,
    124     0,          1.02116,    0,         1.35762,   0,         0.356909,
    125     0.436415,   0.0355727,  0,         0,
    126 
    127     0,          0,          0,         0.262335,  0,         0,
    128     0,          1.33992,    0,         2.9739,    0,         0,
    129     1.31914,    2.66147,    0,         0,
    130 
    131     0.942568,   0,          0,         0,         0.025507,  0,
    132     0,          0,          0.321429,  0.569141,  1.25274,   1.57719,
    133     0.8158,     1.21805,    0.586239,  0.25427,
    134 
    135     1.04436,    0,          0.630725,  0,         0.133801,  0.210693,
    136     0.363026,   0,          0.533426,  0,         1.25926,   0.722707,
    137     0,          1.22031,    1.30117,   0.495867,
    138 
    139     0.222187,   0,          0.72725,   0,         0.767003,  0,
    140     0,          0.147835,   0,         0,         0,         0.608758,
    141     0.469394,   0.00720298, 0.927537,  0,
    142 
    143     0.856974,   0.424257,   0,         0,         0.937329,  0,
    144     0,          0,          0.476425,  0,         0.566017,  0.418462,
    145     0.141911,   0.996214,   1.13063,   0,
    146 
    147     0.967899,   0,          0,         0,         0.0831304, 0,
    148     0,          1.00378,    0,         0,         0,         1.44818,
    149     1.01768,    0.943891,   0.502745,  0,
    150 
    151     0.940135,   0,          0,         0,         0,         0,
    152     0,          2.13243,    0,         0.71208,   0.123918,  1.53907,
    153     1.30225,    1.59644,    0.70222,   0,
    154 
    155     0.804329,   0,          0.430576,  0,         0.505872,  0.509603,
    156     0.343448,   0,          0.107756,  0.614544,  1.44549,   1.52311,
    157     0.0454298,  0.300267,   0.562784,  0.395095,
    158 
    159     0.228154,   0,          0.675323,  0,         1.70536,   0.766217,
    160     0,          0,          0,         0.735363,  0.0759267, 1.91017,
    161     0.941888,   0,          0,         0,
    162 
    163     0,          0,          1.5909,    0,         0,         0,
    164     0,          0.5755,     0,         0.184687,  0,         1.56296,
    165     0.625285,   0,          0,         0,
    166 
    167     0,          0,          0.0857888, 0,         0,         0,
    168     0,          0.488383,   0.252786,  0,         0,         0,
    169     1.02817,    1.85665,    0,         0,
    170 
    171     0.00981836, 0,          1.06371,   0,         0,         0,
    172     0,          0,          0,         0.290445,  0.316406,  0,
    173     0.304161,   1.25079,    0.0707152, 0,
    174 
    175     0.986264,   0.309201,   0,         0,         0,         0,
    176     0,          1.64896,    0.346248,  0,         0.918175,  0.78884,
    177     0.524981,   1.92076,    2.07013,   0.333244,
    178 
    179     0.415153,   0.210318,   0,         0,         0,         0,
    180     0,          2.02616,    0,         0.728256,  0.84183,   0.0907453,
    181     0.628881,   3.58099,    1.49974,   0
    182 ]
    183 
    184 input_sequence_size = int(len(test_inputs) / input_size / batches)
    185 
    186 # TODO: enable the other data points after fixing reference issues
    187 #for i in range(input_sequence_size):
    188 for i in range(1):
    189   input_begin = i * input_size
    190   input_end = input_begin + input_size
    191   input0[input] = test_inputs[input_begin:input_end]
    192   input0[input].extend(input0[input])
    193   input0[hidden_state_in] = [0 for x in range(batches * units)]
    194   output0 = {
    195     hidden_state_out: [0 for x in range(batches * units)],
    196   }
    197   golden_start = i * units
    198   golden_end = golden_start + units
    199   output0[output] = golden_outputs[golden_start:golden_end]
    200   output0[output].extend(output0[output])
    201   Example((input0, output0))
    202