Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/text_literal_reader.h"
     17 
     18 #include <limits>
     19 #include <string>
     20 #include <utility>
     21 #include <vector>
     22 
     23 #include "tensorflow/compiler/xla/literal_util.h"
     24 #include "tensorflow/compiler/xla/ptr_util.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/status_macros.h"
     27 #include "tensorflow/compiler/xla/types.h"
     28 #include "tensorflow/compiler/xla/util.h"
     29 #include "tensorflow/compiler/xla/xla_data.pb.h"
     30 #include "tensorflow/core/lib/core/stringpiece.h"
     31 #include "tensorflow/core/lib/io/buffered_inputstream.h"
     32 #include "tensorflow/core/lib/io/random_inputstream.h"
     33 #include "tensorflow/core/lib/strings/str_util.h"
     34 #include "tensorflow/core/platform/protobuf.h"
     35 #include "tensorflow/core/platform/types.h"
     36 
     37 namespace xla {
     38 
     39 StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath(
     40     tensorflow::StringPiece path) {
     41   CHECK(!path.ends_with(".gz"))
     42       << "TextLiteralReader no longer supports reading .gz files";
     43   std::unique_ptr<tensorflow::RandomAccessFile> file;
     44   Status s =
     45       tensorflow::Env::Default()->NewRandomAccessFile(path.ToString(), &file);
     46   if (!s.ok()) {
     47     return s;
     48   }
     49 
     50   TextLiteralReader reader(file.release());
     51   return reader.ReadAllLines();
     52 }
     53 
     54 TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file)
     55     : file_(file) {}
     56 
     57 namespace {
     58 // This is an optimized version of tensorflow::str_util::Split which uses
     59 // StringPiece for the delimited strings and uses an out parameter for the
     60 // result to avoid vector creation/destruction.
     61 void SplitByDelimToStringPieces(tensorflow::StringPiece text, char delim,
     62                                 std::vector<tensorflow::StringPiece>* result) {
     63   result->clear();
     64 
     65   if (text.empty()) {
     66     return;
     67   }
     68 
     69   // The following loop is a little strange: its bound is text.size() + 1
     70   // instead of the more typical text.size().
     71   // The final iteration of the loop (when i is equal to text.size()) handles
     72   // the trailing token.
     73   size_t token_start = 0;
     74   for (size_t i = 0; i < text.size() + 1; i++) {
     75     if (i == text.size() || text[i] == delim) {
     76       tensorflow::StringPiece token(text.data() + token_start, i - token_start);
     77       result->push_back(token);
     78       token_start = i + 1;
     79     }
     80   }
     81 }
     82 }  // namespace
     83 
     84 StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
     85   tensorflow::io::RandomAccessInputStream stream(file_.get());
     86   tensorflow::io::BufferedInputStream buf(&stream, 65536);
     87   string shape_string;
     88   Status s = buf.ReadLine(&shape_string);
     89   if (!s.ok()) {
     90     return s;
     91   }
     92 
     93   tensorflow::StringPiece sp(shape_string);
     94   if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) {
     95     string tmp = sp.ToString();
     96     shape_string = tmp;
     97   }
     98   TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string));
     99   if (shape.element_type() != F32) {
    100     return Unimplemented(
    101         "unsupported element type for text literal reading: %s",
    102         ShapeUtil::HumanString(shape).c_str());
    103   }
    104 
    105   auto result = MakeUnique<Literal>(shape);
    106   const float fill = std::numeric_limits<float>::quiet_NaN();
    107   result->PopulateWithValue<float>(fill);
    108   std::vector<tensorflow::StringPiece> pieces;
    109   std::vector<tensorflow::StringPiece> coordinates;
    110   std::vector<int64> coordinate_values;
    111   string line;
    112   while (buf.ReadLine(&line).ok()) {
    113     SplitByDelimToStringPieces(line, ':', &pieces);
    114     tensorflow::StringPiece coordinates_string = pieces[0];
    115     tensorflow::StringPiece value_string = pieces[1];
    116     tensorflow::str_util::RemoveWhitespaceContext(&coordinates_string);
    117     tensorflow::str_util::RemoveWhitespaceContext(&value_string);
    118     if (!coordinates_string.Consume("(")) {
    119       return InvalidArgument(
    120           "expected '(' at the beginning of coordinates: \"%s\"", line.c_str());
    121     }
    122     if (!tensorflow::str_util::ConsumeSuffix(&coordinates_string, ")")) {
    123       return InvalidArgument("expected ')' at the end of coordinates: \"%s\"",
    124                              line.c_str());
    125     }
    126     float value;
    127     if (!tensorflow::strings::safe_strtof(value_string.ToString().c_str(),
    128                                           &value)) {
    129       return InvalidArgument("could not parse value as float: \"%s\"",
    130                              value_string.ToString().c_str());
    131     }
    132     SplitByDelimToStringPieces(coordinates_string, ',', &coordinates);
    133     coordinate_values.clear();
    134     for (tensorflow::StringPiece piece : coordinates) {
    135       int64 coordinate_value;
    136       if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) {
    137         return InvalidArgument(
    138             "could not parse coordinate member as int64: \"%s\"",
    139             piece.ToString().c_str());
    140       }
    141       coordinate_values.push_back(coordinate_value);
    142     }
    143     if (coordinate_values.size() != shape.dimensions_size()) {
    144       return InvalidArgument(
    145           "line did not have expected number of coordinates; want %d got %zu: "
    146           "\"%s\"",
    147           shape.dimensions_size(), coordinate_values.size(), line.c_str());
    148     }
    149     result->Set<float>(coordinate_values, value);
    150   }
    151   return std::move(result);
    152 }
    153 
    154 }  // namespace xla
    155