Home | History | Annotate | Download | only in default
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/platform/device_tracer.h"
     17 
     18 #if GOOGLE_CUDA
     19 
     20 #include <stdlib.h>
     21 #include <memory>
     22 
     23 #include "tensorflow/core/common_runtime/step_stats_collector.h"
     24 #include "tensorflow/core/framework/step_stats.pb.h"
     25 #include "tensorflow/core/lib/core/errors.h"
     26 #include "tensorflow/core/lib/strings/stringprintf.h"
     27 #include "tensorflow/core/platform/cupti_wrapper.h"
     28 #include "tensorflow/core/platform/env.h"
     29 #include "tensorflow/core/platform/macros.h"
     30 #include "tensorflow/core/platform/mem.h"
     31 #include "tensorflow/core/platform/mutex.h"
     32 #include "tensorflow/core/platform/tracing.h"
     33 
     34 namespace {
     35 
     36 // Maps a MemcpyKind enum to a const string.
     37 const char *getMemcpyKindString(CUpti_ActivityMemcpyKind kind) {
     38   switch (kind) {
     39     case CUPTI_ACTIVITY_MEMCPY_KIND_HTOD:
     40       return "HtoD";
     41     case CUPTI_ACTIVITY_MEMCPY_KIND_DTOH:
     42       return "DtoH";
     43     case CUPTI_ACTIVITY_MEMCPY_KIND_HTOA:
     44       return "HtoA";
     45     case CUPTI_ACTIVITY_MEMCPY_KIND_ATOH:
     46       return "AtoH";
     47     case CUPTI_ACTIVITY_MEMCPY_KIND_ATOA:
     48       return "AtoA";
     49     case CUPTI_ACTIVITY_MEMCPY_KIND_ATOD:
     50       return "AtoD";
     51     case CUPTI_ACTIVITY_MEMCPY_KIND_DTOA:
     52       return "DtoA";
     53     case CUPTI_ACTIVITY_MEMCPY_KIND_DTOD:
     54       return "DtoD";
     55     case CUPTI_ACTIVITY_MEMCPY_KIND_HTOH:
     56       return "HtoH";
     57     case CUPTI_ACTIVITY_MEMCPY_KIND_PTOP:
     58       return "PtoP";
     59     default:
     60       break;
     61   }
     62   return "<unknown>";
     63 }
     64 
     65 // Maps a MemoryKind enum to a const string.
     66 const char *getMemoryKindString(CUpti_ActivityMemoryKind kind) {
     67   switch (kind) {
     68     case CUPTI_ACTIVITY_MEMORY_KIND_UNKNOWN:
     69       return "Unknown";
     70     case CUPTI_ACTIVITY_MEMORY_KIND_PAGEABLE:
     71       return "Pageable";
     72     case CUPTI_ACTIVITY_MEMORY_KIND_PINNED:
     73       return "Pinned";
     74     case CUPTI_ACTIVITY_MEMORY_KIND_DEVICE:
     75       return "Device";
     76     case CUPTI_ACTIVITY_MEMORY_KIND_ARRAY:
     77       return "Array";
     78     default:
     79       break;
     80   }
     81   return "<unknown>";
     82 }
     83 
     84 // Maps an OverheadKind enum to a const string.
     85 const char *getActivityOverheadKindString(CUpti_ActivityOverheadKind kind) {
     86   switch (kind) {
     87     case CUPTI_ACTIVITY_OVERHEAD_DRIVER_COMPILER:
     88       return "COMPILER";
     89     case CUPTI_ACTIVITY_OVERHEAD_CUPTI_BUFFER_FLUSH:
     90       return "BUFFER_FLUSH";
     91     case CUPTI_ACTIVITY_OVERHEAD_CUPTI_INSTRUMENTATION:
     92       return "INSTRUMENTATION";
     93     case CUPTI_ACTIVITY_OVERHEAD_CUPTI_RESOURCE:
     94       return "RESOURCE";
     95     default:
     96       break;
     97   }
     98   return "<unknown>";
     99 }
    100 
    101 }  // namespace
    102 
    103 namespace tensorflow {
    104 namespace devicetracer {
    105 
    106 // Forward declaration.
    107 class CUPTIManager;
    108 
    109 // Returns a pointer to the CUPTIManager singleton.
    110 CUPTIManager *GetCUPTIManager();
    111 
    112 // Callback interface for consumers of CUPTI tracing.
    113 class CUPTIClient {
    114  public:
    115   virtual ~CUPTIClient() {}
    116 
    117   // Invoked for each CUPTI activity reported.
    118   virtual void ActivityCallback(const CUpti_Activity &activity) = 0;
    119 };
    120 
    121 #define CUPTI_CALL(call)                                            \
    122   do {                                                              \
    123     CUptiResult _status = cupti_wrapper_->call;                     \
    124     if (_status != CUPTI_SUCCESS) {                                 \
    125       LOG(ERROR) << "cuda call " << #call << " failed " << _status; \
    126     }                                                               \
    127   } while (0)
    128 
    129 // Singleton class to manage registration of CUPTI callbacks.
    130 class CUPTIManager {
    131  public:
    132   CUPTIManager() {
    133     cupti_wrapper_.reset(new perftools::gputools::profiler::CuptiWrapper());
    134     CUPTI_CALL(ActivityRegisterCallbacks(BufferRequested, BufferCompleted));
    135   }
    136 
    137   // Enables tracing and delivers event callbacks to 'client'.
    138   // Does not take ownership of client.  Client's lifetime must persist
    139   // until tracing is disabled.
    140   Status EnableTrace(CUPTIClient *client);
    141 
    142   // Disable tracing.  No further events will be delivered to 'client'.
    143   Status DisableTrace();
    144 
    145  private:
    146   // Static functions which we can use as CUPTI callbacks.
    147   static void BufferRequested(uint8_t **buffer, size_t *size,
    148                               size_t *maxNumRecords) {
    149     GetCUPTIManager()->InternalBufferRequested(buffer, size, maxNumRecords);
    150   }
    151   static void BufferCompleted(CUcontext ctx, uint32_t streamId, uint8_t *buffer,
    152                               size_t size, size_t validSize) {
    153     GetCUPTIManager()->InternalBufferCompleted(ctx, streamId, buffer, size,
    154                                                validSize);
    155   }
    156   // These methods are called by the static stubs above.
    157   void InternalBufferRequested(uint8_t **buffer, size_t *size,
    158                                size_t *maxNumRecords);
    159   void InternalBufferCompleted(CUcontext ctx, uint32_t streamId,
    160                                uint8_t *buffer, size_t size, size_t validSize);
    161 
    162   // Size of buffers used for CUPTI tracing.
    163   static constexpr size_t kBufferSize = 32 * 1024;
    164   // Required alignment of CUPTI buffers.
    165   static constexpr size_t kBufferAlignment = 8;
    166 
    167   mutex mu_;
    168   CUPTIClient *client_ GUARDED_BY(mu_);
    169   std::unique_ptr<perftools::gputools::profiler::CuptiWrapper> cupti_wrapper_;
    170 
    171   TF_DISALLOW_COPY_AND_ASSIGN(CUPTIManager);
    172 };
    173 
    174 Status CUPTIManager::EnableTrace(CUPTIClient *client) {
    175   mutex_lock l(mu_);
    176   // TODO(pbar) Work out the minimal set to trace.
    177   // We can currently manage without driver/runtime tracing.
    178   // CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_CONTEXT));
    179   // CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER));
    180   // CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME));
    181   // These might be useful for annotations but require NVTX API.
    182   // CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_NAME));
    183   // CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_MARKER));
    184 
    185   CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_DEVICE));
    186   CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_KERNEL));
    187   CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_MEMCPY));
    188   CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_MEMCPY2));
    189   CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_MEMSET));
    190   CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_OVERHEAD));
    191   client_ = client;
    192   return Status::OK();
    193 }
    194 
    195 Status CUPTIManager::DisableTrace() {
    196   // We turn off all tracing regardless.
    197   CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_NAME));
    198   CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_MARKER));
    199   CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_OVERHEAD));
    200   CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_CONTEXT));
    201   CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_DRIVER));
    202   CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_RUNTIME));
    203   CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_DEVICE));
    204   CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_KERNEL));
    205   CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_MEMCPY));
    206   CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_MEMCPY2));
    207   CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_MEMSET));
    208   CUPTI_CALL(ActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED));
    209   {
    210     // Don't acquire this lock until Flush returns, since Flush
    211     // will potentially cause callbacks into BufferCompleted.
    212     mutex_lock l(mu_);
    213     client_ = nullptr;
    214   }
    215   return Status::OK();
    216 }
    217 
    218 void CUPTIManager::InternalBufferRequested(uint8_t **buffer, size_t *size,
    219                                            size_t *maxNumRecords) {
    220   VLOG(2) << "BufferRequested";
    221   void *p = port::AlignedMalloc(kBufferSize, kBufferAlignment);
    222   *size = kBufferSize;
    223   *buffer = reinterpret_cast<uint8_t *>(p);
    224   *maxNumRecords = 0;
    225 }
    226 
    227 void CUPTIManager::InternalBufferCompleted(CUcontext ctx, uint32_t streamId,
    228                                            uint8_t *buffer, size_t size,
    229                                            size_t validSize) {
    230   VLOG(2) << "BufferCompleted";
    231   CUptiResult status;
    232   CUpti_Activity *record = nullptr;
    233   mutex_lock l(mu_);  // Hold mu_ while using client_.
    234   if (client_ && validSize > 0) {
    235     do {
    236       status =
    237           cupti_wrapper_->ActivityGetNextRecord(buffer, validSize, &record);
    238       if (status == CUPTI_SUCCESS) {
    239         client_->ActivityCallback(*record);
    240       } else {
    241         break;
    242       }
    243     } while (1);
    244 
    245     // report any records dropped from the queue
    246     size_t dropped;
    247     CUPTI_CALL(ActivityGetNumDroppedRecords(ctx, streamId, &dropped));
    248     if (dropped != 0) {
    249       LOG(WARNING) << "Dropped " << dropped << " activity records";
    250     }
    251   }
    252   port::AlignedFree(buffer);
    253 }
    254 
    255 CUPTIManager *GetCUPTIManager() {
    256   static CUPTIManager *manager = new CUPTIManager();
    257   return manager;
    258 }
    259 
    260 #ifdef _MSC_VER
    261 #define __thread __declspec(thread)
    262 #endif
    263 
    264 // TODO(pbar) Move this to platform specific header file?
    265 // Static thread local variable for POD types.
    266 #define TF_STATIC_THREAD_LOCAL_POD(_Type_, _var_)                  \
    267   static __thread _Type_ s_obj_##_var_;                            \
    268   namespace {                                                      \
    269   class ThreadLocal_##_var_ {                                      \
    270    public:                                                         \
    271     ThreadLocal_##_var_() {}                                       \
    272     void Init() {}                                                 \
    273     inline _Type_ *pointer() const { return &s_obj_##_var_; }      \
    274     inline _Type_ *safe_pointer() const { return &s_obj_##_var_; } \
    275     _Type_ &get() const { return s_obj_##_var_; }                  \
    276     bool is_native_tls() const { return true; }                    \
    277                                                                    \
    278    private:                                                        \
    279     TF_DISALLOW_COPY_AND_ASSIGN(ThreadLocal_##_var_);              \
    280   } _var_;                                                         \
    281   }  // namespace
    282 
    283 // Thread-local state recording the most recent annotation (if any).
    284 // When non-null, this points to a string in the active annotation
    285 // of the current thread.  The annotation is guaranteed to remain live
    286 // for the duration of the CUPTI API callback.
    287 TF_STATIC_THREAD_LOCAL_POD(const char *, tls_current_annotation);
    288 
    289 class DeviceTracerImpl : public DeviceTracer,
    290                          public CUPTIClient,
    291                          public port::Tracing::Engine {
    292  public:
    293   DeviceTracerImpl();
    294   ~DeviceTracerImpl() override;
    295 
    296   // DeviceTracer interface:
    297   Status Start() override;
    298   Status Stop() override;
    299   Status Collect(StepStatsCollector *collector) override;
    300 
    301   // port::Tracing::Engine interface:
    302   bool IsEnabled() const override {
    303     // We only register the Engine while tracing is enabled.
    304     return true;
    305   }
    306   Annotation *PushAnnotation(StringPiece name) override {
    307     VLOG(2) << "PushAnnotation " << name;
    308     struct Impl : public port::Tracing::Engine::Annotation {
    309       string annotation;
    310       explicit Impl(StringPiece n) : annotation(n.ToString()) {
    311         // Remember the most recent ScopedAnnotation for each thread.
    312         tls_current_annotation.get() = annotation.c_str();
    313       }
    314       ~Impl() override { tls_current_annotation.get() = nullptr; }
    315     };
    316     return new Impl(name);
    317   }
    318   Tracer *StartTracing(StringPiece label, bool is_expensive) override {
    319     // We don't do anything with 'TraceMe' regions yet.
    320     return nullptr;
    321   }
    322 
    323  protected:
    324   // This callback is used exclusively by CUPTIManager.
    325   friend class CUPTIManager;
    326   void ActivityCallback(const CUpti_Activity &activity) override;
    327 
    328  private:
    329   // Internal struct to record kernel launches.
    330   struct KernelRecord {
    331     uint64_t start_timestamp;
    332     uint64_t end_timestamp;
    333     uint32 device_id;
    334     uint32 stream_id;
    335     uint32 correlation_id;
    336   };
    337   // Internal struct to record memcpy operations.
    338   struct MemcpyRecord {
    339     uint64_t start_timestamp;
    340     uint64_t end_timestamp;
    341     uint32 device_id;
    342     uint32 stream_id;
    343     uint32 correlation_id;
    344     uint8 copyKind;
    345     uint8 srcKind;
    346     uint8 dstKind;
    347     uint64 bytes;
    348   };
    349 
    350   // This is the subscriber callback which is invoked directly by CUPTI.
    351   // The 'userdata' argument will be a pointer to the active 'DeviceTracerImpl'.
    352   static void CUPTIAPI ApiCallback(void *userdata, CUpti_CallbackDomain domain,
    353                                    CUpti_CallbackId cbid, const void *cbdata);
    354 
    355   // Records the mapping between correlation ID and kernel name.
    356   void AddCorrelationId(uint32 correlation_id, const string &name);
    357 
    358   // Returns the current system time in microseconds.
    359   inline int64 NowInUsec() { return Env::Default()->NowMicros(); }
    360 
    361   CUPTIManager *cupti_manager_;
    362   std::unique_ptr<perftools::gputools::profiler::CuptiWrapper> cupti_wrapper_;
    363   CUpti_SubscriberHandle subscriber_;
    364 
    365   mutex trace_mu_;
    366   static constexpr size_t kMaxRecords = 1024 * 1024;
    367   std::map<uint32, string> correlations_ GUARDED_BY(trace_mu_);
    368   std::vector<KernelRecord> kernel_records_ GUARDED_BY(trace_mu_);
    369   std::vector<MemcpyRecord> memcpy_records_ GUARDED_BY(trace_mu_);
    370 
    371   mutex mu_;
    372   bool enabled_ GUARDED_BY(mu_);
    373   int64 start_walltime_us_ GUARDED_BY(mu_);
    374   int64 end_walltime_us_ GUARDED_BY(mu_);
    375   uint64_t start_timestamp_ GUARDED_BY(mu_);
    376   uint64_t end_timestamp_ GUARDED_BY(mu_);
    377 
    378   TF_DISALLOW_COPY_AND_ASSIGN(DeviceTracerImpl);
    379 };
    380 
    381 DeviceTracerImpl::DeviceTracerImpl() {
    382   VLOG(1) << "DeviceTracer created.";
    383   cupti_manager_ = GetCUPTIManager();
    384   CHECK(cupti_manager_);
    385   cupti_wrapper_.reset(new perftools::gputools::profiler::CuptiWrapper());
    386   enabled_ = false;
    387 }
    388 
    389 DeviceTracerImpl::~DeviceTracerImpl() {
    390   // Unregister the CUPTI callbacks if needed to prevent them from accessing
    391   // freed memory.
    392   Stop().IgnoreError();
    393 }
    394 
    395 Status DeviceTracerImpl::Start() {
    396   VLOG(1) << "DeviceTracer::Start";
    397   mutex_lock l(mu_);
    398   if (enabled_) {
    399     return errors::FailedPrecondition("DeviceTracer is already enabled.");
    400   }
    401   // There can only be one CUPTI subscriber.  If we can't create one then
    402   // there is another trace in progress (possibly by external code).
    403   CUptiResult ret;
    404   ret = cupti_wrapper_->Subscribe(
    405       &subscriber_, static_cast<CUpti_CallbackFunc>(ApiCallback), this);
    406   if (ret == CUPTI_ERROR_MAX_LIMIT_REACHED) {
    407     return errors::Unavailable("CUPTI subcriber limit reached.");
    408   } else if (ret != CUPTI_SUCCESS) {
    409     return errors::Internal("Failed to create CUPTI subcriber.");
    410   }
    411 
    412   // Register as a TraceEngine to receive ScopedAnnotations.
    413   port::Tracing::RegisterEngine(this);
    414 
    415   // Intercept launch and memcpy calls to capture the Op name annotation.
    416   // TODO(pbar) Add callbacks for memcpy variants.
    417   CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_,
    418                             CUPTI_CB_DOMAIN_DRIVER_API,
    419                             CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel));
    420   CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_,
    421                             CUPTI_CB_DOMAIN_RUNTIME_API,
    422                             CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020));
    423   CUPTI_CALL(EnableCallback(
    424       /*enable=*/1, subscriber_, CUPTI_CB_DOMAIN_RUNTIME_API,
    425       CUPTI_RUNTIME_TRACE_CBID_cudaMemcpyAsync_v3020));
    426 
    427   CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_,
    428                             CUPTI_CB_DOMAIN_DRIVER_API,
    429                             CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2));
    430   CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_,
    431                             CUPTI_CB_DOMAIN_DRIVER_API,
    432                             CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2));
    433   CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_,
    434                             CUPTI_CB_DOMAIN_DRIVER_API,
    435                             CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2));
    436   CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_,
    437                             CUPTI_CB_DOMAIN_DRIVER_API,
    438                             CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2));
    439   CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_,
    440                             CUPTI_CB_DOMAIN_DRIVER_API,
    441                             CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2));
    442   CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_,
    443                             CUPTI_CB_DOMAIN_DRIVER_API,
    444                             CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2));
    445 
    446   TF_RETURN_IF_ERROR(cupti_manager_->EnableTrace(this));
    447 
    448   CUPTI_CALL(GetTimestamp(&start_timestamp_));
    449   start_walltime_us_ = NowInUsec();
    450   enabled_ = true;
    451   return Status::OK();
    452 }
    453 
    454 Status DeviceTracerImpl::Stop() {
    455   VLOG(1) << "DeviceTracer::Stop";
    456   mutex_lock l(mu_);
    457   if (!enabled_) {
    458     return Status::OK();
    459   }
    460   CUPTI_CALL(Unsubscribe(subscriber_));
    461   port::Tracing::RegisterEngine(nullptr);
    462   TF_RETURN_IF_ERROR(cupti_manager_->DisableTrace());
    463   end_walltime_us_ = NowInUsec();
    464   CUPTI_CALL(GetTimestamp(&end_timestamp_));
    465   enabled_ = false;
    466   return Status::OK();
    467 }
    468 
    469 void DeviceTracerImpl::AddCorrelationId(uint32 correlation_id,
    470                                         const string &name) {
    471   VLOG(2) << correlation_id << " : " << name;
    472   mutex_lock l(trace_mu_);
    473   if (correlations_.size() >= kMaxRecords) return;
    474   correlations_.emplace(correlation_id, name);
    475 }
    476 
    477 /*static*/ void DeviceTracerImpl::ApiCallback(void *userdata,
    478                                               CUpti_CallbackDomain domain,
    479                                               CUpti_CallbackId cbid,
    480                                               const void *cbdata) {
    481   auto *cbInfo = reinterpret_cast<const CUpti_CallbackData *>(cbdata);
    482   DeviceTracerImpl *tracer = reinterpret_cast<DeviceTracerImpl *>(userdata);
    483   VLOG(2) << "ApiCallback " << domain << ":" << cbid
    484           << " func: " << cbInfo->functionName;
    485 
    486   // API callbacks are invoked synchronously on the thread making the
    487   // CUDA API call.  If this pointer is non-null then the ScopedAnnotation
    488   // must be valid.
    489   const char *tls_annotation = tls_current_annotation.get();
    490 
    491   if ((domain == CUPTI_CB_DOMAIN_DRIVER_API) &&
    492       (cbid == CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel)) {
    493     if (cbInfo->callbackSite == CUPTI_API_ENTER) {
    494       auto *params = reinterpret_cast<const cuLaunchKernel_params *>(
    495           cbInfo->functionParams);
    496       if (VLOG_IS_ON(2)) {
    497         VLOG(2) << "LAUNCH stream " << params->hStream << " correllation "
    498                 << cbInfo->correlationId << " kernel " << cbInfo->symbolName;
    499       }
    500       const string annotation =
    501           tls_annotation ? tls_annotation : cbInfo->symbolName;
    502       tracer->AddCorrelationId(cbInfo->correlationId, annotation);
    503     }
    504   } else if ((domain == CUPTI_CB_DOMAIN_RUNTIME_API) &&
    505              (cbid == CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020 ||
    506               cbid == CUPTI_RUNTIME_TRACE_CBID_cudaMemcpyAsync_v3020)) {
    507     if (cbInfo->callbackSite == CUPTI_API_ENTER) {
    508       if (VLOG_IS_ON(2)) {
    509         auto *funcParams = reinterpret_cast<const cudaMemcpy_v3020_params *>(
    510             cbInfo->functionParams);
    511         size_t count = funcParams->count;
    512         enum cudaMemcpyKind kind = funcParams->kind;
    513         VLOG(2) << "MEMCPY count " << count << " kind " << kind;
    514       }
    515       if (tls_annotation) {
    516         const string annotation = tls_annotation;
    517         tracer->AddCorrelationId(cbInfo->correlationId, annotation);
    518       }
    519     }
    520   } else if ((domain == CUPTI_CB_DOMAIN_DRIVER_API) &&
    521              (cbid == CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2 ||
    522               cbid == CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2 ||
    523               cbid == CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2 ||
    524               cbid == CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2 ||
    525               cbid == CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2 ||
    526               cbid == CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2)) {
    527     if (cbInfo->callbackSite == CUPTI_API_EXIT && tls_annotation) {
    528       const string annotation = tls_annotation;
    529       tracer->AddCorrelationId(cbInfo->correlationId, annotation);
    530     }
    531   } else {
    532     VLOG(1) << "Unhandled API Callback for " << domain << " " << cbid;
    533   }
    534 }
    535 
    536 void DeviceTracerImpl::ActivityCallback(const CUpti_Activity &record) {
    537   VLOG(2) << "ActivityCallback " << record.kind;
    538   mutex_lock l(trace_mu_);
    539   switch (record.kind) {
    540     case CUPTI_ACTIVITY_KIND_MEMCPY: {
    541       if (memcpy_records_.size() >= kMaxRecords) return;
    542       auto *memcpy = reinterpret_cast<const CUpti_ActivityMemcpy *>(&record);
    543       memcpy_records_.push_back(MemcpyRecord{
    544           memcpy->start, memcpy->end, memcpy->deviceId, memcpy->streamId,
    545           memcpy->correlationId, memcpy->copyKind, memcpy->srcKind,
    546           memcpy->dstKind, memcpy->bytes});
    547       break;
    548     }
    549     case CUPTI_ACTIVITY_KIND_MEMCPY2: {
    550       if (memcpy_records_.size() >= kMaxRecords) return;
    551       auto *memcpy = reinterpret_cast<const CUpti_ActivityMemcpy2 *>(&record);
    552       memcpy_records_.push_back(MemcpyRecord{
    553           memcpy->start, memcpy->end, memcpy->deviceId, memcpy->streamId,
    554           memcpy->correlationId, memcpy->copyKind, memcpy->srcKind,
    555           memcpy->dstKind, memcpy->bytes});
    556       break;
    557     }
    558     case CUPTI_ACTIVITY_KIND_KERNEL:
    559     case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: {
    560       if (kernel_records_.size() >= kMaxRecords) return;
    561       auto *kernel = reinterpret_cast<const CUpti_ActivityKernel3 *>(&record);
    562       kernel_records_.push_back(KernelRecord{kernel->start, kernel->end,
    563                                              kernel->deviceId, kernel->streamId,
    564                                              kernel->correlationId});
    565       break;
    566     }
    567     default:
    568       VLOG(1) << "ActivityCallback unhandled kind";
    569       break;
    570   }
    571 }
    572 
    573 Status DeviceTracerImpl::Collect(StepStatsCollector *collector) {
    574   mutex_lock l(mu_);
    575   if (enabled_) {
    576     return errors::FailedPrecondition("DeviceTracer is still enabled.");
    577   }
    578 
    579   // TODO(pbar) Handle device IDs and prefix properly.
    580   const string prefix = "";
    581   const int id = 0;
    582   const string stream_device =
    583       strings::StrCat(prefix, "/device:GPU:", id, "/stream:");
    584   const string memcpy_device =
    585       strings::StrCat(prefix, "/device:GPU:", id, "/memcpy");
    586 
    587   mutex_lock l2(trace_mu_);
    588   for (const auto &rec : kernel_records_) {
    589     auto it = correlations_.find(rec.correlation_id);
    590     const string name = (it != correlations_.cend()) ? it->second : "unknown";
    591     NodeExecStats *ns = new NodeExecStats;
    592     ns->set_all_start_micros(start_walltime_us_ +
    593                              ((rec.start_timestamp - start_timestamp_) / 1000));
    594     ns->set_op_start_rel_micros(0);
    595     auto elapsed_us =
    596         std::max<int64>((rec.end_timestamp - rec.start_timestamp) / 1000, 1);
    597     ns->set_op_end_rel_micros(elapsed_us);
    598     ns->set_all_end_rel_micros(elapsed_us);
    599     ns->set_node_name(name);
    600     // TODO(pbar) Generate details based on the kernel activity record.
    601     // ns->set_timeline_label(details);
    602     auto nscopy = new NodeExecStats;
    603     *nscopy = *ns;
    604     collector->Save(strings::StrCat(stream_device, "all"), ns);
    605     collector->Save(strings::StrCat(stream_device, rec.stream_id), nscopy);
    606   }
    607   for (const auto &rec : memcpy_records_) {
    608     auto it = correlations_.find(rec.correlation_id);
    609     const string name = (it != correlations_.cend()) ? it->second : "unknown";
    610     NodeExecStats *ns = new NodeExecStats;
    611     ns->set_all_start_micros(start_walltime_us_ +
    612                              ((rec.start_timestamp - start_timestamp_) / 1000));
    613     ns->set_op_start_rel_micros(0);
    614     auto elapsed_us =
    615         std::max<int64>((rec.end_timestamp - rec.start_timestamp) / 1000, 1);
    616     ns->set_op_end_rel_micros(elapsed_us);
    617     ns->set_all_end_rel_micros(elapsed_us);
    618     auto copyKind = static_cast<CUpti_ActivityMemcpyKind>(rec.copyKind);
    619     auto srcKind = static_cast<CUpti_ActivityMemoryKind>(rec.srcKind);
    620     auto dstKind = static_cast<CUpti_ActivityMemoryKind>(rec.dstKind);
    621     const string details = strings::Printf(
    622         "MEMCPY%s %llu bytes (%s to %s)", getMemcpyKindString(copyKind),
    623         rec.bytes, getMemoryKindString(srcKind), getMemoryKindString(dstKind));
    624     ns->set_node_name(
    625         strings::StrCat(name, ":MEMCPY", getMemcpyKindString(copyKind)));
    626     ns->set_timeline_label(details);
    627     auto nscopy = new NodeExecStats;
    628     *nscopy = *ns;
    629     collector->Save(memcpy_device, ns);
    630     collector->Save(strings::StrCat(stream_device, rec.stream_id), nscopy);
    631   }
    632   return Status::OK();
    633 }
    634 
    635 }  // namespace devicetracer
    636 
    637 std::unique_ptr<DeviceTracer> CreateDeviceTracer() {
    638   std::unique_ptr<DeviceTracer> tracer(new devicetracer::DeviceTracerImpl());
    639   return tracer;
    640 }
    641 
    642 }  // namespace tensorflow
    643 
    644 #else  // GOOGLE_CUDA
    645 
    646 namespace tensorflow {
    647 
    648 std::unique_ptr<DeviceTracer> CreateDeviceTracer() { return nullptr; }
    649 
    650 }  // namespace tensorflow
    651 
    652 #endif  // GOOGLE_CUDA
    653