Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/reader_base.h"
     17 
     18 #include "tensorflow/core/framework/reader_base.pb.h"
     19 #include "tensorflow/core/framework/types.h"
     20 #include "tensorflow/core/lib/core/coding.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 #include "tensorflow/core/lib/core/notification.h"
     23 #include "tensorflow/core/lib/core/stringpiece.h"
     24 #include "tensorflow/core/lib/strings/str_util.h"
     25 #include "tensorflow/core/lib/strings/strcat.h"
     26 
     27 namespace tensorflow {
     28 
     29 // ReaderBase ------------------------------------------------------
     30 
     31 ReaderBase::ReaderBase(const string& name) : name_(name) {}
     32 
     33 int64 ReaderBase::NumRecordsProduced() {
     34   mutex_lock lock(mu_);
     35   return num_records_produced_;
     36 }
     37 
     38 int64 ReaderBase::NumWorkUnitsCompleted() {
     39   mutex_lock lock(mu_);
     40   return work_finished_;
     41 }
     42 
     43 Status ReaderBase::Reset() {
     44   mutex_lock lock(mu_);
     45   return ResetLocked();
     46 }
     47 
     48 Status ReaderBase::ResetLocked() {
     49   work_started_ = 0;
     50   work_finished_ = 0;
     51   num_records_produced_ = 0;
     52   work_.clear();
     53   return Status::OK();
     54 }
     55 
     56 Status ReaderBase::SerializeState(string* state) {
     57   mutex_lock lock(mu_);
     58   return SerializeStateLocked(state);
     59 }
     60 
     61 Status ReaderBase::SerializeStateLocked(string* state) {
     62   return errors::Unimplemented("Reader SerializeState");
     63 }
     64 
     65 Status ReaderBase::RestoreState(const string& state) {
     66   mutex_lock lock(mu_);
     67   Status status = RestoreStateLocked(state);
     68   if (!status.ok()) {
     69     ResetLocked().IgnoreError();
     70   }
     71   return status;
     72 }
     73 
     74 Status ReaderBase::RestoreStateLocked(const string& state) {
     75   return errors::Unimplemented("Reader RestoreState");
     76 }
     77 
     78 int64 ReaderBase::ReadUpTo(const int64 num_records, QueueInterface* queue,
     79                            std::vector<string>* keys,
     80                            std::vector<string>* values,
     81                            OpKernelContext* context) {
     82   mutex_lock lock(mu_);
     83   int64 records_produced_this_call = 0;
     84   while (true) {
     85     // Records produced by this iteration of the ReadUpToLocked call.
     86     int64 num_records_produced = 0;
     87     int64 remaining = num_records - records_produced_this_call;
     88     if (remaining == 0) {
     89       return records_produced_this_call;
     90     }
     91     if (!work_in_progress()) {
     92       work_ = GetNextWorkLocked(queue, context);
     93       if (!context->status().ok()) {
     94         return records_produced_this_call;
     95       }
     96       Status status = OnWorkStartedLocked();
     97       if (status.ok()) {
     98         work_started_++;
     99       } else {
    100         context->SetStatus(status);
    101         return records_produced_this_call;
    102       }
    103     }
    104     bool at_end = false;
    105 
    106     Status status =
    107         ReadUpToLocked(remaining, keys, values, &num_records_produced, &at_end);
    108     // This call so far.
    109     records_produced_this_call += num_records_produced;
    110 
    111     // In total, over the lifetime of the ReaderBase.
    112     num_records_produced_ += num_records_produced;
    113 
    114     if (!at_end && status.ok() && num_records_produced == 0) {
    115       status = errors::Internal(
    116           "ReadManyLocked() for ", name(),
    117           " must set *at_end=true, *num_produced > 0 or return an error.");
    118       context->SetStatus(status);
    119       return records_produced_this_call;
    120     }
    121     if (status.ok() && at_end) {
    122       status = OnWorkFinishedLocked();
    123       work_finished_ = work_started_;
    124       if (records_produced_this_call > 0) {
    125         return records_produced_this_call;
    126       }
    127     }
    128     if (!status.ok()) {
    129       context->SetStatus(status);
    130       return records_produced_this_call;
    131     }
    132   }
    133 }
    134 
    135 // Default implementation just reads one record at a time.
    136 Status ReaderBase::ReadUpToLocked(int64 num_records, std::vector<string>* keys,
    137                                   std::vector<string>* values, int64* num_read,
    138                                   bool* at_end) {
    139   bool produced = false;
    140   string key;
    141   string value;
    142   Status status = ReadLocked(&key, &value, &produced, at_end);
    143   if (produced) {
    144     keys->emplace_back(key);
    145     values->emplace_back(value);
    146     *num_read = 1;
    147   } else {
    148     *num_read = 0;
    149   }
    150   return status;
    151 }
    152 
    153 void ReaderBase::Read(QueueInterface* queue, string* key, string* value,
    154                       OpKernelContext* context) {
    155   mutex_lock lock(mu_);
    156   while (true) {
    157     if (!work_in_progress()) {
    158       work_ = GetNextWorkLocked(queue, context);
    159       if (!context->status().ok()) {
    160         return;
    161       }
    162       Status status = OnWorkStartedLocked();
    163       if (status.ok()) {
    164         work_started_++;
    165       } else {
    166         context->SetStatus(status);
    167         return;
    168       }
    169     }
    170 
    171     bool produced = false;
    172     bool at_end = false;
    173     Status status = ReadLocked(key, value, &produced, &at_end);
    174 
    175     if (!at_end && status.ok() && !produced) {
    176       status = errors::Internal(
    177           "ReadLocked() for ", name(),
    178           " must set *at_end=true, *produced=true, or return an error.");
    179     }
    180     if (!status.ok() && produced) {
    181       status = errors::Internal(
    182           "ReadLocked() for ", name(),
    183           " set *produced=true *and* returned an error: ", status.ToString());
    184     }
    185     if (status.ok() && at_end) {
    186       status = OnWorkFinishedLocked();
    187       work_finished_ = work_started_;
    188     }
    189     if (!status.ok()) {
    190       context->SetStatus(status);
    191       return;
    192     }
    193     if (produced) {
    194       ++num_records_produced_;
    195       return;
    196     }
    197   }
    198 }
    199 
    200 string ReaderBase::GetNextWorkLocked(QueueInterface* queue,
    201                                      OpKernelContext* context) const {
    202   string work;
    203   Notification n;
    204   queue->TryDequeue(
    205       context, [this, context, &n, &work](const QueueInterface::Tuple& tuple) {
    206         if (context->status().ok()) {
    207           if (tuple.size() != 1) {
    208             context->SetStatus(
    209                 errors::InvalidArgument("Expected single component queue"));
    210           } else if (tuple[0].dtype() != DT_STRING) {
    211             context->SetStatus(errors::InvalidArgument(
    212                 "Expected queue with single string component"));
    213           } else if (tuple[0].NumElements() != 1) {
    214             context->SetStatus(errors::InvalidArgument(
    215                 "Expected to dequeue a one-element string tensor"));
    216           } else {
    217             work = tuple[0].flat<string>()(0);
    218           }
    219         }
    220         n.Notify();
    221       });
    222   n.WaitForNotification();
    223   return work;
    224 }
    225 
    226 void ReaderBase::SaveBaseState(ReaderBaseState* state) const {
    227   state->Clear();
    228   state->set_work_started(work_started_);
    229   state->set_work_finished(work_finished_);
    230   state->set_num_records_produced(num_records_produced_);
    231   state->set_current_work(work_);
    232 }
    233 
    234 string ReaderBase::KeyName(const string& key) const {
    235   return strings::StrCat(current_work(), ":", key);
    236 }
    237 
    238 Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) {
    239   work_started_ = state.work_started();
    240   work_finished_ = state.work_finished();
    241   num_records_produced_ = state.num_records_produced();
    242   work_ = state.current_work();
    243   if (work_started_ < 0 || work_finished_ < 0 || num_records_produced_ < 0) {
    244 #ifdef __ANDROID__
    245     const string debug_string = "<debug state not available>";
    246 #else
    247     const string debug_string = state.DebugString();
    248 #endif
    249     return errors::InvalidArgument(
    250         "Unexpected negative value when restoring in ", name(), ": ",
    251         debug_string);
    252   }
    253   if (work_started_ > work_finished_) {
    254 #ifdef __ANDROID__
    255     const string debug_string = "<debug state not available>";
    256 #else
    257     const string debug_string = state.DebugString();
    258 #endif
    259     return errors::InvalidArgument(
    260         "Inconsistent work started vs. finished when restoring in ", name(),
    261         ": ", debug_string);
    262   }
    263   return Status::OK();
    264 }
    265 
    266 }  // namespace tensorflow
    267