mirror of
https://github.com/safing/portmaster
synced 2025-09-01 18:19:12 +00:00
Update bandwidth handling and add basic chart support
This commit is contained in:
parent
8b92291d2c
commit
bc285b593d
6 changed files with 208 additions and 22 deletions
|
@ -679,9 +679,15 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) {
|
||||||
}
|
}
|
||||||
defer conn.Unlock()
|
defer conn.Unlock()
|
||||||
|
|
||||||
|
bytesIn := bwUpdate.BytesReceived
|
||||||
|
bytesOut := bwUpdate.BytesSent
|
||||||
|
|
||||||
// Update stats according to method.
|
// Update stats according to method.
|
||||||
switch bwUpdate.Method {
|
switch bwUpdate.Method {
|
||||||
case packet.Absolute:
|
case packet.Absolute:
|
||||||
|
bytesIn = bwUpdate.BytesReceived - conn.BytesReceived
|
||||||
|
bytesOut = bwUpdate.BytesSent - conn.BytesSent
|
||||||
|
|
||||||
conn.BytesReceived = bwUpdate.BytesReceived
|
conn.BytesReceived = bwUpdate.BytesReceived
|
||||||
conn.BytesSent = bwUpdate.BytesSent
|
conn.BytesSent = bwUpdate.BytesSent
|
||||||
case packet.Additive:
|
case packet.Additive:
|
||||||
|
@ -697,10 +703,11 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) {
|
||||||
if err := netquery.DefaultModule.Store.UpdateBandwidth(
|
if err := netquery.DefaultModule.Store.UpdateBandwidth(
|
||||||
ctx,
|
ctx,
|
||||||
conn.HistoryEnabled,
|
conn.HistoryEnabled,
|
||||||
|
fmt.Sprintf("%s/%s", conn.ProcessContext.Source, conn.ProcessContext.Profile),
|
||||||
conn.Process().GetKey(),
|
conn.Process().GetKey(),
|
||||||
conn.ID,
|
conn.ID,
|
||||||
conn.BytesReceived,
|
bytesIn,
|
||||||
conn.BytesSent,
|
bytesOut,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
log.Errorf("filter: failed to persist bandwidth data: %s", err)
|
log.Errorf("filter: failed to persist bandwidth data: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,12 +13,12 @@ import (
|
||||||
"github.com/safing/portmaster/netquery/orm"
|
"github.com/safing/portmaster/netquery/orm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ChartHandler handles requests for connection charts.
|
// ActiveChartHandler handles requests for connection charts.
|
||||||
type ChartHandler struct {
|
type ActiveChartHandler struct {
|
||||||
Database *Database
|
Database *Database
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
func (ch *ActiveChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||||
requestPayload, err := ch.parseRequest(req)
|
requestPayload, err := ch.parseRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(resp, err.Error(), http.StatusBadRequest)
|
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||||
|
@ -62,7 +62,7 @@ func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl
|
func (ch *ActiveChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl
|
||||||
var body io.Reader
|
var body io.Reader
|
||||||
|
|
||||||
switch req.Method {
|
switch req.Method {
|
||||||
|
@ -99,10 +99,11 @@ WITH RECURSIVE epoch(x) AS (
|
||||||
UNION ALL
|
UNION ALL
|
||||||
SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0
|
SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0
|
||||||
)
|
)
|
||||||
SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked FROM epoch
|
SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked
|
||||||
|
FROM epoch
|
||||||
JOIN connections
|
JOIN connections
|
||||||
ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0)
|
ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0)
|
||||||
%s
|
%s
|
||||||
GROUP BY round(timestamp/10, 0)*10;`
|
GROUP BY round(timestamp/10, 0)*10;`
|
||||||
|
|
||||||
clause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)
|
clause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)
|
145
netquery/bandwidth_chart_handler.go
Normal file
145
netquery/bandwidth_chart_handler.go
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
package netquery
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/safing/portmaster/netquery/orm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BandwidthChartHandler handles requests for connection charts.
|
||||||
|
type BandwidthChartHandler struct {
|
||||||
|
Database *Database
|
||||||
|
}
|
||||||
|
|
||||||
|
type BandwidthChartRequest struct {
|
||||||
|
Profiles []string `json:"profiles"`
|
||||||
|
Connections []string `json:"connections"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *BandwidthChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||||
|
requestPayload, err := ch.parseRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// actually execute the query against the database and collect the result
|
||||||
|
var result []map[string]interface{}
|
||||||
|
if err := ch.Database.Execute(
|
||||||
|
req.Context(),
|
||||||
|
query,
|
||||||
|
orm.WithNamedArgs(paramMap),
|
||||||
|
orm.WithResult(&result),
|
||||||
|
orm.WithSchema(*ch.Database.Schema),
|
||||||
|
); err != nil {
|
||||||
|
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// send the HTTP status code
|
||||||
|
resp.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
// prepare the result encoder.
|
||||||
|
enc := json.NewEncoder(resp)
|
||||||
|
enc.SetEscapeHTML(false)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
|
||||||
|
_ = enc.Encode(map[string]interface{}{ //nolint:errchkjson
|
||||||
|
"results": result,
|
||||||
|
"query": query,
|
||||||
|
"params": paramMap,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *BandwidthChartHandler) parseRequest(req *http.Request) (*BandwidthChartRequest, error) { //nolint:dupl
|
||||||
|
var body io.Reader
|
||||||
|
|
||||||
|
switch req.Method {
|
||||||
|
case http.MethodPost, http.MethodPut:
|
||||||
|
body = req.Body
|
||||||
|
case http.MethodGet:
|
||||||
|
body = strings.NewReader(req.URL.Query().Get("q"))
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid HTTP method")
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestPayload BandwidthChartRequest
|
||||||
|
blob, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read body" + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
body = bytes.NewReader(blob)
|
||||||
|
|
||||||
|
dec := json.NewDecoder(body)
|
||||||
|
dec.DisallowUnknownFields()
|
||||||
|
|
||||||
|
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
return nil, fmt.Errorf("invalid query: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &requestPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *BandwidthChartRequest) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
||||||
|
selects := []string{
|
||||||
|
"(round(time/10, 0)*10) as time",
|
||||||
|
"SUM(incoming) as incoming",
|
||||||
|
"SUM(outgoing) as outgoing",
|
||||||
|
}
|
||||||
|
groupBy := []string{"round(time/10, 0)*10"}
|
||||||
|
whereClause := ""
|
||||||
|
params := make(map[string]any)
|
||||||
|
|
||||||
|
if len(req.Profiles) > 0 {
|
||||||
|
groupBy = []string{"profile", "round(time/10, 0)*10"}
|
||||||
|
selects = append(selects, "profile")
|
||||||
|
clauses := make([]string, len(req.Profiles))
|
||||||
|
|
||||||
|
for idx, p := range req.Profiles {
|
||||||
|
key := fmt.Sprintf(":p%d", idx)
|
||||||
|
clauses[idx] = "profile = " + key
|
||||||
|
params[key] = p
|
||||||
|
}
|
||||||
|
|
||||||
|
whereClause = "WHERE " + strings.Join(clauses, " OR ")
|
||||||
|
} else if len(req.Connections) > 0 {
|
||||||
|
groupBy = []string{"conn_id", "round(time/10, 0)*10"}
|
||||||
|
selects = append(selects, "conn_id")
|
||||||
|
|
||||||
|
clauses := make([]string, len(req.Connections))
|
||||||
|
|
||||||
|
for idx, p := range req.Connections {
|
||||||
|
key := fmt.Sprintf(":c%d", idx)
|
||||||
|
clauses[idx] = "conn_id = " + key
|
||||||
|
params[key] = p
|
||||||
|
}
|
||||||
|
|
||||||
|
whereClause = "WHERE " + strings.Join(clauses, " OR ")
|
||||||
|
}
|
||||||
|
|
||||||
|
template := fmt.Sprintf(
|
||||||
|
`SELECT %s FROM main.bandwidth %s GROUP BY %s ORDER BY time ASC`,
|
||||||
|
strings.Join(selects, ", "),
|
||||||
|
whereClause,
|
||||||
|
strings.Join(groupBy, ", "),
|
||||||
|
)
|
||||||
|
|
||||||
|
return template, params, nil
|
||||||
|
}
|
|
@ -299,6 +299,20 @@ func (db *Database) ApplyMigrations() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bwSchema := `CREATE TABLE IF NOT EXISTS main.bandwidth (
|
||||||
|
conn_id TEXT NOT NULL,
|
||||||
|
profile TEXT NOT NULL,
|
||||||
|
time INTEGER NOT NULL,
|
||||||
|
incoming INTEGER NOT NULL,
|
||||||
|
outgoing INTEGER NOT NULL,
|
||||||
|
CONSTRAINT fk_conn_id
|
||||||
|
FOREIGN KEY(conn_id) REFERENCES connections(id)
|
||||||
|
ON DELETE CASCADE
|
||||||
|
)`
|
||||||
|
if err := sqlitex.ExecuteTransient(db.writeConn, bwSchema, nil); err != nil {
|
||||||
|
return fmt.Errorf("failed to create main.bandwidth database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -535,21 +549,16 @@ func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error {
|
||||||
|
|
||||||
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes
|
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes
|
||||||
// the bandwidth data to the history database.
|
// the bandwidth data to the history database.
|
||||||
func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error {
|
func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error {
|
||||||
params := map[string]any{
|
params := map[string]any{
|
||||||
":id": makeNqIDFromParts(processKey, connID),
|
":id": makeNqIDFromParts(processKey, connID),
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := []string{}
|
parts := []string{}
|
||||||
if bytesReceived != 0 {
|
parts = append(parts, "bytes_received = :bytes_received")
|
||||||
parts = append(parts, "bytes_received = :bytes_received")
|
params[":bytes_received"] = bytesReceived
|
||||||
params[":bytes_received"] = bytesReceived
|
parts = append(parts, "bytes_sent = :bytes_sent")
|
||||||
}
|
params[":bytes_sent"] = bytesSent
|
||||||
|
|
||||||
if bytesSent != 0 {
|
|
||||||
parts = append(parts, "bytes_sent = :bytes_sent")
|
|
||||||
params[":bytes_sent"] = bytesSent
|
|
||||||
}
|
|
||||||
|
|
||||||
updateSet := strings.Join(parts, ", ")
|
updateSet := strings.Join(parts, ", ")
|
||||||
|
|
||||||
|
@ -570,6 +579,14 @@ func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, pro
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// also add the date to the in-memory bandwidth database
|
||||||
|
params[":time"] = time.Now().Unix()
|
||||||
|
params[":profile"] = profileKey
|
||||||
|
stmt := "INSERT INTO main.bandwidth (conn_id, profile, time, incoming, outgoing) VALUES(:id, :profile, :time, :bytes_received, :bytes_sent)"
|
||||||
|
if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil {
|
||||||
|
merr.Errors = append(merr.Errors, fmt.Errorf("failed to update main.bandwidth: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
return merr.ErrorOrNil()
|
return merr.ErrorOrNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ type (
|
||||||
|
|
||||||
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes
|
// UpdateBandwidth updates bandwidth data for the connection and optionally also writes
|
||||||
// the bandwidth data to the history database.
|
// the bandwidth data to the history database.
|
||||||
UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error
|
UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error
|
||||||
|
|
||||||
// CleanupHistory deletes data outside of the retention time frame from the history database.
|
// CleanupHistory deletes data outside of the retention time frame from the history database.
|
||||||
CleanupHistory(ctx context.Context) error
|
CleanupHistory(ctx context.Context) error
|
||||||
|
|
|
@ -87,7 +87,11 @@ func (m *module) prepare() error {
|
||||||
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
|
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
|
||||||
}
|
}
|
||||||
|
|
||||||
chartHandler := &ChartHandler{
|
chartHandler := &ActiveChartHandler{
|
||||||
|
Database: m.Store,
|
||||||
|
}
|
||||||
|
|
||||||
|
bwChartHandler := &BandwidthChartHandler{
|
||||||
Database: m.Store,
|
Database: m.Store,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,6 +133,19 @@ func (m *module) prepare() error {
|
||||||
return fmt.Errorf("failed to register API endpoint: %w", err)
|
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := api.RegisterEndpoint(api.Endpoint{
|
||||||
|
// TODO: Use query parameters instead.
|
||||||
|
Path: "netquery/charts/bandwidth",
|
||||||
|
MimeType: "application/json",
|
||||||
|
Write: api.PermitUser,
|
||||||
|
BelongsTo: m.Module,
|
||||||
|
HandlerFunc: bwChartHandler.ServeHTTP,
|
||||||
|
Name: "Bandwidth Chart",
|
||||||
|
Description: "Query the in-memory sqlite connection database and return a chart of bytes sent/received.",
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := api.RegisterEndpoint(api.Endpoint{
|
if err := api.RegisterEndpoint(api.Endpoint{
|
||||||
Name: "Remove connections from profile history",
|
Name: "Remove connections from profile history",
|
||||||
Description: "Remove all connections from the history database for one or more profiles",
|
Description: "Remove all connections from the history database for one or more profiles",
|
||||||
|
@ -137,7 +154,6 @@ func (m *module) prepare() error {
|
||||||
Write: api.PermitUser,
|
Write: api.PermitUser,
|
||||||
BelongsTo: m.Module,
|
BelongsTo: m.Module,
|
||||||
ActionFunc: func(ar *api.Request) (msg string, err error) {
|
ActionFunc: func(ar *api.Request) (msg string, err error) {
|
||||||
// TODO: Use query parameters instead.
|
|
||||||
var body struct {
|
var body struct {
|
||||||
ProfileIDs []string `json:"profileIDs"`
|
ProfileIDs []string `json:"profileIDs"`
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue