mirror of
https://github.com/safing/portmaster
synced 2025-04-20 02:49:10 +00:00
148 lines
3.5 KiB
Go
148 lines
3.5 KiB
Go
package netquery
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/safing/portmaster/service/netquery/orm"
|
|
)
|
|
|
|
// BandwidthChartHandler handles requests for connection charts.
|
|
type BandwidthChartHandler struct {
|
|
Database *Database
|
|
}
|
|
|
|
// BandwidthChartRequest holds a request for a bandwidth chart.
|
|
type BandwidthChartRequest struct {
|
|
Interval int `json:"interval"`
|
|
Query Query `json:"query"`
|
|
GroupBy []string `json:"groupBy"`
|
|
}
|
|
|
|
func (ch *BandwidthChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { //nolint:dupl
|
|
requestPayload, err := ch.parseRequest(req)
|
|
if err != nil {
|
|
http.Error(resp, err.Error(), http.StatusBadRequest)
|
|
|
|
return
|
|
}
|
|
|
|
query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema)
|
|
if err != nil {
|
|
http.Error(resp, err.Error(), http.StatusBadRequest)
|
|
|
|
return
|
|
}
|
|
|
|
// actually execute the query against the database and collect the result
|
|
var result []map[string]interface{}
|
|
if err := ch.Database.Execute(
|
|
req.Context(),
|
|
query,
|
|
orm.WithNamedArgs(paramMap),
|
|
orm.WithResult(&result),
|
|
orm.WithSchema(*ch.Database.Schema),
|
|
); err != nil {
|
|
http.Error(resp, failedQuery+err.Error(), http.StatusInternalServerError)
|
|
|
|
return
|
|
}
|
|
|
|
// send the HTTP status code
|
|
resp.WriteHeader(http.StatusOK)
|
|
|
|
// prepare the result encoder.
|
|
enc := json.NewEncoder(resp)
|
|
enc.SetEscapeHTML(false)
|
|
enc.SetIndent("", " ")
|
|
|
|
_ = enc.Encode(map[string]interface{}{ //nolint:errchkjson
|
|
"results": result,
|
|
"query": query,
|
|
"params": paramMap,
|
|
})
|
|
}
|
|
|
|
func (ch *BandwidthChartHandler) parseRequest(req *http.Request) (*BandwidthChartRequest, error) { //nolint:dupl
|
|
var body io.Reader
|
|
|
|
switch req.Method {
|
|
case http.MethodPost, http.MethodPut:
|
|
body = req.Body
|
|
case http.MethodGet:
|
|
body = strings.NewReader(req.URL.Query().Get("q"))
|
|
default:
|
|
return nil, fmt.Errorf("invalid HTTP method")
|
|
}
|
|
|
|
var requestPayload BandwidthChartRequest
|
|
blob, err := io.ReadAll(body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read body: %w", err)
|
|
}
|
|
|
|
body = bytes.NewReader(blob)
|
|
|
|
dec := json.NewDecoder(body)
|
|
dec.DisallowUnknownFields()
|
|
|
|
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
|
|
return nil, fmt.Errorf("invalid query: %w", err)
|
|
}
|
|
|
|
return &requestPayload, nil
|
|
}
|
|
|
|
func (req *BandwidthChartRequest) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
|
if req.Interval == 0 {
|
|
req.Interval = 10
|
|
}
|
|
|
|
interval := fmt.Sprintf("round(time/%d, 0)*%d", req.Interval, req.Interval)
|
|
|
|
// make sure there are only allowed fields specified in the request group-by
|
|
for _, gb := range req.GroupBy {
|
|
def := schema.GetColumnDef(gb)
|
|
if def == nil {
|
|
return "", nil, fmt.Errorf("unsupported groupBy key: %q", gb)
|
|
}
|
|
}
|
|
|
|
selects := append([]string{
|
|
interval + " as timestamp",
|
|
"SUM(incoming) as incoming",
|
|
"SUM(outgoing) as outgoing",
|
|
}, req.GroupBy...)
|
|
|
|
groupBy := append([]string{interval}, req.GroupBy...)
|
|
|
|
whereClause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
if whereClause != "" {
|
|
whereClause = "WHERE " + whereClause
|
|
}
|
|
|
|
template := fmt.Sprintf(
|
|
`SELECT %s
|
|
FROM main.bandwidth AS bw
|
|
JOIN main.connections AS conns
|
|
ON bw.conn_id = conns.id
|
|
%s
|
|
GROUP BY %s
|
|
ORDER BY time ASC`,
|
|
strings.Join(selects, ", "),
|
|
whereClause,
|
|
strings.Join(groupBy, ", "),
|
|
)
|
|
|
|
return template, params, nil
|
|
}
|