Improve svchost service discovery

This commit is contained in:
Daniel 2022-10-10 16:24:51 +02:00
parent b564e77168
commit a391eb3dad

View file

@ -12,7 +12,7 @@ import (
)
var (
serviceNames map[int32]string
serviceNames map[int32][]string
serviceNamesLock sync.Mutex
)
@ -22,7 +22,7 @@ var (
)
// GetServiceNames returns all service names assosicated with a svchost.exe process on Windows.
func GetServiceNames(pid int32) (string, error) {
func GetServiceNames(pid int32) ([]string, error) {
serviceNamesLock.Lock()
defer serviceNamesLock.Unlock()
@ -35,7 +35,7 @@ func GetServiceNames(pid int32) (string, error) {
serviceNames, err := GetAllServiceNames()
if err != nil {
return "", err
return nil, err
}
names, ok := serviceNames[pid]
@ -43,11 +43,11 @@ func GetServiceNames(pid int32) (string, error) {
return names, nil
}
return "", ErrServiceNotFound
return nil, ErrServiceNotFound
}
// GetAllServiceNames returns a list of service names assosicated with svchost.exe processes on Windows.
func GetAllServiceNames() (map[int32]string, error) {
func GetAllServiceNames() (map[int32][]string, error) {
output, err := exec.Command("tasklist", "/svc", "/fi", "imagename eq svchost.exe").Output()
if err != nil {
return nil, fmt.Errorf("failed to get svchost tasklist: %s", err)
@ -66,8 +66,8 @@ func GetAllServiceNames() (map[int32]string, error) {
var (
pid int32
services string
collection = make(map[int32]string)
services []string
collection = make(map[int32][]string)
)
for scanner.Scan() {
@ -83,11 +83,11 @@ func GetAllServiceNames() (map[int32]string, error) {
if fields[0] == "svchost.exe" {
// save old entry
if pid != 0 {
collection[pid] = strings.TrimSpace(services)
collection[pid] = services
}
// reset
// reset PID
pid = 0
services = ""
services = make([]string, 0, len(fields))
// check fields length
if len(fields) < 3 {
@ -106,12 +106,14 @@ func GetAllServiceNames() (map[int32]string, error) {
}
// add service names
services += " " + strings.Join(fields, " ")
for _, field := range fields {
services = append(services, strings.Trim(strings.TrimSpace(field), ","))
}
}
if pid != 0 {
// save last entry
collection[pid] = strings.TrimSpace(services)
collection[pid] = services
}
return collection, nil