diff --git a/api/client/api.go b/api/client/api.go new file mode 100644 index 0000000..e1a1664 --- /dev/null +++ b/api/client/api.go @@ -0,0 +1,57 @@ +package client + +// Get sends a get command to the API. +func (c *Client) Get(key string, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestGet, key, nil) + return op +} + +// Query sends a query command to the API. +func (c *Client) Query(query string, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestQuery, query, nil) + return op +} + +// Sub sends a sub command to the API. +func (c *Client) Sub(query string, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestSub, query, nil) + return op +} + +// Qsub sends a qsub command to the API. +func (c *Client) Qsub(query string, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestQsub, query, nil) + return op +} + +// Create sends a create command to the API. +func (c *Client) Create(key string, value interface{}, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestCreate, key, value) + return op +} + +// Update sends an update command to the API. +func (c *Client) Update(key string, value interface{}, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestUpdate, key, value) + return op +} + +// Insert sends an insert command to the API. +func (c *Client) Insert(key string, value interface{}, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestInsert, key, value) + return op +} + +// Delete sends a delete command to the API. +func (c *Client) Delete(key string, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestDelete, key, nil) + return op +} diff --git a/api/client/client.go b/api/client/client.go new file mode 100644 index 0000000..bbdb14e --- /dev/null +++ b/api/client/client.go @@ -0,0 +1,241 @@ +package client + +import ( + "fmt" + "sync" + "time" + + "github.com/Safing/portbase/log" + "github.com/gorilla/websocket" + "github.com/tevino/abool" +) + +const ( + backOffTimer = 1 * time.Second + + offlineSignal uint8 = 0 + onlineSignal uint8 = 1 +) + +// The Client enables easy interaction with the API. +type Client struct { + sync.Mutex + + server string + + onlineSignal chan struct{} + offlineSignal chan struct{} + shutdownSignal chan struct{} + lastSignal uint8 + + send chan *Message + resend chan *Message + recv chan *Message + + operations map[string]*Operation + nextOpID uint64 + + wsConn *websocket.Conn + lastError string +} + +// NewClient returns a new Client. +func NewClient(server string) *Client { + c := &Client{ + server: server, + onlineSignal: make(chan struct{}), + offlineSignal: make(chan struct{}), + shutdownSignal: make(chan struct{}), + lastSignal: offlineSignal, + send: make(chan *Message, 100), + resend: make(chan *Message, 1), + recv: make(chan *Message, 100), + operations: make(map[string]*Operation), + } + go c.handler() + return c +} + +// Connect connects to the API once. +func (c *Client) Connect() error { + defer c.signalOffline() + + err := c.wsConnect() + if err != nil && err.Error() != c.lastError { + log.Errorf("client: error connecting to Portmaster: %s", err) + c.lastError = err.Error() + } + return err +} + +// StayConnected calls Connect again whenever the connection is lost. +func (c *Client) StayConnected() { + log.Infof("client: connecting to Portmaster at %s", c.server) + + c.Connect() + for { + select { + case <-time.After(backOffTimer): + log.Infof("client: reconnecting...") + c.Connect() + case <-c.shutdownSignal: + return + } + } +} + +// Shutdown shuts the client down. +func (c *Client) Shutdown() { + select { + case <-c.shutdownSignal: + default: + close(c.shutdownSignal) + } +} + +func (c *Client) signalOnline() { + c.Lock() + defer c.Unlock() + if c.lastSignal == offlineSignal { + log.Infof("client: went online") + c.offlineSignal = make(chan struct{}) + close(c.onlineSignal) + c.lastSignal = onlineSignal + + // resend unsent request + for _, op := range c.operations { + if op.resuscitationEnabled.IsSet() && op.request.sent != nil && op.request.sent.SetToIf(true, false) { + op.client.send <- op.request + log.Infof("client: resuscitated %s %s %s", op.request.OpID, op.request.Type, op.request.Key) + } + } + + } +} + +func (c *Client) signalOffline() { + c.Lock() + defer c.Unlock() + if c.lastSignal == onlineSignal { + log.Infof("client: went offline") + c.onlineSignal = make(chan struct{}) + close(c.offlineSignal) + c.lastSignal = offlineSignal + + // signal offline status to operations + for _, op := range c.operations { + op.handle(&Message{ + OpID: op.ID, + Type: MsgOffline, + }) + } + + } +} + +// Online returns a closed channel read if the client is connected to the API. +func (c *Client) Online() <-chan struct{} { + c.Lock() + defer c.Unlock() + return c.onlineSignal +} + +// Offline returns a closed channel read if the client is not connected to the API. +func (c *Client) Offline() <-chan struct{} { + c.Lock() + defer c.Unlock() + return c.offlineSignal +} + +func (c *Client) handler() { + for { + select { + + case m := <-c.recv: + + if m == nil { + return + } + + c.Lock() + op, ok := c.operations[m.OpID] + c.Unlock() + + if ok { + log.Tracef("client: [%s] received %s msg: %s", m.OpID, m.Type, m.Key) + op.handle(m) + } else { + log.Tracef("client: received message for unknown operation %s", m.OpID) + } + + case <-c.shutdownSignal: + return + + } + } +} + +// Operation represents a single operation by a client. +type Operation struct { + ID string + request *Message + client *Client + handleFunc func(*Message) + handler chan *Message + resuscitationEnabled *abool.AtomicBool +} + +func (op *Operation) handle(m *Message) { + if op.handleFunc != nil { + op.handleFunc(m) + } else { + select { + case op.handler <- m: + default: + log.Warningf("client: handler channel of operation %s overflowed", op.ID) + } + } +} + +// Cancel the operation. +func (op *Operation) Cancel() { + op.client.Lock() + defer op.client.Unlock() + delete(op.client.operations, op.ID) + close(op.handler) +} + +// Send sends a request to the API. +func (op *Operation) Send(command, text string, data interface{}) { + op.request = &Message{ + OpID: op.ID, + Type: command, + Key: text, + Value: data, + sent: abool.NewBool(false), + } + log.Tracef("client: [%s] sending %s msg: %s", op.request.OpID, op.request.Type, op.request.Key) + op.client.send <- op.request +} + +// EnableResuscitation will resend the request after reconnecting to the API. +func (op *Operation) EnableResuscitation() { + op.resuscitationEnabled.Set() +} + +// NewOperation returns a new operation. +func (c *Client) NewOperation(handleFunc func(*Message)) *Operation { + c.Lock() + defer c.Unlock() + + c.nextOpID++ + op := &Operation{ + ID: fmt.Sprintf("#%d", c.nextOpID), + client: c, + handleFunc: handleFunc, + handler: make(chan *Message, 100), + resuscitationEnabled: abool.NewBool(false), + } + c.operations[op.ID] = op + return op +} diff --git a/api/client/const.go b/api/client/const.go new file mode 100644 index 0000000..c189c0e --- /dev/null +++ b/api/client/const.go @@ -0,0 +1,30 @@ +package client + +// message types +const ( + msgRequestGet = "get" + msgRequestQuery = "query" + msgRequestSub = "sub" + msgRequestQsub = "qsub" + msgRequestCreate = "create" + msgRequestUpdate = "update" + msgRequestInsert = "insert" + msgRequestDelete = "delete" + + MsgOk = "ok" + MsgError = "error" + MsgDone = "done" + MsgSuccess = "success" + MsgUpdate = "upd" + MsgNew = "new" + MsgDelete = "del" + MsgWarning = "warning" + + MsgOffline = "offline" // special message type for signaling the handler that the connection was lost + + apiSeperator = "|" +) + +var ( + apiSeperatorBytes = []byte(apiSeperator) +) diff --git a/api/client/message.go b/api/client/message.go new file mode 100644 index 0000000..31864c1 --- /dev/null +++ b/api/client/message.go @@ -0,0 +1,117 @@ +package client + +import ( + "bytes" + "errors" + + "github.com/Safing/portbase/container" + "github.com/Safing/portbase/formats/dsd" + "github.com/tevino/abool" +) + +var ( + ErrMalformedMessage = errors.New("malformed message") +) + +type Message struct { + OpID string + Type string + Key string + RawValue []byte + Value interface{} + sent *abool.AtomicBool +} + +func ParseMessage(data []byte) (*Message, error) { + parts := bytes.SplitN(data, apiSeperatorBytes, 4) + if len(parts) < 2 { + return nil, ErrMalformedMessage + } + + m := &Message{ + OpID: string(parts[0]), + Type: string(parts[1]), + } + + switch m.Type { + case MsgOk, MsgUpdate, MsgNew: + // parse key and data + // 127|ok|| + // 127|upd|| + // 127|new|| + if len(parts) != 4 { + return nil, ErrMalformedMessage + } + m.Key = string(parts[2]) + m.RawValue = parts[3] + case MsgDelete: + // parse key + // 127|del| + if len(parts) != 3 { + return nil, ErrMalformedMessage + } + m.Key = string(parts[2]) + case MsgWarning, MsgError: + // parse message + // 127|error| + // 127|warning| // error with single record, operation continues + if len(parts) != 3 { + return nil, ErrMalformedMessage + } + m.Key = string(parts[2]) + case MsgDone, MsgSuccess: + // nothing more to do + // 127|success + // 127|done + } + + return m, nil +} + +func (m *Message) Pack() ([]byte, error) { + c := container.New([]byte(m.OpID), apiSeperatorBytes, []byte(m.Type)) + + if m.Key != "" { + c.Append(apiSeperatorBytes) + c.Append([]byte(m.Key)) + if len(m.RawValue) > 0 { + c.Append(apiSeperatorBytes) + c.Append(m.RawValue) + } else if m.Value != nil { + var err error + m.RawValue, err = dsd.Dump(m.Value, dsd.JSON) + if err != nil { + return nil, err + } + c.Append(apiSeperatorBytes) + c.Append(m.RawValue) + } + } + + return c.CompileData(), nil +} + +func (m *Message) IsOk() bool { + return m.Type == MsgOk +} +func (m *Message) IsDone() bool { + return m.Type == MsgDone +} +func (m *Message) IsError() bool { + return m.Type == MsgError +} +func (m *Message) IsUpdate() bool { + return m.Type == MsgUpdate +} +func (m *Message) IsNew() bool { + return m.Type == MsgNew +} +func (m *Message) IsDelete() bool { + return m.Type == MsgDelete +} +func (m *Message) IsWarning() bool { + return m.Type == MsgWarning +} +func (m *Message) GetMessage() string { + return m.Key +} diff --git a/api/client/websocket.go b/api/client/websocket.go new file mode 100644 index 0000000..70e2e35 --- /dev/null +++ b/api/client/websocket.go @@ -0,0 +1,121 @@ +package client + +import ( + "fmt" + "sync" + + "github.com/Safing/portbase/log" + "github.com/tevino/abool" + + "github.com/gorilla/websocket" +) + +type wsState struct { + wsConn *websocket.Conn + wg sync.WaitGroup + failing *abool.AtomicBool + failSignal chan struct{} +} + +func (c *Client) wsConnect() error { + state := &wsState{ + failing: abool.NewBool(false), + failSignal: make(chan struct{}), + } + + var err error + state.wsConn, _, err = websocket.DefaultDialer.Dial(fmt.Sprintf("ws://%s/api/database/v1", c.server), nil) + if err != nil { + return err + } + + c.signalOnline() + + state.wg.Add(2) + go c.wsReader(state) + go c.wsWriter(state) + + // wait for end of connection + select { + case <-state.failSignal: + case <-c.shutdownSignal: + state.Error("") + } + state.wsConn.Close() + state.wg.Wait() + + return nil +} + +func (c *Client) wsReader(state *wsState) { + defer state.wg.Done() + for { + _, data, err := state.wsConn.ReadMessage() + log.Tracef("client: read message") + if err != nil { + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + state.Error(fmt.Sprintf("client: read error: %s", err)) + } else { + state.Error("client: connection closed by server") + } + return + } + log.Tracef("client: received message: %s", string(data)) + m, err := ParseMessage(data) + if err != nil { + log.Warningf("client: failed to parse message: %s", err) + } else { + select { + case c.recv <- m: + case <-state.failSignal: + return + } + } + } +} + +func (c *Client) wsWriter(state *wsState) { + defer state.wg.Done() + for { + select { + case <-state.failSignal: + return + case m := <-c.resend: + data, err := m.Pack() + if err == nil { + err = state.wsConn.WriteMessage(websocket.BinaryMessage, data) + } + if err != nil { + state.Error(fmt.Sprintf("client: write error: %s", err)) + return + } + log.Tracef("client: sent message: %s", string(data)) + if m.sent != nil { + m.sent.Set() + } + case m := <-c.send: + data, err := m.Pack() + if err == nil { + err = state.wsConn.WriteMessage(websocket.BinaryMessage, data) + } + if err != nil { + c.resend <- m + state.Error(fmt.Sprintf("client: write error: %s", err)) + return + } + log.Tracef("client: sent message: %s", string(data)) + if m.sent != nil { + m.sent.Set() + } + } + } +} + +func (state *wsState) Error(message string) { + if state.failing.SetToIf(false, true) { + close(state.failSignal) + if message != "" { + log.Warning(message) + } + } +}