Skip to content

Commit a0e17f9

Browse files
committed
Better checking for api sort param to prevent sql injection
And moved filters out and cached object reflection
1 parent 9b32329 commit a0e17f9

File tree

12 files changed

+312
-223
lines changed

12 files changed

+312
-223
lines changed

backend/internal/api/context/context.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ var (
77
UserIDCtxKey = &contextKey{"UserID"}
88
// FiltersCtxKey is the name of the Filters value on the context
99
FiltersCtxKey = &contextKey{"Filters"}
10+
// SortCtxKey is the name of the Sort value on the context
11+
SortCtxKey = &contextKey{"Sort"}
1012
// PrettyPrintCtxKey is the name of the pretty print context
1113
PrettyPrintCtxKey = &contextKey{"Pretty"}
1214
// ExpansionCtxKey is the name of the expansion context

backend/internal/api/handler/helpers.go

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ package handler
33
import (
44
"net/http"
55
"strconv"
6-
"strings"
76

87
"npm/internal/api/context"
8+
"npm/internal/api/middleware"
99
"npm/internal/model"
1010

1111
"github.com/go-chi/chi/v5"
@@ -23,50 +23,11 @@ func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) {
2323
return pageInfo, err
2424
}
2525

26-
pageInfo.Sort = getSortParameter(r)
26+
pageInfo.Sort = middleware.GetSortFromContext(r)
2727

2828
return pageInfo, nil
2929
}
3030

31-
func getSortParameter(r *http.Request) []model.Sort {
32-
var sortFields []model.Sort
33-
34-
queryValues := r.URL.Query()
35-
sortString := queryValues.Get("sort")
36-
if sortString == "" {
37-
return sortFields
38-
}
39-
40-
// Split sort fields up in to slice
41-
sorts := strings.Split(sortString, ",")
42-
for _, sortItem := range sorts {
43-
if strings.Contains(sortItem, ".") {
44-
theseItems := strings.Split(sortItem, ".")
45-
46-
switch strings.ToLower(theseItems[1]) {
47-
case "desc":
48-
fallthrough
49-
case "descending":
50-
theseItems[1] = "DESC"
51-
default:
52-
theseItems[1] = "ASC"
53-
}
54-
55-
sortFields = append(sortFields, model.Sort{
56-
Field: theseItems[0],
57-
Direction: theseItems[1],
58-
})
59-
} else {
60-
sortFields = append(sortFields, model.Sort{
61-
Field: sortItem,
62-
Direction: "ASC",
63-
})
64-
}
65-
}
66-
67-
return sortFields
68-
}
69-
7031
func getQueryVarInt(r *http.Request, varName string, required bool, defaultValue int) (int, error) {
7132
queryValues := r.URL.Query()
7233
varValue := queryValues.Get(varName)

backend/internal/api/middleware/expansion.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
)
1010

1111
// Expansion will determine whether the request should have objects expanded
12-
// with ?expand=1 or ?expand=true
12+
// with ?expand=item,item
1313
func Expansion(next http.Handler) http.Handler {
1414
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1515
expandStr := r.URL.Query().Get("expand")

backend/internal/api/middleware/filters.go

Lines changed: 0 additions & 118 deletions
This file was deleted.
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
package middleware
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"net/http"
8+
"strings"
9+
10+
c "npm/internal/api/context"
11+
h "npm/internal/api/http"
12+
"npm/internal/entity"
13+
"npm/internal/model"
14+
"npm/internal/tags"
15+
"npm/internal/util"
16+
17+
"github.com/qri-io/jsonschema"
18+
)
19+
20+
// ListQuery will accept a pre-defined schemaData to validate against the GET query params
21+
// passed in to this endpoint. This will ensure that the filters are not injecting SQL
22+
// and the sort parameter is valid as well.
23+
// After we have determined what the Filters are to be, they are saved on the Context
24+
// to be used later in other endpoints.
25+
func ListQuery(obj interface{}) func(http.Handler) http.Handler {
26+
schemaData := entity.GetFilterSchema(obj, true)
27+
filterMap := tags.GetFilterMap(obj)
28+
29+
return func(next http.Handler) http.Handler {
30+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
31+
ctx := r.Context()
32+
33+
ctx, statusCode, errMsg, errors := listQueryFilters(r, ctx, schemaData)
34+
if statusCode > 0 {
35+
h.ResultErrorJSON(w, r, statusCode, errMsg, errors)
36+
return
37+
}
38+
39+
ctx, statusCode, errMsg = listQuerySort(r, filterMap, ctx)
40+
if statusCode > 0 {
41+
h.ResultErrorJSON(w, r, statusCode, errMsg, nil)
42+
return
43+
}
44+
45+
next.ServeHTTP(w, r.WithContext(ctx))
46+
})
47+
}
48+
}
49+
50+
func listQuerySort(
51+
r *http.Request,
52+
filterMap map[string]model.FilterMapValue,
53+
ctx context.Context,
54+
) (context.Context, int, string) {
55+
var sortFields []model.Sort
56+
57+
sortString := r.URL.Query().Get("sort")
58+
if sortString == "" {
59+
return ctx, 0, ""
60+
}
61+
62+
// Split sort fields up in to slice
63+
sorts := strings.Split(sortString, ",")
64+
for _, sortItem := range sorts {
65+
if strings.Contains(sortItem, ".") {
66+
theseItems := strings.Split(sortItem, ".")
67+
68+
switch strings.ToLower(theseItems[1]) {
69+
case "desc":
70+
fallthrough
71+
case "descending":
72+
theseItems[1] = "DESC"
73+
default:
74+
theseItems[1] = "ASC"
75+
}
76+
77+
sortFields = append(sortFields, model.Sort{
78+
Field: theseItems[0],
79+
Direction: theseItems[1],
80+
})
81+
} else {
82+
sortFields = append(sortFields, model.Sort{
83+
Field: sortItem,
84+
Direction: "ASC",
85+
})
86+
}
87+
}
88+
89+
// check against filter schema
90+
for _, f := range sortFields {
91+
if _, exists := filterMap[f.Field]; !exists {
92+
return ctx, http.StatusBadRequest, "Invalid sort field"
93+
}
94+
}
95+
96+
ctx = context.WithValue(ctx, c.SortCtxKey, sortFields)
97+
98+
// No problems!
99+
return ctx, 0, ""
100+
}
101+
102+
func listQueryFilters(
103+
r *http.Request,
104+
ctx context.Context,
105+
schemaData string,
106+
) (context.Context, int, string, interface{}) {
107+
reservedFilterKeys := []string{
108+
"limit",
109+
"offset",
110+
"sort",
111+
"expand",
112+
"t", // This is used as a timestamp paramater in some clients and can be ignored
113+
}
114+
115+
var filters []model.Filter
116+
for key, val := range r.URL.Query() {
117+
key = strings.ToLower(key)
118+
119+
// Split out the modifier from the field name and set a default modifier
120+
var keyParts []string
121+
keyParts = strings.Split(key, ":")
122+
if len(keyParts) == 1 {
123+
// Default modifier
124+
keyParts = append(keyParts, "equals")
125+
}
126+
127+
// Only use this filter if it's not a reserved get param
128+
if !util.SliceContainsItem(reservedFilterKeys, keyParts[0]) {
129+
for _, valItem := range val {
130+
// Check that the val isn't empty
131+
if len(strings.TrimSpace(valItem)) > 0 {
132+
valSlice := []string{valItem}
133+
if keyParts[1] == "in" || keyParts[1] == "notin" {
134+
valSlice = strings.Split(valItem, ",")
135+
}
136+
137+
filters = append(filters, model.Filter{
138+
Field: keyParts[0],
139+
Modifier: keyParts[1],
140+
Value: valSlice,
141+
})
142+
}
143+
}
144+
}
145+
}
146+
147+
// Only validate schema if there are filters to validate
148+
if len(filters) > 0 {
149+
// Marshal the Filters in to a JSON string so that the Schema Validation works against it
150+
filterData, marshalErr := json.MarshalIndent(filters, "", " ")
151+
if marshalErr != nil {
152+
return ctx, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", marshalErr), nil
153+
}
154+
155+
// Create root schema
156+
rs := &jsonschema.Schema{}
157+
if err := json.Unmarshal([]byte(schemaData), rs); err != nil {
158+
return ctx, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", err), nil
159+
}
160+
161+
// Validate it
162+
errors, jsonError := rs.ValidateBytes(ctx, filterData)
163+
if jsonError != nil {
164+
return ctx, http.StatusBadRequest, jsonError.Error(), nil
165+
}
166+
167+
if len(errors) > 0 {
168+
return ctx, http.StatusBadRequest, "Invalid Filters", errors
169+
}
170+
171+
ctx = context.WithValue(ctx, c.FiltersCtxKey, filters)
172+
}
173+
174+
// No problems!
175+
return ctx, 0, "", nil
176+
}
177+
178+
// GetFiltersFromContext returns the Filters
179+
func GetFiltersFromContext(r *http.Request) []model.Filter {
180+
filters, ok := r.Context().Value(c.FiltersCtxKey).([]model.Filter)
181+
if !ok {
182+
// the assertion failed
183+
return nil
184+
}
185+
return filters
186+
}
187+
188+
// GetSortFromContext returns the Sort
189+
func GetSortFromContext(r *http.Request) []model.Sort {
190+
sorts, ok := r.Context().Value(c.SortCtxKey).([]model.Sort)
191+
if !ok {
192+
// the assertion failed
193+
return nil
194+
}
195+
return sorts
196+
}

0 commit comments

Comments
 (0)