Home | History | Annotate | Download | only in go
      1 /*
      2 Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      3 
      4 Licensed under the Apache License, Version 2.0 (the "License");
      5 you may not use this file except in compliance with the License.
      6 You may obtain a copy of the License at
      7 
      8     http://www.apache.org/licenses/LICENSE-2.0
      9 
     10 Unless required by applicable law or agreed to in writing, software
     11 distributed under the License is distributed on an "AS IS" BASIS,
     12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 See the License for the specific language governing permissions and
     14 limitations under the License.
     15 */
     16 
     17 package tensorflow
     18 
     19 // #include <stdlib.h>
     20 // #include <string.h>
     21 // #include "tensorflow/c/c_api.h"
     22 import "C"
     23 
     24 import (
     25 	"bytes"
     26 	"encoding/binary"
     27 	"fmt"
     28 	"io"
     29 	"reflect"
     30 	"runtime"
     31 	"unsafe"
     32 )
     33 
     34 // DataType holds the type for a scalar value.  E.g., one slot in a tensor.
     35 type DataType C.TF_DataType
     36 
     37 // Types of scalar values in the TensorFlow type system.
     38 const (
     39 	Float      DataType = C.TF_FLOAT
     40 	Double     DataType = C.TF_DOUBLE
     41 	Int32      DataType = C.TF_INT32
     42 	Uint32     DataType = C.TF_UINT32
     43 	Uint8      DataType = C.TF_UINT8
     44 	Int16      DataType = C.TF_INT16
     45 	Int8       DataType = C.TF_INT8
     46 	String     DataType = C.TF_STRING
     47 	Complex64  DataType = C.TF_COMPLEX64
     48 	Complex    DataType = C.TF_COMPLEX
     49 	Int64      DataType = C.TF_INT64
     50 	Uint64     DataType = C.TF_UINT64
     51 	Bool       DataType = C.TF_BOOL
     52 	Qint8      DataType = C.TF_QINT8
     53 	Quint8     DataType = C.TF_QUINT8
     54 	Qint32     DataType = C.TF_QINT32
     55 	Bfloat16   DataType = C.TF_BFLOAT16
     56 	Qint16     DataType = C.TF_QINT16
     57 	Quint16    DataType = C.TF_QUINT16
     58 	Uint16     DataType = C.TF_UINT16
     59 	Complex128 DataType = C.TF_COMPLEX128
     60 	Half       DataType = C.TF_HALF
     61 )
     62 
     63 // Tensor holds a multi-dimensional array of elements of a single data type.
     64 type Tensor struct {
     65 	c     *C.TF_Tensor
     66 	shape []int64
     67 }
     68 
     69 // NewTensor converts from a Go value to a Tensor. Valid values are scalars,
     70 // slices, and arrays. Every element of a slice must have the same length so
     71 // that the resulting Tensor has a valid shape.
     72 func NewTensor(value interface{}) (*Tensor, error) {
     73 	val := reflect.ValueOf(value)
     74 	shape, dataType, err := shapeAndDataTypeOf(val)
     75 	if err != nil {
     76 		return nil, err
     77 	}
     78 	nflattened := numElements(shape)
     79 	nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened)
     80 	if dataType == String {
     81 		// TF_STRING tensors are encoded as an array of 8-byte offsets
     82 		// followed by string data. See c_api.h.
     83 		nbytes = uintptr(nflattened*8) + byteSizeOfEncodedStrings(value)
     84 	}
     85 	var shapePtr *C.int64_t
     86 	if len(shape) > 0 {
     87 		shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
     88 	}
     89 	t := &Tensor{
     90 		c:     C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)),
     91 		shape: shape,
     92 	}
     93 	runtime.SetFinalizer(t, (*Tensor).finalize)
     94 	raw := tensorData(t.c)
     95 	buf := bytes.NewBuffer(raw[:0:len(raw)])
     96 	if dataType != String {
     97 		if err := encodeTensor(buf, val, shape); err != nil {
     98 			return nil, err
     99 		}
    100 		if uintptr(buf.Len()) != nbytes {
    101 			return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len())
    102 		}
    103 	} else {
    104 		e := stringEncoder{offsets: buf, data: raw[nflattened*8:], status: newStatus()}
    105 		if err := e.encode(reflect.ValueOf(value), shape); err != nil {
    106 			return nil, err
    107 		}
    108 		if int64(buf.Len()) != nflattened*8 {
    109 			return nil, bug("invalid offset encoding for TF_STRING tensor with shape %v (got %v, want %v)", shape, buf.Len(), nflattened*8)
    110 		}
    111 	}
    112 	return t, nil
    113 }
    114 
    115 // ReadTensor constructs a Tensor with the provided type and shape from the
    116 // serialized tensor contents in r.
    117 //
    118 // See also WriteContentsTo.
    119 func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) {
    120 	if err := isTensorSerializable(dataType); err != nil {
    121 		return nil, err
    122 	}
    123 	nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape))
    124 	var shapePtr *C.int64_t
    125 	if len(shape) > 0 {
    126 		shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
    127 	}
    128 	t := &Tensor{
    129 		c:     C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)),
    130 		shape: shape,
    131 	}
    132 	runtime.SetFinalizer(t, (*Tensor).finalize)
    133 	raw := tensorData(t.c)
    134 	n, err := r.Read(raw)
    135 	if err != nil {
    136 		return nil, err
    137 	}
    138 	if uintptr(n) != nbytes {
    139 		return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n)
    140 	}
    141 	return t, nil
    142 }
    143 
    144 // newTensorFromC takes ownership of c and returns the owning Tensor.
    145 func newTensorFromC(c *C.TF_Tensor) *Tensor {
    146 	var shape []int64
    147 	if ndims := int(C.TF_NumDims(c)); ndims > 0 {
    148 		shape = make([]int64, ndims)
    149 	}
    150 	for i := range shape {
    151 		shape[i] = int64(C.TF_Dim(c, C.int(i)))
    152 	}
    153 	t := &Tensor{c: c, shape: shape}
    154 	runtime.SetFinalizer(t, (*Tensor).finalize)
    155 	return t
    156 }
    157 
    158 func (t *Tensor) finalize() { C.TF_DeleteTensor(t.c) }
    159 
    160 // DataType returns the scalar datatype of the Tensor.
    161 func (t *Tensor) DataType() DataType { return DataType(C.TF_TensorType(t.c)) }
    162 
    163 // Shape returns the shape of the Tensor.
    164 func (t *Tensor) Shape() []int64 { return t.shape }
    165 
    166 // Value converts the Tensor to a Go value. For now, not all Tensor types are
    167 // supported, and this function may panic if it encounters an unsupported
    168 // DataType.
    169 //
    170 // The type of the output depends on the Tensor type and dimensions.
    171 // For example:
    172 // Tensor(int64, 0): int64
    173 // Tensor(float64, 3): [][][]float64
    174 func (t *Tensor) Value() interface{} {
    175 	typ := typeOf(t.DataType(), t.Shape())
    176 	val := reflect.New(typ)
    177 	raw := tensorData(t.c)
    178 	if t.DataType() != String {
    179 		if err := decodeTensor(bytes.NewReader(raw), t.Shape(), typ, val); err != nil {
    180 			panic(bug("unable to decode Tensor of type %v and shape %v - %v", t.DataType(), t.Shape(), err))
    181 		}
    182 	} else {
    183 		nflattened := numElements(t.Shape())
    184 		d := stringDecoder{offsets: bytes.NewReader(raw[0 : 8*nflattened]), data: raw[8*nflattened:], status: newStatus()}
    185 		if err := d.decode(val, t.Shape()); err != nil {
    186 			panic(bug("unable to decode String tensor with shape %v - %v", t.Shape(), err))
    187 		}
    188 	}
    189 	return reflect.Indirect(val).Interface()
    190 }
    191 
    192 // WriteContentsTo writes the serialized contents of t to w.
    193 //
    194 // Returns the number of bytes written. See ReadTensor for
    195 // reconstructing a Tensor from the serialized form.
    196 //
    197 // WARNING: WriteContentsTo is not comprehensive and will fail
    198 // if t.DataType() is non-numeric (e.g., String). See
    199 // https://github.com/tensorflow/tensorflow/issues/6003.
    200 func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error) {
    201 	if err := isTensorSerializable(t.DataType()); err != nil {
    202 		return 0, err
    203 	}
    204 	return io.Copy(w, bytes.NewReader(tensorData(t.c)))
    205 }
    206 
    207 func tensorData(c *C.TF_Tensor) []byte {
    208 	// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
    209 	cbytes := C.TF_TensorData(c)
    210 	if cbytes == nil {
    211 		return nil
    212 	}
    213 	length := int(C.TF_TensorByteSize(c))
    214 	slice := (*[1 << 30]byte)(unsafe.Pointer(cbytes))[:length:length]
    215 	return slice
    216 }
    217 
    218 var types = []struct {
    219 	typ      reflect.Type
    220 	dataType C.TF_DataType
    221 }{
    222 	{reflect.TypeOf(float32(0)), C.TF_FLOAT},
    223 	{reflect.TypeOf(float64(0)), C.TF_DOUBLE},
    224 	{reflect.TypeOf(int32(0)), C.TF_INT32},
    225 	{reflect.TypeOf(uint32(0)), C.TF_UINT32},
    226 	{reflect.TypeOf(uint8(0)), C.TF_UINT8},
    227 	{reflect.TypeOf(int16(0)), C.TF_INT16},
    228 	{reflect.TypeOf(int8(0)), C.TF_INT8},
    229 	{reflect.TypeOf(""), C.TF_STRING},
    230 	{reflect.TypeOf(complex(float32(0), float32(0))), C.TF_COMPLEX64},
    231 	{reflect.TypeOf(int64(0)), C.TF_INT64},
    232 	{reflect.TypeOf(uint64(0)), C.TF_UINT64},
    233 	{reflect.TypeOf(false), C.TF_BOOL},
    234 	{reflect.TypeOf(uint16(0)), C.TF_UINT16},
    235 	{reflect.TypeOf(complex(float64(0), float64(0))), C.TF_COMPLEX128},
    236 	// TODO(apassos): support DT_RESOURCE representation in go.
    237 	// TODO(keveman): support DT_VARIANT representation in go.
    238 }
    239 
    240 // shapeAndDataTypeOf returns the data type and shape of the Tensor
    241 // corresponding to a Go type.
    242 func shapeAndDataTypeOf(val reflect.Value) (shape []int64, dt DataType, err error) {
    243 	typ := val.Type()
    244 	for typ.Kind() == reflect.Array || typ.Kind() == reflect.Slice {
    245 		shape = append(shape, int64(val.Len()))
    246 		if val.Len() > 0 {
    247 			// In order to check tensor structure properly in general case we need to iterate over all slices of the tensor to check sizes match
    248 			// Since we already going to iterate over all elements in encodeTensor() let's
    249 			// 1) do the actual check in encodeTensor() to save some cpu cycles here
    250 			// 2) assume the shape is represented by lengths of elements with zero index in each dimension
    251 			val = val.Index(0)
    252 		}
    253 		typ = typ.Elem()
    254 	}
    255 	for _, t := range types {
    256 		if typ.Kind() == t.typ.Kind() {
    257 			return shape, DataType(t.dataType), nil
    258 		}
    259 	}
    260 	return shape, dt, fmt.Errorf("unsupported type %v", typ)
    261 }
    262 
    263 // typeOf converts from a DataType and Shape to the equivalent Go type.
    264 func typeOf(dt DataType, shape []int64) reflect.Type {
    265 	var ret reflect.Type
    266 	for _, t := range types {
    267 		if dt == DataType(t.dataType) {
    268 			ret = t.typ
    269 			break
    270 		}
    271 	}
    272 	if ret == nil {
    273 		panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt))
    274 	}
    275 	for range shape {
    276 		ret = reflect.SliceOf(ret)
    277 	}
    278 	return ret
    279 }
    280 
    281 func numElements(shape []int64) int64 {
    282 	n := int64(1)
    283 	for _, d := range shape {
    284 		n *= d
    285 	}
    286 	return n
    287 }
    288 
    289 // byteSizeOfEncodedStrings returns the size of the encoded strings in val.
    290 // val MUST be a string, or a container (array/slice etc.) of strings.
    291 func byteSizeOfEncodedStrings(val interface{}) uintptr {
    292 	if s, ok := val.(string); ok {
    293 		return uintptr(C.TF_StringEncodedSize(C.size_t(len(s))))
    294 	}
    295 	// Otherwise must be an array or slice.
    296 	var size uintptr
    297 	v := reflect.ValueOf(val)
    298 	for i := 0; i < v.Len(); i++ {
    299 		size += byteSizeOfEncodedStrings(v.Index(i).Interface())
    300 	}
    301 	return size
    302 }
    303 
    304 // encodeTensor writes v to the specified buffer using the format specified in
    305 // c_api.h. Use stringEncoder for String tensors.
    306 func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
    307 	switch v.Kind() {
    308 	case reflect.Bool:
    309 		b := byte(0)
    310 		if v.Bool() {
    311 			b = 1
    312 		}
    313 		if err := w.WriteByte(b); err != nil {
    314 			return err
    315 		}
    316 	case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
    317 		if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
    318 			return err
    319 		}
    320 
    321 	case reflect.Array, reflect.Slice:
    322 		// If current dimension is a slice, verify that it has the expected size
    323 		// Go's type system makes that guarantee for arrays.
    324 		if v.Kind() == reflect.Slice {
    325 			expected := int(shape[0])
    326 			if v.Len() != expected {
    327 				return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
    328 			}
    329 		}
    330 
    331 		// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
    332 		if len(shape) == 1 && v.Len() > 0 {
    333 			switch v.Index(0).Kind() {
    334 			case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
    335 				return binary.Write(w, nativeEndian, v.Interface())
    336 			}
    337 		}
    338 
    339 		subShape := shape[1:]
    340 		for i := 0; i < v.Len(); i++ {
    341 			err := encodeTensor(w, v.Index(i), subShape)
    342 			if err != nil {
    343 				return err
    344 			}
    345 		}
    346 
    347 	default:
    348 		return fmt.Errorf("unsupported type %v", v.Type())
    349 	}
    350 	return nil
    351 }
    352 
    353 // decodeTensor decodes the Tensor from the buffer to ptr using the format
    354 // specified in c_api.h. Use stringDecoder for String tensors.
    355 func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
    356 	switch typ.Kind() {
    357 	case reflect.Bool:
    358 		b, err := r.ReadByte()
    359 		if err != nil {
    360 			return err
    361 		}
    362 		ptr.Elem().SetBool(b == 1)
    363 	case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
    364 		if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
    365 			return err
    366 		}
    367 
    368 	case reflect.Slice:
    369 		val := reflect.Indirect(ptr)
    370 		val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0])))
    371 
    372 		// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
    373 		if len(shape) == 1 && val.Len() > 0 {
    374 			switch val.Index(0).Kind() {
    375 			case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
    376 				return binary.Read(r, nativeEndian, val.Interface())
    377 			}
    378 		}
    379 
    380 		for i := 0; i < val.Len(); i++ {
    381 			if err := decodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
    382 				return err
    383 			}
    384 		}
    385 
    386 	default:
    387 		return fmt.Errorf("unsupported type %v", typ)
    388 	}
    389 	return nil
    390 }
    391 
    392 type stringEncoder struct {
    393 	offsets io.Writer
    394 	data    []byte
    395 	offset  uint64
    396 	status  *status
    397 }
    398 
    399 func (e *stringEncoder) encode(v reflect.Value, shape []int64) error {
    400 	if v.Kind() == reflect.String {
    401 		if err := binary.Write(e.offsets, nativeEndian, e.offset); err != nil {
    402 			return err
    403 		}
    404 		var (
    405 			s      = v.Interface().(string)
    406 			src    = C.CString(s)
    407 			srcLen = C.size_t(len(s))
    408 			dst    = (*C.char)(unsafe.Pointer(&e.data[e.offset]))
    409 			dstLen = C.size_t(uint64(len(e.data)) - e.offset)
    410 		)
    411 		e.offset += uint64(C.TF_StringEncode(src, srcLen, dst, dstLen, e.status.c))
    412 		C.free(unsafe.Pointer(src))
    413 		return e.status.Err()
    414 	}
    415 
    416 	if v.Kind() == reflect.Slice {
    417 		expected := int(shape[0])
    418 		if v.Len() != expected {
    419 			return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
    420 		}
    421 	}
    422 
    423 	subShape := shape[1:]
    424 	for i := 0; i < v.Len(); i++ {
    425 		if err := e.encode(v.Index(i), subShape); err != nil {
    426 			return err
    427 		}
    428 	}
    429 	return nil
    430 }
    431 
    432 type stringDecoder struct {
    433 	offsets io.Reader
    434 	data    []byte
    435 	status  *status
    436 }
    437 
    438 func (d *stringDecoder) decode(ptr reflect.Value, shape []int64) error {
    439 	if len(shape) == 0 {
    440 		var offset uint64
    441 		if err := binary.Read(d.offsets, nativeEndian, &offset); err != nil {
    442 			return err
    443 		}
    444 		var (
    445 			src    = (*C.char)(unsafe.Pointer(&d.data[offset]))
    446 			srcLen = C.size_t(len(d.data)) - C.size_t(offset)
    447 			dst    *C.char
    448 			dstLen C.size_t
    449 		)
    450 		if offset > uint64(len(d.data)) {
    451 			return fmt.Errorf("invalid offsets in String Tensor")
    452 		}
    453 		C.TF_StringDecode(src, srcLen, &dst, &dstLen, d.status.c)
    454 		if err := d.status.Err(); err != nil {
    455 			return err
    456 		}
    457 		s := ptr.Interface().(*string)
    458 		*s = C.GoStringN(dst, C.int(dstLen))
    459 		return nil
    460 	}
    461 	val := reflect.Indirect(ptr)
    462 	val.Set(reflect.MakeSlice(typeOf(String, shape), int(shape[0]), int(shape[0])))
    463 	for i := 0; i < val.Len(); i++ {
    464 		if err := d.decode(val.Index(i).Addr(), shape[1:]); err != nil {
    465 			return err
    466 		}
    467 	}
    468 	return nil
    469 }
    470 
    471 func bug(format string, args ...interface{}) error {
    472 	return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
    473 }
    474 
    475 func isTensorSerializable(dataType DataType) error {
    476 	// For numeric types, the serialized Tensor matches the in-memory
    477 	// representation.  See the implementation of Tensor::AsProtoContent in
    478 	// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc
    479 	//
    480 	// The more appropriate way to be in sync with Tensor::AsProtoContent
    481 	// would be to have the TensorFlow C library export functions for
    482 	// serialization and deserialization of Tensors.  Till then capitalize
    483 	// on knowledge of the implementation for numeric types.
    484 	switch dataType {
    485 	case Float, Double, Int32, Uint8, Int16, Int8, Complex, Int64, Bool, Quint8, Qint32, Bfloat16, Qint16, Quint16, Uint16, Complex128, Half:
    486 		return nil
    487 	default:
    488 		return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType)
    489 	}
    490 }
    491 
    492 // nativeEndian is the byte order for the local platform. Used to send back and
    493 // forth Tensors with the C API. We test for endianness at runtime because
    494 // some architectures can be booted into different endian modes.
    495 var nativeEndian binary.ByteOrder
    496 
    497 func init() {
    498 	buf := [2]byte{}
    499 	*(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD)
    500 
    501 	switch buf {
    502 	case [2]byte{0xCD, 0xAB}:
    503 		nativeEndian = binary.LittleEndian
    504 	case [2]byte{0xAB, 0xCD}:
    505 		nativeEndian = binary.BigEndian
    506 	default:
    507 		panic("Could not determine native endianness.")
    508 	}
    509 }
    510