From eee7ad88142e4e572a248afa92cf404dfb243900 Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Fri, 2 Apr 2021 13:23:36 +0300 Subject: [PATCH] seclient: improved concurrency correctness --- seclient/jar.go | 53 +++++++++++++++++++++ seclient/seclient.go | 111 +++++++++++++++++++++++++------------------ 2 files changed, 119 insertions(+), 45 deletions(-) create mode 100644 seclient/jar.go diff --git a/seclient/jar.go b/seclient/jar.go new file mode 100644 index 0000000..106504c --- /dev/null +++ b/seclient/jar.go @@ -0,0 +1,53 @@ +package seclient + +import ( + "net/http" + "net/http/cookiejar" + "net/url" + "sync" + + "golang.org/x/net/publicsuffix" +) + +type StdJar struct { + jar *cookiejar.Jar + mux sync.RWMutex +} + +func NewStdJar() (*StdJar, error) { + var jar StdJar + + err := jar.Reset() + if err != nil { + return nil, err + } + + return &jar, nil +} + +func (j *StdJar) SetCookies(u *url.URL, cookies []*http.Cookie) { + j.mux.RLock() + j.jar.SetCookies(u, cookies) + j.mux.RUnlock() +} + +func (j *StdJar) Cookies(u *url.URL) []*http.Cookie { + j.mux.RLock() + c := j.jar.Cookies(u) + j.mux.RUnlock() + return c +} + +func (j *StdJar) Reset() error { + jar, err := cookiejar.New(&cookiejar.Options{ + PublicSuffixList: publicsuffix.List, + }) + if err != nil { + return err + } + + j.mux.Lock() + j.jar = jar + j.mux.Unlock() + return nil +} diff --git a/seclient/seclient.go b/seclient/seclient.go index 6c482d9..e74df99 100644 --- a/seclient/seclient.go +++ b/seclient/seclient.go @@ -8,13 +8,11 @@ import ( "io/ioutil" "math/rand" "net/http" - "net/http/cookiejar" "net/url" "strings" "sync" dac "github.com/Snawoot/go-http-digest-auth-client" - "golang.org/x/net/publicsuffix" ) const ( @@ -61,7 +59,7 @@ var DefaultSESettings = SESettings{ } type SEClient struct { - HttpClient *http.Client + httpClient *http.Client Settings SESettings SubscriberEmail string SubscriberPassword string @@ -69,7 +67,7 @@ type SEClient struct { AssignedDeviceID string AssignedDeviceIDHash string DevicePassword string - StateMux sync.RWMutex + Mux sync.Mutex rng *rand.Rand } @@ -90,8 +88,14 @@ func NewSEClient(apiUsername, apiSecret string, transport http.RoundTripper) (*S return nil, err } + jar, err := NewStdJar() + if err != nil { + return nil, err + } + res := &SEClient{ - HttpClient: &http.Client{ + httpClient: &http.Client{ + Jar: jar, Transport: dac.NewDigestTransport(apiUsername, apiSecret, transport), }, Settings: DefaultSESettings, @@ -99,48 +103,48 @@ func NewSEClient(apiUsername, apiSecret string, transport http.RoundTripper) (*S DeviceID: device_id, } - err = res.ResetCookies() - if err != nil { - return nil, err - } - return res, nil } func (c *SEClient) ResetCookies() error { - jar, err := cookiejar.New(&cookiejar.Options{ - PublicSuffixList: publicsuffix.List, - }) - if err != nil { - return err - } + c.Mux.Lock() + defer c.Mux.Unlock() + return c.resetCookies() +} - c.StateMux.Lock() - c.HttpClient.Jar = jar - c.StateMux.Unlock() - return nil +func (c *SEClient) resetCookies() error { + return (c.httpClient.Jar.(*StdJar)).Reset() } func (c *SEClient) AnonRegister(ctx context.Context) error { + c.Mux.Lock() + defer c.Mux.Unlock() + localPart, err := randomEmailLocalPart(c.rng) if err != nil { return err } - subscriberEmail := fmt.Sprintf("%s@%s.best.vpn", localPart, c.Settings.ClientType) - subscriberPassword := capitalHexSHA1(subscriberEmail) + c.SubscriberEmail = fmt.Sprintf("%s@%s.best.vpn", localPart, c.Settings.ClientType) + c.SubscriberPassword = capitalHexSHA1(c.SubscriberEmail) - c.StateMux.Lock() - c.SubscriberEmail = subscriberEmail - c.SubscriberPassword = subscriberPassword - c.StateMux.Unlock() - - return c.Register(ctx) + return c.register(ctx) } func (c *SEClient) Register(ctx context.Context) error { + c.Mux.Lock() + defer c.Mux.Unlock() + return c.register(ctx) +} + +func (c *SEClient) register(ctx context.Context) error { + err := c.resetCookies() + if err != nil { + return err + } + var regRes SERegisterSubscriberResponse - err := c.RpcCall(ctx, c.Settings.Endpoints.RegisterSubscriber, StrKV{ + err = c.rpcCall(ctx, c.Settings.Endpoints.RegisterSubscriber, StrKV{ "email": c.SubscriberEmail, "password": c.SubscriberPassword, }, ®Res) @@ -156,8 +160,11 @@ func (c *SEClient) Register(ctx context.Context) error { } func (c *SEClient) RegisterDevice(ctx context.Context) error { + c.Mux.Lock() + defer c.Mux.Unlock() + var regRes SERegisterDeviceResponse - err := c.RpcCall(ctx, c.Settings.Endpoints.RegisterDevice, StrKV{ + err := c.rpcCall(ctx, c.Settings.Endpoints.RegisterDevice, StrKV{ "client_type": c.Settings.ClientType, "device_hash": c.DeviceID, "device_name": c.Settings.DeviceName, @@ -171,17 +178,18 @@ func (c *SEClient) RegisterDevice(ctx context.Context) error { regRes.Status.Code, regRes.Status.Message) } - c.StateMux.Lock() c.AssignedDeviceID = regRes.Data.DeviceID c.DevicePassword = regRes.Data.DevicePassword c.AssignedDeviceIDHash = capitalHexSHA1(regRes.Data.DeviceID) - c.StateMux.Unlock() return nil } func (c *SEClient) GeoList(ctx context.Context) ([]SEGeoEntry, error) { + c.Mux.Lock() + defer c.Mux.Unlock() + var geoListRes SEGeoListResponse - err := c.RpcCall(ctx, c.Settings.Endpoints.GeoList, StrKV{ + err := c.rpcCall(ctx, c.Settings.Endpoints.GeoList, StrKV{ "device_id": c.AssignedDeviceIDHash, }, &geoListRes) if err != nil { @@ -197,8 +205,11 @@ func (c *SEClient) GeoList(ctx context.Context) ([]SEGeoEntry, error) { } func (c *SEClient) Discover(ctx context.Context, requestedGeo string) ([]SEIPEntry, error) { + c.Mux.Lock() + defer c.Mux.Unlock() + var discoverRes SEDiscoverResponse - err := c.RpcCall(ctx, c.Settings.Endpoints.Discover, StrKV{ + err := c.rpcCall(ctx, c.Settings.Endpoints.Discover, StrKV{ "serial_no": c.AssignedDeviceIDHash, "requested_geo": requestedGeo, }, &discoverRes) @@ -215,13 +226,16 @@ func (c *SEClient) Discover(ctx context.Context, requestedGeo string) ([]SEIPEnt } func (c *SEClient) Login(ctx context.Context) error { - err := c.ResetCookies() + c.Mux.Lock() + defer c.Mux.Unlock() + + err := c.resetCookies() if err != nil { return err } var loginRes SESubscriberLoginResponse - err = c.RpcCall(ctx, c.Settings.Endpoints.SubscriberLogin, StrKV{ + err = c.rpcCall(ctx, c.Settings.Endpoints.SubscriberLogin, StrKV{ "login": c.SubscriberEmail, "password": c.SubscriberPassword, "client_type": c.Settings.ClientType, @@ -238,8 +252,11 @@ func (c *SEClient) Login(ctx context.Context) error { } func (c *SEClient) DeviceGeneratePassword(ctx context.Context) error { + c.Mux.Lock() + defer c.Mux.Unlock() + var genRes SEDeviceGeneratePasswordResponse - err := c.RpcCall(ctx, c.Settings.Endpoints.DeviceGeneratePassword, StrKV{ + err := c.rpcCall(ctx, c.Settings.Endpoints.DeviceGeneratePassword, StrKV{ "device_id": c.AssignedDeviceID, }, &genRes) if err != nil { @@ -251,18 +268,15 @@ func (c *SEClient) DeviceGeneratePassword(ctx context.Context) error { genRes.Status.Code, genRes.Status.Message) } - c.StateMux.Lock() c.DevicePassword = genRes.Data.DevicePassword - c.StateMux.Unlock() return nil } func (c *SEClient) GetProxyCredentials() (string, string) { - c.StateMux.RLock() - assignedDeviceIDHash := c.AssignedDeviceIDHash - devicePassword := c.DevicePassword - c.StateMux.RUnlock() - return assignedDeviceIDHash, devicePassword + c.Mux.Lock() + defer c.Mux.Unlock() + + return c.AssignedDeviceIDHash, c.DevicePassword } func (c *SEClient) populateRequest(req *http.Request) { @@ -272,6 +286,13 @@ func (c *SEClient) populateRequest(req *http.Request) { } func (c *SEClient) RpcCall(ctx context.Context, endpoint string, params map[string]string, res interface{}) error { + c.Mux.Lock() + defer c.Mux.Unlock() + + return c.rpcCall(ctx, endpoint, params, res) +} + +func (c *SEClient) rpcCall(ctx context.Context, endpoint string, params map[string]string, res interface{}) error { input := make(url.Values) for k, v := range params { input[k] = []string{v} @@ -289,7 +310,7 @@ func (c *SEClient) RpcCall(ctx context.Context, endpoint string, params map[stri req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") - resp, err := c.HttpClient.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { return err }