1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_ 17 #define TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_ 18 19 #include <functional> 20 #include "tensorflow/core/framework/allocator.h" 21 #include "tensorflow/core/framework/tracking_allocator.h" 22 23 namespace tensorflow { 24 25 // Subclass VisitableAllocator instead of Allocator when a memory 26 // allocator needs to enable some kind of registration/deregistration 27 // of memory areas. 28 class VisitableAllocator : public Allocator { 29 public: 30 // Visitor gets called with a pointer to a memory area and its 31 // size in bytes. 32 typedef std::function<void(void*, size_t)> Visitor; 33 34 // Register a visitor guaranteed to be called exactly once on each 35 // chunk of memory newly allocated from the underlying device. 36 // Typically, chunks will be reused and possibly sub-divided by a 37 // pool manager, so the calls will happen only once per process 38 // execution, not once per tensor (re)allocation. 39 virtual void AddAllocVisitor(Visitor visitor) = 0; 40 41 // Register a visitor guaranteed to be called on each chunk of 42 // memory returned to the underlying device. 43 virtual void AddFreeVisitor(Visitor visitor) = 0; 44 }; 45 46 // Needed for cases when a VisitableAllocator gets wrapped for tracking. 47 // Multiple-inheritance is considered acceptable in this case because 48 // VisitableAllocator is a pure virtual interface and only TrackingAllocator 49 // has default implementation. 50 class TrackingVisitableAllocator : public TrackingAllocator, 51 public VisitableAllocator { 52 public: 53 TrackingVisitableAllocator(VisitableAllocator* allocator, bool track_ids) 54 : TrackingAllocator(allocator, track_ids), allocator_(allocator) {} 55 ~TrackingVisitableAllocator() override {} 56 57 string Name() override { return TrackingAllocator::Name(); } 58 59 void* AllocateRaw(size_t alignment, size_t num_bytes) override { 60 return TrackingAllocator::AllocateRaw(alignment, num_bytes); 61 } 62 63 void DeallocateRaw(void* ptr) override { 64 TrackingAllocator::DeallocateRaw(ptr); 65 } 66 67 void AddAllocVisitor(Visitor visitor) override { 68 allocator_->AddAllocVisitor(visitor); 69 } 70 71 void AddFreeVisitor(Visitor visitor) override { 72 allocator_->AddFreeVisitor(visitor); 73 } 74 75 protected: 76 VisitableAllocator* allocator_; 77 }; 78 } // namespace tensorflow 79 #endif // TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_ 80