1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Non-core ops for LabeledTensor.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import collections 21 import types 22 23 import numpy as np 24 from six import string_types 25 26 from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc 27 from tensorflow.contrib.labeled_tensor.python.ops import core 28 from tensorflow.python.framework import dtypes 29 from tensorflow.python.framework import ops 30 from tensorflow.python.ops import array_ops 31 from tensorflow.python.ops import functional_ops 32 from tensorflow.python.ops import math_ops 33 from tensorflow.python.ops import numerics 34 from tensorflow.python.ops import random_ops 35 from tensorflow.python.training import input # pylint: disable=redefined-builtin 36 37 38 @tc.returns(core.LabeledTensor) 39 @tc.accepts(core.LabeledTensor, ops.Tensor, core.Axis, 40 tc.Optional(string_types)) 41 def _gather_1d_on_axis(labeled_tensor, indexer, axis, name=None): 42 with ops.name_scope(name, 'lt_take', [labeled_tensor]) as scope: 43 temp_axes = core.Axes([axis] + list( 44 labeled_tensor.axes.remove(axis.name).values())) 45 transposed = core.transpose(labeled_tensor, temp_axes.keys()) 46 indexed = core.LabeledTensor( 47 array_ops.gather(transposed.tensor, indexer), temp_axes) 48 return core.transpose(indexed, labeled_tensor.axes.keys(), name=scope) 49 50 51 @tc.returns(core.LabeledTensor) 52 @tc.accepts(core.LabeledTensorLike, 53 tc.Mapping(string_types, 54 tc.Union(slice, collections.Hashable, list)), 55 tc.Optional(string_types)) 56 def select(labeled_tensor, selection, name=None): 57 """Slice out a subset of the tensor. 58 59 Args: 60 labeled_tensor: The input tensor. 61 selection: A dictionary mapping an axis name to a scalar, slice or list of 62 values to select. Currently supports two types of selections: 63 (a) Any number of scalar and/or slice selections. 64 (b) Exactly one list selection, without any scalars or slices. 65 name: Optional op name. 66 67 Returns: 68 The selection as a `LabeledTensor`. 69 70 Raises: 71 ValueError: If the tensor doesn't have an axis in the selection or if 72 that axis lacks labels. 73 KeyError: If any labels in a selection are not found in the original axis. 74 NotImplementedError: If you attempt to combine a list selection with 75 scalar selection or another list selection. 76 """ 77 with ops.name_scope(name, 'lt_select', [labeled_tensor]) as scope: 78 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 79 80 slices = {} 81 indexers = {} 82 for axis_name, value in selection.items(): 83 if axis_name not in labeled_tensor.axes: 84 raise ValueError( 85 'The tensor does not have an axis named %s. Its axes are: %r' % 86 (axis_name, labeled_tensor.axes.keys())) 87 axis = labeled_tensor.axes[axis_name] 88 if axis.labels is None: 89 raise ValueError( 90 'The axis named %s does not have labels. The axis is: %r' % 91 (axis_name, axis)) 92 93 if isinstance(value, slice): 94 # TODO(shoyer): consider deprecating using slices in favor of lists 95 if value.start is None: 96 start = None 97 else: 98 start = axis.index(value.start) 99 100 if value.stop is None: 101 stop = None 102 else: 103 # For now, follow the pandas convention of making labeled slices 104 # inclusive of both bounds. 105 stop = axis.index(value.stop) + 1 106 107 if value.step is not None: 108 raise NotImplementedError('slicing with a step is not yet supported') 109 110 slices[axis_name] = slice(start, stop) 111 112 # Needs to be after checking for slices, since slice objects claim to be 113 # instances of collections.Hashable but hash() on them fails. 114 elif isinstance(value, collections.Hashable): 115 slices[axis_name] = axis.index(value) 116 117 elif isinstance(value, list): 118 if indexers: 119 raise NotImplementedError( 120 'select does not yet support more than one list selection at ' 121 'the same time') 122 indexer = [axis.index(v) for v in value] 123 indexers[axis_name] = ops.convert_to_tensor(indexer, dtype=dtypes.int64) 124 125 else: 126 # If type checking is working properly, this shouldn't be possible. 127 raise TypeError('cannot handle arbitrary types') 128 129 if indexers and slices: 130 raise NotImplementedError( 131 'select does not yet support combined scalar and list selection') 132 133 # For now, handle array selection separately, because tf.gather_nd does 134 # not support gradients yet. Later, using gather_nd will let us combine 135 # these paths. 136 if indexers: 137 (axis_name, indexer), = indexers.items() 138 axis = core.Axis(axis_name, selection[axis_name]) 139 return _gather_1d_on_axis(labeled_tensor, indexer, axis, name=scope) 140 else: 141 return core.slice_function(labeled_tensor, slices, name=scope) 142 143 144 @tc.returns(core.LabeledTensor) 145 @tc.accepts( 146 tc.Collection(core.LabeledTensorLike), string_types, 147 tc.Optional(string_types)) 148 def concat(labeled_tensors, axis_name, name=None): 149 """Concatenate tensors along a dimension. 150 151 See tf.concat. 152 153 Args: 154 labeled_tensors: A list of input LabeledTensors. 155 axis_name: The name of the axis along which to concatenate. 156 name: Optional op name. 157 158 Returns: 159 The concatenated tensor. 160 The coordinate labels for the concatenation dimension are also concatenated, 161 if they are available for every tensor. 162 163 Raises: 164 ValueError: If fewer than one tensor inputs is provided, if the tensors 165 have incompatible axes, or if `axis_name` isn't the name of an axis. 166 """ 167 with ops.name_scope(name, 'lt_concat', labeled_tensors) as scope: 168 labeled_tensors = [ 169 core.convert_to_labeled_tensor(lt) for lt in labeled_tensors 170 ] 171 172 if len(labeled_tensors) < 1: 173 raise ValueError('concat expects at least 1 tensor, but received %s' % 174 labeled_tensors) 175 176 # All tensors must have these axes. 177 axes_0 = labeled_tensors[0].axes 178 axis_names = list(axes_0.keys()) 179 180 if axis_name not in axis_names: 181 raise ValueError('%s not in %s' % (axis_name, axis_names)) 182 183 shared_axes = axes_0.remove(axis_name) 184 185 tensors = [labeled_tensors[0].tensor] 186 concat_axis_list = [axes_0[axis_name]] 187 for labeled_tensor in labeled_tensors[1:]: 188 current_shared_axes = labeled_tensor.axes.remove(axis_name) 189 if current_shared_axes != shared_axes: 190 # TODO(shoyer): add more specific checks about what went wrong, 191 # including raising AxisOrderError when appropriate 192 raise ValueError('Mismatched shared axes: the first tensor ' 193 'had axes %r but this tensor has axes %r.' % 194 (shared_axes, current_shared_axes)) 195 196 # Accumulate the axis labels, if they're available. 197 concat_axis_list.append(labeled_tensor.axes[axis_name]) 198 tensors.append(labeled_tensor.tensor) 199 200 concat_axis = core.concat_axes(concat_axis_list) 201 concat_dimension = axis_names.index(axis_name) 202 concat_tensor = array_ops.concat(tensors, concat_dimension, name=scope) 203 values = list(axes_0.values()) 204 concat_axes = (values[:concat_dimension] + [concat_axis] + 205 values[concat_dimension + 1:]) 206 207 return core.LabeledTensor(concat_tensor, concat_axes) 208 209 210 # TODO(shoyer): rename pack/unpack to stack/unstack 211 212 213 @tc.returns(core.LabeledTensor) 214 @tc.accepts( 215 tc.Collection(core.LabeledTensorLike), 216 tc.Union(string_types, core.AxisLike), int, tc.Optional(string_types)) 217 def pack(labeled_tensors, new_axis, axis_position=0, name=None): 218 """Pack tensors along a new axis. 219 220 See tf.pack. 221 222 Args: 223 labeled_tensors: The input tensors, which must have identical axes. 224 new_axis: The name of the new axis, or a tuple containing the name 225 and coordinate labels. 226 axis_position: Optional integer position at which to insert the new axis. 227 name: Optional op name. 228 229 Returns: 230 The packed tensors as a single LabeledTensor, with `new_axis` in the given 231 `axis_position`. 232 233 Raises: 234 ValueError: If fewer than one input tensors is provided, or if the tensors 235 don't have identical axes. 236 """ 237 with ops.name_scope(name, 'lt_pack', labeled_tensors) as scope: 238 labeled_tensors = [ 239 core.convert_to_labeled_tensor(lt) for lt in labeled_tensors 240 ] 241 242 if len(labeled_tensors) < 1: 243 raise ValueError('pack expects at least 1 tensors, but received %s' % 244 labeled_tensors) 245 246 axes_0 = labeled_tensors[0].axes 247 for t in labeled_tensors: 248 if t.axes != axes_0: 249 raise ValueError('Non-identical axes. Expected %s but got %s' % 250 (axes_0, t.axes)) 251 252 pack_op = array_ops.stack( 253 [t.tensor for t in labeled_tensors], axis=axis_position, name=scope) 254 axes = list(axes_0.values()) 255 axes.insert(axis_position, new_axis) 256 return core.LabeledTensor(pack_op, axes) 257 258 259 @tc.returns(tc.List(core.LabeledTensor)) 260 @tc.accepts(core.LabeledTensorLike, 261 tc.Optional(string_types), tc.Optional(string_types)) 262 def unpack(labeled_tensor, axis_name=None, name=None): 263 """Unpack the tensor. 264 265 See tf.unpack. 266 267 Args: 268 labeled_tensor: The input tensor. 269 axis_name: Optional name of axis to unpack. By default, the first axis is 270 used. 271 name: Optional op name. 272 273 Returns: 274 The list of unpacked LabeledTensors. 275 276 Raises: 277 ValueError: If `axis_name` is not an axis on the input. 278 """ 279 with ops.name_scope(name, 'lt_unpack', [labeled_tensor]) as scope: 280 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 281 282 axis_names = list(labeled_tensor.axes.keys()) 283 if axis_name is None: 284 axis_name = axis_names[0] 285 286 if axis_name not in axis_names: 287 raise ValueError('%s not in %s' % (axis_name, axis_names)) 288 axis = axis_names.index(axis_name) 289 290 unpack_ops = array_ops.unstack(labeled_tensor.tensor, axis=axis, name=scope) 291 axes = [a for i, a in enumerate(labeled_tensor.axes.values()) if i != axis] 292 return [core.LabeledTensor(t, axes) for t in unpack_ops] 293 294 295 @tc.returns(core.LabeledTensor) 296 @tc.accepts(core.LabeledTensorLike, 297 tc.Collection(string_types), 298 tc.Collection(tc.Union(string_types, core.AxisLike)), 299 tc.Optional(string_types)) 300 def reshape(labeled_tensor, existing_axes, new_axes, name=None): 301 """Reshape specific axes of a LabeledTensor. 302 303 Non-indicated axes remain in their original locations. 304 305 Args: 306 labeled_tensor: The input tensor. 307 existing_axes: List of axis names found on the input tensor. These must 308 appear sequentially in the list of axis names on the input. In other 309 words, they must be a valid slice of `list(labeled_tensor.axes.keys())`. 310 new_axes: List of strings, tuples of (axis_name, axis_value) or Axis objects 311 providing new axes with which to replace `existing_axes` in the reshaped 312 result. At most one element of `new_axes` may be a string, indicating an 313 axis with unknown size. 314 name: Optional op name. 315 316 Returns: 317 The reshaped LabeledTensor. 318 319 Raises: 320 ValueError: If `existing_axes` are not all axes on the input, or if more 321 than one of `new_axes` has unknown size. 322 AxisOrderError: If `existing_axes` are not a slice of axis names on the 323 input. 324 """ 325 with ops.name_scope(name, 'lt_reshape', [labeled_tensor]) as scope: 326 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 327 328 original_axis_names = list(labeled_tensor.axes.keys()) 329 existing_axes = list(existing_axes) 330 if not set(existing_axes) <= set(original_axis_names): 331 raise ValueError('existing_axes %r are not contained in the set of axis ' 332 'names %r on the input labeled tensor' % 333 (existing_axes, original_axis_names)) 334 335 start = original_axis_names.index(existing_axes[0]) 336 stop = original_axis_names.index(existing_axes[-1]) + 1 337 338 if existing_axes != original_axis_names[start:stop]: 339 # We could support existing_axes that aren't a slice by using transpose, 340 # but that could lead to unpredictable performance consequences because 341 # transposes are not free in TensorFlow. If we did transpose 342 # automatically, the user might never realize that their data is being 343 # produced with the wrong order. (The later will occur with some frequency 344 # because of how broadcasting automatically choose axis order.) 345 # So for now we've taken the strict approach. 346 raise core.AxisOrderError( 347 'existing_axes %r are not a slice of axis names %r on the input ' 348 'labeled tensor. Use `transpose` or `impose_axis_order` to reorder ' 349 'axes on the input explicitly.' % 350 (existing_axes, original_axis_names)) 351 352 if sum(isinstance(axis, string_types) for axis in new_axes) > 1: 353 raise ValueError( 354 'at most one axis in new_axes can have unknown size. All other ' 355 'axes must have an indicated integer size or labels: %r' % new_axes) 356 357 original_values = list(labeled_tensor.axes.values()) 358 axis_size = lambda axis: -1 if axis.size is None else axis.size 359 shape = [axis_size(axis) for axis in original_values[:start]] 360 for axis_ref in new_axes: 361 if isinstance(axis_ref, string_types): 362 shape.append(-1) 363 else: 364 axis = core.as_axis(axis_ref) 365 shape.append(axis_size(axis)) 366 shape.extend(axis_size(axis) for axis in original_values[stop:]) 367 368 reshaped_tensor = array_ops.reshape( 369 labeled_tensor.tensor, shape, name=scope) 370 axes = original_values[:start] + list(new_axes) + original_values[stop:] 371 return core.LabeledTensor(reshaped_tensor, axes) 372 373 374 @tc.returns(core.LabeledTensor) 375 @tc.accepts(core.LabeledTensorLike, string_types, string_types, 376 tc.Optional(string_types)) 377 def rename_axis(labeled_tensor, existing_name, new_name, name=None): 378 """Rename an axis of LabeledTensor. 379 380 Args: 381 labeled_tensor: The input tensor. 382 existing_name: Name for an existing axis on the input. 383 new_name: Desired replacement name. 384 name: Optional op name. 385 386 Returns: 387 LabeledTensor with renamed axis. 388 389 Raises: 390 ValueError: If `existing_name` is not an axis on the input. 391 """ 392 with ops.name_scope(name, 'lt_rename_axis', [labeled_tensor]) as scope: 393 if existing_name not in labeled_tensor.axes: 394 raise ValueError('existing_name %r are not contained in the set of axis ' 395 'names %r on the input labeled tensor' % 396 (existing_name, labeled_tensor.axes.keys())) 397 new_axis = core.Axis(new_name, labeled_tensor.axes[existing_name].value) 398 return reshape(labeled_tensor, [existing_name], [new_axis], name=scope) 399 400 401 @tc.returns(tc.List(core.LabeledTensor)) 402 @tc.accepts(string_types, collections.Callable, int, bool, 403 tc.Collection(core.LabeledTensorLike), bool, 404 tc.Optional(string_types)) 405 def _batch_helper(default_name, 406 batch_fn, 407 batch_size, 408 enqueue_many, 409 labeled_tensors, 410 allow_smaller_final_batch, 411 name=None): 412 with ops.name_scope(name, default_name, labeled_tensors) as scope: 413 labeled_tensors = [ 414 core.convert_to_labeled_tensor(lt) for lt in labeled_tensors 415 ] 416 417 batch_ops = batch_fn([t.tensor for t in labeled_tensors], scope) 418 # TODO(shoyer): Remove this when they sanitize the TF API. 419 if not isinstance(batch_ops, list): 420 assert isinstance(batch_ops, ops.Tensor) 421 batch_ops = [batch_ops] 422 423 if allow_smaller_final_batch: 424 batch_size = None 425 426 @tc.returns(core.Axes) 427 @tc.accepts(core.Axes) 428 def output_axes(axes): 429 if enqueue_many: 430 if 'batch' not in axes or list(axes.keys()).index('batch') != 0: 431 raise ValueError( 432 'When enqueue_many is True, input tensors must have an axis ' 433 'called "batch" as their first dimension, ' 434 'but axes were %s' % axes) 435 culled_axes = axes.remove('batch') 436 return core.Axes([('batch', batch_size)] + list(culled_axes.values())) 437 else: 438 return core.Axes([('batch', batch_size)] + list(axes.values())) 439 440 output_labeled_tensors = [] 441 for i, tensor in enumerate(batch_ops): 442 axes = output_axes(labeled_tensors[i].axes) 443 output_labeled_tensors.append(core.LabeledTensor(tensor, axes)) 444 445 return output_labeled_tensors 446 447 448 @tc.returns(tc.List(core.LabeledTensor)) 449 @tc.accepts( 450 tc.Collection(core.LabeledTensorLike), int, int, int, bool, bool, 451 tc.Optional(string_types)) 452 def batch(labeled_tensors, 453 batch_size, 454 num_threads=1, 455 capacity=32, 456 enqueue_many=False, 457 allow_smaller_final_batch=False, 458 name=None): 459 """Rebatch a tensor. 460 461 See tf.batch. 462 463 Args: 464 labeled_tensors: The input tensors. 465 batch_size: The output batch size. 466 num_threads: See tf.batch. 467 capacity: See tf.batch. 468 enqueue_many: If true, the input tensors must contain a 'batch' axis as 469 their first axis. 470 If false, the input tensors must not contain a 'batch' axis. 471 See tf.batch. 472 allow_smaller_final_batch: See tf.batch. 473 name: Optional op name. 474 475 Returns: 476 The rebatched tensors. 477 If enqueue_many is false, the output tensors will have a new 'batch' axis 478 as their first axis. 479 480 Raises: 481 ValueError: If enqueue_many is True and the first axis of the tensors 482 isn't "batch". 483 """ 484 485 def fn(tensors, scope): 486 return input.batch( 487 tensors, 488 batch_size=batch_size, 489 num_threads=num_threads, 490 capacity=capacity, 491 enqueue_many=enqueue_many, 492 allow_smaller_final_batch=allow_smaller_final_batch, 493 name=scope) 494 495 return _batch_helper('lt_batch', fn, batch_size, enqueue_many, 496 labeled_tensors, allow_smaller_final_batch, name) 497 498 499 @tc.returns(tc.List(core.LabeledTensor)) 500 @tc.accepts( 501 tc.Collection(core.LabeledTensorLike), int, int, int, bool, int, 502 tc.Optional(int), bool, tc.Optional(string_types)) 503 def shuffle_batch(labeled_tensors, 504 batch_size, 505 num_threads=1, 506 capacity=32, 507 enqueue_many=False, 508 min_after_dequeue=0, 509 seed=None, 510 allow_smaller_final_batch=False, 511 name=None): 512 """Rebatch a tensor, with shuffling. 513 514 See tf.batch. 515 516 Args: 517 labeled_tensors: The input tensors. 518 batch_size: The output batch size. 519 num_threads: See tf.batch. 520 capacity: See tf.batch. 521 enqueue_many: If true, the input tensors must contain a 'batch' axis as 522 their first axis. 523 If false, the input tensors must not contain a 'batch' axis. 524 See tf.batch. 525 min_after_dequeue: Minimum number of elements in the queue after a dequeue, 526 used to ensure mixing. 527 seed: Optional random seed. 528 allow_smaller_final_batch: See tf.batch. 529 name: Optional op name. 530 531 Returns: 532 The rebatched tensors. 533 If enqueue_many is false, the output tensors will have a new 'batch' axis 534 as their first axis. 535 536 Raises: 537 ValueError: If enqueue_many is True and the first axis of the tensors 538 isn't "batch". 539 """ 540 541 def fn(tensors, scope): 542 return input.shuffle_batch( 543 tensors, 544 batch_size=batch_size, 545 num_threads=num_threads, 546 capacity=capacity, 547 enqueue_many=enqueue_many, 548 min_after_dequeue=min_after_dequeue, 549 seed=seed, 550 allow_smaller_final_batch=allow_smaller_final_batch, 551 name=scope) 552 553 return _batch_helper('lt_shuffle_batch', fn, batch_size, enqueue_many, 554 labeled_tensors, allow_smaller_final_batch, name) 555 556 557 @tc.returns(core.LabeledTensor) 558 @tc.accepts(core.LabeledTensorLike, 559 tc.Mapping(string_types, int), 560 tc.Optional(int), tc.Optional(string_types)) 561 def random_crop(labeled_tensor, shape_map, seed=None, name=None): 562 """Randomly crops a tensor to a given size. 563 564 See tf.random_crop. 565 566 Args: 567 labeled_tensor: The input tensor. 568 shape_map: A dictionary mapping axis names to the size of the random crop 569 for that dimension. 570 seed: An optional random seed. 571 name: An optional op name. 572 573 Returns: 574 A tensor of the same rank as `labeled_tensor`, cropped randomly in the 575 selected dimensions. 576 577 Raises: 578 ValueError: If the shape map contains an axis name not in the input tensor. 579 """ 580 with ops.name_scope(name, 'lt_random_crop', [labeled_tensor]) as scope: 581 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 582 583 for axis_name in shape_map: 584 if axis_name not in labeled_tensor.axes: 585 raise ValueError('Selection axis %s not in axes %s' % 586 (axis_name, labeled_tensor.axes)) 587 588 shape = [] 589 axes = [] 590 for axis in labeled_tensor.axes.values(): 591 if axis.name in shape_map: 592 size = shape_map[axis.name] 593 shape.append(size) 594 # We lose labels for the axes we crop, leaving just the size. 595 axes.append((axis.name, size)) 596 else: 597 shape.append(len(axis)) 598 axes.append(axis) 599 600 crop_op = random_ops.random_crop( 601 labeled_tensor.tensor, shape, seed=seed, name=scope) 602 603 return core.LabeledTensor(crop_op, axes) 604 605 606 # TODO(shoyer): Allow the user to select the axis over which to map. 607 @tc.returns(core.LabeledTensor) 608 @tc.accepts(collections.Callable, core.LabeledTensorLike, 609 tc.Optional(string_types)) 610 def map_fn(fn, labeled_tensor, name=None): 611 """Map on the list of tensors unpacked from labeled_tensor. 612 613 See tf.map_fn. 614 615 Args: 616 fn: The function to apply to each unpacked LabeledTensor. 617 It should have type LabeledTensor -> LabeledTensor. 618 labeled_tensor: The input tensor. 619 name: Optional op name. 620 621 Returns: 622 A tensor that packs the results of applying fn to the list of tensors 623 unpacked from labeled_tensor. 624 """ 625 with ops.name_scope(name, 'lt_map_fn', [labeled_tensor]) as scope: 626 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 627 628 unpack_lts = unpack(labeled_tensor) 629 630 # TODO(ericmc): Fix this upstream. 631 if labeled_tensor.dtype == dtypes.string: 632 # We must construct the full graph here, because functional_ops.map_fn 633 # doesn't work for string-valued tensors. 634 # Constructing the full graph may be slow. 635 map_lts = [fn(t) for t in unpack_lts] 636 return pack(map_lts, list(labeled_tensor.axes.values())[0], name=scope) 637 else: 638 # Figure out what the axis labels should be, but use tf.map_fn to 639 # construct the graph because it's efficient. 640 # It may be slow to construct the full graph, so we infer the labels from 641 # the first element. 642 # TODO(ericmc): This builds a subgraph which then gets thrown away. 643 # Find a more elegant solution. 644 first_map_lt = fn(unpack_lts[0]) 645 final_axes = list(labeled_tensor.axes.values())[:1] + list( 646 first_map_lt.axes.values()) 647 648 @tc.returns(ops.Tensor) 649 @tc.accepts(ops.Tensor) 650 def tf_fn(tensor): 651 original_axes = list(labeled_tensor.axes.values())[1:] 652 tensor_lt = core.LabeledTensor(tensor, original_axes) 653 return fn(tensor_lt).tensor 654 655 map_op = functional_ops.map_fn(tf_fn, labeled_tensor.tensor) 656 map_lt = core.LabeledTensor(map_op, final_axes) 657 658 return core.identity(map_lt, name=scope) 659 660 661 @tc.returns(core.LabeledTensor) 662 @tc.accepts(collections.Callable, core.LabeledTensorLike, 663 core.LabeledTensorLike, tc.Optional(string_types)) 664 def foldl(fn, labeled_tensor, initial_value, name=None): 665 """Left fold on the list of tensors unpacked from labeled_tensor. 666 667 See tf.foldl. 668 669 Args: 670 fn: The function to apply to each unpacked LabeledTensor. 671 It should have type (LabeledTensor, LabeledTensor) -> LabeledTensor. 672 Its arguments are (accumulated_value, next_value). 673 labeled_tensor: The input tensor. 674 initial_value: The initial value of the accumulator. 675 name: Optional op name. 676 677 Returns: 678 The accumulated value. 679 """ 680 with ops.name_scope(name, 'lt_foldl', 681 [labeled_tensor, initial_value]) as scope: 682 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 683 initial_value = core.convert_to_labeled_tensor(initial_value) 684 685 @tc.returns(ops.Tensor) 686 @tc.accepts(ops.Tensor, ops.Tensor) 687 def tf_fn(accumulator, next_element): 688 accumulator_lt = core.LabeledTensor(accumulator, initial_value.axes) 689 next_element_lt = core.LabeledTensor( 690 next_element, list(labeled_tensor.axes.values())[1:]) 691 return fn(accumulator_lt, next_element_lt).tensor 692 693 foldl_op = functional_ops.foldl( 694 tf_fn, labeled_tensor.tensor, initializer=initial_value.tensor) 695 foldl_lt = core.LabeledTensor(foldl_op, initial_value.axes) 696 697 return core.identity(foldl_lt, name=scope) 698 699 700 @tc.returns(core.LabeledTensor) 701 @tc.accepts(core.LabeledTensorLike, 702 tc.Optional(tc.Collection(string_types)), tc.Optional(string_types)) 703 def squeeze(labeled_tensor, axis_names=None, name=None): 704 """Remove size-1 dimensions. 705 706 See tf.squeeze. 707 708 Args: 709 labeled_tensor: The input tensor. 710 axis_names: The names of the dimensions to remove, or None to remove 711 all size-1 dimensions. 712 name: Optional op name. 713 714 Returns: 715 A tensor with the specified dimensions removed. 716 717 Raises: 718 ValueError: If the named axes are not in the tensor, or if they are 719 not size-1. 720 """ 721 with ops.name_scope(name, 'lt_squeeze', [labeled_tensor]) as scope: 722 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 723 724 if axis_names is None: 725 axis_names = [a.name for a in labeled_tensor.axes.values() if len(a) == 1] 726 727 for axis_name in axis_names: 728 if axis_name not in labeled_tensor.axes: 729 raise ValueError('axis %s is not in tensor axes %s' % 730 (axis_name, labeled_tensor.axes)) 731 elif len(labeled_tensor.axes[axis_name]) != 1: 732 raise ValueError( 733 'cannot squeeze axis with size greater than 1: (%s, %s)' % 734 (axis_name, labeled_tensor.axes[axis_name])) 735 736 squeeze_dimensions = [] 737 axes = [] 738 for i, axis in enumerate(labeled_tensor.axes.values()): 739 if axis.name in axis_names: 740 squeeze_dimensions.append(i) 741 else: 742 axes.append(axis) 743 744 if squeeze_dimensions: 745 squeeze_op = array_ops.squeeze( 746 labeled_tensor.tensor, squeeze_dimensions, name=scope) 747 else: 748 squeeze_op = array_ops.identity(labeled_tensor.tensor, name=scope) 749 750 return core.LabeledTensor(squeeze_op, axes) 751 752 753 # pylint: disable=invalid-name 754 ReduceAxis = tc.Union(string_types, 755 tc.Tuple(string_types, collections.Hashable)) 756 ReduceAxes = tc.Optional(tc.Union(ReduceAxis, tc.Collection(ReduceAxis))) 757 # pylint: enable=invalid-name 758 759 760 @tc.returns(core.LabeledTensor) 761 @tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike, 762 tc.Optional(string_types)) 763 def matmul(a, b, name=None): 764 """Matrix multiply two tensors with rank 1 or 2. 765 766 If both tensors have rank 2, a matrix-matrix product is performed. 767 If one tensor has rank 1 and the other has rank 2, then a matrix-vector 768 product is performed. 769 If both tensors have rank 1, then a vector dot-product is performed. 770 (This behavior matches that of `numpy.dot`.) 771 772 Both tensors must share exactly one dimension in common, which is the 773 dimension the operation is summed along. The inputs will be automatically 774 transposed if necessary as part of the matmul op. 775 776 We intend to eventually support `matmul` on higher rank input, and also 777 eventually support summing over any number shared dimensions (via an `axis` 778 argument), but neither of these features has been implemented yet. 779 780 Args: 781 a: First LabeledTensor. 782 b: Second LabeledTensor. 783 name: Optional op name. 784 785 Returns: 786 LabeledTensor with the result of matrix multiplication. Axes are ordered by 787 the current axis_order_scope, if set, or in or order of appearance on the 788 inputs. 789 790 Raises: 791 NotImplementedError: If inputs have rank >2 or share multiple axes. 792 ValueError: If the inputs have rank 0 or do not share any axes. 793 """ 794 with ops.name_scope(name, 'lt_matmul', [a, b]) as scope: 795 796 a = core.convert_to_labeled_tensor(a) 797 b = core.convert_to_labeled_tensor(b) 798 799 if len(a.axes) > 2 or len(b.axes) > 2: 800 # We could pass batched inputs to tf.matmul to make this work, but we 801 # would also need to use tf.tile and/or tf.transpose. These are more 802 # expensive than doing reshapes, so it's not clear if it's a good idea to 803 # do this automatically. 804 raise NotImplementedError( 805 'matmul currently requires inputs with rank 2 or less, but ' 806 'inputs have ranks %r and %r' % (len(a.axes), len(b.axes))) 807 808 if not a.axes or not b.axes: 809 raise ValueError( 810 'matmul currently requires inputs with at least rank 1, but ' 811 'inputs have ranks %r and %r' % (len(a.axes), len(b.axes))) 812 813 shared_axes = set(a.axes) & set(b.axes) 814 if len(shared_axes) > 1: 815 raise NotImplementedError( 816 'matmul does not yet support summing over multiple shared axes: %r. ' 817 'Use transpose and reshape to create a single shared axis to sum ' 818 'over.' % shared_axes) 819 if not shared_axes: 820 raise ValueError('there must have exactly one axis in common between ' 821 'input to matmul: %r, %r' % 822 (a.axes.keys(), b.axes.keys())) 823 shared_axis, = shared_axes 824 825 if a.axes[shared_axis] != b.axes[shared_axis]: 826 raise ValueError('axis %r does not match on input arguments: %r vs %r' % 827 (shared_axis, a.axes[shared_axis].value, 828 b.axes[shared_axis].value)) 829 830 result_axes = [] 831 for axes in [a.axes, b.axes]: 832 for axis in axes.values(): 833 if axis.name != shared_axis: 834 result_axes.append(axis) 835 836 axis_scope_order = core.get_axis_order() 837 if axis_scope_order is not None: 838 result_axis_names = [axis.name for axis in result_axes] 839 new_axis_names = [ 840 name for name in axis_scope_order if name in result_axis_names 841 ] 842 if new_axis_names != result_axis_names: 843 # switch a and b 844 b, a = a, b 845 # result_axes is a list of length 1 or 2 846 result_axes = result_axes[::-1] 847 848 squeeze_dims = [] 849 850 if len(a.axes) == 1: 851 a_tensor = array_ops.reshape(a.tensor, (1, -1)) 852 squeeze_dims.append(0) 853 transpose_a = False 854 else: 855 a_tensor = a.tensor 856 transpose_a = list(a.axes.keys()).index(shared_axis) == 0 857 858 if len(b.axes) == 1: 859 b_tensor = array_ops.reshape(b.tensor, (-1, 1)) 860 squeeze_dims.append(1) 861 transpose_b = False 862 else: 863 b_tensor = b.tensor 864 transpose_b = list(b.axes.keys()).index(shared_axis) == 1 865 866 result_op = math_ops.matmul( 867 a_tensor, b_tensor, transpose_a=transpose_a, transpose_b=transpose_b) 868 869 if squeeze_dims: 870 result_op = array_ops.squeeze(result_op, squeeze_dims) 871 result_op = array_ops.identity(result_op, name=scope) 872 873 return core.LabeledTensor(result_op, result_axes) 874 875 876 @tc.returns(types.FunctionType) 877 @tc.accepts(string_types, collections.Callable) 878 def define_reduce_op(op_name, reduce_fn): 879 """Define a reduction op for labeled tensors. 880 881 Args: 882 op_name: string name of the TensorFlow op. 883 reduce_fn: function to call to evaluate the op on a tf.Tensor. 884 885 Returns: 886 Function defining the given reduction op that acts on a LabeledTensor. 887 """ 888 889 default_name = 'lt_%s' % op_name 890 891 @tc.returns(core.LabeledTensor) 892 @tc.accepts(core.LabeledTensorLike, ReduceAxes, tc.Optional(string_types)) 893 def op(labeled_tensor, axes=None, name=None): 894 """Computes the given reduction across the given axes of a LabeledTensor. 895 896 See `tf.{op_name}` for full details. 897 898 Args: 899 labeled_tensor: The input tensor. 900 axes: A set of axes or None. 901 If None, all axes will be reduced. 902 Axes must all be strings, in which case those dimensions will be 903 removed, or pairs of (name, None) or (name, label), in which case those 904 dimensions will be kept. 905 name: Optional op name. 906 907 Returns: 908 The reduced LabeledTensor. 909 910 Raises: 911 ValueError: if any of the axes to reduce over are not found on 912 `labeled_tensor`. 913 """ 914 with ops.name_scope(name, default_name, [labeled_tensor]) as scope: 915 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 916 917 if axes is None: 918 axes = labeled_tensor.axes.keys() 919 920 if isinstance(axes, (string_types, tuple)): 921 axes = [axes] 922 923 reduction_axes = {} 924 axes_to_squeeze = [] 925 for a in axes: 926 if isinstance(a, string_types): 927 # We squeeze out this axis. 928 reduction_axes[a] = a 929 axes_to_squeeze.append(a) 930 else: 931 # We keep this axis, with the user-provided labels. 932 (axis_name, label) = a 933 if label is not None: 934 # The input was a single label, so make it a list so it can be 935 # turned into an Axis. 936 label = [label] 937 reduction_axes[axis_name] = (axis_name, label) 938 939 for axis_name in reduction_axes: 940 if axis_name not in labeled_tensor.axes: 941 raise ValueError('Axis %s not in axes %s' % 942 (axis_name, labeled_tensor.axes)) 943 944 intermediate_axes = [] 945 reduction_dimensions = [] 946 for i, axis in enumerate(labeled_tensor.axes.values()): 947 if axis.name in reduction_axes: 948 intermediate_axes.append(reduction_axes[axis.name]) 949 reduction_dimensions.append(i) 950 else: 951 intermediate_axes.append(axis) 952 953 reduce_op = reduce_fn( 954 labeled_tensor.tensor, reduction_dimensions, keepdims=True) 955 reduce_lt = core.LabeledTensor(reduce_op, intermediate_axes) 956 957 return squeeze(reduce_lt, axes_to_squeeze, name=scope) 958 959 op.__doc__ = op.__doc__.format(op_name=op_name) 960 op.__name__ = op_name 961 962 return op 963 964 965 reduce_all = define_reduce_op('reduce_all', math_ops.reduce_all) 966 reduce_any = define_reduce_op('reduce_any', math_ops.reduce_any) 967 reduce_logsumexp = define_reduce_op('reduce_logsumexp', 968 math_ops.reduce_logsumexp) 969 reduce_max = define_reduce_op('reduce_max', math_ops.reduce_max) 970 reduce_mean = define_reduce_op('reduce_mean', math_ops.reduce_mean) 971 reduce_min = define_reduce_op('reduce_min', math_ops.reduce_min) 972 reduce_prod = define_reduce_op('reduce_prod', math_ops.reduce_prod) 973 reduce_sum = define_reduce_op('reduce_sum', math_ops.reduce_sum) 974 975 976 @tc.returns(core.LabeledTensor) 977 @tc.accepts(core.LabeledTensorLike, 978 tc.Mapping(str, tc.Union(int, ops.Tensor)), 979 tc.Optional(string_types)) 980 def tile(labeled_tensor, multiples, name=None): 981 """Constructs a tensor by tiling a given tensor. 982 983 Only axes without tick-labels can be tiled. (Otherwise, axis labels on tiled 984 tensors would no longer be unique.) 985 986 See lt.tile. 987 988 Args: 989 labeled_tensor: The input tensor. 990 multiples: A mapping where the keys are axis names and the values are the 991 integer number of times to tile along that axis. Only axes with a multiple 992 different than 1 need be included. 993 name: Optional op name. 994 995 Returns: 996 A tensor with the indicated axes tiled. 997 998 Raises: 999 ValueError: If the tiled axes are not axes in the input tensor, or if any 1000 axes in multiples have tick labels. 1001 """ 1002 with ops.name_scope(name, 'lt_tile', [labeled_tensor]) as scope: 1003 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 1004 1005 if not set(multiples.keys()) <= set(labeled_tensor.axes.keys()): 1006 raise ValueError('tile axes %r are not contained in the set of axis ' 1007 'names %r on the input labeled tensor' % 1008 (multiples.keys(), labeled_tensor.axes)) 1009 1010 labeled_axes = [ 1011 name for name in multiples 1012 if labeled_tensor.axes[name].labels is not None 1013 ] 1014 if labeled_axes: 1015 raise ValueError('cannot tile axes with tick labels: %r' % labeled_axes) 1016 1017 multiples_list = [multiples.get(name, 1) for name in labeled_tensor.axes] 1018 tile_op = array_ops.tile(labeled_tensor.tensor, multiples_list, name=scope) 1019 1020 new_axes = [ 1021 axis.name if axis.labels is None else axis 1022 for axis in labeled_tensor.axes.values() 1023 ] 1024 return core.LabeledTensor(tile_op, new_axes) 1025 1026 1027 @tc.returns(core.LabeledTensor) 1028 @tc.accepts(core.LabeledTensorLike, 1029 tc.Mapping(str, tc.Tuple(core.AxisValue, core.AxisValue)), 1030 string_types, tc.Optional(string_types)) 1031 def pad(labeled_tensor, paddings, mode='CONSTANT', name=None): 1032 """Pads a tensor. 1033 1034 See tf.pad. 1035 1036 Args: 1037 labeled_tensor: The input tensor. 1038 paddings: A mapping where the keys are axis names and the values are 1039 tuples where the first element is the padding to insert at the beginning 1040 of the axis and the second is the padding to insert at the end of the 1041 axis. 1042 mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC". 1043 name: Optional op name. 1044 1045 Returns: 1046 A tensor with the indicated axes padded, optionally with those axes extended 1047 with the provided labels. 1048 1049 Raises: 1050 ValueError: If the padded axes are not axes in the input tensor. 1051 """ 1052 with ops.name_scope(name, 'lt_pad', [labeled_tensor]) as scope: 1053 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 1054 1055 if not set(paddings.keys()) <= set(labeled_tensor.axes.keys()): 1056 raise ValueError('pad axes %r are not contained in the set of axis ' 1057 'names %r on the input labeled tensor' % 1058 (paddings.keys(), labeled_tensor.axes)) 1059 1060 new_axes = [] 1061 padding_pairs = [] 1062 for name, axis in labeled_tensor.axes.items(): 1063 if name in paddings: 1064 padding_before, padding_after = paddings[name] 1065 axis_before = core.Axis(name, padding_before) 1066 axis_after = core.Axis(name, padding_after) 1067 new_axes.append(core.concat_axes([axis_before, axis, axis_after])) 1068 padding_pairs.append((len(axis_before), len(axis_after))) 1069 else: 1070 new_axes.append(axis) 1071 padding_pairs.append((0, 0)) 1072 1073 pad_op = array_ops.pad(labeled_tensor.tensor, 1074 padding_pairs, 1075 mode, 1076 name=scope) 1077 1078 return core.LabeledTensor(pad_op, new_axes) 1079 1080 1081 @tc.returns(core.LabeledTensor) 1082 @tc.accepts( 1083 tc.Union(np.ndarray, list, tuple, core.Scalar), 1084 tc.Optional(dtypes.DType), 1085 tc.Optional( 1086 tc.Union(core.Axes, tc.Collection( 1087 tc.Union(string_types, core.AxisLike)))), tc.Optional(string_types)) 1088 def constant(value, dtype=None, axes=None, name=None): 1089 """Creates a constant tensor. 1090 1091 If `axes` includes any strings, shape is inferred from `value`. Otherwise, 1092 the sizes of the given `axes` are used to set `shape` for `tf.constant`. 1093 1094 See tf.constant for more details. 1095 1096 Args: 1097 value: The input tensor. 1098 dtype: The type of the returned tensor. 1099 axes: Optional Axes, list of strings or list of objects coercible to Axis 1100 objects. By default, axes are assumed to be an empty list (i.e., `value` 1101 is treated as a scalar). 1102 name: Optional op name. 1103 1104 Returns: 1105 The tensor with elements set to zero. 1106 """ 1107 with ops.name_scope(name, 'lt_constant', [value]) as scope: 1108 1109 if axes is None: 1110 axes = [] 1111 1112 if isinstance(axes, core.Axes): 1113 axes = axes.values() 1114 1115 if any(isinstance(ax, string_types) for ax in axes): 1116 # need to infer shape 1117 shape = None 1118 else: 1119 # axes already indicate shape 1120 axes = [core.as_axis(a) for a in axes] 1121 shape = [a.size for a in axes] 1122 1123 op = array_ops.constant(value, dtype=dtype, shape=shape, name=scope) 1124 return core.LabeledTensor(op, axes) 1125 1126 1127 @tc.returns(core.LabeledTensor) 1128 @tc.accepts(core.LabeledTensorLike, 1129 tc.Optional(dtypes.DType), tc.Optional(string_types)) 1130 def zeros_like(labeled_tensor, dtype=None, name=None): 1131 """Creates an identical tensor with all elements set to zero. 1132 1133 Args: 1134 labeled_tensor: The input tensor. 1135 dtype: The type of the returned tensor. 1136 name: Optional op name. 1137 1138 Returns: 1139 The tensor with elements set to zero. 1140 """ 1141 with ops.name_scope(name, 'lt_zeros_like', [labeled_tensor]) as scope: 1142 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 1143 op = array_ops.zeros_like(labeled_tensor.tensor, dtype=dtype, name=scope) 1144 return core.LabeledTensor(op, labeled_tensor.axes) 1145 1146 1147 @tc.returns(core.LabeledTensor) 1148 @tc.accepts(core.LabeledTensorLike, 1149 tc.Optional(dtypes.DType), tc.Optional(string_types)) 1150 def ones_like(labeled_tensor, dtype=None, name=None): 1151 """Creates an identical tensor with all elements set to one. 1152 1153 Args: 1154 labeled_tensor: The input tensor. 1155 dtype: The type of the returned tensor. 1156 name: Optional op name. 1157 1158 Returns: 1159 The tensor with elements set to one. 1160 """ 1161 with ops.name_scope(name, 'lt_ones_like', [labeled_tensor]) as scope: 1162 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 1163 op = array_ops.ones_like(labeled_tensor.tensor, dtype=dtype, name=scope) 1164 return core.LabeledTensor(op, labeled_tensor.axes) 1165 1166 1167 @tc.returns(core.LabeledTensor) 1168 @tc.accepts(core.LabeledTensorLike, 1169 tc.Optional(dtypes.DType), tc.Optional(string_types)) 1170 def cast(labeled_tensor, dtype=None, name=None): 1171 """Casts a labeled tensor to a new type. 1172 1173 Args: 1174 labeled_tensor: The input tensor. 1175 dtype: The type of the returned tensor. 1176 name: Optional op name. 1177 1178 Returns: 1179 A labeled tensor with the new dtype. 1180 """ 1181 with ops.name_scope(name, 'lt_cast', [labeled_tensor]) as scope: 1182 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 1183 op = math_ops.cast(labeled_tensor.tensor, dtype=dtype, name=scope) 1184 return core.LabeledTensor(op, labeled_tensor.axes) 1185 1186 1187 @tc.returns(core.LabeledTensor) 1188 @tc.accepts(core.LabeledTensorLike, string_types, tc.Optional(string_types)) 1189 def verify_tensor_all_finite(labeled_tensor, message, name=None): 1190 """Asserts a tensor doesn't contain NaNs or Infs. 1191 1192 See tf.verify_tensor_all_finite. 1193 1194 Args: 1195 labeled_tensor: The input tensor. 1196 message: Message to log on failure. 1197 name: Optional op name. 1198 1199 Returns: 1200 The input tensor. 1201 """ 1202 with ops.name_scope(name, 'lt_verify_tensor_all_finite', 1203 [labeled_tensor]) as scope: 1204 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 1205 op = numerics.verify_tensor_all_finite( 1206 labeled_tensor.tensor, msg=message, name=scope) 1207 return core.LabeledTensor(op, labeled_tensor.axes) 1208 1209 1210 @tc.returns(core.LabeledTensor) 1211 @tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike, 1212 tc.Optional(string_types)) 1213 def boolean_mask(labeled_tensor, mask, name=None): 1214 """Apply a boolean mask to a labeled tensor. 1215 1216 Unlike `tf.boolean_mask`, this currently only works on 1-dimensional masks. 1217 The mask is applied to the first axis of `labeled_tensor`. Labels on the first 1218 axis are removed, because True indices in `mask` may not be known dynamically. 1219 1220 Args: 1221 labeled_tensor: The input tensor. 1222 mask: The type of the returned tensor. 1223 name: Optional op name. 1224 1225 Returns: 1226 The masked labeled tensor. 1227 1228 Raises: 1229 ValueError: if the first axis of the mask 1230 """ 1231 with ops.name_scope(name, 'lt_boolean_mask', [labeled_tensor, mask]) as scope: 1232 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 1233 mask = core.convert_to_labeled_tensor(mask) 1234 1235 if len(mask.axes) > 1: 1236 raise NotImplementedError( 1237 "LabeledTensor's boolean_mask currently only supports 1D masks") 1238 mask_axis = list(mask.axes.values())[0] 1239 lt_axis = list(labeled_tensor.axes.values())[0] 1240 if mask_axis != lt_axis: 1241 raise ValueError('the first axis of the labeled tensor and the mask ' 1242 'are not equal:\n%r\n%r' % (lt_axis, mask_axis)) 1243 op = array_ops.boolean_mask(labeled_tensor.tensor, mask.tensor, name=scope) 1244 # TODO(shoyer): attempt to infer labels for the masked values, by calling 1245 # tf.contrib.util.constant_value on the mask? 1246 axes = [lt_axis.name] + list(labeled_tensor.axes.values())[1:] 1247 return core.LabeledTensor(op, axes) 1248 1249 1250 @tc.returns(core.LabeledTensor) 1251 @tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike, 1252 core.LabeledTensorLike, tc.Optional(string_types)) 1253 def where(condition, x, y, name=None): 1254 """Return elements from x or y depending on condition. 1255 1256 See `tf.where` for more details. This function currently only implements the 1257 three argument version of where. 1258 1259 Args: 1260 condition: LabeledTensor of type `bool`. 1261 x: LabeledTensor for values where condition is true. 1262 y: LabeledTensor for values where condition is false. 1263 name: Optional op name. 1264 1265 Returns: 1266 The labeled tensor with values according to condition. 1267 1268 Raises: 1269 ValueError: if `x` and `y` have different axes, or if the axes of `x` do not 1270 start with the axes of `condition`. 1271 """ 1272 with ops.name_scope(name, 'lt_where', [condition, x, y]) as scope: 1273 condition = core.convert_to_labeled_tensor(condition) 1274 x = core.convert_to_labeled_tensor(x) 1275 y = core.convert_to_labeled_tensor(y) 1276 1277 if not condition.axes == x.axes == y.axes: 1278 raise ValueError('all inputs to `where` must have equal axes') 1279 1280 op = array_ops.where(condition.tensor, x.tensor, y.tensor, name=scope) 1281 return core.LabeledTensor(op, x.axes) 1282