Add thread-safe Add/Del methods and refactor client locking

- Add Add and Del methods to Client for dynamic host management.
- Add RWMutex to Client to protect the devices map.
- Add Transport to Config to allow mocking HTTP transport in tests.
- Add getDeviceByHost helper to centralize device lookup locking.
- Refactor GetAll* methods to snapshot host keys before iteration to avoid concurrent map read/write panic.
- Add tests for thread safety and Add/Del functionality.
This commit is contained in:
2026-01-04 13:56:19 -05:00
parent 906d005edf
commit 1754eb6e84
4 changed files with 268 additions and 51 deletions

View File

@@ -3,15 +3,14 @@ package edgeos
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"sync" "sync"
) )
// GetInterfaces retrieves the interfaces for a specific device. // GetInterfaces retrieves the interfaces for a specific device.
func (c *Client) GetInterfaces(ctx context.Context, host string) ([]Interface, error) { func (c *Client) GetInterfaces(ctx context.Context, host string) ([]Interface, error) {
d, ok := c.devices[host] d, err := c.getDeviceByHost(host)
if !ok { if err != nil {
return nil, fmt.Errorf("device not found: %s", host) return nil, err
} }
var out []Interface var out []Interface
@@ -31,7 +30,14 @@ func (c *Client) GetAllInterfaces(ctx context.Context) (map[string][]Interface,
errs error errs error
) )
for host := range c.devices { c.mu.RLock()
hosts := make([]string, 0, len(c.devices))
for h := range c.devices {
hosts = append(hosts, h)
}
c.mu.RUnlock()
for _, host := range hosts {
wg.Go(func() { wg.Go(func() {
res, err := c.GetInterfaces(ctx, host) res, err := c.GetInterfaces(ctx, host)
if err != nil { if err != nil {
@@ -51,9 +57,9 @@ func (c *Client) GetAllInterfaces(ctx context.Context) (map[string][]Interface,
// GetDevice retrieves the device info for a specific device. // GetDevice retrieves the device info for a specific device.
func (c *Client) GetDevice(ctx context.Context, host string) (*Device, error) { func (c *Client) GetDevice(ctx context.Context, host string) (*Device, error) {
d, ok := c.devices[host] d, err := c.getDeviceByHost(host)
if !ok { if err != nil {
return nil, fmt.Errorf("device not found: %s", host) return nil, err
} }
var out Device var out Device
@@ -73,7 +79,14 @@ func (c *Client) GetAllDevices(ctx context.Context) (map[string]*Device, error)
errs error errs error
) )
for host := range c.devices { c.mu.RLock()
hosts := make([]string, 0, len(c.devices))
for h := range c.devices {
hosts = append(hosts, h)
}
c.mu.RUnlock()
for _, host := range hosts {
wg.Go(func() { wg.Go(func() {
res, err := c.GetDevice(ctx, host) res, err := c.GetDevice(ctx, host)
if err != nil { if err != nil {
@@ -93,9 +106,9 @@ func (c *Client) GetAllDevices(ctx context.Context) (map[string]*Device, error)
// GetSystem retrieves the system info for a specific device. // GetSystem retrieves the system info for a specific device.
func (c *Client) GetSystem(ctx context.Context, host string) (*System, error) { func (c *Client) GetSystem(ctx context.Context, host string) (*System, error) {
d, ok := c.devices[host] d, err := c.getDeviceByHost(host)
if !ok { if err != nil {
return nil, fmt.Errorf("device not found: %s", host) return nil, err
} }
var out System var out System
@@ -115,7 +128,14 @@ func (c *Client) GetAllSystems(ctx context.Context) (map[string]*System, error)
errs error errs error
) )
for host := range c.devices { c.mu.RLock()
hosts := make([]string, 0, len(c.devices))
for h := range c.devices {
hosts = append(hosts, h)
}
c.mu.RUnlock()
for _, host := range hosts {
wg.Go(func() { wg.Go(func() {
res, err := c.GetSystem(ctx, host) res, err := c.GetSystem(ctx, host)
if err != nil { if err != nil {
@@ -135,9 +155,9 @@ func (c *Client) GetAllSystems(ctx context.Context) (map[string]*System, error)
// GetVLANs retrieves the VLANs for a specific device. // GetVLANs retrieves the VLANs for a specific device.
func (c *Client) GetVLANs(ctx context.Context, host string) (*VLANs, error) { func (c *Client) GetVLANs(ctx context.Context, host string) (*VLANs, error) {
d, ok := c.devices[host] d, err := c.getDeviceByHost(host)
if !ok { if err != nil {
return nil, fmt.Errorf("device not found: %s", host) return nil, err
} }
var out VLANs var out VLANs
@@ -157,7 +177,14 @@ func (c *Client) GetAllVLANs(ctx context.Context) (map[string]*VLANs, error) {
errs error errs error
) )
for host := range c.devices { c.mu.RLock()
hosts := make([]string, 0, len(c.devices))
for h := range c.devices {
hosts = append(hosts, h)
}
c.mu.RUnlock()
for _, host := range hosts {
wg.Go(func() { wg.Go(func() {
res, err := c.GetVLANs(ctx, host) res, err := c.GetVLANs(ctx, host)
if err != nil { if err != nil {
@@ -177,9 +204,9 @@ func (c *Client) GetAllVLANs(ctx context.Context) (map[string]*VLANs, error) {
// GetServices retrieves the services for a specific device. // GetServices retrieves the services for a specific device.
func (c *Client) GetServices(ctx context.Context, host string) (*Services, error) { func (c *Client) GetServices(ctx context.Context, host string) (*Services, error) {
d, ok := c.devices[host] d, err := c.getDeviceByHost(host)
if !ok { if err != nil {
return nil, fmt.Errorf("device not found: %s", host) return nil, err
} }
var out Services var out Services
@@ -199,7 +226,14 @@ func (c *Client) GetAllServices(ctx context.Context) (map[string]*Services, erro
errs error errs error
) )
for host := range c.devices { c.mu.RLock()
hosts := make([]string, 0, len(c.devices))
for h := range c.devices {
hosts = append(hosts, h)
}
c.mu.RUnlock()
for _, host := range hosts {
wg.Go(func() { wg.Go(func() {
res, err := c.GetServices(ctx, host) res, err := c.GetServices(ctx, host)
if err != nil { if err != nil {
@@ -219,9 +253,9 @@ func (c *Client) GetAllServices(ctx context.Context) (map[string]*Services, erro
// GetStatistics retrieves the statistics for a specific device. // GetStatistics retrieves the statistics for a specific device.
func (c *Client) GetStatistics(ctx context.Context, host string) ([]Statistics, error) { func (c *Client) GetStatistics(ctx context.Context, host string) ([]Statistics, error) {
d, ok := c.devices[host] d, err := c.getDeviceByHost(host)
if !ok { if err != nil {
return nil, fmt.Errorf("device not found: %s", host) return nil, err
} }
var out []Statistics var out []Statistics
@@ -241,7 +275,14 @@ func (c *Client) GetAllStatistics(ctx context.Context) (map[string][]Statistics,
errs error errs error
) )
for host := range c.devices { c.mu.RLock()
hosts := make([]string, 0, len(c.devices))
for h := range c.devices {
hosts = append(hosts, h)
}
c.mu.RUnlock()
for _, host := range hosts {
wg.Go(func() { wg.Go(func() {
res, err := c.GetStatistics(ctx, host) res, err := c.GetStatistics(ctx, host)
if err != nil { if err != nil {
@@ -261,9 +302,9 @@ func (c *Client) GetAllStatistics(ctx context.Context) (map[string][]Statistics,
// GetNeighbors retrieves the neighbors for a specific device. // GetNeighbors retrieves the neighbors for a specific device.
func (c *Client) GetNeighbors(ctx context.Context, host string) ([]Neighbor, error) { func (c *Client) GetNeighbors(ctx context.Context, host string) ([]Neighbor, error) {
d, ok := c.devices[host] d, err := c.getDeviceByHost(host)
if !ok { if err != nil {
return nil, fmt.Errorf("device not found: %s", host) return nil, err
} }
var out []Neighbor var out []Neighbor
@@ -283,7 +324,14 @@ func (c *Client) GetAllNeighbors(ctx context.Context) (map[string][]Neighbor, er
errs error errs error
) )
for host := range c.devices { c.mu.RLock()
hosts := make([]string, 0, len(c.devices))
for h := range c.devices {
hosts = append(hosts, h)
}
c.mu.RUnlock()
for _, host := range hosts {
wg.Go(func() { wg.Go(func() {
res, err := c.GetNeighbors(ctx, host) res, err := c.GetNeighbors(ctx, host)
if err != nil { if err != nil {

View File

@@ -21,6 +21,7 @@ import (
// Client handles communication with EdgeOS devices. // Client handles communication with EdgeOS devices.
type Client struct { type Client struct {
mu sync.RWMutex
devices map[string]*deviceClient devices map[string]*deviceClient
} }
@@ -31,40 +32,94 @@ type deviceClient struct {
mu sync.Mutex mu sync.Mutex
} }
// MustNew creates a new Client with the given configurations. func newDeviceClient(cfg Config) *deviceClient {
// It panics if a configuration is invalid (though currently we just accept all).
func MustNew(ctx context.Context, configs []Config) *Client {
devices := make(map[string]*deviceClient)
for _, cfg := range configs {
// Use Host as the key.
// Ensure scheme is set // Ensure scheme is set
if cfg.Scheme == "" { if cfg.Scheme == "" {
cfg.Scheme = "https" cfg.Scheme = "https"
} }
tr := http.DefaultTransport.(*http.Transport).Clone() var tr http.RoundTripper
if tr.TLSClientConfig == nil { if cfg.Transport != nil {
tr.TLSClientConfig = &tls.Config{} tr = cfg.Transport
} else {
defaultTr := http.DefaultTransport.(*http.Transport).Clone()
if defaultTr.TLSClientConfig == nil {
defaultTr.TLSClientConfig = &tls.Config{}
}
defaultTr.TLSClientConfig.InsecureSkipVerify = cfg.Insecure
tr = defaultTr
} }
tr.TLSClientConfig.InsecureSkipVerify = cfg.Insecure
client := &http.Client{ client := &http.Client{
Transport: tr, Transport: tr,
Timeout: cfg.Timeout, Timeout: cfg.Timeout,
} }
devices[cfg.Host] = &deviceClient{ return &deviceClient{
config: cfg, config: cfg,
client: client, client: client,
} }
} }
// MustNew creates a new Client with the given configurations.
// It panics if a configuration is invalid (though currently we just accept all).
func MustNew(ctx context.Context, configs []Config) *Client {
devices := make(map[string]*deviceClient)
for _, cfg := range configs {
devices[cfg.Host] = newDeviceClient(cfg)
}
return &Client{ return &Client{
devices: devices, devices: devices,
} }
} }
// Add adds a new device to the client.
// It returns an error if a device with the same host already exists.
func (c *Client) Add(cfg *Config) error {
if cfg == nil {
return fmt.Errorf("config cannot be nil")
}
d := newDeviceClient(*cfg)
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.devices[cfg.Host]; ok {
return fmt.Errorf("device already exists: %s", cfg.Host)
}
c.devices[cfg.Host] = d
return nil
}
// Del removes a device from the client.
// It returns an error if the device does not exist.
func (c *Client) Del(host string) error {
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.devices[host]; !ok {
return fmt.Errorf("device not found: %s", host)
}
delete(c.devices, host)
return nil
}
func (c *Client) getDeviceByHost(host string) (*deviceClient, error) {
c.mu.RLock()
defer c.mu.RUnlock()
d, ok := c.devices[host]
if !ok {
return nil, fmt.Errorf("device not found: %s", host)
}
return d, nil
}
func (d *deviceClient) login(ctx context.Context) error { func (d *deviceClient) login(ctx context.Context) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()

109
pkg/edgeos/client_test.go Normal file
View File

@@ -0,0 +1,109 @@
package edgeos
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"sync"
"testing"
)
type mockTransport struct {
RoundTripFunc func(req *http.Request) (*http.Response, error)
}
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if m.RoundTripFunc != nil {
return m.RoundTripFunc(req)
}
// Default mock response
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewBufferString("{}")),
Header: make(http.Header),
}, nil
}
func TestClient_ThreadSafety(t *testing.T) {
ctx := context.Background()
client := MustNew(ctx, []Config{})
var wg sync.WaitGroup
start := make(chan struct{})
// Writer: Adds and deletes hosts
wg.Add(1)
go func() {
defer wg.Done()
<-start
for i := 0; i < 100; i++ {
host := fmt.Sprintf("host-%d", i)
cfg := &Config{
Host: host,
Transport: &mockTransport{},
}
if err := client.Add(cfg); err != nil {
// verify we don't error on valid add
t.Logf("Add error: %v", err)
}
// We invoke Del immediately.
if err := client.Del(host); err != nil {
t.Logf("Del error: %v", err)
}
}
}()
// Reader: Iterates hosts
wg.Add(1)
go func() {
defer wg.Done()
<-start
for i := 0; i < 10; i++ {
// GetAllInterfaces iterates keys.
// With mock transport, this will succeed (returning empty structs)
// checking for race conditions.
_, _ = client.GetAllInterfaces(ctx)
}
}()
close(start)
wg.Wait()
}
func TestClient_AddDel(t *testing.T) {
ctx := context.Background()
client := MustNew(ctx, []Config{})
cfg := &Config{
Host: "test-host",
Transport: &mockTransport{},
}
if err := client.Add(cfg); err != nil {
t.Fatalf("Add failed: %v", err)
}
if err := client.Add(cfg); err == nil {
t.Fatal("Expected error adding duplicate host, got nil")
}
// Verify we can retrieve it
// Mock transport returns 200 OK with empty body, so GetInterfaces should return empty slice (or error decoding if empty body is not valid JSON array? actually "{}" is valid object, but GetInterfaces expects array for /interfaces?)
// Let's check api.go: GetInterfaces calls /interfaces.
// We can customize the mock if we want to test success return.
// For this test, we just care that it doesn't return "device not found".
_, err := client.GetInterfaces(ctx, "test-host")
if err != nil && err.Error() == "device not found: test-host" {
t.Fatal("Device should exist")
}
if err := client.Del("test-host"); err != nil {
t.Fatalf("Del failed: %v", err)
}
if err := client.Del("test-host"); err == nil {
t.Fatal("Expected error deleting non-existent host, got nil")
}
}

View File

@@ -1,6 +1,9 @@
package edgeos package edgeos
import "time" import (
"net/http"
"time"
)
// Config represents the configuration for an EdgeOS device. // Config represents the configuration for an EdgeOS device.
type Config struct { type Config struct {
@@ -10,4 +13,6 @@ type Config struct {
Username string Username string
Password string Password string
Timeout time.Duration Timeout time.Duration
// Transport allows customizing the http transport (useful for testing)
Transport http.RoundTripper
} }