Update bandwidth handling and add basic chart support

This commit is contained in:
Patrick Pacher 2023-08-24 13:14:23 +02:00 committed by Daniel
parent 8b92291d2c
commit bc285b593d
6 changed files with 208 additions and 22 deletions

View file

@ -679,9 +679,15 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) {
} }
defer conn.Unlock() defer conn.Unlock()
bytesIn := bwUpdate.BytesReceived
bytesOut := bwUpdate.BytesSent
// Update stats according to method. // Update stats according to method.
switch bwUpdate.Method { switch bwUpdate.Method {
case packet.Absolute: case packet.Absolute:
bytesIn = bwUpdate.BytesReceived - conn.BytesReceived
bytesOut = bwUpdate.BytesSent - conn.BytesSent
conn.BytesReceived = bwUpdate.BytesReceived conn.BytesReceived = bwUpdate.BytesReceived
conn.BytesSent = bwUpdate.BytesSent conn.BytesSent = bwUpdate.BytesSent
case packet.Additive: case packet.Additive:
@ -697,10 +703,11 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) {
if err := netquery.DefaultModule.Store.UpdateBandwidth( if err := netquery.DefaultModule.Store.UpdateBandwidth(
ctx, ctx,
conn.HistoryEnabled, conn.HistoryEnabled,
fmt.Sprintf("%s/%s", conn.ProcessContext.Source, conn.ProcessContext.Profile),
conn.Process().GetKey(), conn.Process().GetKey(),
conn.ID, conn.ID,
conn.BytesReceived, bytesIn,
conn.BytesSent, bytesOut,
); err != nil { ); err != nil {
log.Errorf("filter: failed to persist bandwidth data: %s", err) log.Errorf("filter: failed to persist bandwidth data: %s", err)
} }

View file

@ -13,12 +13,12 @@ import (
"github.com/safing/portmaster/netquery/orm" "github.com/safing/portmaster/netquery/orm"
) )
// ChartHandler handles requests for connection charts. // ActiveChartHandler handles requests for connection charts.
type ChartHandler struct { type ActiveChartHandler struct {
Database *Database Database *Database
} }
func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { func (ch *ActiveChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
requestPayload, err := ch.parseRequest(req) requestPayload, err := ch.parseRequest(req)
if err != nil { if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest) http.Error(resp, err.Error(), http.StatusBadRequest)
@ -62,7 +62,7 @@ func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
}) })
} }
func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl func (ch *ActiveChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl
var body io.Reader var body io.Reader
switch req.Method { switch req.Method {
@ -99,7 +99,8 @@ WITH RECURSIVE epoch(x) AS (
UNION ALL UNION ALL
SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0 SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0
) )
SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked FROM epoch SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked
FROM epoch
JOIN connections JOIN connections
ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0) ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0)
%s %s

View file

@ -0,0 +1,145 @@
package netquery
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/safing/portmaster/netquery/orm"
)
// BandwidthChartHandler handles requests for connection charts.
type BandwidthChartHandler struct {
Database *Database
}
type BandwidthChartRequest struct {
Profiles []string `json:"profiles"`
Connections []string `json:"connections"`
}
func (ch *BandwidthChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
requestPayload, err := ch.parseRequest(req)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
// actually execute the query against the database and collect the result
var result []map[string]interface{}
if err := ch.Database.Execute(
req.Context(),
query,
orm.WithNamedArgs(paramMap),
orm.WithResult(&result),
orm.WithSchema(*ch.Database.Schema),
); err != nil {
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
return
}
// send the HTTP status code
resp.WriteHeader(http.StatusOK)
// prepare the result encoder.
enc := json.NewEncoder(resp)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
_ = enc.Encode(map[string]interface{}{ //nolint:errchkjson
"results": result,
"query": query,
"params": paramMap,
})
}
func (ch *BandwidthChartHandler) parseRequest(req *http.Request) (*BandwidthChartRequest, error) { //nolint:dupl
var body io.Reader
switch req.Method {
case http.MethodPost, http.MethodPut:
body = req.Body
case http.MethodGet:
body = strings.NewReader(req.URL.Query().Get("q"))
default:
return nil, fmt.Errorf("invalid HTTP method")
}
var requestPayload BandwidthChartRequest
blob, err := io.ReadAll(body)
if err != nil {
return nil, fmt.Errorf("failed to read body" + err.Error())
}
body = bytes.NewReader(blob)
dec := json.NewDecoder(body)
dec.DisallowUnknownFields()
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("invalid query: %w", err)
}
return &requestPayload, nil
}
func (req *BandwidthChartRequest) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
selects := []string{
"(round(time/10, 0)*10) as time",
"SUM(incoming) as incoming",
"SUM(outgoing) as outgoing",
}
groupBy := []string{"round(time/10, 0)*10"}
whereClause := ""
params := make(map[string]any)
if len(req.Profiles) > 0 {
groupBy = []string{"profile", "round(time/10, 0)*10"}
selects = append(selects, "profile")
clauses := make([]string, len(req.Profiles))
for idx, p := range req.Profiles {
key := fmt.Sprintf(":p%d", idx)
clauses[idx] = "profile = " + key
params[key] = p
}
whereClause = "WHERE " + strings.Join(clauses, " OR ")
} else if len(req.Connections) > 0 {
groupBy = []string{"conn_id", "round(time/10, 0)*10"}
selects = append(selects, "conn_id")
clauses := make([]string, len(req.Connections))
for idx, p := range req.Connections {
key := fmt.Sprintf(":c%d", idx)
clauses[idx] = "conn_id = " + key
params[key] = p
}
whereClause = "WHERE " + strings.Join(clauses, " OR ")
}
template := fmt.Sprintf(
`SELECT %s FROM main.bandwidth %s GROUP BY %s ORDER BY time ASC`,
strings.Join(selects, ", "),
whereClause,
strings.Join(groupBy, ", "),
)
return template, params, nil
}

