Home | History | Annotate | Download | only in internal
      1 /*
      2 Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      3 
      4 Licensed under the Apache License, Version 2.0 (the "License");
      5 you may not use this file except in compliance with the License.
      6 You may obtain a copy of the License at
      7 
      8     http://www.apache.org/licenses/LICENSE-2.0
      9 
     10 Unless required by applicable law or agreed to in writing, software
     11 distributed under the License is distributed on an "AS IS" BASIS,
     12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 See the License for the specific language governing permissions and
     14 limitations under the License.
     15 */
     16 
     17 // Package internal generates Go source code with functions for TensorFlow operations.
     18 //
     19 // The basic outline of the generated API is as follows:
     20 //
     21 // - One function for each TensorFlow operation
     22 // - The arguments to the function are the inputs and required attributes of the operation
     23 // - The function returns the outputs
     24 // - A function is also generated for each optional attribute of the operation.
     25 //
     26 // There is a possibility that there are name collisions between the functions
     27 // generated for ops and the functions generated for optional attributes. For
     28 // now, we ignore those, but will need to revisit if a collision is actually
     29 // encountered.
     30 package internal
     31 
     32 /*
     33 #include <stdlib.h>
     34 
     35 #include "tensorflow/c/c_api.h"
     36 */
     37 import "C"
     38 
     39 import (
     40 	"fmt"
     41 	"io"
     42 	"io/ioutil"
     43 	"path"
     44 	"reflect"
     45 	"strings"
     46 	"text/template"
     47 	"unsafe"
     48 
     49 	"github.com/golang/protobuf/proto"
     50 	pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework"
     51 )
     52 
     53 // GenerateFunctionsForRegisteredOps writes a Go source code file to w
     54 // containing functions for each TensorFlow operation registered in the address
     55 // space of the calling process.
     56 // apidefDirs should be a contain of directories containing api_def_*.pbtxt
     57 // files to load.
     58 func GenerateFunctionsForRegisteredOps(
     59 	w io.Writer, apidefDirs []string) error {
     60 	ops, apimap, err := registeredOps()
     61 	if err != nil {
     62 		return err
     63 	}
     64 	for _, dir := range apidefDirs {
     65 		if err = updateAPIDefs(apimap, dir); err != nil {
     66 			return err
     67 		}
     68 	}
     69 	return generateFunctionsForOps(w, ops, apimap)
     70 }
     71 
     72 func registeredOps() (*pb.OpList, *apiDefMap, error) {
     73 	buf := C.TF_GetAllOpList()
     74 	defer C.TF_DeleteBuffer(buf)
     75 	var (
     76 		list = new(pb.OpList)
     77 		size = int(buf.length)
     78 		// A []byte backed by C memory.
     79 		// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
     80 		data = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:size:size]
     81 		err  = proto.Unmarshal(data, list)
     82 	)
     83 	if err != nil {
     84 		return nil, nil, err
     85 	}
     86 	apimap, err := newAPIDefMap(list)
     87 	return list, apimap, err
     88 }
     89 
     90 func updateAPIDefs(m *apiDefMap, dir string) error {
     91 	files, err := ioutil.ReadDir(dir)
     92 	if err != nil {
     93 		return err
     94 	}
     95 	for _, file := range files {
     96 		data, err := ioutil.ReadFile(path.Join(dir, file.Name()))
     97 		if err != nil {
     98 			return fmt.Errorf("failed to read %q: %v", file.Name(), err)
     99 		}
    100 		if err = m.Put(string(data)); err != nil {
    101 			return fmt.Errorf("failed to process %q: %v", file.Name(), err)
    102 		}
    103 	}
    104 	return nil
    105 }
    106 
    107 func generateFunctionsForOps(w io.Writer, ops *pb.OpList, apimap *apiDefMap) error {
    108 	thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath()
    109 	if err := tmplHeader.Execute(w, thisPackage); err != nil {
    110 		return err
    111 	}
    112 	blacklist := map[string]bool{
    113 		"Const":           true,
    114 		"PyFunc":          true,
    115 		"PyFuncStateless": true,
    116 	}
    117 	for _, op := range ops.Op {
    118 		if blacklist[op.Name] {
    119 			continue
    120 		}
    121 		apidef, err := apimap.Get(op.Name)
    122 		if err != nil {
    123 			return err
    124 		}
    125 		if err := generateFunctionForOp(w, op, apidef); err != nil {
    126 			return err
    127 		}
    128 	}
    129 	return nil
    130 }
    131 
    132 func generateFunctionForOp(w io.Writer, op *pb.OpDef, apidef *pb.ApiDef) error {
    133 	if strings.HasPrefix(op.Name, "_") { // Internal operation
    134 		return nil
    135 	}
    136 	// Ignore operations where the Go types corresponding to the TensorFlow
    137 	// type haven't been worked out (such as "func"s).
    138 	for _, a := range op.Attr {
    139 		if _, err := goType(a.Type); err != nil {
    140 			return nil
    141 		}
    142 	}
    143 	// Also, haven't figured out reference types yet, so ignore those too.
    144 	for _, a := range op.InputArg {
    145 		if a.IsRef {
    146 			return nil
    147 		}
    148 	}
    149 	for _, a := range op.OutputArg {
    150 		if a.IsRef {
    151 			return nil
    152 		}
    153 	}
    154 	if apidef.Summary == "" {
    155 		// Undocumented operation, perhaps a sign of not being ready to
    156 		// export.
    157 		return nil
    158 	}
    159 	tmplArgs, err := newTmplArgs(op, apidef)
    160 	if err != nil {
    161 		return err
    162 	}
    163 	return tmplOp.Execute(w, tmplArgs)
    164 }
    165 
    166 var (
    167 	// Go keywords that cannot be used as identifiers.
    168 	// From https://golang.org/ref/spec#Keywords
    169 	keywords = []string{
    170 		"break", "default", "func", "interface", "select", "case",
    171 		"defer", "go", "map", "struct", "chan", "else", "goto",
    172 		"package", "switch", "const", "fallthrough", "if", "range",
    173 		"type", "continue", "for", "import", "return", "var",
    174 	}
    175 
    176 	tmplHeader = template.Must(template.New("header").Parse(`// DO NOT EDIT
    177 // This file was machine generated by {{.}}
    178 //
    179 // WARNING: This generation of wrapper function for TensorFlow ops is in an
    180 // experimental state. The generated API can change without notice.
    181 
    182 package op
    183 
    184 import tf "github.com/tensorflow/tensorflow/tensorflow/go"
    185 
    186 // optionalAttr is an intentionally un-exported type to hide
    187 // details of how optional attributes to operations are implemented.
    188 type optionalAttr map[string]interface{}
    189 
    190 func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, int, error) {
    191 	size, err := op.OutputListSize(output)
    192 	if err != nil {
    193 		return nil, start, err
    194 	}
    195 	list := make([]tf.Output, size)
    196 	for i := 0; i < size; i++ {
    197 		list[i] = op.Output(start + i)
    198 	}
    199 	return list, start + size, nil
    200 }
    201 `))
    202 
    203 	tmplOp = template.Must(template.New("op").Funcs(template.FuncMap{
    204 		"MakeComment":       makeComment,
    205 		"GoType":            goType,
    206 		"CamelCase":         camelCase,
    207 		"Identifier":        identifier,
    208 		"IsListArg":         isListArg,
    209 		"IsListAttr":        isListAttr,
    210 		"StripLeadingColon": stripLeadingColon,
    211 	}).Parse(`
    212 {{if .OptionalAttrs -}}
    213 {{/* Type for specifying all optional attributes. */ -}}
    214 // {{.Op.Name}}Attr is an optional argument to {{.Op.Name}}.
    215 type {{.Op.Name}}Attr func(optionalAttr)
    216 
    217 {{range .OptionalAttrs}}
    218 // {{$.Op.Name}}{{CamelCase .RenameTo}} sets the optional {{.RenameTo}} attribute to value.
    219 {{- if .Description}}
    220 //
    221 // value: {{MakeComment .Description}}
    222 {{- end}}
    223 // If not specified, defaults to {{StripLeadingColon .DefaultValue}}
    224 {{- if .HasMinimum}}
    225 //
    226 // {{if .IsListAttr }}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}}
    227 {{- end}}
    228 func {{$.Op.Name}}{{CamelCase .RenameTo}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {
    229 	return func(m optionalAttr) {
    230 		m[{{printf "%q" .Name}}] = value
    231 	}
    232 }
    233 {{end}}
    234 {{end}}
    235 
    236 {{- /* Create a godoc friendly comment. */ -}}
    237 
    238 // {{MakeComment .APIDef.Summary}}
    239 
    240 {{- with .Op.Deprecation}}
    241 //
    242 // DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}}
    243 {{- end -}}
    244 
    245 {{- with .APIDef.Description}}
    246 //
    247 // {{MakeComment .}}
    248 {{- end -}}
    249 
    250 {{- if .DescribeArguments}}
    251 //
    252 // Arguments:
    253 {{- range .InArgsReordered}}
    254 //	{{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
    255 {{- end -}}
    256 {{- range .RequiredAttrs}}
    257 //	{{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
    258 {{- end -}}
    259 {{- end -}}
    260 
    261 {{- if (not .Op.OutputArg) }}
    262 //
    263 // Returns the created operation.
    264 {{- else }}
    265 {{- if .DescribeOutputs}}
    266 //
    267 {{- if ((len .OutArgs) eq 1) }}
    268 // Returns {{range .OutArgs}}{{MakeComment .Description}}{{end}}
    269 {{- else }}
    270 // Returns:
    271 {{- range .OutArgs}}
    272 //	{{Identifier .RenameTo}}{{if .Description}}: {{MakeComment .Description}}{{end}}
    273 {{- end -}}
    274 {{- end -}}
    275 {{- end -}}
    276 {{- end -}}
    277 {{- /*
    278 
    279   The function signature.
    280   Since OpDef.Name is in CamelCase, it cannot conflict with a reserved keyword in Golang
    281 */}}
    282 func {{.Op.Name}}
    283 
    284 {{- /*
    285   Fill in input arguments:
    286   (1) The Scope
    287   (2) All input arguments (which may be either []tf.Output or tf.Output)
    288   (3) All required attributes
    289   (4) Variadic list of optional attributes
    290 */ -}}
    291 
    292 (scope *Scope
    293 {{- range $i, $a := .InArgsReordered}}, {{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}}
    294 {{range $i, $a := .RequiredAttrs}}, {{Identifier $a.RenameTo}} {{GoType $a.Type}}{{end -}}
    295 {{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}}
    296 )
    297 
    298 {{- /* Construct outputs: len(.OutArgs) or a *tf.Operation */ -}}
    299 
    300 {{if .OutArgs -}}
    301 ({{range $i,$a := .OutArgs}}{{if $i}}, {{end}}{{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}})
    302 {{- else -}}
    303 (o *tf.Operation)
    304 {{- end }} {
    305 	if scope.Err() != nil {
    306 		return
    307 	}
    308 	{{if .HasAttrs -}}
    309 	attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .RenameTo}},{{end}}}
    310 	{{if .OptionalAttrs -}}
    311 	for _, a := range optional {
    312 		a(attrs)
    313 	}
    314 	{{end -}}
    315 	{{end -}}
    316 	opspec := tf.OpSpec{
    317 		Type: {{printf "%q" .Op.Name}},
    318 		{{if .InArgs -}}
    319 		Input: []tf.Input{
    320 			{{range $i,$a := .InArgs}}{{if $a.IsListArg}}tf.OutputList({{Identifier $a.RenameTo}}){{else}}{{Identifier $a.RenameTo}}{{end}}, {{end}}
    321 		},
    322 		{{- end}}
    323 		{{- if .HasAttrs}}
    324 		Attrs: attrs,
    325 		{{- end}}
    326 	}
    327 	{{- if .OutArgs}}
    328 	{{- if .HasListOutput}}
    329 	op := scope.AddOperation(opspec)
    330 	if scope.Err() != nil {
    331 		return
    332 	}
    333 	var idx int
    334 	var err error
    335 	{{- range $i, $a := .OutArgs}}
    336 	{{- if $a.IsListArg}}
    337 	if {{Identifier .RenameTo}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil {
    338 		scope.UpdateErr({{printf "%q" $.Op.Name}}, err)
    339 		return
    340 	}
    341 	{{- else }}
    342 	{{Identifier .RenameTo}} = op.Output(idx)
    343 	{{- end }}{{- /* if IsListArg */}}
    344 	{{- end }}{{- /* range .OutArgs */}}
    345 	return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}{{Identifier .RenameTo}}{{end}}
    346 	{{- else }}
    347 	op := scope.AddOperation(opspec)
    348 	return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}op.Output({{$i}}){{end}}
    349 	{{- end }}{{- /* if .HasListOutput */}}
    350 	{{- else }}
    351 	return scope.AddOperation(opspec)
    352 	{{- end }}{{- /* if .OutArgs */}}
    353 }
    354 `))
    355 )
    356 
    357 type attrWrapper struct {
    358 	op  *pb.OpDef_AttrDef
    359 	api *pb.ApiDef_Attr
    360 }
    361 
    362 func (a *attrWrapper) Name() string             { return a.api.Name }
    363 func (a *attrWrapper) RenameTo() string         { return a.api.RenameTo }
    364 func (a *attrWrapper) Description() string      { return a.api.Description }
    365 func (a *attrWrapper) Type() string             { return a.op.Type }
    366 func (a *attrWrapper) IsListAttr() bool         { return isListAttr(a.op) }
    367 func (a *attrWrapper) HasMinimum() bool         { return a.op.HasMinimum }
    368 func (a *attrWrapper) Minimum() int64           { return a.op.Minimum }
    369 func (a *attrWrapper) DefaultValue() interface{} { return a.api.DefaultValue }
    370 
    371 type argWrapper struct {
    372 	op  *pb.OpDef_ArgDef
    373 	api *pb.ApiDef_Arg
    374 }
    375 
    376 func (a *argWrapper) Name() string        { return a.api.Name }
    377 func (a *argWrapper) RenameTo() string    { return a.api.RenameTo }
    378 func (a *argWrapper) Description() string { return a.api.Description }
    379 func (a *argWrapper) IsListArg() bool     { return isListArg(a.op) }
    380 
    381 type tmplArgs struct {
    382 	Op     *pb.OpDef
    383 	APIDef *pb.ApiDef
    384 	// Op.Attr is split into two categories
    385 	// (1) Required: These must be specified by the client and are thus
    386 	//     included in the function signature.
    387 	// (2) Optional: These need not be specified (as they have default
    388 	//     values) and thus do not appear in the function signature.
    389 	RequiredAttrs []*attrWrapper
    390 	OptionalAttrs []*attrWrapper
    391 	InArgs        []*argWrapper
    392 	// Input arguments ordered based on arg_order field of ApiDef.
    393 	InArgsReordered []*argWrapper
    394 	OutArgs         []*argWrapper
    395 }
    396 
    397 func newTmplArgs(op *pb.OpDef, apidef *pb.ApiDef) (*tmplArgs, error) {
    398 	ret := tmplArgs{Op: op, APIDef: apidef}
    399 
    400 	// Setup InArgs field
    401 	for i, in := range op.InputArg {
    402 		argCombined := argWrapper{op: in, api: apidef.InArg[i]}
    403 		ret.InArgs = append(ret.InArgs, &argCombined)
    404 	}
    405 
    406 	// Setup OutArgs field
    407 	for i, out := range op.OutputArg {
    408 		argCombined := argWrapper{op: out, api: apidef.OutArg[i]}
    409 		ret.OutArgs = append(ret.OutArgs, &argCombined)
    410 	}
    411 
    412 	// Setup InArgsReordered field
    413 	for _, argName := range apidef.ArgOrder {
    414 		// Find the argument in op.InputArg
    415 		argIndex := -1
    416 		for i, in := range op.InputArg {
    417 			if in.Name == argName {
    418 				argIndex = i
    419 				break
    420 			}
    421 		}
    422 		if argIndex == -1 {
    423 			return nil, fmt.Errorf(
    424 				"couldn't find argument %s in ApiDef for op %s",
    425 				argName, op.Name)
    426 		}
    427 		argCombined := argWrapper{
    428 			op: op.InputArg[argIndex], api: apidef.InArg[argIndex]}
    429 		ret.InArgsReordered = append(ret.InArgsReordered, &argCombined)
    430 	}
    431 
    432 	if len(op.Attr) == 0 {
    433 		return &ret, nil
    434 	}
    435 	// Attributes related to the InputArg's type are inferred automatically
    436 	// and are not exposed to the client.
    437 	inferred := make(map[string]bool)
    438 	for _, in := range op.InputArg {
    439 		switch {
    440 		case in.TypeAttr != "":
    441 			inferred[in.TypeAttr] = true
    442 		case in.TypeListAttr != "":
    443 			inferred[in.TypeListAttr] = true
    444 		}
    445 		if in.NumberAttr != "" {
    446 			inferred[in.NumberAttr] = true
    447 		}
    448 	}
    449 	for i, attr := range op.Attr {
    450 		if inferred[attr.Name] {
    451 			continue
    452 		}
    453 		attrCombined := attrWrapper{op: attr, api: apidef.Attr[i]}
    454 		if attr.DefaultValue == nil {
    455 			ret.RequiredAttrs = append(ret.RequiredAttrs, &attrCombined)
    456 		} else {
    457 			ret.OptionalAttrs = append(ret.OptionalAttrs, &attrCombined)
    458 		}
    459 	}
    460 	return &ret, nil
    461 }
    462 
    463 func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 }
    464 func (a *tmplArgs) DescribeArguments() bool {
    465 	for _, arg := range a.InArgs {
    466 		if arg.Description() != "" {
    467 			return true
    468 		}
    469 	}
    470 	for _, attr := range a.RequiredAttrs {
    471 		if attr.Description() != "" {
    472 			return true
    473 		}
    474 	}
    475 	return false
    476 
    477 }
    478 func (a *tmplArgs) DescribeOutputs() bool {
    479 	for _, arg := range a.OutArgs {
    480 		if arg.Description() != "" {
    481 			return true
    482 		}
    483 	}
    484 	return false
    485 }
    486 func (a *tmplArgs) HasListOutput() bool {
    487 	for _, arg := range a.OutArgs {
    488 		if arg.IsListArg() {
    489 			return true
    490 		}
    491 	}
    492 	return false
    493 }
    494 
    495 func makeComment(lines string) string {
    496 	return strings.Join(strings.SplitAfter(lines, "\n"), "// ")
    497 }
    498 
    499 // goType converts a TensorFlow "type" ('string', 'int', 'list(string)' etc.)
    500 // to the corresponding type in Go.
    501 func goType(tfType string) (string, error) {
    502 	list, tfType := parseTFType(tfType)
    503 	var gotype string
    504 	switch tfType {
    505 	case "int":
    506 		gotype = "int64"
    507 	case "float":
    508 		gotype = "float32"
    509 	case "bool":
    510 		gotype = "bool"
    511 	case "type":
    512 		gotype = "tf.DataType"
    513 	case "shape":
    514 		gotype = "tf.Shape"
    515 	case "tensor":
    516 		gotype = "tf.Tensor"
    517 	case "string":
    518 		gotype = "string"
    519 	default:
    520 		return "", fmt.Errorf("%q is not a recognized DataType", tfType)
    521 	}
    522 	if list {
    523 		gotype = "[]" + gotype
    524 	}
    525 	return gotype, nil
    526 }
    527 
    528 func camelCase(snakeCase string) string {
    529 	words := strings.Split(snakeCase, "_")
    530 	for i, w := range words {
    531 		words[i] = strings.ToUpper(string(w[0])) + w[1:]
    532 	}
    533 	return strings.Join(words, "")
    534 }
    535 
    536 // identifier creates an identifier for s usable in the generated Go source
    537 // code.
    538 //
    539 // Avoids collisions with keywords and other identifiers used in the generated
    540 // code.
    541 func identifier(s string) string {
    542 	// Identifiers used in the generated code.
    543 	if s == "tf" || s == "scope" || s == "err" || s == "op" {
    544 		return s + "_"
    545 	}
    546 	for _, k := range keywords {
    547 		if s == k {
    548 			// Alternatively, make the first letter upper case.
    549 			return s + "_"
    550 		}
    551 	}
    552 	return s
    553 }
    554 
    555 func isListArg(argdef *pb.OpDef_ArgDef) bool {
    556 	return argdef.TypeListAttr != "" || argdef.NumberAttr != ""
    557 }
    558 
    559 func isListAttr(attrdef *pb.OpDef_AttrDef) bool {
    560 	list, _ := parseTFType(attrdef.Type)
    561 	return list
    562 }
    563 
    564 // stripLeadingColon removes the prefix of the string up to the first colon.
    565 //
    566 // This is useful when 's' corresponds to a "oneof" protocol buffer message.
    567 // For example, consider the protocol buffer message:
    568 //   oneof value { bool b = 1;  int64 i = 2; }
    569 // String() on a Go corresponding object (using proto.CompactTextString) will
    570 // print "b:true", or "i:7" etc. This function strips out the leading "b:" or
    571 // "i:".
    572 func stripLeadingColon(s fmt.Stringer) string {
    573 	x := s.String()
    574 	y := strings.SplitN(x, ":", 2)
    575 	if len(y) < 2 {
    576 		return x
    577 	}
    578 	return y[1]
    579 }
    580 
    581 func parseTFType(tfType string) (list bool, typ string) {
    582 	const (
    583 		listPrefix = "list("
    584 		listSuffix = ")"
    585 	)
    586 	if strings.HasPrefix(tfType, listPrefix) && strings.HasSuffix(tfType, listSuffix) {
    587 		return true, strings.TrimSuffix(strings.TrimPrefix(tfType, listPrefix), listSuffix)
    588 	}
    589 	return false, tfType
    590 }
    591