1 /* 2 Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package tensorflow 18 19 // #include <stdlib.h> 20 // #include <string.h> 21 // #include "tensorflow/c/c_api.h" 22 import "C" 23 24 import ( 25 "bytes" 26 "encoding/binary" 27 "fmt" 28 "io" 29 "reflect" 30 "runtime" 31 "unsafe" 32 ) 33 34 // DataType holds the type for a scalar value. E.g., one slot in a tensor. 35 type DataType C.TF_DataType 36 37 // Types of scalar values in the TensorFlow type system. 38 const ( 39 Float DataType = C.TF_FLOAT 40 Double DataType = C.TF_DOUBLE 41 Int32 DataType = C.TF_INT32 42 Uint32 DataType = C.TF_UINT32 43 Uint8 DataType = C.TF_UINT8 44 Int16 DataType = C.TF_INT16 45 Int8 DataType = C.TF_INT8 46 String DataType = C.TF_STRING 47 Complex64 DataType = C.TF_COMPLEX64 48 Complex DataType = C.TF_COMPLEX 49 Int64 DataType = C.TF_INT64 50 Uint64 DataType = C.TF_UINT64 51 Bool DataType = C.TF_BOOL 52 Qint8 DataType = C.TF_QINT8 53 Quint8 DataType = C.TF_QUINT8 54 Qint32 DataType = C.TF_QINT32 55 Bfloat16 DataType = C.TF_BFLOAT16 56 Qint16 DataType = C.TF_QINT16 57 Quint16 DataType = C.TF_QUINT16 58 Uint16 DataType = C.TF_UINT16 59 Complex128 DataType = C.TF_COMPLEX128 60 Half DataType = C.TF_HALF 61 ) 62 63 // Tensor holds a multi-dimensional array of elements of a single data type. 64 type Tensor struct { 65 c *C.TF_Tensor 66 shape []int64 67 } 68 69 // NewTensor converts from a Go value to a Tensor. Valid values are scalars, 70 // slices, and arrays. Every element of a slice must have the same length so 71 // that the resulting Tensor has a valid shape. 72 func NewTensor(value interface{}) (*Tensor, error) { 73 val := reflect.ValueOf(value) 74 shape, dataType, err := shapeAndDataTypeOf(val) 75 if err != nil { 76 return nil, err 77 } 78 nflattened := numElements(shape) 79 nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened) 80 if dataType == String { 81 // TF_STRING tensors are encoded as an array of 8-byte offsets 82 // followed by string data. See c_api.h. 83 nbytes = uintptr(nflattened*8) + byteSizeOfEncodedStrings(value) 84 } 85 var shapePtr *C.int64_t 86 if len(shape) > 0 { 87 shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0])) 88 } 89 t := &Tensor{ 90 c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)), 91 shape: shape, 92 } 93 runtime.SetFinalizer(t, (*Tensor).finalize) 94 raw := tensorData(t.c) 95 buf := bytes.NewBuffer(raw[:0:len(raw)]) 96 if dataType != String { 97 if err := encodeTensor(buf, val, shape); err != nil { 98 return nil, err 99 } 100 if uintptr(buf.Len()) != nbytes { 101 return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len()) 102 } 103 } else { 104 e := stringEncoder{offsets: buf, data: raw[nflattened*8:], status: newStatus()} 105 if err := e.encode(reflect.ValueOf(value), shape); err != nil { 106 return nil, err 107 } 108 if int64(buf.Len()) != nflattened*8 { 109 return nil, bug("invalid offset encoding for TF_STRING tensor with shape %v (got %v, want %v)", shape, buf.Len(), nflattened*8) 110 } 111 } 112 return t, nil 113 } 114 115 // ReadTensor constructs a Tensor with the provided type and shape from the 116 // serialized tensor contents in r. 117 // 118 // See also WriteContentsTo. 119 func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) { 120 if err := isTensorSerializable(dataType); err != nil { 121 return nil, err 122 } 123 nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape)) 124 var shapePtr *C.int64_t 125 if len(shape) > 0 { 126 shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0])) 127 } 128 t := &Tensor{ 129 c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)), 130 shape: shape, 131 } 132 runtime.SetFinalizer(t, (*Tensor).finalize) 133 raw := tensorData(t.c) 134 n, err := r.Read(raw) 135 if err != nil { 136 return nil, err 137 } 138 if uintptr(n) != nbytes { 139 return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n) 140 } 141 return t, nil 142 } 143 144 // newTensorFromC takes ownership of c and returns the owning Tensor. 145 func newTensorFromC(c *C.TF_Tensor) *Tensor { 146 var shape []int64 147 if ndims := int(C.TF_NumDims(c)); ndims > 0 { 148 shape = make([]int64, ndims) 149 } 150 for i := range shape { 151 shape[i] = int64(C.TF_Dim(c, C.int(i))) 152 } 153 t := &Tensor{c: c, shape: shape} 154 runtime.SetFinalizer(t, (*Tensor).finalize) 155 return t 156 } 157 158 func (t *Tensor) finalize() { C.TF_DeleteTensor(t.c) } 159 160 // DataType returns the scalar datatype of the Tensor. 161 func (t *Tensor) DataType() DataType { return DataType(C.TF_TensorType(t.c)) } 162 163 // Shape returns the shape of the Tensor. 164 func (t *Tensor) Shape() []int64 { return t.shape } 165 166 // Value converts the Tensor to a Go value. For now, not all Tensor types are 167 // supported, and this function may panic if it encounters an unsupported 168 // DataType. 169 // 170 // The type of the output depends on the Tensor type and dimensions. 171 // For example: 172 // Tensor(int64, 0): int64 173 // Tensor(float64, 3): [][][]float64 174 func (t *Tensor) Value() interface{} { 175 typ := typeOf(t.DataType(), t.Shape()) 176 val := reflect.New(typ) 177 raw := tensorData(t.c) 178 if t.DataType() != String { 179 if err := decodeTensor(bytes.NewReader(raw), t.Shape(), typ, val); err != nil { 180 panic(bug("unable to decode Tensor of type %v and shape %v - %v", t.DataType(), t.Shape(), err)) 181 } 182 } else { 183 nflattened := numElements(t.Shape()) 184 d := stringDecoder{offsets: bytes.NewReader(raw[0 : 8*nflattened]), data: raw[8*nflattened:], status: newStatus()} 185 if err := d.decode(val, t.Shape()); err != nil { 186 panic(bug("unable to decode String tensor with shape %v - %v", t.Shape(), err)) 187 } 188 } 189 return reflect.Indirect(val).Interface() 190 } 191 192 // WriteContentsTo writes the serialized contents of t to w. 193 // 194 // Returns the number of bytes written. See ReadTensor for 195 // reconstructing a Tensor from the serialized form. 196 // 197 // WARNING: WriteContentsTo is not comprehensive and will fail 198 // if t.DataType() is non-numeric (e.g., String). See 199 // https://github.com/tensorflow/tensorflow/issues/6003. 200 func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error) { 201 if err := isTensorSerializable(t.DataType()); err != nil { 202 return 0, err 203 } 204 return io.Copy(w, bytes.NewReader(tensorData(t.c))) 205 } 206 207 func tensorData(c *C.TF_Tensor) []byte { 208 // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices 209 cbytes := C.TF_TensorData(c) 210 if cbytes == nil { 211 return nil 212 } 213 length := int(C.TF_TensorByteSize(c)) 214 slice := (*[1 << 30]byte)(unsafe.Pointer(cbytes))[:length:length] 215 return slice 216 } 217 218 var types = []struct { 219 typ reflect.Type 220 dataType C.TF_DataType 221 }{ 222 {reflect.TypeOf(float32(0)), C.TF_FLOAT}, 223 {reflect.TypeOf(float64(0)), C.TF_DOUBLE}, 224 {reflect.TypeOf(int32(0)), C.TF_INT32}, 225 {reflect.TypeOf(uint32(0)), C.TF_UINT32}, 226 {reflect.TypeOf(uint8(0)), C.TF_UINT8}, 227 {reflect.TypeOf(int16(0)), C.TF_INT16}, 228 {reflect.TypeOf(int8(0)), C.TF_INT8}, 229 {reflect.TypeOf(""), C.TF_STRING}, 230 {reflect.TypeOf(complex(float32(0), float32(0))), C.TF_COMPLEX64}, 231 {reflect.TypeOf(int64(0)), C.TF_INT64}, 232 {reflect.TypeOf(uint64(0)), C.TF_UINT64}, 233 {reflect.TypeOf(false), C.TF_BOOL}, 234 {reflect.TypeOf(uint16(0)), C.TF_UINT16}, 235 {reflect.TypeOf(complex(float64(0), float64(0))), C.TF_COMPLEX128}, 236 // TODO(apassos): support DT_RESOURCE representation in go. 237 // TODO(keveman): support DT_VARIANT representation in go. 238 } 239 240 // shapeAndDataTypeOf returns the data type and shape of the Tensor 241 // corresponding to a Go type. 242 func shapeAndDataTypeOf(val reflect.Value) (shape []int64, dt DataType, err error) { 243 typ := val.Type() 244 for typ.Kind() == reflect.Array || typ.Kind() == reflect.Slice { 245 shape = append(shape, int64(val.Len())) 246 if val.Len() > 0 { 247 // In order to check tensor structure properly in general case we need to iterate over all slices of the tensor to check sizes match 248 // Since we already going to iterate over all elements in encodeTensor() let's 249 // 1) do the actual check in encodeTensor() to save some cpu cycles here 250 // 2) assume the shape is represented by lengths of elements with zero index in each dimension 251 val = val.Index(0) 252 } 253 typ = typ.Elem() 254 } 255 for _, t := range types { 256 if typ.Kind() == t.typ.Kind() { 257 return shape, DataType(t.dataType), nil 258 } 259 } 260 return shape, dt, fmt.Errorf("unsupported type %v", typ) 261 } 262 263 // typeOf converts from a DataType and Shape to the equivalent Go type. 264 func typeOf(dt DataType, shape []int64) reflect.Type { 265 var ret reflect.Type 266 for _, t := range types { 267 if dt == DataType(t.dataType) { 268 ret = t.typ 269 break 270 } 271 } 272 if ret == nil { 273 panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt)) 274 } 275 for range shape { 276 ret = reflect.SliceOf(ret) 277 } 278 return ret 279 } 280 281 func numElements(shape []int64) int64 { 282 n := int64(1) 283 for _, d := range shape { 284 n *= d 285 } 286 return n 287 } 288 289 // byteSizeOfEncodedStrings returns the size of the encoded strings in val. 290 // val MUST be a string, or a container (array/slice etc.) of strings. 291 func byteSizeOfEncodedStrings(val interface{}) uintptr { 292 if s, ok := val.(string); ok { 293 return uintptr(C.TF_StringEncodedSize(C.size_t(len(s)))) 294 } 295 // Otherwise must be an array or slice. 296 var size uintptr 297 v := reflect.ValueOf(val) 298 for i := 0; i < v.Len(); i++ { 299 size += byteSizeOfEncodedStrings(v.Index(i).Interface()) 300 } 301 return size 302 } 303 304 // encodeTensor writes v to the specified buffer using the format specified in 305 // c_api.h. Use stringEncoder for String tensors. 306 func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error { 307 switch v.Kind() { 308 case reflect.Bool: 309 b := byte(0) 310 if v.Bool() { 311 b = 1 312 } 313 if err := w.WriteByte(b); err != nil { 314 return err 315 } 316 case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: 317 if err := binary.Write(w, nativeEndian, v.Interface()); err != nil { 318 return err 319 } 320 321 case reflect.Array, reflect.Slice: 322 // If current dimension is a slice, verify that it has the expected size 323 // Go's type system makes that guarantee for arrays. 324 if v.Kind() == reflect.Slice { 325 expected := int(shape[0]) 326 if v.Len() != expected { 327 return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected) 328 } 329 } 330 331 // Optimisation: if only one dimension is left we can use binary.Write() directly for this slice 332 if len(shape) == 1 && v.Len() > 0 { 333 switch v.Index(0).Kind() { 334 case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: 335 return binary.Write(w, nativeEndian, v.Interface()) 336 } 337 } 338 339 subShape := shape[1:] 340 for i := 0; i < v.Len(); i++ { 341 err := encodeTensor(w, v.Index(i), subShape) 342 if err != nil { 343 return err 344 } 345 } 346 347 default: 348 return fmt.Errorf("unsupported type %v", v.Type()) 349 } 350 return nil 351 } 352 353 // decodeTensor decodes the Tensor from the buffer to ptr using the format 354 // specified in c_api.h. Use stringDecoder for String tensors. 355 func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error { 356 switch typ.Kind() { 357 case reflect.Bool: 358 b, err := r.ReadByte() 359 if err != nil { 360 return err 361 } 362 ptr.Elem().SetBool(b == 1) 363 case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: 364 if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil { 365 return err 366 } 367 368 case reflect.Slice: 369 val := reflect.Indirect(ptr) 370 val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0]))) 371 372 // Optimization: if only one dimension is left we can use binary.Read() directly for this slice 373 if len(shape) == 1 && val.Len() > 0 { 374 switch val.Index(0).Kind() { 375 case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: 376 return binary.Read(r, nativeEndian, val.Interface()) 377 } 378 } 379 380 for i := 0; i < val.Len(); i++ { 381 if err := decodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil { 382 return err 383 } 384 } 385 386 default: 387 return fmt.Errorf("unsupported type %v", typ) 388 } 389 return nil 390 } 391 392 type stringEncoder struct { 393 offsets io.Writer 394 data []byte 395 offset uint64 396 status *status 397 } 398 399 func (e *stringEncoder) encode(v reflect.Value, shape []int64) error { 400 if v.Kind() == reflect.String { 401 if err := binary.Write(e.offsets, nativeEndian, e.offset); err != nil { 402 return err 403 } 404 var ( 405 s = v.Interface().(string) 406 src = C.CString(s) 407 srcLen = C.size_t(len(s)) 408 dst = (*C.char)(unsafe.Pointer(&e.data[e.offset])) 409 dstLen = C.size_t(uint64(len(e.data)) - e.offset) 410 ) 411 e.offset += uint64(C.TF_StringEncode(src, srcLen, dst, dstLen, e.status.c)) 412 C.free(unsafe.Pointer(src)) 413 return e.status.Err() 414 } 415 416 if v.Kind() == reflect.Slice { 417 expected := int(shape[0]) 418 if v.Len() != expected { 419 return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected) 420 } 421 } 422 423 subShape := shape[1:] 424 for i := 0; i < v.Len(); i++ { 425 if err := e.encode(v.Index(i), subShape); err != nil { 426 return err 427 } 428 } 429 return nil 430 } 431 432 type stringDecoder struct { 433 offsets io.Reader 434 data []byte 435 status *status 436 } 437 438 func (d *stringDecoder) decode(ptr reflect.Value, shape []int64) error { 439 if len(shape) == 0 { 440 var offset uint64 441 if err := binary.Read(d.offsets, nativeEndian, &offset); err != nil { 442 return err 443 } 444 var ( 445 src = (*C.char)(unsafe.Pointer(&d.data[offset])) 446 srcLen = C.size_t(len(d.data)) - C.size_t(offset) 447 dst *C.char 448 dstLen C.size_t 449 ) 450 if offset > uint64(len(d.data)) { 451 return fmt.Errorf("invalid offsets in String Tensor") 452 } 453 C.TF_StringDecode(src, srcLen, &dst, &dstLen, d.status.c) 454 if err := d.status.Err(); err != nil { 455 return err 456 } 457 s := ptr.Interface().(*string) 458 *s = C.GoStringN(dst, C.int(dstLen)) 459 return nil 460 } 461 val := reflect.Indirect(ptr) 462 val.Set(reflect.MakeSlice(typeOf(String, shape), int(shape[0]), int(shape[0]))) 463 for i := 0; i < val.Len(); i++ { 464 if err := d.decode(val.Index(i).Addr(), shape[1:]); err != nil { 465 return err 466 } 467 } 468 return nil 469 } 470 471 func bug(format string, args ...interface{}) error { 472 return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...)) 473 } 474 475 func isTensorSerializable(dataType DataType) error { 476 // For numeric types, the serialized Tensor matches the in-memory 477 // representation. See the implementation of Tensor::AsProtoContent in 478 // https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc 479 // 480 // The more appropriate way to be in sync with Tensor::AsProtoContent 481 // would be to have the TensorFlow C library export functions for 482 // serialization and deserialization of Tensors. Till then capitalize 483 // on knowledge of the implementation for numeric types. 484 switch dataType { 485 case Float, Double, Int32, Uint8, Int16, Int8, Complex, Int64, Bool, Quint8, Qint32, Bfloat16, Qint16, Quint16, Uint16, Complex128, Half: 486 return nil 487 default: 488 return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType) 489 } 490 } 491 492 // nativeEndian is the byte order for the local platform. Used to send back and 493 // forth Tensors with the C API. We test for endianness at runtime because 494 // some architectures can be booted into different endian modes. 495 var nativeEndian binary.ByteOrder 496 497 func init() { 498 buf := [2]byte{} 499 *(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD) 500 501 switch buf { 502 case [2]byte{0xCD, 0xAB}: 503 nativeEndian = binary.LittleEndian 504 case [2]byte{0xAB, 0xCD}: 505 nativeEndian = binary.BigEndian 506 default: 507 panic("Could not determine native endianness.") 508 } 509 } 510