Skip to content

Commit 687cc2e

Browse files
committed
Factor redundant code
1 parent d600edd commit 687cc2e

6 files changed

Lines changed: 316 additions & 350 deletions

File tree

dnscrypt-proxy/common.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@ import (
44
"bytes"
55
"encoding/binary"
66
"errors"
7+
"fmt"
8+
"io"
79
"net"
810
"os"
911
"strconv"
1012
"strings"
1113
"sync"
14+
"time"
1215
"unicode"
1316
)
1417

@@ -171,3 +174,110 @@ func ReadTextFile(filename string) (string, error) {
171174
}
172175

173176
func isDigit(b byte) bool { return b >= '0' && b <= '9' }
177+
178+
// ExtractClientIPStr extracts client IP string from pluginsState based on protocol
179+
func ExtractClientIPStr(pluginsState *PluginsState) (string, bool) {
180+
switch pluginsState.clientProto {
181+
case "udp":
182+
return (*pluginsState.clientAddr).(*net.UDPAddr).IP.String(), true
183+
case "tcp", "local_doh":
184+
return (*pluginsState.clientAddr).(*net.TCPAddr).IP.String(), true
185+
default:
186+
return "", false
187+
}
188+
}
189+
190+
// FormatLogLine formats a log line based on the specified format (tsv or ltsv)
191+
func FormatLogLine(format, clientIP, qName, reason string, additionalFields ...string) (string, error) {
192+
if format == "tsv" {
193+
now := time.Now()
194+
year, month, day := now.Date()
195+
hour, minute, second := now.Clock()
196+
tsStr := fmt.Sprintf("[%d-%02d-%02d %02d:%02d:%02d]", year, int(month), day, hour, minute, second)
197+
198+
line := fmt.Sprintf("%s\t%s\t%s\t%s", tsStr, clientIP, StringQuote(qName), StringQuote(reason))
199+
for _, field := range additionalFields {
200+
line += fmt.Sprintf("\t%s", StringQuote(field))
201+
}
202+
return line + "\n", nil
203+
} else if format == "ltsv" {
204+
line := fmt.Sprintf("time:%d\thost:%s\tqname:%s\tmessage:%s", time.Now().Unix(), clientIP, StringQuote(qName), StringQuote(reason))
205+
206+
// For LTSV format, additional fields are added with specific labels
207+
for i, field := range additionalFields {
208+
if i == 0 {
209+
line += fmt.Sprintf("\tip:%s", StringQuote(field))
210+
} else {
211+
line += fmt.Sprintf("\tfield%d:%s", i, StringQuote(field))
212+
}
213+
}
214+
return line + "\n", nil
215+
}
216+
return "", fmt.Errorf("unexpected log format: [%s]", format)
217+
}
218+
219+
// WritePluginLog writes a log entry for plugin actions
220+
func WritePluginLog(logger io.Writer, format, clientIP, qName, reason string, additionalFields ...string) error {
221+
if logger == nil {
222+
return errors.New("Log file not initialized")
223+
}
224+
225+
line, err := FormatLogLine(format, clientIP, qName, reason, additionalFields...)
226+
if err != nil {
227+
return err
228+
}
229+
230+
_, err = logger.Write([]byte(line))
231+
return err
232+
}
233+
234+
// ParseTimeBasedRule parses a rule line that may contain time-based restrictions (@timerange)
235+
func ParseTimeBasedRule(line string, lineNo int, allWeeklyRanges *map[string]WeeklyRanges) (rulePart string, weeklyRanges *WeeklyRanges, err error) {
236+
parts := strings.Split(line, "@")
237+
timeRangeName := ""
238+
239+
if len(parts) == 2 {
240+
rulePart = strings.TrimSpace(parts[0])
241+
timeRangeName = strings.TrimSpace(parts[1])
242+
} else if len(parts) > 2 {
243+
return "", nil, fmt.Errorf("syntax error at line %d -- Unexpected @ character", 1+lineNo)
244+
} else {
245+
rulePart = line
246+
}
247+
248+
if len(timeRangeName) > 0 {
249+
if weeklyRangesX, ok := (*allWeeklyRanges)[timeRangeName]; ok {
250+
weeklyRanges = &weeklyRangesX
251+
} else {
252+
return "", nil, fmt.Errorf("time range [%s] not found at line %d", timeRangeName, 1+lineNo)
253+
}
254+
}
255+
256+
return rulePart, weeklyRanges, nil
257+
}
258+
259+
// ParseIPRule parses and validates an IP rule line
260+
func ParseIPRule(line string, lineNo int) (cleanLine string, trailingStar bool, err error) {
261+
ip := net.ParseIP(line)
262+
trailingStar = strings.HasSuffix(line, "*")
263+
264+
if len(line) < 2 || (ip != nil && trailingStar) {
265+
return "", false, fmt.Errorf("suspicious IP rule [%s] at line %d", line, lineNo)
266+
}
267+
268+
cleanLine = line
269+
if trailingStar {
270+
cleanLine = cleanLine[:len(cleanLine)-1]
271+
}
272+
if strings.HasSuffix(cleanLine, ":") || strings.HasSuffix(cleanLine, ".") {
273+
cleanLine = cleanLine[:len(cleanLine)-1]
274+
}
275+
if len(cleanLine) == 0 {
276+
return "", false, fmt.Errorf("empty IP rule at line %d", lineNo)
277+
}
278+
if strings.Contains(cleanLine, "*") {
279+
return "", false, fmt.Errorf("invalid rule: [%s] - wildcards can only be used as a suffix at line %d", line, lineNo)
280+
}
281+
282+
return strings.ToLower(cleanLine), trailingStar, nil
283+
}

