diff --git a/pkg/eia/eia.go b/pkg/eia/eia.go index 3b3b224..f61c56d 100644 --- a/pkg/eia/eia.go +++ b/pkg/eia/eia.go @@ -2,9 +2,11 @@ package eia import ( "context" + "errors" "fmt" "net/url" "slices" + "strings" "time" "github.com/deepmap/oapi-codegen/pkg/securityprovider" @@ -66,7 +68,14 @@ func NewFacets(facets ...*Facet) *eiaapi.Facets { func NewClient(opts *ClientOpts) (*Client, error) { baseURL := defaultBaseURL if opts.BaseURL != nil { + if !strings.HasPrefix(opts.BaseURL.Scheme, "http") { + return nil, errors.New("invalid scheme, only http or https supported") + } baseURL = opts.BaseURL.String() + + if _, err := url.Parse(baseURL); err != nil { + return nil, err + } } hcTimeout := defaultPingTimeout diff --git a/pkg/eia/eia_reflection_test.go b/pkg/eia/eia_reflection_test.go new file mode 100644 index 0000000..9a694bd --- /dev/null +++ b/pkg/eia/eia_reflection_test.go @@ -0,0 +1,40 @@ +package eia + +import ( + "reflect" + "testing" +) + +func TestGetRoutes(t *testing.T) { + type args struct { + suffixes []string + } + tests := []struct { + name string + args args + want []string + }{ + { + name: "List known routes", + args: args{ + suffixes: []string{ + "Aeo", + "Electricity", + "Gas", + }, + }, + want: []string{ + "GetV2Aeo", + "GetV2Electricity", + "GetV2NaturalGas", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetRoutes(tt.args.suffixes...); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetRoutes() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/eia/eia_test.go b/pkg/eia/eia_test.go index 403496f..c9cb20d 100644 --- a/pkg/eia/eia_test.go +++ b/pkg/eia/eia_test.go @@ -1,6 +1,8 @@ package eia import ( + "context" + "net/url" "reflect" "testing" ) @@ -65,3 +67,48 @@ func TestNewFacets(t *testing.T) { }) } } + +func TestNewClient(t *testing.T) { + type args struct { + opts *ClientOpts + } + tests := []struct { + name string + args args + want *Client + wantErr bool + }{ + { + name: "Bad URL Scheme", + args: args{opts: &ClientOpts{ + Context: context.TODO(), + APIKey: "testkey", + BaseURL: &url.URL{Scheme: "grpc"}, + }}, + want: nil, + wantErr: true, + }, + { + name: "Bad Host", + args: args{opts: &ClientOpts{ + Context: context.TODO(), + APIKey: "testkey", + BaseURL: &url.URL{Host: "bad host:realbad"}, + }}, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewClient(tt.args.opts) + if (err != nil) != tt.wantErr { + t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewClient() = %v, want %v", got, tt.want) + } + }) + } +}