1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/text_literal_reader.h" 17 18 #include <limits> 19 #include <string> 20 #include <utility> 21 #include <vector> 22 23 #include "tensorflow/compiler/xla/literal_util.h" 24 #include "tensorflow/compiler/xla/ptr_util.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/status_macros.h" 27 #include "tensorflow/compiler/xla/types.h" 28 #include "tensorflow/compiler/xla/util.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/lib/core/stringpiece.h" 31 #include "tensorflow/core/lib/io/buffered_inputstream.h" 32 #include "tensorflow/core/lib/io/random_inputstream.h" 33 #include "tensorflow/core/lib/strings/str_util.h" 34 #include "tensorflow/core/platform/protobuf.h" 35 #include "tensorflow/core/platform/types.h" 36 37 namespace xla { 38 39 StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath( 40 tensorflow::StringPiece path) { 41 CHECK(!path.ends_with(".gz")) 42 << "TextLiteralReader no longer supports reading .gz files"; 43 std::unique_ptr<tensorflow::RandomAccessFile> file; 44 Status s = 45 tensorflow::Env::Default()->NewRandomAccessFile(path.ToString(), &file); 46 if (!s.ok()) { 47 return s; 48 } 49 50 TextLiteralReader reader(file.release()); 51 return reader.ReadAllLines(); 52 } 53 54 TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file) 55 : file_(file) {} 56 57 namespace { 58 // This is an optimized version of tensorflow::str_util::Split which uses 59 // StringPiece for the delimited strings and uses an out parameter for the 60 // result to avoid vector creation/destruction. 61 void SplitByDelimToStringPieces(tensorflow::StringPiece text, char delim, 62 std::vector<tensorflow::StringPiece>* result) { 63 result->clear(); 64 65 if (text.empty()) { 66 return; 67 } 68 69 // The following loop is a little strange: its bound is text.size() + 1 70 // instead of the more typical text.size(). 71 // The final iteration of the loop (when i is equal to text.size()) handles 72 // the trailing token. 73 size_t token_start = 0; 74 for (size_t i = 0; i < text.size() + 1; i++) { 75 if (i == text.size() || text[i] == delim) { 76 tensorflow::StringPiece token(text.data() + token_start, i - token_start); 77 result->push_back(token); 78 token_start = i + 1; 79 } 80 } 81 } 82 } // namespace 83 84 StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() { 85 tensorflow::io::RandomAccessInputStream stream(file_.get()); 86 tensorflow::io::BufferedInputStream buf(&stream, 65536); 87 string shape_string; 88 Status s = buf.ReadLine(&shape_string); 89 if (!s.ok()) { 90 return s; 91 } 92 93 tensorflow::StringPiece sp(shape_string); 94 if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) { 95 string tmp = sp.ToString(); 96 shape_string = tmp; 97 } 98 TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); 99 if (shape.element_type() != F32) { 100 return Unimplemented( 101 "unsupported element type for text literal reading: %s", 102 ShapeUtil::HumanString(shape).c_str()); 103 } 104 105 auto result = MakeUnique<Literal>(shape); 106 const float fill = std::numeric_limits<float>::quiet_NaN(); 107 result->PopulateWithValue<float>(fill); 108 std::vector<tensorflow::StringPiece> pieces; 109 std::vector<tensorflow::StringPiece> coordinates; 110 std::vector<int64> coordinate_values; 111 string line; 112 while (buf.ReadLine(&line).ok()) { 113 SplitByDelimToStringPieces(line, ':', &pieces); 114 tensorflow::StringPiece coordinates_string = pieces[0]; 115 tensorflow::StringPiece value_string = pieces[1]; 116 tensorflow::str_util::RemoveWhitespaceContext(&coordinates_string); 117 tensorflow::str_util::RemoveWhitespaceContext(&value_string); 118 if (!coordinates_string.Consume("(")) { 119 return InvalidArgument( 120 "expected '(' at the beginning of coordinates: \"%s\"", line.c_str()); 121 } 122 if (!tensorflow::str_util::ConsumeSuffix(&coordinates_string, ")")) { 123 return InvalidArgument("expected ')' at the end of coordinates: \"%s\"", 124 line.c_str()); 125 } 126 float value; 127 if (!tensorflow::strings::safe_strtof(value_string.ToString().c_str(), 128 &value)) { 129 return InvalidArgument("could not parse value as float: \"%s\"", 130 value_string.ToString().c_str()); 131 } 132 SplitByDelimToStringPieces(coordinates_string, ',', &coordinates); 133 coordinate_values.clear(); 134 for (tensorflow::StringPiece piece : coordinates) { 135 int64 coordinate_value; 136 if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) { 137 return InvalidArgument( 138 "could not parse coordinate member as int64: \"%s\"", 139 piece.ToString().c_str()); 140 } 141 coordinate_values.push_back(coordinate_value); 142 } 143 if (coordinate_values.size() != shape.dimensions_size()) { 144 return InvalidArgument( 145 "line did not have expected number of coordinates; want %d got %zu: " 146 "\"%s\"", 147 shape.dimensions_size(), coordinate_values.size(), line.c_str()); 148 } 149 result->Set<float>(coordinate_values, value); 150 } 151 return std::move(result); 152 } 153 154 } // namespace xla 155