HomeSort by relevance Sort by last modified time
    Searched refs:batch_dims (Results 1 - 12 of 12) sorted by null

  /external/tensorflow/tensorflow/python/kernel_tests/
matrix_triangular_solve_op_test.py 31 def _verifySolveAllWays(self, x, y, dtypes, batch_dims=None):
40 batch_dims=batch_dims,
44 def _verifySolveAllWaysReal(self, x, y, batch_dims=None):
45 self._verifySolveAllWays(x, y, (np.float32, np.float64), batch_dims)
47 def _verifySolveAllWaysComplex(self, x, y, batch_dims=None):
48 self._verifySolveAllWays(x, y, (np.complex64, np.complex128), batch_dims)
55 batch_dims=None,
72 if batch_dims is not None:
73 a = np.tile(a, batch_dims + [1, 1]
    [all...]
matrix_solve_op_test.py 37 def _verifySolve(self, x, y, batch_dims=None):
51 if batch_dims is not None:
52 a = np.tile(a, batch_dims + [1, 1])
53 a_np = np.tile(a_np, batch_dims + [1, 1])
54 b = np.tile(b, batch_dims + [1, 1])
89 for batch_dims in [[2], [2, 2], [7, 4]]:
90 self._verifySolve(matrix, rhs, batch_dims=batch_dims)
qr_op_test.py 202 for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
204 shape = batch_dims + (rows, cols)
222 for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
223 shape = batch_dims + (rows, cols)
svd_op_test.py 264 for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
265 shape = batch_dims + (rows, cols)
282 for batch_dims in [(), (3,)]:
283 shape = batch_dims + mat_shape
matrix_exponential_op_test.py 134 for batch_dims in [(), (1,), (3,), (2, 2)]:
136 shape = batch_dims + (size, size)
matrix_logarithm_op_test.py 104 for batch_dims in [(), (1,), (3,), (2, 2)]:
106 shape = batch_dims + (size, size)
batch_matmul_op_test.py 37 batch_dims = x.shape[:-2]
38 num = np.prod(batch_dims)
39 z = np.empty(list(batch_dims) + [d0, d2], dtype=x.dtype)
matrix_inverse_op_test.py 131 for batch_dims in [(), (1,), (3,), (2, 2)]:
133 shape = batch_dims + (size, size)
self_adjoint_eig_op_test.py 229 for batch_dims in [(), (3,)] + [(3, 2)] * (max(size, size) < 10):
230 shape = batch_dims + (size, size)
  /external/tensorflow/tensorflow/contrib/distributions/python/ops/
test_util.py 88 batch_dims = array_ops.shape(dist.batch_shape_tensor())[0]
89 edges_expanded_shape = 1 + array_ops.pad([-2], paddings=[[0, batch_dims]])
  /external/tensorflow/tensorflow/compiler/xla/service/
shape_inference.cc 588 tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
592 std::all_of(batch_dims.begin(), batch_dims.end(), in_range);
614 tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
621 std::all_of(batch_dims.begin(), batch_dims.end(), is_unique);
    [all...]
  /external/tensorflow/tensorflow/core/ops/
math_ops.cc 88 ShapeHandle batch_dims;
91 TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
100 batch_dims, c->Matrix(output_rows, output_cols), &out));
    [all...]

Completed in 279 milliseconds