Add and update netquery code based on review discussions

This commit is contained in:
Patrick Pacher 2022-03-17 14:28:01 +01:00
parent 976c0a702e
commit c2d2064ec8
No known key found for this signature in database
GPG key ID: E8CD2DA160925A6D
9 changed files with 324 additions and 148 deletions

View file

@ -13,7 +13,6 @@ import ( //nolint:gci,nolintlint
_ "github.com/safing/portmaster/core" _ "github.com/safing/portmaster/core"
_ "github.com/safing/portmaster/firewall" _ "github.com/safing/portmaster/firewall"
_ "github.com/safing/portmaster/nameserver" _ "github.com/safing/portmaster/nameserver"
_ "github.com/safing/portmaster/netquery"
_ "github.com/safing/portmaster/ui" _ "github.com/safing/portmaster/ui"
_ "github.com/safing/spn/captain" _ "github.com/safing/spn/captain"
) )

View file

@ -7,10 +7,12 @@ import (
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
"github.com/safing/portbase/modules/subsystems" "github.com/safing/portbase/modules/subsystems"
"github.com/safing/portmaster/updates"
_ "github.com/safing/portmaster/netenv" _ "github.com/safing/portmaster/netenv"
_ "github.com/safing/portmaster/netquery"
_ "github.com/safing/portmaster/status" _ "github.com/safing/portmaster/status"
_ "github.com/safing/portmaster/ui" _ "github.com/safing/portmaster/ui"
"github.com/safing/portmaster/updates"
) )
const ( const (
@ -25,7 +27,7 @@ var (
) )
func init() { func init() {
module = modules.Register("core", prep, start, nil, "base", "subsystems", "status", "updates", "api", "notifications", "ui", "netenv", "network", "interception", "compat") module = modules.Register("core", prep, start, nil, "base", "subsystems", "status", "updates", "api", "notifications", "ui", "netenv", "network", "interception", "compat", "netquery")
subsystems.Register( subsystems.Register(
"core", "core",
"Core", "Core",

View file

@ -244,7 +244,7 @@ func (db *Database) dumpTo(ctx context.Context, w io.Writer) error {
// Save inserts the connection conn into the SQLite database. If conn // Save inserts the connection conn into the SQLite database. If conn
// already exists the table row is updated instead. // already exists the table row is updated instead.
func (db *Database) Save(ctx context.Context, conn Conn) error { func (db *Database) Save(ctx context.Context, conn Conn) error {
connMap, err := orm.EncodeAsMap(ctx, conn, "", orm.DefaultEncodeConfig) connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig)
if err != nil { if err != nil {
return fmt.Errorf("failed to encode connection for SQL: %w", err) return fmt.Errorf("failed to encode connection for SQL: %w", err)
} }

View file

@ -109,8 +109,13 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
continue continue
} }
// we clone the record metadata from the connection
// into the new model so the portbase/database layer
// can handle NEW/UPDATE correctly.
cloned := conn.Meta().Duplicate()
// push an update for the connection // push an update for the connection
if err := mng.pushConnUpdate(ctx, *model); err != nil { if err := mng.pushConnUpdate(ctx, *cloned, *model); err != nil {
log.Errorf("netquery: failed to push update for conn %s via database system: %w", conn.ID, err) log.Errorf("netquery: failed to push update for conn %s via database system: %w", conn.ID, err)
} }
@ -123,7 +128,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
} }
} }
func (mng *Manager) pushConnUpdate(ctx context.Context, conn Conn) error { func (mng *Manager) pushConnUpdate(ctx context.Context, meta record.Meta, conn Conn) error {
blob, err := json.Marshal(conn) blob, err := json.Marshal(conn)
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal connection: %w", err) return fmt.Errorf("failed to marshal connection: %w", err)
@ -132,7 +137,7 @@ func (mng *Manager) pushConnUpdate(ctx context.Context, conn Conn) error {
key := fmt.Sprintf("%s:%s%s", mng.runtimeReg.DatabaseName(), mng.pushPrefix, conn.ID) key := fmt.Sprintf("%s:%s%s", mng.runtimeReg.DatabaseName(), mng.pushPrefix, conn.ID)
wrapper, err := record.NewWrapper( wrapper, err := record.NewWrapper(
key, key,
new(record.Meta), &meta,
dsd.JSON, dsd.JSON,
blob, blob,
) )
@ -140,20 +145,6 @@ func (mng *Manager) pushConnUpdate(ctx context.Context, conn Conn) error {
return fmt.Errorf("failed to create record wrapper: %w", err) return fmt.Errorf("failed to create record wrapper: %w", err)
} }
// FIXME(ppacher): it may happen that started != now for NEW connections.
// In that case we would push and UPD rather than NEW even if
// the connection is new ...
// Though, that's still better than always pushing NEW for existing
// connections.
// If we would use UnixNano() here chances would be even worse.
//
// Verify if the check in portbase/api/database.go is vulnerable
// to such timing issues in general.
wrapper.SetMeta(&record.Meta{
Created: conn.Started.Unix(),
Modified: time.Now().Unix(),
})
mng.push(wrapper) mng.push(wrapper)
return nil return nil
} }
@ -195,7 +186,7 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
extraData["reason"] = conn.Reason extraData["reason"] = conn.Reason
c.RemoteIP = conn.Entity.IP.String() c.RemoteIP = conn.Entity.IP.String()
c.RemotePort = conn.Entity.Port // FIXME(ppacher): or do we want DstPort() here? c.RemotePort = conn.Entity.Port
c.Domain = conn.Entity.Domain c.Domain = conn.Entity.Domain
c.Country = conn.Entity.Country c.Country = conn.Entity.Country
c.ASN = conn.Entity.ASN c.ASN = conn.Entity.ASN

View file

@ -13,128 +13,135 @@ import (
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
) )
type Module struct {
*modules.Module
db *database.Interface
sqlStore *Database
mng *Manager
feed chan *network.Connection
}
func init() { func init() {
var ( mod := new(Module)
module *modules.Module mod.Module = modules.Register(
db *database.Interface
sqlStore *Database
mng *Manager
)
module = modules.Register(
"netquery", "netquery",
/* Prepare Module */ mod.Prepare,
func() error { mod.Start,
var err error mod.Stop,
db = database.NewInterface(&database.Options{
Local: true,
Internal: true,
CacheSize: 0,
})
sqlStore, err = NewInMemory()
if err != nil {
return fmt.Errorf("failed to create in-memory database: %w", err)
}
mng, err = NewManager(sqlStore, "netquery/updates/", runtime.DefaultRegistry)
if err != nil {
return fmt.Errorf("failed to create manager: %w", err)
}
return nil
},
/* Start Module */
func() error {
ch := make(chan *network.Connection, 100)
module.StartServiceWorker("netquery-feeder", time.Second, func(ctx context.Context) error {
sub, err := db.Subscribe(query.New("network:"))
if err != nil {
return fmt.Errorf("failed to subscribe to network tree: %w", err)
}
defer sub.Cancel()
for {
select {
case <-ctx.Done():
return nil
case rec, ok := <-sub.Feed:
if !ok {
return nil
}
conn, ok := rec.(*network.Connection)
if !ok {
// This is fine as we also receive process updates on
// this channel.
continue
}
ch <- conn
}
}
})
module.StartServiceWorker("netquery-persister", time.Second, func(ctx context.Context) error {
defer close(ch)
mng.HandleFeed(ctx, ch)
return nil
})
module.StartWorker("netquery-row-cleaner", func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
case <-time.After(10 * time.Second):
count, err := sqlStore.Cleanup(ctx, time.Now().Add(-5*time.Minute))
if err != nil {
log.Errorf("netquery: failed to count number of rows in memory: %w", err)
} else {
log.Infof("netquery: successfully removed %d old rows", count)
}
}
}
})
module.StartWorker("netquery-row-counter", func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
case <-time.After(5 * time.Second):
count, err := sqlStore.CountRows(ctx)
if err != nil {
log.Errorf("netquery: failed to count number of rows in memory: %w", err)
} else {
log.Infof("netquery: currently holding %d rows in memory", count)
}
/*
if err := sqlStore.dumpTo(ctx, os.Stderr); err != nil {
log.Errorf("netquery: failed to dump sqlite memory content: %w", err)
}
*/
}
}
})
// for debugging, we provide a simple direct SQL query interface using
// the runtime database
_, err := NewRuntimeQueryRunner(sqlStore, "netquery/query/", runtime.DefaultRegistry)
if err != nil {
return fmt.Errorf("failed to set up runtime SQL query runner: %w", err)
}
return nil
},
nil,
"network", "network",
"database", "database",
) )
}
module.Enable()
func (m *Module) Prepare() error {
var err error
m.db = database.NewInterface(&database.Options{
Local: true,
Internal: true,
CacheSize: 0,
})
m.sqlStore, err = NewInMemory()
if err != nil {
return fmt.Errorf("failed to create in-memory database: %w", err)
}
m.mng, err = NewManager(m.sqlStore, "netquery/data/", runtime.DefaultRegistry)
if err != nil {
return fmt.Errorf("failed to create manager: %w", err)
}
m.feed = make(chan *network.Connection, 1000)
return nil
}
func (mod *Module) Start() error {
mod.StartServiceWorker("netquery-feeder", time.Second, func(ctx context.Context) error {
sub, err := mod.db.Subscribe(query.New("network:"))
if err != nil {
return fmt.Errorf("failed to subscribe to network tree: %w", err)
}
defer sub.Cancel()
for {
select {
case <-ctx.Done():
return nil
case rec, ok := <-sub.Feed:
if !ok {
return nil
}
conn, ok := rec.(*network.Connection)
if !ok {
// This is fine as we also receive process updates on
// this channel.
continue
}
mod.feed <- conn
}
}
})
mod.StartServiceWorker("netquery-persister", time.Second, func(ctx context.Context) error {
mod.mng.HandleFeed(ctx, mod.feed)
return nil
})
mod.StartServiceWorker("netquery-row-cleaner", time.Second, func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
case <-time.After(10 * time.Second):
count, err := mod.sqlStore.Cleanup(ctx, time.Now().Add(-network.DeleteConnsAfterEndedThreshold))
if err != nil {
log.Errorf("netquery: failed to count number of rows in memory: %w", err)
} else {
log.Infof("netquery: successfully removed %d old rows", count)
}
}
}
})
mod.StartWorker("netquery-row-counter", func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
case <-time.After(5 * time.Second):
count, err := mod.sqlStore.CountRows(ctx)
if err != nil {
log.Errorf("netquery: failed to count number of rows in memory: %w", err)
} else {
log.Infof("netquery: currently holding %d rows in memory", count)
}
/*
if err := sqlStore.dumpTo(ctx, os.Stderr); err != nil {
log.Errorf("netquery: failed to dump sqlite memory content: %w", err)
}
*/
}
}
})
// for debugging, we provide a simple direct SQL query interface using
// the runtime database
_, err := NewRuntimeQueryRunner(mod.sqlStore, "netquery/query/", runtime.DefaultRegistry)
if err != nil {
return fmt.Errorf("failed to set up runtime SQL query runner: %w", err)
}
return nil
}
func (mod *Module) Stop() error {
close(mod.feed)
return nil
} }

