Update bandwidth handling and add basic chart support

This commit is contained in:
Patrick Pacher 2023-08-24 13:14:23 +02:00 committed by Daniel
parent 8b92291d2c
commit bc285b593d
6 changed files with 208 additions and 22 deletions

View file

@ -679,9 +679,15 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) {
}
defer conn.Unlock()
bytesIn := bwUpdate.BytesReceived
bytesOut := bwUpdate.BytesSent
// Update stats according to method.
switch bwUpdate.Method {
case packet.Absolute:
bytesIn = bwUpdate.BytesReceived - conn.BytesReceived
bytesOut = bwUpdate.BytesSent - conn.BytesSent
conn.BytesReceived = bwUpdate.BytesReceived
conn.BytesSent = bwUpdate.BytesSent
case packet.Additive:
@ -697,10 +703,11 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) {
if err := netquery.DefaultModule.Store.UpdateBandwidth(
ctx,
conn.HistoryEnabled,
fmt.Sprintf("%s/%s", conn.ProcessContext.Source, conn.ProcessContext.Profile),
conn.Process().GetKey(),
conn.ID,
conn.BytesReceived,
conn.BytesSent,
bytesIn,
bytesOut,
); err != nil {
log.Errorf("filter: failed to persist bandwidth data: %s", err)
}

View file

@ -13,12 +13,12 @@ import (
"github.com/safing/portmaster/netquery/orm"
)
// ChartHandler handles requests for connection charts.
type ChartHandler struct {
// ActiveChartHandler handles requests for connection charts.
type ActiveChartHandler struct {
Database *Database
}
func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
func (ch *ActiveChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
requestPayload, err := ch.parseRequest(req)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
@ -62,7 +62,7 @@ func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
})
}
func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl
func (ch *ActiveChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl
var body io.Reader
switch req.Method {
@ -99,10 +99,11 @@ WITH RECURSIVE epoch(x) AS (
UNION ALL
SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0
)
SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked FROM epoch
SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked
FROM epoch
JOIN connections
ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0)
%s
ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0)
%s
GROUP BY round(timestamp/10, 0)*10;`
clause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)

View file

@ -0,0 +1,145 @@
package netquery
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/safing/portmaster/netquery/orm"
)
// BandwidthChartHandler handles requests for connection charts.
type BandwidthChartHandler struct {
Database *Database
}
type BandwidthChartRequest struct {
Profiles []string `json:"profiles"`
Connections []string `json:"connections"`
}
func (ch *BandwidthChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
requestPayload, err := ch.parseRequest(req)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
// actually execute the query against the database and collect the result
var result []map[string]interface{}
if err := ch.Database.Execute(
req.Context(),
query,
orm.WithNamedArgs(paramMap),
orm.WithResult(&result),
orm.WithSchema(*ch.Database.Schema),
); err != nil {
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
return
}
// send the HTTP status code
resp.WriteHeader(http.StatusOK)
// prepare the result encoder.
enc := json.NewEncoder(resp)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
_ = enc.Encode(map[string]interface{}{ //nolint:errchkjson
"results": result,
"query": query,
"params": paramMap,
})
}
func (ch *BandwidthChartHandler) parseRequest(req *http.Request) (*BandwidthChartRequest, error) { //nolint:dupl
var body io.Reader
switch req.Method {
case http.MethodPost, http.MethodPut:
body = req.Body
case http.MethodGet:
body = strings.NewReader(req.URL.Query().Get("q"))
default:
return nil, fmt.Errorf("invalid HTTP method")
}
var requestPayload BandwidthChartRequest
blob, err := io.ReadAll(body)
if err != nil {
return nil, fmt.Errorf("failed to read body" + err.Error())
}
body = bytes.NewReader(blob)
dec := json.NewDecoder(body)
dec.DisallowUnknownFields()
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("invalid query: %w", err)
}
return &requestPayload, nil
}
func (req *BandwidthChartRequest) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
selects := []string{
"(round(time/10, 0)*10) as time",
"SUM(incoming) as incoming",
"SUM(outgoing) as outgoing",
}
groupBy := []string{"round(time/10, 0)*10"}
whereClause := ""
params := make(map[string]any)
if len(req.Profiles) > 0 {
groupBy = []string{"profile", "round(time/10, 0)*10"}
selects = append(selects, "profile")
clauses := make([]string, len(req.Profiles))
for idx, p := range req.Profiles {
key := fmt.Sprintf(":p%d", idx)
clauses[idx] = "profile = " + key
params[key] = p
}
whereClause = "WHERE " + strings.Join(clauses, " OR ")
} else if len(req.Connections) > 0 {
groupBy = []string{"conn_id", "round(time/10, 0)*10"}
selects = append(selects, "conn_id")
clauses := make([]string, len(req.Connections))
for idx, p := range req.Connections {
key := fmt.Sprintf(":c%d", idx)
clauses[idx] = "conn_id = " + key
params[key] = p
}
whereClause = "WHERE " + strings.Join(clauses, " OR ")
}
template := fmt.Sprintf(
`SELECT %s FROM main.bandwidth %s GROUP BY %s ORDER BY time ASC`,
strings.Join(selects, ", "),
whereClause,
strings.Join(groupBy, ", "),
)
return template, params, nil
}

View file

@ -299,6 +299,20 @@ func (db *Database) ApplyMigrations() error {
}
}
bwSchema := `CREATE TABLE IF NOT EXISTS main.bandwidth (
conn_id TEXT NOT NULL,
profile TEXT NOT NULL,
time INTEGER NOT NULL,
incoming INTEGER NOT NULL,
outgoing INTEGER NOT NULL,
CONSTRAINT fk_conn_id
FOREIGN KEY(conn_id) REFERENCES connections(id)
ON DELETE CASCADE
)`
if err := sqlitex.ExecuteTransient(db.writeConn, bwSchema, nil); err != nil {
return fmt.Errorf("failed to create main.bandwidth database: %w", err)
}
return nil
}
@ -535,21 +549,16 @@ func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error {
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes
// the bandwidth data to the history database.
func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error {
func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error {
params := map[string]any{
":id": makeNqIDFromParts(processKey, connID),
}
parts := []string{}
if bytesReceived != 0 {
parts = append(parts, "bytes_received = :bytes_received")
params[":bytes_received"] = bytesReceived
}
if bytesSent != 0 {
parts = append(parts, "bytes_sent = :bytes_sent")
params[":bytes_sent"] = bytesSent
}
parts = append(parts, "bytes_received = :bytes_received")
params[":bytes_received"] = bytesReceived
parts = append(parts, "bytes_sent = :bytes_sent")
params[":bytes_sent"] = bytesSent
updateSet := strings.Join(parts, ", ")
@ -570,6 +579,14 @@ func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, pro
}
}
// also add the date to the in-memory bandwidth database
params[":time"] = time.Now().Unix()
params[":profile"] = profileKey
stmt := "INSERT INTO main.bandwidth (conn_id, profile, time, incoming, outgoing) VALUES(:id, :profile, :time, :bytes_received, :bytes_sent)"
if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil {
merr.Errors = append(merr.Errors, fmt.Errorf("failed to update main.bandwidth: %w", err))
}
return merr.ErrorOrNil()
}

View file

@ -38,7 +38,7 @@ type (
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes
// the bandwidth data to the history database.
UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error
UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error
// CleanupHistory deletes data outside of the retention time frame from the history database.
CleanupHistory(ctx context.Context) error

View file

@ -87,7 +87,11 @@ func (m *module) prepare() error {
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
}
chartHandler := &ChartHandler{
chartHandler := &ActiveChartHandler{
Database: m.Store,
}
bwChartHandler := &BandwidthChartHandler{
Database: m.Store,
}
@ -129,6 +133,19 @@ func (m *module) prepare() error {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
if err := api.RegisterEndpoint(api.Endpoint{
// TODO: Use query parameters instead.
Path: "netquery/charts/bandwidth",
MimeType: "application/json",
Write: api.PermitUser,
BelongsTo: m.Module,
HandlerFunc: bwChartHandler.ServeHTTP,
Name: "Bandwidth Chart",
Description: "Query the in-memory sqlite connection database and return a chart of bytes sent/received.",
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
if err := api.RegisterEndpoint(api.Endpoint{
Name: "Remove connections from profile history",
Description: "Remove all connections from the history database for one or more profiles",
@ -137,7 +154,6 @@ func (m *module) prepare() error {
Write: api.PermitUser,
BelongsTo: m.Module,
ActionFunc: func(ar *api.Request) (msg string, err error) {
// TODO: Use query parameters instead.
var body struct {
ProfileIDs []string `json:"profileIDs"`
}