1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tf.GrpcServer.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.client import session 22 from tensorflow.python.framework import errors_impl 23 from tensorflow.python.framework import ops 24 from tensorflow.python.ops import variables 25 from tensorflow.python.platform import test 26 from tensorflow.python.training import server_lib 27 28 29 class MultipleContainersTest(test.TestCase): 30 31 # Verifies behavior of tf.Session.reset() with multiple containers using 32 # tf.container. 33 # TODO(b/34465411): Starting multiple servers with different configurations 34 # in the same test is flaky. Move this test case back into 35 # "server_lib_test.py" when this is no longer the case. 36 def testMultipleContainers(self): 37 with ops.container("test0"): 38 v0 = variables.Variable(1.0, name="v0") 39 with ops.container("test1"): 40 v1 = variables.Variable(2.0, name="v0") 41 server = server_lib.Server.create_local_server() 42 sess = session.Session(server.target) 43 sess.run(variables.global_variables_initializer()) 44 self.assertAllEqual(1.0, sess.run(v0)) 45 self.assertAllEqual(2.0, sess.run(v1)) 46 47 # Resets container. Session aborts. 48 session.Session.reset(server.target, ["test0"]) 49 with self.assertRaises(errors_impl.AbortedError): 50 sess.run(v1) 51 52 # Connects to the same target. Device memory for the v0 would have 53 # been released, so it will be uninitialized. But v1 should still 54 # be valid. 55 sess = session.Session(server.target) 56 with self.assertRaises(errors_impl.FailedPreconditionError): 57 sess.run(v0) 58 self.assertAllEqual(2.0, sess.run(v1)) 59 60 61 if __name__ == "__main__": 62 test.main() 63