-
Notifications
You must be signed in to change notification settings - Fork 104
feat: implement state persistence #177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: cj/refactor-conversation-orig
Are you sure you want to change the base?
Changes from all commits
a0f8bb5
ca3cdff
1c224e9
30f82d7
12bed1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -103,6 +103,26 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er | |
| } | ||
| } | ||
|
|
||
| // Get the variables related to state management | ||
| stateFile := viper.GetString(StateFile) | ||
| loadState := true | ||
| saveState := true | ||
| if stateFile != "" { | ||
| if !viper.IsSet(LoadState) { | ||
| loadState = true | ||
| } else { | ||
| loadState = viper.GetBool(LoadState) | ||
| } | ||
|
|
||
| if !viper.IsSet(SaveState) { | ||
| saveState = true | ||
| } else { | ||
| saveState = viper.GetBool(SaveState) | ||
| } | ||
| } | ||
|
|
||
| pidFile := viper.GetString(PidFile) | ||
|
|
||
| printOpenAPI := viper.GetBool(FlagPrintOpenAPI) | ||
| var process *termexec.Process | ||
| if printOpenAPI { | ||
|
|
@@ -128,7 +148,14 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er | |
| AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), | ||
| AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), | ||
| InitialPrompt: initialPrompt, | ||
| StatePersistenceCfg: httpapi.StatePersistenceCfg{ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: Naming |
||
| StateFile: stateFile, | ||
| LoadState: loadState, | ||
| SaveState: saveState, | ||
| PidFile: pidFile, | ||
| }, | ||
| }) | ||
|
|
||
| if err != nil { | ||
| return xerrors.Errorf("failed to create server: %w", err) | ||
| } | ||
|
|
@@ -137,6 +164,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er | |
| return nil | ||
| } | ||
| srv.StartSnapshotLoop(ctx) | ||
| srv.HandleSignals(ctx, process) | ||
| logger.Info("Starting server on port", "port", port) | ||
| processExitCh := make(chan error, 1) | ||
| go func() { | ||
|
|
@@ -152,7 +180,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er | |
| logger.Error("Failed to stop server", "error", err) | ||
| } | ||
| }() | ||
| if err := srv.Start(); err != nil && err != context.Canceled && err != http.ErrServerClosed { | ||
| if err := srv.Start(); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) { | ||
| return xerrors.Errorf("failed to start server: %w", err) | ||
| } | ||
| select { | ||
|
|
@@ -191,6 +219,10 @@ const ( | |
| FlagAllowedOrigins = "allowed-origins" | ||
| FlagExit = "exit" | ||
| FlagInitialPrompt = "initial-prompt" | ||
| StateFile = "state-file" | ||
| LoadState = "load-state" | ||
| SaveState = "save-state" | ||
| PidFile = "pid-file" | ||
| ) | ||
|
|
||
| func CreateServerCmd() *cobra.Command { | ||
|
|
@@ -229,6 +261,10 @@ func CreateServerCmd() *cobra.Command { | |
| // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. | ||
| {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, | ||
| {FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"}, | ||
| {StateFile, "s", "", "Path to file for saving/loading server state", "string"}, | ||
| {LoadState, "", false, "Load state from state-file on startup (defaults to true when state-file is set)", "bool"}, | ||
| {SaveState, "", false, "Save state to state-file on shutdown (defaults to true when state-file is set)", "bool"}, | ||
| {PidFile, "", "", "Path to file where the server process ID will be written for shutdown scripts", "string"}, | ||
| } | ||
|
|
||
| for _, spec := range flagSpecs { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,18 +34,20 @@ import ( | |
|
|
||
| // Server represents the HTTP server | ||
| type Server struct { | ||
| router chi.Router | ||
| api huma.API | ||
| port int | ||
| srv *http.Server | ||
| mu sync.RWMutex | ||
| logger *slog.Logger | ||
| conversation *st.PTYConversation | ||
| agentio *termexec.Process | ||
| agentType mf.AgentType | ||
| emitter *EventEmitter | ||
| chatBasePath string | ||
| tempDir string | ||
| router chi.Router | ||
| api huma.API | ||
| port int | ||
| srv *http.Server | ||
| mu sync.RWMutex | ||
| logger *slog.Logger | ||
| conversation *st.PTYConversation | ||
| agentio *termexec.Process | ||
| agentType mf.AgentType | ||
| emitter *EventEmitter | ||
| chatBasePath string | ||
| tempDir string | ||
| statePersistenceCfg StatePersistenceCfg | ||
| stateLoadComplete bool | ||
| } | ||
|
|
||
| func (s *Server) NormalizeSchema(schema any) any { | ||
|
|
@@ -94,14 +96,22 @@ func (s *Server) GetOpenAPI() string { | |
| // because the action of taking a snapshot takes time too. | ||
| const snapshotInterval = 25 * time.Millisecond | ||
|
|
||
| type StatePersistenceCfg struct { | ||
| StateFile string | ||
| LoadState bool | ||
| SaveState bool | ||
| PidFile string | ||
| } | ||
|
|
||
| type ServerConfig struct { | ||
| AgentType mf.AgentType | ||
| Process *termexec.Process | ||
| Port int | ||
| ChatBasePath string | ||
| AllowedHosts []string | ||
| AllowedOrigins []string | ||
| InitialPrompt string | ||
| AgentType mf.AgentType | ||
| Process *termexec.Process | ||
| Port int | ||
| ChatBasePath string | ||
| AllowedHosts []string | ||
| AllowedOrigins []string | ||
| InitialPrompt string | ||
| StatePersistenceCfg StatePersistenceCfg | ||
| } | ||
|
|
||
| // Validate allowed hosts don't contain whitespace, commas, schemes, or ports. | ||
|
|
@@ -260,16 +270,18 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { | |
| logger.Info("Created temporary directory for uploads", "tempDir", tempDir) | ||
|
|
||
| s := &Server{ | ||
| router: router, | ||
| api: api, | ||
| port: config.Port, | ||
| conversation: conversation, | ||
| logger: logger, | ||
| agentio: config.Process, | ||
| agentType: config.AgentType, | ||
| emitter: emitter, | ||
| chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), | ||
| tempDir: tempDir, | ||
| router: router, | ||
| api: api, | ||
| port: config.Port, | ||
| conversation: conversation, | ||
| logger: logger, | ||
| agentio: config.Process, | ||
| agentType: config.AgentType, | ||
| emitter: emitter, | ||
| chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), | ||
| tempDir: tempDir, | ||
| statePersistenceCfg: config.StatePersistenceCfg, | ||
| stateLoadComplete: false, | ||
| } | ||
|
|
||
| // Register API routes | ||
|
|
@@ -336,16 +348,22 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) { | |
| for { | ||
| currentStatus := s.conversation.Status() | ||
|
|
||
| // Send initial prompt when agent becomes stable for the first time | ||
| if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable { | ||
| // Send initial prompt & load state when agent becomes stable for the first time | ||
| if convertStatus(currentStatus) == AgentStatusStable { | ||
|
|
||
| if err := s.conversation.Send(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { | ||
| s.logger.Error("Failed to send initial prompt", "error", err) | ||
| } else { | ||
| s.conversation.InitialPromptSent = true | ||
| s.conversation.ReadyForInitialPrompt = false | ||
| currentStatus = st.ConversationStatusChanging | ||
| s.logger.Info("Initial prompt sent successfully") | ||
| if !s.stateLoadComplete && s.statePersistenceCfg.LoadState { | ||
| _, _ = s.conversation.LoadState(s.statePersistenceCfg.StateFile) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not objecting, just curious. Why do we wait for stability to load the state? |
||
| s.stateLoadComplete = true | ||
| } | ||
| if !s.conversation.InitialPromptSent { | ||
| if err := s.conversation.Send(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { | ||
| s.logger.Error("Failed to send initial prompt", "error", err) | ||
| } else { | ||
| s.conversation.InitialPromptSent = true | ||
| s.conversation.ReadyForInitialPrompt = false | ||
| currentStatus = st.ConversationStatusChanging | ||
| s.logger.Info("Initial prompt sent successfully") | ||
| } | ||
| } | ||
| } | ||
| s.emitter.UpdateStatusAndEmitChanges(currentStatus, s.agentType) | ||
|
|
@@ -587,6 +605,11 @@ func (s *Server) Start() error { | |
| Handler: s.router, | ||
| } | ||
|
|
||
| // Write PID file if configured | ||
| if err := s.writePIDFile(); err != nil { | ||
| return xerrors.Errorf("failed to write PID file: %w", err) | ||
| } | ||
|
|
||
| return s.srv.ListenAndServe() | ||
| } | ||
|
|
||
|
|
@@ -610,6 +633,70 @@ func (s *Server) cleanupTempDir() { | |
| } | ||
| } | ||
|
|
||
| // writePIDFile writes the current process ID to the configured PID file | ||
| func (s *Server) writePIDFile() error { | ||
| if s.statePersistenceCfg.PidFile == "" { | ||
| return nil | ||
| } | ||
|
|
||
| pid := os.Getpid() | ||
| pidContent := fmt.Sprintf("%d\n", pid) | ||
|
|
||
| // Create directory if it doesn't exist | ||
| dir := filepath.Dir(s.statePersistenceCfg.PidFile) | ||
| if err := os.MkdirAll(dir, 0o755); err != nil { | ||
| return xerrors.Errorf("failed to create PID file directory: %w", err) | ||
| } | ||
|
|
||
| // Write PID file | ||
| if err := os.WriteFile(s.statePersistenceCfg.PidFile, []byte(pidContent), 0o644); err != nil { | ||
| return xerrors.Errorf("failed to write PID file: %w", err) | ||
| } | ||
|
|
||
| s.logger.Info("Wrote PID file", "pidFile", s.statePersistenceCfg.PidFile, "pid", pid) | ||
| return nil | ||
| } | ||
|
|
||
| // cleanupPIDFile removes the PID file if it exists | ||
| func (s *Server) cleanupPIDFile() { | ||
| if s.statePersistenceCfg.PidFile == "" { | ||
| return | ||
| } | ||
|
|
||
| if err := os.Remove(s.statePersistenceCfg.PidFile); err != nil && !os.IsNotExist(err) { | ||
| s.logger.Error("Failed to remove PID file", "pidFile", s.statePersistenceCfg.PidFile, "error", err) | ||
| } else if err == nil { | ||
| s.logger.Info("Removed PID file", "pidFile", s.statePersistenceCfg.PidFile) | ||
| } | ||
| } | ||
|
|
||
| // saveAndCleanup saves the conversation state and cleans up before shutdown | ||
| func (s *Server) saveAndCleanup(sig os.Signal, process *termexec.Process) { | ||
| // Save conversation state if configured (synchronously before closing process) | ||
| s.saveStateIfConfigured(sig.String()) | ||
|
|
||
| // Clean up PID file | ||
| s.cleanupPIDFile() | ||
|
|
||
| // Now close the process | ||
| if err := process.Close(s.logger, 5*time.Second); err != nil { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It feels a bit strange to control the process from this "far in". To me it would make sense to invert some of this control, i.e. change how things are wired up in If we close here, won't we likely be logging an error in |
||
| s.logger.Error("Error closing process", "signal", sig, "error", err) | ||
| } | ||
| } | ||
|
|
||
| // saveStateIfConfigured saves the conversation state if configured | ||
| func (s *Server) saveStateIfConfigured(source string) { | ||
| if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { | ||
| if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { | ||
| s.logger.Error("Failed to save conversation state", "source", source, "error", err) | ||
| } else { | ||
| s.logger.Info("Saved conversation state", "source", source, "stateFile", s.statePersistenceCfg.StateFile) | ||
| } | ||
| } else { | ||
| s.logger.Warn("Save requested but state saving is not configured", "source", source) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't this print save requested for regular stop signals like SIGTERM? I think this log is only applicable for USR1. |
||
| } | ||
| } | ||
|
|
||
| // registerStaticFileRoutes sets up routes for serving static files | ||
| func (s *Server) registerStaticFileRoutes() { | ||
| chatHandler := FileServerWithIndexFallback(s.chatBasePath) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,42 @@ | ||||||||
| //go:build unix | ||||||||
|
|
||||||||
| package httpapi | ||||||||
|
|
||||||||
| import ( | ||||||||
| "context" | ||||||||
| "os" | ||||||||
| "os/signal" | ||||||||
| "syscall" | ||||||||
|
|
||||||||
| "github.com/coder/agentapi/lib/termexec" | ||||||||
| ) | ||||||||
|
|
||||||||
| // HandleSignals sets up signal handlers for: | ||||||||
| // - SIGTERM, SIGINT, SIGHUP: save conversation state, then close the process | ||||||||
| // - SIGUSR1: save conversation state without exiting | ||||||||
| func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { | ||||||||
| // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) | ||||||||
| shutdownCh := make(chan os.Signal, 1) | ||||||||
| signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) | ||||||||
| go func() { | ||||||||
| sig := <-shutdownCh | ||||||||
| s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) | ||||||||
|
|
||||||||
| s.saveAndCleanup(sig, process) | ||||||||
| }() | ||||||||
|
|
||||||||
| // Handle SIGUSR1 for save without exit | ||||||||
| saveOnlyCh := make(chan os.Signal, 1) | ||||||||
| signal.Notify(saveOnlyCh, syscall.SIGUSR1) | ||||||||
| go func() { | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Suggestion: Good practice to unregister on teardown. |
||||||||
| for { | ||||||||
| select { | ||||||||
| case <-saveOnlyCh: | ||||||||
| s.logger.Info("Received SIGUSR1, saving state without exiting") | ||||||||
| s.saveStateIfConfigured("SIGUSR1") | ||||||||
| case <-ctx.Done(): | ||||||||
| return | ||||||||
| } | ||||||||
| } | ||||||||
| }() | ||||||||
| } | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| //go:build windows | ||
|
|
||
| package httpapi | ||
|
|
||
| import ( | ||
| "context" | ||
| "os" | ||
| "os/signal" | ||
| "syscall" | ||
|
|
||
| "github.com/coder/agentapi/lib/termexec" | ||
| ) | ||
|
|
||
| // HandleSignals sets up signal handlers for Windows. | ||
| // Only handles SIGTERM and SIGINT (SIGHUP and SIGUSR1 don't exist on Windows). | ||
| func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { | ||
| // Handle shutdown signals (SIGTERM, SIGINT only on Windows) | ||
| shutdownCh := make(chan os.Signal, 1) | ||
| signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this compile on Windows? IIRC we can only support |
||
| go func() { | ||
| sig := <-shutdownCh | ||
| s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) | ||
|
|
||
| s.saveAndCleanup(sig, process) | ||
| }() | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: I think it'd make sense to move pid file handling here rather than httpapi as it becomes a bit disconnected and here we can write it early.