package sqlite

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"path/filepath"
	"sync"
	"time"

	"github.com/aarondl/opt/omit"
	"github.com/aarondl/opt/omitnull"
	migrate "github.com/rubenv/sql-migrate"
	sqldblogger "github.com/simukti/sqldb-logger"
	"github.com/stephenafamo/bob"
	"github.com/stephenafamo/bob/dialect/sqlite"
	"github.com/stephenafamo/bob/dialect/sqlite/im"
	"github.com/stephenafamo/bob/dialect/sqlite/um"
	_ "modernc.org/sqlite"

	"github.com/safing/portmaster/base/database/accessor"
	"github.com/safing/portmaster/base/database/iterator"
	"github.com/safing/portmaster/base/database/query"
	"github.com/safing/portmaster/base/database/record"
	"github.com/safing/portmaster/base/database/storage"
	"github.com/safing/portmaster/base/database/storage/sqlite/models"
	"github.com/safing/portmaster/base/log"
	"github.com/safing/structures/dsd"
)

// Errors.
var (
	ErrQueryTimeout = errors.New("query timeout")
)

// SQLite storage.
type SQLite struct {
	name string

	db  *sql.DB
	bob bob.DB
	wg  sync.WaitGroup

	ctx       context.Context
	cancelCtx context.CancelFunc
}

func init() {
	_ = storage.Register("sqlite", func(name, location string) (storage.Interface, error) {
		return NewSQLite(name, location)
	})
}

// NewSQLite creates a sqlite database.
func NewSQLite(name, location string) (*SQLite, error) {
	return openSQLite(name, location, false)
}

// openSQLite creates a sqlite database.
func openSQLite(name, location string, printStmts bool) (*SQLite, error) {
	dbFile := filepath.Join(location, "db.sqlite")

	// Open database file.
	// Default settings:
	// _time_format = YYYY-MM-DDTHH:MM:SS.SSS
	// _txlock = deferred
	db, err := sql.Open("sqlite", dbFile)
	if err != nil {
		return nil, fmt.Errorf("open sqlite: %w", err)
	}

	// Enable statement printing.
	if printStmts {
		db = sqldblogger.OpenDriver(dbFile, db.Driver(), &statementLogger{})
	}

	// Set other settings.
	pragmas := []string{
		"PRAGMA journal_mode=WAL;",   // Corruption safe write ahead log for txs.
		"PRAGMA synchronous=NORMAL;", // Best for WAL.
		"PRAGMA cache_size=-10000;",  // 10MB Cache.
	}
	for _, pragma := range pragmas {
		_, err := db.Exec(pragma)
		if err != nil {
			return nil, fmt.Errorf("failed to init sqlite with %s: %w", pragma, err)
		}
	}

	// Run migrations on database.
	n, err := migrate.Exec(db, "sqlite3", getMigrations(), migrate.Up)
	if err != nil {
		return nil, fmt.Errorf("migrate sqlite: %w", err)
	}
	log.Debugf("database/sqlite: ran %d migrations on %s database", n, name)

	// Return as bob database.
	ctx, cancelCtx := context.WithCancel(context.Background())
	return &SQLite{
		name:      name,
		bob:       bob.NewDB(db),
		ctx:       ctx,
		cancelCtx: cancelCtx,
	}, nil
}

// Get returns a database record.
func (db *SQLite) Get(key string) (record.Record, error) {
	db.wg.Add(1)
	defer db.wg.Done()

	// Get record from database.
	r, err := models.FindRecord(db.ctx, db.bob, key)
	if err != nil {
		return nil, fmt.Errorf("%w: %w", storage.ErrNotFound, err)
	}

	// Return data in wrapper.
	return record.NewWrapperFromDatabase(
		db.name,
		key,
		getMeta(r),
		uint8(r.Format.GetOrZero()), //nolint:gosec // Values are within uint8.
		r.Value.GetOrZero(),
	)
}

// GetMeta returns the metadata of a database record.
func (db *SQLite) GetMeta(key string) (*record.Meta, error) {
	r, err := db.Get(key)
	if err != nil {
		return nil, err
	}

	return r.Meta(), nil
}

// Put stores a record in the database.
func (db *SQLite) Put(r record.Record) (record.Record, error) {
	return db.putRecord(r, nil)
}

func (db *SQLite) putRecord(r record.Record, tx *bob.Tx) (record.Record, error) {
	db.wg.Add(1)
	defer db.wg.Done()

	// Lock record if in a transaction.
	if tx != nil {
		r.Lock()
		defer r.Unlock()
	}

	// Serialize to JSON.
	data, err := r.MarshalDataOnly(r, dsd.JSON)
	if err != nil {
		return nil, err
	}
	// Prepare for setter.
	setFormat := omitnull.From(int16(dsd.JSON))
	setData := omitnull.From(data)
	if len(data) == 0 {
		setFormat.Null()
		setData.Null()
	}

	// Create structure for insert.
	m := r.Meta()
	setter := models.RecordSetter{
		Key:        omit.From(r.DatabaseKey()),
		Format:     setFormat,
		Value:      setData,
		Created:    omit.From(m.Created),
		Modified:   omit.From(m.Modified),
		Expires:    omit.From(m.Expires),
		Deleted:    omit.From(m.Deleted),
		Secret:     omit.From(m.IsSecret()),
		Crownjewel: omit.From(m.IsCrownJewel()),
	}

	// Simulate upsert with custom selection on conflict.
	dbQuery := models.Records.Insert(
		&setter,
		im.OnConflict("key").DoUpdate(
			im.SetExcluded("format", "value", "created", "modified", "expires", "deleted", "secret", "crownjewel"),
		),
	)

	// Execute in transaction or directly.
	if tx != nil {
		_, err = dbQuery.Exec(db.ctx, tx)
	} else {
		_, err = dbQuery.Exec(db.ctx, db.bob)
	}
	if err != nil {
		return nil, err
	}

	return r, nil
}

// PutMany stores many records in the database.
func (db *SQLite) PutMany(shadowDelete bool) (chan<- record.Record, <-chan error) {
	db.wg.Add(1)
	defer db.wg.Done()

	batch := make(chan record.Record, 100)
	errs := make(chan error, 1)

	tx, err := db.bob.BeginTx(db.ctx, nil)
	if err != nil {
		errs <- err
		return batch, errs
	}

	// start handler
	go func() {
		// Read all put records.
	writeBatch:
		for {
			select {
			case r := <-batch:
				if r != nil {
					// Write record.
					_, err := db.putRecord(r, &tx)
					if err != nil {
						errs <- err
						break writeBatch
					}
				} else {
					// Finalize transcation.
					errs <- tx.Commit()
					return
				}

			case <-db.ctx.Done():
				break writeBatch
			}
		}

		// Rollback transaction.
		errs <- tx.Rollback()
	}()

	return batch, errs
}

// Delete deletes a record from the database.
func (db *SQLite) Delete(key string) error {
	db.wg.Add(1)
	defer db.wg.Done()

	toDelete := &models.Record{Key: key}
	return toDelete.Delete(db.ctx, db.bob)
}

// Query returns a an iterator for the supplied query.
func (db *SQLite) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
	db.wg.Add(1)
	defer db.wg.Done()

	_, err := q.Check()
	if err != nil {
		return nil, fmt.Errorf("invalid query: %w", err)
	}

	queryIter := iterator.New()

	go db.queryExecutor(queryIter, q, local, internal)
	return queryIter, nil
}

func (db *SQLite) queryExecutor(queryIter *iterator.Iterator, q *query.Query, local, internal bool) {
	db.wg.Add(1)
	defer db.wg.Done()

	// Build query.
	var recordQuery *sqlite.ViewQuery[*models.Record, models.RecordSlice]
	if q.DatabaseKeyPrefix() != "" {
		recordQuery = models.Records.View.Query(
			models.SelectWhere.Records.Key.Like(q.DatabaseKeyPrefix() + "%"),
		)
	} else {
		recordQuery = models.Records.View.Query()
	}

	// Get cursor to go over all records in the query.
	cursor, err := models.RecordsQuery.Cursor(recordQuery, db.ctx, db.bob)
	if err != nil {
		queryIter.Finish(err)
		return
	}
	defer func() {
		_ = cursor.Close()
	}()

recordsLoop:
	for cursor.Next() {
		// Get next record
		r, cErr := cursor.Get()
		if cErr != nil {
			err = fmt.Errorf("cursor error: %w", cErr)
			break recordsLoop
		}

		// Check if key matches.
		if !q.MatchesKey(r.Key) {
			continue recordsLoop
		}

		// Check Meta.
		m := getMeta(r)
		if !m.CheckValidity() ||
			!m.CheckPermission(local, internal) {
			continue recordsLoop
		}

		// Check Data.
		if q.HasWhereCondition() {
			if r.Format.IsNull() || r.Value.IsNull() {
				continue recordsLoop
			}

			jsonData := string(r.Value.GetOrZero())
			jsonAccess := accessor.NewJSONAccessor(&jsonData)
			if !q.MatchesAccessor(jsonAccess) {
				continue recordsLoop
			}
		}

		// Build database record.
		matched, _ := record.NewWrapperFromDatabase(
			db.name,
			r.Key,
			m,
			uint8(r.Format.GetOrZero()), //nolint:gosec // Values are within uint8.
			r.Value.GetOrZero(),
		)

		select {
		case <-queryIter.Done:
			break recordsLoop
		case queryIter.Next <- matched:
		default:
			select {
			case <-queryIter.Done:
				break recordsLoop
			case queryIter.Next <- matched:
			case <-time.After(1 * time.Second):
				err = ErrQueryTimeout
				break recordsLoop
			}
		}

	}

	queryIter.Finish(err)
}

// Purge deletes all records that match the given query. It returns the number of successful deletes and an error.
func (db *SQLite) Purge(ctx context.Context, q *query.Query, local, internal, shadowDelete bool) (int, error) {
	db.wg.Add(1)
	defer db.wg.Done()

	// Optimize for local and internal queries without where clause and without shadow delete.
	if local && internal && !shadowDelete && !q.HasWhereCondition() {
		// First count entries (SQLite does not support affected rows)
		n, err := models.Records.Query(
			models.SelectWhere.Records.Key.Like(q.DatabaseKeyPrefix()+"%"),
		).Count(db.ctx, db.bob)
		if err != nil || n == 0 {
			return int(n), err
		}

		// Delete entries.
		_, err = models.Records.Delete(
			models.DeleteWhere.Records.Key.Like(q.DatabaseKeyPrefix()+"%"),
		).Exec(db.ctx, db.bob)
		return int(n), err
	}

	// Optimize for local and internal queries without where clause, but with shadow delete.
	if local && internal && shadowDelete && !q.HasWhereCondition() {
		// First count entries (SQLite does not support affected rows)
		n, err := models.Records.Query(
			models.SelectWhere.Records.Key.Like(q.DatabaseKeyPrefix()+"%"),
		).Count(db.ctx, db.bob)
		if err != nil || n == 0 {
			return int(n), err
		}

		// Mark purged records as deleted.
		now := time.Now().Unix()
		_, err = models.Records.Update(
			um.SetCol("format").ToArg(nil),
			um.SetCol("value").ToArg(nil),
			um.SetCol("deleted").ToArg(now),
			models.UpdateWhere.Records.Key.Like(q.DatabaseKeyPrefix()+"%"),
		).Exec(db.ctx, db.bob)
		return int(n), err
	}

	// Otherwise, iterate over all entries and delete matching ones.
	return 0, storage.ErrNotImplemented

	// Create iterator to check all matching records.

	// TODO: This is untested and also needs handling of shadowDelete.
	// For now: Use only without where condition and with a local and internal db interface.
	// queryIter := iterator.New()
	// defer queryIter.Cancel()
	// go db.queryExecutor(queryIter, q, local, internal)

	// // Delete all matching records.
	// var deleted int
	// for r := range queryIter.Next {
	// 	db.Delete(r.DatabaseKey())
	// 	deleted++
	// }

	// return deleted, nil
}

// ReadOnly returns whether the database is read only.
func (db *SQLite) ReadOnly() bool {
	return false
}

// Injected returns whether the database is injected.
func (db *SQLite) Injected() bool {
	return false
}

// MaintainRecordStates maintains records states in the database.
func (db *SQLite) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error {
	db.wg.Add(1)
	defer db.wg.Done()

	now := time.Now().Unix()
	purgeThreshold := purgeDeletedBefore.Unix()

	// Option 1: Using shadow delete.
	if shadowDelete {
		// Mark expired records as deleted.
		_, err := models.Records.Update(
			um.SetCol("format").ToArg(nil),
			um.SetCol("value").ToArg(nil),
			um.SetCol("deleted").ToArg(now),
			models.UpdateWhere.Records.Deleted.EQ(0),
			models.UpdateWhere.Records.Expires.GT(0),
			models.UpdateWhere.Records.Expires.LT(now),
		).Exec(db.ctx, db.bob)
		if err != nil {
			return fmt.Errorf("failed to shadow delete expired records: %w", err)
		}

		// Purge deleted records before threshold.
		_, err = models.Records.Delete(
			models.DeleteWhere.Records.Deleted.GT(0),
			models.DeleteWhere.Records.Deleted.LT(purgeThreshold),
		).Exec(db.ctx, db.bob)
		if err != nil {
			return fmt.Errorf("failed to purge deleted records (before threshold): %w", err)
		}
		return nil
	}

	// Option 2: Immediate delete.

	// Delete expired record.
	_, err := models.Records.Delete(
		models.DeleteWhere.Records.Expires.GT(0),
		models.DeleteWhere.Records.Expires.LT(now),
	).Exec(db.ctx, db.bob)
	if err != nil {
		return fmt.Errorf("failed to delete expired records: %w", err)
	}

	// Delete shadow deleted records.
	_, err = models.Records.Delete(
		models.DeleteWhere.Records.Deleted.GT(0),
	).Exec(db.ctx, db.bob)
	if err != nil {
		return fmt.Errorf("failed to purge deleted records: %w", err)
	}

	return nil
}

func (db *SQLite) Maintain(ctx context.Context) error {
	db.wg.Add(1)
	defer db.wg.Done()

	// Remove up to about 100KB of SQLite pages from the freelist on every run.
	// (Assuming 4KB page size.)
	_, err := db.db.ExecContext(ctx, "PRAGMA incremental_vacuum(25);")
	return err
}

func (db *SQLite) MaintainThorough(ctx context.Context) error {
	db.wg.Add(1)
	defer db.wg.Done()

	// Remove all pages from the freelist.
	_, err := db.db.ExecContext(ctx, "PRAGMA incremental_vacuum;")
	return err
}

// Shutdown shuts down the database.
func (db *SQLite) Shutdown() error {
	db.wg.Wait()
	db.cancelCtx()

	return db.bob.Close()
}

type statementLogger struct{}

func (sl statementLogger) Log(ctx context.Context, level sqldblogger.Level, msg string, data map[string]interface{}) {
	fmt.Printf("SQL: %s --- %+v\n", msg, data)
}