Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/util/saved_tensor_slice_util.h"
     17 
     18 #include "tensorflow/core/lib/core/errors.h"
     19 #include "tensorflow/core/lib/strings/ordered_code.h"
     20 #include "tensorflow/core/lib/strings/str_util.h"
     21 
     22 namespace tensorflow {
     23 
     24 namespace checkpoint {
     25 
     26 const char kSavedTensorSlicesKey[] = "";
     27 
     28 string EncodeTensorNameSlice(const string& name, const TensorSlice& slice) {
     29   string buffer;
     30   // All the tensor slice keys will start with a 0
     31   tensorflow::strings::OrderedCode::WriteNumIncreasing(&buffer, 0);
     32   tensorflow::strings::OrderedCode::WriteString(&buffer, name);
     33   tensorflow::strings::OrderedCode::WriteNumIncreasing(&buffer, slice.dims());
     34   for (int d = 0; d < slice.dims(); ++d) {
     35     // A trivial extent (meaning we take EVERYTHING) will default to -1 for both
     36     // start and end. These will be properly parsed.
     37     tensorflow::strings::OrderedCode::WriteSignedNumIncreasing(&buffer,
     38                                                                slice.start(d));
     39     tensorflow::strings::OrderedCode::WriteSignedNumIncreasing(&buffer,
     40                                                                slice.length(d));
     41   }
     42   return buffer;
     43 }
     44 
     45 Status DecodeTensorNameSlice(const string& code, string* name,
     46                              tensorflow::TensorSlice* slice) {
     47   StringPiece src(code);
     48   uint64 x;
     49   if (!tensorflow::strings::OrderedCode::ReadNumIncreasing(&src, &x)) {
     50     return errors::Internal("Failed to parse the leading number: src = ", src);
     51   }
     52   if (x != 0) {
     53     return errors::Internal(
     54         "The leading number should always be 0 for any valid key: src = ", src);
     55   }
     56   if (!tensorflow::strings::OrderedCode::ReadString(&src, name)) {
     57     return errors::Internal("Failed to parse the tensor name: src = ", src);
     58   }
     59   if (!tensorflow::strings::OrderedCode::ReadNumIncreasing(&src, &x)) {
     60     return errors::Internal("Failed to parse the tensor rank: src = ", src);
     61   }
     62   if (x == 0) {
     63     return errors::Internal("Expecting positive rank of the tensor, got ", x,
     64                             ", src = ", src);
     65   }
     66   if (x >= kint32max) {
     67     return errors::Internal("Too many elements ", x);
     68   }
     69   slice->SetFullSlice(x);
     70   for (int d = 0; d < static_cast<int32>(x); ++d) {
     71     // We expected 2x integers
     72     int64 start, length;
     73     if (!tensorflow::strings::OrderedCode::ReadSignedNumIncreasing(&src,
     74                                                                    &start)) {
     75       return errors::Internal("Failed to parse start: src = ", src);
     76     }
     77     if (!tensorflow::strings::OrderedCode::ReadSignedNumIncreasing(&src,
     78                                                                    &length)) {
     79       return errors::Internal("Failed to parse length: src = ", src);
     80     }
     81     if (length >= 0) {
     82       // a non-trivial extent
     83       slice->set_start(d, start);
     84       slice->set_length(d, length);
     85     }
     86   }
     87   return Status::OK();
     88 }
     89 
     90 Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape,
     91                           TensorSlice* slice, TensorShape* shape_slice) {
     92   CHECK(!shape_and_slice.empty());
     93   // Syntax: dim0 dim1 dim2 ... <slice string>
     94   // Where slice string is defined in core/framework/tensor_slice.h
     95   std::vector<string> splits = str_util::Split(shape_and_slice, ' ');
     96 
     97   // Must have at least 2 strings.
     98   if (splits.size() < 2) {
     99     return errors::InvalidArgument(
    100         "Need least two elements in shape_and_slice specification: ",
    101         shape_and_slice);
    102   }
    103 
    104   // The last split is the slice specification.
    105   slice->Clear();
    106   auto status = slice->Parse(splits.back(), slice);
    107   if (!status.ok()) return status;
    108 
    109   // The first n-1 are the shape specification.
    110   splits.pop_back();
    111   shape->Clear();
    112   for (const auto& s : splits) {
    113     int64 dim;
    114     if (!strings::safe_strto64(s, &dim)) {
    115       return errors::InvalidArgument(
    116           "Non numerical dimension in shape_and_slice: ", shape_and_slice);
    117     }
    118     shape->AddDim(dim);
    119   }
    120 
    121   // The specified slice must be compatible with the specified shape.
    122   return slice->SliceTensorShape(*shape, shape_slice);
    123 }
    124 
    125 }  // namespace checkpoint
    126 
    127 }  // namespace tensorflow
    128