View file

@ -299,6 +299,20 @@ func (db *Database) ApplyMigrations() error {
} }
} }
bwSchema := `CREATE TABLE IF NOT EXISTS main.bandwidth (
conn_id TEXT NOT NULL,
profile TEXT NOT NULL,
time INTEGER NOT NULL,
incoming INTEGER NOT NULL,
outgoing INTEGER NOT NULL,
CONSTRAINT fk_conn_id
FOREIGN KEY(conn_id) REFERENCES connections(id)
ON DELETE CASCADE
)`
if err := sqlitex.ExecuteTransient(db.writeConn, bwSchema, nil); err != nil {
return fmt.Errorf("failed to create main.bandwidth database: %w", err)
}
return nil return nil
} }
@ -535,21 +549,16 @@ func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error {
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes // UpdateBandwidth updates bandwidth data for the connection and optionally also writes
// the bandwidth data to the history database. // the bandwidth data to the history database.
func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error { func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error {
params := map[string]any{ params := map[string]any{
":id": makeNqIDFromParts(processKey, connID), ":id": makeNqIDFromParts(processKey, connID),
} }
parts := []string{} parts := []string{}
if bytesReceived != 0 {
parts = append(parts, "bytes_received = :bytes_received") parts = append(parts, "bytes_received = :bytes_received")
params[":bytes_received"] = bytesReceived params[":bytes_received"] = bytesReceived
}
if bytesSent != 0 {
parts = append(parts, "bytes_sent = :bytes_sent") parts = append(parts, "bytes_sent = :bytes_sent")
params[":bytes_sent"] = bytesSent params[":bytes_sent"] = bytesSent
}
updateSet := strings.Join(parts, ", ") updateSet := strings.Join(parts, ", ")
@ -570,6 +579,14 @@ func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, pro
} }
} }
// also add the date to the in-memory bandwidth database
params[":time"] = time.Now().Unix()
params[":profile"] = profileKey
stmt := "INSERT INTO main.bandwidth (conn_id, profile, time, incoming, outgoing) VALUES(:id, :profile, :time, :bytes_received, :bytes_sent)"
if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil {
merr.Errors = append(merr.Errors, fmt.Errorf("failed to update main.bandwidth: %w", err))
}
return merr.ErrorOrNil() return merr.ErrorOrNil()
} }

View file

@ -38,7 +38,7 @@ type (
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes // UpdateBandwidth updates bandwidth data for the connection and optionally also writes
// the bandwidth data to the history database. // the bandwidth data to the history database.
UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error
// CleanupHistory deletes data outside of the retention time frame from the history database. // CleanupHistory deletes data outside of the retention time frame from the history database.
CleanupHistory(ctx context.Context) error CleanupHistory(ctx context.Context) error

View file

@ -87,7 +87,11 @@ func (m *module) prepare() error {
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false), IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
} }
chartHandler := &ChartHandler{ chartHandler := &ActiveChartHandler{
Database: m.Store,
}
bwChartHandler := &BandwidthChartHandler{
Database: m.Store, Database: m.Store,
} }
@ -129,6 +133,19 @@ func (m *module) prepare() error {
return fmt.Errorf("failed to register API endpoint: %w", err) return fmt.Errorf("failed to register API endpoint: %w", err)
} }
if err := api.RegisterEndpoint(api.Endpoint{
// TODO: Use query parameters instead.
Path: "netquery/charts/bandwidth",
MimeType: "application/json",
Write: api.PermitUser,
BelongsTo: m.Module,
HandlerFunc: bwChartHandler.ServeHTTP,
Name: "Bandwidth Chart",
Description: "Query the in-memory sqlite connection database and return a chart of bytes sent/received.",
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
if err := api.RegisterEndpoint(api.Endpoint{ if err := api.RegisterEndpoint(api.Endpoint{
Name: "Remove connections from profile history", Name: "Remove connections from profile history",
Description: "Remove all connections from the history database for one or more profiles", Description: "Remove all connections from the history database for one or more profiles",
@ -137,7 +154,6 @@ func (m *module) prepare() error {
Write: api.PermitUser, Write: api.PermitUser,
BelongsTo: m.Module, BelongsTo: m.Module,
ActionFunc: func(ar *api.Request) (msg string, err error) { ActionFunc: func(ar *api.Request) (msg string, err error) {
// TODO: Use query parameters instead.
var body struct { var body struct {
ProfileIDs []string `json:"profileIDs"` ProfileIDs []string `json:"profileIDs"`
} }