Home | History | Annotate | Download | only in kernels

Lines Matching refs:updates

38 // Check whether updates.shape = indices.shape + params.shape[1:]
39 static bool ValidShapes(const Tensor& params, const Tensor& updates,
41 if (updates.dims() != indices.dims() + params.dims() - 1) return false;
43 if (updates.dim_size(d) != indices.dim_size(d)) {
48 if (params.dim_size(d) != updates.dim_size(d - 1 + indices.dims())) {
56 const Tensor& indices, const Tensor& updates) {
63 c, ValidShapes(params, updates, indices),
65 "Must have updates.shape = indices.shape + params.shape[1:], got ",
66 "updates.shape ", updates.shape().DebugString(), ", indices.shape ",
85 // Hold mutex while we apply updates
99 const Tensor& updates = c->input(2);
100 DoValidationChecking(c, params, indices, updates);
125 auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
149 // Hold mutex while we apply updates
163 const Tensor& updates = c->input(2);
164 DoValidationChecking(c, params, indices, updates);
198 auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});