diff --git a/cmds/portmaster-core/main.go b/cmds/portmaster-core/main.go index edfbe197..b7f1d1ec 100644 --- a/cmds/portmaster-core/main.go +++ b/cmds/portmaster-core/main.go @@ -13,7 +13,6 @@ import ( //nolint:gci,nolintlint _ "github.com/safing/portmaster/core" _ "github.com/safing/portmaster/firewall" _ "github.com/safing/portmaster/nameserver" - _ "github.com/safing/portmaster/netquery" _ "github.com/safing/portmaster/ui" _ "github.com/safing/spn/captain" ) diff --git a/core/core.go b/core/core.go index 80bd6343..ec2431b6 100644 --- a/core/core.go +++ b/core/core.go @@ -7,10 +7,12 @@ import ( "github.com/safing/portbase/modules" "github.com/safing/portbase/modules/subsystems" + "github.com/safing/portmaster/updates" + _ "github.com/safing/portmaster/netenv" + _ "github.com/safing/portmaster/netquery" _ "github.com/safing/portmaster/status" _ "github.com/safing/portmaster/ui" - "github.com/safing/portmaster/updates" ) const ( @@ -25,7 +27,7 @@ var ( ) func init() { - module = modules.Register("core", prep, start, nil, "base", "subsystems", "status", "updates", "api", "notifications", "ui", "netenv", "network", "interception", "compat") + module = modules.Register("core", prep, start, nil, "base", "subsystems", "status", "updates", "api", "notifications", "ui", "netenv", "network", "interception", "compat", "netquery") subsystems.Register( "core", "Core", diff --git a/netquery/database.go b/netquery/database.go index 7485f19c..cf5450a0 100644 --- a/netquery/database.go +++ b/netquery/database.go @@ -244,7 +244,7 @@ func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { // Save inserts the connection conn into the SQLite database. If conn // already exists the table row is updated instead. func (db *Database) Save(ctx context.Context, conn Conn) error { - connMap, err := orm.EncodeAsMap(ctx, conn, "", orm.DefaultEncodeConfig) + connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig) if err != nil { return fmt.Errorf("failed to encode connection for SQL: %w", err) } diff --git a/netquery/manager.go b/netquery/manager.go index cd9b2335..647e82ed 100644 --- a/netquery/manager.go +++ b/netquery/manager.go @@ -109,8 +109,13 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect continue } + // we clone the record metadata from the connection + // into the new model so the portbase/database layer + // can handle NEW/UPDATE correctly. + cloned := conn.Meta().Duplicate() + // push an update for the connection - if err := mng.pushConnUpdate(ctx, *model); err != nil { + if err := mng.pushConnUpdate(ctx, *cloned, *model); err != nil { log.Errorf("netquery: failed to push update for conn %s via database system: %w", conn.ID, err) } @@ -123,7 +128,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect } } -func (mng *Manager) pushConnUpdate(ctx context.Context, conn Conn) error { +func (mng *Manager) pushConnUpdate(ctx context.Context, meta record.Meta, conn Conn) error { blob, err := json.Marshal(conn) if err != nil { return fmt.Errorf("failed to marshal connection: %w", err) @@ -132,7 +137,7 @@ func (mng *Manager) pushConnUpdate(ctx context.Context, conn Conn) error { key := fmt.Sprintf("%s:%s%s", mng.runtimeReg.DatabaseName(), mng.pushPrefix, conn.ID) wrapper, err := record.NewWrapper( key, - new(record.Meta), + &meta, dsd.JSON, blob, ) @@ -140,20 +145,6 @@ func (mng *Manager) pushConnUpdate(ctx context.Context, conn Conn) error { return fmt.Errorf("failed to create record wrapper: %w", err) } - // FIXME(ppacher): it may happen that started != now for NEW connections. - // In that case we would push and UPD rather than NEW even if - // the connection is new ... - // Though, that's still better than always pushing NEW for existing - // connections. - // If we would use UnixNano() here chances would be even worse. - // - // Verify if the check in portbase/api/database.go is vulnerable - // to such timing issues in general. - wrapper.SetMeta(&record.Meta{ - Created: conn.Started.Unix(), - Modified: time.Now().Unix(), - }) - mng.push(wrapper) return nil } @@ -195,7 +186,7 @@ func convertConnection(conn *network.Connection) (*Conn, error) { extraData["reason"] = conn.Reason c.RemoteIP = conn.Entity.IP.String() - c.RemotePort = conn.Entity.Port // FIXME(ppacher): or do we want DstPort() here? + c.RemotePort = conn.Entity.Port c.Domain = conn.Entity.Domain c.Country = conn.Entity.Country c.ASN = conn.Entity.ASN diff --git a/netquery/module_api.go b/netquery/module_api.go index da4e7a1a..6bc9cb4f 100644 --- a/netquery/module_api.go +++ b/netquery/module_api.go @@ -13,128 +13,135 @@ import ( "github.com/safing/portmaster/network" ) +type Module struct { + *modules.Module + + db *database.Interface + sqlStore *Database + mng *Manager + feed chan *network.Connection +} + func init() { - var ( - module *modules.Module - db *database.Interface - sqlStore *Database - mng *Manager - ) - - module = modules.Register( + mod := new(Module) + mod.Module = modules.Register( "netquery", - /* Prepare Module */ - func() error { - var err error - - db = database.NewInterface(&database.Options{ - Local: true, - Internal: true, - CacheSize: 0, - }) - - sqlStore, err = NewInMemory() - if err != nil { - return fmt.Errorf("failed to create in-memory database: %w", err) - } - - mng, err = NewManager(sqlStore, "netquery/updates/", runtime.DefaultRegistry) - if err != nil { - return fmt.Errorf("failed to create manager: %w", err) - } - - return nil - }, - /* Start Module */ - func() error { - ch := make(chan *network.Connection, 100) - - module.StartServiceWorker("netquery-feeder", time.Second, func(ctx context.Context) error { - sub, err := db.Subscribe(query.New("network:")) - if err != nil { - return fmt.Errorf("failed to subscribe to network tree: %w", err) - } - defer sub.Cancel() - - for { - select { - case <-ctx.Done(): - return nil - case rec, ok := <-sub.Feed: - if !ok { - return nil - } - - conn, ok := rec.(*network.Connection) - if !ok { - // This is fine as we also receive process updates on - // this channel. - continue - } - - ch <- conn - } - } - }) - - module.StartServiceWorker("netquery-persister", time.Second, func(ctx context.Context) error { - defer close(ch) - - mng.HandleFeed(ctx, ch) - return nil - }) - - module.StartWorker("netquery-row-cleaner", func(ctx context.Context) error { - for { - select { - case <-ctx.Done(): - return nil - case <-time.After(10 * time.Second): - count, err := sqlStore.Cleanup(ctx, time.Now().Add(-5*time.Minute)) - if err != nil { - log.Errorf("netquery: failed to count number of rows in memory: %w", err) - } else { - log.Infof("netquery: successfully removed %d old rows", count) - } - } - } - }) - - module.StartWorker("netquery-row-counter", func(ctx context.Context) error { - for { - select { - case <-ctx.Done(): - return nil - case <-time.After(5 * time.Second): - count, err := sqlStore.CountRows(ctx) - if err != nil { - log.Errorf("netquery: failed to count number of rows in memory: %w", err) - } else { - log.Infof("netquery: currently holding %d rows in memory", count) - } - - /* - if err := sqlStore.dumpTo(ctx, os.Stderr); err != nil { - log.Errorf("netquery: failed to dump sqlite memory content: %w", err) - } - */ - } - } - }) - - // for debugging, we provide a simple direct SQL query interface using - // the runtime database - _, err := NewRuntimeQueryRunner(sqlStore, "netquery/query/", runtime.DefaultRegistry) - if err != nil { - return fmt.Errorf("failed to set up runtime SQL query runner: %w", err) - } - - return nil - }, - nil, + mod.Prepare, + mod.Start, + mod.Stop, "network", "database", ) - - module.Enable() +} + +func (m *Module) Prepare() error { + var err error + + m.db = database.NewInterface(&database.Options{ + Local: true, + Internal: true, + CacheSize: 0, + }) + + m.sqlStore, err = NewInMemory() + if err != nil { + return fmt.Errorf("failed to create in-memory database: %w", err) + } + + m.mng, err = NewManager(m.sqlStore, "netquery/data/", runtime.DefaultRegistry) + if err != nil { + return fmt.Errorf("failed to create manager: %w", err) + } + + m.feed = make(chan *network.Connection, 1000) + + return nil +} + +func (mod *Module) Start() error { + mod.StartServiceWorker("netquery-feeder", time.Second, func(ctx context.Context) error { + sub, err := mod.db.Subscribe(query.New("network:")) + if err != nil { + return fmt.Errorf("failed to subscribe to network tree: %w", err) + } + defer sub.Cancel() + + for { + select { + case <-ctx.Done(): + return nil + case rec, ok := <-sub.Feed: + if !ok { + return nil + } + + conn, ok := rec.(*network.Connection) + if !ok { + // This is fine as we also receive process updates on + // this channel. + continue + } + + mod.feed <- conn + } + } + }) + + mod.StartServiceWorker("netquery-persister", time.Second, func(ctx context.Context) error { + mod.mng.HandleFeed(ctx, mod.feed) + return nil + }) + + mod.StartServiceWorker("netquery-row-cleaner", time.Second, func(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return nil + case <-time.After(10 * time.Second): + count, err := mod.sqlStore.Cleanup(ctx, time.Now().Add(-network.DeleteConnsAfterEndedThreshold)) + if err != nil { + log.Errorf("netquery: failed to count number of rows in memory: %w", err) + } else { + log.Infof("netquery: successfully removed %d old rows", count) + } + } + } + }) + + mod.StartWorker("netquery-row-counter", func(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return nil + case <-time.After(5 * time.Second): + count, err := mod.sqlStore.CountRows(ctx) + if err != nil { + log.Errorf("netquery: failed to count number of rows in memory: %w", err) + } else { + log.Infof("netquery: currently holding %d rows in memory", count) + } + + /* + if err := sqlStore.dumpTo(ctx, os.Stderr); err != nil { + log.Errorf("netquery: failed to dump sqlite memory content: %w", err) + } + */ + } + } + }) + + // for debugging, we provide a simple direct SQL query interface using + // the runtime database + _, err := NewRuntimeQueryRunner(mod.sqlStore, "netquery/query/", runtime.DefaultRegistry) + if err != nil { + return fmt.Errorf("failed to set up runtime SQL query runner: %w", err) + } + + return nil +} + +func (mod *Module) Stop() error { + close(mod.feed) + + return nil } diff --git a/netquery/orm/encoder.go b/netquery/orm/encoder.go index c2d9a62c..fc4e772c 100644 --- a/netquery/orm/encoder.go +++ b/netquery/orm/encoder.go @@ -17,10 +17,10 @@ type ( } ) -// EncodeAsMap returns a map that contains the value of each struct field of +// ToParamMap returns a map that contains the sqlite compatible value of each struct field of // r using the sqlite column name as a map key. It either uses the name of the // exported struct field or the value of the "sqlite" tag. -func EncodeAsMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig) (map[string]interface{}, error) { +func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig) (map[string]interface{}, error) { // make sure we work on a struct type val := reflect.Indirect(reflect.ValueOf(r)) if val.Kind() != reflect.Struct { @@ -38,12 +38,20 @@ func EncodeAsMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encod continue } - colDev, err := getColumnDef(fieldType) + colDef, err := getColumnDef(fieldType) if err != nil { return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err) } - x, found, err := runEncodeHooks(colDev, fieldType.Type, field, cfg.EncodeHooks) + x, found, err := runEncodeHooks( + colDef, + fieldType.Type, + field, + append( + cfg.EncodeHooks, + encodeBasic(), + ), + ) if err != nil { return nil, fmt.Errorf("failed to run encode hooks: %w", err) } @@ -61,6 +69,69 @@ func EncodeAsMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encod return res, nil } +func EncodeValue(ctx context.Context, colDef *ColumnDef, val interface{}, cfg EncodeConfig) (interface{}, error) { + fieldValue := reflect.ValueOf(val) + fieldType := reflect.TypeOf(val) + + x, found, err := runEncodeHooks( + colDef, + fieldType, + fieldValue, + append( + cfg.EncodeHooks, + encodeBasic(), + ), + ) + if err != nil { + return nil, fmt.Errorf("failed to run encode hooks: %w", err) + } + + if !found { + if reflect.Indirect(fieldValue).IsValid() { + x = reflect.Indirect(fieldValue).Interface() + } + } + + return x, nil +} + +func encodeBasic() EncodeFunc { + return func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) { + kind := valType.Kind() + if kind == reflect.Ptr { + valType = valType.Elem() + kind = valType.Kind() + + if val.IsNil() { + return nil, true, nil + } + + val = val.Elem() + } + + switch normalizeKind(kind) { + case reflect.String, + reflect.Float64, + reflect.Bool, + reflect.Int, + reflect.Uint: + // sqlite package handles conversion of those types + // already + return val.Interface(), true, nil + + case reflect.Slice: + if valType.Elem().Kind() == reflect.Uint8 { + // this is []byte + return val.Interface(), true, nil + } + fallthrough + + default: + return nil, false, fmt.Errorf("cannot convert value of kind %s for use in SQLite", kind) + } + } +} + func DatetimeEncoder(loc *time.Location) EncodeFunc { return func(colDev *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) { // if fieldType holds a pointer we need to dereference the value @@ -103,6 +174,10 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc { } func runEncodeHooks(colDev *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) { + if valType == nil { + return nil, true, nil + } + for _, fn := range hooks { res, end, err := fn(colDev, valType, val) if err != nil { diff --git a/netquery/orm/encoder_test.go b/netquery/orm/encoder_test.go index 056bf953..8b802a78 100644 --- a/netquery/orm/encoder_test.go +++ b/netquery/orm/encoder_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "zombiezen.com/go/sqlite" ) func Test_EncodeAsMap(t *testing.T) { @@ -118,9 +119,106 @@ func Test_EncodeAsMap(t *testing.T) { t.Run(c.Desc, func(t *testing.T) { // t.Parallel() - res, err := EncodeAsMap(ctx, c.Input, "", DefaultEncodeConfig) + res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig) assert.NoError(t, err) assert.Equal(t, c.Expected, res) }) } } + +func Test_EncodeValue(t *testing.T) { + ctx := context.TODO() + refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC) + + cases := []struct { + Desc string + Column ColumnDef + Input interface{} + Output interface{} + }{ + { + "Special value time.Time as text", + ColumnDef{ + IsTime: true, + Type: sqlite.TypeText, + }, + refTime, + refTime.Format(sqliteTimeFormat), + }, + { + "Special value time.Time as unix-epoch", + ColumnDef{ + IsTime: true, + Type: sqlite.TypeInteger, + }, + refTime, + refTime.Unix(), + }, + { + "Special value time.Time as unixnano-epoch", + ColumnDef{ + IsTime: true, + Type: sqlite.TypeInteger, + UnixNano: true, + }, + refTime, + refTime.UnixNano(), + }, + { + "Special value zero time", + ColumnDef{ + IsTime: true, + Type: sqlite.TypeText, + }, + time.Time{}, + nil, + }, + { + "Special value zero time pointer", + ColumnDef{ + IsTime: true, + Type: sqlite.TypeText, + }, + new(time.Time), + nil, + }, + { + "Special value *time.Time as text", + ColumnDef{ + IsTime: true, + Type: sqlite.TypeText, + }, + &refTime, + refTime.Format(sqliteTimeFormat), + }, + { + "Special value untyped nil", + ColumnDef{ + IsTime: true, + Type: sqlite.TypeText, + }, + nil, + nil, + }, + { + "Special value typed nil", + ColumnDef{ + IsTime: true, + Type: sqlite.TypeText, + }, + (*time.Time)(nil), + nil, + }, + } + + for idx := range cases { + c := cases[idx] + t.Run(c.Desc, func(t *testing.T) { + // t.Parallel() + + res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig) + assert.NoError(t, err) + assert.Equal(t, c.Output, res) + }) + } +} diff --git a/netquery/orm/schema_builder.go b/netquery/orm/schema_builder.go index aa57e21a..9289b06d 100644 --- a/netquery/orm/schema_builder.go +++ b/netquery/orm/schema_builder.go @@ -18,6 +18,7 @@ var ( TagUnixNano = "unixnano" TagPrimaryKey = "primary" TagAutoIncrement = "autoincrement" + TagTime = "time" TagNotNull = "not-null" TagNullable = "nullable" TagTypeInt = "integer" @@ -48,6 +49,7 @@ type ( PrimaryKey bool AutoIncrement bool UnixNano bool + IsTime bool } ) @@ -188,6 +190,8 @@ func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error { def.Nullable = true case TagUnixNano: def.UnixNano = true + case TagTime: + def.IsTime = true // basic column types case TagTypeInt: diff --git a/network/clean.go b/network/clean.go index 608e2c88..6ff1eb38 100644 --- a/network/clean.go +++ b/network/clean.go @@ -12,7 +12,7 @@ import ( const ( cleanerTickDuration = 5 * time.Second - deleteConnsAfterEndedThreshold = 10 * time.Minute + DeleteConnsAfterEndedThreshold = 10 * time.Minute ) func connectionCleaner(ctx context.Context) error { @@ -41,7 +41,7 @@ func cleanConnections() (activePIDs map[int]struct{}) { _ = module.RunMediumPriorityMicroTask(&name, func(ctx context.Context) error { now := time.Now().UTC() nowUnix := now.Unix() - deleteOlderThan := now.Add(-deleteConnsAfterEndedThreshold).Unix() + deleteOlderThan := now.Add(-DeleteConnsAfterEndedThreshold).Unix() // network connections for _, conn := range conns.clone() {