diff --git a/openedx/core/djangoapps/enrollments/forms.py b/openedx/core/djangoapps/enrollments/forms.py index e87210b3fee4..8c2a7a1d01bc 100644 --- a/openedx/core/djangoapps/enrollments/forms.py +++ b/openedx/core/djangoapps/enrollments/forms.py @@ -14,13 +14,68 @@ class CourseEnrollmentsApiListForm(Form): """ A form that validates the query string parameters for the CourseEnrollmentsApiListView. + + ADR 0033 – OEP-68 parameter naming standardization: + - ``course_key`` is the preferred parameter name; ``course_id`` is accepted + as a deprecated alias (BC strategy §1). When both are present, + ``course_key`` wins. + - ``course_keys`` is the preferred parameter name; ``course_ids`` is + accepted as a deprecated alias (same precedence rule). + Internally the cleaned_data continues to expose ``course_id`` / + ``course_ids`` so call sites do not need to change. Use + :meth:`legacy_param_aliases_used` to detect when the deprecated names were + sent by the client (used to emit the ``Deprecation`` HTTP header). """ MAX_INPUT_COUNT = 100 + # Legacy / OEP-68 alias pairs: (legacy, preferred). + _LEGACY_PARAM_ALIASES = ( + ("course_id", "course_key"), + ("course_ids", "course_keys"), + ) + username = CharField(required=False) course_id = CharField(required=False) + course_key = CharField(required=False) course_ids = CharField(required=False) + course_keys = CharField(required=False) email = CharField(required=False) + def __init__(self, query_params, *args, **kwargs): + # Capture the raw param names supplied on the wire, *before* Django's + # form layer resolves aliases, so :meth:`legacy_param_aliases_used` + # can later report exactly which legacy names were used. + try: + raw_keys = set(query_params.keys()) + except AttributeError: + raw_keys = set() + self._raw_param_names = raw_keys + + # Coalesce OEP-68 preferred names into the legacy fields so the + # downstream view code keeps reading ``course_id`` / ``course_ids`` + # without changes. Preferred wins when both are sent. + if hasattr(query_params, "copy"): + data = query_params.copy() + else: + data = dict(query_params) + for legacy_name, preferred_name in self._LEGACY_PARAM_ALIASES: + preferred_value = data.get(preferred_name) + if preferred_value: + data[legacy_name] = preferred_value + + super().__init__(data, *args, **kwargs) + + def legacy_param_aliases_used(self): + """ + Return the list of legacy (OEP-68-violating) parameter names actually + present in the request, in declaration order. + + Used by the view layer to emit the ADR 0033 ``Deprecation`` header. + """ + return [ + legacy for legacy, _preferred in self._LEGACY_PARAM_ALIASES + if legacy in self._raw_param_names + ] + def clean_course_id(self): """ Validate and return a course ID. diff --git a/openedx/core/djangoapps/enrollments/paginators.py b/openedx/core/djangoapps/enrollments/paginators.py index e7534c05bc67..554e3a9bd791 100644 --- a/openedx/core/djangoapps/enrollments/paginators.py +++ b/openedx/core/djangoapps/enrollments/paginators.py @@ -3,14 +3,19 @@ """ -from rest_framework.pagination import CursorPagination +from edx_rest_framework_extensions.paginators import DefaultPagination # ADR 0032 -class CourseEnrollmentsApiListPagination(CursorPagination): +class CourseEnrollmentsApiListPagination(DefaultPagination): """ - Paginator for the Course enrollments list API. + ADR 0032 – standard pagination for the admin enrollments list API + (GET /api/enrollment/v1/enrollments). + + Extends DefaultPagination with a larger default page size appropriate + for an admin-facing, bulk-query endpoint. The full 7-field response + envelope (count, num_pages, current_page, start, next, previous, + results) is provided by DefaultPagination.get_paginated_response. """ page_size = 100 page_size_query_param = 'page_size' max_page_size = 100 - page_query_param = 'page' diff --git a/openedx/core/djangoapps/enrollments/serializers.py b/openedx/core/djangoapps/enrollments/serializers.py index 9b64cc95caf8..10d9ef2b3932 100644 --- a/openedx/core/djangoapps/enrollments/serializers.py +++ b/openedx/core/djangoapps/enrollments/serializers.py @@ -137,3 +137,22 @@ class Meta: model = CourseEnrollmentAllowed exclude = ["id"] lookup_field = "user" + + +class UserRoleSerializer(serializers.Serializer): # pylint: disable=abstract-method + """Serializes a single course-level role entry for a user.""" + + org = serializers.CharField() + course_id = serializers.SerializerMethodField() + role = serializers.CharField() + + def get_course_id(self, obj): + """Return course_id as a string.""" + return str(obj.course_id) + + +class UserRolesResponseSerializer(serializers.Serializer): # pylint: disable=abstract-method + """Serializes the full response payload for EnrollmentUserRolesView.""" + + roles = UserRoleSerializer(many=True) + is_staff = serializers.BooleanField() diff --git a/openedx/core/djangoapps/enrollments/tests/test_views.py b/openedx/core/djangoapps/enrollments/tests/test_views.py index 41c5a9624c24..86f7a4aeb586 100644 --- a/openedx/core/djangoapps/enrollments/tests/test_views.py +++ b/openedx/core/djangoapps/enrollments/tests/test_views.py @@ -36,7 +36,7 @@ from openedx.core.djangoapps.course_groups import cohorts from openedx.core.djangoapps.embargo.models import Country, CountryAccessRule, RestrictedCourse from openedx.core.djangoapps.embargo.test_utils import restrict_course -from openedx.core.djangoapps.enrollments import api, data +from openedx.core.djangoapps.enrollments import data from openedx.core.djangoapps.enrollments.errors import CourseEnrollmentError from openedx.core.djangoapps.enrollments.views import EnrollmentUserThrottle from openedx.core.djangoapps.notifications.config.waffle import ENABLE_NOTIFICATIONS @@ -711,9 +711,9 @@ def test_get_enrollment_details_bad_course(self): ) assert resp.status_code == status.HTTP_400_BAD_REQUEST - @patch.object(api, "get_enrollment") - def test_get_enrollment_internal_error(self, mock_get_enrollment): - mock_get_enrollment.side_effect = CourseEnrollmentError("Something bad happened.") + @patch.object(CourseEnrollment.objects, "get") + def test_get_enrollment_internal_error(self, mock_get): + mock_get.side_effect = CourseEnrollmentError("Something bad happened.") resp = self.client.get( reverse( 'courseenrollment', @@ -2031,3 +2031,825 @@ def test_delete_enrollment_allowed(self, delete_data, expected_result): self.client.post(self.url, self.data) response = self.client.delete(self.url, delete_data) assert response.status_code == expected_result + + # --- Response-shape tests (ADR 0025 serializer migration) --- + + def test_post_response_shape(self): + """POST 201 response contains the expected fields from CourseEnrollmentAllowedSerializer.""" + response = self.client.post(self.url, self.data) + assert response.status_code == status.HTTP_201_CREATED + body = response.json() + assert body['email'] == self.data['email'] + assert body['course_id'] == self.data['course_id'] + assert body['auto_enroll'] is False + assert 'created' in body + + def test_post_auto_enroll_true_in_response(self): + """POST with auto_enroll=true is reflected in the 201 response.""" + response = self.client.post(self.url, {**self.data, 'auto_enroll': True}) + assert response.status_code == status.HTTP_201_CREATED + assert response.json()['auto_enroll'] is True + + def test_post_missing_email_returns_field_error(self): + """POST without email returns a serializer field-level 400 with an 'email' key.""" + response = self.client.post(self.url, {'course_id': self.data['course_id']}) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert 'email' in response.json() + + def test_post_missing_course_id_returns_field_error(self): + """POST without course_id returns a serializer field-level 400 with a 'course_id' key.""" + response = self.client.post(self.url, {'email': self.data['email']}) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert 'course_id' in response.json() + + def test_post_duplicate_returns_409_with_message(self): + """A duplicate POST returns 409 with a 'message' key.""" + self.client.post(self.url, self.data) + response = self.client.post(self.url, self.data) + assert response.status_code == status.HTTP_409_CONFLICT + assert 'message' in response.json() + + def test_get_response_is_list(self): + """GET response body is a JSON list.""" + response = self.client.get(self.url, {'email': self.data['email']}) + assert response.status_code == status.HTTP_200_OK + assert isinstance(response.json(), list) + + def test_get_empty_response_is_empty_list(self): + """GET with no matching enrollments returns an empty list, not null.""" + response = self.client.get(self.url, {'email': 'nobody@example.com'}) + assert response.status_code == status.HTTP_200_OK + assert response.json() == [] + + def test_get_item_shape(self): + """Each item in the GET response has the fields from CourseEnrollmentAllowedSerializer.""" + self.client.post(self.url, self.data) + response = self.client.get(self.url, {'email': self.data['email']}) + assert response.status_code == status.HTTP_200_OK + item = response.json()[0] + assert item['email'] == self.data['email'] + assert item['course_id'] == self.data['course_id'] + assert 'auto_enroll' in item + assert 'created' in item + + def test_get_multiple_entries_returned(self): + """GET returns all enrollment-allowed records for a given email.""" + second_course = 'course-v1:edX+OtherX+Other_Course' + self.client.post(self.url, self.data) + self.client.post(self.url, {'email': self.data['email'], 'course_id': second_course}) + response = self.client.get(self.url, {'email': self.data['email']}) + assert response.status_code == status.HTTP_200_OK + results = response.json() + assert len(results) == 2 + assert all(r['email'] == self.data['email'] for r in results) + + def test_delete_missing_email_returns_field_error(self): + """DELETE without email returns a serializer field-level 400 with an 'email' key.""" + self.client.post(self.url, self.data) + response = self.client.delete(self.url, {'course_id': self.data['course_id']}) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert 'email' in response.json() + + +@skip_unless_lms +class EnrollmentViewResponseShapeTest(ModuleStoreTestCase, APITestCase): + """ + Tests that verify EnrollmentView (GET /enrollment/v1/enrollment/{course_id} and + /enrollment/v1/enrollment/{username},{course_id}) response structure is preserved + after migrating to direct serializer usage (ADR 0025). + """ + + USERNAME = "Bob" + PASSWORD = "edx" + + def setUp(self): + super().setUp() + self.course = CourseFactory.create(emit_signals=True) + self.user = UserFactory.create(username=self.USERNAME, password=self.PASSWORD) + self.client.login(username=self.USERNAME, password=self.PASSWORD) + CourseModeFactory.create( + course_id=self.course.id, + mode_slug=CourseMode.DEFAULT_MODE_SLUG, + mode_display_name=CourseMode.DEFAULT_MODE_SLUG, + ) + CourseEnrollment.enroll(self.user, self.course.id) + + def _get_by_course_id(self): + return self.client.get( + reverse('courseenrollment', kwargs={'course_id': str(self.course.id)}) + ) + + def _get_by_username_and_course_id(self): + return self.client.get( + reverse('courseenrollment', kwargs={'username': self.USERNAME, 'course_id': str(self.course.id)}) + ) + + def test_get_by_course_id_returns_200(self): + assert self._get_by_course_id().status_code == status.HTTP_200_OK + + def test_get_by_username_course_id_returns_200(self): + assert self._get_by_username_and_course_id().status_code == status.HTTP_200_OK + + def test_get_response_top_level_fields(self): + """Response contains the expected top-level enrollment fields.""" + body = self._get_by_course_id().json() + for field in ('created', 'mode', 'is_active', 'user', 'course_details'): + assert field in body, f"Missing top-level field: {field}" + + def test_get_response_user_and_mode(self): + """user and mode values match the enrollment.""" + body = self._get_by_course_id().json() + assert body['user'] == self.USERNAME + assert body['mode'] == CourseMode.DEFAULT_MODE_SLUG + assert body['is_active'] is True + + def test_get_by_username_course_id_matches_by_course_id(self): + """Both URL shapes return identical response bodies.""" + by_course = self._get_by_course_id().json() + by_username = self._get_by_username_and_course_id().json() + assert by_course == by_username + + def test_get_course_details_fields(self): + """course_details contains the expected nested fields.""" + course_details = self._get_by_course_id().json()['course_details'] + for field in ( + 'course_id', 'course_name', 'enrollment_start', 'enrollment_end', + 'course_start', 'course_end', 'invite_only', 'course_modes', 'pacing_type', + ): + assert field in course_details, f"Missing course_details field: {field}" + assert course_details['course_id'] == str(self.course.id) + + def test_get_no_enrollment_returns_null(self): + """GET for a course the user never enrolled in returns HTTP 200 with a null body.""" + unenrolled_course = CourseFactory.create(emit_signals=True) + resp = self.client.get( + reverse('courseenrollment', kwargs={'course_id': str(unenrolled_course.id)}) + ) + assert resp.status_code == status.HTTP_200_OK + assert resp.json() is None + + +@skip_unless_lms +class EnrollmentCourseDetailViewResponseShapeTest(ModuleStoreTestCase, APITestCase): + """ + Tests that verify EnrollmentCourseDetailView (GET /enrollment/v1/course/{course_id}) + response structure is preserved after migrating to CourseSerializer + direct ORM (ADR 0025). + """ + + def setUp(self): + super().setUp() + self.course = CourseFactory.create(emit_signals=True) + CourseModeFactory.create( + course_id=self.course.id, + mode_slug=CourseMode.DEFAULT_MODE_SLUG, + mode_display_name=CourseMode.DEFAULT_MODE_SLUG, + ) + + def _get_course_details(self, course_id=None, include_expired=False): + url = reverse('courseenrollmentdetails', kwargs={'course_id': course_id or str(self.course.id)}) + if include_expired: + url += '?include_expired=1' + return self.client.get(url) + + def test_returns_200(self): + assert self._get_course_details().status_code == status.HTTP_200_OK + + def test_response_top_level_fields(self): + """Response contains the expected top-level CourseSerializer fields.""" + body = self._get_course_details().json() + for field in ('course_id', 'course_name', 'enrollment_start', 'enrollment_end', + 'course_start', 'course_end', 'invite_only', 'course_modes', 'pacing_type'): + assert field in body, f"Missing field: {field}" + + def test_course_id_matches_requested_course(self): + body = self._get_course_details().json() + assert body['course_id'] == str(self.course.id) + + def test_course_modes_is_list(self): + body = self._get_course_details().json() + assert isinstance(body['course_modes'], list) + + def test_course_mode_fields(self): + """Each mode entry contains the expected fields.""" + body = self._get_course_details().json() + mode = body['course_modes'][0] + for field in ('slug', 'name', 'min_price', 'suggested_prices', 'currency', + 'expiration_datetime', 'description', 'sku', 'bulk_sku'): + assert field in mode, f"Missing course_mode field: {field}" + + def test_invalid_course_id_returns_400(self): + resp = self._get_course_details(course_id='not/a/real/course') + assert resp.status_code == status.HTTP_400_BAD_REQUEST + + def test_nonexistent_course_returns_400(self): + resp = self._get_course_details(course_id='course-v1:Org+NonExistent+2099') + assert resp.status_code == status.HTTP_400_BAD_REQUEST + + +@skip_unless_lms +class EnrollmentListViewResponseShapeTest(ModuleStoreTestCase, APITestCase): + """ + Tests that verify EnrollmentListView (GET /enrollment/v1/enrollment) + response structure is preserved after migrating to CourseEnrollmentSerializer + ORM (ADR 0025). + """ + + USERNAME = "TestLearner" + PASSWORD = "edx" + + def setUp(self): + super().setUp() + self.course = CourseFactory.create(emit_signals=True) + CourseModeFactory.create( + course_id=self.course.id, + mode_slug=CourseMode.DEFAULT_MODE_SLUG, + mode_display_name=CourseMode.DEFAULT_MODE_SLUG, + ) + self.user = UserFactory.create(username=self.USERNAME, password=self.PASSWORD) + self.client.login(username=self.USERNAME, password=self.PASSWORD) + CourseEnrollment.enroll(self.user, self.course.id) + + def _get_enrollments(self, user=None): + url = reverse('courseenrollments') + if user: + url += f'?user={user}' + return self.client.get(url) + + def test_returns_200(self): + assert self._get_enrollments().status_code == status.HTTP_200_OK + + def test_response_is_list(self): + body = self._get_enrollments().json() + assert isinstance(body, list) + + def test_enrollment_top_level_fields(self): + """Each enrollment entry contains the expected top-level fields.""" + body = self._get_enrollments().json() + assert len(body) >= 1 + entry = body[0] + for field in ('created', 'mode', 'is_active', 'user', 'course_details'): + assert field in entry, f"Missing top-level field: {field}" + + def test_enrollment_user_and_mode_values(self): + body = self._get_enrollments().json() + entry = body[0] + assert entry['user'] == self.USERNAME + assert entry['mode'] == CourseMode.DEFAULT_MODE_SLUG + assert entry['is_active'] is True + + def test_enrollment_course_details_fields(self): + """course_details nested object contains the expected fields.""" + body = self._get_enrollments().json() + course_details = body[0]['course_details'] + for field in ('course_id', 'course_name', 'enrollment_start', 'enrollment_end', + 'course_start', 'course_end', 'invite_only', 'course_modes'): + assert field in course_details, f"Missing course_details field: {field}" + + def test_no_enrollments_returns_empty_list(self): + """A user with no enrollments gets an empty list, not null or an error.""" + new_user = UserFactory.create(password=self.PASSWORD) + self.client.login(username=new_user.username, password=self.PASSWORD) + body = self.client.get(reverse('courseenrollments')).json() + assert body == [] + + +@skip_unless_lms +class UserRoleViewResponseShapeTest(ModuleStoreTestCase): + """ + Tests that verify EnrollmentUserRolesView (GET /enrollment/v1/roles/) + response structure is preserved after migrating to UserRolesResponseSerializer (ADR 0025). + """ + + USERNAME = "RoleTester" + PASSWORD = "edx" + + def setUp(self): + super().setUp() + self.course = CourseFactory.create(emit_signals=True, org="testorg", course="c1", run="r1") + self.user = UserFactory.create(username=self.USERNAME, password=self.PASSWORD) + self.client.login(username=self.USERNAME, password=self.PASSWORD) + + def _get_roles(self, course_id=None): + url = reverse('roles') + if course_id: + url += f'?course_id={course_id}' + return self.client.get(url) + + def test_returns_200(self): + assert self._get_roles().status_code == status.HTTP_200_OK + + def test_response_top_level_keys(self): + """Response always contains 'roles' (list) and 'is_staff' (bool).""" + body = self._get_roles().json() + assert 'roles' in body + assert 'is_staff' in body + assert isinstance(body['roles'], list) + assert isinstance(body['is_staff'], bool) + + def test_no_roles_returns_empty_list(self): + body = self._get_roles().json() + assert body['roles'] == [] + assert body['is_staff'] is False + + def test_role_entry_shape(self): + """A role entry contains org, course_id, and role fields.""" + role = CourseStaffRole(self.course.id) + role.add_users(self.user) + body = self._get_roles().json() + assert len(body['roles']) == 1 + entry = body['roles'][0] + for field in ('org', 'course_id', 'role'): + assert field in entry, f"Missing role field: {field}" + assert entry['org'] == self.course.org + assert entry['course_id'] == str(self.course.id) + + def test_is_staff_true_for_staff_user(self): + staff_user = UserFactory.create(password=self.PASSWORD, is_staff=True) + self.client.login(username=staff_user.username, password=self.PASSWORD) + body = self._get_roles().json() + assert body['is_staff'] is True + + def test_filter_by_course_id(self): + """course_id query param filters roles to that course only.""" + course2 = CourseFactory.create(emit_signals=True, org="other", course="c2", run="r2") + CourseStaffRole(self.course.id).add_users(self.user) + CourseStaffRole(course2.id).add_users(self.user) + body = self._get_roles(course_id=str(self.course.id)).json() + assert all(r['course_id'] == str(self.course.id) for r in body['roles']) + + +# --------------------------------------------------------------------------- +# ADR 0028 – EnrollmentViewSet permission regression tests +# --------------------------------------------------------------------------- + +@skip_unless_lms +class TestEnrollmentViewSetList(APITestCase): + """ + ADR 0028 – permission regression tests for EnrollmentViewSet.list (GET /enrollment/). + Migrated from EnrollmentListView. + """ + API_KEY = "test-api-key" + + def setUp(self): + super().setUp() + self.user = UserFactory.create(password="test") + self.url = reverse("enrollment-list") + + def test_unauthenticated_gets_401(self): + """Unauthenticated request must be rejected.""" + response = self.client.get(self.url) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + @patch("openedx.core.djangoapps.enrollments.views.CourseEnrollment.objects") + def test_authenticated_user_gets_200(self, mock_objects): + """An authenticated user must reach the list action (permission check passes).""" + mock_objects.filter.return_value.select_related.return_value = [] + self.client.force_authenticate(user=self.user) + response = self.client.get(self.url) + assert response.status_code == status.HTTP_200_OK + + @patch("openedx.core.djangoapps.enrollments.views.CourseEnrollment.objects") + def test_valid_api_key_gets_200(self, mock_objects): + """A valid API key must bypass session auth and reach the list action.""" + mock_objects.filter.return_value.select_related.return_value = [] + with override_settings(EDX_API_KEY=self.API_KEY): + response = self.client.get(self.url, HTTP_X_EDX_API_KEY=self.API_KEY) + assert response.status_code == status.HTTP_200_OK + + def test_invalid_api_key_without_session_gets_401(self): + """An invalid API key without session auth must be rejected.""" + response = self.client.get(self.url, HTTP_X_EDX_API_KEY="wrong-key") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +@skip_unless_lms +class TestEnrollmentViewSetCreate(APITestCase): + """ + ADR 0028 – permission regression tests for EnrollmentViewSet.create (POST /enrollment/). + Migrated from EnrollmentListView. + """ + + def setUp(self): + super().setUp() + self.user = UserFactory.create(password="test") + self.url = reverse("enrollment-list") + + def test_unauthenticated_post_gets_401(self): + """Unauthenticated POST must be rejected.""" + response = self.client.post(self.url, data={}, content_type="application/json") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_authenticated_post_missing_course_id_gets_400(self): + """Authenticated POST without course_id must return 400.""" + self.client.force_authenticate(user=self.user) + response = self.client.post(self.url, data={}, content_type="application/json") + assert response.status_code == status.HTTP_400_BAD_REQUEST + + +@skip_unless_lms +class TestEnrollmentViewSetUnenroll(APITestCase): + """ + ADR 0028 – permission regression tests for EnrollmentViewSet.unenroll (POST /enrollment/unenroll/). + Migrated from UnenrollmentView. Requires IsAuthenticated + CanRetireUser. + """ + + def setUp(self): + super().setUp() + self.user = UserFactory.create(password="test") + self.url = reverse("enrollment-unenroll") + + def test_unauthenticated_gets_401(self): + """Unauthenticated request must be rejected.""" + response = self.client.post(self.url, data={"username": self.user.username}, content_type="application/json") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_authenticated_non_retirement_user_gets_403(self): + """An authenticated user without CanRetireUser permission must get 403.""" + self.client.force_authenticate(user=self.user) + response = self.client.post(self.url, data={"username": self.user.username}, content_type="application/json") + assert response.status_code == status.HTTP_403_FORBIDDEN + + +@skip_unless_lms +class TestEnrollmentViewSetAllowed(APITestCase): + """ + ADR 0028 – permission regression tests for EnrollmentViewSet.allowed + (GET/POST/DELETE /enrollment/enrollment_allowed/). Migrated from EnrollmentAllowedView. + Requires IsAdminUser. + """ + + def setUp(self): + super().setUp() + self.user = UserFactory.create(password="test") + self.admin = AdminFactory.create(password="test") + self.url = reverse("enrollment-allowed") + + def test_unauthenticated_get_gets_401(self): + """Unauthenticated GET must be rejected.""" + response = self.client.get(self.url) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_non_admin_get_gets_403(self): + """Regular authenticated user must get 403.""" + self.client.force_authenticate(user=self.user) + response = self.client.get(self.url) + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_admin_get_gets_200(self): + """Admin user must get 200 and an empty list.""" + self.client.force_authenticate(user=self.admin) + response = self.client.get(self.url) + assert response.status_code == status.HTTP_200_OK + assert response.data == [] + + def test_non_admin_post_gets_403(self): + """Regular authenticated user POST must get 403.""" + self.client.force_authenticate(user=self.user) + response = self.client.post( + self.url, + data={"email": "test@example.com", "course_id": "course-v1:edX+DemoX+Demo_Course"}, + content_type="application/json", + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_non_admin_delete_gets_403(self): + """Regular authenticated user DELETE must get 403.""" + self.client.force_authenticate(user=self.user) + response = self.client.delete( + self.url, + data={"email": "test@example.com", "course_id": "course-v1:edX+DemoX+Demo_Course"}, + content_type="application/json", + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +# --------------------------------------------------------------------------- +# ADR 0032 – Pagination standardization tests +# --------------------------------------------------------------------------- + +@skip_unless_lms +class TestCourseEnrollmentsApiListPaginatorStructure(APITestCase): + """ + ADR 0032 – structural checks for CourseEnrollmentsApiListPagination. + + Verifies that the paginator subclasses DefaultPagination (not CursorPagination) + and that CourseEnrollmentsApiListView wires it up correctly. + """ + + def test_paginator_is_defaultpagination_subclass(self): + """CourseEnrollmentsApiListPagination must subclass DefaultPagination (not CursorPagination).""" + from edx_rest_framework_extensions.paginators import DefaultPagination + from openedx.core.djangoapps.enrollments.paginators import CourseEnrollmentsApiListPagination + assert issubclass(CourseEnrollmentsApiListPagination, DefaultPagination) + + def test_view_uses_correct_paginator(self): + """CourseEnrollmentsApiListView.pagination_class must be CourseEnrollmentsApiListPagination.""" + from openedx.core.djangoapps.enrollments.paginators import CourseEnrollmentsApiListPagination + from openedx.core.djangoapps.enrollments.views import CourseEnrollmentsApiListView + assert CourseEnrollmentsApiListView.pagination_class is CourseEnrollmentsApiListPagination + + def test_enrollment_viewset_uses_defaultpagination(self): + """EnrollmentViewSet.pagination_class must be DefaultPagination (ADR 0032).""" + from edx_rest_framework_extensions.paginators import DefaultPagination + from openedx.core.djangoapps.enrollments.views import EnrollmentViewSet + assert EnrollmentViewSet.pagination_class is DefaultPagination + + +@skip_unless_lms +class TestCourseEnrollmentsApiListPaginationEnvelope(APITestCase): + """ + ADR 0032 – pagination envelope regression tests for CourseEnrollmentsApiListView + (GET /api/enrollment/v1/enrollments). + + Verifies that the response includes all 7 required envelope fields after + migration from CursorPagination to DefaultPagination. + """ + + def setUp(self): + super().setUp() + self.admin = UserFactory.create(is_staff=True, is_superuser=True) + self.client.force_authenticate(user=self.admin) + self.url = reverse("courseenrollmentsapilist") + + def test_response_includes_full_envelope(self): + """All 7 ADR 0032 envelope fields must be present in every paginated response.""" + response = self.client.get(self.url) + assert response.status_code == status.HTTP_200_OK + for field in ('count', 'num_pages', 'current_page', 'start', 'next', 'previous', 'results'): + assert field in response.data, f"ADR 0032: missing envelope field '{field}'" + + def test_current_page_is_one_on_first_page(self): + """?page=1 must return current_page=1 in the response envelope.""" + response = self.client.get(self.url, {'page': 1}) + assert response.status_code == status.HTTP_200_OK + assert response.data['current_page'] == 1 + + def test_start_is_zero_on_first_page(self): + """start must be 0 for the first page (0-based index of first item on the page).""" + response = self.client.get(self.url, {'page': 1}) + assert response.status_code == status.HTTP_200_OK + assert response.data['start'] == 0 + + def test_results_is_a_list(self): + """results must be a list (not null or a dict).""" + response = self.client.get(self.url) + assert response.status_code == status.HTTP_200_OK + assert isinstance(response.data['results'], list) + + +@skip_unless_lms +class TestEnrollmentViewSetListPaginationEnvelope(APITestCase): + """ + ADR 0032 – pagination envelope regression tests for EnrollmentViewSet.list + (GET /api/enrollment/v1/enrollment/). + + Verifies that the response includes all 7 required envelope fields after + adding DefaultPagination to the ViewSet. + """ + + def setUp(self): + super().setUp() + self.user = UserFactory.create(password="test") + self.client.force_authenticate(user=self.user) + self.url = reverse("enrollment-list") + + def test_response_includes_full_envelope(self): + """All 7 ADR 0032 envelope fields must be present in the list response.""" + response = self.client.get(self.url) + assert response.status_code == status.HTTP_200_OK + for field in ('count', 'num_pages', 'current_page', 'start', 'next', 'previous', 'results'): + assert field in response.data, f"ADR 0032: missing envelope field '{field}'" + + def test_current_page_is_one_on_first_page(self): + """?page=1 must return current_page=1 in the response envelope.""" + response = self.client.get(self.url, {'page': 1}) + assert response.status_code == status.HTTP_200_OK + assert response.data['current_page'] == 1 + + def test_start_is_zero_on_first_page(self): + """start must be 0 for the first page.""" + response = self.client.get(self.url, {'page': 1}) + assert response.status_code == status.HTTP_200_OK + assert response.data['start'] == 0 + + def test_results_is_a_list(self): + """results must be a list (empty for a new user with no enrollments).""" + response = self.client.get(self.url) + assert response.status_code == status.HTTP_200_OK + assert isinstance(response.data['results'], list) + + def test_count_reflects_user_enrollment_count(self): + """count must equal the number of enrollments for the user.""" + response = self.client.get(self.url) + assert response.status_code == status.HTTP_200_OK + expected = CourseEnrollment.objects.filter(user=self.user).count() + assert response.data['count'] == expected + + +# --------------------------------------------------------------------------- +# ADR 0033 – Sorting / OEP-68 parameter-naming standardization tests +# --------------------------------------------------------------------------- + + +_ADR_0033_DEPRECATION_HEADER_COURSE_ID = ( + "Parameter 'course_id' is deprecated. Use 'course_key' instead. " + "Support will be removed in release ''." +) +_ADR_0033_DEPRECATION_HEADER_COURSE_IDS = ( + "Parameter 'course_ids' is deprecated. Use 'course_keys' instead. " + "Support will be removed in release ''." +) +_ADR_0033_DEPRECATION_HEADER_COURSE_ID_AND_IDS = ( + "Parameter 'course_id' is deprecated. Use 'course_key' instead. " + "Parameter 'course_ids' is deprecated. Use 'course_keys' instead. " + "Support will be removed in release ''." +) + + +@skip_unless_lms +class TestCourseEnrollmentsApiListAdr0033(APITestCase, ModuleStoreTestCase): + """ + ADR 0033 – tests for CourseEnrollmentsApiListView (GET /api/enrollment/v1/enrollments). + + Covers: + * OEP-68 §2: ``course_key`` / ``course_keys`` work identically to the + legacy ``course_id`` / ``course_ids`` filters. + * BC §1: legacy and preferred names are accepted simultaneously; preferred + wins when both are sent. + * BC §2: the ``Deprecation`` HTTP header is emitted when (and only when) + a legacy name appears in the request — even if the preferred name is + also present. + * §3: the standard ``ordering`` parameter applies a whitelisted ORDER BY. + """ + + def setUp(self): + super().setUp() + self.url = reverse("courseenrollmentsapilist") + self.admin = AdminFactory(username="adr33admin", email="adr33admin@example.com", password="edx") + self.student_a = UserFactory(username="adr33a", email="a@example.com", password="edx") + self.student_b = UserFactory(username="adr33b", email="b@example.com", password="edx") + + self.course_a = CourseFactory.create(org="adr33", number="A", run="r", emit_signals=True) + self.course_b = CourseFactory.create(org="adr33", number="B", run="r", emit_signals=True) + + for mode_slug in ("honor", "audit"): + CourseModeFactory.create( + course_id=self.course_a.id, mode_slug=mode_slug, mode_display_name=mode_slug, + ) + CourseModeFactory.create( + course_id=self.course_b.id, mode_slug=mode_slug, mode_display_name=mode_slug, + ) + + data.create_course_enrollment(self.student_a.username, str(self.course_a.id), "honor", True) + data.create_course_enrollment(self.student_b.username, str(self.course_b.id), "audit", True) + + self.client.login(username=self.admin.username, password="edx") + + # ---- course_key / course_id ---- + + def test_new_course_key_param_no_header(self): + """``?course_key=…`` returns 200 and no ``Deprecation`` header.""" + response = self.client.get(self.url, {"course_key": str(self.course_a.id)}) + assert response.status_code == status.HTTP_200_OK + assert "Deprecation" not in response.headers + # Filtering still works through the alias path. + results = response.data["results"] + assert all(r["course_id"] == str(self.course_a.id) for r in results) + assert results + + def test_legacy_course_id_param_emits_header(self): + """``?course_id=…`` returns 200 and emits the ADR 0033 header.""" + response = self.client.get(self.url, {"course_id": str(self.course_a.id)}) + assert response.status_code == status.HTTP_200_OK + assert response.headers.get("Deprecation") == _ADR_0033_DEPRECATION_HEADER_COURSE_ID + assert response.data["results"] + + def test_course_key_wins_when_both_sent_but_header_still_emitted(self): + """When both names are sent, the preferred name wins but the header is still emitted.""" + response = self.client.get(self.url, { + "course_key": str(self.course_a.id), + "course_id": str(self.course_b.id), + }) + assert response.status_code == status.HTTP_200_OK + assert response.headers.get("Deprecation") == _ADR_0033_DEPRECATION_HEADER_COURSE_ID + # course_key (course_a) must win — none of course_b's enrollments should appear. + results = response.data["results"] + assert all(r["course_id"] == str(self.course_a.id) for r in results) + assert results + + # ---- course_keys / course_ids ---- + + def test_legacy_course_ids_param_emits_header(self): + """``?course_ids=…`` returns 200 and emits the ADR 0033 header.""" + response = self.client.get(self.url, { + "course_ids": f"{self.course_a.id},{self.course_b.id}", + }) + assert response.status_code == status.HTTP_200_OK + assert response.headers.get("Deprecation") == _ADR_0033_DEPRECATION_HEADER_COURSE_IDS + + def test_new_course_keys_param_no_header(self): + """``?course_keys=…`` returns 200 with no ``Deprecation`` header.""" + response = self.client.get(self.url, { + "course_keys": f"{self.course_a.id},{self.course_b.id}", + }) + assert response.status_code == status.HTTP_200_OK + assert "Deprecation" not in response.headers + + def test_both_legacy_names_combined_header(self): + """Sending both legacy names produces a combined deprecation header.""" + response = self.client.get(self.url, { + "course_id": str(self.course_a.id), + "course_ids": f"{self.course_a.id},{self.course_b.id}", + }) + assert response.status_code == status.HTTP_200_OK + assert response.headers.get("Deprecation") == _ADR_0033_DEPRECATION_HEADER_COURSE_ID_AND_IDS + + # ---- baseline / no params ---- + + def test_no_legacy_params_no_header(self): + """A plain unfiltered request emits no ``Deprecation`` header.""" + response = self.client.get(self.url) + assert response.status_code == status.HTTP_200_OK + assert "Deprecation" not in response.headers + + # ---- ordering whitelist ---- + + def test_ordering_created_ascending(self): + """``?ordering=created`` orders results by ``created`` ascending.""" + response = self.client.get(self.url, {"ordering": "created"}) + assert response.status_code == status.HTTP_200_OK + created_values = [row["created"] for row in response.data["results"]] + assert created_values == sorted(created_values) + + def test_ordering_created_descending(self): + """``?ordering=-created`` orders results by ``created`` descending.""" + response = self.client.get(self.url, {"ordering": "-created"}) + assert response.status_code == status.HTTP_200_OK + created_values = [row["created"] for row in response.data["results"]] + assert created_values == sorted(created_values, reverse=True) + + def test_ordering_invalid_value_is_ignored(self): + """A value outside the whitelist is silently ignored (no 400, no ORDER BY).""" + response = self.client.get(self.url, {"ordering": "user__email"}) # not in whitelist + assert response.status_code == status.HTTP_200_OK + + +@ddt.ddt +@skip_unless_lms +class TestEnrollmentUserRolesAdr0033(ModuleStoreTestCase): + """ + ADR 0033 – tests for EnrollmentUserRolesView (GET /api/enrollment/v1/roles/). + + Covers: + * OEP-68 §2: ``course_key`` is the preferred filter; ``course_id`` is a + deprecated alias. + * BC §2: ``Deprecation`` header is emitted when the legacy name is used. + * Tie-break: when both names are sent, ``course_key`` wins and the + header is still emitted. + """ + + USERNAME = "adr33-roles" + EMAIL = "adr33-roles@example.com" + PASSWORD = "edx" + + def setUp(self): + super().setUp() + self.course_a = CourseFactory.create(emit_signals=True, org="adr33r", course="a", run="r") + self.course_b = CourseFactory.create(emit_signals=True, org="adr33r", course="b", run="r") + self.user = UserFactory.create(username=self.USERNAME, email=self.EMAIL, password=self.PASSWORD) + CourseStaffRole(self.course_a.id).add_users(self.user) + CourseStaffRole(self.course_b.id).add_users(self.user) + self.client.login(username=self.USERNAME, password=self.PASSWORD) + self.url = reverse("roles") + + def test_new_course_key_param_no_header(self): + """``?course_key=…`` returns 200 and no ``Deprecation`` header; filter applies.""" + response = self.client.get(self.url, {"course_key": str(self.course_a.id)}) + assert response.status_code == status.HTTP_200_OK + assert "Deprecation" not in response.headers + roles = json.loads(response.content.decode("utf-8"))["roles"] + assert {r["course_id"] for r in roles} == {str(self.course_a.id)} + + def test_legacy_course_id_param_emits_header(self): + """``?course_id=…`` returns 200 and emits the ADR 0033 header; filter applies.""" + response = self.client.get(self.url, {"course_id": str(self.course_a.id)}) + assert response.status_code == status.HTTP_200_OK + assert response.headers.get("Deprecation") == _ADR_0033_DEPRECATION_HEADER_COURSE_ID + roles = json.loads(response.content.decode("utf-8"))["roles"] + assert {r["course_id"] for r in roles} == {str(self.course_a.id)} + + def test_course_key_wins_when_both_sent_header_still_emitted(self): + """When both names are sent, ``course_key`` wins and the header is still emitted.""" + response = self.client.get(self.url, { + "course_key": str(self.course_a.id), + "course_id": str(self.course_b.id), + }) + assert response.status_code == status.HTTP_200_OK + assert response.headers.get("Deprecation") == _ADR_0033_DEPRECATION_HEADER_COURSE_ID + roles = json.loads(response.content.decode("utf-8"))["roles"] + assert {r["course_id"] for r in roles} == {str(self.course_a.id)} + + def test_no_filter_no_header(self): + """Plain ``GET /roles/`` does not emit the ``Deprecation`` header.""" + response = self.client.get(self.url) + assert response.status_code == status.HTTP_200_OK + assert "Deprecation" not in response.headers diff --git a/openedx/core/djangoapps/enrollments/urls.py b/openedx/core/djangoapps/enrollments/urls.py index 828d4b61798b..163be26e2ec2 100644 --- a/openedx/core/djangoapps/enrollments/urls.py +++ b/openedx/core/djangoapps/enrollments/urls.py @@ -5,6 +5,7 @@ from django.conf import settings from django.urls import path, re_path +from rest_framework.routers import DefaultRouter from .views import ( CourseEnrollmentsApiListView, @@ -13,10 +14,18 @@ EnrollmentListView, EnrollmentUserRolesView, EnrollmentView, + EnrollmentViewSet, UnenrollmentView, ) -urlpatterns = [ +# ADR 0028: EnrollmentViewSet registered via DefaultRouter. +# Generates: GET/POST /enrollment/, POST /enrollment/unenroll/, GET/POST/DELETE /enrollment/enrollment_allowed/ +router = DefaultRouter() +router.register(r"enrollment", EnrollmentViewSet, basename="enrollment") + +urlpatterns = router.urls + [ + # EnrollmentView kept as-is: non-standard {username},{course_key} URL is incompatible with + # DefaultRouter lookup — migrate to ViewSet retrieve() in a follow-up (TODO ADR 0028). re_path( r"^enrollment/{username},{course_key}$".format( username=settings.USERNAME_PATTERN, course_key=settings.COURSE_ID_PATTERN @@ -25,12 +34,15 @@ name="courseenrollment", ), re_path(rf"^enrollment/{settings.COURSE_ID_PATTERN}$", EnrollmentView.as_view(), name="courseenrollment"), - path("enrollment", EnrollmentListView.as_view(), name="courseenrollments"), re_path(r"^enrollments/?$", CourseEnrollmentsApiListView.as_view(), name="courseenrollmentsapilist"), re_path( rf"^course/{settings.COURSE_ID_PATTERN}$", EnrollmentCourseDetailView.as_view(), name="courseenrollmentdetails" ), - path("unenroll/", UnenrollmentView.as_view(), name="unenrollment"), path("roles/", EnrollmentUserRolesView.as_view(), name="roles"), + + # DEPRECATED (ADR 0028): flat URL patterns kept for backward compatibility. + # Will be removed after one named release. Use the router-generated enrollment/ URLs instead. + path("enrollment", EnrollmentListView.as_view(), name="courseenrollments"), + path("unenroll/", UnenrollmentView.as_view(), name="unenrollment"), path("enrollment_allowed/", EnrollmentAllowedView.as_view(), name="courseenrollmentallowed"), ] diff --git a/openedx/core/djangoapps/enrollments/views.py b/openedx/core/djangoapps/enrollments/views.py index dc3423245e9b..f8d484ecefce 100644 --- a/openedx/core/djangoapps/enrollments/views.py +++ b/openedx/core/djangoapps/enrollments/views.py @@ -13,15 +13,24 @@ from django.db import IntegrityError # lint-amnesty, pylint: disable=wrong-import-order from django.db.models import Q # lint-amnesty, pylint: disable=wrong-import-order from django.utils.decorators import method_decorator # lint-amnesty, pylint: disable=wrong-import-order +from drf_spectacular.utils import ( # lint-amnesty, pylint: disable=wrong-import-order + extend_schema, + extend_schema_view, + OpenApiParameter, + OpenApiRequest, + OpenApiResponse, +) from edx_rest_framework_extensions.auth.jwt.authentication import ( JwtAuthentication, ) # lint-amnesty, pylint: disable=wrong-import-order from edx_rest_framework_extensions.auth.session.authentication import ( SessionAuthenticationAllowInactiveUser, ) # lint-amnesty, pylint: disable=wrong-import-order +from edx_rest_framework_extensions.paginators import DefaultPagination # lint-amnesty, pylint: disable=wrong-import-order from opaque_keys import InvalidKeyError # lint-amnesty, pylint: disable=wrong-import-order from opaque_keys.edx.keys import CourseKey # lint-amnesty, pylint: disable=wrong-import-order -from rest_framework import permissions, status # lint-amnesty, pylint: disable=wrong-import-order +from rest_framework import permissions, status, viewsets # lint-amnesty, pylint: disable=wrong-import-order +from rest_framework.decorators import action # lint-amnesty, pylint: disable=wrong-import-order from rest_framework.generics import ListAPIView # lint-amnesty, pylint: disable=wrong-import-order from rest_framework.response import Response # lint-amnesty, pylint: disable=wrong-import-order from rest_framework.throttling import UserRateThrottle # lint-amnesty, pylint: disable=wrong-import-order @@ -35,19 +44,23 @@ from openedx.core.djangoapps.cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf from openedx.core.djangoapps.cors_csrf.decorators import ensure_csrf_cookie_cross_domain from openedx.core.djangoapps.course_groups.cohorts import CourseUserGroup, add_user_to_cohort, get_cohort_by_name -from openedx.core.djangoapps.embargo import api as embargo_api -from openedx.core.djangoapps.enrollments import api -from openedx.core.djangoapps.enrollments.errors import ( +from openedx.core.djangoapps.content.course_overviews.models import CourseOverview # lint-amnesty, pylint: disable=wrong-import-order +from openedx.core.djangoapps.embargo import api as embargo_api # lint-amnesty, pylint: disable=wrong-import-order +from openedx.core.djangoapps.enrollments import api # lint-amnesty, pylint: disable=wrong-import-order +from openedx.core.djangoapps.enrollments.errors import ( # lint-amnesty, pylint: disable=wrong-import-order CourseEnrollmentError, CourseEnrollmentExistsError, CourseModeNotFoundError, InvalidEnrollmentAttribute, ) -from openedx.core.djangoapps.enrollments.forms import CourseEnrollmentsApiListForm -from openedx.core.djangoapps.enrollments.paginators import CourseEnrollmentsApiListPagination -from openedx.core.djangoapps.enrollments.serializers import ( +from openedx.core.djangoapps.enrollments.forms import CourseEnrollmentsApiListForm # lint-amnesty, pylint: disable=wrong-import-order +from openedx.core.djangoapps.enrollments.paginators import CourseEnrollmentsApiListPagination # lint-amnesty, pylint: disable=wrong-import-order +from openedx.core.djangoapps.enrollments.serializers import ( # lint-amnesty, pylint: disable=wrong-import-order CourseEnrollmentAllowedSerializer, + CourseEnrollmentSerializer, CourseEnrollmentsApiListSerializer, + CourseSerializer, + UserRolesResponseSerializer, ) from openedx.core.djangoapps.user_api.accounts.permissions import CanRetireUser from openedx.core.djangoapps.user_api.models import UserRetirementStatus @@ -70,6 +83,73 @@ } +# ADR 0027 — shared OpenAPI parameter and response building blocks +def _path_param(name: str, description: str) -> OpenApiParameter: + return OpenApiParameter( + name=name, description=description, required=True, type=str, location=OpenApiParameter.PATH, + ) + + +def _query_param(name: str, description: str, required: bool = False, type_=str) -> OpenApiParameter: + return OpenApiParameter( + name=name, description=description, required=required, type=type_, location=OpenApiParameter.QUERY, + ) + + +_COURSE_ID_PATH_PARAM = _path_param("course_id", "Course ID (e.g. course-v1:org+course+run).") +_USERNAME_PATH_PARAM = _path_param("username", "Username of the user.") +_USER_QUERY_PARAM = _query_param("user", "Username of the user whose enrollments to list.") +_INCLUDE_EXPIRED_QUERY_PARAM = _query_param( + "include_expired", "If '1', include expired enrollment modes in the response.", +) +_PAGE_QUERY_PARAM = _query_param("page", "Page number to retrieve. Default 1.") +_PAGE_SIZE_QUERY_PARAM = _query_param("page_size", "Items per page (default 10, max 100).") + +_RESP_UNAUTHENTICATED = OpenApiResponse(description="The requester is not authenticated.") +_RESP_FORBIDDEN = OpenApiResponse(description="The requester does not have permission for this operation.") +_RESP_NOT_FOUND = OpenApiResponse(description="The requested resource does not exist.") +_RESP_BAD_REQUEST = OpenApiResponse(description="Invalid request data or parameters.") + + +# ADR 0033 – sorting / OEP-68 parameter naming standardization helpers. +# Used by list endpoints that accept legacy parameter names (e.g. ``course_id`` +# instead of ``course_key``) so they can emit the BC-strategy §2 ``Deprecation`` +# HTTP header without each view duplicating the boilerplate. +def _build_legacy_param_deprecation_header(legacy_to_preferred): + """ + Build the ``Deprecation`` HTTP header value for one or more legacy parameter + names, each paired with its OEP-68-compliant replacement. + + Example: ``[('course_id', 'course_key')]`` → + ``"Parameter 'course_id' is deprecated. Use 'course_key' instead. ..."`` + """ + parts = [ + f"Parameter '{legacy}' is deprecated. Use '{preferred}' instead." + for legacy, preferred in legacy_to_preferred + ] + parts.append("Support will be removed in release ''.") + return " ".join(parts) + + +def _maybe_set_legacy_param_deprecation_header(request, response, alias_pairs): + """ + Set the ADR 0033 ``Deprecation`` HTTP header on ``response`` when any of + the legacy parameter names in ``alias_pairs`` is present in the request's + query string. + + ``alias_pairs`` is a sequence of ``(legacy, preferred)`` tuples (e.g. + ``[('course_id', 'course_key'), ('course_ids', 'course_keys')]``). The + header is emitted whenever any *legacy* name appears, even if the + corresponding ``preferred`` name was also supplied (in which case the + preferred value wins, but the caller should still be told that the legacy + alias is deprecated). + """ + used = [(legacy, preferred) for legacy, preferred in alias_pairs if legacy in request.query_params] + if used: + response['Deprecation'] = _build_legacy_param_deprecation_header(used) + return response + + class EnrollmentCrossDomainSessionAuth(SessionAuthenticationAllowInactiveUser, SessionAuthenticationCrossDomainCsrf): """Session authentication that allows inactive users and cross-domain requests.""" @@ -187,9 +267,27 @@ class EnrollmentView(APIView, ApiKeyPermissionMixIn): ) permission_classes = (ApiKeyHeaderPermissionIsAuthenticated,) throttle_classes = (EnrollmentUserThrottle,) + serializer_class = CourseEnrollmentSerializer # Since the course about page on the marketing site uses this API to auto-enroll users, # we need to support cross-domain CSRF. + @extend_schema( + summary="Retrieve a user's enrollment in a course", + description=( + "Returns the current user's enrollment for the specified course, or the named user's " + "enrollment when invoked with the {username},{course_id} URL form (server-to-server or " + "staff only)." + ), + parameters=[_USERNAME_PATH_PARAM, _COURSE_ID_PATH_PARAM], + responses={ + 200: OpenApiResponse( + response=CourseEnrollmentSerializer, + description="Enrollment retrieved successfully (or empty body if no enrollment).", + ), + 400: _RESP_BAD_REQUEST, + 404: _RESP_NOT_FOUND, + }, + ) @method_decorator(ensure_csrf_cookie_cross_domain) def get(self, request, course_id=None, username=None): """Create, read, or update enrollment information for a user. @@ -221,7 +319,17 @@ def get(self, request, course_id=None, username=None): return Response(status=status.HTTP_404_NOT_FOUND) try: - return Response(api.get_enrollment(username, course_id)) + course_key = CourseKey.from_string(course_id) + except InvalidKeyError: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": f"No course '{course_id}' found for enrollment"}, + ) + + try: + enrollment = CourseEnrollment.objects.get(user__username=username, course_id=course_key) + except CourseEnrollment.DoesNotExist: + return Response(None) except CourseEnrollmentError: return Response( status=status.HTTP_400_BAD_REQUEST, @@ -233,6 +341,9 @@ def get(self, request, course_id=None, username=None): }, ) + serializer = self.serializer_class(enrollment) + return Response(serializer.data) + class EnrollmentUserRolesView(APIView): """ @@ -266,17 +377,51 @@ class EnrollmentUserRolesView(APIView): ) permission_classes = (ApiKeyHeaderPermissionIsAuthenticated,) throttle_classes = (EnrollmentUserThrottle,) - + serializer_class = UserRolesResponseSerializer + + # ADR 0033 §2 / OEP-68: ``course_key`` is the standardized name; + # ``course_id`` is retained as a deprecated alias. + _LEGACY_PARAM_ALIASES = (("course_id", "course_key"),) + + @extend_schema( + summary="List the current user's course roles", + description=( + "Returns the list of course-level roles held by the currently logged-in user, plus an " + "is_staff flag. Optionally filters by course_key." + ), + parameters=[ + _query_param("course_key", "If provided, only roles for this course are returned (per OEP-68)."), + OpenApiParameter( + name="course_id", + description="Deprecated alias for 'course_key' (ADR 0033). Use 'course_key' instead.", + required=False, + type=str, + location=OpenApiParameter.QUERY, + deprecated=True, + ), + ], + responses={ + 200: OpenApiResponse( + response=UserRolesResponseSerializer, + description="Roles retrieved successfully.", + ), + 400: _RESP_BAD_REQUEST, + }, + ) @method_decorator(ensure_csrf_cookie_cross_domain) def get(self, request): """ - Gets a list of all roles for the currently logged-in user, filtered by course_id if supplied + Gets a list of all roles for the currently logged-in user, filtered by + ``course_key`` (preferred, ADR 0033 / OEP-68) or ``course_id`` (deprecated + alias). When both are present, ``course_key`` wins; in either case the + response carries the ADR 0033 ``Deprecation`` header if the legacy name + was used. """ try: - course_id = request.GET.get("course_id") + course_key = request.GET.get("course_key") or request.GET.get("course_id") roles_data = api.get_user_roles(request.user.username) - if course_id: - roles_data = [role for role in roles_data if str(role.course_id) == course_id] + if course_key: + roles_data = [role for role in roles_data if str(role.course_id) == course_key] except Exception: # pylint: disable=broad-except return Response( status=status.HTTP_400_BAD_REQUEST, @@ -286,13 +431,13 @@ def get(self, request): ) }, ) - return Response( - { - "roles": [ - {"org": role.org, "course_id": str(role.course_id), "role": role.role} for role in roles_data - ], - "is_staff": request.user.is_staff, - } + serializer = self.serializer_class({ + "roles": list(roles_data), + "is_staff": request.user.is_staff, + }) + response = Response(serializer.data) + return _maybe_set_legacy_param_deprecation_header( + request, response, self._LEGACY_PARAM_ALIASES, ) @@ -363,7 +508,24 @@ class EnrollmentCourseDetailView(APIView): authentication_classes = [] permission_classes = [] throttle_classes = (EnrollmentUserThrottle,) - + serializer_class = CourseSerializer + + @extend_schema( + summary="Get enrollment details for a course", + description=( + "Returns the course schedule and the enrollment modes supported by the course. " + "This endpoint does not require authentication. Use ?include_expired=1 to include " + "expired enrollment modes." + ), + parameters=[_COURSE_ID_PATH_PARAM, _INCLUDE_EXPIRED_QUERY_PARAM], + responses={ + 200: OpenApiResponse( + response=CourseSerializer, + description="Course enrollment details retrieved successfully.", + ), + 400: _RESP_BAD_REQUEST, + }, + ) def get(self, request, course_id=None): """Read enrollment information for a particular course. @@ -380,14 +542,25 @@ def get(self, request, course_id=None): """ try: - return Response(api.get_course_enrollment_details(course_id, bool(request.GET.get("include_expired", "")))) - except CourseNotFoundError: + course_key = CourseKey.from_string(course_id) + except InvalidKeyError: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": f"No course found for course ID '{course_id}'"}, + ) + try: + course_overview = CourseOverview.get_from_id(course_key) + except CourseOverview.DoesNotExist: return Response( status=status.HTTP_400_BAD_REQUEST, - data={"message": ("No course found for course ID '{course_id}'").format(course_id=course_id)}, + data={"message": f"No course found for course ID '{course_id}'"}, ) + include_expired = bool(request.GET.get("include_expired", "")) + serializer = self.serializer_class(course_overview, include_expired=include_expired) + return Response(serializer.data) +# DEPRECATED (ADR 0028): Use EnrollmentViewSet.unenroll action instead. Will be removed after one named release. class UnenrollmentView(APIView): """ **Use Cases** @@ -428,7 +601,30 @@ class UnenrollmentView(APIView): permissions.IsAuthenticated, CanRetireUser, ) - + serializer_class = CourseEnrollmentSerializer + + @extend_schema( + operation_id="enrollment_v1_unenroll_deprecated", + summary="Unenroll a user from all courses (deprecated)", + description=( + "Deprecated. Use POST /api/enrollment/v1/enrollment/unenroll/ " + "(EnrollmentViewSet.unenroll action) instead. Privileged retirement-pipeline use only." + ), + request=OpenApiRequest( + request={ + "type": "object", + "properties": {"username": {"type": "string"}}, + "required": ["username"], + } + ), + responses={ + 200: OpenApiResponse(description="List of courses from which the user was unenrolled."), + 204: OpenApiResponse(description="User has no active enrollments."), + 404: OpenApiResponse(description="Username not specified or no retirement status for user."), + 500: OpenApiResponse(description="Unexpected error during unenrollment."), + }, + deprecated=True, + ) def post(self, request): """ Unenrolls the specified user from all courses. @@ -438,9 +634,10 @@ def post(self, request): username = request.data["username"] # Ensure that a retirement request status row exists for this username. UserRetirementStatus.get_retirement_for_retirement_action(username) - enrollments = api.get_enrollments(username) - active_enrollments = [enrollment for enrollment in enrollments if enrollment["is_active"]] - if len(active_enrollments) < 1: + active_enrollments = CourseEnrollment.objects.filter( + user__username=username, is_active=True + ) + if not active_enrollments.exists(): return Response(status=status.HTTP_204_NO_CONTENT) return Response(api.unenroll_user_from_all_courses(username)) except KeyError: @@ -451,6 +648,497 @@ def post(self, request): return Response(str(exc), status=status.HTTP_500_INTERNAL_SERVER_ERROR) +# ADR 0028 – consolidated from EnrollmentListView, UnenrollmentView, EnrollmentAllowedView +@can_disable_rate_limit +class EnrollmentViewSet(viewsets.ViewSet, ApiKeyPermissionMixIn): + """ + DRF ViewSet for the Enrollment API. + + Consolidates EnrollmentListView, UnenrollmentView, and EnrollmentAllowedView into a single + ViewSet registered via DefaultRouter per ADR 0028. + + Actions: + list GET /enrollment/ List enrollments for the current user. + create POST /enrollment/ Enroll the current user in a course. + unenroll POST /enrollment/unenroll/ Unenroll a user from all courses (retirement pipeline). + allowed GET/POST/DELETE /enrollment/enrollment_allowed/ Manage CourseEnrollmentAllowed records. + """ + + authentication_classes = ( + JwtAuthentication, + BearerAuthenticationAllowInactiveUser, + EnrollmentCrossDomainSessionAuth, + ) + permission_classes = (ApiKeyHeaderPermissionIsAuthenticated,) + throttle_classes = (EnrollmentUserThrottle,) + serializer_class = CourseEnrollmentSerializer + pagination_class = DefaultPagination # ADR 0032 + + def get_serializer_class(self): + """Return CourseEnrollmentAllowedSerializer for the 'allowed' action, else the default.""" + if self.action == "allowed": + return CourseEnrollmentAllowedSerializer + return self.serializer_class + + def get_serializer(self, *args, **kwargs): + """Instantiate and return the appropriate serializer for this action.""" + return self.get_serializer_class()(*args, **kwargs) + + @extend_schema( + summary="List enrollments for a user (paginated)", + description=( + "Returns a paginated list of enrollments for the currently logged-in user, or for the " + "user named by the 'user' query parameter (staff/admin/api-key access required to view " + "another user's enrollments — otherwise filtered to courses the requester staffs)." + ), + parameters=[_USER_QUERY_PARAM, _PAGE_QUERY_PARAM, _PAGE_SIZE_QUERY_PARAM], + responses={ + 200: OpenApiResponse( + response=CourseEnrollmentSerializer(many=True), + description="Paginated enrollment list.", + ), + 401: _RESP_UNAUTHENTICATED, + }, + ) + @method_decorator(ensure_csrf_cookie_cross_domain) + def list(self, request): + """Gets a list of all course enrollments for a user. + + Returns a paginated list for the currently logged-in user, or for the user named by the + 'user' GET parameter. If the username does not match that of the currently logged-in user, + only courses for which the currently logged-in user has the Staff or Admin role are listed. + + **Pagination Parameters** + + - ``page`` (int): Page number to retrieve. Default is 1. + - ``page_size`` (int): Items per page. Default is 10, max is 100. + + **Response Envelope** + + - ``count`` (int): Total number of results. + - ``num_pages`` (int): Total number of pages. + - ``current_page`` (int): The current page number. + - ``start`` (int): The 0-based index of the first item on this page. + - ``next`` (str|null): URL for the next page, or null. + - ``previous`` (str|null): URL for the previous page, or null. + - ``results`` (list): The list of enrollments for this page. + """ + username = request.GET.get("user", request.user.username) + enrollments = CourseEnrollment.objects.filter( + user__username=username + ).select_related("user", "course") # "course" is the FK field; "course_overview" is a property + paginator = self.pagination_class() + if ( + username == request.user.username + or GlobalStaff().has_user(request.user) + or self.has_api_key_permissions(request) + ): + page = paginator.paginate_queryset(enrollments, request, view=self) + return paginator.get_paginated_response(self.get_serializer(page, many=True).data) + filtered_enrollments = [ + enrollment for enrollment in enrollments + if user_has_role(request.user, CourseStaffRole(enrollment.course_id)) + ] + page = paginator.paginate_queryset(filtered_enrollments, request, view=self) + return paginator.get_paginated_response(self.get_serializer(page, many=True).data) + + @extend_schema( + summary="Create or update an enrollment", + description=( + "Enrolls a user in a course. Server-to-server calls may deactivate or modify the mode " + "of existing enrollments; all other requests go through add_enrollment(), which creates " + "or reactivates enrollments. The request body must include course_details.course_id." + ), + request=OpenApiRequest(request=CourseEnrollmentSerializer), + responses={ + 200: OpenApiResponse( + response=CourseEnrollmentSerializer, + description="Enrollment created, reactivated, or updated successfully.", + ), + 400: _RESP_BAD_REQUEST, + 403: _RESP_FORBIDDEN, + 404: _RESP_NOT_FOUND, + 406: OpenApiResponse(description="The specified user does not exist."), + }, + ) + @method_decorator(ensure_csrf_cookie_cross_domain) + def create(self, request): + # pylint: disable=too-many-statements + """Enrolls the currently logged-in user in a course. + + Server-to-server calls may deactivate or modify the mode of existing enrollments. All other + requests go through add_enrollment(), which allows creation and reactivation of enrollments. + """ + username = request.data.get("user") + course_id = request.data.get("course_details", {}).get("course_id") + + if not course_id: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": "Course ID must be specified to create a new enrollment."}, + ) + + try: + course_id = CourseKey.from_string(course_id) + except InvalidKeyError: + return Response( + status=status.HTTP_400_BAD_REQUEST, data={"message": f"No course '{course_id}' found for enrollment"} + ) + + mode = request.data.get("mode") + + has_api_key_permissions = self.has_api_key_permissions(request) + + if ( + username + and username != request.user.username + and not has_api_key_permissions + and not GlobalStaff().has_user(request.user) + ): + return Response(status=status.HTTP_404_NOT_FOUND) + + if not username: + email = request.data.get("email") + if email: + if not has_api_key_permissions and not GlobalStaff().has_user(request.user): + return Response(status=status.HTTP_404_NOT_FOUND) + try: + username = User.objects.get(email=email).username + except ObjectDoesNotExist: + return Response( + status=status.HTTP_406_NOT_ACCEPTABLE, + data={"message": f"The user with the email address {email} does not exist."}, + ) + else: + username = request.user.username + + if ( + mode not in (CourseMode.AUDIT, CourseMode.HONOR, None) + and not has_api_key_permissions + and not GlobalStaff().has_user(request.user) + ): + return Response( + status=status.HTTP_403_FORBIDDEN, + data={ + "message": "User does not have permission to create enrollment with mode [{mode}].".format( + mode=mode + ) + }, + ) + + try: + user = User.objects.get(username=username) + except ObjectDoesNotExist: + return Response( + status=status.HTTP_406_NOT_ACCEPTABLE, data={"message": f"The user {username} does not exist."} + ) + + embargo_response = embargo_api.get_embargo_response(request, course_id, user) + + if embargo_response: + return embargo_response + + try: + is_active = request.data.get("is_active") + if is_active is not None and not isinstance(is_active, bool): + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": ("'{value}' is an invalid enrollment activation status.").format(value=is_active)}, + ) + + explicit_linked_enterprise = request.data.get("linked_enterprise_customer") + if explicit_linked_enterprise and has_api_key_permissions and enterprise_enabled(): + enterprise_api_client = EnterpriseApiServiceClient() + consent_client = ConsentApiServiceClient() + try: + enterprise_api_client.post_enterprise_course_enrollment(username, str(course_id)) + except EnterpriseApiException as error: + log.exception( + "An unexpected error occurred while creating the new EnterpriseCourseEnrollment " + "for user [%s] in course run [%s]", + username, + course_id, + ) + raise CourseEnrollmentError(str(error)) # lint-amnesty, pylint: disable=raise-missing-from + kwargs = { + "username": username, + "course_id": str(course_id), + "enterprise_customer_uuid": explicit_linked_enterprise, + } + consent_client.provide_consent(**kwargs) + + enrollment_attributes = request.data.get("enrollment_attributes") + force_enrollment = request.data.get("force_enrollment") + if force_enrollment is not None and not isinstance(force_enrollment, bool): + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={ + "message": ("'{value}' is an invalid force enrollment status.").format(value=force_enrollment) + }, + ) + force_enrollment = force_enrollment and GlobalStaff().has_user(request.user) + enrollment = api.get_enrollment(username, str(course_id)) + mode_changed = enrollment and mode is not None and enrollment["mode"] != mode + active_changed = enrollment and is_active is not None and enrollment["is_active"] != is_active + missing_attrs = [] + if enrollment_attributes: + actual_attrs = ["{namespace}:{name}".format(**attr) for attr in enrollment_attributes] + missing_attrs = set(REQUIRED_ATTRIBUTES.get(mode, [])) - set(actual_attrs) + if (GlobalStaff().has_user(request.user) or has_api_key_permissions) and (mode_changed or active_changed): + if mode_changed and active_changed and not is_active: + msg = "Enrollment mode mismatch: active mode={}, requested mode={}. Won't deactivate.".format( + enrollment["mode"], mode + ) + log.warning(msg) + return Response(status=status.HTTP_400_BAD_REQUEST, data={"message": msg}) + + if missing_attrs: + msg = "Missing enrollment attributes: requested mode={} required attributes={}".format( + mode, REQUIRED_ATTRIBUTES.get(mode) + ) + log.warning(msg) + return Response(status=status.HTTP_400_BAD_REQUEST, data={"message": msg}) + + response = api.update_enrollment( + username, + str(course_id), + mode=mode, + is_active=is_active, + enrollment_attributes=enrollment_attributes, + include_expired=has_api_key_permissions, + ) + else: + response = api.add_enrollment( + username, + str(course_id), + mode=mode, + is_active=is_active, + enrollment_attributes=enrollment_attributes, + enterprise_uuid=request.data.get("enterprise_uuid"), + force_enrollment=force_enrollment, + include_expired=force_enrollment, + ) + + cohort_name = request.data.get("cohort") + if cohort_name is not None: + cohort = get_cohort_by_name(course_id, cohort_name) + try: + add_user_to_cohort(cohort, user) + except ValueError: + log.exception("Cohort re-addition") + email_opt_in = request.data.get("email_opt_in", None) + if email_opt_in is not None: + org = course_id.org + update_email_opt_in(request.user, org, email_opt_in) + + log.info("The user [%s] has already been enrolled in course run [%s].", username, course_id) + return Response(response) + + except InvalidEnrollmentAttribute as error: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={ + "message": str(error), + "localizedMessage": str(error), + } + ) + except EnrollmentNotAllowed as error: + return Response( + status=status.HTTP_403_FORBIDDEN, + data={ + "message": str(error), + "localizedMessage": str(error), + } + ) + except CourseModeNotFoundError as error: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={ + "message": ( + "The [{mode}] course mode is expired or otherwise unavailable for course run [{course_id}]." + ).format(mode=mode, course_id=course_id), + "course_details": error.data, + }, + ) + except CourseNotFoundError: + return Response( + status=status.HTTP_400_BAD_REQUEST, data={"message": f"No course '{course_id}' found for enrollment"} + ) + except CourseEnrollmentExistsError as error: + log.warning("An enrollment already exists for user [%s] in course run [%s].", username, course_id) + return Response(data=error.enrollment) + except CourseEnrollmentError: + log.exception( + "An error occurred while creating the new course enrollment for user [%s] in course run [%s]", + username, + course_id, + ) + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={ + "message": ( + "An error occurred while creating the new course enrollment for user " + "'{username}' in course '{course_id}'" + ).format(username=username, course_id=course_id) + }, + ) + except CourseUserGroup.DoesNotExist: + log.exception("Missing cohort [%s] in course run [%s]", cohort_name, course_id) + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": "An error occured while adding to cohort [%s]" % cohort_name}, + ) + finally: + if has_api_key_permissions: + try: + current_enrollment_obj = CourseEnrollment.objects.get( + user__username=username, course_id=course_id + ) + actual_mode = current_enrollment_obj.mode + actual_activation = current_enrollment_obj.is_active + except CourseEnrollment.DoesNotExist: + actual_mode = None + actual_activation = None + audit_log( + "enrollment_change_requested", + course_id=str(course_id), + requested_mode=mode, + actual_mode=actual_mode, + requested_activation=is_active, + actual_activation=actual_activation, + user_id=user.id, + ) + + @extend_schema( + summary="Unenroll a user from all courses (retirement)", + description=( + "Privileged retirement-pipeline use only. Unenrolls the named user from every active " + "enrollment. The request must be made by a service user with CanRetireUser permission, " + "not the user being unenrolled." + ), + request=OpenApiRequest( + request={ + "type": "object", + "properties": {"username": {"type": "string"}}, + "required": ["username"], + } + ), + responses={ + 200: OpenApiResponse(description="List of courses from which the user was unenrolled."), + 204: OpenApiResponse(description="User has no active enrollments."), + 404: OpenApiResponse(description="Username not specified or no retirement status for user."), + 500: OpenApiResponse(description="Unexpected error during unenrollment."), + }, + ) + @action( + detail=False, + methods=["post"], + url_path="unenroll", + permission_classes=[permissions.IsAuthenticated, CanRetireUser], + ) + def unenroll(self, request): + """Unenrolls the specified user from all courses. + + Privileged retirement-pipeline use only. The request must be made by a service user + with CanRetireUser permission, not the user being unenrolled. + """ + try: + username = request.data["username"] + UserRetirementStatus.get_retirement_for_retirement_action(username) + active_enrollments = CourseEnrollment.objects.filter( + user__username=username, is_active=True + ) + if not active_enrollments.exists(): + return Response(status=status.HTTP_204_NO_CONTENT) + return Response(api.unenroll_user_from_all_courses(username)) + except KeyError: + return Response("Username not specified.", status=status.HTTP_404_NOT_FOUND) + except UserRetirementStatus.DoesNotExist: + return Response("No retirement request status for username.", status=status.HTTP_404_NOT_FOUND) + except Exception as exc: # pylint: disable=broad-except + return Response(str(exc), status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @extend_schema( + summary="Manage CourseEnrollmentAllowed records (admin-only)", + description=( + "GET lists allowed enrollments for an email; POST creates a new one; DELETE removes one " + "by email + course_id. Admin-only." + ), + request=OpenApiRequest(request=CourseEnrollmentAllowedSerializer), + parameters=[_query_param("email", "Email to query (GET only). Defaults to the requester's email.")], + responses={ + 200: OpenApiResponse( + response=CourseEnrollmentAllowedSerializer(many=True), + description="GET success — list of allowed enrollments for the email.", + ), + 201: OpenApiResponse( + response=CourseEnrollmentAllowedSerializer, + description="POST success — allowed enrollment created.", + ), + 204: OpenApiResponse(description="DELETE success — allowed enrollment deleted."), + 400: _RESP_BAD_REQUEST, + 404: OpenApiResponse(description="DELETE: allowed enrollment not found for the given email/course."), + 409: OpenApiResponse(description="POST: allowed enrollment already exists for this email/course."), + }, + ) + @action( + detail=False, + methods=["get", "post", "delete"], + url_path="enrollment_allowed", + permission_classes=[permissions.IsAdminUser], + throttle_classes=[EnrollmentUserThrottle], + ) + def allowed(self, request): + """Retrieve, create, or delete CourseEnrollmentAllowed records. Admin-only. + + GET /enrollment/enrollment_allowed/?email= List allowed enrollments for an email. + POST /enrollment/enrollment_allowed/ Create a new allowed enrollment. + DELETE /enrollment/enrollment_allowed/ Delete an existing allowed enrollment. + """ + if request.method == "GET": + user_email = request.query_params.get("email") or request.user.email + enrollments_allowed = CourseEnrollmentAllowed.objects.filter(email=user_email) + return Response( + status=status.HTTP_200_OK, + data=self.get_serializer(enrollments_allowed, many=True).data, + ) + + serializer = self.get_serializer(data=request.data) + if not serializer.is_valid(): + return Response(status=status.HTTP_400_BAD_REQUEST, data=serializer.errors) + + if request.method == "POST": + try: + enrollment_allowed = serializer.save() + except IntegrityError: + return Response( + status=status.HTTP_409_CONFLICT, + data={ + "message": ( + f"An enrollment allowed with email {serializer.validated_data.get('email')} " + f"and course {serializer.validated_data.get('course_id')} already exists." + ) + }, + ) + return Response( + status=status.HTTP_201_CREATED, + data=self.get_serializer(enrollment_allowed).data, + ) + + # DELETE + email = serializer.validated_data.get("email") + course_id = serializer.validated_data.get("course_id") + try: + CourseEnrollmentAllowed.objects.get(email=email, course_id=course_id).delete() + return Response(status=status.HTTP_204_NO_CONTENT) + except ObjectDoesNotExist: + return Response( + status=status.HTTP_404_NOT_FOUND, + data={"message": f"An enrollment allowed with email {email} and course {course_id} doesn't exists."}, + ) + + +# DEPRECATED (ADR 0028): Use EnrollmentViewSet instead. Will be removed after one named release. @can_disable_rate_limit class EnrollmentListView(APIView, ApiKeyPermissionMixIn): """ @@ -633,10 +1321,28 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): ) permission_classes = (ApiKeyHeaderPermissionIsAuthenticated,) throttle_classes = (EnrollmentUserThrottle,) + serializer_class = CourseEnrollmentSerializer # Since the course about page on the marketing site # uses this API to auto-enroll users, we need to support # cross-domain CSRF. + @extend_schema( + operation_id="enrollment_v1_enrollment_list_deprecated", + summary="List enrollments for a user (deprecated)", + description=( + "Deprecated. Use GET /api/enrollment/v1/enrollment/ (EnrollmentViewSet.list) instead. " + "This legacy endpoint returns an unpaginated list." + ), + parameters=[_USER_QUERY_PARAM], + responses={ + 200: OpenApiResponse( + response=CourseEnrollmentSerializer(many=True), + description="Enrollments retrieved successfully.", + ), + 401: _RESP_UNAUTHENTICATED, + }, + deprecated=True, + ) @method_decorator(ensure_csrf_cookie_cross_domain) def get(self, request): """Gets a list of all course enrollments for a user. @@ -656,30 +1362,42 @@ def get(self, request): courses. """ username = request.GET.get("user", request.user.username) - try: - enrollment_data = api.get_enrollments(username) - except CourseEnrollmentError: - return Response( - status=status.HTTP_400_BAD_REQUEST, - data={ - "message": ("An error occurred while retrieving enrollments for user '{username}'").format( - username=username - ) - }, - ) + enrollments = CourseEnrollment.objects.filter( + user__username=username + ).select_related("user", "course_overview") if ( username == request.user.username or GlobalStaff().has_user(request.user) or self.has_api_key_permissions(request) ): - return Response(enrollment_data) - filtered_data = [] - for enrollment in enrollment_data: - course_key = CourseKey.from_string(enrollment["course_details"]["course_id"]) - if user_has_role(request.user, CourseStaffRole(course_key)): - filtered_data.append(enrollment) - return Response(filtered_data) - + serializer = self.serializer_class(enrollments, many=True) + return Response(serializer.data) + filtered_enrollments = [ + enrollment for enrollment in enrollments + if user_has_role(request.user, CourseStaffRole(enrollment.course_id)) + ] + serializer = self.serializer_class(filtered_enrollments, many=True) + return Response(serializer.data) + + @extend_schema( + operation_id="enrollment_v1_enrollment_create_deprecated", + summary="Create or update an enrollment (deprecated)", + description=( + "Deprecated. Use POST /api/enrollment/v1/enrollment/ (EnrollmentViewSet.create) instead." + ), + request=OpenApiRequest(request=CourseEnrollmentSerializer), + responses={ + 200: OpenApiResponse( + response=CourseEnrollmentSerializer, + description="Enrollment created, reactivated, or updated successfully.", + ), + 400: _RESP_BAD_REQUEST, + 403: _RESP_FORBIDDEN, + 404: _RESP_NOT_FOUND, + 406: OpenApiResponse(description="The specified user does not exist."), + }, + deprecated=True, + ) def post(self, request): # pylint: disable=too-many-statements """Enrolls the currently logged-in user in a course. @@ -929,18 +1647,76 @@ def post(self, request): finally: # Assumes that the ecommerce service uses an API key to authenticate. if has_api_key_permissions: - current_enrollment = api.get_enrollment(username, str(course_id)) + try: + current_enrollment_obj = CourseEnrollment.objects.get( + user__username=username, course_id=course_id + ) + actual_mode = current_enrollment_obj.mode + actual_activation = current_enrollment_obj.is_active + except CourseEnrollment.DoesNotExist: + actual_mode = None + actual_activation = None audit_log( "enrollment_change_requested", course_id=str(course_id), requested_mode=mode, - actual_mode=current_enrollment["mode"] if current_enrollment else None, + actual_mode=actual_mode, requested_activation=is_active, - actual_activation=current_enrollment["is_active"] if current_enrollment else None, + actual_activation=actual_activation, user_id=user.id, ) +@extend_schema_view( + get=extend_schema( + summary="List all course enrollments (admin-only, paginated)", + description=( + "Admin-only paginated list of CourseEnrollment records, optionally filtered by " + "course_key, course_keys, username, or email, and optionally ordered." + ), + parameters=[ + # ADR 0033 §2 / OEP-68: ``course_key`` and ``course_keys`` are the + # standardized names; ``course_id`` and ``course_ids`` are kept as + # deprecated aliases (BC strategy §1) and trigger a ``Deprecation`` + # HTTP header (BC strategy §2). + _query_param("course_key", "Filter to enrollments for this course (per OEP-68)."), + _query_param("course_keys", "Comma-separated list of course keys (per OEP-68)."), + OpenApiParameter( + name="course_id", + description="Deprecated alias for 'course_key' (ADR 0033). Use 'course_key' instead.", + required=False, + type=str, + location=OpenApiParameter.QUERY, + deprecated=True, + ), + OpenApiParameter( + name="course_ids", + description="Deprecated alias for 'course_keys' (ADR 0033). Use 'course_keys' instead.", + required=False, + type=str, + location=OpenApiParameter.QUERY, + deprecated=True, + ), + _query_param("username", "Comma-separated list of usernames."), + _query_param("email", "Comma-separated list of emails."), + _query_param( + "ordering", + "Order results by one of: created, -created, id, -id (ADR 0033 §3).", + ), + _PAGE_QUERY_PARAM, + _PAGE_SIZE_QUERY_PARAM, + ], + responses={ + 200: OpenApiResponse( + response=CourseEnrollmentsApiListSerializer(many=True), + description="Paginated list of course enrollments.", + ), + 400: _RESP_BAD_REQUEST, + 401: _RESP_UNAUTHENTICATED, + 403: _RESP_FORBIDDEN, + }, + ), +) @can_disable_rate_limit class CourseEnrollmentsApiListView(DeveloperErrorViewMixin, ListAPIView): """ @@ -977,9 +1753,9 @@ class CourseEnrollmentsApiListView(DeveloperErrorViewMixin, ListAPIView): * email: List of comma-separated emails. Filters the result to the course enrollments of the given users. Optional. - * page_size: Number of results to return per page. Optional. + * page_size: Number of results to return per page. Default 100, max 100. Optional. - * page: Page number to retrieve. Optional. + * page: Page number to retrieve. Default is 1. Optional. **Response Values** @@ -988,6 +1764,20 @@ class CourseEnrollmentsApiListView(DeveloperErrorViewMixin, ListAPIView): The HTTP 200 response has the following values. + * count: Total number of course enrollments matching the request. + + * num_pages: Total number of pages. + + * current_page: The current page number. + + * start: The 0-based index of the first item on this page. + + * next: The URL to the next page of results, or null if this is the + last page. + + * previous: The URL to the previous page of results, or null if this + is the first page. + * results: A list of the course enrollments matching the request. * created: Date and time when the course enrollment was created. @@ -1000,12 +1790,6 @@ class CourseEnrollmentsApiListView(DeveloperErrorViewMixin, ListAPIView): * course_id: Course ID of the course in the course enrollment. - * next: The URL to the next page of results, or null if this is the - last page. - - * previous: The URL to the next page of results, or null if this - is the first page. - If the user is not logged in, a 401 error is returned. If the user is not global staff, a 403 error is returned. @@ -1028,9 +1812,35 @@ class CourseEnrollmentsApiListView(DeveloperErrorViewMixin, ListAPIView): serializer_class = CourseEnrollmentsApiListSerializer pagination_class = CourseEnrollmentsApiListPagination + # ADR 0033 §3: whitelist of allowed values for the standard ``ordering`` + # query parameter. Any other value is silently ignored (the queryset + # falls back to the default ordering). + ALLOWED_ORDERING_FIELDS = frozenset({"created", "-created", "id", "-id"}) + + # ADR 0033 §2 / OEP-68 alias pairs accepted by this endpoint. Used by + # the response post-processor to emit the ``Deprecation`` header when a + # caller still sends the legacy name. + _LEGACY_PARAM_ALIASES = ( + ("course_id", "course_key"), + ("course_ids", "course_keys"), + ) + def get_queryset(self): """ - Get all the course enrollments for the given course_id and/or given list of usernames. + Get all the course enrollments for the given course key(s) and/or given list of usernames. + + ADR 0033 compliance notes: + - Filter parameters accept both the OEP-68-preferred names + (``course_key``, ``course_keys``) and the deprecated legacy names + (``course_id``, ``course_ids``). Resolution is handled by + :class:`CourseEnrollmentsApiListForm`. + - The DRF-standard ``ordering`` query parameter is honored when its + value is in :pyattr:`ALLOWED_ORDERING_FIELDS`. + - Full migration to ``django-filter``/``DjangoFilterBackend`` is + tracked as a follow-up: the existing ``CourseEnrollmentsApiListForm`` + performs nuanced parsing (CSV → list, MAX 100, course-key + validation, username validation) that is not a free conversion to + a ``FilterSet``. """ form = CourseEnrollmentsApiListForm(self.request.query_params) @@ -1055,9 +1865,24 @@ def get_queryset(self): queryset = queryset.filter(user__username__in=usernames) if emails: queryset = queryset.filter(user__email__in=emails) + + ordering = self.request.query_params.get("ordering") + if ordering in self.ALLOWED_ORDERING_FIELDS: + queryset = queryset.order_by(ordering) return queryset + def list(self, request, *args, **kwargs): + """ + ADR 0033 BC §2: emit the ``Deprecation`` HTTP header when a caller + still uses the legacy ``course_id`` / ``course_ids`` parameter names. + """ + response = super().list(request, *args, **kwargs) + return _maybe_set_legacy_param_deprecation_header( + request, response, self._LEGACY_PARAM_ALIASES, + ) + +# DEPRECATED (ADR 0028): Use EnrollmentViewSet.allowed action instead. Will be removed after one named release. class EnrollmentAllowedView(APIView): """ A view that allows the retrieval and creation of enrollment allowed for a given user email and course id. @@ -1067,6 +1892,23 @@ class EnrollmentAllowedView(APIView): throttle_classes = (EnrollmentUserThrottle,) serializer_class = CourseEnrollmentAllowedSerializer + @extend_schema( + operation_id="enrollment_v1_enrollment_allowed_list_deprecated", + summary="List allowed enrollments by email (deprecated)", + description=( + "Deprecated. Use GET /api/enrollment/v1/enrollment/enrollment_allowed/ " + "(EnrollmentViewSet.allowed action) instead. Admin-only." + ), + parameters=[_query_param("email", "Email to query. Defaults to the requester's email if omitted.")], + responses={ + 200: OpenApiResponse( + response=CourseEnrollmentAllowedSerializer(many=True), + description="Allowed enrollments retrieved successfully.", + ), + 403: _RESP_FORBIDDEN, + }, + deprecated=True, + ) def get(self, request): """ Returns the enrollments allowed for a given user email. @@ -1087,13 +1929,29 @@ def get(self, request): if not user_email: user_email = request.user.email - enrollments_allowed = CourseEnrollmentAllowed.objects.filter(email=user_email) or [] - serialized_enrollments_allowed = [ - CourseEnrollmentAllowedSerializer(enrollment).data for enrollment in enrollments_allowed - ] - - return Response(status=status.HTTP_200_OK, data=serialized_enrollments_allowed) - + enrollments_allowed = CourseEnrollmentAllowed.objects.filter(email=user_email) + serializer = self.serializer_class(enrollments_allowed, many=True) + return Response(status=status.HTTP_200_OK, data=serializer.data) + + @extend_schema( + operation_id="enrollment_v1_enrollment_allowed_create_deprecated", + summary="Create an allowed enrollment (deprecated)", + description=( + "Deprecated. Use POST /api/enrollment/v1/enrollment/enrollment_allowed/ " + "(EnrollmentViewSet.allowed action) instead. Admin-only." + ), + request=OpenApiRequest(request=CourseEnrollmentAllowedSerializer), + responses={ + 201: OpenApiResponse( + response=CourseEnrollmentAllowedSerializer, + description="Allowed enrollment created.", + ), + 400: _RESP_BAD_REQUEST, + 403: _RESP_FORBIDDEN, + 409: OpenApiResponse(description="Allowed enrollment already exists for this email/course."), + }, + deprecated=True, + ) def post(self, request): """ Creates an enrollment allowed for a given user email and course id. @@ -1126,24 +1984,41 @@ def post(self, request): - 403: Forbidden, you need to be staff. - 409: Conflict, enrollment allowed already exists. """ - is_bad_request_response, email, course_id = self.check_required_data(request) - auto_enroll = request.data.get("auto_enroll", False) - if is_bad_request_response: - return is_bad_request_response + serializer = self.serializer_class(data=request.data) + if not serializer.is_valid(): + return Response(status=status.HTTP_400_BAD_REQUEST, data=serializer.errors) try: - enrollment_allowed = CourseEnrollmentAllowed.objects.create( - email=email, course_id=course_id, auto_enroll=auto_enroll - ) + enrollment_allowed = serializer.save() except IntegrityError: return Response( status=status.HTTP_409_CONFLICT, - data={"message": f"An enrollment allowed with email {email} and course {course_id} already exists."}, + data={ + "message": ( + f"An enrollment allowed with email {serializer.validated_data.get('email')} " + f"and course {serializer.validated_data.get('course_id')} already exists." + ) + }, ) - serializer = CourseEnrollmentAllowedSerializer(enrollment_allowed) - return Response(status=status.HTTP_201_CREATED, data=serializer.data) - + return Response(status=status.HTTP_201_CREATED, data=self.serializer_class(enrollment_allowed).data) + + @extend_schema( + operation_id="enrollment_v1_enrollment_allowed_destroy_deprecated", + summary="Delete an allowed enrollment (deprecated)", + description=( + "Deprecated. Use DELETE /api/enrollment/v1/enrollment/enrollment_allowed/ " + "(EnrollmentViewSet.allowed action) instead. Admin-only." + ), + request=OpenApiRequest(request=CourseEnrollmentAllowedSerializer), + responses={ + 204: OpenApiResponse(description="Allowed enrollment deleted."), + 400: _RESP_BAD_REQUEST, + 403: _RESP_FORBIDDEN, + 404: OpenApiResponse(description="Allowed enrollment not found for the given email/course."), + }, + deprecated=True, + ) def delete(self, request): """ Deletes an enrollment allowed for a given user email and course id. @@ -1174,32 +2049,18 @@ def delete(self, request): - 403: Forbidden, you need to be staff. - 404: Not found, the course enrollment allowed doesn't exists. """ - is_bad_request_response, email, course_id = self.check_required_data(request) - if is_bad_request_response: - return is_bad_request_response + serializer = self.serializer_class(data=request.data) + if not serializer.is_valid(): + return Response(status=status.HTTP_400_BAD_REQUEST, data=serializer.errors) + + email = serializer.validated_data.get("email") + course_id = serializer.validated_data.get("course_id") try: CourseEnrollmentAllowed.objects.get(email=email, course_id=course_id).delete() - return Response( - status=status.HTTP_204_NO_CONTENT, - ) + return Response(status=status.HTTP_204_NO_CONTENT) except ObjectDoesNotExist: return Response( status=status.HTTP_404_NOT_FOUND, data={"message": f"An enrollment allowed with email {email} and course {course_id} doesn't exists."}, ) - - def check_required_data(self, request): - """ - Check if the request has email and course_id. - """ - email = request.data.get("email") - course_id = request.data.get("course_id") - if not email or not course_id: - is_bad_request = Response( - status=status.HTTP_400_BAD_REQUEST, - data={"message": "Please provide a value for 'email' and 'course_id' in the request data."}, - ) - else: - is_bad_request = None - return (is_bad_request, email, course_id)