diff --git a/pkg/testing/httptesting/client.go b/pkg/testing/httptesting/client.go new file mode 100644 index 000000000..3ff094f99 --- /dev/null +++ b/pkg/testing/httptesting/client.go @@ -0,0 +1,68 @@ +package httptesting + +import ( + "encoding/json" + "net/http" + "os" +) + +// Simplied client for testing that doesn't require multiple URLs + +type EchoSave struct { + // saveTo provides a way for tests to verify http.Request fields. + + // An http.Client's transport layer has only one method, so there's no way to + // return variables while adhering to it's interface. One solution is to use + // type casting where the caller must know the transport layer is actually + // of type "EchoSave". But a cleaner approach is to pass in the address of + // a local variable, and store the http.Request there. + + // Callers provide the address of a local variable, which is stored here. + saveTo **http.Request + content string + err error +} + +func (st *EchoSave) RoundTrip(req *http.Request) (*http.Response, error) { + if st.saveTo != nil { + // If the caller provided a local variable, update it with the latest http.Request + *st.saveTo = req + } + resp := BuildResponseString(http.StatusOK, st.content) + SetHeader(resp, "Content-Type", "application/json") + return resp, st.err +} + +func HttpClientFromFile(filename string) *http.Client { + rawBytes, err := os.ReadFile(filename) + transport := EchoSave{err: err, content: string(rawBytes)} + return &http.Client{Transport: &transport} +} + +func HttpClientWithContent(content string) *http.Client { + transport := EchoSave{content: content} + return &http.Client{Transport: &transport} +} + +func HttpClientWithError(err error) *http.Client { + transport := EchoSave{err: err} + return &http.Client{Transport: &transport} +} + +func HttpClientWithJson(jsonData interface{}) *http.Client { + jsonBytes, err := json.Marshal(jsonData) + transport := EchoSave{err: err, content: string(jsonBytes)} + return &http.Client{Transport: &transport} +} + +// "Saver" refers to saving the *http.Request in a local variable provided by the caller. +func HttpClientSaver(saved **http.Request, content string) *http.Client { + transport := EchoSave{saveTo: saved, content: content} + return &http.Client{Transport: &transport} +} + +func HttpClientSaverWithJson(saved **http.Request, jsonData interface{}) *http.Client { + jsonBytes, err := json.Marshal(jsonData) + transport := EchoSave{saveTo: saved, err: err, content: string(jsonBytes)} + return &http.Client{Transport: &transport} +} diff --git a/pkg/testing/httptesting/response.go b/pkg/testing/httptesting/response.go new file mode 100644 index 000000000..77b7f21fc --- /dev/null +++ b/pkg/testing/httptesting/response.go @@ -0,0 +1,54 @@ +package httptesting + +import ( + "bytes" + "encoding/json" + "io" + "net/http" +) + +func BuildResponse(code int, payload []byte) *http.Response { + return &http.Response{ + StatusCode: code, + Body: io.NopCloser(bytes.NewBuffer(payload)), + ContentLength: int64(len(payload)), + } +} + +func BuildResponseString(code int, payload string) *http.Response { + b := []byte(payload) + return &http.Response{ + StatusCode: code, + Body: io.NopCloser( + bytes.NewBuffer(b), + ), + ContentLength: int64(len(b)), + } +} + +func BuildResponseJson(code int, payload interface{}) *http.Response { + data, err := json.Marshal(payload) + if err != nil { + return BuildResponseString(http.StatusInternalServerError, `{error: "httptesting.MockTransport error calling json.Marshal()"}`) + } + + resp := BuildResponse(code, data) + resp.Header = http.Header{} + resp.Header.Set("Content-Type", "application/json") + return resp +} + +func SetHeader(resp *http.Response, name string, value string) *http.Response { + if resp.Header == nil { + resp.Header = http.Header{} + } + resp.Header.Set(name, value) + return resp +} + +func DeleteHeader(resp *http.Response, name string) *http.Response { + if resp.Header != nil { + resp.Header.Del(name) + } + return resp +} diff --git a/pkg/testing/httptesting/transport.go b/pkg/testing/httptesting/transport.go new file mode 100644 index 000000000..f4ed47d9d --- /dev/null +++ b/pkg/testing/httptesting/transport.go @@ -0,0 +1,98 @@ +package httptesting + +import ( + "net/http" + "strings" + + "github.com/pkg/errors" +) + +type RoundTripFunc func(req *http.Request) (*http.Response, error) + +type MockTransport struct { + getHandlers map[string]RoundTripFunc + postHandlers map[string]RoundTripFunc + deleteHandlers map[string]RoundTripFunc + putHandlers map[string]RoundTripFunc +} + +func (transport *MockTransport) GET(path string, f RoundTripFunc) { + if transport.getHandlers == nil { + transport.getHandlers = make(map[string]RoundTripFunc) + } + + transport.getHandlers[path] = f +} + +func (transport *MockTransport) POST(path string, f RoundTripFunc) { + if transport.postHandlers == nil { + transport.postHandlers = make(map[string]RoundTripFunc) + } + + transport.postHandlers[path] = f +} + +func (transport *MockTransport) DELETE(path string, f RoundTripFunc) { + if transport.deleteHandlers == nil { + transport.deleteHandlers = make(map[string]RoundTripFunc) + } + + transport.deleteHandlers[path] = f +} + +func (transport *MockTransport) PUT(path string, f RoundTripFunc) { + if transport.putHandlers == nil { + transport.putHandlers = make(map[string]RoundTripFunc) + } + + transport.putHandlers[path] = f +} + +// Used for migration to MAX v3 api, where order cancel uses DELETE (MAX v2 api uses POST). +func (transport *MockTransport) PostOrDelete(isDelete bool, path string, f RoundTripFunc) { + if isDelete { + transport.DELETE(path, f) + } else { + transport.POST(path, f) + } +} + +func (transport *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + var handlers map[string]RoundTripFunc + + switch strings.ToUpper(req.Method) { + + case "GET": + handlers = transport.getHandlers + case "POST": + handlers = transport.postHandlers + case "DELETE": + handlers = transport.deleteHandlers + case "PUT": + handlers = transport.putHandlers + + default: + return nil, errors.Errorf("unsupported mock transport request method: %s", req.Method) + + } + + f, ok := handlers[req.URL.Path] + if !ok { + return nil, errors.Errorf("roundtrip mock to %s %s is not defined", req.Method, req.URL.Path) + } + + return f(req) +} + +func MockWithJsonReply(url string, rawData interface{}) *http.Client { + tripFunc := func(_ *http.Request) (*http.Response, error) { + return BuildResponseJson(http.StatusOK, rawData), nil + } + + transport := &MockTransport{} + transport.DELETE(url, tripFunc) + transport.GET(url, tripFunc) + transport.POST(url, tripFunc) + transport.PUT(url, tripFunc) + return &http.Client{Transport: transport} +}