Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
     17 #define TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
     18 
     19 #include <functional>
     20 #include "tensorflow/core/framework/allocator.h"
     21 #include "tensorflow/core/framework/tracking_allocator.h"
     22 
     23 namespace tensorflow {
     24 
     25 // Subclass VisitableAllocator instead of Allocator when a memory
     26 // allocator needs to enable some kind of registration/deregistration
     27 // of memory areas.
     28 class VisitableAllocator : public Allocator {
     29  public:
     30   // Visitor gets called with a pointer to a memory area and its
     31   // size in bytes.
     32   typedef std::function<void(void*, size_t)> Visitor;
     33 
     34   // Register a visitor guaranteed to be called exactly once on each
     35   // chunk of memory newly allocated from the underlying device.
     36   // Typically, chunks will be reused and possibly sub-divided by a
     37   // pool manager, so the calls will happen only once per process
     38   // execution, not once per tensor (re)allocation.
     39   virtual void AddAllocVisitor(Visitor visitor) = 0;
     40 
     41   // Register a visitor guaranteed to be called on each chunk of
     42   // memory returned to the underlying device.
     43   virtual void AddFreeVisitor(Visitor visitor) = 0;
     44 };
     45 
     46 // Needed for cases when a VisitableAllocator gets wrapped for tracking.
     47 // Multiple-inheritance is considered acceptable in this case because
     48 // VisitableAllocator is a pure virtual interface and only TrackingAllocator
     49 // has default implementation.
     50 class TrackingVisitableAllocator : public TrackingAllocator,
     51                                    public VisitableAllocator {
     52  public:
     53   TrackingVisitableAllocator(VisitableAllocator* allocator, bool track_ids)
     54       : TrackingAllocator(allocator, track_ids), allocator_(allocator) {}
     55   ~TrackingVisitableAllocator() override {}
     56 
     57   string Name() override { return TrackingAllocator::Name(); }
     58 
     59   void* AllocateRaw(size_t alignment, size_t num_bytes) override {
     60     return TrackingAllocator::AllocateRaw(alignment, num_bytes);
     61   }
     62 
     63   void DeallocateRaw(void* ptr) override {
     64     TrackingAllocator::DeallocateRaw(ptr);
     65   }
     66 
     67   void AddAllocVisitor(Visitor visitor) override {
     68     allocator_->AddAllocVisitor(visitor);
     69   }
     70 
     71   void AddFreeVisitor(Visitor visitor) override {
     72     allocator_->AddFreeVisitor(visitor);
     73   }
     74 
     75  protected:
     76   VisitableAllocator* allocator_;
     77 };
     78 }  // namespace tensorflow
     79 #endif  // TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
     80