/* Package edgeos provides a client for interacting with Ubiquiti EdgeOS devices via their REST API. It supports authentication, session management, and retrieval of system and interface configuration from one or more devices. */ package edgeos import ( "bytes" "context" "crypto/tls" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" "sync" ) // Client handles communication with EdgeOS devices. type Client struct { mu sync.RWMutex devices map[string]*deviceClient } type deviceClient struct { config Config client *http.Client cookies []*http.Cookie authInfo *AuthResponse mu sync.Mutex } func newDeviceClient(cfg Config) *deviceClient { if cfg.Scheme == "" { cfg.Scheme = "https" } var tr http.RoundTripper if cfg.Transport != nil { tr = cfg.Transport } else { defaultTr := http.DefaultTransport.(*http.Transport).Clone() if defaultTr.TLSClientConfig == nil { defaultTr.TLSClientConfig = &tls.Config{} } defaultTr.TLSClientConfig.InsecureSkipVerify = cfg.Insecure tr = defaultTr } client := &http.Client{ Transport: tr, Timeout: cfg.Timeout, } return &deviceClient{ config: cfg, client: client, } } // MustNew creates a new Client with the given configurations. // It panics if a configuration is invalid (though currently we just accept all). func MustNew(ctx context.Context, configs []Config) *Client { devices := make(map[string]*deviceClient) for _, cfg := range configs { devices[cfg.Host] = newDeviceClient(cfg) } return &Client{ devices: devices, } } // Add adds a new device to the client. // It returns an error if a device with the same host already exists. func (c *Client) Add(cfg *Config) error { if cfg == nil { return fmt.Errorf("config cannot be nil") } d := newDeviceClient(*cfg) c.mu.Lock() defer c.mu.Unlock() if _, ok := c.devices[cfg.Host]; ok { return fmt.Errorf("device already exists: %s", cfg.Host) } c.devices[cfg.Host] = d return nil } // Del removes a device from the client. // It returns an error if the device does not exist. func (c *Client) Del(host string) error { c.mu.Lock() defer c.mu.Unlock() if _, ok := c.devices[host]; !ok { return fmt.Errorf("device not found: %s", host) } delete(c.devices, host) return nil } func (c *Client) getDeviceByHost(host string) (*deviceClient, error) { c.mu.RLock() defer c.mu.RUnlock() d, ok := c.devices[host] if !ok { return nil, fmt.Errorf("device not found: %s", host) } return d, nil } func (d *deviceClient) login(ctx context.Context) error { d.mu.Lock() defer d.mu.Unlock() reqUrl := fmt.Sprintf("%s://%s/api/login2", d.config.Scheme, d.config.Host) data := url.Values{} data.Set("username", d.config.Username) data.Set("password", d.config.Password) req, err := http.NewRequestWithContext(ctx, "POST", reqUrl, strings.NewReader(data.Encode())) if err != nil { return err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") req.Header.Set("Origin", fmt.Sprintf("%s://%s", d.config.Scheme, d.config.Host)) resp, err := d.client.Do(req) if err != nil { return err } defer resp.Body.Close() respPayload, err := io.ReadAll(resp.Body) if err != nil { return err } if resp.StatusCode != http.StatusOK { return fmt.Errorf("login failed [%d]: %s", resp.StatusCode, string(respPayload)) } var authResp AuthResponse if err := json.Unmarshal(respPayload, &authResp); err != nil { return fmt.Errorf("failed to parse auth response: %w", err) } if !authResp.Authenticated { return fmt.Errorf("authentication failed for user %s", d.config.Username) } d.authInfo = &authResp d.cookies = resp.Cookies() return nil } func (d *deviceClient) do(ctx context.Context, method, path string, body any, out any) error { err := d.doRequest(ctx, method, path, body, out) if err == nil { return nil } if strings.Contains(err.Error(), "status 401") || strings.Contains(err.Error(), "unauthorized") { if loginErr := d.login(ctx); loginErr != nil { return fmt.Errorf("re-login failed: %w", loginErr) } return d.doRequest(ctx, method, path, body, out) } return err } func (d *deviceClient) doRequest(ctx context.Context, method, path string, body any, out any) error { reqUrl := fmt.Sprintf("%s://%s%s", d.config.Scheme, d.config.Host, path) var reqBody io.Reader if body != nil { b, err := json.Marshal(body) if err != nil { return err } reqBody = bytes.NewBuffer(b) } req, err := http.NewRequestWithContext(ctx, method, reqUrl, reqBody) if err != nil { return err } d.mu.Lock() cookies := d.cookies d.mu.Unlock() if len(cookies) > 0 { cookieURL, _ := url.Parse(reqUrl) for _, cookie := range cookies { if cookie.Domain == "" || strings.HasSuffix(cookieURL.Host, cookie.Domain) { req.AddCookie(cookie) } } } if body != nil { req.Header.Set("Content-Type", "application/json") } req.Header.Set("Accept", "application/json") resp, err := d.client.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode == http.StatusUnauthorized { return fmt.Errorf("status 401") } if resp.StatusCode != http.StatusOK { b, _ := io.ReadAll(resp.Body) return fmt.Errorf("request failed: status %d, body: %s", resp.StatusCode, string(b)) } if out != nil { if err := json.NewDecoder(resp.Body).Decode(out); err != nil { return err } } return nil }