1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <iostream> 16 #include <string> 17 #include <vector> 18 19 #include "tensorflow/contrib/lite/toco/model.h" 20 #include "tensorflow/contrib/lite/toco/tooling_util.h" 21 22 namespace toco { 23 24 // For consistency with the parameters defined in extended LstmCell's kernel 25 // (tensorflow/contrib/lite/kernels/lstm.cc), 26 // use lowercase for these constants. 27 28 enum ExtendedLstmCellInputs { 29 kInputTensor = 0, 30 kInputToInputWeightsTensor = 1, // Optional 31 kInputToForgetWeightsTensor = 2, 32 kInputToCellWeightsTensor = 3, 33 kInputToOutputWeightsTensor = 4, 34 kRecurrentToInputWeightsTensor = 5, // Optional 35 kRecurrentToForgetWeightsTensor = 6, 36 kRecurrentToCellWeightsTensor = 7, 37 kRecurrentToOutputWeightsTensor = 8, 38 kCellToInputWeightsTensor = 9, // Optional 39 kCellToForgetWeightsTensor = 10, // Optional 40 kCellToOutputWeightsTensor = 11, // Optional 41 kInputGateBiasTensor = 12, // Optional 42 kForgetGateBiasTensor = 13, 43 kCellGateBiasTensor = 14, 44 kOutputGateBiasTensor = 15, 45 kProjectionWeightsTensor = 16, // Optional 46 kProjectionBiasTensor = 17, // Optional 47 kExtendedLstmInputCount = 18 48 }; 49 50 enum ExtendedLstmCellOutputs { 51 kScratchBufferTensor = 0, 52 kOutputStateTensor = 1, 53 kCellStateTensor = 2, 54 kOutputTensor = 3 55 }; 56 57 // Create optional array used for optional tensor in ExtendedLstmCell inputs. 58 void CreateOptionalArray(Model* model, string* input_array_buffer, 59 const string& array_name); 60 61 // Create float array and get its buffer. 62 Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model, 63 string* array_name, 64 const Shape& shape); 65 66 // Copy data from one array to the other one (supports 1D and 2D array), 67 // for 1D array, the 2nd dim's size is 1. 68 // Arguments: 69 // src_buffer: the source buffer 70 // src_stride: the stride of source buffer, i.e., 2nd dim's size 71 // src_start_idx1: the 1st dim index of start point in src matrix 72 // src_start_idx2: the 2nd dim index of start point in src matrix 73 // dst_buffer: the destination buffer 74 // dst_stride: the stride of destination buffer, i.e., 2nd dim's size 75 // dst_start_idx1: the 1st dim index of start point in dst matrix 76 // dst_start_idx2: the 2nd dim index of start point in dst matrix 77 // dim1_copy_size: 1st dim size of copy data 78 // dim2_copy_size: 2nd dim size of copy data 79 void CopyArrayData(const Buffer<ArrayDataType::kFloat>& src_buffer, 80 int src_stride, int src_start_idx1, int src_start_idx2, 81 Buffer<ArrayDataType::kFloat>* dst_buffer, int dst_stride, 82 int dst_start_idx1, int dst_start_idx2, int dim1_copy_size, 83 int dim2_copy_size); 84 85 // Copy a subset of array data and create a smaller array, 86 // mostly used for spliting weights and bias for Lstm cell. 87 void CopySubArrayToArray(Model* model, string* array_name, 88 const string& tensor_name, int dim1_size, 89 int dim2_size, const Array& original_array, 90 int start_idx1, int start_idx2); 91 92 // Copy array data to a large array's submatrix, 93 // mostly used for merging weights and bias for Lstm cell. 94 void CopyArrayToSubArray(Buffer<ArrayDataType::kFloat>& tensor_buffer, 95 int tensor_stride, const Array& sub_array, 96 int start_idx1, int start_idx2); 97 98 // Get mating rnn array inputs using rnn_states flag. 99 bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array, 100 string* rnn_array); 101 102 } // namespace toco 103