mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-28 03:20:11 +00:00
Harden AI session storage paths
This commit is contained in:
parent
f88d89622a
commit
d5b4850715
2 changed files with 108 additions and 7 deletions
|
|
@ -1,6 +1,8 @@
|
|||
package chat
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
|
@ -66,9 +68,30 @@ func NewSessionStore(dataDir string) (*SessionStore, error) {
|
|||
|
||||
// sessionPath returns the file path for a session
|
||||
func (s *SessionStore) sessionPath(id string) string {
|
||||
return filepath.Join(s.dataDir, hashedSessionStorageName(id)+".json")
|
||||
}
|
||||
|
||||
func (s *SessionStore) legacySessionPath(id string) string {
|
||||
return filepath.Join(s.dataDir, id+".json")
|
||||
}
|
||||
|
||||
func hashedSessionStorageName(id string) string {
|
||||
sum := sha256.Sum256([]byte(id))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func resolveSessionPath(primaryPath, legacyPath string) string {
|
||||
if _, err := os.Stat(primaryPath); err == nil {
|
||||
return primaryPath
|
||||
}
|
||||
if legacyPath != "" {
|
||||
if _, err := os.Stat(legacyPath); err == nil {
|
||||
return legacyPath
|
||||
}
|
||||
}
|
||||
return primaryPath
|
||||
}
|
||||
|
||||
// List returns all sessions, sorted by updated_at descending
|
||||
func (s *SessionStore) List() ([]Session, error) {
|
||||
s.mu.RLock()
|
||||
|
|
@ -85,11 +108,16 @@ func (s *SessionStore) List() ([]Session, error) {
|
|||
continue
|
||||
}
|
||||
|
||||
data, err := s.readSession(strings.TrimSuffix(entry.Name(), ".json"))
|
||||
file, err := os.ReadFile(filepath.Join(s.dataDir, entry.Name()))
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("file", entry.Name()).Msg("Failed to read session file")
|
||||
continue
|
||||
}
|
||||
var data sessionData
|
||||
if err := json.Unmarshal(file, &data); err != nil {
|
||||
log.Warn().Err(err).Str("file", entry.Name()).Msg("Failed to parse session file")
|
||||
continue
|
||||
}
|
||||
|
||||
sessions = append(sessions, Session{
|
||||
ID: data.ID,
|
||||
|
|
@ -158,12 +186,20 @@ func (s *SessionStore) Delete(id string) error {
|
|||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
path := s.sessionPath(id)
|
||||
if err := os.Remove(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("session not found: %s", id)
|
||||
primaryPath := s.sessionPath(id)
|
||||
legacyPath := s.legacySessionPath(id)
|
||||
var removed bool
|
||||
for _, path := range []string{primaryPath, legacyPath} {
|
||||
if err := os.Remove(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("failed to delete session: %w", err)
|
||||
}
|
||||
return fmt.Errorf("failed to delete session: %w", err)
|
||||
removed = true
|
||||
}
|
||||
if !removed {
|
||||
return fmt.Errorf("session not found: %s", id)
|
||||
}
|
||||
|
||||
// Also clean up resolved context, FSM, and knowledge accumulator
|
||||
|
|
@ -239,7 +275,7 @@ func (s *SessionStore) UpdateLastMessage(id string, msg Message) error {
|
|||
|
||||
// readSession reads a session from disk (caller must hold lock)
|
||||
func (s *SessionStore) readSession(id string) (*sessionData, error) {
|
||||
path := s.sessionPath(id)
|
||||
path := resolveSessionPath(s.sessionPath(id), s.legacySessionPath(id))
|
||||
file, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
|
|
@ -267,6 +303,10 @@ func (s *SessionStore) writeSession(data sessionData) error {
|
|||
if err := os.WriteFile(path, file, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write session: %w", err)
|
||||
}
|
||||
legacyPath := s.legacySessionPath(data.ID)
|
||||
if legacyPath != path {
|
||||
_ = os.Remove(legacyPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
package chat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -117,3 +119,62 @@ func TestGenerateTitle(t *testing.T) {
|
|||
assert.Equal(t, tt.expected, generateTitle(tt.input))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_HashedPathsAndLegacyCompatibility(t *testing.T) {
|
||||
store, err := NewSessionStore(t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
session := sessionData{
|
||||
ID: "legacy-session",
|
||||
Title: "Legacy Title",
|
||||
Messages: []Message{{Role: "user", Content: "hello"}},
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
path := store.sessionPath(session.ID)
|
||||
assert.Equal(t, filepath.Join(store.dataDir, hashedSessionStorageName(session.ID)+".json"), path)
|
||||
assert.NotContains(t, filepath.Base(path), "..")
|
||||
|
||||
legacyPath := store.legacySessionPath(session.ID)
|
||||
raw, err := json.Marshal(session)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(legacyPath, raw, 0600))
|
||||
|
||||
got, err := store.Get(session.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, session.ID, got.ID)
|
||||
assert.Equal(t, session.Title, got.Title)
|
||||
|
||||
sessions, err := store.List()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, sessions, 1)
|
||||
assert.Equal(t, session.ID, sessions[0].ID)
|
||||
}
|
||||
|
||||
func TestSessionStore_PathTraversalIDsStayWithinStore(t *testing.T) {
|
||||
store, err := NewSessionStore(t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.writeSession(sessionData{
|
||||
ID: "..",
|
||||
Title: "Traversal",
|
||||
Messages: []Message{},
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
path := store.sessionPath("..")
|
||||
rel, err := filepath.Rel(store.dataDir, path)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, strings.HasPrefix(rel, ".."))
|
||||
|
||||
got, err := store.Get("..")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "..", got.ID)
|
||||
|
||||
require.NoError(t, store.Delete(".."))
|
||||
_, err = store.Get("..")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue