Home | History | Annotate | Download | only in internal
      1 /*
      2 Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      3 
      4 Licensed under the Apache License, Version 2.0 (the "License");
      5 you may not use this file except in compliance with the License.
      6 You may obtain a copy of the License at
      7 
      8 http://www.apache.org/licenses/LICENSE-2.0
      9 
     10 Unless required by applicable law or agreed to in writing, software
     11 distributed under the License is distributed on an "AS IS" BASIS,
     12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 See the License for the specific language governing permissions and
     14 limitations under the License.
     15 */
     16 
     17 package internal
     18 
     19 /*
     20 #include <stdlib.h>
     21 #include <string.h>
     22 
     23 #include "tensorflow/c/c_api.h"
     24 */
     25 import "C"
     26 
     27 import (
     28 	"errors"
     29 	"fmt"
     30 	"runtime"
     31 	"unsafe"
     32 
     33 	"github.com/golang/protobuf/proto"
     34 	pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/github.com/tensorflow/tensorflow/tensorflow/go/core/framework"
     35 )
     36 
     37 // Encapsulates a collection of API definitions.
     38 //
     39 // apiDefMap represents a map from operation name to corresponding
     40 // ApiDef proto (see
     41 // https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto
     42 // for ApiDef proto definition).
     43 type apiDefMap struct {
     44 	c *C.TF_ApiDefMap
     45 }
     46 
     47 // Creates and returns a new apiDefMap instance.
     48 //
     49 // oplist is and OpList proto instance (see
     50 // https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto
     51 // for OpList proto definition).
     52 
     53 func newAPIDefMap(oplist *pb.OpList) (*apiDefMap, error) {
     54 	// Create a buffer containing the serialized OpList.
     55 	opdefSerialized, err := proto.Marshal(oplist)
     56 	if err != nil {
     57 		return nil, fmt.Errorf("could not serialize OpDef for %s", oplist.String())
     58 	}
     59 	data := C.CBytes(opdefSerialized)
     60 	defer C.free(data)
     61 
     62 	opbuf := C.TF_NewBuffer()
     63 	defer C.TF_DeleteBuffer(opbuf)
     64 	opbuf.data = data
     65 	opbuf.length = C.size_t(len(opdefSerialized))
     66 
     67 	// Create ApiDefMap.
     68 	status := C.TF_NewStatus()
     69 	defer C.TF_DeleteStatus(status)
     70 	capimap := C.TF_NewApiDefMap(opbuf, status)
     71 	if C.TF_GetCode(status) != C.TF_OK {
     72 		return nil, errors.New(C.GoString(C.TF_Message(status)))
     73 	}
     74 	apimap := &apiDefMap{capimap}
     75 	runtime.SetFinalizer(
     76 		apimap,
     77 		func(a *apiDefMap) {
     78 			C.TF_DeleteApiDefMap(a.c)
     79 		})
     80 	return apimap, nil
     81 }
     82 
     83 // Updates apiDefMap with the overrides specified in `data`.
     84 //
     85 // data - ApiDef text proto.
     86 func (m *apiDefMap) Put(data string) error {
     87 	cdata := C.CString(data)
     88 	defer C.free(unsafe.Pointer(cdata))
     89 	status := C.TF_NewStatus()
     90 	defer C.TF_DeleteStatus(status)
     91 	C.TF_ApiDefMapPut(m.c, cdata, C.size_t(len(data)), status)
     92 	if C.TF_GetCode(status) != C.TF_OK {
     93 		return errors.New(C.GoString(C.TF_Message(status)))
     94 	}
     95 	return nil
     96 }
     97 
     98 // Returns ApiDef proto instance for the TensorFlow operation
     99 // named `opname`.
    100 func (m *apiDefMap) Get(opname string) (*pb.ApiDef, error) {
    101 	cname := C.CString(opname)
    102 	defer C.free(unsafe.Pointer(cname))
    103 	status := C.TF_NewStatus()
    104 	defer C.TF_DeleteStatus(status)
    105 	apidefBuf := C.TF_ApiDefMapGet(
    106 		m.c, cname, C.size_t(len(opname)), status)
    107 	defer C.TF_DeleteBuffer(apidefBuf)
    108 	if C.TF_GetCode(status) != C.TF_OK {
    109 		return nil, errors.New(C.GoString(C.TF_Message(status)))
    110 	}
    111 	if apidefBuf == nil {
    112 		return nil, fmt.Errorf("could not find ApiDef for %s", opname)
    113 	}
    114 
    115 	var (
    116 		apidef = new(pb.ApiDef)
    117 		size   = int(apidefBuf.length)
    118 		// A []byte backed by C memory.
    119 		// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
    120 		data = (*[1 << 30]byte)(unsafe.Pointer(apidefBuf.data))[:size:size]
    121 		err  = proto.Unmarshal(data, apidef)
    122 	)
    123 	if err != nil {
    124 		return nil, err
    125 	}
    126 	return apidef, nil
    127 }
    128