Skip to content

Commit 160aec1

Browse files
committed
Update server_socket.go
1 parent 50264b0 commit 160aec1

1 file changed

Lines changed: 27 additions & 17 deletions

File tree

lib/go/thrift/server_socket.go

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,9 @@ import (
2626
"time"
2727
)
2828

29-
// TServerSocketListenerFactory abstracts how listeners are created.
30-
type TServerSocketListenerFactory func(listen bool) (net.Addr, net.Listener, error)
31-
3229
type TServerSocket struct {
33-
factory TServerSocketListenerFactory
30+
// TServerSocketListenerFactory abstracts how listeners are created.
31+
factory func(bool) (net.Addr, net.Listener, error)
3432
clientTimeout time.Duration
3533

3634
mu sync.RWMutex
@@ -67,7 +65,20 @@ func NewTServerSocketFromAddrTimeout(addr net.Addr, clientTimeout time.Duration)
6765
}
6866

6967
// Allows full customization (TLS, mocks, unix sockets, windows named pipes, etc.)
70-
func NewTServerSocketFromFactoryTimeout(factory TServerSocketListenerFactory, clientTimeout time.Duration) *TServerSocket {
68+
func NewTServerSocketFromFactoryAddrTimeout(proc func(addr net.Addr) (listener net.Listener, err error),addr net.Addr, clientTimeout time.Duration) *TServerSocket {
69+
factory := func(listen bool) (net.Addr, net.Listener, error) {
70+
var listener net.Listener
71+
var err error
72+
if (listen){
73+
listener, err = proc(addr)
74+
}
75+
return addr, listener, err
76+
}
77+
return NewTServerSocketFromFactoryTimeout(factory, clientTimeout)
78+
}
79+
80+
// Allows full customization (TLS, mocks, unix sockets, windows named pipes, etc.)
81+
func NewTServerSocketFromFactoryTimeout(factory func(listen bool) (addr net.Addr, listener net.Listener, err error), clientTimeout time.Duration) *TServerSocket {
7182
return &TServerSocket{
7283
factory: factory,
7384
clientTimeout: clientTimeout,
@@ -147,27 +158,26 @@ func (p *TServerSocket) Addr() net.Addr {
147158

148159
// --- Shutdown / control ---
149160

150-
func (p *TServerSocket) Close() error {
161+
func (p *TServerSocket) try_close(interrupt bool) error {
151162
p.mu.Lock()
152163
defer p.mu.Unlock()
164+
if (interrupt){
165+
p.interrupted = true
166+
}
153167

154-
var err error
168+
var err error = nil
155169
if p.listener != nil {
156170
err = p.listener.Close()
157171
p.listener = nil
158172
}
159173
return err
160174
}
161175

162-
func (p *TServerSocket) Interrupt() error {
163-
p.mu.Lock()
164-
p.interrupted = true
165-
listener := p.listener
166-
p.listener = nil
167-
p.mu.Unlock()
168176

169-
if listener != nil {
170-
return listener.Close()
171-
}
172-
return nil
177+
func (p *TServerSocket) Close() error {
178+
return p.try_close(false)
179+
}
180+
181+
func (p *TServerSocket) Interrupt() error {
182+
return p.try_close(true)
173183
}

0 commit comments

Comments
 (0)