Merge pull request from safing/fix/missing-dll-failure

Fix Missing dll failure
This commit is contained in:
Daniel Hååvi 2024-12-02 15:23:07 +01:00 committed by GitHub
commit 6e173e3b96
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 90 additions and 26 deletions

View file

@ -14,7 +14,7 @@ import (
) )
type ETWSession struct { type ETWSession struct {
i integration.ETWFunctions i *integration.ETWFunctions
shutdownGuard atomic.Bool shutdownGuard atomic.Bool
shutdownMutex sync.Mutex shutdownMutex sync.Mutex
@ -23,7 +23,10 @@ type ETWSession struct {
} }
// NewSession creates new ETW event listener and initilizes it. This is a low level interface, make sure to call DestorySession when you are done using it. // NewSession creates new ETW event listener and initilizes it. This is a low level interface, make sure to call DestorySession when you are done using it.
func NewSession(etwInterface integration.ETWFunctions, callback func(domain string, result string)) (*ETWSession, error) { func NewSession(etwInterface *integration.ETWFunctions, callback func(domain string, result string)) (*ETWSession, error) {
if etwInterface == nil {
return nil, fmt.Errorf("etw interface was nil")
}
etwSession := &ETWSession{ etwSession := &ETWSession{
i: etwInterface, i: etwInterface,
} }
@ -47,7 +50,7 @@ func NewSession(etwInterface integration.ETWFunctions, callback func(domain stri
// Initialize session. // Initialize session.
err := etwSession.i.InitializeSession(etwSession.state) err := etwSession.i.InitializeSession(etwSession.state)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialzie session: %q", err) return nil, fmt.Errorf("failed to initialize session: %q", err)
} }
return etwSession, nil return etwSession, nil
@ -65,6 +68,10 @@ func (l *ETWSession) IsRunning() bool {
// FlushTrace flushes the trace buffer. // FlushTrace flushes the trace buffer.
func (l *ETWSession) FlushTrace() error { func (l *ETWSession) FlushTrace() error {
if l.i == nil {
return fmt.Errorf("session not initialized")
}
l.shutdownMutex.Lock() l.shutdownMutex.Lock()
defer l.shutdownMutex.Unlock() defer l.shutdownMutex.Unlock()
@ -83,6 +90,9 @@ func (l *ETWSession) StopTrace() error {
// DestroySession closes the session and frees the allocated memory. Listener cannot be used after this function is called. // DestroySession closes the session and frees the allocated memory. Listener cannot be used after this function is called.
func (l *ETWSession) DestroySession() error { func (l *ETWSession) DestroySession() error {
if l.i == nil {
return fmt.Errorf("session not initialized")
}
l.shutdownMutex.Lock() l.shutdownMutex.Lock()
defer l.shutdownMutex.Unlock() defer l.shutdownMutex.Unlock()

View file

@ -23,22 +23,38 @@ func newListener(module *DNSMonitor) (*Listener, error) {
ResolverInfo.Source = resolver.ServerSourceETW ResolverInfo.Source = resolver.ServerSourceETW
listener := &Listener{} listener := &Listener{}
var err error
// Initialize new dns event session. // Initialize new dns event session.
listener.etw, err = NewSession(module.instance.OSIntegration().GetETWInterface(), listener.processEvent) err := initializeSessions(module, listener)
if err != nil { if err != nil {
return nil, err // Listen for event if the dll has been loaded
module.instance.OSIntegration().OnInitializedEvent.AddCallback("loader-listener", func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) {
err = initializeSessions(module, listener)
if err != nil {
return false, err
}
return true, nil
})
} }
// Start listening for events.
module.mgr.Go("etw-dns-event-listener", func(w *mgr.WorkerCtx) error {
return listener.etw.StartTrace()
})
return listener, nil return listener, nil
} }
func initializeSessions(module *DNSMonitor, listener *Listener) error {
var err error
listener.etw, err = NewSession(module.instance.OSIntegration().GetETWInterface(), listener.processEvent)
if err != nil {
return err
}
// Start listener
module.mgr.Go("etw-dns-event-listener", func(w *mgr.WorkerCtx) error {
return listener.etw.StartTrace()
})
return nil
}
func (l *Listener) flush() error { func (l *Listener) flush() error {
if l.etw == nil {
return fmt.Errorf("etw not initialized")
}
return l.etw.FlushTrace() return l.etw.FlushTrace()
} }

View file

@ -19,8 +19,8 @@ type ETWFunctions struct {
stopOldSession *windows.Proc stopOldSession *windows.Proc
} }
func initializeETW(dll *windows.DLL) (ETWFunctions, error) { func initializeETW(dll *windows.DLL) (*ETWFunctions, error) {
var functions ETWFunctions functions := &ETWFunctions{}
var err error var err error
functions.createState, err = dll.FindProc("PM_ETWCreateState") functions.createState, err = dll.FindProc("PM_ETWCreateState")
if err != nil { if err != nil {

View file

@ -5,24 +5,55 @@ package integration
import ( import (
"fmt" "fmt"
"sync"
"github.com/safing/portmaster/base/log"
"github.com/safing/portmaster/service/mgr"
"github.com/safing/portmaster/service/updates" "github.com/safing/portmaster/service/updates"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
type OSSpecific struct { type OSSpecific struct {
dll *windows.DLL dll *windows.DLL
etwFunctions ETWFunctions etwFunctions *ETWFunctions
} }
// Initialize loads the dll and finds all the needed functions from it. // Initialize loads the dll and finds all the needed functions from it.
func (i *OSIntegration) Initialize() error { func (i *OSIntegration) Initialize() error {
// Try to load dll
err := i.loadDLL()
if err != nil {
log.Errorf("integration: failed to load dll: %s", err)
callbackLock := sync.Mutex{}
// listen for event from the updater and try to load again if any.
i.instance.Updates().EventResourcesUpdated.AddCallback("core-dll-loader", func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) {
// Make sure no multiple callas are executed at the same time.
callbackLock.Lock()
defer callbackLock.Unlock()
// Try to load again.
err = i.loadDLL()
if err != nil {
log.Errorf("integration: failed to load dll: %s", err)
} else {
log.Info("integration: initialize successful after updater event")
}
return false, nil
})
} else {
log.Info("integration: initialize successful")
}
return nil
}
func (i *OSIntegration) loadDLL() error {
// Find path to the dll. // Find path to the dll.
file, err := updates.GetPlatformFile("dll/portmaster-core.dll") file, err := updates.GetPlatformFile("dll/portmaster-core.dll")
if err != nil { if err != nil {
return err return err
} }
// Load the DLL. // Load the DLL.
i.os.dll, err = windows.LoadDLL(file.Path()) i.os.dll, err = windows.LoadDLL(file.Path())
if err != nil { if err != nil {
@ -35,10 +66,13 @@ func (i *OSIntegration) Initialize() error {
return err return err
} }
// Notify listeners
i.OnInitializedEvent.Submit(struct{}{})
return nil return nil
} }
// CleanUp releases any resourses allocated during initializaion. // CleanUp releases any resources allocated during initialization.
func (i *OSIntegration) CleanUp() error { func (i *OSIntegration) CleanUp() error {
if i.os.dll != nil { if i.os.dll != nil {
return i.os.dll.Release() return i.os.dll.Release()
@ -46,7 +80,7 @@ func (i *OSIntegration) CleanUp() error {
return nil return nil
} }
// GetETWInterface return struct containing all the ETW related functions. // GetETWInterface return struct containing all the ETW related functions, and nil if it was not loaded yet
func (i *OSIntegration) GetETWInterface() ETWFunctions { func (i *OSIntegration) GetETWInterface() *ETWFunctions {
return i.os.etwFunctions return i.os.etwFunctions
} }

View file

@ -7,8 +7,9 @@ import (
// OSIntegration module provides special integration with the OS. // OSIntegration module provides special integration with the OS.
type OSIntegration struct { type OSIntegration struct {
m *mgr.Manager m *mgr.Manager
states *mgr.StateMgr
OnInitializedEvent *mgr.EventMgr[struct{}]
//nolint:unused //nolint:unused
os OSSpecific os OSSpecific
@ -20,10 +21,9 @@ type OSIntegration struct {
func New(instance instance) (*OSIntegration, error) { func New(instance instance) (*OSIntegration, error) {
m := mgr.New("OSIntegration") m := mgr.New("OSIntegration")
module := &OSIntegration{ module := &OSIntegration{
m: m, m: m,
states: m.NewStateMgr(), OnInitializedEvent: mgr.NewEventMgr[struct{}]("on-initialized", m),
instance: instance,
instance: instance,
} }
return module, nil return module, nil

View file

@ -550,7 +550,11 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) {
if module.instance.Resolver().IsDisabled() && conn.shouldWaitForDomain() { if module.instance.Resolver().IsDisabled() && conn.shouldWaitForDomain() {
// Flush the dns listener buffer and try again. // Flush the dns listener buffer and try again.
for i := range 4 { for i := range 4 {
_ = module.instance.DNSMonitor().Flush() err = module.instance.DNSMonitor().Flush()
if err != nil {
// Error flushing, dont try again.
break
}
ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String()) ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String())
if err == nil { if err == nil {
log.Tracer(pkt.Ctx()).Debugf("network: found domain from dnsmonitor after %d tries", i+1) log.Tracer(pkt.Ctx()).Debugf("network: found domain from dnsmonitor after %d tries", i+1)