From bc285b593d52b69595082251ae9adade6d6005b5 Mon Sep 17 00:00:00 2001 From: Patrick Pacher Date: Thu, 24 Aug 2023 13:14:23 +0200 Subject: [PATCH] Update bandwidth handling and add basic chart support --- firewall/packet_handler.go | 11 +- ...art_handler.go => active_chart_handler.go} | 15 +- netquery/bandwidth_chart_handler.go | 145 ++++++++++++++++++ netquery/database.go | 37 +++-- netquery/manager.go | 2 +- netquery/module_api.go | 20 ++- 6 files changed, 208 insertions(+), 22 deletions(-) rename netquery/{chart_handler.go => active_chart_handler.go} (85%) create mode 100644 netquery/bandwidth_chart_handler.go diff --git a/firewall/packet_handler.go b/firewall/packet_handler.go index 0ddf3a9a..e30d68aa 100644 --- a/firewall/packet_handler.go +++ b/firewall/packet_handler.go @@ -679,9 +679,15 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) { } defer conn.Unlock() + bytesIn := bwUpdate.BytesReceived + bytesOut := bwUpdate.BytesSent + // Update stats according to method. switch bwUpdate.Method { case packet.Absolute: + bytesIn = bwUpdate.BytesReceived - conn.BytesReceived + bytesOut = bwUpdate.BytesSent - conn.BytesSent + conn.BytesReceived = bwUpdate.BytesReceived conn.BytesSent = bwUpdate.BytesSent case packet.Additive: @@ -697,10 +703,11 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) { if err := netquery.DefaultModule.Store.UpdateBandwidth( ctx, conn.HistoryEnabled, + fmt.Sprintf("%s/%s", conn.ProcessContext.Source, conn.ProcessContext.Profile), conn.Process().GetKey(), conn.ID, - conn.BytesReceived, - conn.BytesSent, + bytesIn, + bytesOut, ); err != nil { log.Errorf("filter: failed to persist bandwidth data: %s", err) } diff --git a/netquery/chart_handler.go b/netquery/active_chart_handler.go similarity index 85% rename from netquery/chart_handler.go rename to netquery/active_chart_handler.go index a44f03ac..62c3ee7a 100644 --- a/netquery/chart_handler.go +++ b/netquery/active_chart_handler.go @@ -13,12 +13,12 @@ import ( "github.com/safing/portmaster/netquery/orm" ) -// ChartHandler handles requests for connection charts. -type ChartHandler struct { +// ActiveChartHandler handles requests for connection charts. +type ActiveChartHandler struct { Database *Database } -func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { +func (ch *ActiveChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { requestPayload, err := ch.parseRequest(req) if err != nil { http.Error(resp, err.Error(), http.StatusBadRequest) @@ -62,7 +62,7 @@ func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { }) } -func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl +func (ch *ActiveChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl var body io.Reader switch req.Method { @@ -99,10 +99,11 @@ WITH RECURSIVE epoch(x) AS ( UNION ALL SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0 ) -SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked FROM epoch +SELECT x as timestamp, SUM(verdict IN (2, 5, 6)) AS value, SUM(verdict NOT IN (2, 5, 6)) as countBlocked + FROM epoch JOIN connections - ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0) - %s + ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 >= timestamp+0) + %s GROUP BY round(timestamp/10, 0)*10;` clause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig) diff --git a/netquery/bandwidth_chart_handler.go b/netquery/bandwidth_chart_handler.go new file mode 100644 index 00000000..1d61e1de --- /dev/null +++ b/netquery/bandwidth_chart_handler.go @@ -0,0 +1,145 @@ +package netquery + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/safing/portmaster/netquery/orm" +) + +// BandwidthChartHandler handles requests for connection charts. +type BandwidthChartHandler struct { + Database *Database +} + +type BandwidthChartRequest struct { + Profiles []string `json:"profiles"` + Connections []string `json:"connections"` +} + +func (ch *BandwidthChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + requestPayload, err := ch.parseRequest(req) + if err != nil { + http.Error(resp, err.Error(), http.StatusBadRequest) + + return + } + + query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema) + if err != nil { + http.Error(resp, err.Error(), http.StatusBadRequest) + + return + } + + // actually execute the query against the database and collect the result + var result []map[string]interface{} + if err := ch.Database.Execute( + req.Context(), + query, + orm.WithNamedArgs(paramMap), + orm.WithResult(&result), + orm.WithSchema(*ch.Database.Schema), + ); err != nil { + http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError) + + return + } + + // send the HTTP status code + resp.WriteHeader(http.StatusOK) + + // prepare the result encoder. + enc := json.NewEncoder(resp) + enc.SetEscapeHTML(false) + enc.SetIndent("", " ") + + _ = enc.Encode(map[string]interface{}{ //nolint:errchkjson + "results": result, + "query": query, + "params": paramMap, + }) +} + +func (ch *BandwidthChartHandler) parseRequest(req *http.Request) (*BandwidthChartRequest, error) { //nolint:dupl + var body io.Reader + + switch req.Method { + case http.MethodPost, http.MethodPut: + body = req.Body + case http.MethodGet: + body = strings.NewReader(req.URL.Query().Get("q")) + default: + return nil, fmt.Errorf("invalid HTTP method") + } + + var requestPayload BandwidthChartRequest + blob, err := io.ReadAll(body) + if err != nil { + return nil, fmt.Errorf("failed to read body" + err.Error()) + } + + body = bytes.NewReader(blob) + + dec := json.NewDecoder(body) + dec.DisallowUnknownFields() + + if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("invalid query: %w", err) + } + + return &requestPayload, nil +} + +func (req *BandwidthChartRequest) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) { + selects := []string{ + "(round(time/10, 0)*10) as time", + "SUM(incoming) as incoming", + "SUM(outgoing) as outgoing", + } + groupBy := []string{"round(time/10, 0)*10"} + whereClause := "" + params := make(map[string]any) + + if len(req.Profiles) > 0 { + groupBy = []string{"profile", "round(time/10, 0)*10"} + selects = append(selects, "profile") + clauses := make([]string, len(req.Profiles)) + + for idx, p := range req.Profiles { + key := fmt.Sprintf(":p%d", idx) + clauses[idx] = "profile = " + key + params[key] = p + } + + whereClause = "WHERE " + strings.Join(clauses, " OR ") + } else if len(req.Connections) > 0 { + groupBy = []string{"conn_id", "round(time/10, 0)*10"} + selects = append(selects, "conn_id") + + clauses := make([]string, len(req.Connections)) + + for idx, p := range req.Connections { + key := fmt.Sprintf(":c%d", idx) + clauses[idx] = "conn_id = " + key + params[key] = p + } + + whereClause = "WHERE " + strings.Join(clauses, " OR ") + } + + template := fmt.Sprintf( + `SELECT %s FROM main.bandwidth %s GROUP BY %s ORDER BY time ASC`, + strings.Join(selects, ", "), + whereClause, + strings.Join(groupBy, ", "), + ) + + return template, params, nil +} diff --git a/netquery/database.go b/netquery/database.go index 807c0097..4750dbe4 100644 --- a/netquery/database.go +++ b/netquery/database.go @@ -299,6 +299,20 @@ func (db *Database) ApplyMigrations() error { } } + bwSchema := `CREATE TABLE IF NOT EXISTS main.bandwidth ( + conn_id TEXT NOT NULL, + profile TEXT NOT NULL, + time INTEGER NOT NULL, + incoming INTEGER NOT NULL, + outgoing INTEGER NOT NULL, + CONSTRAINT fk_conn_id + FOREIGN KEY(conn_id) REFERENCES connections(id) + ON DELETE CASCADE + )` + if err := sqlitex.ExecuteTransient(db.writeConn, bwSchema, nil); err != nil { + return fmt.Errorf("failed to create main.bandwidth database: %w", err) + } + return nil } @@ -535,21 +549,16 @@ func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error { // UpdateBandwidth updates bandwidth data for the connection and optionally also writes // the bandwidth data to the history database. -func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error { +func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error { params := map[string]any{ ":id": makeNqIDFromParts(processKey, connID), } parts := []string{} - if bytesReceived != 0 { - parts = append(parts, "bytes_received = :bytes_received") - params[":bytes_received"] = bytesReceived - } - - if bytesSent != 0 { - parts = append(parts, "bytes_sent = :bytes_sent") - params[":bytes_sent"] = bytesSent - } + parts = append(parts, "bytes_received = :bytes_received") + params[":bytes_received"] = bytesReceived + parts = append(parts, "bytes_sent = :bytes_sent") + params[":bytes_sent"] = bytesSent updateSet := strings.Join(parts, ", ") @@ -570,6 +579,14 @@ func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, pro } } + // also add the date to the in-memory bandwidth database + params[":time"] = time.Now().Unix() + params[":profile"] = profileKey + stmt := "INSERT INTO main.bandwidth (conn_id, profile, time, incoming, outgoing) VALUES(:id, :profile, :time, :bytes_received, :bytes_sent)" + if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil { + merr.Errors = append(merr.Errors, fmt.Errorf("failed to update main.bandwidth: %w", err)) + } + return merr.ErrorOrNil() } diff --git a/netquery/manager.go b/netquery/manager.go index 8749f482..810782a8 100644 --- a/netquery/manager.go +++ b/netquery/manager.go @@ -38,7 +38,7 @@ type ( // UpdateBandwidth updates bandwidth data for the connection and optionally also writes // the bandwidth data to the history database. - UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error + UpdateBandwidth(ctx context.Context, enableHistory bool, profileKey string, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error // CleanupHistory deletes data outside of the retention time frame from the history database. CleanupHistory(ctx context.Context) error diff --git a/netquery/module_api.go b/netquery/module_api.go index 39647186..63a3d07f 100644 --- a/netquery/module_api.go +++ b/netquery/module_api.go @@ -87,7 +87,11 @@ func (m *module) prepare() error { IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false), } - chartHandler := &ChartHandler{ + chartHandler := &ActiveChartHandler{ + Database: m.Store, + } + + bwChartHandler := &BandwidthChartHandler{ Database: m.Store, } @@ -129,6 +133,19 @@ func (m *module) prepare() error { return fmt.Errorf("failed to register API endpoint: %w", err) } + if err := api.RegisterEndpoint(api.Endpoint{ + // TODO: Use query parameters instead. + Path: "netquery/charts/bandwidth", + MimeType: "application/json", + Write: api.PermitUser, + BelongsTo: m.Module, + HandlerFunc: bwChartHandler.ServeHTTP, + Name: "Bandwidth Chart", + Description: "Query the in-memory sqlite connection database and return a chart of bytes sent/received.", + }); err != nil { + return fmt.Errorf("failed to register API endpoint: %w", err) + } + if err := api.RegisterEndpoint(api.Endpoint{ Name: "Remove connections from profile history", Description: "Remove all connections from the history database for one or more profiles", @@ -137,7 +154,6 @@ func (m *module) prepare() error { Write: api.PermitUser, BelongsTo: m.Module, ActionFunc: func(ar *api.Request) (msg string, err error) { - // TODO: Use query parameters instead. var body struct { ProfileIDs []string `json:"profileIDs"` }