1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """A library of helpers for use with SamplingDecoders. 16 """ 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import abc 23 24 import six 25 26 from tensorflow.contrib.seq2seq.python.ops import decoder 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import ops 29 from tensorflow.python.framework import tensor_shape 30 from tensorflow.python.ops import array_ops 31 from tensorflow.python.ops import control_flow_ops 32 from tensorflow.python.ops import embedding_ops 33 from tensorflow.python.ops import gen_array_ops 34 from tensorflow.python.ops import math_ops 35 from tensorflow.python.ops import tensor_array_ops 36 from tensorflow.python.ops.distributions import bernoulli 37 from tensorflow.python.ops.distributions import categorical 38 from tensorflow.python.util import nest 39 40 __all__ = [ 41 "Helper", 42 "TrainingHelper", 43 "GreedyEmbeddingHelper", 44 "SampleEmbeddingHelper", 45 "CustomHelper", 46 "ScheduledEmbeddingTrainingHelper", 47 "ScheduledOutputTrainingHelper", 48 "InferenceHelper", 49 ] 50 51 _transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access 52 53 54 def _unstack_ta(inp): 55 return tensor_array_ops.TensorArray( 56 dtype=inp.dtype, size=array_ops.shape(inp)[0], 57 element_shape=inp.get_shape()[1:]).unstack(inp) 58 59 60 @six.add_metaclass(abc.ABCMeta) 61 class Helper(object): 62 """Interface for implementing sampling in seq2seq decoders. 63 64 Helper instances are used by `BasicDecoder`. 65 """ 66 67 @abc.abstractproperty 68 def batch_size(self): 69 """Batch size of tensor returned by `sample`. 70 71 Returns a scalar int32 tensor. 72 """ 73 raise NotImplementedError("batch_size has not been implemented") 74 75 @abc.abstractproperty 76 def sample_ids_shape(self): 77 """Shape of tensor returned by `sample`, excluding the batch dimension. 78 79 Returns a `TensorShape`. 80 """ 81 raise NotImplementedError("sample_ids_shape has not been implemented") 82 83 @abc.abstractproperty 84 def sample_ids_dtype(self): 85 """DType of tensor returned by `sample`. 86 87 Returns a DType. 88 """ 89 raise NotImplementedError("sample_ids_dtype has not been implemented") 90 91 @abc.abstractmethod 92 def initialize(self, name=None): 93 """Returns `(initial_finished, initial_inputs)`.""" 94 pass 95 96 @abc.abstractmethod 97 def sample(self, time, outputs, state, name=None): 98 """Returns `sample_ids`.""" 99 pass 100 101 @abc.abstractmethod 102 def next_inputs(self, time, outputs, state, sample_ids, name=None): 103 """Returns `(finished, next_inputs, next_state)`.""" 104 pass 105 106 107 class CustomHelper(Helper): 108 """Base abstract class that allows the user to customize sampling.""" 109 110 def __init__(self, initialize_fn, sample_fn, next_inputs_fn, 111 sample_ids_shape=None, sample_ids_dtype=None): 112 """Initializer. 113 114 Args: 115 initialize_fn: callable that returns `(finished, next_inputs)` 116 for the first iteration. 117 sample_fn: callable that takes `(time, outputs, state)` 118 and emits tensor `sample_ids`. 119 next_inputs_fn: callable that takes `(time, outputs, state, sample_ids)` 120 and emits `(finished, next_inputs, next_state)`. 121 sample_ids_shape: Either a list of integers, or a 1-D Tensor of type 122 `int32`, the shape of each value in the `sample_ids` batch. Defaults to 123 a scalar. 124 sample_ids_dtype: The dtype of the `sample_ids` tensor. Defaults to int32. 125 """ 126 self._initialize_fn = initialize_fn 127 self._sample_fn = sample_fn 128 self._next_inputs_fn = next_inputs_fn 129 self._batch_size = None 130 self._sample_ids_shape = tensor_shape.TensorShape(sample_ids_shape or []) 131 self._sample_ids_dtype = sample_ids_dtype or dtypes.int32 132 133 @property 134 def batch_size(self): 135 if self._batch_size is None: 136 raise ValueError("batch_size accessed before initialize was called") 137 return self._batch_size 138 139 @property 140 def sample_ids_shape(self): 141 return self._sample_ids_shape 142 143 @property 144 def sample_ids_dtype(self): 145 return self._sample_ids_dtype 146 147 def initialize(self, name=None): 148 with ops.name_scope(name, "%sInitialize" % type(self).__name__): 149 (finished, next_inputs) = self._initialize_fn() 150 if self._batch_size is None: 151 self._batch_size = array_ops.size(finished) 152 return (finished, next_inputs) 153 154 def sample(self, time, outputs, state, name=None): 155 with ops.name_scope( 156 name, "%sSample" % type(self).__name__, (time, outputs, state)): 157 return self._sample_fn(time=time, outputs=outputs, state=state) 158 159 def next_inputs(self, time, outputs, state, sample_ids, name=None): 160 with ops.name_scope( 161 name, "%sNextInputs" % type(self).__name__, (time, outputs, state)): 162 return self._next_inputs_fn( 163 time=time, outputs=outputs, state=state, sample_ids=sample_ids) 164 165 166 class TrainingHelper(Helper): 167 """A helper for use during training. Only reads inputs. 168 169 Returned sample_ids are the argmax of the RNN output logits. 170 """ 171 172 def __init__(self, inputs, sequence_length, time_major=False, name=None): 173 """Initializer. 174 175 Args: 176 inputs: A (structure of) input tensors. 177 sequence_length: An int32 vector tensor. 178 time_major: Python bool. Whether the tensors in `inputs` are time major. 179 If `False` (default), they are assumed to be batch major. 180 name: Name scope for any created operations. 181 182 Raises: 183 ValueError: if `sequence_length` is not a 1D tensor. 184 """ 185 with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]): 186 inputs = ops.convert_to_tensor(inputs, name="inputs") 187 self._inputs = inputs 188 if not time_major: 189 inputs = nest.map_structure(_transpose_batch_time, inputs) 190 191 self._input_tas = nest.map_structure(_unstack_ta, inputs) 192 self._sequence_length = ops.convert_to_tensor( 193 sequence_length, name="sequence_length") 194 if self._sequence_length.get_shape().ndims != 1: 195 raise ValueError( 196 "Expected sequence_length to be a vector, but received shape: %s" % 197 self._sequence_length.get_shape()) 198 199 self._zero_inputs = nest.map_structure( 200 lambda inp: array_ops.zeros_like(inp[0, :]), inputs) 201 202 self._batch_size = array_ops.size(sequence_length) 203 204 @property 205 def inputs(self): 206 return self._inputs 207 208 @property 209 def sequence_length(self): 210 return self._sequence_length 211 212 @property 213 def batch_size(self): 214 return self._batch_size 215 216 @property 217 def sample_ids_shape(self): 218 return tensor_shape.TensorShape([]) 219 220 @property 221 def sample_ids_dtype(self): 222 return dtypes.int32 223 224 def initialize(self, name=None): 225 with ops.name_scope(name, "TrainingHelperInitialize"): 226 finished = math_ops.equal(0, self._sequence_length) 227 all_finished = math_ops.reduce_all(finished) 228 next_inputs = control_flow_ops.cond( 229 all_finished, lambda: self._zero_inputs, 230 lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas)) 231 return (finished, next_inputs) 232 233 def sample(self, time, outputs, name=None, **unused_kwargs): 234 with ops.name_scope(name, "TrainingHelperSample", [time, outputs]): 235 sample_ids = math_ops.cast( 236 math_ops.argmax(outputs, axis=-1), dtypes.int32) 237 return sample_ids 238 239 def next_inputs(self, time, outputs, state, name=None, **unused_kwargs): 240 """next_inputs_fn for TrainingHelper.""" 241 with ops.name_scope(name, "TrainingHelperNextInputs", 242 [time, outputs, state]): 243 next_time = time + 1 244 finished = (next_time >= self._sequence_length) 245 all_finished = math_ops.reduce_all(finished) 246 def read_from_ta(inp): 247 return inp.read(next_time) 248 next_inputs = control_flow_ops.cond( 249 all_finished, lambda: self._zero_inputs, 250 lambda: nest.map_structure(read_from_ta, self._input_tas)) 251 return (finished, next_inputs, state) 252 253 254 class ScheduledEmbeddingTrainingHelper(TrainingHelper): 255 """A training helper that adds scheduled sampling. 256 257 Returns -1s for sample_ids where no sampling took place; valid sample id 258 values elsewhere. 259 """ 260 261 def __init__(self, inputs, sequence_length, embedding, sampling_probability, 262 time_major=False, seed=None, scheduling_seed=None, name=None): 263 """Initializer. 264 265 Args: 266 inputs: A (structure of) input tensors. 267 sequence_length: An int32 vector tensor. 268 embedding: A callable that takes a vector tensor of `ids` (argmax ids), 269 or the `params` argument for `embedding_lookup`. 270 sampling_probability: A 0D `float32` tensor: the probability of sampling 271 categorically from the output ids instead of reading directly from the 272 inputs. 273 time_major: Python bool. Whether the tensors in `inputs` are time major. 274 If `False` (default), they are assumed to be batch major. 275 seed: The sampling seed. 276 scheduling_seed: The schedule decision rule sampling seed. 277 name: Name scope for any created operations. 278 279 Raises: 280 ValueError: if `sampling_probability` is not a scalar or vector. 281 """ 282 with ops.name_scope(name, "ScheduledEmbeddingSamplingWrapper", 283 [embedding, sampling_probability]): 284 if callable(embedding): 285 self._embedding_fn = embedding 286 else: 287 self._embedding_fn = ( 288 lambda ids: embedding_ops.embedding_lookup(embedding, ids)) 289 self._sampling_probability = ops.convert_to_tensor( 290 sampling_probability, name="sampling_probability") 291 if self._sampling_probability.get_shape().ndims not in (0, 1): 292 raise ValueError( 293 "sampling_probability must be either a scalar or a vector. " 294 "saw shape: %s" % (self._sampling_probability.get_shape())) 295 self._seed = seed 296 self._scheduling_seed = scheduling_seed 297 super(ScheduledEmbeddingTrainingHelper, self).__init__( 298 inputs=inputs, 299 sequence_length=sequence_length, 300 time_major=time_major, 301 name=name) 302 303 def initialize(self, name=None): 304 return super(ScheduledEmbeddingTrainingHelper, self).initialize(name=name) 305 306 def sample(self, time, outputs, state, name=None): 307 with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample", 308 [time, outputs, state]): 309 # Return -1s where we did not sample, and sample_ids elsewhere 310 select_sampler = bernoulli.Bernoulli( 311 probs=self._sampling_probability, dtype=dtypes.bool) 312 select_sample = select_sampler.sample( 313 sample_shape=self.batch_size, seed=self._scheduling_seed) 314 sample_id_sampler = categorical.Categorical(logits=outputs) 315 return array_ops.where( 316 select_sample, 317 sample_id_sampler.sample(seed=self._seed), 318 gen_array_ops.fill([self.batch_size], -1)) 319 320 def next_inputs(self, time, outputs, state, sample_ids, name=None): 321 with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperNextInputs", 322 [time, outputs, state, sample_ids]): 323 (finished, base_next_inputs, state) = ( 324 super(ScheduledEmbeddingTrainingHelper, self).next_inputs( 325 time=time, 326 outputs=outputs, 327 state=state, 328 sample_ids=sample_ids, 329 name=name)) 330 331 def maybe_sample(): 332 """Perform scheduled sampling.""" 333 where_sampling = math_ops.cast( 334 array_ops.where(sample_ids > -1), dtypes.int32) 335 where_not_sampling = math_ops.cast( 336 array_ops.where(sample_ids <= -1), dtypes.int32) 337 sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling) 338 inputs_not_sampling = array_ops.gather_nd( 339 base_next_inputs, where_not_sampling) 340 sampled_next_inputs = self._embedding_fn(sample_ids_sampling) 341 base_shape = array_ops.shape(base_next_inputs) 342 return (array_ops.scatter_nd(indices=where_sampling, 343 updates=sampled_next_inputs, 344 shape=base_shape) 345 + array_ops.scatter_nd(indices=where_not_sampling, 346 updates=inputs_not_sampling, 347 shape=base_shape)) 348 349 all_finished = math_ops.reduce_all(finished) 350 next_inputs = control_flow_ops.cond( 351 all_finished, lambda: base_next_inputs, maybe_sample) 352 return (finished, next_inputs, state) 353 354 355 class ScheduledOutputTrainingHelper(TrainingHelper): 356 """A training helper that adds scheduled sampling directly to outputs. 357 358 Returns False for sample_ids where no sampling took place; True elsewhere. 359 """ 360 361 def __init__(self, inputs, sequence_length, sampling_probability, 362 time_major=False, seed=None, next_inputs_fn=None, 363 auxiliary_inputs=None, name=None): 364 """Initializer. 365 366 Args: 367 inputs: A (structure) of input tensors. 368 sequence_length: An int32 vector tensor. 369 sampling_probability: A 0D `float32` tensor: the probability of sampling 370 from the outputs instead of reading directly from the inputs. 371 time_major: Python bool. Whether the tensors in `inputs` are time major. 372 If `False` (default), they are assumed to be batch major. 373 seed: The sampling seed. 374 next_inputs_fn: (Optional) callable to apply to the RNN outputs to create 375 the next input when sampling. If `None` (default), the RNN outputs will 376 be used as the next inputs. 377 auxiliary_inputs: An optional (structure of) auxiliary input tensors with 378 a shape that matches `inputs` in all but (potentially) the final 379 dimension. These tensors will be concatenated to the sampled output or 380 the `inputs` when not sampling for use as the next input. 381 name: Name scope for any created operations. 382 383 Raises: 384 ValueError: if `sampling_probability` is not a scalar or vector. 385 """ 386 with ops.name_scope(name, "ScheduledOutputTrainingHelper", 387 [inputs, auxiliary_inputs, sampling_probability]): 388 self._sampling_probability = ops.convert_to_tensor( 389 sampling_probability, name="sampling_probability") 390 if self._sampling_probability.get_shape().ndims not in (0, 1): 391 raise ValueError( 392 "sampling_probability must be either a scalar or a vector. " 393 "saw shape: %s" % (self._sampling_probability.get_shape())) 394 395 if auxiliary_inputs is None: 396 maybe_concatenated_inputs = inputs 397 else: 398 inputs = ops.convert_to_tensor(inputs, name="inputs") 399 auxiliary_inputs = ops.convert_to_tensor( 400 auxiliary_inputs, name="auxiliary_inputs") 401 maybe_concatenated_inputs = nest.map_structure( 402 lambda x, y: array_ops.concat((x, y), -1), 403 inputs, auxiliary_inputs) 404 if not time_major: 405 auxiliary_inputs = nest.map_structure( 406 _transpose_batch_time, auxiliary_inputs) 407 408 self._auxiliary_input_tas = ( 409 nest.map_structure(_unstack_ta, auxiliary_inputs) 410 if auxiliary_inputs is not None else None) 411 412 self._seed = seed 413 414 self._next_inputs_fn = next_inputs_fn 415 416 super(ScheduledOutputTrainingHelper, self).__init__( 417 inputs=maybe_concatenated_inputs, 418 sequence_length=sequence_length, 419 time_major=time_major, 420 name=name) 421 422 def initialize(self, name=None): 423 return super(ScheduledOutputTrainingHelper, self).initialize(name=name) 424 425 def sample(self, time, outputs, state, name=None): 426 with ops.name_scope(name, "ScheduledOutputTrainingHelperSample", 427 [time, outputs, state]): 428 sampler = bernoulli.Bernoulli(probs=self._sampling_probability) 429 return sampler.sample(sample_shape=self.batch_size, seed=self._seed) 430 431 def next_inputs(self, time, outputs, state, sample_ids, name=None): 432 with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs", 433 [time, outputs, state, sample_ids]): 434 (finished, base_next_inputs, state) = ( 435 super(ScheduledOutputTrainingHelper, self).next_inputs( 436 time=time, 437 outputs=outputs, 438 state=state, 439 sample_ids=sample_ids, 440 name=name)) 441 sample_ids = math_ops.cast(sample_ids, dtypes.bool) 442 443 def maybe_sample(): 444 """Perform scheduled sampling.""" 445 446 def maybe_concatenate_auxiliary_inputs(outputs_, indices=None): 447 """Concatenate outputs with auxiliary inputs, if they exist.""" 448 if self._auxiliary_input_tas is None: 449 return outputs_ 450 451 next_time = time + 1 452 auxiliary_inputs = nest.map_structure( 453 lambda ta: ta.read(next_time), self._auxiliary_input_tas) 454 if indices is not None: 455 auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices) 456 return nest.map_structure( 457 lambda x, y: array_ops.concat((x, y), -1), 458 outputs_, auxiliary_inputs) 459 460 if self._next_inputs_fn is None: 461 return array_ops.where( 462 sample_ids, maybe_concatenate_auxiliary_inputs(outputs), 463 base_next_inputs) 464 465 where_sampling = math_ops.cast( 466 array_ops.where(sample_ids), dtypes.int32) 467 where_not_sampling = math_ops.cast( 468 array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32) 469 outputs_sampling = array_ops.gather_nd(outputs, where_sampling) 470 inputs_not_sampling = array_ops.gather_nd(base_next_inputs, 471 where_not_sampling) 472 sampled_next_inputs = maybe_concatenate_auxiliary_inputs( 473 self._next_inputs_fn(outputs_sampling), where_sampling) 474 475 base_shape = array_ops.shape(base_next_inputs) 476 return (array_ops.scatter_nd(indices=where_sampling, 477 updates=sampled_next_inputs, 478 shape=base_shape) 479 + array_ops.scatter_nd(indices=where_not_sampling, 480 updates=inputs_not_sampling, 481 shape=base_shape)) 482 483 all_finished = math_ops.reduce_all(finished) 484 no_samples = math_ops.logical_not(math_ops.reduce_any(sample_ids)) 485 next_inputs = control_flow_ops.cond( 486 math_ops.logical_or(all_finished, no_samples), 487 lambda: base_next_inputs, maybe_sample) 488 return (finished, next_inputs, state) 489 490 491 class GreedyEmbeddingHelper(Helper): 492 """A helper for use during inference. 493 494 Uses the argmax of the output (treated as logits) and passes the 495 result through an embedding layer to get the next input. 496 """ 497 498 def __init__(self, embedding, start_tokens, end_token): 499 """Initializer. 500 501 Args: 502 embedding: A callable that takes a vector tensor of `ids` (argmax ids), 503 or the `params` argument for `embedding_lookup`. The returned tensor 504 will be passed to the decoder input. 505 start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. 506 end_token: `int32` scalar, the token that marks end of decoding. 507 508 Raises: 509 ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a 510 scalar. 511 """ 512 if callable(embedding): 513 self._embedding_fn = embedding 514 else: 515 self._embedding_fn = ( 516 lambda ids: embedding_ops.embedding_lookup(embedding, ids)) 517 518 self._start_tokens = ops.convert_to_tensor( 519 start_tokens, dtype=dtypes.int32, name="start_tokens") 520 self._end_token = ops.convert_to_tensor( 521 end_token, dtype=dtypes.int32, name="end_token") 522 if self._start_tokens.get_shape().ndims != 1: 523 raise ValueError("start_tokens must be a vector") 524 self._batch_size = array_ops.size(start_tokens) 525 if self._end_token.get_shape().ndims != 0: 526 raise ValueError("end_token must be a scalar") 527 self._start_inputs = self._embedding_fn(self._start_tokens) 528 529 @property 530 def batch_size(self): 531 return self._batch_size 532 533 @property 534 def sample_ids_shape(self): 535 return tensor_shape.TensorShape([]) 536 537 @property 538 def sample_ids_dtype(self): 539 return dtypes.int32 540 541 def initialize(self, name=None): 542 finished = array_ops.tile([False], [self._batch_size]) 543 return (finished, self._start_inputs) 544 545 def sample(self, time, outputs, state, name=None): 546 """sample for GreedyEmbeddingHelper.""" 547 del time, state # unused by sample_fn 548 # Outputs are logits, use argmax to get the most probable id 549 if not isinstance(outputs, ops.Tensor): 550 raise TypeError("Expected outputs to be a single Tensor, got: %s" % 551 type(outputs)) 552 sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32) 553 return sample_ids 554 555 def next_inputs(self, time, outputs, state, sample_ids, name=None): 556 """next_inputs_fn for GreedyEmbeddingHelper.""" 557 del time, outputs # unused by next_inputs_fn 558 finished = math_ops.equal(sample_ids, self._end_token) 559 all_finished = math_ops.reduce_all(finished) 560 next_inputs = control_flow_ops.cond( 561 all_finished, 562 # If we're finished, the next_inputs value doesn't matter 563 lambda: self._start_inputs, 564 lambda: self._embedding_fn(sample_ids)) 565 return (finished, next_inputs, state) 566 567 568 class SampleEmbeddingHelper(GreedyEmbeddingHelper): 569 """A helper for use during inference. 570 571 Uses sampling (from a distribution) instead of argmax and passes the 572 result through an embedding layer to get the next input. 573 """ 574 575 def __init__(self, embedding, start_tokens, end_token, 576 softmax_temperature=None, seed=None): 577 """Initializer. 578 579 Args: 580 embedding: A callable that takes a vector tensor of `ids` (argmax ids), 581 or the `params` argument for `embedding_lookup`. The returned tensor 582 will be passed to the decoder input. 583 start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. 584 end_token: `int32` scalar, the token that marks end of decoding. 585 softmax_temperature: (Optional) `float32` scalar, value to divide the 586 logits by before computing the softmax. Larger values (above 1.0) result 587 in more random samples, while smaller values push the sampling 588 distribution towards the argmax. Must be strictly greater than 0. 589 Defaults to 1.0. 590 seed: (Optional) The sampling seed. 591 592 Raises: 593 ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a 594 scalar. 595 """ 596 super(SampleEmbeddingHelper, self).__init__( 597 embedding, start_tokens, end_token) 598 self._softmax_temperature = softmax_temperature 599 self._seed = seed 600 601 def sample(self, time, outputs, state, name=None): 602 """sample for SampleEmbeddingHelper.""" 603 del time, state # unused by sample_fn 604 # Outputs are logits, we sample instead of argmax (greedy). 605 if not isinstance(outputs, ops.Tensor): 606 raise TypeError("Expected outputs to be a single Tensor, got: %s" % 607 type(outputs)) 608 if self._softmax_temperature is None: 609 logits = outputs 610 else: 611 logits = outputs / self._softmax_temperature 612 613 sample_id_sampler = categorical.Categorical(logits=logits) 614 sample_ids = sample_id_sampler.sample(seed=self._seed) 615 616 return sample_ids 617 618 619 class InferenceHelper(Helper): 620 """A helper to use during inference with a custom sampling function.""" 621 622 def __init__(self, sample_fn, sample_shape, sample_dtype, 623 start_inputs, end_fn, next_inputs_fn=None): 624 """Initializer. 625 626 Args: 627 sample_fn: A callable that takes `outputs` and emits tensor `sample_ids`. 628 sample_shape: Either a list of integers, or a 1-D Tensor of type `int32`, 629 the shape of the each sample in the batch returned by `sample_fn`. 630 sample_dtype: the dtype of the sample returned by `sample_fn`. 631 start_inputs: The initial batch of inputs. 632 end_fn: A callable that takes `sample_ids` and emits a `bool` vector 633 shaped `[batch_size]` indicating whether each sample is an end token. 634 next_inputs_fn: (Optional) A callable that takes `sample_ids` and returns 635 the next batch of inputs. If not provided, `sample_ids` is used as the 636 next batch of inputs. 637 """ 638 self._sample_fn = sample_fn 639 self._end_fn = end_fn 640 self._sample_shape = tensor_shape.TensorShape(sample_shape) 641 self._sample_dtype = sample_dtype 642 self._next_inputs_fn = next_inputs_fn 643 self._batch_size = array_ops.shape(start_inputs)[0] 644 self._start_inputs = ops.convert_to_tensor( 645 start_inputs, name="start_inputs") 646 647 @property 648 def batch_size(self): 649 return self._batch_size 650 651 @property 652 def sample_ids_shape(self): 653 return self._sample_shape 654 655 @property 656 def sample_ids_dtype(self): 657 return self._sample_dtype 658 659 def initialize(self, name=None): 660 finished = array_ops.tile([False], [self._batch_size]) 661 return (finished, self._start_inputs) 662 663 def sample(self, time, outputs, state, name=None): 664 del time, state # unused by sample 665 return self._sample_fn(outputs) 666 667 def next_inputs(self, time, outputs, state, sample_ids, name=None): 668 del time, outputs # unused by next_inputs 669 if self._next_inputs_fn is None: 670 next_inputs = sample_ids 671 else: 672 next_inputs = self._next_inputs_fn(sample_ids) 673 finished = self._end_fn(sample_ids) 674 return (finished, next_inputs, state) 675