View file

@ -17,10 +17,10 @@ type (
} }
) )
// EncodeAsMap returns a map that contains the value of each struct field of // ToParamMap returns a map that contains the sqlite compatible value of each struct field of
// r using the sqlite column name as a map key. It either uses the name of the // r using the sqlite column name as a map key. It either uses the name of the
// exported struct field or the value of the "sqlite" tag. // exported struct field or the value of the "sqlite" tag.
func EncodeAsMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig) (map[string]interface{}, error) { func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig) (map[string]interface{}, error) {
// make sure we work on a struct type // make sure we work on a struct type
val := reflect.Indirect(reflect.ValueOf(r)) val := reflect.Indirect(reflect.ValueOf(r))
if val.Kind() != reflect.Struct { if val.Kind() != reflect.Struct {
@ -38,12 +38,20 @@ func EncodeAsMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encod
continue continue
} }
colDev, err := getColumnDef(fieldType) colDef, err := getColumnDef(fieldType)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err) return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err)
} }
x, found, err := runEncodeHooks(colDev, fieldType.Type, field, cfg.EncodeHooks) x, found, err := runEncodeHooks(
colDef,
fieldType.Type,
field,
append(
cfg.EncodeHooks,
encodeBasic(),
),
)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to run encode hooks: %w", err) return nil, fmt.Errorf("failed to run encode hooks: %w", err)
} }
@ -61,6 +69,69 @@ func EncodeAsMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encod
return res, nil return res, nil
} }
func EncodeValue(ctx context.Context, colDef *ColumnDef, val interface{}, cfg EncodeConfig) (interface{}, error) {
fieldValue := reflect.ValueOf(val)
fieldType := reflect.TypeOf(val)
x, found, err := runEncodeHooks(
colDef,
fieldType,
fieldValue,
append(
cfg.EncodeHooks,
encodeBasic(),
),
)
if err != nil {
return nil, fmt.Errorf("failed to run encode hooks: %w", err)
}
if !found {
if reflect.Indirect(fieldValue).IsValid() {
x = reflect.Indirect(fieldValue).Interface()
}
}
return x, nil
}
func encodeBasic() EncodeFunc {
return func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
kind := valType.Kind()
if kind == reflect.Ptr {
valType = valType.Elem()
kind = valType.Kind()
if val.IsNil() {
return nil, true, nil
}
val = val.Elem()
}
switch normalizeKind(kind) {
case reflect.String,
reflect.Float64,
reflect.Bool,
reflect.Int,
reflect.Uint:
// sqlite package handles conversion of those types
// already
return val.Interface(), true, nil
case reflect.Slice:
if valType.Elem().Kind() == reflect.Uint8 {
// this is []byte
return val.Interface(), true, nil
}
fallthrough
default:
return nil, false, fmt.Errorf("cannot convert value of kind %s for use in SQLite", kind)
}
}
}
func DatetimeEncoder(loc *time.Location) EncodeFunc { func DatetimeEncoder(loc *time.Location) EncodeFunc {
return func(colDev *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) { return func(colDev *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
// if fieldType holds a pointer we need to dereference the value // if fieldType holds a pointer we need to dereference the value
@ -103,6 +174,10 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc {
} }
func runEncodeHooks(colDev *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) { func runEncodeHooks(colDev *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) {
if valType == nil {
return nil, true, nil
}
for _, fn := range hooks { for _, fn := range hooks {
res, end, err := fn(colDev, valType, val) res, end, err := fn(colDev, valType, val)
if err != nil { if err != nil {

View file

@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"zombiezen.com/go/sqlite"
) )
func Test_EncodeAsMap(t *testing.T) { func Test_EncodeAsMap(t *testing.T) {
@ -118,9 +119,106 @@ func Test_EncodeAsMap(t *testing.T) {
t.Run(c.Desc, func(t *testing.T) { t.Run(c.Desc, func(t *testing.T) {
// t.Parallel() // t.Parallel()
res, err := EncodeAsMap(ctx, c.Input, "", DefaultEncodeConfig) res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, c.Expected, res) assert.Equal(t, c.Expected, res)
}) })
} }
} }
func Test_EncodeValue(t *testing.T) {
ctx := context.TODO()
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC)
cases := []struct {
Desc string
Column ColumnDef
Input interface{}
Output interface{}
}{
{
"Special value time.Time as text",
ColumnDef{
IsTime: true,
Type: sqlite.TypeText,
},
refTime,
refTime.Format(sqliteTimeFormat),
},
{
"Special value time.Time as unix-epoch",
ColumnDef{
IsTime: true,
Type: sqlite.TypeInteger,
},
refTime,
refTime.Unix(),
},
{
"Special value time.Time as unixnano-epoch",
ColumnDef{
IsTime: true,
Type: sqlite.TypeInteger,
UnixNano: true,
},
refTime,
refTime.UnixNano(),
},
{
"Special value zero time",
ColumnDef{
IsTime: true,
Type: sqlite.TypeText,
},
time.Time{},
nil,
},
{
"Special value zero time pointer",
ColumnDef{
IsTime: true,
Type: sqlite.TypeText,
},
new(time.Time),
nil,
},
{
"Special value *time.Time as text",
ColumnDef{
IsTime: true,
Type: sqlite.TypeText,
},
&refTime,
refTime.Format(sqliteTimeFormat),
},
{
"Special value untyped nil",
ColumnDef{
IsTime: true,
Type: sqlite.TypeText,
},
nil,
nil,
},
{
"Special value typed nil",
ColumnDef{
IsTime: true,
Type: sqlite.TypeText,
},
(*time.Time)(nil),
nil,
},
}
for idx := range cases {
c := cases[idx]
t.Run(c.Desc, func(t *testing.T) {
// t.Parallel()
res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig)
assert.NoError(t, err)
assert.Equal(t, c.Output, res)
})
}
}

View file

@ -18,6 +18,7 @@ var (
TagUnixNano = "unixnano" TagUnixNano = "unixnano"
TagPrimaryKey = "primary" TagPrimaryKey = "primary"
TagAutoIncrement = "autoincrement" TagAutoIncrement = "autoincrement"
TagTime = "time"
TagNotNull = "not-null" TagNotNull = "not-null"
TagNullable = "nullable" TagNullable = "nullable"
TagTypeInt = "integer" TagTypeInt = "integer"
@ -48,6 +49,7 @@ type (
PrimaryKey bool PrimaryKey bool
AutoIncrement bool AutoIncrement bool
UnixNano bool UnixNano bool
IsTime bool
} }
) )
@ -188,6 +190,8 @@ func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error {
def.Nullable = true def.Nullable = true
case TagUnixNano: case TagUnixNano:
def.UnixNano = true def.UnixNano = true
case TagTime:
def.IsTime = true
// basic column types // basic column types
case TagTypeInt: case TagTypeInt:

View file

@ -12,7 +12,7 @@ import (
const ( const (
cleanerTickDuration = 5 * time.Second cleanerTickDuration = 5 * time.Second
deleteConnsAfterEndedThreshold = 10 * time.Minute DeleteConnsAfterEndedThreshold = 10 * time.Minute
) )
func connectionCleaner(ctx context.Context) error { func connectionCleaner(ctx context.Context) error {
@ -41,7 +41,7 @@ func cleanConnections() (activePIDs map[int]struct{}) {
_ = module.RunMediumPriorityMicroTask(&name, func(ctx context.Context) error { _ = module.RunMediumPriorityMicroTask(&name, func(ctx context.Context) error {
now := time.Now().UTC() now := time.Now().UTC()
nowUnix := now.Unix() nowUnix := now.Unix()
deleteOlderThan := now.Add(-deleteConnsAfterEndedThreshold).Unix() deleteOlderThan := now.Add(-DeleteConnsAfterEndedThreshold).Unix()
// network connections // network connections
for _, conn := range conns.clone() { for _, conn := range conns.clone() {