Skip to content

Commit 9ac0e8c

Browse files
committed
Add more unit tests
1 parent 155e094 commit 9ac0e8c

File tree

7 files changed

+363
-2
lines changed

7 files changed

+363
-2
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package middleware_test
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"npm/internal/api/middleware"
9+
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestAccessControl(t *testing.T) {
14+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
15+
w.WriteHeader(http.StatusOK)
16+
})
17+
18+
rr := httptest.NewRecorder()
19+
req, err := http.NewRequest("GET", "/", nil)
20+
assert.Nil(t, err)
21+
accessControl := middleware.AccessControl(handler)
22+
accessControl.ServeHTTP(rr, req)
23+
assert.Equal(t, http.StatusOK, rr.Code)
24+
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
25+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package middleware_test
2+
3+
import (
4+
"bytes"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
c "npm/internal/api/context"
10+
"npm/internal/api/middleware"
11+
12+
"github.com/stretchr/testify/assert"
13+
)
14+
15+
func TestBodyContext(t *testing.T) {
16+
// Create a test request with a body
17+
body := []byte(`{"name": "John", "age": 30}`)
18+
req, err := http.NewRequest("POST", "/test", bytes.NewBuffer(body))
19+
assert.Nil(t, err)
20+
21+
// Create a test response recorder
22+
rr := httptest.NewRecorder()
23+
24+
// Create a test handler that checks the context for the body data
25+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
26+
bodyData := r.Context().Value(c.BodyCtxKey).([]byte)
27+
assert.Equal(t, body, bodyData)
28+
})
29+
30+
// Wrap the handler with the BodyContext middleware
31+
mw := middleware.BodyContext()(handler)
32+
33+
// Call the middleware with the test request and response recorder
34+
mw.ServeHTTP(rr, req)
35+
36+
// Check that the response status code is 200
37+
status := rr.Code
38+
assert.Equal(t, http.StatusOK, status)
39+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package middleware_test
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"npm/internal/api/middleware"
9+
10+
"github.com/go-chi/chi/v5"
11+
"github.com/stretchr/testify/assert"
12+
)
13+
14+
func TestCors(t *testing.T) {
15+
r := chi.NewRouter()
16+
r.Use(middleware.Cors(r))
17+
18+
r.Get("/test", func(w http.ResponseWriter, r *http.Request) {
19+
w.Write([]byte("test"))
20+
})
21+
22+
req, err := http.NewRequest("GET", "/test", nil)
23+
assert.Nil(t, err)
24+
25+
rr := httptest.NewRecorder()
26+
r.ServeHTTP(rr, req)
27+
28+
assert.Equal(t, "GET,OPTIONS", rr.Header().Get("Access-Control-Allow-Methods"))
29+
assert.Equal(t, "Authorization,Host,Content-Type,Connection,User-Agent,Cache-Control,Accept-Encoding", rr.Header().Get("Access-Control-Allow-Headers"))
30+
assert.Equal(t, "test", rr.Body.String())
31+
}
32+
33+
func TestCorsNoRoute(t *testing.T) {
34+
r := chi.NewRouter()
35+
r.Use(middleware.Cors(r))
36+
37+
req, err := http.NewRequest("GET", "/test", nil)
38+
assert.Nil(t, err)
39+
40+
rr := httptest.NewRecorder()
41+
r.ServeHTTP(rr, req)
42+
43+
assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Methods"))
44+
assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Headers"))
45+
}
46+
47+
func TestOptions(t *testing.T) {
48+
r := chi.NewRouter()
49+
r.Use(middleware.Options(r))
50+
51+
r.Get("/test", func(w http.ResponseWriter, r *http.Request) {
52+
w.Write([]byte("test"))
53+
})
54+
55+
req, err := http.NewRequest("OPTIONS", "/test", nil)
56+
assert.Nil(t, err)
57+
58+
rr := httptest.NewRecorder()
59+
r.ServeHTTP(rr, req)
60+
61+
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
62+
assert.Equal(t, "application/json", rr.Header().Get("Content-Type"))
63+
assert.Equal(t, "{}", rr.Body.String())
64+
}
65+
66+
func TestOptionsNoRoute(t *testing.T) {
67+
r := chi.NewRouter()
68+
r.Use(middleware.Options(r))
69+
70+
req, err := http.NewRequest("OPTIONS", "/test", nil)
71+
assert.Nil(t, err)
72+
73+
rr := httptest.NewRecorder()
74+
r.ServeHTTP(rr, req)
75+
76+
assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Origin"))
77+
assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Methods"))
78+
assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Headers"))
79+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package middleware_test
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
10+
"npm/internal/api/middleware"
11+
"npm/internal/config"
12+
)
13+
14+
func TestEnforceSetup(t *testing.T) {
15+
tests := []struct {
16+
name string
17+
shouldBeSetup bool
18+
isSetup bool
19+
expectedCode int
20+
}{
21+
{
22+
name: "should allow request when setup is expected and is setup",
23+
shouldBeSetup: true,
24+
isSetup: true,
25+
expectedCode: http.StatusOK,
26+
},
27+
{
28+
name: "should error when setup is expected but not setup",
29+
shouldBeSetup: true,
30+
isSetup: false,
31+
expectedCode: http.StatusForbidden,
32+
},
33+
{
34+
name: "should allow request when setup is not expected and not setup",
35+
shouldBeSetup: false,
36+
isSetup: false,
37+
expectedCode: http.StatusOK,
38+
},
39+
{
40+
name: "should error when setup is not expected but is setup",
41+
shouldBeSetup: false,
42+
isSetup: true,
43+
expectedCode: http.StatusForbidden,
44+
},
45+
}
46+
47+
for _, tt := range tests {
48+
t.Run(tt.name, func(t *testing.T) {
49+
config.IsSetup = tt.isSetup
50+
51+
handler := middleware.EnforceSetup(tt.shouldBeSetup)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
52+
w.WriteHeader(http.StatusOK)
53+
}))
54+
55+
req := httptest.NewRequest(http.MethodGet, "/", nil)
56+
w := httptest.NewRecorder()
57+
handler.ServeHTTP(w, req)
58+
assert.Equal(t, tt.expectedCode, w.Code)
59+
})
60+
}
61+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package middleware_test
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
c "npm/internal/api/context"
10+
"npm/internal/api/middleware"
11+
12+
"github.com/stretchr/testify/assert"
13+
)
14+
15+
func TestExpansion(t *testing.T) {
16+
t.Run("with expand query param", func(t *testing.T) {
17+
req, err := http.NewRequest("GET", "/path?expand=item1,item2", nil)
18+
assert.NoError(t, err)
19+
20+
rr := httptest.NewRecorder()
21+
22+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
23+
expand := middleware.GetExpandFromContext(r)
24+
assert.Equal(t, []string{"item1", "item2"}, expand)
25+
})
26+
27+
middleware.Expansion(handler).ServeHTTP(rr, req)
28+
29+
assert.Equal(t, http.StatusOK, rr.Code)
30+
})
31+
32+
t.Run("without expand query param", func(t *testing.T) {
33+
req, err := http.NewRequest("GET", "/path", nil)
34+
assert.NoError(t, err)
35+
36+
rr := httptest.NewRecorder()
37+
38+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
39+
expand := middleware.GetExpandFromContext(r)
40+
assert.Nil(t, expand)
41+
})
42+
43+
middleware.Expansion(handler).ServeHTTP(rr, req)
44+
45+
assert.Equal(t, http.StatusOK, rr.Code)
46+
})
47+
}
48+
49+
func TestGetExpandFromContext(t *testing.T) {
50+
t.Run("with context value", func(t *testing.T) {
51+
req, err := http.NewRequest("GET", "/path", nil)
52+
assert.NoError(t, err)
53+
54+
ctx := req.Context()
55+
ctx = context.WithValue(ctx, c.ExpansionCtxKey, []string{"item1", "item2"})
56+
req = req.WithContext(ctx)
57+
58+
expand := middleware.GetExpandFromContext(req)
59+
assert.Equal(t, []string{"item1", "item2"}, expand)
60+
})
61+
62+
t.Run("without context value", func(t *testing.T) {
63+
req, err := http.NewRequest("GET", "/path", nil)
64+
assert.NoError(t, err)
65+
66+
expand := middleware.GetExpandFromContext(req)
67+
assert.Nil(t, expand)
68+
})
69+
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package middleware_test
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
c "npm/internal/api/context"
10+
"npm/internal/api/middleware"
11+
"npm/internal/entity/user"
12+
"npm/internal/model"
13+
"npm/internal/tags"
14+
15+
"github.com/stretchr/testify/assert"
16+
)
17+
18+
func TestListQuery(t *testing.T) {
19+
tests := []struct {
20+
name string
21+
queryParams string
22+
expectedStatus int
23+
}{
24+
{
25+
name: "valid query params",
26+
queryParams: "?name:contains=John&sort=name.desc",
27+
expectedStatus: http.StatusOK,
28+
},
29+
{
30+
name: "invalid sort field",
31+
queryParams: "?name:contains=John&sort=invalid_field",
32+
expectedStatus: http.StatusBadRequest,
33+
},
34+
{
35+
name: "invalid filter value",
36+
queryParams: "?name=123",
37+
expectedStatus: http.StatusOK,
38+
},
39+
}
40+
41+
for _, tt := range tests {
42+
t.Run(tt.name, func(t *testing.T) {
43+
req, err := http.NewRequest("GET", "/test"+tt.queryParams, nil)
44+
assert.NoError(t, err)
45+
46+
testObj := user.Model{}
47+
48+
ctx := context.Background()
49+
ctx = context.WithValue(ctx, c.FiltersCtxKey, tags.GetFilterSchema(testObj))
50+
51+
rr := httptest.NewRecorder()
52+
handler := middleware.ListQuery(testObj)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
53+
w.WriteHeader(http.StatusOK)
54+
}))
55+
56+
handler.ServeHTTP(rr, req.WithContext(ctx))
57+
58+
assert.Equal(t, tt.expectedStatus, rr.Code)
59+
})
60+
}
61+
}
62+
63+
func TestGetFiltersFromContext(t *testing.T) {
64+
req, err := http.NewRequest("GET", "/test", nil)
65+
assert.NoError(t, err)
66+
67+
filters := []model.Filter{
68+
{Field: "name", Modifier: "contains", Value: []string{"test"}},
69+
}
70+
ctx := context.WithValue(req.Context(), c.FiltersCtxKey, filters)
71+
req = req.WithContext(ctx)
72+
73+
result := middleware.GetFiltersFromContext(req)
74+
assert.Equal(t, filters, result)
75+
}
76+
77+
func TestGetSortFromContext(t *testing.T) {
78+
req, err := http.NewRequest("GET", "/test", nil)
79+
assert.NoError(t, err)
80+
81+
sorts := []model.Sort{
82+
{Field: "name", Direction: "asc"},
83+
}
84+
ctx := context.WithValue(req.Context(), c.SortCtxKey, sorts)
85+
req = req.WithContext(ctx)
86+
87+
result := middleware.GetSortFromContext(req)
88+
assert.Equal(t, sorts, result)
89+
}

backend/internal/api/middleware/pretty_print.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ func PrettyPrint(next http.Handler) http.Handler {
1313
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1414
prettyStr := r.URL.Query().Get("pretty")
1515
if prettyStr == "1" || prettyStr == "true" {
16-
ctx := r.Context()
17-
ctx = context.WithValue(ctx, c.PrettyPrintCtxKey, true)
16+
ctx := context.WithValue(r.Context(), c.PrettyPrintCtxKey, true)
1817
next.ServeHTTP(w, r.WithContext(ctx))
1918
} else {
2019
next.ServeHTTP(w, r)

0 commit comments

Comments
 (0)