From 0f28af66cd3104c7e9e488f9b463f2c37075f73f Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov <vladimir@safing.io> Date: Mon, 27 Jan 2025 17:21:54 +0200 Subject: [PATCH] Add PID in ETW DNS event in the integration dll (#1768) * [service] Add reading of PID in ETW DNS event * [service] Use PID of the ETW DNS events * [service] Fix use of nil pointer * [service] Fix compiler error --- .../interception/dnsmonitor/etwlink_windows.go | 10 +++++----- .../dnsmonitor/eventlistener_linux.go | 2 +- .../dnsmonitor/eventlistener_windows.go | 15 +++++++++++++-- .../firewall/interception/dnsmonitor/module.go | 4 ++-- service/network/connection.go | 10 +++++++++- windows_core_dll/dllmain.cpp | 8 ++++---- 6 files changed, 34 insertions(+), 15 deletions(-) diff --git a/service/firewall/interception/dnsmonitor/etwlink_windows.go b/service/firewall/interception/dnsmonitor/etwlink_windows.go index cb9d8675..6f9bd16d 100644 --- a/service/firewall/interception/dnsmonitor/etwlink_windows.go +++ b/service/firewall/interception/dnsmonitor/etwlink_windows.go @@ -22,8 +22,8 @@ type ETWSession struct { state uintptr } -// NewSession creates new ETW event listener and initilizes it. This is a low level interface, make sure to call DestorySession when you are done using it. -func NewSession(etwInterface *integration.ETWFunctions, callback func(domain string, result string)) (*ETWSession, error) { +// NewSession creates new ETW event listener and initializes it. This is a low level interface, make sure to call DestroySession when you are done using it. +func NewSession(etwInterface *integration.ETWFunctions, callback func(domain string, pid uint32, result string)) (*ETWSession, error) { if etwInterface == nil { return nil, fmt.Errorf("etw interface was nil") } @@ -35,8 +35,8 @@ func NewSession(etwInterface *integration.ETWFunctions, callback func(domain str _ = etwSession.i.StopOldSession() // Initialize notification activated callback - win32Callback := windows.NewCallback(func(domain *uint16, result *uint16) uintptr { - callback(windows.UTF16PtrToString(domain), windows.UTF16PtrToString(result)) + win32Callback := windows.NewCallback(func(domain *uint16, pid uint32, result *uint16) uintptr { + callback(windows.UTF16PtrToString(domain), pid, windows.UTF16PtrToString(result)) return 0 }) // The function only allocates memory it will not fail. @@ -83,7 +83,7 @@ func (l *ETWSession) FlushTrace() error { return l.i.FlushTrace(l.state) } -// StopTrace stopes the trace. This will cause StartTrace to return. +// StopTrace stops the trace. This will cause StartTrace to return. func (l *ETWSession) StopTrace() error { return l.i.StopTrace(l.state) } diff --git a/service/firewall/interception/dnsmonitor/eventlistener_linux.go b/service/firewall/interception/dnsmonitor/eventlistener_linux.go index 6e9bb3ee..f4fc99a0 100644 --- a/service/firewall/interception/dnsmonitor/eventlistener_linux.go +++ b/service/firewall/interception/dnsmonitor/eventlistener_linux.go @@ -141,5 +141,5 @@ func (l *Listener) processAnswer(domain string, queryResult *QueryResult) { } } - saveDomain(domain, ips, cnames) + saveDomain(domain, ips, cnames, resolver.IPInfoProfileScopeGlobal) } diff --git a/service/firewall/interception/dnsmonitor/eventlistener_windows.go b/service/firewall/interception/dnsmonitor/eventlistener_windows.go index a46e8cc6..d71c3c43 100644 --- a/service/firewall/interception/dnsmonitor/eventlistener_windows.go +++ b/service/firewall/interception/dnsmonitor/eventlistener_windows.go @@ -4,6 +4,7 @@ package dnsmonitor import ( + "context" "fmt" "net" "strconv" @@ -11,6 +12,7 @@ import ( "github.com/miekg/dns" "github.com/safing/portmaster/service/mgr" + "github.com/safing/portmaster/service/process" "github.com/safing/portmaster/service/resolver" ) @@ -79,7 +81,7 @@ func (l *Listener) stop() error { return nil } -func (l *Listener) processEvent(domain string, result string) { +func (l *Listener) processEvent(domain string, pid uint32, result string) { if processIfSelfCheckDomain(dns.Fqdn(domain)) { // Not need to process result. return @@ -90,6 +92,15 @@ func (l *Listener) processEvent(domain string, result string) { return } + profileScope := resolver.IPInfoProfileScopeGlobal + // Get the profile ID if the process can be found + if proc, err := process.GetOrFindProcess(context.Background(), int(pid)); err == nil { + if profile := proc.Profile(); profile != nil { + if localProfile := profile.LocalProfile(); localProfile != nil { + profileScope = localProfile.ID + } + } + } cnames := make(map[string]string) ips := []net.IP{} @@ -115,5 +126,5 @@ func (l *Listener) processEvent(domain string, result string) { } } } - saveDomain(domain, ips, cnames) + saveDomain(domain, ips, cnames, profileScope) } diff --git a/service/firewall/interception/dnsmonitor/module.go b/service/firewall/interception/dnsmonitor/module.go index 974429f7..918d1e88 100644 --- a/service/firewall/interception/dnsmonitor/module.go +++ b/service/firewall/interception/dnsmonitor/module.go @@ -61,7 +61,7 @@ func (dl *DNSMonitor) Flush() error { return dl.listener.flush() } -func saveDomain(domain string, ips []net.IP, cnames map[string]string) { +func saveDomain(domain string, ips []net.IP, cnames map[string]string, profileScope string) { fqdn := dns.Fqdn(domain) // Create new record for this IP. record := resolver.ResolvedDomain{ @@ -75,7 +75,7 @@ func saveDomain(domain string, ips []net.IP, cnames map[string]string) { record.AddCNAMEs(cnames) // Add to cache - saveIPsInCache(ips, resolver.IPInfoProfileScopeGlobal, record) + saveIPsInCache(ips, profileScope, record) } func New(instance instance) (*DNSMonitor, error) { diff --git a/service/network/connection.go b/service/network/connection.go index 1c1bbf19..2cdf12e7 100644 --- a/service/network/connection.go +++ b/service/network/connection.go @@ -538,8 +538,9 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) { // Find domain and DNS context of entity. if conn.Entity.Domain == "" && conn.process.Profile() != nil { + profileScope := conn.process.Profile().LocalProfile().ID // check if we can find a domain for that IP - ipinfo, err := resolver.GetIPInfo(conn.process.Profile().LocalProfile().ID, pkt.Info().RemoteIP().String()) + ipinfo, err := resolver.GetIPInfo(profileScope, pkt.Info().RemoteIP().String()) if err != nil { // Try again with the global scope, in case DNS went through the system resolver. ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String()) @@ -555,6 +556,13 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) { // Error flushing, dont try again. break } + // Try with profile scope + ipinfo, err = resolver.GetIPInfo(profileScope, pkt.Info().RemoteIP().String()) + if err == nil { + log.Tracer(pkt.Ctx()).Debugf("network: found domain with scope (%s) from dnsmonitor after %d tries", profileScope, +1) + break + } + // Try again with the global scope ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String()) if err == nil { log.Tracer(pkt.Ctx()).Debugf("network: found domain from dnsmonitor after %d tries", i+1) diff --git a/windows_core_dll/dllmain.cpp b/windows_core_dll/dllmain.cpp index cc0efaac..7674539f 100644 --- a/windows_core_dll/dllmain.cpp +++ b/windows_core_dll/dllmain.cpp @@ -22,7 +22,7 @@ static const GUID PORTMASTER_ETW_SESSION_GUID = { #define LOGSESSION_NAME L"PortmasterDNSEventListener" // Fuction type of the callback that will be called on each event. -typedef uint64_t(*GoEventRecordCallback)(wchar_t* domain, wchar_t* result); +typedef uint64_t(*GoEventRecordCallback)(wchar_t* domain, uint32_t pid, wchar_t* result); // Holds the state of the ETW Session. struct ETWSessionState { @@ -41,7 +41,7 @@ static bool getPropertyValue(PEVENT_RECORD evt, LPWSTR prop, PBYTE* pData) { DataDescriptor.ArrayIndex = 0; DWORD PropertySize = 0; - // Check if the data is avaliable and what is the size of it. + // Check if the data is available and what is the size of it. DWORD status = TdhGetPropertySize(evt, 0, NULL, 1, &DataDescriptor, &PropertySize); if (ERROR_SUCCESS != status) { @@ -79,7 +79,7 @@ static void WINAPI EventRecordCallback(PEVENT_RECORD eventRecord) { ETWSessionState* state = (ETWSessionState*)eventRecord->UserContext; if (resultValue != NULL && domainValue != NULL) { - state->callback((wchar_t*)domainValue, (wchar_t*)resultValue); + state->callback((wchar_t*)domainValue, eventRecord->EventHeader.ProcessId, (wchar_t*)resultValue); } free(resultValue); @@ -160,7 +160,7 @@ extern "C" { EVENT_TRACE_CONTROL_STOP); } - // PM_ETWFlushTrace Closes the session and frees resourses. + // PM_ETWFlushTrace Closes the session and frees recourses. __declspec(dllexport) uint32_t PM_ETWDestroySession(ETWSessionState* state) { if (state == NULL) { return 1;