Skip to content

Commit cd69ec3

Browse files
committed
Refactor config reloading
1 parent a8704e6 commit cd69ec3

3 files changed

Lines changed: 83 additions & 58 deletions

File tree

pkg/cmd/cmd.go

Lines changed: 10 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ import (
88

99
"github.com/alecthomas/kong"
1010
"github.com/cedws/sshgate/pkg/sshgate"
11-
"github.com/fsnotify/fsnotify"
1211
)
1312

1413
type cli struct {
15-
ListenAddr string `help:"Address to listen on" default:":2222"`
16-
Ruleless bool `help:"Run in ruleless mode"`
17-
LogFormat string `help:"Log format"`
18-
Config string `help:"Path to JSON config file"`
14+
ListenAddr string `help:"Address to listen on" default:":2222"`
15+
Ruleless bool `help:"Run in ruleless mode"`
16+
NoConfigReload bool `help:"Disable config reload"`
17+
LogFormat string `help:"Log format"`
18+
Config string `help:"Path to JSON config file"`
1919

2020
Serve serveCmd `cmd:"" default:"1" help:"Start the server"`
2121
JSONSchema jsonschemaCmd `cmd:"" name:"jsonschema" help:"Print config JSON schema"`
@@ -43,74 +43,27 @@ func (s *serveCmd) Run(cli *cli) error {
4343
return err
4444
}
4545

46-
if err := serveUntilReload(context.Background(), cli, config); err != nil && !errors.Is(err, context.Canceled) {
46+
if err := serve(context.Background(), cli, config); err != nil && !errors.Is(err, context.Canceled) {
4747
return err
4848
}
4949
}
5050
}
5151

52-
func serveUntilReload(ctx context.Context, cli *cli, config *sshgate.Config) error {
53-
watcher, err := fsnotify.NewWatcher()
54-
if err != nil {
55-
return err
56-
}
57-
defer watcher.Close()
58-
59-
if err := watcher.Add(cli.Config); err != nil {
60-
return err
61-
}
62-
63-
for _, hostKeyPath := range []string{
64-
config.HostKeyPaths.ECDSA,
65-
config.HostKeyPaths.ED25519,
66-
config.HostKeyPaths.RSA,
67-
} {
68-
if hostKeyPath != "" {
69-
if err := watcher.Add(hostKeyPath); err != nil {
70-
return err
71-
}
72-
}
73-
}
74-
75-
ctx, cancel := fsnotifyContext(ctx, watcher)
76-
defer cancel()
77-
78-
return serve(ctx, cli, config)
79-
}
80-
8152
func serve(ctx context.Context, c *cli, config *sshgate.Config) error {
8253
var opts []sshgate.Option
54+
8355
if c.Ruleless {
8456
opts = append(opts, sshgate.WithRulelessMode())
8557
}
58+
if !c.NoConfigReload {
59+
opts = append(opts, sshgate.WithConfigReload())
60+
}
8661

8762
server := sshgate.New(config, c.ListenAddr, opts...)
8863

8964
return server.ListenAndServe(ctx)
9065
}
9166

92-
func fsnotifyContext(ctx context.Context, watcher *fsnotify.Watcher) (context.Context, context.CancelFunc) {
93-
ctx, cancel := context.WithCancel(ctx)
94-
95-
go func() {
96-
for {
97-
select {
98-
case evt := <-watcher.Events:
99-
if evt.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Rename) != 0 {
100-
slog.Info("config file changed, reloading")
101-
cancel()
102-
}
103-
case <-watcher.Errors:
104-
cancel()
105-
case <-ctx.Done():
106-
return
107-
}
108-
}
109-
}()
110-
111-
return ctx, cancel
112-
}
113-
11467
func Execute() {
11568
var cli cli
11669

pkg/sshgate/config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ type Config struct {
4343
Tsnet Tsnet `json:"tsnet"`
4444
HostKeyPaths HostKeyPaths `json:"host_key_paths"`
4545

46+
path string
4647
signers []ssh.Signer
4748
parsedPolicies parsedPolicies
4849
}
@@ -92,6 +93,8 @@ func ReadConfig(path string) (*Config, error) {
9293
Hostname: defaultTsnetHostname,
9394
Port: defaultTsnetPort,
9495
},
96+
97+
path: path,
9598
}
9699
if err := json.Unmarshal(data, &config); err != nil {
97100
return nil, fmt.Errorf("failed to parse config JSON: %w", err)

pkg/sshgate/sshgate.go

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"sync/atomic"
1818
"time"
1919

20+
"github.com/fsnotify/fsnotify"
2021
"golang.org/x/crypto/ssh"
2122
"golang.org/x/sync/errgroup"
2223
"tailscale.com/client/local"
@@ -112,7 +113,8 @@ func (d *directTCPIPExtraData) UnmarshalBinary(b []byte) error {
112113
}
113114

114115
type Options struct {
115-
Ruleless bool
116+
Ruleless bool
117+
ConfigReload bool
116118
}
117119

118120
type Option func(*Options)
@@ -123,6 +125,12 @@ func WithRulelessMode() Option {
123125
}
124126
}
125127

128+
func WithConfigReload() Option {
129+
return func(o *Options) {
130+
o.ConfigReload = true
131+
}
132+
}
133+
126134
type Server struct {
127135
config *Config
128136
listenAddr string
@@ -158,6 +166,14 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
158166
slog.Warn("running in ruleless mode")
159167
}
160168

169+
if s.options.ConfigReload {
170+
var err error
171+
ctx, err = notifyConfigReload(ctx, s.config)
172+
if err != nil {
173+
return err
174+
}
175+
}
176+
161177
errgroup, ctx := errgroup.WithContext(ctx)
162178

163179
errgroup.Go(func() error {
@@ -177,6 +193,59 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
177193
return errgroup.Wait()
178194
}
179195

196+
func notifyConfigReload(ctx context.Context, config *Config) (context.Context, error) {
197+
watcher, err := fsnotify.NewWatcher()
198+
if err != nil {
199+
return nil, err
200+
}
201+
go func() {
202+
<-ctx.Done()
203+
watcher.Close()
204+
}()
205+
206+
if err := watcher.Add(config.path); err != nil {
207+
return nil, err
208+
}
209+
210+
for _, hostKeyPath := range []string{
211+
config.HostKeyPaths.ECDSA,
212+
config.HostKeyPaths.ED25519,
213+
config.HostKeyPaths.RSA,
214+
} {
215+
if hostKeyPath != "" {
216+
if err := watcher.Add(hostKeyPath); err != nil {
217+
return nil, err
218+
}
219+
}
220+
}
221+
222+
return fsnotifyContext(ctx, watcher), nil
223+
}
224+
225+
func fsnotifyContext(ctx context.Context, watcher *fsnotify.Watcher) context.Context {
226+
ctx, cancel := context.WithCancel(ctx)
227+
228+
go func() {
229+
for {
230+
select {
231+
case evt := <-watcher.Events:
232+
if evt.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Rename) != 0 {
233+
slog.Info("config file changed, reloading")
234+
cancel()
235+
return
236+
}
237+
case <-watcher.Errors:
238+
cancel()
239+
return
240+
case <-ctx.Done():
241+
return
242+
}
243+
}
244+
}()
245+
246+
return ctx
247+
}
248+
180249
func (s *Server) logStats(ctx context.Context) error {
181250
for {
182251
select {

0 commit comments

Comments
 (0)