ScuttleBot
feat: config API exposes agent_policy + logging, HTTP connector self-registers, gemini-relay input loop starts from online time
Commit
763c8734c4f192335e5b32af25f551756f40dcfafd911d810b852cf8e48e98fb
Parent
7597cf0c9f6cc8e…
10 files changed
+5
-3
+78
-15
+310
+33
-2
+27
-7
+143
+158
+65
-8
+77
-16
+2
~
cmd/gemini-relay/main.go
~
internal/api/config_handlers.go
~
internal/api/config_handlers_test.go
~
internal/api/settings.go
~
internal/config/config.go
~
internal/config/config_test.go
~
internal/registry/registry_test.go
~
pkg/sessionrelay/http.go
~
pkg/sessionrelay/irc.go
~
pkg/sessionrelay/sessionrelay_test.go
+5
-3
| --- cmd/gemini-relay/main.go | ||
| +++ cmd/gemini-relay/main.go | ||
| @@ -104,10 +104,11 @@ | ||
| 104 | 104 | _ = sessionrelay.RemoveChannelStateFile(cfg.ChannelStateFile) |
| 105 | 105 | defer func() { _ = sessionrelay.RemoveChannelStateFile(cfg.ChannelStateFile) }() |
| 106 | 106 | |
| 107 | 107 | var relay sessionrelay.Connector |
| 108 | 108 | relayActive := false |
| 109 | + var onlineAt time.Time | |
| 109 | 110 | if relayRequested { |
| 110 | 111 | conn, err := sessionrelay.New(sessionrelay.Config{ |
| 111 | 112 | Transport: cfg.Transport, |
| 112 | 113 | URL: cfg.URL, |
| 113 | 114 | Token: cfg.Token, |
| @@ -132,10 +133,11 @@ | ||
| 132 | 133 | relay = conn |
| 133 | 134 | relayActive = true |
| 134 | 135 | if err := sessionrelay.WriteChannelStateFile(cfg.ChannelStateFile, relay.ControlChannel(), relay.Channels()); err != nil { |
| 135 | 136 | fmt.Fprintf(os.Stderr, "gemini-relay: channel state disabled: %v\n", err) |
| 136 | 137 | } |
| 138 | + onlineAt = time.Now() | |
| 137 | 139 | _ = relay.Post(context.Background(), fmt.Sprintf( |
| 138 | 140 | "online in %s; mention %s to interrupt before the next action", |
| 139 | 141 | filepath.Base(cfg.TargetCWD), cfg.Nick, |
| 140 | 142 | )) |
| 141 | 143 | } |
| @@ -215,11 +217,11 @@ | ||
| 215 | 217 | }() |
| 216 | 218 | go func() { |
| 217 | 219 | copyPTYOutput(ptmx, os.Stdout, state) |
| 218 | 220 | }() |
| 219 | 221 | if relayActive { |
| 220 | - go relayInputLoop(ctx, relay, cfg, state, ptmx) | |
| 222 | + go relayInputLoop(ctx, relay, cfg, state, ptmx, onlineAt) | |
| 221 | 223 | } |
| 222 | 224 | |
| 223 | 225 | err = cmd.Wait() |
| 224 | 226 | cancel() |
| 225 | 227 | |
| @@ -228,12 +230,12 @@ | ||
| 228 | 230 | _ = relay.Post(context.Background(), fmt.Sprintf("offline (exit %d)", exitCode)) |
| 229 | 231 | } |
| 230 | 232 | return err |
| 231 | 233 | } |
| 232 | 234 | |
| 233 | -func relayInputLoop(ctx context.Context, relay sessionrelay.Connector, cfg config, state *relayState, ptyFile *os.File) { | |
| 234 | - lastSeen := time.Now() | |
| 235 | +func relayInputLoop(ctx context.Context, relay sessionrelay.Connector, cfg config, state *relayState, ptyFile *os.File, since time.Time) { | |
| 236 | + lastSeen := since | |
| 235 | 237 | ticker := time.NewTicker(cfg.PollInterval) |
| 236 | 238 | defer ticker.Stop() |
| 237 | 239 | |
| 238 | 240 | for { |
| 239 | 241 | select { |
| 240 | 242 |
| --- cmd/gemini-relay/main.go | |
| +++ cmd/gemini-relay/main.go | |
| @@ -104,10 +104,11 @@ | |
| 104 | _ = sessionrelay.RemoveChannelStateFile(cfg.ChannelStateFile) |
| 105 | defer func() { _ = sessionrelay.RemoveChannelStateFile(cfg.ChannelStateFile) }() |
| 106 | |
| 107 | var relay sessionrelay.Connector |
| 108 | relayActive := false |
| 109 | if relayRequested { |
| 110 | conn, err := sessionrelay.New(sessionrelay.Config{ |
| 111 | Transport: cfg.Transport, |
| 112 | URL: cfg.URL, |
| 113 | Token: cfg.Token, |
| @@ -132,10 +133,11 @@ | |
| 132 | relay = conn |
| 133 | relayActive = true |
| 134 | if err := sessionrelay.WriteChannelStateFile(cfg.ChannelStateFile, relay.ControlChannel(), relay.Channels()); err != nil { |
| 135 | fmt.Fprintf(os.Stderr, "gemini-relay: channel state disabled: %v\n", err) |
| 136 | } |
| 137 | _ = relay.Post(context.Background(), fmt.Sprintf( |
| 138 | "online in %s; mention %s to interrupt before the next action", |
| 139 | filepath.Base(cfg.TargetCWD), cfg.Nick, |
| 140 | )) |
| 141 | } |
| @@ -215,11 +217,11 @@ | |
| 215 | }() |
| 216 | go func() { |
| 217 | copyPTYOutput(ptmx, os.Stdout, state) |
| 218 | }() |
| 219 | if relayActive { |
| 220 | go relayInputLoop(ctx, relay, cfg, state, ptmx) |
| 221 | } |
| 222 | |
| 223 | err = cmd.Wait() |
| 224 | cancel() |
| 225 | |
| @@ -228,12 +230,12 @@ | |
| 228 | _ = relay.Post(context.Background(), fmt.Sprintf("offline (exit %d)", exitCode)) |
| 229 | } |
| 230 | return err |
| 231 | } |
| 232 | |
| 233 | func relayInputLoop(ctx context.Context, relay sessionrelay.Connector, cfg config, state *relayState, ptyFile *os.File) { |
| 234 | lastSeen := time.Now() |
| 235 | ticker := time.NewTicker(cfg.PollInterval) |
| 236 | defer ticker.Stop() |
| 237 | |
| 238 | for { |
| 239 | select { |
| 240 |
| --- cmd/gemini-relay/main.go | |
| +++ cmd/gemini-relay/main.go | |
| @@ -104,10 +104,11 @@ | |
| 104 | _ = sessionrelay.RemoveChannelStateFile(cfg.ChannelStateFile) |
| 105 | defer func() { _ = sessionrelay.RemoveChannelStateFile(cfg.ChannelStateFile) }() |
| 106 | |
| 107 | var relay sessionrelay.Connector |
| 108 | relayActive := false |
| 109 | var onlineAt time.Time |
| 110 | if relayRequested { |
| 111 | conn, err := sessionrelay.New(sessionrelay.Config{ |
| 112 | Transport: cfg.Transport, |
| 113 | URL: cfg.URL, |
| 114 | Token: cfg.Token, |
| @@ -132,10 +133,11 @@ | |
| 133 | relay = conn |
| 134 | relayActive = true |
| 135 | if err := sessionrelay.WriteChannelStateFile(cfg.ChannelStateFile, relay.ControlChannel(), relay.Channels()); err != nil { |
| 136 | fmt.Fprintf(os.Stderr, "gemini-relay: channel state disabled: %v\n", err) |
| 137 | } |
| 138 | onlineAt = time.Now() |
| 139 | _ = relay.Post(context.Background(), fmt.Sprintf( |
| 140 | "online in %s; mention %s to interrupt before the next action", |
| 141 | filepath.Base(cfg.TargetCWD), cfg.Nick, |
| 142 | )) |
| 143 | } |
| @@ -215,11 +217,11 @@ | |
| 217 | }() |
| 218 | go func() { |
| 219 | copyPTYOutput(ptmx, os.Stdout, state) |
| 220 | }() |
| 221 | if relayActive { |
| 222 | go relayInputLoop(ctx, relay, cfg, state, ptmx, onlineAt) |
| 223 | } |
| 224 | |
| 225 | err = cmd.Wait() |
| 226 | cancel() |
| 227 | |
| @@ -228,12 +230,12 @@ | |
| 230 | _ = relay.Post(context.Background(), fmt.Sprintf("offline (exit %d)", exitCode)) |
| 231 | } |
| 232 | return err |
| 233 | } |
| 234 | |
| 235 | func relayInputLoop(ctx context.Context, relay sessionrelay.Connector, cfg config, state *relayState, ptyFile *os.File, since time.Time) { |
| 236 | lastSeen := since |
| 237 | ticker := time.NewTicker(cfg.PollInterval) |
| 238 | defer ticker.Stop() |
| 239 | |
| 240 | for { |
| 241 | select { |
| 242 |
+78
-15
| --- internal/api/config_handlers.go | ||
| +++ internal/api/config_handlers.go | ||
| @@ -9,18 +9,20 @@ | ||
| 9 | 9 | ) |
| 10 | 10 | |
| 11 | 11 | // configView is the JSON shape returned by GET /v1/config. |
| 12 | 12 | // Secrets are masked — zero values mean "no change" on PUT. |
| 13 | 13 | type configView struct { |
| 14 | - APIAddr string `json:"api_addr"` | |
| 15 | - MCPAddr string `json:"mcp_addr"` | |
| 16 | - Bridge bridgeConfigView `json:"bridge"` | |
| 17 | - Ergo ergoConfigView `json:"ergo"` | |
| 18 | - TLS tlsConfigView `json:"tls"` | |
| 19 | - LLM llmConfigView `json:"llm"` | |
| 20 | - Topology config.TopologyConfig `json:"topology"` | |
| 21 | - History config.ConfigHistoryConfig `json:"config_history"` | |
| 14 | + APIAddr string `json:"api_addr"` | |
| 15 | + MCPAddr string `json:"mcp_addr"` | |
| 16 | + Bridge bridgeConfigView `json:"bridge"` | |
| 17 | + Ergo ergoConfigView `json:"ergo"` | |
| 18 | + TLS tlsConfigView `json:"tls"` | |
| 19 | + LLM llmConfigView `json:"llm"` | |
| 20 | + Topology config.TopologyConfig `json:"topology"` | |
| 21 | + History config.ConfigHistoryConfig `json:"config_history"` | |
| 22 | + AgentPolicy config.AgentPolicyConfig `json:"agent_policy"` | |
| 23 | + Logging config.LoggingConfig `json:"logging"` | |
| 22 | 24 | } |
| 23 | 25 | |
| 24 | 26 | type bridgeConfigView struct { |
| 25 | 27 | Enabled bool `json:"enabled"` |
| 26 | 28 | Nick string `json:"nick"` |
| @@ -95,13 +97,15 @@ | ||
| 95 | 97 | TLS: tlsConfigView{ |
| 96 | 98 | Domain: cfg.TLS.Domain, |
| 97 | 99 | Email: cfg.TLS.Email, |
| 98 | 100 | AllowInsecure: cfg.TLS.AllowInsecure, |
| 99 | 101 | }, |
| 100 | - LLM: llmConfigView{Backends: backends}, | |
| 101 | - Topology: cfg.Topology, | |
| 102 | - History: cfg.History, | |
| 102 | + LLM: llmConfigView{Backends: backends}, | |
| 103 | + Topology: cfg.Topology, | |
| 104 | + History: cfg.History, | |
| 105 | + AgentPolicy: cfg.AgentPolicy, | |
| 106 | + Logging: cfg.Logging, | |
| 103 | 107 | } |
| 104 | 108 | } |
| 105 | 109 | |
| 106 | 110 | // handleGetConfig handles GET /v1/config. |
| 107 | 111 | func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { |
| @@ -111,18 +115,35 @@ | ||
| 111 | 115 | |
| 112 | 116 | // configUpdateRequest is the body accepted by PUT /v1/config. |
| 113 | 117 | // Only the mutable, hot-reloadable sections. Restart-required fields (ergo IRC |
| 114 | 118 | // addr, TLS domain, api_addr) are accepted but flagged in the response. |
| 115 | 119 | type configUpdateRequest struct { |
| 116 | - Bridge *bridgeConfigUpdate `json:"bridge,omitempty"` | |
| 117 | - Topology *config.TopologyConfig `json:"topology,omitempty"` | |
| 118 | - History *config.ConfigHistoryConfig `json:"config_history,omitempty"` | |
| 119 | - LLM *llmConfigUpdate `json:"llm,omitempty"` | |
| 120 | + Bridge *bridgeConfigUpdate `json:"bridge,omitempty"` | |
| 121 | + Topology *config.TopologyConfig `json:"topology,omitempty"` | |
| 122 | + History *config.ConfigHistoryConfig `json:"config_history,omitempty"` | |
| 123 | + LLM *llmConfigUpdate `json:"llm,omitempty"` | |
| 124 | + AgentPolicy *config.AgentPolicyConfig `json:"agent_policy,omitempty"` | |
| 125 | + Logging *config.LoggingConfig `json:"logging,omitempty"` | |
| 126 | + Ergo *ergoConfigUpdate `json:"ergo,omitempty"` | |
| 127 | + TLS *tlsConfigUpdate `json:"tls,omitempty"` | |
| 120 | 128 | // These fields trigger a restart_required notice but are still persisted. |
| 121 | 129 | APIAddr *string `json:"api_addr,omitempty"` |
| 122 | 130 | MCPAddr *string `json:"mcp_addr,omitempty"` |
| 123 | 131 | } |
| 132 | + | |
| 133 | +type ergoConfigUpdate struct { | |
| 134 | + NetworkName *string `json:"network_name,omitempty"` | |
| 135 | + ServerName *string `json:"server_name,omitempty"` | |
| 136 | + IRCAddr *string `json:"irc_addr,omitempty"` | |
| 137 | + External *bool `json:"external,omitempty"` | |
| 138 | +} | |
| 139 | + | |
| 140 | +type tlsConfigUpdate struct { | |
| 141 | + Domain *string `json:"domain,omitempty"` | |
| 142 | + Email *string `json:"email,omitempty"` | |
| 143 | + AllowInsecure *bool `json:"allow_insecure,omitempty"` | |
| 144 | +} | |
| 124 | 145 | |
| 125 | 146 | type bridgeConfigUpdate struct { |
| 126 | 147 | Enabled *bool `json:"enabled,omitempty"` |
| 127 | 148 | Nick *string `json:"nick,omitempty"` |
| 128 | 149 | Channels []string `json:"channels,omitempty"` |
| @@ -189,10 +210,52 @@ | ||
| 189 | 210 | } |
| 190 | 211 | |
| 191 | 212 | if req.LLM != nil { |
| 192 | 213 | next.LLM.Backends = req.LLM.Backends |
| 193 | 214 | } |
| 215 | + | |
| 216 | + if req.AgentPolicy != nil { | |
| 217 | + next.AgentPolicy = *req.AgentPolicy | |
| 218 | + } | |
| 219 | + | |
| 220 | + if req.Logging != nil { | |
| 221 | + next.Logging = *req.Logging | |
| 222 | + } | |
| 223 | + | |
| 224 | + if req.Ergo != nil { | |
| 225 | + e := req.Ergo | |
| 226 | + if e.NetworkName != nil { | |
| 227 | + next.Ergo.NetworkName = *e.NetworkName | |
| 228 | + restartRequired = appendUniq(restartRequired, "ergo.network_name") | |
| 229 | + } | |
| 230 | + if e.ServerName != nil { | |
| 231 | + next.Ergo.ServerName = *e.ServerName | |
| 232 | + restartRequired = appendUniq(restartRequired, "ergo.server_name") | |
| 233 | + } | |
| 234 | + if e.IRCAddr != nil { | |
| 235 | + next.Ergo.IRCAddr = *e.IRCAddr | |
| 236 | + restartRequired = appendUniq(restartRequired, "ergo.irc_addr") | |
| 237 | + } | |
| 238 | + if e.External != nil { | |
| 239 | + next.Ergo.External = *e.External | |
| 240 | + restartRequired = appendUniq(restartRequired, "ergo.external") | |
| 241 | + } | |
| 242 | + } | |
| 243 | + | |
| 244 | + if req.TLS != nil { | |
| 245 | + t := req.TLS | |
| 246 | + if t.Domain != nil { | |
| 247 | + next.TLS.Domain = *t.Domain | |
| 248 | + restartRequired = appendUniq(restartRequired, "tls.domain") | |
| 249 | + } | |
| 250 | + if t.Email != nil { | |
| 251 | + next.TLS.Email = *t.Email | |
| 252 | + } | |
| 253 | + if t.AllowInsecure != nil { | |
| 254 | + next.TLS.AllowInsecure = *t.AllowInsecure | |
| 255 | + } | |
| 256 | + } | |
| 194 | 257 | |
| 195 | 258 | if req.APIAddr != nil && *req.APIAddr != "" { |
| 196 | 259 | next.APIAddr = *req.APIAddr |
| 197 | 260 | restartRequired = appendUniq(restartRequired, "api_addr") |
| 198 | 261 | } |
| 199 | 262 |
| --- internal/api/config_handlers.go | |
| +++ internal/api/config_handlers.go | |
| @@ -9,18 +9,20 @@ | |
| 9 | ) |
| 10 | |
| 11 | // configView is the JSON shape returned by GET /v1/config. |
| 12 | // Secrets are masked — zero values mean "no change" on PUT. |
| 13 | type configView struct { |
| 14 | APIAddr string `json:"api_addr"` |
| 15 | MCPAddr string `json:"mcp_addr"` |
| 16 | Bridge bridgeConfigView `json:"bridge"` |
| 17 | Ergo ergoConfigView `json:"ergo"` |
| 18 | TLS tlsConfigView `json:"tls"` |
| 19 | LLM llmConfigView `json:"llm"` |
| 20 | Topology config.TopologyConfig `json:"topology"` |
| 21 | History config.ConfigHistoryConfig `json:"config_history"` |
| 22 | } |
| 23 | |
| 24 | type bridgeConfigView struct { |
| 25 | Enabled bool `json:"enabled"` |
| 26 | Nick string `json:"nick"` |
| @@ -95,13 +97,15 @@ | |
| 95 | TLS: tlsConfigView{ |
| 96 | Domain: cfg.TLS.Domain, |
| 97 | Email: cfg.TLS.Email, |
| 98 | AllowInsecure: cfg.TLS.AllowInsecure, |
| 99 | }, |
| 100 | LLM: llmConfigView{Backends: backends}, |
| 101 | Topology: cfg.Topology, |
| 102 | History: cfg.History, |
| 103 | } |
| 104 | } |
| 105 | |
| 106 | // handleGetConfig handles GET /v1/config. |
| 107 | func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { |
| @@ -111,18 +115,35 @@ | |
| 111 | |
| 112 | // configUpdateRequest is the body accepted by PUT /v1/config. |
| 113 | // Only the mutable, hot-reloadable sections. Restart-required fields (ergo IRC |
| 114 | // addr, TLS domain, api_addr) are accepted but flagged in the response. |
| 115 | type configUpdateRequest struct { |
| 116 | Bridge *bridgeConfigUpdate `json:"bridge,omitempty"` |
| 117 | Topology *config.TopologyConfig `json:"topology,omitempty"` |
| 118 | History *config.ConfigHistoryConfig `json:"config_history,omitempty"` |
| 119 | LLM *llmConfigUpdate `json:"llm,omitempty"` |
| 120 | // These fields trigger a restart_required notice but are still persisted. |
| 121 | APIAddr *string `json:"api_addr,omitempty"` |
| 122 | MCPAddr *string `json:"mcp_addr,omitempty"` |
| 123 | } |
| 124 | |
| 125 | type bridgeConfigUpdate struct { |
| 126 | Enabled *bool `json:"enabled,omitempty"` |
| 127 | Nick *string `json:"nick,omitempty"` |
| 128 | Channels []string `json:"channels,omitempty"` |
| @@ -189,10 +210,52 @@ | |
| 189 | } |
| 190 | |
| 191 | if req.LLM != nil { |
| 192 | next.LLM.Backends = req.LLM.Backends |
| 193 | } |
| 194 | |
| 195 | if req.APIAddr != nil && *req.APIAddr != "" { |
| 196 | next.APIAddr = *req.APIAddr |
| 197 | restartRequired = appendUniq(restartRequired, "api_addr") |
| 198 | } |
| 199 |
| --- internal/api/config_handlers.go | |
| +++ internal/api/config_handlers.go | |
| @@ -9,18 +9,20 @@ | |
| 9 | ) |
| 10 | |
| 11 | // configView is the JSON shape returned by GET /v1/config. |
| 12 | // Secrets are masked — zero values mean "no change" on PUT. |
| 13 | type configView struct { |
| 14 | APIAddr string `json:"api_addr"` |
| 15 | MCPAddr string `json:"mcp_addr"` |
| 16 | Bridge bridgeConfigView `json:"bridge"` |
| 17 | Ergo ergoConfigView `json:"ergo"` |
| 18 | TLS tlsConfigView `json:"tls"` |
| 19 | LLM llmConfigView `json:"llm"` |
| 20 | Topology config.TopologyConfig `json:"topology"` |
| 21 | History config.ConfigHistoryConfig `json:"config_history"` |
| 22 | AgentPolicy config.AgentPolicyConfig `json:"agent_policy"` |
| 23 | Logging config.LoggingConfig `json:"logging"` |
| 24 | } |
| 25 | |
| 26 | type bridgeConfigView struct { |
| 27 | Enabled bool `json:"enabled"` |
| 28 | Nick string `json:"nick"` |
| @@ -95,13 +97,15 @@ | |
| 97 | TLS: tlsConfigView{ |
| 98 | Domain: cfg.TLS.Domain, |
| 99 | Email: cfg.TLS.Email, |
| 100 | AllowInsecure: cfg.TLS.AllowInsecure, |
| 101 | }, |
| 102 | LLM: llmConfigView{Backends: backends}, |
| 103 | Topology: cfg.Topology, |
| 104 | History: cfg.History, |
| 105 | AgentPolicy: cfg.AgentPolicy, |
| 106 | Logging: cfg.Logging, |
| 107 | } |
| 108 | } |
| 109 | |
| 110 | // handleGetConfig handles GET /v1/config. |
| 111 | func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { |
| @@ -111,18 +115,35 @@ | |
| 115 | |
| 116 | // configUpdateRequest is the body accepted by PUT /v1/config. |
| 117 | // Only the mutable, hot-reloadable sections. Restart-required fields (ergo IRC |
| 118 | // addr, TLS domain, api_addr) are accepted but flagged in the response. |
| 119 | type configUpdateRequest struct { |
| 120 | Bridge *bridgeConfigUpdate `json:"bridge,omitempty"` |
| 121 | Topology *config.TopologyConfig `json:"topology,omitempty"` |
| 122 | History *config.ConfigHistoryConfig `json:"config_history,omitempty"` |
| 123 | LLM *llmConfigUpdate `json:"llm,omitempty"` |
| 124 | AgentPolicy *config.AgentPolicyConfig `json:"agent_policy,omitempty"` |
| 125 | Logging *config.LoggingConfig `json:"logging,omitempty"` |
| 126 | Ergo *ergoConfigUpdate `json:"ergo,omitempty"` |
| 127 | TLS *tlsConfigUpdate `json:"tls,omitempty"` |
| 128 | // These fields trigger a restart_required notice but are still persisted. |
| 129 | APIAddr *string `json:"api_addr,omitempty"` |
| 130 | MCPAddr *string `json:"mcp_addr,omitempty"` |
| 131 | } |
| 132 | |
| 133 | type ergoConfigUpdate struct { |
| 134 | NetworkName *string `json:"network_name,omitempty"` |
| 135 | ServerName *string `json:"server_name,omitempty"` |
| 136 | IRCAddr *string `json:"irc_addr,omitempty"` |
| 137 | External *bool `json:"external,omitempty"` |
| 138 | } |
| 139 | |
| 140 | type tlsConfigUpdate struct { |
| 141 | Domain *string `json:"domain,omitempty"` |
| 142 | Email *string `json:"email,omitempty"` |
| 143 | AllowInsecure *bool `json:"allow_insecure,omitempty"` |
| 144 | } |
| 145 | |
| 146 | type bridgeConfigUpdate struct { |
| 147 | Enabled *bool `json:"enabled,omitempty"` |
| 148 | Nick *string `json:"nick,omitempty"` |
| 149 | Channels []string `json:"channels,omitempty"` |
| @@ -189,10 +210,52 @@ | |
| 210 | } |
| 211 | |
| 212 | if req.LLM != nil { |
| 213 | next.LLM.Backends = req.LLM.Backends |
| 214 | } |
| 215 | |
| 216 | if req.AgentPolicy != nil { |
| 217 | next.AgentPolicy = *req.AgentPolicy |
| 218 | } |
| 219 | |
| 220 | if req.Logging != nil { |
| 221 | next.Logging = *req.Logging |
| 222 | } |
| 223 | |
| 224 | if req.Ergo != nil { |
| 225 | e := req.Ergo |
| 226 | if e.NetworkName != nil { |
| 227 | next.Ergo.NetworkName = *e.NetworkName |
| 228 | restartRequired = appendUniq(restartRequired, "ergo.network_name") |
| 229 | } |
| 230 | if e.ServerName != nil { |
| 231 | next.Ergo.ServerName = *e.ServerName |
| 232 | restartRequired = appendUniq(restartRequired, "ergo.server_name") |
| 233 | } |
| 234 | if e.IRCAddr != nil { |
| 235 | next.Ergo.IRCAddr = *e.IRCAddr |
| 236 | restartRequired = appendUniq(restartRequired, "ergo.irc_addr") |
| 237 | } |
| 238 | if e.External != nil { |
| 239 | next.Ergo.External = *e.External |
| 240 | restartRequired = appendUniq(restartRequired, "ergo.external") |
| 241 | } |
| 242 | } |
| 243 | |
| 244 | if req.TLS != nil { |
| 245 | t := req.TLS |
| 246 | if t.Domain != nil { |
| 247 | next.TLS.Domain = *t.Domain |
| 248 | restartRequired = appendUniq(restartRequired, "tls.domain") |
| 249 | } |
| 250 | if t.Email != nil { |
| 251 | next.TLS.Email = *t.Email |
| 252 | } |
| 253 | if t.AllowInsecure != nil { |
| 254 | next.TLS.AllowInsecure = *t.AllowInsecure |
| 255 | } |
| 256 | } |
| 257 | |
| 258 | if req.APIAddr != nil && *req.APIAddr != "" { |
| 259 | next.APIAddr = *req.APIAddr |
| 260 | restartRequired = appendUniq(restartRequired, "api_addr") |
| 261 | } |
| 262 |
| --- internal/api/config_handlers_test.go | ||
| +++ internal/api/config_handlers_test.go | ||
| @@ -7,10 +7,11 @@ | ||
| 7 | 7 | "log/slog" |
| 8 | 8 | "net/http" |
| 9 | 9 | "net/http/httptest" |
| 10 | 10 | "path/filepath" |
| 11 | 11 | "testing" |
| 12 | + "time" | |
| 12 | 13 | |
| 13 | 14 | "github.com/conflicthq/scuttlebot/internal/config" |
| 14 | 15 | "github.com/conflicthq/scuttlebot/internal/registry" |
| 15 | 16 | ) |
| 16 | 17 | |
| @@ -104,10 +105,319 @@ | ||
| 104 | 105 | } |
| 105 | 106 | if len(got.Topology.Channels) != 1 || got.Topology.Channels[0].Name != "#general" { |
| 106 | 107 | t.Errorf("topology.channels = %+v", got.Topology.Channels) |
| 107 | 108 | } |
| 108 | 109 | } |
| 110 | + | |
| 111 | +func TestHandlePutConfigAgentPolicy(t *testing.T) { | |
| 112 | + srv, store := newCfgTestServer(t) | |
| 113 | + | |
| 114 | + update := map[string]any{ | |
| 115 | + "agent_policy": map[string]any{ | |
| 116 | + "require_checkin": true, | |
| 117 | + "checkin_channel": "#fleet", | |
| 118 | + "required_channels": []string{"#general"}, | |
| 119 | + }, | |
| 120 | + } | |
| 121 | + body, _ := json.Marshal(update) | |
| 122 | + req, _ := http.NewRequest(http.MethodPut, srv.URL+"/v1/config", bytes.NewReader(body)) | |
| 123 | + req.Header.Set("Authorization", "Bearer tok") | |
| 124 | + req.Header.Set("Content-Type", "application/json") | |
| 125 | + resp, err := http.DefaultClient.Do(req) | |
| 126 | + if err != nil { | |
| 127 | + t.Fatal(err) | |
| 128 | + } | |
| 129 | + defer resp.Body.Close() | |
| 130 | + if resp.StatusCode != http.StatusOK { | |
| 131 | + t.Fatalf("want 200, got %d", resp.StatusCode) | |
| 132 | + } | |
| 133 | + | |
| 134 | + got := store.Get() | |
| 135 | + if !got.AgentPolicy.RequireCheckin { | |
| 136 | + t.Error("agent_policy.require_checkin should be true") | |
| 137 | + } | |
| 138 | + if got.AgentPolicy.CheckinChannel != "#fleet" { | |
| 139 | + t.Errorf("agent_policy.checkin_channel = %q, want #fleet", got.AgentPolicy.CheckinChannel) | |
| 140 | + } | |
| 141 | + if len(got.AgentPolicy.RequiredChannels) != 1 || got.AgentPolicy.RequiredChannels[0] != "#general" { | |
| 142 | + t.Errorf("agent_policy.required_channels = %v", got.AgentPolicy.RequiredChannels) | |
| 143 | + } | |
| 144 | +} | |
| 145 | + | |
| 146 | +func TestHandlePutConfigLogging(t *testing.T) { | |
| 147 | + srv, store := newCfgTestServer(t) | |
| 148 | + | |
| 149 | + update := map[string]any{ | |
| 150 | + "logging": map[string]any{ | |
| 151 | + "enabled": true, | |
| 152 | + "dir": "./data/logs", | |
| 153 | + "format": "jsonl", | |
| 154 | + "rotation": "daily", | |
| 155 | + "per_channel": true, | |
| 156 | + "max_age_days": 30, | |
| 157 | + }, | |
| 158 | + } | |
| 159 | + body, _ := json.Marshal(update) | |
| 160 | + req, _ := http.NewRequest(http.MethodPut, srv.URL+"/v1/config", bytes.NewReader(body)) | |
| 161 | + req.Header.Set("Authorization", "Bearer tok") | |
| 162 | + req.Header.Set("Content-Type", "application/json") | |
| 163 | + resp, err := http.DefaultClient.Do(req) | |
| 164 | + if err != nil { | |
| 165 | + t.Fatal(err) | |
| 166 | + } | |
| 167 | + defer resp.Body.Close() | |
| 168 | + if resp.StatusCode != http.StatusOK { | |
| 169 | + t.Fatalf("want 200, got %d", resp.StatusCode) | |
| 170 | + } | |
| 171 | + | |
| 172 | + got := store.Get() | |
| 173 | + if !got.Logging.Enabled { | |
| 174 | + t.Error("logging.enabled should be true") | |
| 175 | + } | |
| 176 | + if got.Logging.Dir != "./data/logs" { | |
| 177 | + t.Errorf("logging.dir = %q, want ./data/logs", got.Logging.Dir) | |
| 178 | + } | |
| 179 | + if got.Logging.Format != "jsonl" { | |
| 180 | + t.Errorf("logging.format = %q, want jsonl", got.Logging.Format) | |
| 181 | + } | |
| 182 | + if got.Logging.Rotation != "daily" { | |
| 183 | + t.Errorf("logging.rotation = %q, want daily", got.Logging.Rotation) | |
| 184 | + } | |
| 185 | + if !got.Logging.PerChannel { | |
| 186 | + t.Error("logging.per_channel should be true") | |
| 187 | + } | |
| 188 | + if got.Logging.MaxAgeDays != 30 { | |
| 189 | + t.Errorf("logging.max_age_days = %d, want 30", got.Logging.MaxAgeDays) | |
| 190 | + } | |
| 191 | +} | |
| 192 | + | |
| 193 | +func TestHandlePutConfigErgo(t *testing.T) { | |
| 194 | + srv, store := newCfgTestServer(t) | |
| 195 | + | |
| 196 | + update := map[string]any{ | |
| 197 | + "ergo": map[string]any{ | |
| 198 | + "network_name": "testnet", | |
| 199 | + "server_name": "irc.test.local", | |
| 200 | + }, | |
| 201 | + } | |
| 202 | + body, _ := json.Marshal(update) | |
| 203 | + req, _ := http.NewRequest(http.MethodPut, srv.URL+"/v1/config", bytes.NewReader(body)) | |
| 204 | + req.Header.Set("Authorization", "Bearer tok") | |
| 205 | + req.Header.Set("Content-Type", "application/json") | |
| 206 | + resp, err := http.DefaultClient.Do(req) | |
| 207 | + if err != nil { | |
| 208 | + t.Fatal(err) | |
| 209 | + } | |
| 210 | + defer resp.Body.Close() | |
| 211 | + if resp.StatusCode != http.StatusOK { | |
| 212 | + t.Fatalf("want 200, got %d", resp.StatusCode) | |
| 213 | + } | |
| 214 | + | |
| 215 | + // Ergo changes should be flagged as restart_required. | |
| 216 | + var result struct { | |
| 217 | + Saved bool `json:"saved"` | |
| 218 | + RestartRequired []string `json:"restart_required"` | |
| 219 | + } | |
| 220 | + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { | |
| 221 | + t.Fatal(err) | |
| 222 | + } | |
| 223 | + if !result.Saved { | |
| 224 | + t.Error("expected saved=true") | |
| 225 | + } | |
| 226 | + if len(result.RestartRequired) == 0 { | |
| 227 | + t.Error("expected restart_required to be non-empty for ergo changes") | |
| 228 | + } | |
| 229 | + | |
| 230 | + got := store.Get() | |
| 231 | + if got.Ergo.NetworkName != "testnet" { | |
| 232 | + t.Errorf("ergo.network_name = %q, want testnet", got.Ergo.NetworkName) | |
| 233 | + } | |
| 234 | + if got.Ergo.ServerName != "irc.test.local" { | |
| 235 | + t.Errorf("ergo.server_name = %q, want irc.test.local", got.Ergo.ServerName) | |
| 236 | + } | |
| 237 | +} | |
| 238 | + | |
| 239 | +func TestHandlePutConfigTLS(t *testing.T) { | |
| 240 | + srv, store := newCfgTestServer(t) | |
| 241 | + | |
| 242 | + update := map[string]any{ | |
| 243 | + "tls": map[string]any{ | |
| 244 | + "domain": "example.com", | |
| 245 | + "email": "[email protected]", | |
| 246 | + "allow_insecure": true, | |
| 247 | + }, | |
| 248 | + } | |
| 249 | + body, _ := json.Marshal(update) | |
| 250 | + req, _ := http.NewRequest(http.MethodPut, srv.URL+"/v1/config", bytes.NewReader(body)) | |
| 251 | + req.Header.Set("Authorization", "Bearer tok") | |
| 252 | + req.Header.Set("Content-Type", "application/json") | |
| 253 | + resp, err := http.DefaultClient.Do(req) | |
| 254 | + if err != nil { | |
| 255 | + t.Fatal(err) | |
| 256 | + } | |
| 257 | + defer resp.Body.Close() | |
| 258 | + if resp.StatusCode != http.StatusOK { | |
| 259 | + t.Fatalf("want 200, got %d", resp.StatusCode) | |
| 260 | + } | |
| 261 | + | |
| 262 | + var result struct { | |
| 263 | + RestartRequired []string `json:"restart_required"` | |
| 264 | + } | |
| 265 | + json.NewDecoder(resp.Body).Decode(&result) | |
| 266 | + if len(result.RestartRequired) == 0 { | |
| 267 | + t.Error("expected restart_required for tls.domain change") | |
| 268 | + } | |
| 269 | + | |
| 270 | + got := store.Get() | |
| 271 | + if got.TLS.Domain != "example.com" { | |
| 272 | + t.Errorf("tls.domain = %q, want example.com", got.TLS.Domain) | |
| 273 | + } | |
| 274 | + if got.TLS.Email != "[email protected]" { | |
| 275 | + t.Errorf("tls.email = %q, want [email protected]", got.TLS.Email) | |
| 276 | + } | |
| 277 | + if !got.TLS.AllowInsecure { | |
| 278 | + t.Error("tls.allow_insecure should be true") | |
| 279 | + } | |
| 280 | +} | |
| 281 | + | |
| 282 | +func TestHandleGetConfigIncludesAgentPolicyAndLogging(t *testing.T) { | |
| 283 | + srv, store := newCfgTestServer(t) | |
| 284 | + | |
| 285 | + cfg := store.Get() | |
| 286 | + cfg.AgentPolicy.RequireCheckin = true | |
| 287 | + cfg.AgentPolicy.CheckinChannel = "#ops" | |
| 288 | + cfg.Logging.Enabled = true | |
| 289 | + cfg.Logging.Format = "csv" | |
| 290 | + if err := store.Save(cfg); err != nil { | |
| 291 | + t.Fatalf("store.Save: %v", err) | |
| 292 | + } | |
| 293 | + | |
| 294 | + req, _ := http.NewRequest(http.MethodGet, srv.URL+"/v1/config", nil) | |
| 295 | + req.Header.Set("Authorization", "Bearer tok") | |
| 296 | + resp, err := http.DefaultClient.Do(req) | |
| 297 | + if err != nil { | |
| 298 | + t.Fatal(err) | |
| 299 | + } | |
| 300 | + defer resp.Body.Close() | |
| 301 | + if resp.StatusCode != http.StatusOK { | |
| 302 | + t.Fatalf("want 200, got %d", resp.StatusCode) | |
| 303 | + } | |
| 304 | + | |
| 305 | + var body map[string]any | |
| 306 | + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { | |
| 307 | + t.Fatal(err) | |
| 308 | + } | |
| 309 | + ap, ok := body["agent_policy"].(map[string]any) | |
| 310 | + if !ok { | |
| 311 | + t.Fatal("response missing agent_policy section") | |
| 312 | + } | |
| 313 | + if ap["require_checkin"] != true { | |
| 314 | + t.Error("agent_policy.require_checkin should be true") | |
| 315 | + } | |
| 316 | + if ap["checkin_channel"] != "#ops" { | |
| 317 | + t.Errorf("agent_policy.checkin_channel = %v, want #ops", ap["checkin_channel"]) | |
| 318 | + } | |
| 319 | + lg, ok := body["logging"].(map[string]any) | |
| 320 | + if !ok { | |
| 321 | + t.Fatal("response missing logging section") | |
| 322 | + } | |
| 323 | + if lg["enabled"] != true { | |
| 324 | + t.Error("logging.enabled should be true") | |
| 325 | + } | |
| 326 | + if lg["format"] != "csv" { | |
| 327 | + t.Errorf("logging.format = %v, want csv", lg["format"]) | |
| 328 | + } | |
| 329 | +} | |
| 330 | + | |
| 331 | +func TestHandleGetConfigHistoryEntry(t *testing.T) { | |
| 332 | + srv, store := newCfgTestServer(t) | |
| 333 | + | |
| 334 | + // Save twice so a snapshot exists. | |
| 335 | + cfg := store.Get() | |
| 336 | + cfg.Bridge.WebUserTTLMinutes = 11 | |
| 337 | + if err := store.Save(cfg); err != nil { | |
| 338 | + t.Fatalf("first save: %v", err) | |
| 339 | + } | |
| 340 | + cfg2 := store.Get() | |
| 341 | + cfg2.Bridge.WebUserTTLMinutes = 22 | |
| 342 | + if err := store.Save(cfg2); err != nil { | |
| 343 | + t.Fatalf("second save: %v", err) | |
| 344 | + } | |
| 345 | + | |
| 346 | + // List history to find a real filename. | |
| 347 | + entries, err := store.ListHistory() | |
| 348 | + if err != nil { | |
| 349 | + t.Fatalf("ListHistory: %v", err) | |
| 350 | + } | |
| 351 | + if len(entries) == 0 { | |
| 352 | + t.Skip("no history entries; snapshot may not have been created") | |
| 353 | + } | |
| 354 | + filename := entries[0].Filename | |
| 355 | + | |
| 356 | + req, _ := http.NewRequest(http.MethodGet, srv.URL+"/v1/config/history/"+filename, nil) | |
| 357 | + req.Header.Set("Authorization", "Bearer tok") | |
| 358 | + resp, err := http.DefaultClient.Do(req) | |
| 359 | + if err != nil { | |
| 360 | + t.Fatal(err) | |
| 361 | + } | |
| 362 | + defer resp.Body.Close() | |
| 363 | + | |
| 364 | + if resp.StatusCode != http.StatusOK { | |
| 365 | + t.Fatalf("want 200, got %d", resp.StatusCode) | |
| 366 | + } | |
| 367 | + var body map[string]any | |
| 368 | + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { | |
| 369 | + t.Fatal(err) | |
| 370 | + } | |
| 371 | + if _, ok := body["bridge"]; !ok { | |
| 372 | + t.Error("history entry response missing bridge section") | |
| 373 | + } | |
| 374 | +} | |
| 375 | + | |
| 376 | +func TestHandleGetConfigHistoryEntryNotFound(t *testing.T) { | |
| 377 | + srv, _ := newCfgTestServer(t) | |
| 378 | + | |
| 379 | + req, _ := http.NewRequest(http.MethodGet, srv.URL+"/v1/config/history/nonexistent.yaml", nil) | |
| 380 | + req.Header.Set("Authorization", "Bearer tok") | |
| 381 | + resp, err := http.DefaultClient.Do(req) | |
| 382 | + if err != nil { | |
| 383 | + t.Fatal(err) | |
| 384 | + } | |
| 385 | + defer resp.Body.Close() | |
| 386 | + | |
| 387 | + if resp.StatusCode != http.StatusNotFound { | |
| 388 | + t.Fatalf("want 404, got %d", resp.StatusCode) | |
| 389 | + } | |
| 390 | +} | |
| 391 | + | |
| 392 | +func TestConfigStoreOnChange(t *testing.T) { | |
| 393 | + dir := t.TempDir() | |
| 394 | + path := filepath.Join(dir, "scuttlebot.yaml") | |
| 395 | + | |
| 396 | + var cfg config.Config | |
| 397 | + cfg.Defaults() | |
| 398 | + cfg.Ergo.DataDir = dir | |
| 399 | + store := NewConfigStore(path, cfg) | |
| 400 | + | |
| 401 | + done := make(chan config.Config, 1) | |
| 402 | + store.OnChange(func(c config.Config) { done <- c }) | |
| 403 | + | |
| 404 | + next := store.Get() | |
| 405 | + next.Bridge.WebUserTTLMinutes = 99 | |
| 406 | + if err := store.Save(next); err != nil { | |
| 407 | + t.Fatalf("Save: %v", err) | |
| 408 | + } | |
| 409 | + | |
| 410 | + select { | |
| 411 | + case c := <-done: | |
| 412 | + if c.Bridge.WebUserTTLMinutes != 99 { | |
| 413 | + t.Errorf("OnChange got TTL=%d, want 99", c.Bridge.WebUserTTLMinutes) | |
| 414 | + } | |
| 415 | + case <-time.After(2 * time.Second): | |
| 416 | + t.Error("OnChange callback not called within timeout") | |
| 417 | + } | |
| 418 | +} | |
| 109 | 419 | |
| 110 | 420 | func TestHandleGetConfigHistory(t *testing.T) { |
| 111 | 421 | srv, store := newCfgTestServer(t) |
| 112 | 422 | |
| 113 | 423 | // Trigger a save to create a snapshot. |
| 114 | 424 |
| --- internal/api/config_handlers_test.go | |
| +++ internal/api/config_handlers_test.go | |
| @@ -7,10 +7,11 @@ | |
| 7 | "log/slog" |
| 8 | "net/http" |
| 9 | "net/http/httptest" |
| 10 | "path/filepath" |
| 11 | "testing" |
| 12 | |
| 13 | "github.com/conflicthq/scuttlebot/internal/config" |
| 14 | "github.com/conflicthq/scuttlebot/internal/registry" |
| 15 | ) |
| 16 | |
| @@ -104,10 +105,319 @@ | |
| 104 | } |
| 105 | if len(got.Topology.Channels) != 1 || got.Topology.Channels[0].Name != "#general" { |
| 106 | t.Errorf("topology.channels = %+v", got.Topology.Channels) |
| 107 | } |
| 108 | } |
| 109 | |
| 110 | func TestHandleGetConfigHistory(t *testing.T) { |
| 111 | srv, store := newCfgTestServer(t) |
| 112 | |
| 113 | // Trigger a save to create a snapshot. |
| 114 |
| --- internal/api/config_handlers_test.go | |
| +++ internal/api/config_handlers_test.go | |
| @@ -7,10 +7,11 @@ | |
| 7 | "log/slog" |
| 8 | "net/http" |
| 9 | "net/http/httptest" |
| 10 | "path/filepath" |
| 11 | "testing" |
| 12 | "time" |
| 13 | |
| 14 | "github.com/conflicthq/scuttlebot/internal/config" |
| 15 | "github.com/conflicthq/scuttlebot/internal/registry" |
| 16 | ) |
| 17 | |
| @@ -104,10 +105,319 @@ | |
| 105 | } |
| 106 | if len(got.Topology.Channels) != 1 || got.Topology.Channels[0].Name != "#general" { |
| 107 | t.Errorf("topology.channels = %+v", got.Topology.Channels) |
| 108 | } |
| 109 | } |
| 110 | |
| 111 | func TestHandlePutConfigAgentPolicy(t *testing.T) { |
| 112 | srv, store := newCfgTestServer(t) |
| 113 | |
| 114 | update := map[string]any{ |
| 115 | "agent_policy": map[string]any{ |
| 116 | "require_checkin": true, |
| 117 | "checkin_channel": "#fleet", |
| 118 | "required_channels": []string{"#general"}, |
| 119 | }, |
| 120 | } |
| 121 | body, _ := json.Marshal(update) |
| 122 | req, _ := http.NewRequest(http.MethodPut, srv.URL+"/v1/config", bytes.NewReader(body)) |
| 123 | req.Header.Set("Authorization", "Bearer tok") |
| 124 | req.Header.Set("Content-Type", "application/json") |
| 125 | resp, err := http.DefaultClient.Do(req) |
| 126 | if err != nil { |
| 127 | t.Fatal(err) |
| 128 | } |
| 129 | defer resp.Body.Close() |
| 130 | if resp.StatusCode != http.StatusOK { |
| 131 | t.Fatalf("want 200, got %d", resp.StatusCode) |
| 132 | } |
| 133 | |
| 134 | got := store.Get() |
| 135 | if !got.AgentPolicy.RequireCheckin { |
| 136 | t.Error("agent_policy.require_checkin should be true") |
| 137 | } |
| 138 | if got.AgentPolicy.CheckinChannel != "#fleet" { |
| 139 | t.Errorf("agent_policy.checkin_channel = %q, want #fleet", got.AgentPolicy.CheckinChannel) |
| 140 | } |
| 141 | if len(got.AgentPolicy.RequiredChannels) != 1 || got.AgentPolicy.RequiredChannels[0] != "#general" { |
| 142 | t.Errorf("agent_policy.required_channels = %v", got.AgentPolicy.RequiredChannels) |
| 143 | } |
| 144 | } |
| 145 | |
| 146 | func TestHandlePutConfigLogging(t *testing.T) { |
| 147 | srv, store := newCfgTestServer(t) |
| 148 | |
| 149 | update := map[string]any{ |
| 150 | "logging": map[string]any{ |
| 151 | "enabled": true, |
| 152 | "dir": "./data/logs", |
| 153 | "format": "jsonl", |
| 154 | "rotation": "daily", |
| 155 | "per_channel": true, |
| 156 | "max_age_days": 30, |
| 157 | }, |
| 158 | } |
| 159 | body, _ := json.Marshal(update) |
| 160 | req, _ := http.NewRequest(http.MethodPut, srv.URL+"/v1/config", bytes.NewReader(body)) |
| 161 | req.Header.Set("Authorization", "Bearer tok") |
| 162 | req.Header.Set("Content-Type", "application/json") |
| 163 | resp, err := http.DefaultClient.Do(req) |
| 164 | if err != nil { |
| 165 | t.Fatal(err) |
| 166 | } |
| 167 | defer resp.Body.Close() |
| 168 | if resp.StatusCode != http.StatusOK { |
| 169 | t.Fatalf("want 200, got %d", resp.StatusCode) |
| 170 | } |
| 171 | |
| 172 | got := store.Get() |
| 173 | if !got.Logging.Enabled { |
| 174 | t.Error("logging.enabled should be true") |
| 175 | } |
| 176 | if got.Logging.Dir != "./data/logs" { |
| 177 | t.Errorf("logging.dir = %q, want ./data/logs", got.Logging.Dir) |
| 178 | } |
| 179 | if got.Logging.Format != "jsonl" { |
| 180 | t.Errorf("logging.format = %q, want jsonl", got.Logging.Format) |
| 181 | } |
| 182 | if got.Logging.Rotation != "daily" { |
| 183 | t.Errorf("logging.rotation = %q, want daily", got.Logging.Rotation) |
| 184 | } |
| 185 | if !got.Logging.PerChannel { |
| 186 | t.Error("logging.per_channel should be true") |
| 187 | } |
| 188 | if got.Logging.MaxAgeDays != 30 { |
| 189 | t.Errorf("logging.max_age_days = %d, want 30", got.Logging.MaxAgeDays) |
| 190 | } |
| 191 | } |
| 192 | |
| 193 | func TestHandlePutConfigErgo(t *testing.T) { |
| 194 | srv, store := newCfgTestServer(t) |
| 195 | |
| 196 | update := map[string]any{ |
| 197 | "ergo": map[string]any{ |
| 198 | "network_name": "testnet", |
| 199 | "server_name": "irc.test.local", |
| 200 | }, |
| 201 | } |
| 202 | body, _ := json.Marshal(update) |
| 203 | req, _ := http.NewRequest(http.MethodPut, srv.URL+"/v1/config", bytes.NewReader(body)) |
| 204 | req.Header.Set("Authorization", "Bearer tok") |
| 205 | req.Header.Set("Content-Type", "application/json") |
| 206 | resp, err := http.DefaultClient.Do(req) |
| 207 | if err != nil { |
| 208 | t.Fatal(err) |
| 209 | } |
| 210 | defer resp.Body.Close() |
| 211 | if resp.StatusCode != http.StatusOK { |
| 212 | t.Fatalf("want 200, got %d", resp.StatusCode) |
| 213 | } |
| 214 | |
| 215 | // Ergo changes should be flagged as restart_required. |
| 216 | var result struct { |
| 217 | Saved bool `json:"saved"` |
| 218 | RestartRequired []string `json:"restart_required"` |
| 219 | } |
| 220 | if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { |
| 221 | t.Fatal(err) |
| 222 | } |
| 223 | if !result.Saved { |
| 224 | t.Error("expected saved=true") |
| 225 | } |
| 226 | if len(result.RestartRequired) == 0 { |
| 227 | t.Error("expected restart_required to be non-empty for ergo changes") |
| 228 | } |
| 229 | |
| 230 | got := store.Get() |
| 231 | if got.Ergo.NetworkName != "testnet" { |
| 232 | t.Errorf("ergo.network_name = %q, want testnet", got.Ergo.NetworkName) |
| 233 | } |
| 234 | if got.Ergo.ServerName != "irc.test.local" { |
| 235 | t.Errorf("ergo.server_name = %q, want irc.test.local", got.Ergo.ServerName) |
| 236 | } |
| 237 | } |
| 238 | |
| 239 | func TestHandlePutConfigTLS(t *testing.T) { |
| 240 | srv, store := newCfgTestServer(t) |
| 241 | |
| 242 | update := map[string]any{ |
| 243 | "tls": map[string]any{ |
| 244 | "domain": "example.com", |
| 245 | "email": "[email protected]", |
| 246 | "allow_insecure": true, |
| 247 | }, |
| 248 | } |
| 249 | body, _ := json.Marshal(update) |
| 250 | req, _ := http.NewRequest(http.MethodPut, srv.URL+"/v1/config", bytes.NewReader(body)) |
| 251 | req.Header.Set("Authorization", "Bearer tok") |
| 252 | req.Header.Set("Content-Type", "application/json") |
| 253 | resp, err := http.DefaultClient.Do(req) |
| 254 | if err != nil { |
| 255 | t.Fatal(err) |
| 256 | } |
| 257 | defer resp.Body.Close() |
| 258 | if resp.StatusCode != http.StatusOK { |
| 259 | t.Fatalf("want 200, got %d", resp.StatusCode) |
| 260 | } |
| 261 | |
| 262 | var result struct { |
| 263 | RestartRequired []string `json:"restart_required"` |
| 264 | } |
| 265 | json.NewDecoder(resp.Body).Decode(&result) |
| 266 | if len(result.RestartRequired) == 0 { |
| 267 | t.Error("expected restart_required for tls.domain change") |
| 268 | } |
| 269 | |
| 270 | got := store.Get() |
| 271 | if got.TLS.Domain != "example.com" { |
| 272 | t.Errorf("tls.domain = %q, want example.com", got.TLS.Domain) |
| 273 | } |
| 274 | if got.TLS.Email != "[email protected]" { |
| 275 | t.Errorf("tls.email = %q, want [email protected]", got.TLS.Email) |
| 276 | } |
| 277 | if !got.TLS.AllowInsecure { |
| 278 | t.Error("tls.allow_insecure should be true") |
| 279 | } |
| 280 | } |
| 281 | |
| 282 | func TestHandleGetConfigIncludesAgentPolicyAndLogging(t *testing.T) { |
| 283 | srv, store := newCfgTestServer(t) |
| 284 | |
| 285 | cfg := store.Get() |
| 286 | cfg.AgentPolicy.RequireCheckin = true |
| 287 | cfg.AgentPolicy.CheckinChannel = "#ops" |
| 288 | cfg.Logging.Enabled = true |
| 289 | cfg.Logging.Format = "csv" |
| 290 | if err := store.Save(cfg); err != nil { |
| 291 | t.Fatalf("store.Save: %v", err) |
| 292 | } |
| 293 | |
| 294 | req, _ := http.NewRequest(http.MethodGet, srv.URL+"/v1/config", nil) |
| 295 | req.Header.Set("Authorization", "Bearer tok") |
| 296 | resp, err := http.DefaultClient.Do(req) |
| 297 | if err != nil { |
| 298 | t.Fatal(err) |
| 299 | } |
| 300 | defer resp.Body.Close() |
| 301 | if resp.StatusCode != http.StatusOK { |
| 302 | t.Fatalf("want 200, got %d", resp.StatusCode) |
| 303 | } |
| 304 | |
| 305 | var body map[string]any |
| 306 | if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { |
| 307 | t.Fatal(err) |
| 308 | } |
| 309 | ap, ok := body["agent_policy"].(map[string]any) |
| 310 | if !ok { |
| 311 | t.Fatal("response missing agent_policy section") |
| 312 | } |
| 313 | if ap["require_checkin"] != true { |
| 314 | t.Error("agent_policy.require_checkin should be true") |
| 315 | } |
| 316 | if ap["checkin_channel"] != "#ops" { |
| 317 | t.Errorf("agent_policy.checkin_channel = %v, want #ops", ap["checkin_channel"]) |
| 318 | } |
| 319 | lg, ok := body["logging"].(map[string]any) |
| 320 | if !ok { |
| 321 | t.Fatal("response missing logging section") |
| 322 | } |
| 323 | if lg["enabled"] != true { |
| 324 | t.Error("logging.enabled should be true") |
| 325 | } |
| 326 | if lg["format"] != "csv" { |
| 327 | t.Errorf("logging.format = %v, want csv", lg["format"]) |
| 328 | } |
| 329 | } |
| 330 | |
| 331 | func TestHandleGetConfigHistoryEntry(t *testing.T) { |
| 332 | srv, store := newCfgTestServer(t) |
| 333 | |
| 334 | // Save twice so a snapshot exists. |
| 335 | cfg := store.Get() |
| 336 | cfg.Bridge.WebUserTTLMinutes = 11 |
| 337 | if err := store.Save(cfg); err != nil { |
| 338 | t.Fatalf("first save: %v", err) |
| 339 | } |
| 340 | cfg2 := store.Get() |
| 341 | cfg2.Bridge.WebUserTTLMinutes = 22 |
| 342 | if err := store.Save(cfg2); err != nil { |
| 343 | t.Fatalf("second save: %v", err) |
| 344 | } |
| 345 | |
| 346 | // List history to find a real filename. |
| 347 | entries, err := store.ListHistory() |
| 348 | if err != nil { |
| 349 | t.Fatalf("ListHistory: %v", err) |
| 350 | } |
| 351 | if len(entries) == 0 { |
| 352 | t.Skip("no history entries; snapshot may not have been created") |
| 353 | } |
| 354 | filename := entries[0].Filename |
| 355 | |
| 356 | req, _ := http.NewRequest(http.MethodGet, srv.URL+"/v1/config/history/"+filename, nil) |
| 357 | req.Header.Set("Authorization", "Bearer tok") |
| 358 | resp, err := http.DefaultClient.Do(req) |
| 359 | if err != nil { |
| 360 | t.Fatal(err) |
| 361 | } |
| 362 | defer resp.Body.Close() |
| 363 | |
| 364 | if resp.StatusCode != http.StatusOK { |
| 365 | t.Fatalf("want 200, got %d", resp.StatusCode) |
| 366 | } |
| 367 | var body map[string]any |
| 368 | if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { |
| 369 | t.Fatal(err) |
| 370 | } |
| 371 | if _, ok := body["bridge"]; !ok { |
| 372 | t.Error("history entry response missing bridge section") |
| 373 | } |
| 374 | } |
| 375 | |
| 376 | func TestHandleGetConfigHistoryEntryNotFound(t *testing.T) { |
| 377 | srv, _ := newCfgTestServer(t) |
| 378 | |
| 379 | req, _ := http.NewRequest(http.MethodGet, srv.URL+"/v1/config/history/nonexistent.yaml", nil) |
| 380 | req.Header.Set("Authorization", "Bearer tok") |
| 381 | resp, err := http.DefaultClient.Do(req) |
| 382 | if err != nil { |
| 383 | t.Fatal(err) |
| 384 | } |
| 385 | defer resp.Body.Close() |
| 386 | |
| 387 | if resp.StatusCode != http.StatusNotFound { |
| 388 | t.Fatalf("want 404, got %d", resp.StatusCode) |
| 389 | } |
| 390 | } |
| 391 | |
| 392 | func TestConfigStoreOnChange(t *testing.T) { |
| 393 | dir := t.TempDir() |
| 394 | path := filepath.Join(dir, "scuttlebot.yaml") |
| 395 | |
| 396 | var cfg config.Config |
| 397 | cfg.Defaults() |
| 398 | cfg.Ergo.DataDir = dir |
| 399 | store := NewConfigStore(path, cfg) |
| 400 | |
| 401 | done := make(chan config.Config, 1) |
| 402 | store.OnChange(func(c config.Config) { done <- c }) |
| 403 | |
| 404 | next := store.Get() |
| 405 | next.Bridge.WebUserTTLMinutes = 99 |
| 406 | if err := store.Save(next); err != nil { |
| 407 | t.Fatalf("Save: %v", err) |
| 408 | } |
| 409 | |
| 410 | select { |
| 411 | case c := <-done: |
| 412 | if c.Bridge.WebUserTTLMinutes != 99 { |
| 413 | t.Errorf("OnChange got TTL=%d, want 99", c.Bridge.WebUserTTLMinutes) |
| 414 | } |
| 415 | case <-time.After(2 * time.Second): |
| 416 | t.Error("OnChange callback not called within timeout") |
| 417 | } |
| 418 | } |
| 419 | |
| 420 | func TestHandleGetConfigHistory(t *testing.T) { |
| 421 | srv, store := newCfgTestServer(t) |
| 422 | |
| 423 | // Trigger a save to create a snapshot. |
| 424 |
+33
-2
| --- internal/api/settings.go | ||
| +++ internal/api/settings.go | ||
| @@ -1,8 +1,12 @@ | ||
| 1 | 1 | package api |
| 2 | 2 | |
| 3 | -import "net/http" | |
| 3 | +import ( | |
| 4 | + "net/http" | |
| 5 | + | |
| 6 | + "github.com/conflicthq/scuttlebot/internal/config" | |
| 7 | +) | |
| 4 | 8 | |
| 5 | 9 | type settingsResponse struct { |
| 6 | 10 | TLS tlsInfo `json:"tls"` |
| 7 | 11 | Policies Policies `json:"policies"` |
| 8 | 12 | } |
| @@ -16,13 +20,40 @@ | ||
| 16 | 20 | func (s *Server) handleGetSettings(w http.ResponseWriter, r *http.Request) { |
| 17 | 21 | resp := settingsResponse{ |
| 18 | 22 | TLS: tlsInfo{ |
| 19 | 23 | Enabled: s.tlsDomain != "", |
| 20 | 24 | Domain: s.tlsDomain, |
| 21 | - AllowInsecure: true, // always true in current build | |
| 25 | + AllowInsecure: true, | |
| 22 | 26 | }, |
| 23 | 27 | } |
| 24 | 28 | if s.policies != nil { |
| 25 | 29 | resp.Policies = s.policies.Get() |
| 26 | 30 | } |
| 31 | + // Prefer ConfigStore for fields that have migrated to scuttlebot.yaml. | |
| 32 | + if s.cfgStore != nil { | |
| 33 | + cfg := s.cfgStore.Get() | |
| 34 | + resp.Policies.AgentPolicy = toAPIAgentPolicy(cfg.AgentPolicy) | |
| 35 | + resp.Policies.Logging = toAPILogging(cfg.Logging) | |
| 36 | + resp.Policies.Bridge.WebUserTTLMinutes = cfg.Bridge.WebUserTTLMinutes | |
| 37 | + } | |
| 27 | 38 | writeJSON(w, http.StatusOK, resp) |
| 28 | 39 | } |
| 40 | + | |
| 41 | +func toAPIAgentPolicy(c config.AgentPolicyConfig) AgentPolicy { | |
| 42 | + return AgentPolicy{ | |
| 43 | + RequireCheckin: c.RequireCheckin, | |
| 44 | + CheckinChannel: c.CheckinChannel, | |
| 45 | + RequiredChannels: c.RequiredChannels, | |
| 46 | + } | |
| 47 | +} | |
| 48 | + | |
| 49 | +func toAPILogging(c config.LoggingConfig) LoggingPolicy { | |
| 50 | + return LoggingPolicy{ | |
| 51 | + Enabled: c.Enabled, | |
| 52 | + Dir: c.Dir, | |
| 53 | + Format: c.Format, | |
| 54 | + Rotation: c.Rotation, | |
| 55 | + MaxSizeMB: c.MaxSizeMB, | |
| 56 | + PerChannel: c.PerChannel, | |
| 57 | + MaxAgeDays: c.MaxAgeDays, | |
| 58 | + } | |
| 59 | +} | |
| 29 | 60 |
| --- internal/api/settings.go | |
| +++ internal/api/settings.go | |
| @@ -1,8 +1,12 @@ | |
| 1 | package api |
| 2 | |
| 3 | import "net/http" |
| 4 | |
| 5 | type settingsResponse struct { |
| 6 | TLS tlsInfo `json:"tls"` |
| 7 | Policies Policies `json:"policies"` |
| 8 | } |
| @@ -16,13 +20,40 @@ | |
| 16 | func (s *Server) handleGetSettings(w http.ResponseWriter, r *http.Request) { |
| 17 | resp := settingsResponse{ |
| 18 | TLS: tlsInfo{ |
| 19 | Enabled: s.tlsDomain != "", |
| 20 | Domain: s.tlsDomain, |
| 21 | AllowInsecure: true, // always true in current build |
| 22 | }, |
| 23 | } |
| 24 | if s.policies != nil { |
| 25 | resp.Policies = s.policies.Get() |
| 26 | } |
| 27 | writeJSON(w, http.StatusOK, resp) |
| 28 | } |
| 29 |
| --- internal/api/settings.go | |
| +++ internal/api/settings.go | |
| @@ -1,8 +1,12 @@ | |
| 1 | package api |
| 2 | |
| 3 | import ( |
| 4 | "net/http" |
| 5 | |
| 6 | "github.com/conflicthq/scuttlebot/internal/config" |
| 7 | ) |
| 8 | |
| 9 | type settingsResponse struct { |
| 10 | TLS tlsInfo `json:"tls"` |
| 11 | Policies Policies `json:"policies"` |
| 12 | } |
| @@ -16,13 +20,40 @@ | |
| 20 | func (s *Server) handleGetSettings(w http.ResponseWriter, r *http.Request) { |
| 21 | resp := settingsResponse{ |
| 22 | TLS: tlsInfo{ |
| 23 | Enabled: s.tlsDomain != "", |
| 24 | Domain: s.tlsDomain, |
| 25 | AllowInsecure: true, |
| 26 | }, |
| 27 | } |
| 28 | if s.policies != nil { |
| 29 | resp.Policies = s.policies.Get() |
| 30 | } |
| 31 | // Prefer ConfigStore for fields that have migrated to scuttlebot.yaml. |
| 32 | if s.cfgStore != nil { |
| 33 | cfg := s.cfgStore.Get() |
| 34 | resp.Policies.AgentPolicy = toAPIAgentPolicy(cfg.AgentPolicy) |
| 35 | resp.Policies.Logging = toAPILogging(cfg.Logging) |
| 36 | resp.Policies.Bridge.WebUserTTLMinutes = cfg.Bridge.WebUserTTLMinutes |
| 37 | } |
| 38 | writeJSON(w, http.StatusOK, resp) |
| 39 | } |
| 40 | |
| 41 | func toAPIAgentPolicy(c config.AgentPolicyConfig) AgentPolicy { |
| 42 | return AgentPolicy{ |
| 43 | RequireCheckin: c.RequireCheckin, |
| 44 | CheckinChannel: c.CheckinChannel, |
| 45 | RequiredChannels: c.RequiredChannels, |
| 46 | } |
| 47 | } |
| 48 | |
| 49 | func toAPILogging(c config.LoggingConfig) LoggingPolicy { |
| 50 | return LoggingPolicy{ |
| 51 | Enabled: c.Enabled, |
| 52 | Dir: c.Dir, |
| 53 | Format: c.Format, |
| 54 | Rotation: c.Rotation, |
| 55 | MaxSizeMB: c.MaxSizeMB, |
| 56 | PerChannel: c.PerChannel, |
| 57 | MaxAgeDays: c.MaxAgeDays, |
| 58 | } |
| 59 | } |
| 60 |
+27
-7
| --- internal/config/config.go | ||
| +++ internal/config/config.go | ||
| @@ -9,17 +9,19 @@ | ||
| 9 | 9 | "gopkg.in/yaml.v3" |
| 10 | 10 | ) |
| 11 | 11 | |
| 12 | 12 | // Config is the top-level scuttlebot configuration. |
| 13 | 13 | type Config struct { |
| 14 | - Ergo ErgoConfig `yaml:"ergo"` | |
| 15 | - Datastore DatastoreConfig `yaml:"datastore"` | |
| 16 | - Bridge BridgeConfig `yaml:"bridge"` | |
| 17 | - TLS TLSConfig `yaml:"tls"` | |
| 18 | - LLM LLMConfig `yaml:"llm"` | |
| 19 | - Topology TopologyConfig `yaml:"topology"` | |
| 20 | - History ConfigHistoryConfig `yaml:"config_history"` | |
| 14 | + Ergo ErgoConfig `yaml:"ergo"` | |
| 15 | + Datastore DatastoreConfig `yaml:"datastore"` | |
| 16 | + Bridge BridgeConfig `yaml:"bridge"` | |
| 17 | + TLS TLSConfig `yaml:"tls"` | |
| 18 | + LLM LLMConfig `yaml:"llm"` | |
| 19 | + Topology TopologyConfig `yaml:"topology"` | |
| 20 | + History ConfigHistoryConfig `yaml:"config_history"` | |
| 21 | + AgentPolicy AgentPolicyConfig `yaml:"agent_policy" json:"agent_policy"` | |
| 22 | + Logging LoggingConfig `yaml:"logging" json:"logging"` | |
| 21 | 23 | |
| 22 | 24 | // APIAddr is the address for scuttlebot's own HTTP management API. |
| 23 | 25 | // Ignored when TLS.Domain is set (HTTPS runs on :443, HTTP on :80). |
| 24 | 26 | // Default: ":8080" |
| 25 | 27 | APIAddr string `yaml:"api_addr"` |
| @@ -26,10 +28,28 @@ | ||
| 26 | 28 | |
| 27 | 29 | // MCPAddr is the address for the MCP server. |
| 28 | 30 | // Default: ":8081" |
| 29 | 31 | MCPAddr string `yaml:"mcp_addr"` |
| 30 | 32 | } |
| 33 | + | |
| 34 | +// AgentPolicyConfig defines requirements applied to all registering agents. | |
| 35 | +type AgentPolicyConfig struct { | |
| 36 | + RequireCheckin bool `yaml:"require_checkin" json:"require_checkin"` | |
| 37 | + CheckinChannel string `yaml:"checkin_channel" json:"checkin_channel"` | |
| 38 | + RequiredChannels []string `yaml:"required_channels" json:"required_channels"` | |
| 39 | +} | |
| 40 | + | |
| 41 | +// LoggingConfig configures message logging. | |
| 42 | +type LoggingConfig struct { | |
| 43 | + Enabled bool `yaml:"enabled" json:"enabled"` | |
| 44 | + Dir string `yaml:"dir" json:"dir"` | |
| 45 | + Format string `yaml:"format" json:"format"` // "jsonl" | "csv" | "text" | |
| 46 | + Rotation string `yaml:"rotation" json:"rotation"` // "none" | "daily" | "weekly" | "size" | |
| 47 | + MaxSizeMB int `yaml:"max_size_mb" json:"max_size_mb"` | |
| 48 | + PerChannel bool `yaml:"per_channel" json:"per_channel"` | |
| 49 | + MaxAgeDays int `yaml:"max_age_days" json:"max_age_days"` | |
| 50 | +} | |
| 31 | 51 | |
| 32 | 52 | // ConfigHistoryConfig controls config write-back history retention. |
| 33 | 53 | type ConfigHistoryConfig struct { |
| 34 | 54 | // Keep is the number of config snapshots to retain in Dir. |
| 35 | 55 | // 0 disables history. Default: 20. |
| 36 | 56 |
| --- internal/config/config.go | |
| +++ internal/config/config.go | |
| @@ -9,17 +9,19 @@ | |
| 9 | "gopkg.in/yaml.v3" |
| 10 | ) |
| 11 | |
| 12 | // Config is the top-level scuttlebot configuration. |
| 13 | type Config struct { |
| 14 | Ergo ErgoConfig `yaml:"ergo"` |
| 15 | Datastore DatastoreConfig `yaml:"datastore"` |
| 16 | Bridge BridgeConfig `yaml:"bridge"` |
| 17 | TLS TLSConfig `yaml:"tls"` |
| 18 | LLM LLMConfig `yaml:"llm"` |
| 19 | Topology TopologyConfig `yaml:"topology"` |
| 20 | History ConfigHistoryConfig `yaml:"config_history"` |
| 21 | |
| 22 | // APIAddr is the address for scuttlebot's own HTTP management API. |
| 23 | // Ignored when TLS.Domain is set (HTTPS runs on :443, HTTP on :80). |
| 24 | // Default: ":8080" |
| 25 | APIAddr string `yaml:"api_addr"` |
| @@ -26,10 +28,28 @@ | |
| 26 | |
| 27 | // MCPAddr is the address for the MCP server. |
| 28 | // Default: ":8081" |
| 29 | MCPAddr string `yaml:"mcp_addr"` |
| 30 | } |
| 31 | |
| 32 | // ConfigHistoryConfig controls config write-back history retention. |
| 33 | type ConfigHistoryConfig struct { |
| 34 | // Keep is the number of config snapshots to retain in Dir. |
| 35 | // 0 disables history. Default: 20. |
| 36 |
| --- internal/config/config.go | |
| +++ internal/config/config.go | |
| @@ -9,17 +9,19 @@ | |
| 9 | "gopkg.in/yaml.v3" |
| 10 | ) |
| 11 | |
| 12 | // Config is the top-level scuttlebot configuration. |
| 13 | type Config struct { |
| 14 | Ergo ErgoConfig `yaml:"ergo"` |
| 15 | Datastore DatastoreConfig `yaml:"datastore"` |
| 16 | Bridge BridgeConfig `yaml:"bridge"` |
| 17 | TLS TLSConfig `yaml:"tls"` |
| 18 | LLM LLMConfig `yaml:"llm"` |
| 19 | Topology TopologyConfig `yaml:"topology"` |
| 20 | History ConfigHistoryConfig `yaml:"config_history"` |
| 21 | AgentPolicy AgentPolicyConfig `yaml:"agent_policy" json:"agent_policy"` |
| 22 | Logging LoggingConfig `yaml:"logging" json:"logging"` |
| 23 | |
| 24 | // APIAddr is the address for scuttlebot's own HTTP management API. |
| 25 | // Ignored when TLS.Domain is set (HTTPS runs on :443, HTTP on :80). |
| 26 | // Default: ":8080" |
| 27 | APIAddr string `yaml:"api_addr"` |
| @@ -26,10 +28,28 @@ | |
| 28 | |
| 29 | // MCPAddr is the address for the MCP server. |
| 30 | // Default: ":8081" |
| 31 | MCPAddr string `yaml:"mcp_addr"` |
| 32 | } |
| 33 | |
| 34 | // AgentPolicyConfig defines requirements applied to all registering agents. |
| 35 | type AgentPolicyConfig struct { |
| 36 | RequireCheckin bool `yaml:"require_checkin" json:"require_checkin"` |
| 37 | CheckinChannel string `yaml:"checkin_channel" json:"checkin_channel"` |
| 38 | RequiredChannels []string `yaml:"required_channels" json:"required_channels"` |
| 39 | } |
| 40 | |
| 41 | // LoggingConfig configures message logging. |
| 42 | type LoggingConfig struct { |
| 43 | Enabled bool `yaml:"enabled" json:"enabled"` |
| 44 | Dir string `yaml:"dir" json:"dir"` |
| 45 | Format string `yaml:"format" json:"format"` // "jsonl" | "csv" | "text" |
| 46 | Rotation string `yaml:"rotation" json:"rotation"` // "none" | "daily" | "weekly" | "size" |
| 47 | MaxSizeMB int `yaml:"max_size_mb" json:"max_size_mb"` |
| 48 | PerChannel bool `yaml:"per_channel" json:"per_channel"` |
| 49 | MaxAgeDays int `yaml:"max_age_days" json:"max_age_days"` |
| 50 | } |
| 51 | |
| 52 | // ConfigHistoryConfig controls config write-back history retention. |
| 53 | type ConfigHistoryConfig struct { |
| 54 | // Keep is the number of config snapshots to retain in Dir. |
| 55 | // 0 disables history. Default: 20. |
| 56 |
| --- internal/config/config_test.go | ||
| +++ internal/config/config_test.go | ||
| @@ -1,8 +1,9 @@ | ||
| 1 | 1 | package config |
| 2 | 2 | |
| 3 | 3 | import ( |
| 4 | + "encoding/json" | |
| 4 | 5 | "os" |
| 5 | 6 | "path/filepath" |
| 6 | 7 | "testing" |
| 7 | 8 | "time" |
| 8 | 9 | ) |
| @@ -141,5 +142,147 @@ | ||
| 141 | 142 | if got != tc.want { |
| 142 | 143 | t.Errorf("input %q: got %v, want %v", tc.input, got, tc.want) |
| 143 | 144 | } |
| 144 | 145 | } |
| 145 | 146 | } |
| 147 | + | |
| 148 | +func TestDurationJSONRoundTrip(t *testing.T) { | |
| 149 | + cases := []struct { | |
| 150 | + dur time.Duration | |
| 151 | + want string | |
| 152 | + }{ | |
| 153 | + {72 * time.Hour, `"72h0m0s"`}, | |
| 154 | + {30 * time.Minute, `"30m0s"`}, | |
| 155 | + {0, `"0s"`}, | |
| 156 | + } | |
| 157 | + for _, tc := range cases { | |
| 158 | + d := Duration{tc.dur} | |
| 159 | + b, err := json.Marshal(d) | |
| 160 | + if err != nil { | |
| 161 | + t.Fatalf("Marshal(%v): %v", tc.dur, err) | |
| 162 | + } | |
| 163 | + if string(b) != tc.want { | |
| 164 | + t.Errorf("Marshal(%v) = %s, want %s", tc.dur, b, tc.want) | |
| 165 | + } | |
| 166 | + var back Duration | |
| 167 | + if err := json.Unmarshal(b, &back); err != nil { | |
| 168 | + t.Fatalf("Unmarshal(%s): %v", b, err) | |
| 169 | + } | |
| 170 | + if back.Duration != tc.dur { | |
| 171 | + t.Errorf("round-trip(%v): got %v", tc.dur, back.Duration) | |
| 172 | + } | |
| 173 | + } | |
| 174 | +} | |
| 175 | + | |
| 176 | +func TestDurationJSONUnmarshalErrors(t *testing.T) { | |
| 177 | + cases := []struct{ input string }{ | |
| 178 | + {`123`}, // not a quoted string | |
| 179 | + {`"notadur"`}, // not parseable | |
| 180 | + {`""`}, // empty string | |
| 181 | + } | |
| 182 | + for _, tc := range cases { | |
| 183 | + var d Duration | |
| 184 | + if err := json.Unmarshal([]byte(tc.input), &d); err == nil { | |
| 185 | + t.Errorf("Unmarshal(%s): expected error, got nil", tc.input) | |
| 186 | + } | |
| 187 | + } | |
| 188 | +} | |
| 189 | + | |
| 190 | +func TestApplyEnv(t *testing.T) { | |
| 191 | + cases := []struct { | |
| 192 | + envKey string | |
| 193 | + check func(c Config) bool | |
| 194 | + }{ | |
| 195 | + {"SCUTTLEBOT_API_ADDR", func(c Config) bool { return c.APIAddr == ":9999" }}, | |
| 196 | + {"SCUTTLEBOT_MCP_ADDR", func(c Config) bool { return c.MCPAddr == ":9998" }}, | |
| 197 | + {"SCUTTLEBOT_DB_DRIVER", func(c Config) bool { return c.Datastore.Driver == "postgres" }}, | |
| 198 | + {"SCUTTLEBOT_DB_DSN", func(c Config) bool { return c.Datastore.DSN == "postgres://test" }}, | |
| 199 | + {"SCUTTLEBOT_ERGO_EXTERNAL", func(c Config) bool { return c.Ergo.External }}, | |
| 200 | + {"SCUTTLEBOT_ERGO_API_ADDR", func(c Config) bool { return c.Ergo.APIAddr == "http://ergo:8089" }}, | |
| 201 | + {"SCUTTLEBOT_ERGO_API_TOKEN", func(c Config) bool { return c.Ergo.APIToken == "tok123" }}, | |
| 202 | + {"SCUTTLEBOT_ERGO_IRC_ADDR", func(c Config) bool { return c.Ergo.IRCAddr == "ergo:6667" }}, | |
| 203 | + {"SCUTTLEBOT_ERGO_NETWORK_NAME", func(c Config) bool { return c.Ergo.NetworkName == "testnet" }}, | |
| 204 | + {"SCUTTLEBOT_ERGO_SERVER_NAME", func(c Config) bool { return c.Ergo.ServerName == "irc.test.local" }}, | |
| 205 | + } | |
| 206 | + | |
| 207 | + envValues := map[string]string{ | |
| 208 | + "SCUTTLEBOT_API_ADDR": ":9999", | |
| 209 | + "SCUTTLEBOT_MCP_ADDR": ":9998", | |
| 210 | + "SCUTTLEBOT_DB_DRIVER": "postgres", | |
| 211 | + "SCUTTLEBOT_DB_DSN": "postgres://test", | |
| 212 | + "SCUTTLEBOT_ERGO_EXTERNAL": "true", | |
| 213 | + "SCUTTLEBOT_ERGO_API_ADDR": "http://ergo:8089", | |
| 214 | + "SCUTTLEBOT_ERGO_API_TOKEN": "tok123", | |
| 215 | + "SCUTTLEBOT_ERGO_IRC_ADDR": "ergo:6667", | |
| 216 | + "SCUTTLEBOT_ERGO_NETWORK_NAME": "testnet", | |
| 217 | + "SCUTTLEBOT_ERGO_SERVER_NAME": "irc.test.local", | |
| 218 | + } | |
| 219 | + | |
| 220 | + for _, tc := range cases { | |
| 221 | + t.Run(tc.envKey, func(t *testing.T) { | |
| 222 | + t.Setenv(tc.envKey, envValues[tc.envKey]) | |
| 223 | + var c Config | |
| 224 | + c.Defaults() | |
| 225 | + c.ApplyEnv() | |
| 226 | + if !tc.check(c) { | |
| 227 | + t.Errorf("%s=%q did not apply correctly", tc.envKey, envValues[tc.envKey]) | |
| 228 | + } | |
| 229 | + }) | |
| 230 | + } | |
| 231 | +} | |
| 232 | + | |
| 233 | +func TestApplyEnvErgoExternalFalseByDefault(t *testing.T) { | |
| 234 | + // SCUTTLEBOT_ERGO_EXTERNAL absent — should not force External=true. | |
| 235 | + var c Config | |
| 236 | + c.Defaults() | |
| 237 | + c.ApplyEnv() | |
| 238 | + if c.Ergo.External { | |
| 239 | + t.Error("Ergo.External should be false when env var is absent") | |
| 240 | + } | |
| 241 | +} | |
| 242 | + | |
| 243 | +func TestConfigSaveAndLoad(t *testing.T) { | |
| 244 | + dir := t.TempDir() | |
| 245 | + path := filepath.Join(dir, "scuttlebot.yaml") | |
| 246 | + | |
| 247 | + var orig Config | |
| 248 | + orig.Defaults() | |
| 249 | + orig.Bridge.WebUserTTLMinutes = 42 | |
| 250 | + orig.AgentPolicy.RequireCheckin = true | |
| 251 | + orig.AgentPolicy.CheckinChannel = "#fleet" | |
| 252 | + orig.Logging.Enabled = true | |
| 253 | + orig.Logging.Format = "jsonl" | |
| 254 | + | |
| 255 | + if err := orig.Save(path); err != nil { | |
| 256 | + t.Fatalf("Save: %v", err) | |
| 257 | + } | |
| 258 | + | |
| 259 | + var loaded Config | |
| 260 | + loaded.Defaults() | |
| 261 | + if err := loaded.LoadFile(path); err != nil { | |
| 262 | + t.Fatalf("LoadFile: %v", err) | |
| 263 | + } | |
| 264 | + | |
| 265 | + if loaded.Bridge.WebUserTTLMinutes != 42 { | |
| 266 | + t.Errorf("WebUserTTLMinutes = %d, want 42", loaded.Bridge.WebUserTTLMinutes) | |
| 267 | + } | |
| 268 | + if !loaded.AgentPolicy.RequireCheckin { | |
| 269 | + t.Error("AgentPolicy.RequireCheckin should be true") | |
| 270 | + } | |
| 271 | + if loaded.AgentPolicy.CheckinChannel != "#fleet" { | |
| 272 | + t.Errorf("CheckinChannel = %q, want #fleet", loaded.AgentPolicy.CheckinChannel) | |
| 273 | + } | |
| 274 | + if !loaded.Logging.Enabled { | |
| 275 | + t.Error("Logging.Enabled should be true") | |
| 276 | + } | |
| 277 | + if loaded.Logging.Format != "jsonl" { | |
| 278 | + t.Errorf("Logging.Format = %q, want jsonl", loaded.Logging.Format) | |
| 279 | + } | |
| 280 | +} | |
| 281 | + | |
| 282 | +func TestLoadFileMissingIsNotError(t *testing.T) { | |
| 283 | + var c Config | |
| 284 | + c.Defaults() | |
| 285 | + if err := c.LoadFile("/nonexistent/path/scuttlebot.yaml"); err != nil { | |
| 286 | + t.Errorf("LoadFile on missing file should return nil, got %v", err) | |
| 287 | + } | |
| 288 | +} | |
| 146 | 289 |
| --- internal/config/config_test.go | |
| +++ internal/config/config_test.go | |
| @@ -1,8 +1,9 @@ | |
| 1 | package config |
| 2 | |
| 3 | import ( |
| 4 | "os" |
| 5 | "path/filepath" |
| 6 | "testing" |
| 7 | "time" |
| 8 | ) |
| @@ -141,5 +142,147 @@ | |
| 141 | if got != tc.want { |
| 142 | t.Errorf("input %q: got %v, want %v", tc.input, got, tc.want) |
| 143 | } |
| 144 | } |
| 145 | } |
| 146 |
| --- internal/config/config_test.go | |
| +++ internal/config/config_test.go | |
| @@ -1,8 +1,9 @@ | |
| 1 | package config |
| 2 | |
| 3 | import ( |
| 4 | "encoding/json" |
| 5 | "os" |
| 6 | "path/filepath" |
| 7 | "testing" |
| 8 | "time" |
| 9 | ) |
| @@ -141,5 +142,147 @@ | |
| 142 | if got != tc.want { |
| 143 | t.Errorf("input %q: got %v, want %v", tc.input, got, tc.want) |
| 144 | } |
| 145 | } |
| 146 | } |
| 147 | |
| 148 | func TestDurationJSONRoundTrip(t *testing.T) { |
| 149 | cases := []struct { |
| 150 | dur time.Duration |
| 151 | want string |
| 152 | }{ |
| 153 | {72 * time.Hour, `"72h0m0s"`}, |
| 154 | {30 * time.Minute, `"30m0s"`}, |
| 155 | {0, `"0s"`}, |
| 156 | } |
| 157 | for _, tc := range cases { |
| 158 | d := Duration{tc.dur} |
| 159 | b, err := json.Marshal(d) |
| 160 | if err != nil { |
| 161 | t.Fatalf("Marshal(%v): %v", tc.dur, err) |
| 162 | } |
| 163 | if string(b) != tc.want { |
| 164 | t.Errorf("Marshal(%v) = %s, want %s", tc.dur, b, tc.want) |
| 165 | } |
| 166 | var back Duration |
| 167 | if err := json.Unmarshal(b, &back); err != nil { |
| 168 | t.Fatalf("Unmarshal(%s): %v", b, err) |
| 169 | } |
| 170 | if back.Duration != tc.dur { |
| 171 | t.Errorf("round-trip(%v): got %v", tc.dur, back.Duration) |
| 172 | } |
| 173 | } |
| 174 | } |
| 175 | |
| 176 | func TestDurationJSONUnmarshalErrors(t *testing.T) { |
| 177 | cases := []struct{ input string }{ |
| 178 | {`123`}, // not a quoted string |
| 179 | {`"notadur"`}, // not parseable |
| 180 | {`""`}, // empty string |
| 181 | } |
| 182 | for _, tc := range cases { |
| 183 | var d Duration |
| 184 | if err := json.Unmarshal([]byte(tc.input), &d); err == nil { |
| 185 | t.Errorf("Unmarshal(%s): expected error, got nil", tc.input) |
| 186 | } |
| 187 | } |
| 188 | } |
| 189 | |
| 190 | func TestApplyEnv(t *testing.T) { |
| 191 | cases := []struct { |
| 192 | envKey string |
| 193 | check func(c Config) bool |
| 194 | }{ |
| 195 | {"SCUTTLEBOT_API_ADDR", func(c Config) bool { return c.APIAddr == ":9999" }}, |
| 196 | {"SCUTTLEBOT_MCP_ADDR", func(c Config) bool { return c.MCPAddr == ":9998" }}, |
| 197 | {"SCUTTLEBOT_DB_DRIVER", func(c Config) bool { return c.Datastore.Driver == "postgres" }}, |
| 198 | {"SCUTTLEBOT_DB_DSN", func(c Config) bool { return c.Datastore.DSN == "postgres://test" }}, |
| 199 | {"SCUTTLEBOT_ERGO_EXTERNAL", func(c Config) bool { return c.Ergo.External }}, |
| 200 | {"SCUTTLEBOT_ERGO_API_ADDR", func(c Config) bool { return c.Ergo.APIAddr == "http://ergo:8089" }}, |
| 201 | {"SCUTTLEBOT_ERGO_API_TOKEN", func(c Config) bool { return c.Ergo.APIToken == "tok123" }}, |
| 202 | {"SCUTTLEBOT_ERGO_IRC_ADDR", func(c Config) bool { return c.Ergo.IRCAddr == "ergo:6667" }}, |
| 203 | {"SCUTTLEBOT_ERGO_NETWORK_NAME", func(c Config) bool { return c.Ergo.NetworkName == "testnet" }}, |
| 204 | {"SCUTTLEBOT_ERGO_SERVER_NAME", func(c Config) bool { return c.Ergo.ServerName == "irc.test.local" }}, |
| 205 | } |
| 206 | |
| 207 | envValues := map[string]string{ |
| 208 | "SCUTTLEBOT_API_ADDR": ":9999", |
| 209 | "SCUTTLEBOT_MCP_ADDR": ":9998", |
| 210 | "SCUTTLEBOT_DB_DRIVER": "postgres", |
| 211 | "SCUTTLEBOT_DB_DSN": "postgres://test", |
| 212 | "SCUTTLEBOT_ERGO_EXTERNAL": "true", |
| 213 | "SCUTTLEBOT_ERGO_API_ADDR": "http://ergo:8089", |
| 214 | "SCUTTLEBOT_ERGO_API_TOKEN": "tok123", |
| 215 | "SCUTTLEBOT_ERGO_IRC_ADDR": "ergo:6667", |
| 216 | "SCUTTLEBOT_ERGO_NETWORK_NAME": "testnet", |
| 217 | "SCUTTLEBOT_ERGO_SERVER_NAME": "irc.test.local", |
| 218 | } |
| 219 | |
| 220 | for _, tc := range cases { |
| 221 | t.Run(tc.envKey, func(t *testing.T) { |
| 222 | t.Setenv(tc.envKey, envValues[tc.envKey]) |
| 223 | var c Config |
| 224 | c.Defaults() |
| 225 | c.ApplyEnv() |
| 226 | if !tc.check(c) { |
| 227 | t.Errorf("%s=%q did not apply correctly", tc.envKey, envValues[tc.envKey]) |
| 228 | } |
| 229 | }) |
| 230 | } |
| 231 | } |
| 232 | |
| 233 | func TestApplyEnvErgoExternalFalseByDefault(t *testing.T) { |
| 234 | // SCUTTLEBOT_ERGO_EXTERNAL absent — should not force External=true. |
| 235 | var c Config |
| 236 | c.Defaults() |
| 237 | c.ApplyEnv() |
| 238 | if c.Ergo.External { |
| 239 | t.Error("Ergo.External should be false when env var is absent") |
| 240 | } |
| 241 | } |
| 242 | |
| 243 | func TestConfigSaveAndLoad(t *testing.T) { |
| 244 | dir := t.TempDir() |
| 245 | path := filepath.Join(dir, "scuttlebot.yaml") |
| 246 | |
| 247 | var orig Config |
| 248 | orig.Defaults() |
| 249 | orig.Bridge.WebUserTTLMinutes = 42 |
| 250 | orig.AgentPolicy.RequireCheckin = true |
| 251 | orig.AgentPolicy.CheckinChannel = "#fleet" |
| 252 | orig.Logging.Enabled = true |
| 253 | orig.Logging.Format = "jsonl" |
| 254 | |
| 255 | if err := orig.Save(path); err != nil { |
| 256 | t.Fatalf("Save: %v", err) |
| 257 | } |
| 258 | |
| 259 | var loaded Config |
| 260 | loaded.Defaults() |
| 261 | if err := loaded.LoadFile(path); err != nil { |
| 262 | t.Fatalf("LoadFile: %v", err) |
| 263 | } |
| 264 | |
| 265 | if loaded.Bridge.WebUserTTLMinutes != 42 { |
| 266 | t.Errorf("WebUserTTLMinutes = %d, want 42", loaded.Bridge.WebUserTTLMinutes) |
| 267 | } |
| 268 | if !loaded.AgentPolicy.RequireCheckin { |
| 269 | t.Error("AgentPolicy.RequireCheckin should be true") |
| 270 | } |
| 271 | if loaded.AgentPolicy.CheckinChannel != "#fleet" { |
| 272 | t.Errorf("CheckinChannel = %q, want #fleet", loaded.AgentPolicy.CheckinChannel) |
| 273 | } |
| 274 | if !loaded.Logging.Enabled { |
| 275 | t.Error("Logging.Enabled should be true") |
| 276 | } |
| 277 | if loaded.Logging.Format != "jsonl" { |
| 278 | t.Errorf("Logging.Format = %q, want jsonl", loaded.Logging.Format) |
| 279 | } |
| 280 | } |
| 281 | |
| 282 | func TestLoadFileMissingIsNotError(t *testing.T) { |
| 283 | var c Config |
| 284 | c.Defaults() |
| 285 | if err := c.LoadFile("/nonexistent/path/scuttlebot.yaml"); err != nil { |
| 286 | t.Errorf("LoadFile on missing file should return nil, got %v", err) |
| 287 | } |
| 288 | } |
| 289 |
| --- internal/registry/registry_test.go | ||
| +++ internal/registry/registry_test.go | ||
| @@ -255,5 +255,163 @@ | ||
| 255 | 255 | // Account should not have been created. |
| 256 | 256 | if p.passphrase("bad-agent") != "" { |
| 257 | 257 | t.Error("account should not be created when config is invalid") |
| 258 | 258 | } |
| 259 | 259 | } |
| 260 | + | |
| 261 | +func TestAdopt(t *testing.T) { | |
| 262 | + p := newMockProvisioner() | |
| 263 | + r := registry.New(p, testKey) | |
| 264 | + | |
| 265 | + payload, err := r.Adopt("preexisting-bot", registry.AgentTypeWorker, | |
| 266 | + cfg([]string{"#fleet"}, []string{"read"})) | |
| 267 | + if err != nil { | |
| 268 | + t.Fatalf("Adopt: %v", err) | |
| 269 | + } | |
| 270 | + if payload.Payload.Nick != "preexisting-bot" { | |
| 271 | + t.Errorf("payload Nick = %q, want preexisting-bot", payload.Payload.Nick) | |
| 272 | + } | |
| 273 | + // Adopt must NOT create a NickServ account (password should be empty in mock). | |
| 274 | + if p.passphrase("preexisting-bot") != "" { | |
| 275 | + t.Error("Adopt should not create a NickServ account") | |
| 276 | + } | |
| 277 | + // Agent should be visible in the registry. | |
| 278 | + agent, err := r.Get("preexisting-bot") | |
| 279 | + if err != nil { | |
| 280 | + t.Fatalf("Get after Adopt: %v", err) | |
| 281 | + } | |
| 282 | + if agent.Nick != "preexisting-bot" { | |
| 283 | + t.Errorf("Get Nick = %q", agent.Nick) | |
| 284 | + } | |
| 285 | +} | |
| 286 | + | |
| 287 | +func TestAdoptDuplicate(t *testing.T) { | |
| 288 | + p := newMockProvisioner() | |
| 289 | + r := registry.New(p, testKey) | |
| 290 | + | |
| 291 | + if _, err := r.Adopt("bot-dup", registry.AgentTypeWorker, registry.EngagementConfig{}); err != nil { | |
| 292 | + t.Fatalf("first Adopt: %v", err) | |
| 293 | + } | |
| 294 | + if _, err := r.Adopt("bot-dup", registry.AgentTypeWorker, registry.EngagementConfig{}); err == nil { | |
| 295 | + t.Error("expected error on duplicate Adopt, got nil") | |
| 296 | + } | |
| 297 | +} | |
| 298 | + | |
| 299 | +func TestDelete(t *testing.T) { | |
| 300 | + p := newMockProvisioner() | |
| 301 | + r := registry.New(p, testKey) | |
| 302 | + | |
| 303 | + if _, _, err := r.Register("del-agent", registry.AgentTypeWorker, registry.EngagementConfig{}); err != nil { | |
| 304 | + t.Fatalf("Register: %v", err) | |
| 305 | + } | |
| 306 | + | |
| 307 | + if err := r.Delete("del-agent"); err != nil { | |
| 308 | + t.Fatalf("Delete: %v", err) | |
| 309 | + } | |
| 310 | + | |
| 311 | + // Agent must no longer appear in List. | |
| 312 | + for _, a := range r.List() { | |
| 313 | + if a.Nick == "del-agent" { | |
| 314 | + t.Error("deleted agent should not appear in List()") | |
| 315 | + } | |
| 316 | + } | |
| 317 | + | |
| 318 | + // Get must fail. | |
| 319 | + if _, err := r.Get("del-agent"); err == nil { | |
| 320 | + t.Error("Get should fail for deleted agent") | |
| 321 | + } | |
| 322 | +} | |
| 323 | + | |
| 324 | +func TestDeleteRevoked(t *testing.T) { | |
| 325 | + // Deleting a revoked agent should succeed (lockout step skipped). | |
| 326 | + p := newMockProvisioner() | |
| 327 | + r := registry.New(p, testKey) | |
| 328 | + | |
| 329 | + if _, _, err := r.Register("rev-del", registry.AgentTypeWorker, registry.EngagementConfig{}); err != nil { | |
| 330 | + t.Fatalf("Register: %v", err) | |
| 331 | + } | |
| 332 | + if err := r.Revoke("rev-del"); err != nil { | |
| 333 | + t.Fatalf("Revoke: %v", err) | |
| 334 | + } | |
| 335 | + if err := r.Delete("rev-del"); err != nil { | |
| 336 | + t.Fatalf("Delete of revoked agent: %v", err) | |
| 337 | + } | |
| 338 | +} | |
| 339 | + | |
| 340 | +func TestDeleteNotFound(t *testing.T) { | |
| 341 | + p := newMockProvisioner() | |
| 342 | + r := registry.New(p, testKey) | |
| 343 | + if err := r.Delete("nobody"); err == nil { | |
| 344 | + t.Error("expected error deleting non-existent agent, got nil") | |
| 345 | + } | |
| 346 | +} | |
| 347 | + | |
| 348 | +func TestUpdateChannels(t *testing.T) { | |
| 349 | + p := newMockProvisioner() | |
| 350 | + r := registry.New(p, testKey) | |
| 351 | + | |
| 352 | + if _, _, err := r.Register("chan-agent", registry.AgentTypeWorker, | |
| 353 | + cfg([]string{"#fleet"}, nil)); err != nil { | |
| 354 | + t.Fatalf("Register: %v", err) | |
| 355 | + } | |
| 356 | + | |
| 357 | + newChans := []string{"#fleet", "#project.foo"} | |
| 358 | + if err := r.UpdateChannels("chan-agent", newChans); err != nil { | |
| 359 | + t.Fatalf("UpdateChannels: %v", err) | |
| 360 | + } | |
| 361 | + | |
| 362 | + agent, err := r.Get("chan-agent") | |
| 363 | + if err != nil { | |
| 364 | + t.Fatalf("Get: %v", err) | |
| 365 | + } | |
| 366 | + if len(agent.Channels) != 2 { | |
| 367 | + t.Errorf("Channels len = %d, want 2", len(agent.Channels)) | |
| 368 | + } | |
| 369 | + if agent.Channels[1] != "#project.foo" { | |
| 370 | + t.Errorf("Channels[1] = %q, want #project.foo", agent.Channels[1]) | |
| 371 | + } | |
| 372 | +} | |
| 373 | + | |
| 374 | +func TestUpdateChannelsNotFound(t *testing.T) { | |
| 375 | + p := newMockProvisioner() | |
| 376 | + r := registry.New(p, testKey) | |
| 377 | + if err := r.UpdateChannels("ghost", []string{"#fleet"}); err == nil { | |
| 378 | + t.Error("expected error for unknown agent, got nil") | |
| 379 | + } | |
| 380 | +} | |
| 381 | + | |
| 382 | +func TestSetDataPathPersistence(t *testing.T) { | |
| 383 | + dataPath := t.TempDir() + "/agents.json" | |
| 384 | + p := newMockProvisioner() | |
| 385 | + r := registry.New(p, testKey) | |
| 386 | + | |
| 387 | + if err := r.SetDataPath(dataPath); err != nil { | |
| 388 | + t.Fatalf("SetDataPath: %v", err) | |
| 389 | + } | |
| 390 | + | |
| 391 | + if _, _, err := r.Register("persist-me", registry.AgentTypeWorker, | |
| 392 | + cfg([]string{"#fleet"}, nil)); err != nil { | |
| 393 | + t.Fatalf("Register: %v", err) | |
| 394 | + } | |
| 395 | + | |
| 396 | + // New registry loaded from the same path — must contain the persisted agent. | |
| 397 | + r2 := registry.New(newMockProvisioner(), testKey) | |
| 398 | + if err := r2.SetDataPath(dataPath); err != nil { | |
| 399 | + t.Fatalf("SetDataPath (r2): %v", err) | |
| 400 | + } | |
| 401 | + | |
| 402 | + agent, err := r2.Get("persist-me") | |
| 403 | + if err != nil { | |
| 404 | + t.Fatalf("Get after reload: %v", err) | |
| 405 | + } | |
| 406 | + if agent.Nick != "persist-me" { | |
| 407 | + t.Errorf("reloaded Nick = %q, want persist-me", agent.Nick) | |
| 408 | + } | |
| 409 | +} | |
| 410 | + | |
| 411 | +func TestSetDataPathMissingFileOK(t *testing.T) { | |
| 412 | + r := registry.New(newMockProvisioner(), testKey) | |
| 413 | + // Path doesn't exist yet — should not error. | |
| 414 | + if err := r.SetDataPath(t.TempDir() + "/agents.json"); err != nil { | |
| 415 | + t.Errorf("SetDataPath on missing file: %v", err) | |
| 416 | + } | |
| 417 | +} | |
| 260 | 418 |
| --- internal/registry/registry_test.go | |
| +++ internal/registry/registry_test.go | |
| @@ -255,5 +255,163 @@ | |
| 255 | // Account should not have been created. |
| 256 | if p.passphrase("bad-agent") != "" { |
| 257 | t.Error("account should not be created when config is invalid") |
| 258 | } |
| 259 | } |
| 260 |
| --- internal/registry/registry_test.go | |
| +++ internal/registry/registry_test.go | |
| @@ -255,5 +255,163 @@ | |
| 255 | // Account should not have been created. |
| 256 | if p.passphrase("bad-agent") != "" { |
| 257 | t.Error("account should not be created when config is invalid") |
| 258 | } |
| 259 | } |
| 260 | |
| 261 | func TestAdopt(t *testing.T) { |
| 262 | p := newMockProvisioner() |
| 263 | r := registry.New(p, testKey) |
| 264 | |
| 265 | payload, err := r.Adopt("preexisting-bot", registry.AgentTypeWorker, |
| 266 | cfg([]string{"#fleet"}, []string{"read"})) |
| 267 | if err != nil { |
| 268 | t.Fatalf("Adopt: %v", err) |
| 269 | } |
| 270 | if payload.Payload.Nick != "preexisting-bot" { |
| 271 | t.Errorf("payload Nick = %q, want preexisting-bot", payload.Payload.Nick) |
| 272 | } |
| 273 | // Adopt must NOT create a NickServ account (password should be empty in mock). |
| 274 | if p.passphrase("preexisting-bot") != "" { |
| 275 | t.Error("Adopt should not create a NickServ account") |
| 276 | } |
| 277 | // Agent should be visible in the registry. |
| 278 | agent, err := r.Get("preexisting-bot") |
| 279 | if err != nil { |
| 280 | t.Fatalf("Get after Adopt: %v", err) |
| 281 | } |
| 282 | if agent.Nick != "preexisting-bot" { |
| 283 | t.Errorf("Get Nick = %q", agent.Nick) |
| 284 | } |
| 285 | } |
| 286 | |
| 287 | func TestAdoptDuplicate(t *testing.T) { |
| 288 | p := newMockProvisioner() |
| 289 | r := registry.New(p, testKey) |
| 290 | |
| 291 | if _, err := r.Adopt("bot-dup", registry.AgentTypeWorker, registry.EngagementConfig{}); err != nil { |
| 292 | t.Fatalf("first Adopt: %v", err) |
| 293 | } |
| 294 | if _, err := r.Adopt("bot-dup", registry.AgentTypeWorker, registry.EngagementConfig{}); err == nil { |
| 295 | t.Error("expected error on duplicate Adopt, got nil") |
| 296 | } |
| 297 | } |
| 298 | |
| 299 | func TestDelete(t *testing.T) { |
| 300 | p := newMockProvisioner() |
| 301 | r := registry.New(p, testKey) |
| 302 | |
| 303 | if _, _, err := r.Register("del-agent", registry.AgentTypeWorker, registry.EngagementConfig{}); err != nil { |
| 304 | t.Fatalf("Register: %v", err) |
| 305 | } |
| 306 | |
| 307 | if err := r.Delete("del-agent"); err != nil { |
| 308 | t.Fatalf("Delete: %v", err) |
| 309 | } |
| 310 | |
| 311 | // Agent must no longer appear in List. |
| 312 | for _, a := range r.List() { |
| 313 | if a.Nick == "del-agent" { |
| 314 | t.Error("deleted agent should not appear in List()") |
| 315 | } |
| 316 | } |
| 317 | |
| 318 | // Get must fail. |
| 319 | if _, err := r.Get("del-agent"); err == nil { |
| 320 | t.Error("Get should fail for deleted agent") |
| 321 | } |
| 322 | } |
| 323 | |
| 324 | func TestDeleteRevoked(t *testing.T) { |
| 325 | // Deleting a revoked agent should succeed (lockout step skipped). |
| 326 | p := newMockProvisioner() |
| 327 | r := registry.New(p, testKey) |
| 328 | |
| 329 | if _, _, err := r.Register("rev-del", registry.AgentTypeWorker, registry.EngagementConfig{}); err != nil { |
| 330 | t.Fatalf("Register: %v", err) |
| 331 | } |
| 332 | if err := r.Revoke("rev-del"); err != nil { |
| 333 | t.Fatalf("Revoke: %v", err) |
| 334 | } |
| 335 | if err := r.Delete("rev-del"); err != nil { |
| 336 | t.Fatalf("Delete of revoked agent: %v", err) |
| 337 | } |
| 338 | } |
| 339 | |
| 340 | func TestDeleteNotFound(t *testing.T) { |
| 341 | p := newMockProvisioner() |
| 342 | r := registry.New(p, testKey) |
| 343 | if err := r.Delete("nobody"); err == nil { |
| 344 | t.Error("expected error deleting non-existent agent, got nil") |
| 345 | } |
| 346 | } |
| 347 | |
| 348 | func TestUpdateChannels(t *testing.T) { |
| 349 | p := newMockProvisioner() |
| 350 | r := registry.New(p, testKey) |
| 351 | |
| 352 | if _, _, err := r.Register("chan-agent", registry.AgentTypeWorker, |
| 353 | cfg([]string{"#fleet"}, nil)); err != nil { |
| 354 | t.Fatalf("Register: %v", err) |
| 355 | } |
| 356 | |
| 357 | newChans := []string{"#fleet", "#project.foo"} |
| 358 | if err := r.UpdateChannels("chan-agent", newChans); err != nil { |
| 359 | t.Fatalf("UpdateChannels: %v", err) |
| 360 | } |
| 361 | |
| 362 | agent, err := r.Get("chan-agent") |
| 363 | if err != nil { |
| 364 | t.Fatalf("Get: %v", err) |
| 365 | } |
| 366 | if len(agent.Channels) != 2 { |
| 367 | t.Errorf("Channels len = %d, want 2", len(agent.Channels)) |
| 368 | } |
| 369 | if agent.Channels[1] != "#project.foo" { |
| 370 | t.Errorf("Channels[1] = %q, want #project.foo", agent.Channels[1]) |
| 371 | } |
| 372 | } |
| 373 | |
| 374 | func TestUpdateChannelsNotFound(t *testing.T) { |
| 375 | p := newMockProvisioner() |
| 376 | r := registry.New(p, testKey) |
| 377 | if err := r.UpdateChannels("ghost", []string{"#fleet"}); err == nil { |
| 378 | t.Error("expected error for unknown agent, got nil") |
| 379 | } |
| 380 | } |
| 381 | |
| 382 | func TestSetDataPathPersistence(t *testing.T) { |
| 383 | dataPath := t.TempDir() + "/agents.json" |
| 384 | p := newMockProvisioner() |
| 385 | r := registry.New(p, testKey) |
| 386 | |
| 387 | if err := r.SetDataPath(dataPath); err != nil { |
| 388 | t.Fatalf("SetDataPath: %v", err) |
| 389 | } |
| 390 | |
| 391 | if _, _, err := r.Register("persist-me", registry.AgentTypeWorker, |
| 392 | cfg([]string{"#fleet"}, nil)); err != nil { |
| 393 | t.Fatalf("Register: %v", err) |
| 394 | } |
| 395 | |
| 396 | // New registry loaded from the same path — must contain the persisted agent. |
| 397 | r2 := registry.New(newMockProvisioner(), testKey) |
| 398 | if err := r2.SetDataPath(dataPath); err != nil { |
| 399 | t.Fatalf("SetDataPath (r2): %v", err) |
| 400 | } |
| 401 | |
| 402 | agent, err := r2.Get("persist-me") |
| 403 | if err != nil { |
| 404 | t.Fatalf("Get after reload: %v", err) |
| 405 | } |
| 406 | if agent.Nick != "persist-me" { |
| 407 | t.Errorf("reloaded Nick = %q, want persist-me", agent.Nick) |
| 408 | } |
| 409 | } |
| 410 | |
| 411 | func TestSetDataPathMissingFileOK(t *testing.T) { |
| 412 | r := registry.New(newMockProvisioner(), testKey) |
| 413 | // Path doesn't exist yet — should not error. |
| 414 | if err := r.SetDataPath(t.TempDir() + "/agents.json"); err != nil { |
| 415 | t.Errorf("SetDataPath on missing file: %v", err) |
| 416 | } |
| 417 | } |
| 418 |
+65
-8
| --- pkg/sessionrelay/http.go | ||
| +++ pkg/sessionrelay/http.go | ||
| @@ -18,10 +18,14 @@ | ||
| 18 | 18 | baseURL string |
| 19 | 19 | token string |
| 20 | 20 | primary string |
| 21 | 21 | nick string |
| 22 | 22 | |
| 23 | + agentType string | |
| 24 | + deleteOnClose bool | |
| 25 | + registeredByConnector bool | |
| 26 | + | |
| 23 | 27 | mu sync.RWMutex |
| 24 | 28 | channels []string |
| 25 | 29 | } |
| 26 | 30 | |
| 27 | 31 | type httpMessage struct { |
| @@ -30,26 +34,63 @@ | ||
| 30 | 34 | Text string `json:"text"` |
| 31 | 35 | } |
| 32 | 36 | |
| 33 | 37 | func newHTTPConnector(cfg Config) Connector { |
| 34 | 38 | return &httpConnector{ |
| 35 | - http: cfg.HTTPClient, | |
| 36 | - baseURL: stringsTrimRightSlash(cfg.URL), | |
| 37 | - token: cfg.Token, | |
| 38 | - primary: normalizeChannel(cfg.Channel), | |
| 39 | - nick: cfg.Nick, | |
| 40 | - channels: append([]string(nil), cfg.Channels...), | |
| 39 | + http: cfg.HTTPClient, | |
| 40 | + baseURL: stringsTrimRightSlash(cfg.URL), | |
| 41 | + token: cfg.Token, | |
| 42 | + primary: normalizeChannel(cfg.Channel), | |
| 43 | + nick: cfg.Nick, | |
| 44 | + agentType: cfg.IRC.AgentType, | |
| 45 | + deleteOnClose: cfg.IRC.DeleteOnClose, | |
| 46 | + channels: append([]string(nil), cfg.Channels...), | |
| 41 | 47 | } |
| 42 | 48 | } |
| 43 | 49 | |
| 44 | -func (c *httpConnector) Connect(context.Context) error { | |
| 50 | +func (c *httpConnector) Connect(ctx context.Context) error { | |
| 45 | 51 | if c.baseURL == "" { |
| 46 | 52 | return fmt.Errorf("sessionrelay: http transport requires url") |
| 47 | 53 | } |
| 48 | 54 | if c.token == "" { |
| 49 | 55 | return fmt.Errorf("sessionrelay: http transport requires token") |
| 50 | 56 | } |
| 57 | + if c.nick != "" { | |
| 58 | + if err := c.registerAgent(ctx); err != nil { | |
| 59 | + return err | |
| 60 | + } | |
| 61 | + } | |
| 62 | + return nil | |
| 63 | +} | |
| 64 | + | |
| 65 | +func (c *httpConnector) registerAgent(ctx context.Context) error { | |
| 66 | + body, _ := json.Marshal(map[string]any{ | |
| 67 | + "nick": c.nick, | |
| 68 | + "type": c.agentType, | |
| 69 | + "channels": c.Channels(), | |
| 70 | + }) | |
| 71 | + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/agents/register", bytes.NewReader(body)) | |
| 72 | + if err != nil { | |
| 73 | + return err | |
| 74 | + } | |
| 75 | + c.authorize(req) | |
| 76 | + req.Header.Set("Content-Type", "application/json") | |
| 77 | + | |
| 78 | + resp, err := c.http.Do(req) | |
| 79 | + if err != nil { | |
| 80 | + return err | |
| 81 | + } | |
| 82 | + defer resp.Body.Close() | |
| 83 | + | |
| 84 | + switch resp.StatusCode { | |
| 85 | + case http.StatusCreated: | |
| 86 | + c.registeredByConnector = true | |
| 87 | + case http.StatusConflict: | |
| 88 | + // agent already exists; registration is best-effort, not an error | |
| 89 | + default: | |
| 90 | + return fmt.Errorf("sessionrelay: register %s: %s", c.nick, resp.Status) | |
| 91 | + } | |
| 51 | 92 | return nil |
| 52 | 93 | } |
| 53 | 94 | |
| 54 | 95 | func (c *httpConnector) Post(ctx context.Context, text string) error { |
| 55 | 96 | for _, channel := range c.Channels() { |
| @@ -175,11 +216,27 @@ | ||
| 175 | 216 | |
| 176 | 217 | func (c *httpConnector) ControlChannel() string { |
| 177 | 218 | return c.primary |
| 178 | 219 | } |
| 179 | 220 | |
| 180 | -func (c *httpConnector) Close(context.Context) error { | |
| 221 | +func (c *httpConnector) Close(ctx context.Context) error { | |
| 222 | + if !c.deleteOnClose || !c.registeredByConnector || c.baseURL == "" || c.token == "" { | |
| 223 | + return nil | |
| 224 | + } | |
| 225 | + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.baseURL+"/v1/agents/"+c.nick, nil) | |
| 226 | + if err != nil { | |
| 227 | + return err | |
| 228 | + } | |
| 229 | + c.authorize(req) | |
| 230 | + resp, err := c.http.Do(req) | |
| 231 | + if err != nil { | |
| 232 | + return err | |
| 233 | + } | |
| 234 | + defer resp.Body.Close() | |
| 235 | + if resp.StatusCode/100 != 2 && resp.StatusCode != http.StatusNotFound { | |
| 236 | + return fmt.Errorf("sessionrelay: delete %s: %s", c.nick, resp.Status) | |
| 237 | + } | |
| 181 | 238 | return nil |
| 182 | 239 | } |
| 183 | 240 | |
| 184 | 241 | func (c *httpConnector) postJSON(ctx context.Context, path string, body any) error { |
| 185 | 242 | data, _ := json.Marshal(body) |
| 186 | 243 |
| --- pkg/sessionrelay/http.go | |
| +++ pkg/sessionrelay/http.go | |
| @@ -18,10 +18,14 @@ | |
| 18 | baseURL string |
| 19 | token string |
| 20 | primary string |
| 21 | nick string |
| 22 | |
| 23 | mu sync.RWMutex |
| 24 | channels []string |
| 25 | } |
| 26 | |
| 27 | type httpMessage struct { |
| @@ -30,26 +34,63 @@ | |
| 30 | Text string `json:"text"` |
| 31 | } |
| 32 | |
| 33 | func newHTTPConnector(cfg Config) Connector { |
| 34 | return &httpConnector{ |
| 35 | http: cfg.HTTPClient, |
| 36 | baseURL: stringsTrimRightSlash(cfg.URL), |
| 37 | token: cfg.Token, |
| 38 | primary: normalizeChannel(cfg.Channel), |
| 39 | nick: cfg.Nick, |
| 40 | channels: append([]string(nil), cfg.Channels...), |
| 41 | } |
| 42 | } |
| 43 | |
| 44 | func (c *httpConnector) Connect(context.Context) error { |
| 45 | if c.baseURL == "" { |
| 46 | return fmt.Errorf("sessionrelay: http transport requires url") |
| 47 | } |
| 48 | if c.token == "" { |
| 49 | return fmt.Errorf("sessionrelay: http transport requires token") |
| 50 | } |
| 51 | return nil |
| 52 | } |
| 53 | |
| 54 | func (c *httpConnector) Post(ctx context.Context, text string) error { |
| 55 | for _, channel := range c.Channels() { |
| @@ -175,11 +216,27 @@ | |
| 175 | |
| 176 | func (c *httpConnector) ControlChannel() string { |
| 177 | return c.primary |
| 178 | } |
| 179 | |
| 180 | func (c *httpConnector) Close(context.Context) error { |
| 181 | return nil |
| 182 | } |
| 183 | |
| 184 | func (c *httpConnector) postJSON(ctx context.Context, path string, body any) error { |
| 185 | data, _ := json.Marshal(body) |
| 186 |
| --- pkg/sessionrelay/http.go | |
| +++ pkg/sessionrelay/http.go | |
| @@ -18,10 +18,14 @@ | |
| 18 | baseURL string |
| 19 | token string |
| 20 | primary string |
| 21 | nick string |
| 22 | |
| 23 | agentType string |
| 24 | deleteOnClose bool |
| 25 | registeredByConnector bool |
| 26 | |
| 27 | mu sync.RWMutex |
| 28 | channels []string |
| 29 | } |
| 30 | |
| 31 | type httpMessage struct { |
| @@ -30,26 +34,63 @@ | |
| 34 | Text string `json:"text"` |
| 35 | } |
| 36 | |
| 37 | func newHTTPConnector(cfg Config) Connector { |
| 38 | return &httpConnector{ |
| 39 | http: cfg.HTTPClient, |
| 40 | baseURL: stringsTrimRightSlash(cfg.URL), |
| 41 | token: cfg.Token, |
| 42 | primary: normalizeChannel(cfg.Channel), |
| 43 | nick: cfg.Nick, |
| 44 | agentType: cfg.IRC.AgentType, |
| 45 | deleteOnClose: cfg.IRC.DeleteOnClose, |
| 46 | channels: append([]string(nil), cfg.Channels...), |
| 47 | } |
| 48 | } |
| 49 | |
| 50 | func (c *httpConnector) Connect(ctx context.Context) error { |
| 51 | if c.baseURL == "" { |
| 52 | return fmt.Errorf("sessionrelay: http transport requires url") |
| 53 | } |
| 54 | if c.token == "" { |
| 55 | return fmt.Errorf("sessionrelay: http transport requires token") |
| 56 | } |
| 57 | if c.nick != "" { |
| 58 | if err := c.registerAgent(ctx); err != nil { |
| 59 | return err |
| 60 | } |
| 61 | } |
| 62 | return nil |
| 63 | } |
| 64 | |
| 65 | func (c *httpConnector) registerAgent(ctx context.Context) error { |
| 66 | body, _ := json.Marshal(map[string]any{ |
| 67 | "nick": c.nick, |
| 68 | "type": c.agentType, |
| 69 | "channels": c.Channels(), |
| 70 | }) |
| 71 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/agents/register", bytes.NewReader(body)) |
| 72 | if err != nil { |
| 73 | return err |
| 74 | } |
| 75 | c.authorize(req) |
| 76 | req.Header.Set("Content-Type", "application/json") |
| 77 | |
| 78 | resp, err := c.http.Do(req) |
| 79 | if err != nil { |
| 80 | return err |
| 81 | } |
| 82 | defer resp.Body.Close() |
| 83 | |
| 84 | switch resp.StatusCode { |
| 85 | case http.StatusCreated: |
| 86 | c.registeredByConnector = true |
| 87 | case http.StatusConflict: |
| 88 | // agent already exists; registration is best-effort, not an error |
| 89 | default: |
| 90 | return fmt.Errorf("sessionrelay: register %s: %s", c.nick, resp.Status) |
| 91 | } |
| 92 | return nil |
| 93 | } |
| 94 | |
| 95 | func (c *httpConnector) Post(ctx context.Context, text string) error { |
| 96 | for _, channel := range c.Channels() { |
| @@ -175,11 +216,27 @@ | |
| 216 | |
| 217 | func (c *httpConnector) ControlChannel() string { |
| 218 | return c.primary |
| 219 | } |
| 220 | |
| 221 | func (c *httpConnector) Close(ctx context.Context) error { |
| 222 | if !c.deleteOnClose || !c.registeredByConnector || c.baseURL == "" || c.token == "" { |
| 223 | return nil |
| 224 | } |
| 225 | req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.baseURL+"/v1/agents/"+c.nick, nil) |
| 226 | if err != nil { |
| 227 | return err |
| 228 | } |
| 229 | c.authorize(req) |
| 230 | resp, err := c.http.Do(req) |
| 231 | if err != nil { |
| 232 | return err |
| 233 | } |
| 234 | defer resp.Body.Close() |
| 235 | if resp.StatusCode/100 != 2 && resp.StatusCode != http.StatusNotFound { |
| 236 | return fmt.Errorf("sessionrelay: delete %s: %s", c.nick, resp.Status) |
| 237 | } |
| 238 | return nil |
| 239 | } |
| 240 | |
| 241 | func (c *httpConnector) postJSON(ctx context.Context, path string, body any) error { |
| 242 | data, _ := json.Marshal(body) |
| 243 |
+77
-16
| --- pkg/sessionrelay/irc.go | ||
| +++ pkg/sessionrelay/irc.go | ||
| @@ -53,10 +53,15 @@ | ||
| 53 | 53 | channels: append([]string(nil), cfg.Channels...), |
| 54 | 54 | messages: make([]Message, 0, defaultBufferSize), |
| 55 | 55 | errCh: make(chan error, 1), |
| 56 | 56 | }, nil |
| 57 | 57 | } |
| 58 | + | |
| 59 | +const ( | |
| 60 | + ircReconnectMin = 2 * time.Second | |
| 61 | + ircReconnectMax = 30 * time.Second | |
| 62 | +) | |
| 58 | 63 | |
| 59 | 64 | func (c *ircConnector) Connect(ctx context.Context) error { |
| 60 | 65 | if err := c.ensureCredentials(ctx); err != nil { |
| 61 | 66 | return err |
| 62 | 67 | } |
| @@ -66,10 +71,34 @@ | ||
| 66 | 71 | return err |
| 67 | 72 | } |
| 68 | 73 | |
| 69 | 74 | joined := make(chan struct{}) |
| 70 | 75 | var joinOnce sync.Once |
| 76 | + c.dial(host, port, func() { joinOnce.Do(func() { close(joined) }) }) | |
| 77 | + | |
| 78 | + select { | |
| 79 | + case <-ctx.Done(): | |
| 80 | + c.mu.Lock() | |
| 81 | + if c.client != nil { | |
| 82 | + c.client.Close() | |
| 83 | + } | |
| 84 | + c.mu.Unlock() | |
| 85 | + return ctx.Err() | |
| 86 | + case err := <-c.errCh: | |
| 87 | + _ = c.cleanupRegistration(context.Background()) | |
| 88 | + return fmt.Errorf("sessionrelay: irc connect: %w", err) | |
| 89 | + case <-joined: | |
| 90 | + go c.keepAlive(ctx, host, port) | |
| 91 | + return nil | |
| 92 | + } | |
| 93 | +} | |
| 94 | + | |
| 95 | +// dial creates a fresh girc client, wires up handlers, and starts the | |
| 96 | +// connection goroutine. onJoined fires once when the primary channel is | |
| 97 | +// joined — used as the initial-connect signal and to reset backoff on | |
| 98 | +// successful reconnects. | |
| 99 | +func (c *ircConnector) dial(host string, port int, onJoined func()) { | |
| 71 | 100 | client := girc.New(girc.Config{ |
| 72 | 101 | Server: host, |
| 73 | 102 | Port: port, |
| 74 | 103 | Nick: c.nick, |
| 75 | 104 | User: c.nick, |
| @@ -86,11 +115,13 @@ | ||
| 86 | 115 | return |
| 87 | 116 | } |
| 88 | 117 | if normalizeChannel(e.Params[0]) != c.primary { |
| 89 | 118 | return |
| 90 | 119 | } |
| 91 | - joinOnce.Do(func() { close(joined) }) | |
| 120 | + if onJoined != nil { | |
| 121 | + onJoined() | |
| 122 | + } | |
| 92 | 123 | }) |
| 93 | 124 | client.Handlers.AddBg(girc.PRIVMSG, func(_ *girc.Client, e girc.Event) { |
| 94 | 125 | if len(e.Params) < 1 || e.Source == nil { |
| 95 | 126 | return |
| 96 | 127 | } |
| @@ -107,51 +138,78 @@ | ||
| 107 | 138 | } |
| 108 | 139 | } |
| 109 | 140 | c.appendMessage(Message{At: time.Now(), Channel: target, Nick: sender, Text: text}) |
| 110 | 141 | }) |
| 111 | 142 | |
| 143 | + c.mu.Lock() | |
| 112 | 144 | c.client = client |
| 145 | + c.mu.Unlock() | |
| 146 | + | |
| 113 | 147 | go func() { |
| 114 | - if err := client.Connect(); err != nil && ctx.Err() == nil { | |
| 148 | + if err := client.Connect(); err != nil { | |
| 115 | 149 | select { |
| 116 | 150 | case c.errCh <- err: |
| 117 | 151 | default: |
| 118 | 152 | } |
| 119 | 153 | } |
| 120 | 154 | }() |
| 121 | - | |
| 122 | - select { | |
| 123 | - case <-ctx.Done(): | |
| 124 | - client.Close() | |
| 125 | - return ctx.Err() | |
| 126 | - case err := <-c.errCh: | |
| 127 | - _ = c.cleanupRegistration(context.Background()) | |
| 128 | - return fmt.Errorf("sessionrelay: irc connect: %w", err) | |
| 129 | - case <-joined: | |
| 130 | - return nil | |
| 155 | +} | |
| 156 | + | |
| 157 | +// keepAlive watches for connection errors and redials with exponential backoff. | |
| 158 | +// It stops when ctx is cancelled (i.e. the broker is shutting down). | |
| 159 | +func (c *ircConnector) keepAlive(ctx context.Context, host string, port int) { | |
| 160 | + wait := ircReconnectMin | |
| 161 | + for { | |
| 162 | + select { | |
| 163 | + case <-ctx.Done(): | |
| 164 | + return | |
| 165 | + case <-c.errCh: | |
| 166 | + } | |
| 167 | + | |
| 168 | + // Close the dead client before replacing it. | |
| 169 | + c.mu.Lock() | |
| 170 | + if c.client != nil { | |
| 171 | + c.client.Close() | |
| 172 | + c.client = nil | |
| 173 | + } | |
| 174 | + c.mu.Unlock() | |
| 175 | + | |
| 176 | + select { | |
| 177 | + case <-ctx.Done(): | |
| 178 | + return | |
| 179 | + case <-time.After(wait): | |
| 180 | + } | |
| 181 | + wait = min(wait*2, ircReconnectMax) | |
| 182 | + c.dial(host, port, func() { wait = ircReconnectMin }) | |
| 131 | 183 | } |
| 132 | 184 | } |
| 133 | 185 | |
| 134 | 186 | func (c *ircConnector) Post(_ context.Context, text string) error { |
| 135 | - if c.client == nil { | |
| 187 | + c.mu.RLock() | |
| 188 | + client := c.client | |
| 189 | + c.mu.RUnlock() | |
| 190 | + if client == nil { | |
| 136 | 191 | return fmt.Errorf("sessionrelay: irc client not connected") |
| 137 | 192 | } |
| 138 | 193 | for _, channel := range c.Channels() { |
| 139 | - c.client.Cmd.Message(channel, text) | |
| 194 | + client.Cmd.Message(channel, text) | |
| 140 | 195 | } |
| 141 | 196 | return nil |
| 142 | 197 | } |
| 143 | 198 | |
| 144 | 199 | func (c *ircConnector) PostTo(_ context.Context, channel, text string) error { |
| 145 | - if c.client == nil { | |
| 200 | + c.mu.RLock() | |
| 201 | + client := c.client | |
| 202 | + c.mu.RUnlock() | |
| 203 | + if client == nil { | |
| 146 | 204 | return fmt.Errorf("sessionrelay: irc client not connected") |
| 147 | 205 | } |
| 148 | 206 | channel = normalizeChannel(channel) |
| 149 | 207 | if channel == "" { |
| 150 | 208 | return fmt.Errorf("sessionrelay: post channel is required") |
| 151 | 209 | } |
| 152 | - c.client.Cmd.Message(channel, text) | |
| 210 | + client.Cmd.Message(channel, text) | |
| 153 | 211 | return nil |
| 154 | 212 | } |
| 155 | 213 | |
| 156 | 214 | func (c *ircConnector) MessagesSince(_ context.Context, since time.Time) ([]Message, error) { |
| 157 | 215 | c.mu.RLock() |
| @@ -253,13 +311,16 @@ | ||
| 253 | 311 | func (c *ircConnector) ControlChannel() string { |
| 254 | 312 | return c.primary |
| 255 | 313 | } |
| 256 | 314 | |
| 257 | 315 | func (c *ircConnector) Close(ctx context.Context) error { |
| 316 | + c.mu.Lock() | |
| 258 | 317 | if c.client != nil { |
| 259 | 318 | c.client.Close() |
| 319 | + c.client = nil | |
| 260 | 320 | } |
| 321 | + c.mu.Unlock() | |
| 261 | 322 | return c.cleanupRegistration(ctx) |
| 262 | 323 | } |
| 263 | 324 | |
| 264 | 325 | func (c *ircConnector) appendMessage(msg Message) { |
| 265 | 326 | c.mu.Lock() |
| 266 | 327 |
| --- pkg/sessionrelay/irc.go | |
| +++ pkg/sessionrelay/irc.go | |
| @@ -53,10 +53,15 @@ | |
| 53 | channels: append([]string(nil), cfg.Channels...), |
| 54 | messages: make([]Message, 0, defaultBufferSize), |
| 55 | errCh: make(chan error, 1), |
| 56 | }, nil |
| 57 | } |
| 58 | |
| 59 | func (c *ircConnector) Connect(ctx context.Context) error { |
| 60 | if err := c.ensureCredentials(ctx); err != nil { |
| 61 | return err |
| 62 | } |
| @@ -66,10 +71,34 @@ | |
| 66 | return err |
| 67 | } |
| 68 | |
| 69 | joined := make(chan struct{}) |
| 70 | var joinOnce sync.Once |
| 71 | client := girc.New(girc.Config{ |
| 72 | Server: host, |
| 73 | Port: port, |
| 74 | Nick: c.nick, |
| 75 | User: c.nick, |
| @@ -86,11 +115,13 @@ | |
| 86 | return |
| 87 | } |
| 88 | if normalizeChannel(e.Params[0]) != c.primary { |
| 89 | return |
| 90 | } |
| 91 | joinOnce.Do(func() { close(joined) }) |
| 92 | }) |
| 93 | client.Handlers.AddBg(girc.PRIVMSG, func(_ *girc.Client, e girc.Event) { |
| 94 | if len(e.Params) < 1 || e.Source == nil { |
| 95 | return |
| 96 | } |
| @@ -107,51 +138,78 @@ | |
| 107 | } |
| 108 | } |
| 109 | c.appendMessage(Message{At: time.Now(), Channel: target, Nick: sender, Text: text}) |
| 110 | }) |
| 111 | |
| 112 | c.client = client |
| 113 | go func() { |
| 114 | if err := client.Connect(); err != nil && ctx.Err() == nil { |
| 115 | select { |
| 116 | case c.errCh <- err: |
| 117 | default: |
| 118 | } |
| 119 | } |
| 120 | }() |
| 121 | |
| 122 | select { |
| 123 | case <-ctx.Done(): |
| 124 | client.Close() |
| 125 | return ctx.Err() |
| 126 | case err := <-c.errCh: |
| 127 | _ = c.cleanupRegistration(context.Background()) |
| 128 | return fmt.Errorf("sessionrelay: irc connect: %w", err) |
| 129 | case <-joined: |
| 130 | return nil |
| 131 | } |
| 132 | } |
| 133 | |
| 134 | func (c *ircConnector) Post(_ context.Context, text string) error { |
| 135 | if c.client == nil { |
| 136 | return fmt.Errorf("sessionrelay: irc client not connected") |
| 137 | } |
| 138 | for _, channel := range c.Channels() { |
| 139 | c.client.Cmd.Message(channel, text) |
| 140 | } |
| 141 | return nil |
| 142 | } |
| 143 | |
| 144 | func (c *ircConnector) PostTo(_ context.Context, channel, text string) error { |
| 145 | if c.client == nil { |
| 146 | return fmt.Errorf("sessionrelay: irc client not connected") |
| 147 | } |
| 148 | channel = normalizeChannel(channel) |
| 149 | if channel == "" { |
| 150 | return fmt.Errorf("sessionrelay: post channel is required") |
| 151 | } |
| 152 | c.client.Cmd.Message(channel, text) |
| 153 | return nil |
| 154 | } |
| 155 | |
| 156 | func (c *ircConnector) MessagesSince(_ context.Context, since time.Time) ([]Message, error) { |
| 157 | c.mu.RLock() |
| @@ -253,13 +311,16 @@ | |
| 253 | func (c *ircConnector) ControlChannel() string { |
| 254 | return c.primary |
| 255 | } |
| 256 | |
| 257 | func (c *ircConnector) Close(ctx context.Context) error { |
| 258 | if c.client != nil { |
| 259 | c.client.Close() |
| 260 | } |
| 261 | return c.cleanupRegistration(ctx) |
| 262 | } |
| 263 | |
| 264 | func (c *ircConnector) appendMessage(msg Message) { |
| 265 | c.mu.Lock() |
| 266 |
| --- pkg/sessionrelay/irc.go | |
| +++ pkg/sessionrelay/irc.go | |
| @@ -53,10 +53,15 @@ | |
| 53 | channels: append([]string(nil), cfg.Channels...), |
| 54 | messages: make([]Message, 0, defaultBufferSize), |
| 55 | errCh: make(chan error, 1), |
| 56 | }, nil |
| 57 | } |
| 58 | |
| 59 | const ( |
| 60 | ircReconnectMin = 2 * time.Second |
| 61 | ircReconnectMax = 30 * time.Second |
| 62 | ) |
| 63 | |
| 64 | func (c *ircConnector) Connect(ctx context.Context) error { |
| 65 | if err := c.ensureCredentials(ctx); err != nil { |
| 66 | return err |
| 67 | } |
| @@ -66,10 +71,34 @@ | |
| 71 | return err |
| 72 | } |
| 73 | |
| 74 | joined := make(chan struct{}) |
| 75 | var joinOnce sync.Once |
| 76 | c.dial(host, port, func() { joinOnce.Do(func() { close(joined) }) }) |
| 77 | |
| 78 | select { |
| 79 | case <-ctx.Done(): |
| 80 | c.mu.Lock() |
| 81 | if c.client != nil { |
| 82 | c.client.Close() |
| 83 | } |
| 84 | c.mu.Unlock() |
| 85 | return ctx.Err() |
| 86 | case err := <-c.errCh: |
| 87 | _ = c.cleanupRegistration(context.Background()) |
| 88 | return fmt.Errorf("sessionrelay: irc connect: %w", err) |
| 89 | case <-joined: |
| 90 | go c.keepAlive(ctx, host, port) |
| 91 | return nil |
| 92 | } |
| 93 | } |
| 94 | |
| 95 | // dial creates a fresh girc client, wires up handlers, and starts the |
| 96 | // connection goroutine. onJoined fires once when the primary channel is |
| 97 | // joined — used as the initial-connect signal and to reset backoff on |
| 98 | // successful reconnects. |
| 99 | func (c *ircConnector) dial(host string, port int, onJoined func()) { |
| 100 | client := girc.New(girc.Config{ |
| 101 | Server: host, |
| 102 | Port: port, |
| 103 | Nick: c.nick, |
| 104 | User: c.nick, |
| @@ -86,11 +115,13 @@ | |
| 115 | return |
| 116 | } |
| 117 | if normalizeChannel(e.Params[0]) != c.primary { |
| 118 | return |
| 119 | } |
| 120 | if onJoined != nil { |
| 121 | onJoined() |
| 122 | } |
| 123 | }) |
| 124 | client.Handlers.AddBg(girc.PRIVMSG, func(_ *girc.Client, e girc.Event) { |
| 125 | if len(e.Params) < 1 || e.Source == nil { |
| 126 | return |
| 127 | } |
| @@ -107,51 +138,78 @@ | |
| 138 | } |
| 139 | } |
| 140 | c.appendMessage(Message{At: time.Now(), Channel: target, Nick: sender, Text: text}) |
| 141 | }) |
| 142 | |
| 143 | c.mu.Lock() |
| 144 | c.client = client |
| 145 | c.mu.Unlock() |
| 146 | |
| 147 | go func() { |
| 148 | if err := client.Connect(); err != nil { |
| 149 | select { |
| 150 | case c.errCh <- err: |
| 151 | default: |
| 152 | } |
| 153 | } |
| 154 | }() |
| 155 | } |
| 156 | |
| 157 | // keepAlive watches for connection errors and redials with exponential backoff. |
| 158 | // It stops when ctx is cancelled (i.e. the broker is shutting down). |
| 159 | func (c *ircConnector) keepAlive(ctx context.Context, host string, port int) { |
| 160 | wait := ircReconnectMin |
| 161 | for { |
| 162 | select { |
| 163 | case <-ctx.Done(): |
| 164 | return |
| 165 | case <-c.errCh: |
| 166 | } |
| 167 | |
| 168 | // Close the dead client before replacing it. |
| 169 | c.mu.Lock() |
| 170 | if c.client != nil { |
| 171 | c.client.Close() |
| 172 | c.client = nil |
| 173 | } |
| 174 | c.mu.Unlock() |
| 175 | |
| 176 | select { |
| 177 | case <-ctx.Done(): |
| 178 | return |
| 179 | case <-time.After(wait): |
| 180 | } |
| 181 | wait = min(wait*2, ircReconnectMax) |
| 182 | c.dial(host, port, func() { wait = ircReconnectMin }) |
| 183 | } |
| 184 | } |
| 185 | |
| 186 | func (c *ircConnector) Post(_ context.Context, text string) error { |
| 187 | c.mu.RLock() |
| 188 | client := c.client |
| 189 | c.mu.RUnlock() |
| 190 | if client == nil { |
| 191 | return fmt.Errorf("sessionrelay: irc client not connected") |
| 192 | } |
| 193 | for _, channel := range c.Channels() { |
| 194 | client.Cmd.Message(channel, text) |
| 195 | } |
| 196 | return nil |
| 197 | } |
| 198 | |
| 199 | func (c *ircConnector) PostTo(_ context.Context, channel, text string) error { |
| 200 | c.mu.RLock() |
| 201 | client := c.client |
| 202 | c.mu.RUnlock() |
| 203 | if client == nil { |
| 204 | return fmt.Errorf("sessionrelay: irc client not connected") |
| 205 | } |
| 206 | channel = normalizeChannel(channel) |
| 207 | if channel == "" { |
| 208 | return fmt.Errorf("sessionrelay: post channel is required") |
| 209 | } |
| 210 | client.Cmd.Message(channel, text) |
| 211 | return nil |
| 212 | } |
| 213 | |
| 214 | func (c *ircConnector) MessagesSince(_ context.Context, since time.Time) ([]Message, error) { |
| 215 | c.mu.RLock() |
| @@ -253,13 +311,16 @@ | |
| 311 | func (c *ircConnector) ControlChannel() string { |
| 312 | return c.primary |
| 313 | } |
| 314 | |
| 315 | func (c *ircConnector) Close(ctx context.Context) error { |
| 316 | c.mu.Lock() |
| 317 | if c.client != nil { |
| 318 | c.client.Close() |
| 319 | c.client = nil |
| 320 | } |
| 321 | c.mu.Unlock() |
| 322 | return c.cleanupRegistration(ctx) |
| 323 | } |
| 324 | |
| 325 | func (c *ircConnector) appendMessage(msg Message) { |
| 326 | c.mu.Lock() |
| 327 |
| --- pkg/sessionrelay/sessionrelay_test.go | ||
| +++ pkg/sessionrelay/sessionrelay_test.go | ||
| @@ -57,10 +57,12 @@ | ||
| 57 | 57 | }}) |
| 58 | 58 | case r.Method == http.MethodGet && r.URL.Path == "/v1/channels/release/messages": |
| 59 | 59 | _ = json.NewEncoder(w).Encode(map[string]any{"messages": []map[string]string{ |
| 60 | 60 | {"at": base.Add(2 * time.Second).Format(time.RFC3339Nano), "nick": "glengoolie", "text": "codex-test: /join #task-42"}, |
| 61 | 61 | }}) |
| 62 | + case r.Method == http.MethodPost && r.URL.Path == "/v1/agents/register": | |
| 63 | + w.WriteHeader(http.StatusCreated) | |
| 62 | 64 | default: |
| 63 | 65 | http.NotFound(w, r) |
| 64 | 66 | } |
| 65 | 67 | })) |
| 66 | 68 | defer srv.Close() |
| 67 | 69 |
| --- pkg/sessionrelay/sessionrelay_test.go | |
| +++ pkg/sessionrelay/sessionrelay_test.go | |
| @@ -57,10 +57,12 @@ | |
| 57 | }}) |
| 58 | case r.Method == http.MethodGet && r.URL.Path == "/v1/channels/release/messages": |
| 59 | _ = json.NewEncoder(w).Encode(map[string]any{"messages": []map[string]string{ |
| 60 | {"at": base.Add(2 * time.Second).Format(time.RFC3339Nano), "nick": "glengoolie", "text": "codex-test: /join #task-42"}, |
| 61 | }}) |
| 62 | default: |
| 63 | http.NotFound(w, r) |
| 64 | } |
| 65 | })) |
| 66 | defer srv.Close() |
| 67 |
| --- pkg/sessionrelay/sessionrelay_test.go | |
| +++ pkg/sessionrelay/sessionrelay_test.go | |
| @@ -57,10 +57,12 @@ | |
| 57 | }}) |
| 58 | case r.Method == http.MethodGet && r.URL.Path == "/v1/channels/release/messages": |
| 59 | _ = json.NewEncoder(w).Encode(map[string]any{"messages": []map[string]string{ |
| 60 | {"at": base.Add(2 * time.Second).Format(time.RFC3339Nano), "nick": "glengoolie", "text": "codex-test: /join #task-42"}, |
| 61 | }}) |
| 62 | case r.Method == http.MethodPost && r.URL.Path == "/v1/agents/register": |
| 63 | w.WriteHeader(http.StatusCreated) |
| 64 | default: |
| 65 | http.NotFound(w, r) |
| 66 | } |
| 67 | })) |
| 68 | defer srv.Close() |
| 69 |