Home | History | Annotate | Download | only in kernel_tests

Lines Matching refs:mvn

146       mvn = ds.MultivariateNormalDiag(
150 mvn.covariance().eval())
152 mvn = ds.MultivariateNormalDiag(
155 self.assertAllEqual([2], mvn.batch_shape)
156 self.assertAllEqual([3], mvn.event_shape)
164 mvn.covariance().eval())
166 mvn = ds.MultivariateNormalDiag(
169 self.assertAllEqual([2], mvn.batch_shape)
170 self.assertAllEqual([3], mvn.event_shape)
178 mvn.covariance().eval())
182 mvn = ds.MultivariateNormalDiag(
186 mvn.variance().eval())
188 mvn = ds.MultivariateNormalDiag(
194 mvn.variance().eval())
196 mvn = ds.MultivariateNormalDiag(
203 mvn.variance().eval())
207 mvn = ds.MultivariateNormalDiag(
211 mvn.stddev().eval())
213 mvn = ds.MultivariateNormalDiag(
219 mvn.stddev().eval())
221 mvn = ds.MultivariateNormalDiag(
227 mvn.stddev().eval())
255 mvn = ds.MultivariateNormalDiag(
259 # Typically you'd use `mvn.log_prob(x_pl)` which is always at least as
260 # numerically stable as `tf.log(mvn.prob(x_pl))`. However in this test
265 neg_log_likelihood = -math_ops.reduce_sum(math_ops.log(mvn.prob(x_pl)))
279 mvn = ds.MultivariateNormalDiag(
282 self.assertListEqual(mvn.batch_shape.as_list(), [None, None])
283 self.assertListEqual(mvn.event_shape.as_list(), [2])
286 mvn = ds.MultivariateNormalDiag(
289 self.assertListEqual(mvn.batch_shape.as_list(), [2, 3])
290 self.assertListEqual(mvn.event_shape.as_list(), [None])
296 mvn = ds.MultivariateNormalDiag(
299 g = gradients_impl.gradients(ds.kl_divergence(mvn, mvn), loc)