Skip to content

Commit b98a810

Browse files
authored
feat: Pydantic field validation (#101)
* Require pydantic * Add tests for Pydantic field validation / coercion * Add Pydantic validation / coercion * Autoformat * CHANGELOG entry * Update docstring * Add nested Pydantic testing
1 parent 00a7854 commit b98a810

File tree

6 files changed

+186
-8
lines changed

6 files changed

+186
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
# [unreleased] - XXXX-XX-XX
88
### Added
99
- [PR 100](https://github.com/salesforce/django-declarative-apis/pull/100) Improve `errors.py`
10+
- [PR 101](https://github.com/salesforce/django-declarative-apis/pull/101) Allow Pydantic models as field types
1011

1112
# [0.23.1] - 2022-05-17
1213
### Fixed

django_declarative_apis/machinery/attributes.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import time
1414

1515
from django.db import models as django_models
16+
import pydantic
1617

1718
from . import errors
1819
from . import tasks
@@ -112,7 +113,9 @@ def documentation(self):
112113
class TypedEndpointAttributeMixin:
113114
def __init__(self, *args, **kwargs):
114115
self.field_type = kwargs.pop("type", str)
115-
if self.field_type not in RequestField.VALID_FIELD_TYPES:
116+
if not any(
117+
issubclass(self.field_type, t) for t in RequestField.VALID_FIELD_TYPES
118+
):
116119
raise NotImplementedError(
117120
"Request fields of type {0} not supported".format(
118121
self.field_type.__name__
@@ -124,6 +127,8 @@ def coerce_value_to_type(self, raw_value):
124127
try:
125128
if self.field_type == bool and not isinstance(raw_value, self.field_type):
126129
return "rue" in raw_value
130+
elif issubclass(self.field_type, pydantic.BaseModel):
131+
return self.field_type.parse_obj(raw_value)
127132
else:
128133
if isinstance(raw_value, collections.abc.Iterable) and not isinstance(
129134
raw_value, (str, dict)
@@ -201,8 +206,8 @@ class RequestField(TypedEndpointAttributeMixin, RequestProperty):
201206
"""Endpoint properties are called fields. Fields can be simple types such as int,
202207
or they can be used as a decorator on a function.
203208
204-
**Valid field types:** :code:`int`, :code:`bool`, :code:`float`, :code:`str`,
205-
:code:`dict`, :code:`complex`
209+
**Valid field types:** A subclass of :code:`int`, :code:`bool`, :code:`float`,
210+
:code:`str`, :code:`dict`, :code:`complex`, :code:`pydantic.BaseModel`
206211
207212
**Example**
208213
@@ -258,7 +263,7 @@ class FooDefinition(EndpointDefinition):
258263
259264
"""
260265

261-
VALID_FIELD_TYPES = (bool, int, float, complex, str, dict)
266+
VALID_FIELD_TYPES = (bool, int, float, complex, str, dict, pydantic.BaseModel)
262267

263268
def __init__(self, *args, **kwargs):
264269
self.default_value = kwargs.pop("default", None)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ cryptography>=2.0,<=3.4.8
44
decorator==4.0.11
55
django-dirtyfields>=1.2.1
66
oauthlib[signedtoken,rsa]>=2.0.6,<3.1.0
7+
pydantic>=1.8

tests/tests.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,104 @@ def test_run_deferred_task(self):
7676
"/simple", consumer=self.consumer, expected_status_code=HTTPStatus.OK
7777
)
7878
self.assertTrue(cache.get("deferred_task_called"))
79+
80+
def test_dict_field_endpoint(self):
81+
good_dict = {
82+
"length": 11,
83+
"description": "This is a description",
84+
"timestamp": "2022-10-24T00:00:00",
85+
"words": ["foo", "bar", "baz", "quux"],
86+
}
87+
test_data = [
88+
(good_dict, HTTPStatus.OK, "good dict"),
89+
({}, HTTPStatus.OK, "empty_dict"),
90+
(list(good_dict), HTTPStatus.BAD_REQUEST, "list"),
91+
("a string", HTTPStatus.BAD_REQUEST, "string"),
92+
(1337, HTTPStatus.BAD_REQUEST, "int"),
93+
]
94+
95+
for dct, expected_status, message in test_data:
96+
data = {"dict_type_field": dct}
97+
with self.subTest(message):
98+
response = self.client.post(
99+
"/dictfield",
100+
consumer=self.consumer,
101+
data=data,
102+
expected_status_code=expected_status,
103+
content_type="application/json",
104+
)
105+
if expected_status == HTTPStatus.OK:
106+
self.assertDictEqual(json.loads(response.content), data)
107+
108+
def test_pydantic_field_endpoint(self):
109+
good_dict = {
110+
"length": 11,
111+
"description": "This is a description",
112+
"timestamp": "2022-10-24T00:00:00",
113+
"words": ["foo", "bar", "baz", "quux"],
114+
}
115+
test_data = [
116+
(good_dict, HTTPStatus.OK, "no errors"),
117+
({**good_dict, "length": "eleven"}, HTTPStatus.BAD_REQUEST, "bad length"),
118+
(
119+
{**good_dict, "description": ["one", "two"]},
120+
HTTPStatus.BAD_REQUEST,
121+
"bad description",
122+
),
123+
(
124+
{**good_dict, "timestamp": "2022-10-24T99:99:99"},
125+
HTTPStatus.BAD_REQUEST,
126+
"bad timestamp",
127+
),
128+
(
129+
{**good_dict, "words": "foo bar baz quux"},
130+
HTTPStatus.BAD_REQUEST,
131+
"bad words",
132+
),
133+
]
134+
135+
for dct, expected_status, message in test_data:
136+
data = {"pydantic_type_field": dct}
137+
with self.subTest(message):
138+
response = self.client.post(
139+
"/pydanticfield",
140+
consumer=self.consumer,
141+
data=data,
142+
expected_status_code=expected_status,
143+
content_type="application/json",
144+
)
145+
if expected_status == HTTPStatus.OK:
146+
self.assertDictEqual(json.loads(response.content), data)
147+
148+
def test_nested_pydantic_field_endpoint(self):
149+
good_dict = {"b": "hello", "c": {"a": "world"}}
150+
test_data = [
151+
(good_dict, HTTPStatus.OK, "no errors"),
152+
({**good_dict, "b": list("abc")}, HTTPStatus.BAD_REQUEST, "bad b"),
153+
({**good_dict, "c": 11}, HTTPStatus.BAD_REQUEST, "bad c"),
154+
({**good_dict, "c": {"a": list("abc")}}, HTTPStatus.BAD_REQUEST, "bad a"),
155+
({**good_dict, "c": {}}, HTTPStatus.BAD_REQUEST, "missing a"),
156+
(
157+
{k: v for (k, v) in good_dict.items() if k != "b"},
158+
HTTPStatus.BAD_REQUEST,
159+
"missing b",
160+
),
161+
(
162+
{k: v for (k, v) in good_dict.items() if k != "c"},
163+
HTTPStatus.BAD_REQUEST,
164+
"missing c",
165+
),
166+
]
167+
168+
for dct, expected_status, message in test_data:
169+
data = {"nested_pydantic_type_field": dct}
170+
with self.subTest(message):
171+
response = self.client.post(
172+
"/nestedpydanticfield",
173+
consumer=self.consumer,
174+
data=data,
175+
expected_status_code=expected_status,
176+
content_type="application/json",
177+
)
178+
if expected_status == HTTPStatus.OK:
179+
self.assertDictEqual(json.loads(response.content), data)

tests/urls.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@
1414
UUID4_REGEX = r"[0-9a-fA-F]{8}-([0-9a-fA-F]{4}-){3}[0-9a-fA-F]{12}"
1515

1616
urlpatterns = [
17-
re_path(r"^simple", resource_adapter(get=views.SimpleEndpointDefinition)),
18-
re_path(r"^dict", resource_adapter(get=views.DictEndpointDefinition)),
17+
re_path(r"^simple$", resource_adapter(get=views.SimpleEndpointDefinition)),
18+
re_path(r"^dict$", resource_adapter(get=views.DictEndpointDefinition)),
19+
re_path(r"^dictfield$", resource_adapter(post=views.DictFieldEndpointDefinition)),
20+
re_path(
21+
r"^pydanticfield$", resource_adapter(post=views.PydanticFieldEndpointDefinition)
22+
),
23+
re_path(
24+
r"^nestedpydanticfield$",
25+
resource_adapter(post=views.NestedPydanticFieldEndpointDefinition),
26+
),
1927
]

tests/views.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,20 @@
55
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
66
#
77

8+
import datetime
9+
import json
10+
from typing import List
11+
12+
from django.core.cache import cache
13+
14+
import pydantic
15+
816
from django_declarative_apis.machinery import (
917
EndpointDefinition,
1018
field,
1119
deferrable_task,
1220
endpoint_resource,
1321
)
14-
from django.core.cache import cache
15-
1622
from tests.models import TestModel
1723

1824

@@ -44,3 +50,59 @@ def is_authorized(self):
4450
def resource(self):
4551
inst = TestModel.objects.create(int_field=1)
4652
return {"test": inst, "deep_test": {"test": inst}}
53+
54+
55+
class DictFieldEndpointDefinition(EndpointDefinition):
56+
def is_authorized(self):
57+
return True
58+
59+
dict_type_field = field(type=dict, required=True)
60+
61+
@endpoint_resource(type=dict)
62+
def resource(self):
63+
return {
64+
"dict_type_field": self.dict_type_field,
65+
}
66+
67+
68+
class PydanticFieldEndpointDefinition(EndpointDefinition):
69+
def is_authorized(self):
70+
return True
71+
72+
class _TestData(pydantic.BaseModel):
73+
length: int
74+
description: str
75+
timestamp: datetime.datetime
76+
words: List[str]
77+
78+
pydantic_type_field = field(type=_TestData, required=True)
79+
80+
@endpoint_resource(type=dict)
81+
def resource(self):
82+
return {
83+
"pydantic_type_field": json.loads(self.pydantic_type_field.json()),
84+
}
85+
86+
87+
class _ModelA(pydantic.BaseModel):
88+
a: str
89+
90+
91+
class _ModelB(pydantic.BaseModel):
92+
b: str
93+
c: _ModelA
94+
95+
96+
class NestedPydanticFieldEndpointDefinition(EndpointDefinition):
97+
def is_authorized(self):
98+
return True
99+
100+
nested_pydantic_type_field = field(type=_ModelB, required=True)
101+
102+
@endpoint_resource(type=dict)
103+
def resource(self):
104+
return {
105+
"nested_pydantic_type_field": json.loads(
106+
self.nested_pydantic_type_field.json()
107+
),
108+
}

0 commit comments

Comments
 (0)