Skip to content

Commit 1e0d9d7

Browse files
committed
Add an optional Init func to the server, fix a data race
That takes a config struct from the client. The old way of configuring the server was to pass env vars (which still works), but this was at best very cumbersome. This also fixes a data race when both sending raw (e.g. log messages) and other responses. Closes #10
1 parent 380fc7d commit 1e0d9d7

12 files changed

Lines changed: 471 additions & 176 deletions

File tree

README.md

Lines changed: 73 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -7,66 +7,60 @@ This library implements a simple, custom [RPC protocol](https://en.wikipedia.org
77
A strongly typed client may look like this:
88

99
```go
10-
package main
11-
12-
import (
13-
"fmt"
14-
"log"
15-
"time"
16-
17-
"github.com/bep/execrpc"
18-
"github.com/bep/execrpc/codecs"
19-
"github.com/bep/execrpc/examples/model"
20-
)
21-
22-
func main() {
23-
// Define the request, message and receipt types for the RPC call.
10+
// Define the request, message and receipt types for the RPC call.
11+
client, err := execrpc.StartClient(
2412
client, err := execrpc.StartClient(
25-
execrpc.ClientOptions[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{
26-
ClientRawOptions: execrpc.ClientRawOptions{
27-
Version: 1,
28-
Cmd: "go",
29-
Dir: "./examples/servers/typed",
30-
Args: []string{"run", "."},
31-
Env: nil,
32-
Timeout: 30 * time.Second,
33-
},
34-
Codec: codecs.JSONCodec{},
13+
execrpc.ClientOptions[model.ExampleConfig, model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{
14+
ClientRawOptions: execrpc.ClientRawOptions{
15+
Version: 1,
16+
Cmd: "go",
17+
Dir: "./examples/servers/typed",
18+
Args: []string{"run", "."},
19+
Env: env,
20+
Timeout: 30 * time.Second,
3521
},
36-
)
37-
if err != nil {
38-
log.Fatal(err)
22+
Config: model.ExampleConfig{},
23+
Codec: codec,
24+
},
25+
)
26+
27+
if err != nil {
28+
log.Fatal(err)
29+
}
30+
31+
// Consume standalone messages (e.g. log messages) in its own goroutine.
32+
go func() {
33+
for msg := range client.MessagesRaw() {
34+
fmt.Println("got message", string(msg.Body))
3935
}
36+
}()
4037

41-
// Consume standalone messages (e.g. log messages) in its own goroutine.
42-
go func() {
43-
for msg := range client.MessagesRaw() {
44-
fmt.Println("got message", string(msg.Body))
45-
}
46-
}()
38+
// Execute the request.
39+
result := client.Execute(model.ExampleRequest{Text: "world"})
4740

48-
// Execute the request.
49-
result := client.Execute(model.ExampleRequest{Text: "world"})
41+
// Check for errors.
42+
if err := result.Err(); err != nil {
43+
log.Fatal(err)
44+
}
5045

51-
// Check for errors.
52-
if err := result.Err(); err != nil {
53-
log.Fatal(err)
54-
}
46+
// Consume the messages.
47+
for m := range result.Messages() {
48+
fmt.Println(m)
49+
}
5550

56-
// Consume the messages.
57-
for m := range result.Messages() {
58-
fmt.Println(m)
59-
}
51+
// Wait for the receipt.
52+
receipt := <-result.Receipt()
6053

61-
// Wait for the receipt.
62-
receipt := <-result.Receipt()
54+
// Check again for errors.
55+
if err := result.Err(); err != nil {
56+
log.Fatal(err)
57+
}
6358

64-
// Check again for errors.
65-
if err := result.Err(); err != nil {
66-
log.Fatal(err)
67-
}
59+
fmt.Println(receipt.Text)
6860

69-
fmt.Println(receipt.Text)
61+
// Close the client.
62+
if err := client.Close(); err != nil {
63+
log.Fatal(err)
7064
}
7165
```
7266

@@ -75,20 +69,32 @@ To get the best performance you should keep the client open as long as its neede
7569
And the server side of the above:
7670

7771
```go
72+
7873
func main() {
79-
getHasher := func() hash.Hash {
80-
return fnv.New64a()
81-
}
74+
log.SetFlags(0)
75+
log.SetPrefix("readme-example: ")
76+
77+
var clientConfig model.ExampleConfig
8278

8379
server, err := execrpc.NewServer(
84-
execrpc.ServerOptions[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{
85-
// Optional function to get a hasher for the ETag.
86-
GetHasher: getHasher,
80+
execrpc.ServerOptions[model.ExampleConfig, model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{
81+
// Optional function to provide a hasher for the ETag.
82+
GetHasher: func() hash.Hash {
83+
return fnv.New64a()
84+
},
8785

8886
// Allows you to delay message delivery, and drop
8987
// them after reading the receipt (e.g. the ETag matches the ETag seen by client).
9088
DelayDelivery: false,
9189

90+
// Optional function to initialize the server
91+
// with the client configuration.
92+
// This will be called once on server start.
93+
Init: func(cfg model.ExampleConfig) error {
94+
clientConfig = cfg
95+
return clientConfig.Init()
96+
},
97+
9298
// Handle the incoming call.
9399
Handle: func(c *execrpc.Call[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]) {
94100
// Raw messages are passed directly to the client,
@@ -99,7 +105,7 @@ func main() {
99105
Version: 32,
100106
Status: 150,
101107
},
102-
Body: []byte("a log message"),
108+
Body: []byte("log message"),
103109
},
104110
)
105111

@@ -124,27 +130,32 @@ func main() {
124130

125131
// ETag provided by the framework.
126132
// A hash of all message bodies.
127-
fmt.Println("Receipt:", receipt.ETag)
133+
// fmt.Println("Receipt:", receipt.ETag)
128134

129135
// Modify if needed.
130136
receipt.Size = uint32(123)
137+
receipt.Text = "echoed: " + c.Request.Text
131138

132-
// Close the message stream.
139+
// Close the message stream and send the receipt.
133140
// Pass true to drop any queued messages,
134141
// this is only relevant if DelayDelivery is enabled.
135142
c.Close(false, receipt)
136143
},
137144
},
138145
)
139146
if err != nil {
140-
log.Fatal(err)
147+
handleErr(err)
141148
}
142149

143-
// Start the server. This will block.
144150
if err := server.Start(); err != nil {
145-
log.Fatal(err)
151+
handleErr(err)
146152
}
147153
}
154+
155+
func handleErr(err error) {
156+
log.Fatalf("error: failed to start typed echo server: %s", err)
157+
}
158+
148159
```
149160

150161
## Generate ETag

client.go

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ const (
2424
)
2525

2626
// StartClient starts a client for the given options.
27-
func StartClient[Q, M, R any](opts ClientOptions[Q, M, R]) (*Client[Q, M, R], error) {
27+
func StartClient[C, Q, M, R any](opts ClientOptions[C, Q, M, R]) (*Client[C, Q, M, R], error) {
2828
if opts.Codec == nil {
2929
return nil, errors.New("opts: Codec is required")
3030
}
@@ -37,16 +37,23 @@ func StartClient[Q, M, R any](opts ClientOptions[Q, M, R]) (*Client[Q, M, R], er
3737
return nil, err
3838
}
3939

40-
return &Client[Q, M, R]{
40+
c := &Client[C, Q, M, R]{
4141
rawClient: rawClient,
4242
opts: opts,
43-
}, nil
43+
}
44+
45+
err = c.init(opts.Config)
46+
if err != nil {
47+
return nil, err
48+
}
49+
50+
return c, nil
4451
}
4552

4653
// Client is a strongly typed RPC client.
47-
type Client[Q, M, R any] struct {
54+
type Client[C, Q, M, R any] struct {
4855
rawClient *ClientRaw
49-
opts ClientOptions[Q, M, R]
56+
opts ClientOptions[C, Q, M, R]
5057
}
5158

5259
// Result is the result of a request
@@ -85,13 +92,49 @@ func (r Result[M, R]) close() {
8592
// MessagesRaw returns the raw messages from the server.
8693
// These are not connected to the request-response flow,
8794
// typically used for log messages etc.
88-
func (c *Client[Q, M, R]) MessagesRaw() <-chan Message {
95+
func (c *Client[C, Q, M, R]) MessagesRaw() <-chan Message {
8996
return c.rawClient.Messages
9097
}
9198

99+
// init passes the configuration to the server.
100+
func (c *Client[C, Q, M, R]) init(cfg C) error {
101+
body, err := c.opts.Codec.Encode(cfg)
102+
if err != nil {
103+
return fmt.Errorf("failed to encode config: %w", err)
104+
}
105+
var (
106+
messagec = make(chan Message, 10)
107+
errc = make(chan error, 1)
108+
)
109+
110+
go func() {
111+
err := c.rawClient.Execute(
112+
func(m *Message) {
113+
m.Body = body
114+
m.Header.Status = MessageStatusInitServer
115+
},
116+
messagec,
117+
)
118+
if err != nil {
119+
errc <- fmt.Errorf("failed to execute init: %w", err)
120+
}
121+
}()
122+
123+
select {
124+
case err := <-errc:
125+
return err
126+
case m := <-messagec:
127+
if m.Header.Status != MessageStatusOK {
128+
return fmt.Errorf("failed to init: %s (error code %d)", m.Body, m.Header.Status)
129+
}
130+
}
131+
132+
return nil
133+
}
134+
92135
// Execute sends the request to the server and returns the result.
93136
// You should check Err() both before and after reading from the messages and receipt channels.
94-
func (c *Client[Q, M, R]) Execute(r Q) Result[M, R] {
137+
func (c *Client[C, Q, M, R]) Execute(r Q) Result[M, R] {
95138
result := Result[M, R]{
96139
messages: make(chan M, 10),
97140
receipt: make(chan R, 1),
@@ -112,28 +155,31 @@ func (c *Client[Q, M, R]) Execute(r Q) Result[M, R] {
112155

113156
messagesRaw := make(chan Message, 10)
114157
go func() {
115-
err := c.rawClient.Execute(body, messagesRaw)
158+
err := c.rawClient.Execute(func(m *Message) { m.Body = body }, messagesRaw)
116159
if err != nil {
117160
result.errc <- fmt.Errorf("failed to execute: %w", err)
118161
}
119162
}()
120163

121164
for message := range messagesRaw {
122-
if message.Header.Status > MessageStatusContinue && message.Header.Status <= MessageStatusSystemReservedMax {
165+
if message.Header.Status >= MessageStatusErrDecodeFailed && message.Header.Status <= MessageStatusSystemReservedMax {
123166
// All of these are currently error situations produced by the server.
124167
result.errc <- fmt.Errorf("%s (error code %d)", message.Body, message.Header.Status)
125168
return
126169
}
127170

128-
if message.Header.Status == MessageStatusContinue {
171+
switch message.Header.Status {
172+
case MessageStatusContinue:
129173
var resp M
130174
err = c.opts.Codec.Decode(message.Body, &resp)
131175
if err != nil {
132176
result.errc <- err
133177
return
134178
}
135179
result.messages <- resp
136-
} else {
180+
case MessageStatusInitServer:
181+
panic("unexpected status")
182+
default:
137183
// Receipt.
138184
var rec R
139185
err = c.opts.Codec.Decode(message.Body, &rec)
@@ -152,7 +198,7 @@ func (c *Client[Q, M, R]) Execute(r Q) Result[M, R] {
152198
}
153199

154200
// Close closes the client.
155-
func (c *Client[Q, M, R]) Close() error {
201+
func (c *Client[C, Q, M, R]) Close() error {
156202
return c.rawClient.Close()
157203
}
158204

@@ -248,10 +294,10 @@ func (c *ClientRaw) Close() error {
248294
// Execute sends body to the server and sends any messages to the messages channel.
249295
// It's safe to call Execute from multiple goroutines.
250296
// The messages channel wil be closed when the call is done.
251-
func (c *ClientRaw) Execute(body []byte, messages chan<- Message) error {
297+
func (c *ClientRaw) Execute(withMessage func(m *Message), messages chan<- Message) error {
252298
defer close(messages)
253299

254-
call, err := c.newCall(body, messages)
300+
call, err := c.newCall(withMessage, messages)
255301
if err != nil {
256302
return err
257303
}
@@ -276,20 +322,21 @@ func (c *ClientRaw) addErrContext(op string, err error) error {
276322
return fmt.Errorf("%s: %s %s", op, err, c.conn.stdErr.String())
277323
}
278324

279-
func (c *ClientRaw) newCall(body []byte, messages chan<- Message) (*call, error) {
325+
func (c *ClientRaw) newCall(withMessage func(m *Message), messages chan<- Message) (*call, error) {
280326
c.mu.Lock()
281327
c.seq++
282328
id := c.seq
329+
m := Message{
330+
Header: Header{
331+
Version: c.version,
332+
ID: id,
333+
},
334+
}
335+
withMessage(&m)
283336

284337
call := &call{
285-
Done: make(chan *call, 1),
286-
Request: Message{
287-
Header: Header{
288-
Version: c.version,
289-
ID: id,
290-
},
291-
Body: body,
292-
},
338+
Done: make(chan *call, 1),
339+
Request: m,
293340
Messages: messages,
294341
}
295342

@@ -384,8 +431,13 @@ func (c *ClientRaw) send(call *call) error {
384431
}
385432

386433
// ClientOptions are options for the client.
387-
type ClientOptions[Q, M, R any] struct {
434+
type ClientOptions[C, Q, M, R any] struct {
388435
ClientRawOptions
436+
437+
// The configuration to pass to the server.
438+
Config C
439+
440+
// The codec to use.
389441
Codec codecs.Codec
390442
}
391443

0 commit comments

Comments
 (0)