Skip to content

Commit 50264b0

Browse files
committed
Replace addr with factory in TServerSocket
1 parent be155ae commit 50264b0

1 file changed

Lines changed: 66 additions & 30 deletions

File tree

lib/go/thrift/server_socket.go

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* under the License.
1818
*/
1919

20+
2021
package thrift
2122

2223
import (
@@ -25,16 +26,20 @@ import (
2526
"time"
2627
)
2728

29+
// TServerSocketListenerFactory abstracts how listeners are created.
30+
type TServerSocketListenerFactory func(listen bool) (net.Addr, net.Listener, error)
31+
2832
type TServerSocket struct {
29-
addr net.Addr
33+
factory TServerSocketListenerFactory
3034
clientTimeout time.Duration
3135

32-
// Protects the listener and interrupted fields to make them thread safe.
3336
mu sync.RWMutex
3437
listener net.Listener
3538
interrupted bool
3639
}
3740

41+
// --- Constructors ---
42+
3843
func NewTServerSocket(listenAddr string) (*TServerSocket, error) {
3944
return NewTServerSocketTimeout(listenAddr, 0)
4045
}
@@ -44,28 +49,62 @@ func NewTServerSocketTimeout(listenAddr string, clientTimeout time.Duration) (*T
4449
if err != nil {
4550
return nil, err
4651
}
47-
return &TServerSocket{addr: addr, clientTimeout: clientTimeout}, nil
52+
53+
return NewTServerSocketFromAddrTimeout(addr, clientTimeout), nil
4854
}
4955

50-
// Creates a TServerSocket from a net.Addr
5156
func NewTServerSocketFromAddrTimeout(addr net.Addr, clientTimeout time.Duration) *TServerSocket {
52-
return &TServerSocket{addr: addr, clientTimeout: clientTimeout}
57+
factory := func(listen bool) (net.Addr, net.Listener, error) {
58+
var listener net.Listener
59+
var err error
60+
if (listen){
61+
listener, err = net.Listen(addr.Network(), addr.String())
62+
}
63+
return addr, listener, err
64+
}
65+
66+
return NewTServerSocketFromFactoryTimeout(factory, clientTimeout)
5367
}
5468

55-
func (p *TServerSocket) Listen() error {
69+
// Allows full customization (TLS, mocks, unix sockets, windows named pipes, etc.)
70+
func NewTServerSocketFromFactoryTimeout(factory TServerSocketListenerFactory, clientTimeout time.Duration) *TServerSocket {
71+
return &TServerSocket{
72+
factory: factory,
73+
clientTimeout: clientTimeout,
74+
}
75+
}
76+
77+
// --- Core methods ---
78+
79+
func (p *TServerSocket) try_listen(raise bool) error {
5680
p.mu.Lock()
5781
defer p.mu.Unlock()
58-
if p.IsListening() {
82+
83+
if p.listener != nil {
84+
if (raise) {
85+
return NewTTransportException(ALREADY_OPEN, "Server socket already open")
86+
}
5987
return nil
6088
}
61-
l, err := net.Listen(p.addr.Network(), p.addr.String())
89+
90+
_, l, err := p.factory(true)
6291
if err != nil {
6392
return err
6493
}
94+
6595
p.listener = l
96+
p.interrupted = false
6697
return nil
6798
}
6899

100+
func (p *TServerSocket) Open() error {
101+
return p.try_listen(true)
102+
}
103+
104+
func (p *TServerSocket) Listen() error {
105+
return p.try_listen(false)
106+
}
107+
69108
func (p *TServerSocket) Accept() (TTransport, error) {
70109
p.mu.RLock()
71110
interrupted := p.interrupted
@@ -87,51 +126,48 @@ func (p *TServerSocket) Accept() (TTransport, error) {
87126
return NewTSocketFromConnTimeout(conn, p.clientTimeout), nil
88127
}
89128

90-
// Checks whether the socket is listening.
129+
// --- State helpers ---
130+
91131
func (p *TServerSocket) IsListening() bool {
132+
p.mu.RLock()
133+
defer p.mu.RUnlock()
92134
return p.listener != nil
93135
}
94136

95-
// Connects the socket, creating a new socket object if necessary.
96-
func (p *TServerSocket) Open() error {
97-
p.mu.Lock()
98-
defer p.mu.Unlock()
99-
if p.IsListening() {
100-
return NewTTransportException(ALREADY_OPEN, "Server socket already open")
101-
}
102-
if l, err := net.Listen(p.addr.Network(), p.addr.String()); err != nil {
103-
return err
104-
} else {
105-
p.listener = l
106-
}
107-
return nil
108-
}
109-
110137
func (p *TServerSocket) Addr() net.Addr {
111138
p.mu.RLock()
112139
defer p.mu.RUnlock()
113-
if p.IsListening() {
140+
141+
if p.listener != nil {
114142
return p.listener.Addr()
115143
}
116-
return p.addr
144+
addr, _, _ := p.factory(false)
145+
return addr
117146
}
118147

148+
// --- Shutdown / control ---
149+
119150
func (p *TServerSocket) Close() error {
120-
var err error
121151
p.mu.Lock()
122-
if p.IsListening() {
152+
defer p.mu.Unlock()
153+
154+
var err error
155+
if p.listener != nil {
123156
err = p.listener.Close()
124157
p.listener = nil
125158
}
126-
p.mu.Unlock()
127159
return err
128160
}
129161

130162
func (p *TServerSocket) Interrupt() error {
131163
p.mu.Lock()
132164
p.interrupted = true
165+
listener := p.listener
166+
p.listener = nil
133167
p.mu.Unlock()
134-
p.Close()
135168

169+
if listener != nil {
170+
return listener.Close()
171+
}
136172
return nil
137173
}

0 commit comments

Comments
 (0)