/external/tensorflow/tensorflow/compiler/xla/service/ |
conditional_simplifier.cc | 36 // computation. If the given conditional has a constant branch_index, tries to 54 int branch_index = 0; local 57 VLOG(2) << "Not attempting to remove conditional as its branch_index is " 64 branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1; 66 branch_index = conditional->operand(0)->literal().Get<int32>({}); 67 if (branch_index < 0 || branch_index >= conditional->branch_count()) { 68 branch_index = conditional->branch_count() - 1; 75 conditional->shape(), {conditional->mutable_operand(branch_index + 1)}, 76 conditional->branch_computation(branch_index))); [all...] |
hlo_evaluator.cc | 1264 int branch_index; local [all...] |
shape_inference.h | 213 const Shape& branch_index,
|
shape_inference.cc | [all...] |
hlo_instruction.h | [all...] |
hlo_instruction.cc | 224 << "conditional should have one branch_index operand plus one " [all...] |
/external/tensorflow/tensorflow/compiler/xla/service/gpu/ |
conditional_thunk.cc | 64 int32 branch_index = -1; local 71 stream->ThenMemcpy(&branch_index, branch_index_address, sizeof(int32)); 77 "Failed to retrieve branch_index value on stream %p: %s.", stream, 81 branch_index = pred ? 0 : 1; 83 // Handle default scenario for branch_index not in [0, num_branches). 84 if (branch_index < 0 || branch_index >= hlo_instruction()->branch_count()) { 85 branch_index = hlo_instruction()->branch_count() - 1; 89 // Execute the branch computation corresponding to the value of branch_index. 91 TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream [all...] |
/external/tensorflow/tensorflow/compiler/tf2xla/ |
functionalize_cond.cc | 428 int branch_index = static_cast<int>(branch); 434 .Finalize(bodies_[branch_index].get(), 435 &cond_arg_node.branch_copy[branch_index])); 440 int branch_index = e->src_output(); 441 Node* src_copy = cond_arg_node.branch_copy[branch_index]; 442 Node* dst_copy = node_maps_[branch_index][e->dst()->id()]; 449 << " on branch " << Branch_Name(BranchType(branch_index)); 454 bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input); 551 int branch_index = static_cast<int>(branch); 552 auto output = bodies_[branch_index].get() [all...] |
/external/tensorflow/tensorflow/core/kernels/data/experimental/ |
choose_fastest_branch_dataset_op.cc | 395 writer->WriteScalar(full_name("branch_index"), branch_index_)); 414 reader->ReadScalar(full_name("branch_index"), &branch_index_)); 471 Status MakeCurrentIterator(IteratorContext* ctx, int64 branch_index, 474 DCHECK_GE(branch_index, 0); 475 DCHECK_LT(branch_index, histograms_.size()); 484 params.node_name = strings::StrCat(params.type_string, branch_index); 497 strings::StrCat(take_dataset_params.type_string, branch_index); 509 ctx, {*wrapper_dataset_tensor_}, branch_index, 510 *instantiated_captured_funcs_[branch_index], prefix(),
|
/external/tensorflow/tensorflow/compiler/xla/tests/ |
conditional_test.cc | 197 XlaOp branch_index; local 199 &builder, &branch_index); 211 Conditional(branch_index, branches_p, operands); 240 XlaOp branch_index; local 242 &builder, &branch_index); 264 Conditional(branch_index, branches_p, operands); 415 XlaOp branch_index; local 417 CreateR0Parameter<int32>(bi, 0, "pred", &builder, &branch_index); 436 Conditional(branch_index, branches_p, [all...] |
/external/tensorflow/tensorflow/core/kernels/ |
functional_ops.cc | 243 const Tensor& branch_index = ctx->input(0); variable 244 OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(branch_index.shape()), 245 errors::InvalidArgument("branch_index must be scalar"), 247 int32 branch = branch_index.scalar<int32>()(); 317 Name("Case").Device(DEVICE_GPU).HostMemory("branch_index"), CaseOp);
|
/external/tensorflow/tensorflow/compiler/xla/service/cpu/ |
ir_emitter.cc | [all...] |
/external/tensorflow/tensorflow/compiler/xla/client/ |
xla_builder.h | 536 XlaOp Conditional(const XlaOp& branch_index, [all...] |
xla_builder.cc | [all...] |