Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 19 additions & 31 deletions api_v2/views/mixins/eager_loading_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,50 +14,38 @@ class (ie. ReadOnlyModelViewSet).

## Usage Example
```
# EagerLoadingMixin inhertired before base-case
class CreatureViewSet(EagerLoadingMixin, viewsets.ReadOnlyModelViewSet):
queryset = models.Creature.objects.all().order_by('pk')
serializer_class = serializers.CreatureSerializer
filterset_class = CreatureFilterSet

select_related_fields = [] # ForeignKey relations to optimise with select_related()
prefetch_related_fields = [] # ManyToMany/reverse relations to optimise with prefetch_related()
```
"""

# Override these lists in child views
select_related_fields = [] # ForeignKeys to optimise
prefetch_related_fields = [] # ManyToMany & reverse relationships to prefetch
select_related_fields = []
prefetch_related_fields = []

def get_queryset(self):
"""Override DRF's default get_queryset() method to apply eager loading"""
queryset = super().get_queryset()

# Get query parameters from request
requested_fields = self.request.query_params.get('fields', '').split(',')
depth = int(self.request.query_params.get('depth', 0))
filtered_select_fields = self.filter_fields(self.select_related_fields, requested_fields)
filtered_prefetch_fields = self.filter_fields(self.prefetch_related_fields, requested_fields)

# if no fields are passed via query param, select/prefetch all fields defined on the view
if not requested_fields:
queryset = queryset.select_related(*self.select_related_fields)
queryset = queryset.prefetch_related(*self.prefetch_related_fields)
return queryset

# filter selected fields against fields requested by user via query params
# this stops Django prefetching data that isn't even returned by this view
select_fields = []
for field_to_select in self.select_related_fields:
if any(field_in_request in field_to_select for field_in_request in requested_fields):
select_fields.append(field_to_select)

# filter prefetch fields against fields requested by user via query params
# this stops Django prefetching data that isn't even returned by this view
prefetch_fields = []
for field_to_prefetch in self.prefetch_related_fields:
if any(field_in_request in field_to_prefetch for field_in_request in requested_fields):
prefetch_fields.append(field_to_prefetch)
return queryset \
.select_related(*filtered_select_fields) \
.prefetch_related(*filtered_prefetch_fields)

# Apply filtered optimisations
queryset = queryset.select_related(*select_fields)
queryset = queryset.prefetch_related(*prefetch_fields)
return queryset
def filter_fields(self, related_fields, requested_fields):
"""
Filters'related_fields' according to whether they are included in
'requested_fields'. Used to remove fields from eager loading if they are
not requested (and thus not returned by API), avoiding unnecessary DB calls
"""
if not any(requested_fields):
return related_fields
return [
related_field for related_field in related_fields
if any(related_field == req or related_field.startswith(req + '__') for req in requested_fields)
]
Loading