Home | History | Annotate | Download | only in tests

Lines Matching refs:builder

57   ComputationBuilder builder(client_, TestName());
58 auto lhs = builder.ConstantR1<float>({});
59 auto rhs = builder.ConstantR1<float>({});
60 auto result = builder.Dot(lhs, rhs);
62 ComputeAndCompareR0<float>(&builder, 0.0, {}, error_spec_);
66 ComputationBuilder builder(client_, TestName());
67 auto lhs = builder.ConstantR2<float>({{3.0, 4.0}});
68 auto rhs = builder.ConstantR1<float>({3.0, 4.0});
69 auto result = builder.Dot(lhs, rhs);
71 ComputeAndCompareR1<float>(&builder, {25.0}, {}, error_spec_);
76 ComputationBuilder builder(client_, TestName());
77 auto lhs = builder.ConstantR1<Element>({2.0});
78 auto rhs = builder.ConstantR1<Element>({3.0});
79 auto result = builder.Dot(lhs, rhs);
81 ComputeAndCompareR0<Element>(&builder, 6.0, {}, error_spec_);
94 ComputationBuilder builder(client_, TestName());
95 auto lhs = builder.ConstantR1<Element>({1.0, 2.5, 42.0});
96 auto rhs = builder.ConstantR1<Element>({11.0, -1.0, 0.5});
97 auto result = builder.Dot(lhs, rhs);
99 ComputeAndCompareR0<Element>(&builder, 29.5, {}, error_spec_);
115 ComputationBuilder builder(client_, TestName());
116 auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
117 auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
118 auto result = builder.Dot(lhs, rhs);
120 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, error_spec_);
124 ComputationBuilder builder(client_, TestName());
125 auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
126 auto rhs = builder.ConstantR2<float>({{7.0, 8.0, 9.0}, {42.0, 77.0, 101.0}});
127 auto result = builder.Dot(lhs, rhs);
129 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 3), {}, error_spec_);
133 ComputationBuilder builder(client_, TestName());
135 builder.ConstantR2<float>({{7.0, 8.0}, {9.0, 42.0}, {77.0, 101.0}});
136 auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
137 auto result = builder.Dot(lhs, rhs);
139 ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0), {}, error_spec_);
143 ComputationBuilder builder(client_, TestName());
144 auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
145 auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
146 auto result = builder.Dot(lhs, rhs);
148 ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 2, 0.0f), {},
153 ComputationBuilder builder(client_, TestName());
154 auto param0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 4}), "arg0");
155 auto param1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4, 1}), "arg1");
156 auto exp0 = builder.Exp(param0);
157 auto result = builder.Dot(exp0, param1);
169 &builder, Array2D<float>({{296.14560492846033}, {0.8611737683031964}}),
189 ComputationBuilder builder(client_, TestName());
191 auto result = builder.Dot(
192 builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"),
193 builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs"));
197 &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
261 ComputationBuilder builder(client_, TestName());
263 auto result = builder.Dot(
264 builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {param.m, param.k}),
266 builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {param.k, param.n}),
270 result = builder.Add(
272 builder.Parameter(
291 ComputeAndCompareR2<float>(&builder, *expected, args, ErrorSpec(0.3, 3e-3));
398 ComputationBuilder builder(client_, TestName());
400 auto result = builder.Dot(
401 builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"),
402 builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs"));
407 &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
459 ComputationBuilder builder(client_, TestName());
461 auto result = builder.Dot(
462 builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"),
463 builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs"));
468 &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
472 ComputationBuilder builder(client_, TestName());
473 auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}});
474 auto matrix2 = builder.ConstantR2<float>({{5.0, 6.0}, {7.0, 8.0}});
475 auto matrix12 = builder.Dot(matrix1, matrix2);
476 auto matrix21 = builder.Dot(matrix2, matrix1);
477 builder.Add(matrix12, matrix21);
480 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
487 ComputationBuilder builder(client_, TestName());
488 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "x");
489 auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "y");
491 auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2});
492 auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2});
498 auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
499 x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2});
500 auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
501 y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2});
503 auto out = builder.Dot(x_slice, y_slice);
504 out = builder.Reshape(out, {0, 1}, {1, 2, 2});
507 auto out_flat = builder.ConcatInDim(out_slices, 0);
508 builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
522 &builder,
530 ComputationBuilder builder(client_, TestName());
531 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2}), "x");
532 auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2}), "y");
540 auto out = builder.DotGeneral(x, y, dnums);
553 &builder,
589 ComputationBuilder builder(client_, TestName());
591 auto lhs_arg = builder.Parameter(
594 auto rhs_arg = builder.Parameter(
598 lhs_arg = builder.Transpose(lhs_arg, {1, 0});
601 rhs_arg = builder.Transpose(rhs_arg, {1, 0});
603 auto result = builder.Dot(lhs_arg, rhs_arg);
608 ComputeAndCompareR2<float>(&builder, expected,
622 ComputationBuilder builder(client_, TestName());
623 auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
624 auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}),
626 auto rhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}),
628 auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}),
630 auto result = builder.Dot(
631 lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0));
655 &builder, expected,
670 ComputationBuilder builder(client_, TestName());
671 auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
672 auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}),
674 auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 3}),
676 auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {2, 1}),
678 auto result = builder.Dot(
679 builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant);
703 &builder, expected,