Skip to content

Commit 40467e1

Browse files
bgrantdshafer
andauthored
chore: Fix save behavior for Django 3 and drop Django 2 support (#89)
* fix BoundEndpointManager get_response saves * use EndpointResourceAttribute on resource * preserve type arg on ResourceEndpointDefinition * Merge __init__ methods * Fix typo * update fields in memory that changed on save to the database * Fix getattr when attr doesn't exist * reset_state * need to manage Django model state * check_relationship=False for the in-memory update * deal with foreign keys in update_or_create * Guard against type=None * django3.2 moved an exception * Add a couple of docstrings * Don't LYIL * Don't bother with old import * Drop Django 2 dependency * Drop Django 2 testing * Update CHANGELOG Co-authored-by: Drew Shafer <[email protected]>
1 parent ccce1d9 commit 40467e1

File tree

7 files changed

+138
-58
lines changed

7 files changed

+138
-58
lines changed

.github/workflows/tests.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ jobs:
99
strategy:
1010
matrix:
1111
python-version: ['3.7', '3.8', '3.9', '3.10']
12-
django: ['2.2', '3.2']
1312

1413
steps:
1514
- uses: actions/checkout@v3

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Versioning](https://semver.org/spec/v2.0.0.html).
1616
### Fixed
1717
- [PR 95](https://github.com/salesforce/django-declarative-apis/pull/95) Fill in missing CHANGELOG entries
1818
- [PR 90](https://github.com/salesforce/django-declarative-apis/pull/90) Fix `BoundEndpointManager` `get_response` saves
19+
- [PR 89](https://github.com/salesforce/django-declarative-apis/pull/89) Fix save behavior for Django 3 and drop Django 2 support
1920

2021
### Changed
2122
- [PR 93](https://github.com/salesforce/django-declarative-apis/pull/93) Remove spaces in name of `test` GitHub Action

django_declarative_apis/machinery/__init__.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from django.conf import settings
1515
from django.http import HttpResponse
1616

17+
from dirtyfields.dirtyfields import reset_state
18+
1719
from django_declarative_apis.machinery.filtering import apply_filters_to_object
1820
from django_declarative_apis.models import BaseConsumer
1921
from django_declarative_apis.resources.utils import HttpStatusCode
@@ -36,7 +38,7 @@
3638
ResourceField,
3739
)
3840

39-
# these imports are unusued in this file but may be used in other projects
41+
# these imports are unused in this file but may be used in other projects
4042
# that use `machinery` as an interface
4143
from .attributes import TypedEndpointAttributeMixin, RequestFieldGroup # noqa
4244
from .utils import locate_object, rate_limit_exceeded
@@ -84,7 +86,7 @@ def resource(self):
8486
return Todo.objects.get(id=self.resource_id)
8587
"""
8688

87-
def __init__(self, type, filter=None, returns_list=False, **kwargs):
89+
def __init__(self, type=None, filter=None, returns_list=False, **kwargs):
8890
super().__init__(**kwargs)
8991
self.type = type
9092
self.filter = filter
@@ -100,10 +102,12 @@ def get_instance_value(self, owner_instance, owner_class):
100102
return self
101103
try:
102104
value = self.func(owner_instance)
103-
except django.core.exceptions.ObjectDoesNotExist:
104-
raise errors.ClientErrorNotFound(
105-
"{0} instance not found".format(self.type.__name__)
106-
)
105+
except django.core.exceptions.ObjectDoesNotExist as e: # noqa: F841
106+
try:
107+
message = f"{self.type.__name__} instance not found"
108+
except AttributeError as e: # noqa: F841
109+
message = "Resource instance not found"
110+
raise errors.ClientErrorNotFound(message)
107111

108112
if value.__class__ == dict:
109113
return value
@@ -173,6 +177,35 @@ def __init__(cls, class_name, bases=None, dict=None):
173177
pass
174178

175179

180+
def current_dirty_dict(resource):
181+
"""Get the `current` (in-memory) values for fields that have not yet been written to the database."""
182+
new_data = resource.get_dirty_fields(check_relationship=True, verbose=True)
183+
field_name_to_att_name = {f.name: f.attname for f in resource._meta.concrete_fields}
184+
return {
185+
field_name_to_att_name[key]: values["current"]
186+
for key, values in new_data.items()
187+
}
188+
189+
190+
def update_dirty(resource):
191+
"""Write dirty fields to the database."""
192+
dirty_dict = current_dirty_dict(resource)
193+
resource_next, created = type(resource).objects.update_or_create(
194+
pk=resource.pk, defaults=dirty_dict
195+
)
196+
197+
# update fields in memory that changed on save to the database
198+
field_name_to_att_name = {f.name: f.attname for f in resource._meta.concrete_fields}
199+
for k, v in resource_next._as_dict(check_relationship=True).items():
200+
att_key = field_name_to_att_name[k]
201+
if getattr(resource, att_key, None) != v:
202+
setattr(resource, att_key, v)
203+
resource._state.adding = False
204+
resource._state.db = resource_next._state.db
205+
resource._state.fields_cache = {}
206+
reset_state(type(resource), resource)
207+
208+
176209
class EndpointBinder:
177210
class BoundEndpointManager:
178211
def __init__(self, manager, bound_endpoint):
@@ -197,7 +230,7 @@ def get_response(self): # noqa: C901
197230

198231
if hasattr(resource, "is_dirty"):
199232
if resource and resource.is_dirty(check_relationship=True):
200-
resource.save()
233+
update_dirty(resource)
201234

202235
endpoint_tasks = sorted(
203236
self.manager.endpoint_tasks, key=lambda t: t.priority
@@ -213,13 +246,18 @@ def get_response(self): # noqa: C901
213246
immediate_task.run(self.bound_endpoint)
214247

215248
except errors.ClientError as ce:
216-
if ce.save_changes and resource and resource.is_dirty():
217-
resource.save()
249+
if (
250+
ce.save_changes
251+
and resource
252+
and hasattr(resource, "is_dirty")
253+
and resource.is_dirty()
254+
):
255+
update_dirty(resource)
218256
raise
219257

220258
if hasattr(resource, "is_dirty"):
221259
if resource and resource.is_dirty(check_relationship=True):
222-
resource.save()
260+
update_dirty(resource)
223261

224262
for deferred_task in deferred_tasks:
225263
deferred_task.run(self.bound_endpoint)
@@ -904,10 +942,10 @@ class ResourceEndpointDefinition(EndpointDefinition):
904942
"""
905943

906944
def __init__(self, *args, **kwargs):
907-
super().__init__()
945+
super().__init__(*args, **kwargs)
908946
self._cached_resource = None
909947

910-
@property
948+
@EndpointResourceAttribute()
911949
def resource(self):
912950
"""Queries the object manager of `self.resource_model` for the given id
913951
(`self.resource_id`).

django_declarative_apis/machinery/filtering.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from django.db import models
1313
from django.db.models import ManyToOneRel
14+
from django.core.exceptions import FieldDoesNotExist
1415

1516
NEVER = 0
1617
ALWAYS = 1
@@ -80,7 +81,7 @@ def _get_filtered_field_value(
8081
else:
8182
try:
8283
val = getattr(inst, field_name)
83-
except (AttributeError, models.fields.FieldDoesNotExist) as e: # noqa
84+
except (AttributeError, FieldDoesNotExist) as e: # noqa
8485
return None
8586

8687
if isinstance(val, models.Manager):

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Django >=2.2, <4
1+
Django >=3.2, <4
22
celery>=4.0.2,!=4.1.0
33
cryptography>=2.0,<=3.4.8
44
decorator==4.0.11

tests/machinery/test_base.py

Lines changed: 73 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from django_declarative_apis.machinery import errors, filtering, tasks
2222
from django_declarative_apis.machinery.tasks import future_task_runner
2323
from django_declarative_apis.resources.utils import HttpStatusCode
24-
from tests import testutils
24+
from tests import testutils, models
2525

2626
_TEST_RESOURCE = {"foo": "bar"}
2727

@@ -147,70 +147,96 @@ def resource(self):
147147
mock_logging.error.assert_called_with("('something bad happened',)\nNone")
148148

149149
def test_get_response_with_dirty_resource(self):
150-
class _TestResource:
151-
def is_dirty(self, check_relationship=False):
152-
return True
150+
class _TestEndpoint1(machinery.EndpointDefinition):
151+
@machinery.endpoint_resource(type=models.DirtyFieldsModel)
152+
def resource(self):
153+
result = models.DirtyFieldsModel(field="abcde")
154+
result.fk_field = models.TestModel.objects.create(int_field=1)
155+
return result
156+
157+
class _TestEndpoint2(machinery.EndpointDefinition):
158+
@machinery.endpoint_resource(type=models.DirtyFieldsModel)
159+
def resource(self):
160+
result = models.DirtyFieldsModel(field="abcde")
161+
result.fk_field = models.TestModel.objects.create(int_field=1)
162+
return result
153163

154-
def save(self):
164+
@machinery.task
165+
def null_task(self):
155166
pass
156167

157-
class _TestEndpoint(machinery.EndpointDefinition):
158-
@machinery.endpoint_resource(type=_TestResource)
168+
class _TestEndpoint3(machinery.EndpointDefinition):
169+
@machinery.endpoint_resource(type=models.DirtyFieldsModel)
159170
def resource(self):
160-
return _TestResource()
171+
result = models.DirtyFieldsModel(field="abcde")
172+
result.fk_field = models.TestModel.objects.create(int_field=1)
173+
return result
161174

162-
endpoint = _TestEndpoint()
163-
manager = machinery.EndpointBinder.BoundEndpointManager(
164-
machinery._EndpointRequestLifecycleManager(endpoint), endpoint
165-
)
175+
@machinery.task
176+
def task(self):
177+
self.resource.field = "zyxwv"
178+
179+
for test_name, endpoint_cls, expected_call_count in (
180+
("No Task", _TestEndpoint1, 1),
181+
("No-op Task", _TestEndpoint2, 1),
182+
("With Task", _TestEndpoint3, 2),
183+
):
184+
with self.subTest(test_name):
185+
endpoint = endpoint_cls()
186+
manager = machinery.EndpointBinder.BoundEndpointManager(
187+
machinery._EndpointRequestLifecycleManager(endpoint), endpoint
188+
)
166189

167-
class _FakeRequest:
168-
META = {}
190+
class _FakeRequest:
191+
META = {}
169192

170-
manager.bound_endpoint.request = _FakeRequest()
193+
manager.bound_endpoint.request = _FakeRequest()
171194

172-
with mock.patch.object(_TestResource, "save", return_value=None) as mock_save:
173-
manager.get_response()
174-
# save is called before and after tasks. since we've hardcoded _TestResource.is_dirty to return True,
175-
# both of them should fire
176-
self.assertEqual(mock_save.call_count, 2)
195+
with mock.patch.object(
196+
models.DirtyFieldsModel.objects,
197+
"update_or_create",
198+
wraps=models.DirtyFieldsModel.objects.update_or_create,
199+
) as mock_uoc:
200+
manager.get_response()
201+
self.assertEqual(mock_uoc.call_count, expected_call_count)
177202

178203
def test_get_response_with_client_error_while_executing_tasks(self):
179-
class _TestResource:
180-
def is_dirty(self, check_relationship=False):
181-
return True
182-
183-
def save(self):
184-
pass
185-
186204
class _TestEndpoint(machinery.EndpointDefinition):
187-
@machinery.endpoint_resource(type=_TestResource)
205+
@machinery.endpoint_resource(type=models.DirtyFieldsModel)
188206
def resource(self):
189-
return _TestResource()
207+
result = models.DirtyFieldsModel(id=1, field="abcde")
208+
result.fk_field = models.TestModel.objects.create(int_field=1)
209+
return result
190210

191211
@machinery.task
192212
def raise_an_exception(self):
213+
self.resource.field = "zyxwv"
193214
raise errors.ClientError(
194215
code=http.HTTPStatus.BAD_REQUEST,
195216
message="something bad happened",
196217
save_changes=error_should_save_changes,
197218
)
198219

199220
for error_should_save_changes in (True, False):
200-
with mock.patch.object(_TestResource, "save") as mock_save:
201-
endpoint = _TestEndpoint()
202-
manager = machinery.EndpointBinder.BoundEndpointManager(
203-
machinery._EndpointRequestLifecycleManager(endpoint), endpoint
204-
)
205-
try:
206-
manager.get_response()
207-
self.fail("This should have failed")
208-
except errors.ClientError:
209-
# save should be called twice if the exception says the resource should be saved: once before
210-
# tasks are executed and once during exception handling.
211-
self.assertEqual(
212-
mock_save.call_count, 2 if error_should_save_changes else 1
221+
with self.subTest(f"error_should_save_changes={error_should_save_changes}"):
222+
with mock.patch.object(
223+
models.DirtyFieldsModel.objects,
224+
"update_or_create",
225+
wraps=models.DirtyFieldsModel.objects.update_or_create,
226+
) as mock_uoc:
227+
endpoint = _TestEndpoint()
228+
manager = machinery.EndpointBinder.BoundEndpointManager(
229+
machinery._EndpointRequestLifecycleManager(endpoint), endpoint
213230
)
231+
try:
232+
manager.get_response()
233+
self.fail("This should have failed")
234+
except errors.ClientError:
235+
# save should be called twice if the exception says the resource should be saved: once before
236+
# tasks are executed and once during exception handling.
237+
self.assertEqual(
238+
mock_uoc.call_count, 2 if error_should_save_changes else 1
239+
)
214240

215241
def test_get_response_custom_http_response(self):
216242
expected_data = {"foo": "bar"}
@@ -280,7 +306,11 @@ class _QuerySet(list):
280306
data = _QuerySet([_TestResource("foo", "bar"), _TestResource("bar", "baz")])
281307

282308
filter_def = {
283-
_TestResource: {"name": filtering.ALWAYS, "secret": filtering.NEVER}
309+
_TestResource: {
310+
"name": filtering.ALWAYS,
311+
"secret": filtering.NEVER,
312+
"foo": filtering.ALWAYS,
313+
}
284314
}
285315

286316
class _TestEndpoint(machinery.EndpointDefinition):

tests/models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,14 @@ class ParentModel(models.Model):
5151
class RootNode(models.Model):
5252
id = models.IntegerField(primary_key=True)
5353
parent_field = models.ForeignKey(ParentModel, on_delete=models.CASCADE)
54+
55+
56+
try:
57+
import dirtyfields
58+
59+
class DirtyFieldsModel(dirtyfields.DirtyFieldsMixin, models.Model):
60+
field = models.CharField(max_length=100)
61+
fk_field = models.ForeignKey(TestModel, null=False, on_delete=models.CASCADE)
62+
63+
except Exception:
64+
pass

0 commit comments

Comments
 (0)