Home | History | Annotate | Download | only in kernel_tests

Lines Matching refs:grad_val

870       grad_val = -np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)
872 v0_grad = gradients_impl.gradients([vout], [v0], [grad_val])[0]
873 state0_grad = gradients_impl.gradients([vout], [state0], [grad_val])[0]
874 var_grad = gradients_impl.gradients([vout], [var], [grad_val])[0]
878 v0_grad, state0_grad, var_grad = grad_fn(v0, state0, var, dy=grad_val)
907 grad_val[0, :] + grad_val[1, :] + grad_val[2, :],
908 grad_val[1, :] + grad_val[2, :], grad_val[2, :]
913 self.assertAllClose(grad_val.sum(axis=0), var_grad_t)
914 self.assertAllClose(grad_val.sum(axis=0), state0_grad_t)