diff --git a/lib/rb/lib/thrift/server/nonblocking_server.rb b/lib/rb/lib/thrift/server/nonblocking_server.rb index f2959f2fe16..8290bf2eb71 100644 --- a/lib/rb/lib/thrift/server/nonblocking_server.rb +++ b/lib/rb/lib/thrift/server/nonblocking_server.rb @@ -138,7 +138,12 @@ def shutdown(timeout = 0) def ensure_closed kill_worker_threads if @worker_threads - @iom_thread.kill + if @iom_thread&.alive? + @iom_thread.kill + @iom_thread.join + end + close_connections + close_signal_pipes end private @@ -246,6 +251,26 @@ def kill_worker_threads @worker_threads.clear end + def close_connections + @connections.each do |fd| + begin + fd.close + rescue IOError, SystemCallError, TransportException + end + end + @connections.clear + @buffers.clear + end + + def close_signal_pipes + @signal_pipes.each do |pipe| + begin + pipe.close unless pipe.closed? + rescue IOError + end + end + end + def slice_frame!(buf) if buf.length >= 4 size = buf.unpack('N').first diff --git a/lib/rb/spec/nonblocking_server_spec.rb b/lib/rb/spec/nonblocking_server_spec.rb index 572e0b5b018..58949e34920 100644 --- a/lib/rb/spec/nonblocking_server_spec.rb +++ b/lib/rb/spec/nonblocking_server_spec.rb @@ -101,7 +101,7 @@ def listen describe Thrift::NonblockingServer do before(:each) do - @port = 43251 + @port = available_port handler = Handler.new processor = SpecNamespace::NonblockingService::Processor.new(handler) queue = Queue.new @@ -121,6 +121,7 @@ def listen end end queue.pop + wait_until_listening(@transport, @server_thread) @clients = [] @catch_exceptions = false @@ -128,9 +129,11 @@ def listen after(:each) do @clients.each { |client, trans| trans.close } - # @server.shutdown(1) - @server_thread.kill - @transport.close + @server.shutdown(1, false) if @server + @server_thread.join(2) if @server_thread + @server_thread.kill if @server_thread && @server_thread.alive? + @server_thread.join(2) if @server_thread + @transport.close if @transport end def setup_client(queue = nil) @@ -261,6 +264,70 @@ def setup_client_thread(result) end end + describe Thrift::NonblockingServer::IOManager do + def build_io_manager + logger = Logger.new(IO::NULL) + logger.level = Logger::FATAL + Thrift::NonblockingServer::IOManager.new( + double('processor'), + double('server_transport'), + Thrift::BaseTransportFactory.new, + Thrift::BinaryProtocolFactory.new, + 1, + logger + ) + end + + it "closes tracked connections and signal pipes during forced cleanup" do + io_manager = build_io_manager + connection = double('connection', :close => nil) + pipe_a = double('pipe_a', :closed? => false, :close => nil) + pipe_b = double('pipe_b', :closed? => false, :close => nil) + + io_manager.instance_variable_set(:@connections, [connection]) + io_manager.instance_variable_set(:@buffers, { connection => 'frame' }) + io_manager.instance_variable_set(:@signal_pipes, [pipe_a, pipe_b]) + io_manager.instance_variable_set(:@worker_threads, []) + + io_manager.ensure_closed + + expect(connection).to have_received(:close) + expect(pipe_a).to have_received(:close) + expect(pipe_b).to have_received(:close) + expect(io_manager.instance_variable_get(:@connections)).to be_empty + expect(io_manager.instance_variable_get(:@buffers)).to be_empty + end + + it "continues closing remaining signal pipes when one close raises" do + io_manager = build_io_manager + pipe_a = double('pipe_a', :closed? => false) + pipe_b = double('pipe_b', :closed? => false, :close => nil) + + allow(pipe_a).to receive(:close).and_raise(IOError) + + io_manager.instance_variable_set(:@signal_pipes, [pipe_a, pipe_b]) + io_manager.instance_variable_set(:@worker_threads, []) + + io_manager.send(:close_signal_pipes) + + expect(pipe_a).to have_received(:close) + expect(pipe_b).to have_received(:close) + end + + it "drops removed connections from bookkeeping" do + io_manager = build_io_manager + connection = double('connection', :close => nil) + + io_manager.instance_variable_set(:@connections, [connection]) + io_manager.instance_variable_set(:@buffers, { connection => 'frame' }) + + io_manager.send(:remove_connection, connection) + + expect(io_manager.instance_variable_get(:@connections)).to be_empty + expect(io_manager.instance_variable_get(:@buffers)).to be_empty + end + end + describe "#{Thrift::NonblockingServer} with TLS transport" do before(:each) do @port = available_port @@ -282,7 +349,7 @@ def setup_client_thread(result) end @clients = [] - wait_until_listening + wait_until_listening(@transport, @server_thread) end after(:each) do @@ -313,19 +380,6 @@ def setup_tls_client client end - def wait_until_listening - Timeout.timeout(2) do - until @transport.handle - raise "Server thread exited unexpectedly" unless @server_thread.alive? - sleep 0.01 - end - end - end - - def available_port - TCPServer.open('localhost', 0) { |server| server.addr[1] } - end - def ssl_keys_dir File.expand_path('../../../test/keys', __dir__) end @@ -358,4 +412,17 @@ def create_client_ssl_context end end end + + def wait_until_listening(server_transport, server_thread) + Timeout.timeout(2) do + until server_transport.handle + raise "Server thread exited unexpectedly" unless server_thread.alive? + sleep 0.01 + end + end + end + + def available_port + TCPServer.open('localhost', 0) { |server| server.addr[1] } + end end