HomeSort by relevance Sort by last modified time
    Searched refs:branch_index (Results 1 - 14 of 14) sorted by null

  /external/tensorflow/tensorflow/compiler/xla/service/
conditional_simplifier.cc 36 // computation. If the given conditional has a constant branch_index, tries to
54 int branch_index = 0; local
57 VLOG(2) << "Not attempting to remove conditional as its branch_index is "
64 branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1;
66 branch_index = conditional->operand(0)->literal().Get<int32>({});
67 if (branch_index < 0 || branch_index >= conditional->branch_count()) {
68 branch_index = conditional->branch_count() - 1;
75 conditional->shape(), {conditional->mutable_operand(branch_index + 1)},
76 conditional->branch_computation(branch_index)));
    [all...]
hlo_evaluator.cc 1264 int branch_index; local
    [all...]
shape_inference.h 213 const Shape& branch_index,
shape_inference.cc     [all...]
hlo_instruction.h     [all...]
hlo_instruction.cc 224 << "conditional should have one branch_index operand plus one "
    [all...]
  /external/tensorflow/tensorflow/compiler/xla/service/gpu/
conditional_thunk.cc 64 int32 branch_index = -1; local
71 stream->ThenMemcpy(&branch_index, branch_index_address, sizeof(int32));
77 "Failed to retrieve branch_index value on stream %p: %s.", stream,
81 branch_index = pred ? 0 : 1;
83 // Handle default scenario for branch_index not in [0, num_branches).
84 if (branch_index < 0 || branch_index >= hlo_instruction()->branch_count()) {
85 branch_index = hlo_instruction()->branch_count() - 1;
89 // Execute the branch computation corresponding to the value of branch_index.
91 TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream
    [all...]
  /external/tensorflow/tensorflow/compiler/tf2xla/
functionalize_cond.cc 428 int branch_index = static_cast<int>(branch);
434 .Finalize(bodies_[branch_index].get(),
435 &cond_arg_node.branch_copy[branch_index]));
440 int branch_index = e->src_output();
441 Node* src_copy = cond_arg_node.branch_copy[branch_index];
442 Node* dst_copy = node_maps_[branch_index][e->dst()->id()];
449 << " on branch " << Branch_Name(BranchType(branch_index));
454 bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input);
551 int branch_index = static_cast<int>(branch);
552 auto output = bodies_[branch_index].get()
    [all...]
  /external/tensorflow/tensorflow/core/kernels/data/experimental/
choose_fastest_branch_dataset_op.cc 395 writer->WriteScalar(full_name("branch_index"), branch_index_));
414 reader->ReadScalar(full_name("branch_index"), &branch_index_));
471 Status MakeCurrentIterator(IteratorContext* ctx, int64 branch_index,
474 DCHECK_GE(branch_index, 0);
475 DCHECK_LT(branch_index, histograms_.size());
484 params.node_name = strings::StrCat(params.type_string, branch_index);
497 strings::StrCat(take_dataset_params.type_string, branch_index);
509 ctx, {*wrapper_dataset_tensor_}, branch_index,
510 *instantiated_captured_funcs_[branch_index], prefix(),
  /external/tensorflow/tensorflow/compiler/xla/tests/
conditional_test.cc 197 XlaOp branch_index; local
199 &builder, &branch_index);
211 Conditional(branch_index, branches_p, operands);
240 XlaOp branch_index; local
242 &builder, &branch_index);
264 Conditional(branch_index, branches_p, operands);
415 XlaOp branch_index; local
417 CreateR0Parameter<int32>(bi, 0, "pred", &builder, &branch_index);
436 Conditional(branch_index, branches_p,
    [all...]
  /external/tensorflow/tensorflow/core/kernels/
functional_ops.cc 243 const Tensor& branch_index = ctx->input(0); variable
244 OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(branch_index.shape()),
245 errors::InvalidArgument("branch_index must be scalar"),
247 int32 branch = branch_index.scalar<int32>()();
317 Name("Case").Device(DEVICE_GPU).HostMemory("branch_index"), CaseOp);
  /external/tensorflow/tensorflow/compiler/xla/service/cpu/
ir_emitter.cc     [all...]
  /external/tensorflow/tensorflow/compiler/xla/client/
xla_builder.h 536 XlaOp Conditional(const XlaOp& branch_index,
    [all...]
xla_builder.cc     [all...]

Completed in 421 milliseconds