1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/lookup_util.h" 17 18 #include "tensorflow/core/framework/tensor.h" 19 #include "tensorflow/core/framework/tensor_shape.h" 20 #include "tensorflow/core/lib/core/errors.h" 21 #include "tensorflow/core/lib/io/inputbuffer.h" 22 23 namespace tensorflow { 24 namespace lookup { 25 namespace { 26 27 static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */ 28 static const int kLineNumber = -1; 29 static const int kWholeLine = -2; 30 31 Status GetNumLinesInTextFile(Env* env, const string& vocab_file, 32 int64* num_lines) { 33 std::unique_ptr<RandomAccessFile> file; 34 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file)); 35 36 io::InputBuffer input_buffer(file.get(), kInputBufferSize); 37 string line; 38 Status s = input_buffer.ReadLine(&line); 39 int64 next_id = 0; 40 while (s.ok()) { 41 next_id++; 42 s = input_buffer.ReadLine(&line); 43 } 44 if (!errors::IsOutOfRange(s)) { 45 return s; 46 } 47 *num_lines = next_id; 48 return Status::OK(); 49 } 50 51 // Iterator that reads a text file. Each iteration process one line, it parses 52 // the line and populates the keys and values tensors used for initialization 53 // with a single key and corresponding value. 54 // 55 // What information of the line to populate the key or values is specified by 56 // providing key_index and value_index. 57 class TextFileLineIterator 58 : public InitializableLookupTable::InitTableIterator { 59 public: 60 TextFileLineIterator() 61 : valid_(false), 62 vocab_size_(-1), 63 status_(errors::FailedPrecondition("Not initialized")) {} 64 65 // Initialize iterator. 66 // 67 // Prepares the file 'filename' and sets the data types to return the keys and 68 // values tensors. It requires the indices of the tokens in the line given a 69 // delimiter to specify where to pick the data from. 70 // 71 // - Index -2 means the entire line as string. 72 // - Index -1 means the line number stored in int64. 73 // - Index >= 0 represent index (starting at zero) of the split line based on 74 // delimiter. 75 Status Init(const string& filename, int64 vocab_size, char delimiter, 76 DataType key_dtype, int64 key_index, DataType value_dtype, 77 int64 value_index, Env* env) { 78 if (vocab_size == -1) { 79 TF_RETURN_IF_ERROR(GetNumLinesInTextFile(env, filename, &vocab_size)); 80 } 81 filename_ = filename; 82 vocab_size_ = vocab_size; 83 delimiter_ = delimiter; 84 key_ = Tensor(key_dtype, TensorShape({})); 85 value_ = Tensor(value_dtype, TensorShape({})); 86 key_index_ = key_index; 87 value_index_ = value_index; 88 89 status_ = env->NewRandomAccessFile(filename_, &file_); 90 if (!status_.ok()) return status_; 91 92 input_buffer_.reset(new io::InputBuffer(file_.get(), kInputBufferSize)); 93 valid_ = true; 94 next_id_ = 0; 95 ignore_split_ = std::max(key_index_, value_index_) < 0; 96 Next(); 97 return status_; 98 } 99 100 void Next() override { 101 if (!valid_) return; 102 103 string line; 104 status_ = input_buffer_->ReadLine(&line); 105 if (!status_.ok()) { 106 if (errors::IsOutOfRange(status_) && next_id_ != vocab_size_) { 107 status_ = errors::InvalidArgument("Invalid vocab_size in ", filename_, 108 ": expected ", vocab_size_, 109 " but got ", next_id_); 110 } 111 valid_ = false; 112 return; 113 } 114 if (next_id_ >= vocab_size_) { 115 LOG(WARNING) << "Truncated " << filename_ << " before its end at " 116 << vocab_size_ << " records."; 117 LOG(WARNING) << "next_id_ : " << next_id_; 118 status_ = errors::OutOfRange("Finished reading ", vocab_size_, 119 " of lines from ", filename_); 120 valid_ = false; 121 return; 122 } 123 if (line.empty()) { 124 status_ = errors::InvalidArgument("Invalid content in ", filename_, 125 ": empty line found at position ", 126 input_buffer_->Tell(), "."); 127 valid_ = false; 128 return; 129 } 130 131 std::vector<string> tokens; 132 if (!ignore_split_) { 133 tokens = str_util::Split(line, delimiter_); 134 if (std::max(key_index_, value_index_) >= tokens.size()) { 135 status_ = errors::InvalidArgument( 136 "Invalid number of columns in ", filename_, " line ", next_id_, 137 " (", line, ") : expected ", std::max(key_index_, value_index_), 138 " got ", tokens.size()); 139 valid_ = false; 140 return; 141 } 142 } 143 status_ = SetValue(line, tokens, key_index_, &key_); 144 if (!status_.ok()) { 145 valid_ = false; 146 return; 147 } 148 status_ = SetValue(line, tokens, value_index_, &value_); 149 if (!status_.ok()) { 150 valid_ = false; 151 return; 152 } 153 154 next_id_++; 155 } 156 157 bool Valid() const override { return valid_; } 158 159 const Tensor& keys() const override { return key_; } 160 161 const Tensor& values() const override { return value_; } 162 163 Status status() const override { return status_; } 164 165 int64 total_size() const override { return vocab_size_; } 166 167 private: 168 Tensor key_; 169 Tensor value_; 170 bool valid_; // true if the iterator points to an existing range. 171 int64 key_index_; 172 int64 value_index_; 173 int64 next_id_; 174 int64 vocab_size_; 175 string filename_; 176 char delimiter_; 177 Status status_; 178 bool ignore_split_; 179 std::unique_ptr<RandomAccessFile> file_; // must outlive input_buffer_ 180 std::unique_ptr<io::InputBuffer> input_buffer_; 181 182 // Set the corresponding value from line or tokens based on 'index' into the 183 // tensor 't'. The value is transformed to the given data type 'dtype'. 184 Status SetValue(const string& line, const std::vector<string>& tokens, 185 int64 index, Tensor* tensor) { 186 if (index == kLineNumber) { 187 tensor->flat<int64>()(0) = next_id_; 188 return Status::OK(); 189 } 190 const string& token = (index == kWholeLine) ? line : tokens[index]; 191 const DataType& dtype = tensor->dtype(); 192 switch (dtype) { 193 case DT_INT32: { 194 int32 value; 195 if (!strings::safe_strto32(token.c_str(), &value)) { 196 valid_ = false; 197 return errors::InvalidArgument("Field ", token, " in line ", next_id_, 198 " is not a valid int32."); 199 } 200 tensor->flat<int32>()(0) = value; 201 } break; 202 case DT_INT64: { 203 int64 value; 204 if (!strings::safe_strto64(token.c_str(), &value)) { 205 valid_ = false; 206 return errors::InvalidArgument("Field ", token, " in line ", next_id_, 207 " is not a valid int64."); 208 } 209 tensor->flat<int64>()(0) = value; 210 } break; 211 case DT_FLOAT: { 212 float value; 213 if (!strings::safe_strtof(token.c_str(), &value)) { 214 valid_ = false; 215 return errors::InvalidArgument("Field ", token, " in line ", next_id_, 216 " is not a valid float."); 217 } 218 tensor->flat<float>()(0) = value; 219 } break; 220 case DT_DOUBLE: { 221 double value; 222 if (!strings::safe_strtod(token.c_str(), &value)) { 223 valid_ = false; 224 return errors::InvalidArgument("Field ", token, " in line ", next_id_, 225 " is not a valid double."); 226 } 227 tensor->flat<double>()(0) = value; 228 } break; 229 case DT_STRING: 230 tensor->flat<string>()(0) = token; 231 break; 232 default: 233 valid_ = false; 234 return errors::InvalidArgument("Data type ", dtype, " not supported."); 235 } 236 return Status::OK(); 237 } 238 239 TF_DISALLOW_COPY_AND_ASSIGN(TextFileLineIterator); 240 }; 241 242 Status GetTableHandle(const string& input_name, OpKernelContext* ctx, 243 string* container, string* table_handle) { 244 { 245 mutex* mu; 246 TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu)); 247 mutex_lock l(*mu); 248 Tensor tensor; 249 TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true)); 250 if (tensor.NumElements() != 2) { 251 return errors::InvalidArgument( 252 "Lookup table handle must be scalar, but had shape: ", 253 tensor.shape().DebugString()); 254 } 255 auto h = tensor.flat<string>(); 256 *container = h(0); 257 *table_handle = h(1); 258 } 259 return Status::OK(); 260 } 261 262 } // namespace 263 264 Status GetLookupTable(const string& input_name, OpKernelContext* ctx, 265 LookupInterface** table) { 266 string container; 267 string table_handle; 268 DataType handle_dtype; 269 TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype)); 270 if (handle_dtype == DT_RESOURCE) { 271 ResourceHandle handle; 272 TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle)); 273 return LookupResource(ctx, handle, table); 274 } else { 275 TF_RETURN_IF_ERROR( 276 GetTableHandle(input_name, ctx, &container, &table_handle)); 277 return ctx->resource_manager()->Lookup(container, table_handle, table); 278 } 279 } 280 281 Status GetInitializableLookupTable(const string& input_name, 282 OpKernelContext* ctx, 283 InitializableLookupTable** table) { 284 LookupInterface* lookup_table; 285 DataType handle_dtype; 286 TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype)); 287 if (handle_dtype == DT_RESOURCE) { 288 ResourceHandle handle; 289 TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle)); 290 TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &lookup_table)); 291 *table = lookup_table->GetInitializableLookupTable(); 292 if (*table == nullptr) { 293 lookup_table->Unref(); 294 return errors::InvalidArgument("Table ", handle.container(), " ", 295 handle.name(), " is not initializable"); 296 } 297 } else { 298 string container; 299 string table_handle; 300 TF_RETURN_IF_ERROR( 301 GetTableHandle(input_name, ctx, &container, &table_handle)); 302 TF_RETURN_IF_ERROR(ctx->resource_manager()->Lookup(container, table_handle, 303 &lookup_table)); 304 *table = lookup_table->GetInitializableLookupTable(); 305 if (*table == nullptr) { 306 lookup_table->Unref(); 307 return errors::InvalidArgument("Table ", container, " ", table_handle, 308 " is not initializable"); 309 } 310 } 311 return Status::OK(); 312 } 313 314 Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype, 315 DataType value_dtype, const string& table_name) { 316 if (table.key_dtype() != key_dtype || table.value_dtype() != value_dtype) { 317 return errors::InvalidArgument( 318 "Conflicting key/value dtypes ", key_dtype, "->", value_dtype, " with ", 319 table.key_dtype(), "-", table.value_dtype(), " for table ", table_name); 320 } 321 return Status::OK(); 322 } 323 324 // Helper function to initialize an InitializableLookupTable from a text file. 325 Status InitializeTableFromTextFile(const string& filename, int64 vocab_size, 326 char delimiter, int32 key_index, 327 int32 value_index, Env* env, 328 InitializableLookupTable* table) { 329 if (key_index == kLineNumber && table->key_dtype() != DT_INT64) { 330 return errors::InvalidArgument( 331 "Key index for line number requires table key dtype of int64, got ", 332 table->key_dtype()); 333 } 334 const DataType& key_dtype = table->key_dtype(); 335 const DataType& value_dtype = table->value_dtype(); 336 if (key_index == kWholeLine && !DataTypeIsInteger(key_dtype) && 337 key_dtype != DT_STRING) { 338 return errors::InvalidArgument( 339 "Key index for whole line requires string or integer table key, got ", 340 table->key_dtype()); 341 } 342 if (value_index == kLineNumber && value_dtype != DT_INT64) { 343 return errors::InvalidArgument( 344 "Value index for line number requires table value dtype of int64, got ", 345 table->value_dtype()); 346 } 347 if (value_index == kWholeLine && value_dtype != DT_STRING) { 348 return errors::InvalidArgument( 349 "Value index for whole line requires table value dtype of string, got ", 350 table->value_dtype()); 351 } 352 353 TextFileLineIterator iter; 354 TF_RETURN_IF_ERROR(iter.Init(filename, vocab_size, delimiter, key_dtype, 355 key_index, value_dtype, value_index, env)); 356 // For initialization from files, ignore if the table is already 357 // initialized. The table shared name should contain the filename to 358 // avoid trying to initialize the same table from the same file at the same 359 // time. 360 Status s = table->Initialize(iter); 361 if (errors::IsFailedPrecondition(s) && table->is_initialized()) { 362 LOG(INFO) << "Table trying to initialize from file " << filename 363 << " is already initialized."; 364 return Status::OK(); 365 } 366 return s; 367 } 368 369 } // namespace lookup 370 } // namespace tensorflow 371