From 1754eb6e84d10b61adcad347435d6b276c007114 Mon Sep 17 00:00:00 2001 From: Ryan McGuire Date: Sun, 4 Jan 2026 13:56:19 -0500 Subject: [PATCH] Add thread-safe Add/Del methods and refactor client locking - Add Add and Del methods to Client for dynamic host management. - Add RWMutex to Client to protect the devices map. - Add Transport to Config to allow mocking HTTP transport in tests. - Add getDeviceByHost helper to centralize device lookup locking. - Refactor GetAll* methods to snapshot host keys before iteration to avoid concurrent map read/write panic. - Add tests for thread safety and Add/Del functionality. --- pkg/edgeos/api.go | 106 ++++++++++++++++++++++++++---------- pkg/edgeos/client.go | 97 +++++++++++++++++++++++++-------- pkg/edgeos/client_test.go | 109 ++++++++++++++++++++++++++++++++++++++ pkg/edgeos/config.go | 7 ++- 4 files changed, 268 insertions(+), 51 deletions(-) create mode 100644 pkg/edgeos/client_test.go diff --git a/pkg/edgeos/api.go b/pkg/edgeos/api.go index 0902eae..c62f81a 100644 --- a/pkg/edgeos/api.go +++ b/pkg/edgeos/api.go @@ -3,15 +3,14 @@ package edgeos import ( "context" "errors" - "fmt" "sync" ) // GetInterfaces retrieves the interfaces for a specific device. func (c *Client) GetInterfaces(ctx context.Context, host string) ([]Interface, error) { - d, ok := c.devices[host] - if !ok { - return nil, fmt.Errorf("device not found: %s", host) + d, err := c.getDeviceByHost(host) + if err != nil { + return nil, err } var out []Interface @@ -31,7 +30,14 @@ func (c *Client) GetAllInterfaces(ctx context.Context) (map[string][]Interface, errs error ) - for host := range c.devices { + c.mu.RLock() + hosts := make([]string, 0, len(c.devices)) + for h := range c.devices { + hosts = append(hosts, h) + } + c.mu.RUnlock() + + for _, host := range hosts { wg.Go(func() { res, err := c.GetInterfaces(ctx, host) if err != nil { @@ -51,9 +57,9 @@ func (c *Client) GetAllInterfaces(ctx context.Context) (map[string][]Interface, // GetDevice retrieves the device info for a specific device. func (c *Client) GetDevice(ctx context.Context, host string) (*Device, error) { - d, ok := c.devices[host] - if !ok { - return nil, fmt.Errorf("device not found: %s", host) + d, err := c.getDeviceByHost(host) + if err != nil { + return nil, err } var out Device @@ -73,7 +79,14 @@ func (c *Client) GetAllDevices(ctx context.Context) (map[string]*Device, error) errs error ) - for host := range c.devices { + c.mu.RLock() + hosts := make([]string, 0, len(c.devices)) + for h := range c.devices { + hosts = append(hosts, h) + } + c.mu.RUnlock() + + for _, host := range hosts { wg.Go(func() { res, err := c.GetDevice(ctx, host) if err != nil { @@ -93,9 +106,9 @@ func (c *Client) GetAllDevices(ctx context.Context) (map[string]*Device, error) // GetSystem retrieves the system info for a specific device. func (c *Client) GetSystem(ctx context.Context, host string) (*System, error) { - d, ok := c.devices[host] - if !ok { - return nil, fmt.Errorf("device not found: %s", host) + d, err := c.getDeviceByHost(host) + if err != nil { + return nil, err } var out System @@ -115,7 +128,14 @@ func (c *Client) GetAllSystems(ctx context.Context) (map[string]*System, error) errs error ) - for host := range c.devices { + c.mu.RLock() + hosts := make([]string, 0, len(c.devices)) + for h := range c.devices { + hosts = append(hosts, h) + } + c.mu.RUnlock() + + for _, host := range hosts { wg.Go(func() { res, err := c.GetSystem(ctx, host) if err != nil { @@ -135,9 +155,9 @@ func (c *Client) GetAllSystems(ctx context.Context) (map[string]*System, error) // GetVLANs retrieves the VLANs for a specific device. func (c *Client) GetVLANs(ctx context.Context, host string) (*VLANs, error) { - d, ok := c.devices[host] - if !ok { - return nil, fmt.Errorf("device not found: %s", host) + d, err := c.getDeviceByHost(host) + if err != nil { + return nil, err } var out VLANs @@ -157,7 +177,14 @@ func (c *Client) GetAllVLANs(ctx context.Context) (map[string]*VLANs, error) { errs error ) - for host := range c.devices { + c.mu.RLock() + hosts := make([]string, 0, len(c.devices)) + for h := range c.devices { + hosts = append(hosts, h) + } + c.mu.RUnlock() + + for _, host := range hosts { wg.Go(func() { res, err := c.GetVLANs(ctx, host) if err != nil { @@ -177,9 +204,9 @@ func (c *Client) GetAllVLANs(ctx context.Context) (map[string]*VLANs, error) { // GetServices retrieves the services for a specific device. func (c *Client) GetServices(ctx context.Context, host string) (*Services, error) { - d, ok := c.devices[host] - if !ok { - return nil, fmt.Errorf("device not found: %s", host) + d, err := c.getDeviceByHost(host) + if err != nil { + return nil, err } var out Services @@ -199,7 +226,14 @@ func (c *Client) GetAllServices(ctx context.Context) (map[string]*Services, erro errs error ) - for host := range c.devices { + c.mu.RLock() + hosts := make([]string, 0, len(c.devices)) + for h := range c.devices { + hosts = append(hosts, h) + } + c.mu.RUnlock() + + for _, host := range hosts { wg.Go(func() { res, err := c.GetServices(ctx, host) if err != nil { @@ -219,9 +253,9 @@ func (c *Client) GetAllServices(ctx context.Context) (map[string]*Services, erro // GetStatistics retrieves the statistics for a specific device. func (c *Client) GetStatistics(ctx context.Context, host string) ([]Statistics, error) { - d, ok := c.devices[host] - if !ok { - return nil, fmt.Errorf("device not found: %s", host) + d, err := c.getDeviceByHost(host) + if err != nil { + return nil, err } var out []Statistics @@ -241,7 +275,14 @@ func (c *Client) GetAllStatistics(ctx context.Context) (map[string][]Statistics, errs error ) - for host := range c.devices { + c.mu.RLock() + hosts := make([]string, 0, len(c.devices)) + for h := range c.devices { + hosts = append(hosts, h) + } + c.mu.RUnlock() + + for _, host := range hosts { wg.Go(func() { res, err := c.GetStatistics(ctx, host) if err != nil { @@ -261,9 +302,9 @@ func (c *Client) GetAllStatistics(ctx context.Context) (map[string][]Statistics, // GetNeighbors retrieves the neighbors for a specific device. func (c *Client) GetNeighbors(ctx context.Context, host string) ([]Neighbor, error) { - d, ok := c.devices[host] - if !ok { - return nil, fmt.Errorf("device not found: %s", host) + d, err := c.getDeviceByHost(host) + if err != nil { + return nil, err } var out []Neighbor @@ -283,7 +324,14 @@ func (c *Client) GetAllNeighbors(ctx context.Context) (map[string][]Neighbor, er errs error ) - for host := range c.devices { + c.mu.RLock() + hosts := make([]string, 0, len(c.devices)) + for h := range c.devices { + hosts = append(hosts, h) + } + c.mu.RUnlock() + + for _, host := range hosts { wg.Go(func() { res, err := c.GetNeighbors(ctx, host) if err != nil { diff --git a/pkg/edgeos/client.go b/pkg/edgeos/client.go index ea7746f..1df1aba 100644 --- a/pkg/edgeos/client.go +++ b/pkg/edgeos/client.go @@ -21,6 +21,7 @@ import ( // Client handles communication with EdgeOS devices. type Client struct { + mu sync.RWMutex devices map[string]*deviceClient } @@ -31,33 +32,42 @@ type deviceClient struct { mu sync.Mutex } +func newDeviceClient(cfg Config) *deviceClient { + // Ensure scheme is set + if cfg.Scheme == "" { + cfg.Scheme = "https" + } + + var tr http.RoundTripper + if cfg.Transport != nil { + tr = cfg.Transport + } else { + defaultTr := http.DefaultTransport.(*http.Transport).Clone() + if defaultTr.TLSClientConfig == nil { + defaultTr.TLSClientConfig = &tls.Config{} + } + defaultTr.TLSClientConfig.InsecureSkipVerify = cfg.Insecure + tr = defaultTr + } + + client := &http.Client{ + Transport: tr, + Timeout: cfg.Timeout, + } + + return &deviceClient{ + config: cfg, + client: client, + } +} + // MustNew creates a new Client with the given configurations. // It panics if a configuration is invalid (though currently we just accept all). func MustNew(ctx context.Context, configs []Config) *Client { devices := make(map[string]*deviceClient) for _, cfg := range configs { - // Use Host as the key. - // Ensure scheme is set - if cfg.Scheme == "" { - cfg.Scheme = "https" - } - - tr := http.DefaultTransport.(*http.Transport).Clone() - if tr.TLSClientConfig == nil { - tr.TLSClientConfig = &tls.Config{} - } - tr.TLSClientConfig.InsecureSkipVerify = cfg.Insecure - - client := &http.Client{ - Transport: tr, - Timeout: cfg.Timeout, - } - - devices[cfg.Host] = &deviceClient{ - config: cfg, - client: client, - } + devices[cfg.Host] = newDeviceClient(cfg) } return &Client{ @@ -65,6 +75,51 @@ func MustNew(ctx context.Context, configs []Config) *Client { } } +// Add adds a new device to the client. +// It returns an error if a device with the same host already exists. +func (c *Client) Add(cfg *Config) error { + if cfg == nil { + return fmt.Errorf("config cannot be nil") + } + + d := newDeviceClient(*cfg) + + c.mu.Lock() + defer c.mu.Unlock() + + if _, ok := c.devices[cfg.Host]; ok { + return fmt.Errorf("device already exists: %s", cfg.Host) + } + + c.devices[cfg.Host] = d + return nil +} + +// Del removes a device from the client. +// It returns an error if the device does not exist. +func (c *Client) Del(host string) error { + c.mu.Lock() + defer c.mu.Unlock() + + if _, ok := c.devices[host]; !ok { + return fmt.Errorf("device not found: %s", host) + } + + delete(c.devices, host) + return nil +} + +func (c *Client) getDeviceByHost(host string) (*deviceClient, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + d, ok := c.devices[host] + if !ok { + return nil, fmt.Errorf("device not found: %s", host) + } + return d, nil +} + func (d *deviceClient) login(ctx context.Context) error { d.mu.Lock() defer d.mu.Unlock() diff --git a/pkg/edgeos/client_test.go b/pkg/edgeos/client_test.go new file mode 100644 index 0000000..f7ec7cf --- /dev/null +++ b/pkg/edgeos/client_test.go @@ -0,0 +1,109 @@ +package edgeos + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "sync" + "testing" +) + +type mockTransport struct { + RoundTripFunc func(req *http.Request) (*http.Response, error) +} + +func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if m.RoundTripFunc != nil { + return m.RoundTripFunc(req) + } + // Default mock response + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString("{}")), + Header: make(http.Header), + }, nil +} + +func TestClient_ThreadSafety(t *testing.T) { + ctx := context.Background() + client := MustNew(ctx, []Config{}) + + var wg sync.WaitGroup + start := make(chan struct{}) + + // Writer: Adds and deletes hosts + wg.Add(1) + go func() { + defer wg.Done() + <-start + for i := 0; i < 100; i++ { + host := fmt.Sprintf("host-%d", i) + cfg := &Config{ + Host: host, + Transport: &mockTransport{}, + } + if err := client.Add(cfg); err != nil { + // verify we don't error on valid add + t.Logf("Add error: %v", err) + } + + // We invoke Del immediately. + if err := client.Del(host); err != nil { + t.Logf("Del error: %v", err) + } + } + }() + + // Reader: Iterates hosts + wg.Add(1) + go func() { + defer wg.Done() + <-start + for i := 0; i < 10; i++ { + // GetAllInterfaces iterates keys. + // With mock transport, this will succeed (returning empty structs) + // checking for race conditions. + _, _ = client.GetAllInterfaces(ctx) + } + }() + + close(start) + wg.Wait() +} + +func TestClient_AddDel(t *testing.T) { + ctx := context.Background() + client := MustNew(ctx, []Config{}) + + cfg := &Config{ + Host: "test-host", + Transport: &mockTransport{}, + } + if err := client.Add(cfg); err != nil { + t.Fatalf("Add failed: %v", err) + } + + if err := client.Add(cfg); err == nil { + t.Fatal("Expected error adding duplicate host, got nil") + } + + // Verify we can retrieve it + // Mock transport returns 200 OK with empty body, so GetInterfaces should return empty slice (or error decoding if empty body is not valid JSON array? actually "{}" is valid object, but GetInterfaces expects array for /interfaces?) + // Let's check api.go: GetInterfaces calls /interfaces. + // We can customize the mock if we want to test success return. + // For this test, we just care that it doesn't return "device not found". + _, err := client.GetInterfaces(ctx, "test-host") + if err != nil && err.Error() == "device not found: test-host" { + t.Fatal("Device should exist") + } + + if err := client.Del("test-host"); err != nil { + t.Fatalf("Del failed: %v", err) + } + + if err := client.Del("test-host"); err == nil { + t.Fatal("Expected error deleting non-existent host, got nil") + } +} \ No newline at end of file diff --git a/pkg/edgeos/config.go b/pkg/edgeos/config.go index a0b5747..f18652a 100644 --- a/pkg/edgeos/config.go +++ b/pkg/edgeos/config.go @@ -1,6 +1,9 @@ package edgeos -import "time" +import ( + "net/http" + "time" +) // Config represents the configuration for an EdgeOS device. type Config struct { @@ -10,4 +13,6 @@ type Config struct { Username string Password string Timeout time.Duration + // Transport allows customizing the http transport (useful for testing) + Transport http.RoundTripper }