dnscrypt-proxy/plugin_allow_ip.go

Lines changed: 42 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,9 @@ package main
22

33
import (
44
"errors"
5-
"fmt"
65
"io"
7-
"net"
86
"strings"
97
"sync"
10-
"time"
118

129
iradix "github.com/hashicorp/go-immutable-radix"
1310
"github.com/jedisct1/dlog"
@@ -69,33 +66,16 @@ func (plugin *PluginAllowedIP) loadRules(lines string, prefixes *iradix.Tree, ip
6966
continue
7067
}
7168

72-
ip := net.ParseIP(line)
73-
trailingStar := strings.HasSuffix(line, "*")
74-
if len(line) < 2 || (ip != nil && trailingStar) {
75-
dlog.Errorf("Suspicious allowed IP rule [%s] at line %d", line, lineNo)
69+
cleanLine, trailingStar, err := ParseIPRule(line, lineNo)
70+
if err != nil {
71+
dlog.Error(err)
7672
continue
7773
}
7874

7975
if trailingStar {
80-
line = line[:len(line)-1]
81-
}
82-
if strings.HasSuffix(line, ":") || strings.HasSuffix(line, ".") {
83-
line = line[:len(line)-1]
84-
}
85-
if len(line) == 0 {
86-
dlog.Errorf("Empty allowed IP rule at line %d", lineNo)
87-
continue
88-
}
89-
if strings.Contains(line, "*") {
90-
dlog.Errorf("Invalid rule: [%s] - wildcards can only be used as a suffix at line %d", line, lineNo)
91-
continue
92-
}
93-
94-
line = strings.ToLower(line)
95-
if trailingStar {
96-
prefixes, _, _ = prefixes.Insert([]byte(line), 0)
76+
prefixes, _, _ = prefixes.Insert([]byte(cleanLine), 0)
9777
} else {
98-
ips[line] = true
78+
ips[cleanLine] = true
9979
}
10080
}
10181

@@ -111,41 +91,35 @@ func (plugin *PluginAllowedIP) Drop() error {
11191

11292
// PrepareReload loads new rules into staging structures but doesn't apply them yet
11393
func (plugin *PluginAllowedIP) PrepareReload() error {
114-
// Read the configuration file
115-
lines, err := SafeReadTextFile(plugin.configFile)
116-
if err != nil {
117-
return fmt.Errorf("error reading config file during reload preparation: %w", err)
118-
}
119-
120-
// Create staging structures
121-
plugin.stagingPrefixes = iradix.New()
122-
plugin.stagingIPs = make(map[string]interface{})
123-
124-
// Load rules into staging structures
125-
plugin.stagingPrefixes, err = plugin.loadRules(lines, plugin.stagingPrefixes, plugin.stagingIPs)
126-
if err != nil {
127-
return fmt.Errorf("error parsing config during reload preparation: %w", err)
128-
}
129-
130-
return nil
94+
return StandardPrepareReloadPattern(plugin.Name(), plugin.configFile, func(lines string) error {
95+
// Create staging structures
96+
plugin.stagingPrefixes = iradix.New()
97+
plugin.stagingIPs = make(map[string]interface{})
98+
99+
// Load rules into staging structures
100+
var err error
101+
plugin.stagingPrefixes, err = plugin.loadRules(lines, plugin.stagingPrefixes, plugin.stagingIPs)
102+
return err
103+
})
131104
}
132105

133106
// ApplyReload atomically replaces the active rules with the staging ones
134107
func (plugin *PluginAllowedIP) ApplyReload() error {
135-
if plugin.stagingPrefixes == nil || plugin.stagingIPs == nil {
136-
return errors.New("no staged configuration to apply")
137-
}
108+
return StandardApplyReloadPattern(plugin.Name(), func() error {
109+
if plugin.stagingPrefixes == nil || plugin.stagingIPs == nil {
110+
return errors.New("no staged configuration to apply")
111+
}
138112

139-
// Use write lock to swap rule structures
140-
plugin.rwLock.Lock()
141-
plugin.allowedPrefixes = plugin.stagingPrefixes
142-
plugin.allowedIPs = plugin.stagingIPs
143-
plugin.stagingPrefixes = nil
144-
plugin.stagingIPs = nil
145-
plugin.rwLock.Unlock()
113+
// Use write lock to swap rule structures
114+
plugin.rwLock.Lock()
115+
plugin.allowedPrefixes = plugin.stagingPrefixes
116+
plugin.allowedIPs = plugin.stagingIPs
117+
plugin.stagingPrefixes = nil
118+
plugin.stagingIPs = nil
119+
plugin.rwLock.Unlock()
146120

147-
dlog.Noticef("Applied new configuration for plugin [%s]", plugin.Name())
148-
return nil
121+
return nil
122+
})
149123
}
150124

151125
// CancelReload cleans up any staging resources
@@ -156,16 +130,16 @@ func (plugin *PluginAllowedIP) CancelReload() {
156130

157131
// Reload implements hot-reloading for the plugin
158132
func (plugin *PluginAllowedIP) Reload() error {
159-
dlog.Noticef("Reloading configuration for plugin [%s]", plugin.Name())
160-
161-
// Prepare the new configuration
162-
if err := plugin.PrepareReload(); err != nil {
163-
plugin.CancelReload()
164-
return err
165-
}
133+
return StandardReloadPattern(plugin.Name(), func() error {
134+
// Prepare the new configuration
135+
if err := plugin.PrepareReload(); err != nil {
136+
plugin.CancelReload()
137+
return err
138+
}
166139

167-
// Apply the new configuration
168-
return plugin.ApplyReload()
140+
// Apply the new configuration
141+
return plugin.ApplyReload()
142+
})
169143
}
170144

171145
// GetConfigPath returns the path to the plugin's configuration file
@@ -218,39 +192,15 @@ func (plugin *PluginAllowedIP) Eval(pluginsState *PluginsState, msg *dns.Msg) er
218192
pluginsState.sessionData["whitelisted"] = true
219193
if plugin.logger != nil {
220194
qName := pluginsState.qName
221-
var clientIPStr string
222-
switch pluginsState.clientProto {
223-
case "udp":
224-
clientIPStr = (*pluginsState.clientAddr).(*net.UDPAddr).IP.String()
225-
case "tcp", "local_doh":
226-
clientIPStr = (*pluginsState.clientAddr).(*net.TCPAddr).IP.String()
227-
default:
195+
clientIPStr, ok := ExtractClientIPStr(pluginsState)
196+
if !ok {
228197
// Ignore internal flow.
229198
return nil
230199
}
231-
var line string
232-
if plugin.format == "tsv" {
233-
now := time.Now()
234-
year, month, day := now.Date()
235-
hour, minute, second := now.Clock()
236-
tsStr := fmt.Sprintf("[%d-%02d-%02d %02d:%02d:%02d]", year, int(month), day, hour, minute, second)
237-
line = fmt.Sprintf(
238-
"%s\t%s\t%s\t%s\t%s\n",
239-
tsStr,
240-
clientIPStr,
241-
StringQuote(qName),
242-
StringQuote(ipStr),
243-
StringQuote(reason),
244-
)
245-
} else if plugin.format == "ltsv" {
246-
line = fmt.Sprintf("time:%d\thost:%s\tqname:%s\tip:%s\tmessage:%s\n", time.Now().Unix(), clientIPStr, StringQuote(qName), StringQuote(ipStr), StringQuote(reason))
247-
} else {
248-
dlog.Fatalf("Unexpected log format: [%s]", plugin.format)
249-
}
250-
if plugin.logger == nil {
251-
return errors.New("Log file not initialized")
200+
201+
if err := WritePluginLog(plugin.logger, plugin.format, clientIPStr, qName, reason, ipStr); err != nil {
202+
return err
252203
}
253-
_, _ = plugin.logger.Write([]byte(line))
254204
}
255205
}
256206
return nil

0 commit comments

Comments
 (0)