From 0f28af66cd3104c7e9e488f9b463f2c37075f73f Mon Sep 17 00:00:00 2001
From: Vladimir Stoilov <vladimir@safing.io>
Date: Mon, 27 Jan 2025 17:21:54 +0200
Subject: [PATCH] Add PID in ETW DNS event in the integration dll (#1768)

* [service] Add reading of PID in ETW DNS event

* [service] Use PID of the ETW DNS events

* [service] Fix use of nil pointer

* [service] Fix compiler error
---
 .../interception/dnsmonitor/etwlink_windows.go    | 10 +++++-----
 .../dnsmonitor/eventlistener_linux.go             |  2 +-
 .../dnsmonitor/eventlistener_windows.go           | 15 +++++++++++++--
 .../firewall/interception/dnsmonitor/module.go    |  4 ++--
 service/network/connection.go                     | 10 +++++++++-
 windows_core_dll/dllmain.cpp                      |  8 ++++----
 6 files changed, 34 insertions(+), 15 deletions(-)

diff --git a/service/firewall/interception/dnsmonitor/etwlink_windows.go b/service/firewall/interception/dnsmonitor/etwlink_windows.go
index cb9d8675..6f9bd16d 100644
--- a/service/firewall/interception/dnsmonitor/etwlink_windows.go
+++ b/service/firewall/interception/dnsmonitor/etwlink_windows.go
@@ -22,8 +22,8 @@ type ETWSession struct {
 	state uintptr
 }
 
-// NewSession creates new ETW event listener and initilizes it. This is a low level interface, make sure to call DestorySession when you are done using it.
-func NewSession(etwInterface *integration.ETWFunctions, callback func(domain string, result string)) (*ETWSession, error) {
+// NewSession creates new ETW event listener and initializes it. This is a low level interface, make sure to call DestroySession when you are done using it.
+func NewSession(etwInterface *integration.ETWFunctions, callback func(domain string, pid uint32, result string)) (*ETWSession, error) {
 	if etwInterface == nil {
 		return nil, fmt.Errorf("etw interface was nil")
 	}
@@ -35,8 +35,8 @@ func NewSession(etwInterface *integration.ETWFunctions, callback func(domain str
 	_ = etwSession.i.StopOldSession()
 
 	// Initialize notification activated callback
-	win32Callback := windows.NewCallback(func(domain *uint16, result *uint16) uintptr {
-		callback(windows.UTF16PtrToString(domain), windows.UTF16PtrToString(result))
+	win32Callback := windows.NewCallback(func(domain *uint16, pid uint32, result *uint16) uintptr {
+		callback(windows.UTF16PtrToString(domain), pid, windows.UTF16PtrToString(result))
 		return 0
 	})
 	// The function only allocates memory it will not fail.
@@ -83,7 +83,7 @@ func (l *ETWSession) FlushTrace() error {
 	return l.i.FlushTrace(l.state)
 }
 
-// StopTrace stopes the trace. This will cause StartTrace to return.
+// StopTrace stops the trace. This will cause StartTrace to return.
 func (l *ETWSession) StopTrace() error {
 	return l.i.StopTrace(l.state)
 }
diff --git a/service/firewall/interception/dnsmonitor/eventlistener_linux.go b/service/firewall/interception/dnsmonitor/eventlistener_linux.go
index 6e9bb3ee..f4fc99a0 100644
--- a/service/firewall/interception/dnsmonitor/eventlistener_linux.go
+++ b/service/firewall/interception/dnsmonitor/eventlistener_linux.go
@@ -141,5 +141,5 @@ func (l *Listener) processAnswer(domain string, queryResult *QueryResult) {
 		}
 	}
 
-	saveDomain(domain, ips, cnames)
+	saveDomain(domain, ips, cnames, resolver.IPInfoProfileScopeGlobal)
 }
diff --git a/service/firewall/interception/dnsmonitor/eventlistener_windows.go b/service/firewall/interception/dnsmonitor/eventlistener_windows.go
index a46e8cc6..d71c3c43 100644
--- a/service/firewall/interception/dnsmonitor/eventlistener_windows.go
+++ b/service/firewall/interception/dnsmonitor/eventlistener_windows.go
@@ -4,6 +4,7 @@
 package dnsmonitor
 
 import (
+	"context"
 	"fmt"
 	"net"
 	"strconv"
@@ -11,6 +12,7 @@ import (
 
 	"github.com/miekg/dns"
 	"github.com/safing/portmaster/service/mgr"
+	"github.com/safing/portmaster/service/process"
 	"github.com/safing/portmaster/service/resolver"
 )
 
@@ -79,7 +81,7 @@ func (l *Listener) stop() error {
 	return nil
 }
 
-func (l *Listener) processEvent(domain string, result string) {
+func (l *Listener) processEvent(domain string, pid uint32, result string) {
 	if processIfSelfCheckDomain(dns.Fqdn(domain)) {
 		// Not need to process result.
 		return
@@ -90,6 +92,15 @@ func (l *Listener) processEvent(domain string, result string) {
 		return
 	}
 
+	profileScope := resolver.IPInfoProfileScopeGlobal
+	// Get the profile ID if the process can be found
+	if proc, err := process.GetOrFindProcess(context.Background(), int(pid)); err == nil {
+		if profile := proc.Profile(); profile != nil {
+			if localProfile := profile.LocalProfile(); localProfile != nil {
+				profileScope = localProfile.ID
+			}
+		}
+	}
 	cnames := make(map[string]string)
 	ips := []net.IP{}
 
@@ -115,5 +126,5 @@ func (l *Listener) processEvent(domain string, result string) {
 			}
 		}
 	}
-	saveDomain(domain, ips, cnames)
+	saveDomain(domain, ips, cnames, profileScope)
 }
diff --git a/service/firewall/interception/dnsmonitor/module.go b/service/firewall/interception/dnsmonitor/module.go
index 974429f7..918d1e88 100644
--- a/service/firewall/interception/dnsmonitor/module.go
+++ b/service/firewall/interception/dnsmonitor/module.go
@@ -61,7 +61,7 @@ func (dl *DNSMonitor) Flush() error {
 	return dl.listener.flush()
 }
 
-func saveDomain(domain string, ips []net.IP, cnames map[string]string) {
+func saveDomain(domain string, ips []net.IP, cnames map[string]string, profileScope string) {
 	fqdn := dns.Fqdn(domain)
 	// Create new record for this IP.
 	record := resolver.ResolvedDomain{
@@ -75,7 +75,7 @@ func saveDomain(domain string, ips []net.IP, cnames map[string]string) {
 	record.AddCNAMEs(cnames)
 
 	// Add to cache
-	saveIPsInCache(ips, resolver.IPInfoProfileScopeGlobal, record)
+	saveIPsInCache(ips, profileScope, record)
 }
 
 func New(instance instance) (*DNSMonitor, error) {
diff --git a/service/network/connection.go b/service/network/connection.go
index 1c1bbf19..2cdf12e7 100644
--- a/service/network/connection.go
+++ b/service/network/connection.go
@@ -538,8 +538,9 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) {
 
 	// Find domain and DNS context of entity.
 	if conn.Entity.Domain == "" && conn.process.Profile() != nil {
+		profileScope := conn.process.Profile().LocalProfile().ID
 		// check if we can find a domain for that IP
-		ipinfo, err := resolver.GetIPInfo(conn.process.Profile().LocalProfile().ID, pkt.Info().RemoteIP().String())
+		ipinfo, err := resolver.GetIPInfo(profileScope, pkt.Info().RemoteIP().String())
 		if err != nil {
 			// Try again with the global scope, in case DNS went through the system resolver.
 			ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String())
@@ -555,6 +556,13 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) {
 						// Error flushing, dont try again.
 						break
 					}
+					// Try with profile scope
+					ipinfo, err = resolver.GetIPInfo(profileScope, pkt.Info().RemoteIP().String())
+					if err == nil {
+						log.Tracer(pkt.Ctx()).Debugf("network: found domain with scope (%s) from dnsmonitor after %d tries", profileScope, +1)
+						break
+					}
+					// Try again with the global scope
 					ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String())
 					if err == nil {
 						log.Tracer(pkt.Ctx()).Debugf("network: found domain from dnsmonitor after %d tries", i+1)
diff --git a/windows_core_dll/dllmain.cpp b/windows_core_dll/dllmain.cpp
index cc0efaac..7674539f 100644
--- a/windows_core_dll/dllmain.cpp
+++ b/windows_core_dll/dllmain.cpp
@@ -22,7 +22,7 @@ static const GUID PORTMASTER_ETW_SESSION_GUID = {
 #define LOGSESSION_NAME L"PortmasterDNSEventListener"
 
 // Fuction type of the callback that will be called on each event.
-typedef uint64_t(*GoEventRecordCallback)(wchar_t* domain, wchar_t* result);
+typedef uint64_t(*GoEventRecordCallback)(wchar_t* domain, uint32_t pid, wchar_t* result);
 
 // Holds the state of the ETW Session.
 struct ETWSessionState {
@@ -41,7 +41,7 @@ static bool getPropertyValue(PEVENT_RECORD evt, LPWSTR prop, PBYTE* pData) {
     DataDescriptor.ArrayIndex = 0;
 
     DWORD PropertySize = 0;
-    // Check if the data is avaliable and what is the size of it.
+    // Check if the data is available and what is the size of it.
     DWORD status =
         TdhGetPropertySize(evt, 0, NULL, 1, &DataDescriptor, &PropertySize);
     if (ERROR_SUCCESS != status) {
@@ -79,7 +79,7 @@ static void WINAPI EventRecordCallback(PEVENT_RECORD eventRecord) {
     ETWSessionState* state = (ETWSessionState*)eventRecord->UserContext;
 
     if (resultValue != NULL && domainValue != NULL) {
-        state->callback((wchar_t*)domainValue, (wchar_t*)resultValue);
+        state->callback((wchar_t*)domainValue, eventRecord->EventHeader.ProcessId, (wchar_t*)resultValue);
     }
 
     free(resultValue);
@@ -160,7 +160,7 @@ extern "C" {
             EVENT_TRACE_CONTROL_STOP);
     }
 
-    // PM_ETWFlushTrace Closes the session and frees resourses.
+    // PM_ETWFlushTrace Closes the session and frees recourses.
     __declspec(dllexport) uint32_t PM_ETWDestroySession(ETWSessionState* state) {
         if (state == NULL) {
             return 1;