Home | History | Annotate | Download | only in experimental
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <queue>
     16 #include "tensorflow/core/framework/op_kernel.h"
     17 #include "tensorflow/core/framework/partial_tensor_shape.h"
     18 #include "tensorflow/core/framework/tensor.h"
     19 #include "tensorflow/core/framework/tensor_shape.h"
     20 #include "tensorflow/core/kernels/data/dataset.h"
     21 #include "tensorflow/core/lib/core/blocking_counter.h"
     22 #include "tensorflow/core/lib/core/errors.h"
     23 #include "tensorflow/core/lib/core/threadpool.h"
     24 #include "tensorflow/core/lib/io/buffered_inputstream.h"
     25 #include "tensorflow/core/lib/io/inputbuffer.h"
     26 #include "tensorflow/core/lib/io/path.h"
     27 #include "tensorflow/core/lib/io/random_inputstream.h"
     28 #include "tensorflow/core/lib/io/record_reader.h"
     29 #include "tensorflow/core/lib/io/zlib_compression_options.h"
     30 #include "tensorflow/core/lib/io/zlib_inputstream.h"
     31 #include "tensorflow/core/platform/env.h"
     32 
     33 namespace tensorflow {
     34 namespace data {
     35 namespace {
     36 
     37 class MatchingFilesDatasetOp : public DatasetOpKernel {
     38  public:
     39   using DatasetOpKernel::DatasetOpKernel;
     40 
     41   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
     42     const Tensor* patterns_t;
     43     OP_REQUIRES_OK(ctx, ctx->input("patterns", &patterns_t));
     44     const auto patterns = patterns_t->flat<string>();
     45     size_t num_patterns = static_cast<size_t>(patterns.size());
     46     std::vector<string> pattern_strs;
     47     pattern_strs.reserve(num_patterns);
     48 
     49     for (size_t i = 0; i < num_patterns; i++) {
     50       pattern_strs.push_back(patterns(i));
     51     }
     52 
     53     *output = new Dataset(ctx, std::move(pattern_strs));
     54   }
     55 
     56  private:
     57   class Dataset : public DatasetBase {
     58    public:
     59     Dataset(OpKernelContext* ctx, std::vector<string> patterns)
     60         : DatasetBase(DatasetContext(ctx)), patterns_(std::move(patterns)) {}
     61 
     62     std::unique_ptr<IteratorBase> MakeIteratorInternal(
     63         const string& prefix) const override {
     64       return absl::make_unique<Iterator>(
     65           Iterator::Params{this, strings::StrCat(prefix, "::MatchingFiles")});
     66     }
     67 
     68     const DataTypeVector& output_dtypes() const override {
     69       static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
     70       return *dtypes;
     71     }
     72 
     73     const std::vector<PartialTensorShape>& output_shapes() const override {
     74       static std::vector<PartialTensorShape>* shapes =
     75           new std::vector<PartialTensorShape>({{}});
     76       return *shapes;
     77     }
     78 
     79     string DebugString() const override {
     80       return "MatchingFilesDatasetOp::Dataset";
     81     }
     82 
     83    protected:
     84     Status AsGraphDefInternal(SerializationContext* ctx,
     85                               DatasetGraphDefBuilder* b,
     86                               Node** output) const override {
     87       Node* patterns_node = nullptr;
     88       TF_RETURN_IF_ERROR(b->AddVector(patterns_, &patterns_node));
     89       TF_RETURN_IF_ERROR(b->AddDataset(this, {patterns_node}, output));
     90       return Status::OK();
     91     }
     92 
     93    private:
     94     class Iterator : public DatasetIterator<Dataset> {
     95      public:
     96       explicit Iterator(const Params& params)
     97           : DatasetIterator<Dataset>(params) {}
     98 
     99       Status GetNextInternal(IteratorContext* ctx,
    100                              std::vector<Tensor>* out_tensors,
    101                              bool* end_of_sequence) override {
    102         mutex_lock l(mu_);
    103         FileSystem* fs;
    104 
    105         TF_RETURN_IF_ERROR(ctx->env()->GetFileSystemForFile(
    106             dataset()->patterns_[(current_pattern_index_ > 0)
    107                                      ? current_pattern_index_ - 1
    108                                      : 0],
    109             &fs));
    110 
    111         while (!filepath_queue_.empty() ||
    112                current_pattern_index_ < dataset()->patterns_.size()) {
    113           // All the elements in the heap will be the matched filenames or the
    114           // potential directories.
    115           if (!filepath_queue_.empty()) {
    116             PathStatus current_path = filepath_queue_.top();
    117             filepath_queue_.pop();
    118 
    119             if (!current_path.second) {
    120               Tensor filepath_tensor(ctx->allocator({}), DT_STRING, {});
    121 
    122               // Replace the forward slash with the backslash for Windows path
    123               if (isWindows_) {
    124                 std::replace(current_path.first.begin(),
    125                              current_path.first.end(), '/', '\\');
    126               }
    127 
    128               filepath_tensor.scalar<string>()() =
    129                   std::move(current_path.first);
    130               out_tensors->emplace_back(std::move(filepath_tensor));
    131               *end_of_sequence = false;
    132               hasMatch_ = true;
    133               return Status::OK();
    134             }
    135 
    136             // In this case, current_path is a directory. Then continue the
    137             // search.
    138             TF_RETURN_IF_ERROR(
    139                 UpdateIterator(ctx, fs, current_path.first, current_pattern_));
    140           } else {
    141             // search a new pattern
    142             current_pattern_ = dataset()->patterns_[current_pattern_index_];
    143 
    144             // Windows paths contain backslashes and Windows APIs accept forward
    145             // and backslashes equivalently, so we convert the pattern to use
    146             // forward slashes exclusively. The backslash is used as the
    147             // indicator of Windows paths. Note that this is not ideal, since
    148             // the API expects backslash as an escape character, but no code
    149             // appears to rely on this behavior
    150             if (current_pattern_.find('\\') != std::string::npos) {
    151               isWindows_ = true;
    152               std::replace(current_pattern_.begin(), current_pattern_.end(),
    153                            '\\', '/');
    154             } else {
    155               isWindows_ = false;
    156             }
    157 
    158             StringPiece fixed_prefix =
    159                 StringPiece(current_pattern_)
    160                     .substr(0, current_pattern_.find_first_of("*?[\\"));
    161             string current_dir(io::Dirname(fixed_prefix));
    162 
    163             // If current_dir is empty then we need to fix up fixed_prefix and
    164             // current_pattern_ to include . as the top level directory.
    165             if (current_dir.empty()) {
    166               current_dir = ".";
    167               current_pattern_ = io::JoinPath(current_dir, current_pattern_);
    168             }
    169 
    170             TF_RETURN_IF_ERROR(
    171                 UpdateIterator(ctx, fs, current_dir, current_pattern_));
    172             ++current_pattern_index_;
    173           }
    174         }
    175 
    176         *end_of_sequence = true;
    177         if (hasMatch_) {
    178           return Status::OK();
    179         } else {
    180           return errors::NotFound("Don't find any matched files");
    181         }
    182       }
    183 
    184      protected:
    185       std::shared_ptr<model::Node> CreateNode(
    186           IteratorContext* ctx, model::Node::Args args) const override {
    187         return model::MakeSourceNode(std::move(args));
    188       }
    189 
    190       Status SaveInternal(IteratorStateWriter* writer) override {
    191         mutex_lock l(mu_);
    192         TF_RETURN_IF_ERROR(writer->WriteScalar(
    193             full_name("current_pattern_index"), current_pattern_index_));
    194 
    195         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_pattern"),
    196                                                current_pattern_));
    197         TF_RETURN_IF_ERROR(
    198             writer->WriteScalar(full_name("hasMatch"), hasMatch_));
    199         TF_RETURN_IF_ERROR(
    200             writer->WriteScalar(full_name("isWindows"), isWindows_));
    201 
    202         if (!filepath_queue_.empty()) {
    203           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("queue_size"),
    204                                                  filepath_queue_.size()));
    205           int i = 0;
    206           while (!filepath_queue_.empty()) {
    207             TF_RETURN_IF_ERROR(
    208                 writer->WriteScalar(full_name(strings::StrCat("path_", i)),
    209                                     filepath_queue_.top().first));
    210             TF_RETURN_IF_ERROR(writer->WriteScalar(
    211                 full_name(strings::StrCat("path_status_", i)),
    212                 filepath_queue_.top().second));
    213             filepath_queue_.pop();
    214             i++;
    215           }
    216         }
    217 
    218         return Status::OK();
    219       }
    220 
    221       Status RestoreInternal(IteratorContext* ctx,
    222                              IteratorStateReader* reader) override {
    223         mutex_lock l(mu_);
    224         int64 current_pattern_index;
    225         TF_RETURN_IF_ERROR(reader->ReadScalar(
    226             full_name("current_pattern_index"), &current_pattern_index));
    227         current_pattern_index_ = size_t(current_pattern_index);
    228 
    229         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_pattern"),
    230                                               &current_pattern_));
    231         int64 hasMatch;
    232         TF_RETURN_IF_ERROR(
    233             reader->ReadScalar(full_name("hasMatch"), &hasMatch));
    234         hasMatch_ = static_cast<bool>(hasMatch);
    235 
    236         int64 isWindows;
    237         TF_RETURN_IF_ERROR(
    238             reader->ReadScalar(full_name("isWindows"), &isWindows));
    239         isWindows_ = static_cast<bool>(isWindows);
    240 
    241         if (reader->Contains(full_name("queue_size"))) {
    242           int64 queue_size;
    243           TF_RETURN_IF_ERROR(
    244               reader->ReadScalar(full_name("queue_size"), &queue_size));
    245           for (int i = 0; i < queue_size; i++) {
    246             string path;
    247             int64 path_status;
    248             TF_RETURN_IF_ERROR(reader->ReadScalar(
    249                 full_name(strings::StrCat("path_", i)), &path));
    250             TF_RETURN_IF_ERROR(reader->ReadScalar(
    251                 full_name(strings::StrCat("path_status_", i)), &path_status));
    252             filepath_queue_.push(
    253                 PathStatus(path, static_cast<bool>(path_status)));
    254           }
    255         }
    256 
    257         return Status::OK();
    258       }
    259 
    260      private:
    261       Status UpdateIterator(IteratorContext* ctx, FileSystem* fs,
    262                             const string& dir, const string& eval_pattern)
    263           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    264         StringPiece fixed_prefix =
    265             StringPiece(eval_pattern)
    266                 .substr(0, eval_pattern.find_first_of("*?[\\"));
    267 
    268         filepath_queue_.push(PathStatus(dir, true));
    269         Status ret;  // Status to return
    270 
    271         // DFS to find the first element in the iterator.
    272         while (!filepath_queue_.empty()) {
    273           const PathStatus current_path = filepath_queue_.top();
    274 
    275           // All the files in the heap are matched with the pattern, so finish
    276           // the search if current_path is a file.
    277           if (!current_path.second) {
    278             return Status::OK();
    279           }
    280 
    281           filepath_queue_.pop();
    282 
    283           // If current_path is a directory, search its children.
    284           const string& current_dir = current_path.first;
    285           std::vector<string> children;
    286           ret.Update(fs->GetChildren(current_dir, &children));
    287 
    288           // Handle the error cases: 1) continue the search if the status is
    289           // NOT_FOUND; 2) return the non-ok status immediately if it is not
    290           // NOT_FOUND.
    291           if (ret.code() == error::NOT_FOUND) {
    292             continue;
    293           } else if (!ret.ok()) {
    294             return ret;
    295           }
    296 
    297           // children_dir_status holds is_dir status for children. It can have
    298           // three possible values: OK for true; FAILED_PRECONDITION for false;
    299           // CANCELLED if we don't calculate IsDirectory (we might do that
    300           // because there isn't any point in exploring that child path).
    301           std::vector<Status> children_dir_status;
    302           children_dir_status.resize(children.size());
    303 
    304           // This IsDirectory call can be expensive for some FS. Parallelizing
    305           // it.
    306           auto is_directory_fn = [fs, current_dir, &children, &fixed_prefix,
    307                                   &children_dir_status](int i) {
    308             const string child_path = io::JoinPath(current_dir, children[i]);
    309             // In case the child_path doesn't start with the fixed_prefix, then
    310             // we don't need to explore this path.
    311             if (!str_util::StartsWith(child_path, fixed_prefix)) {
    312               children_dir_status[i] =
    313                   errors::Cancelled("Operation not needed");
    314             } else {
    315               children_dir_status[i] = fs->IsDirectory(child_path);
    316             }
    317           };
    318 
    319           BlockingCounter counter(children.size());
    320           for (int i = 0; i < children.size(); i++) {
    321             (*ctx->runner())([&is_directory_fn, &counter, i] {
    322               is_directory_fn(i);
    323               counter.DecrementCount();
    324             });
    325           }
    326           counter.Wait();
    327 
    328           for (int i = 0; i < children.size(); i++) {
    329             const string& child_dir_path =
    330                 io::JoinPath(current_dir, children[i]);
    331             const Status& child_dir_status = children_dir_status[i];
    332 
    333             // If the IsDirectory call was cancelled we bail.
    334             if (child_dir_status.code() == tensorflow::error::CANCELLED) {
    335               continue;
    336             }
    337 
    338             if (child_dir_status.ok()) {
    339               // push the child dir for next search
    340               filepath_queue_.push(PathStatus(child_dir_path, true));
    341             } else {
    342               // This case will be a file: if the file matches the pattern, push
    343               // it to the heap; otherwise, ignore it.
    344               if (ctx->env()->MatchPath(child_dir_path, eval_pattern)) {
    345                 filepath_queue_.push(PathStatus(child_dir_path, false));
    346               }
    347             }
    348           }
    349         }
    350         return ret;
    351       }
    352 
    353       mutex mu_;
    354       // True means the path is a directory; False means the path is a filename.
    355       typedef std::pair<string, bool> PathStatus;
    356       std::priority_queue<PathStatus, std::vector<PathStatus>,
    357                           std::greater<PathStatus>>
    358           filepath_queue_ GUARDED_BY(mu_);
    359       size_t current_pattern_index_ GUARDED_BY(mu_) = 0;
    360       string current_pattern_ GUARDED_BY(mu_);
    361       bool hasMatch_ GUARDED_BY(mu_) = false;
    362       bool isWindows_ GUARDED_BY(mu_) = false;
    363     };
    364 
    365     const std::vector<string> patterns_;
    366   };
    367 };
    368 
    369 REGISTER_KERNEL_BUILDER(
    370     Name("ExperimentalMatchingFilesDataset").Device(DEVICE_CPU),
    371     MatchingFilesDatasetOp);
    372 
    373 }  // namespace
    374 }  // namespace data
    375 }  // namespace tensorflow
    376