Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions specifyweb/backend/accounts/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class OAuthLogin(TypedDict):
state: str
provider: str
provider_conf: ProviderConf
next: str

class ExternalUser(TypedDict):
"""Information passed through a session variable to associate the
Expand Down
60 changes: 52 additions & 8 deletions specifyweb/backend/accounts/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,43 @@

logger = logging.getLogger(__name__)

def _normalize_next_url(next_url: str) -> str:
prefix = '/accounts/choose_collection/?next='
return (
unquote_plus(next_url[len(prefix):])
if next_url.startswith(prefix)
else next_url
)

def _redirect_with_next(path: str, next_url: str) -> http.HttpResponseRedirect:
return http.HttpResponseRedirect(
path if len(next_url) == 0 else f'{path}?{urlencode({"next": next_url})}'
)

def _legacy_login_redirect(next_url: str) -> http.HttpResponseRedirect:
return _redirect_with_next(
'/accounts/legacy_login/',
(
'/accounts/choose_collection/'
if len(next_url) == 0
else f'/accounts/choose_collection/?{urlencode({"next": next_url})}'
),
)

def _get_next_url(
request: http.HttpRequest, oauth_login: OAuthLogin | None = None
) -> str:
if request.method == 'POST':
return request.POST.get('next', '')
return request.GET.get('next', oauth_login['next'] if oauth_login else '')

@require_http_methods(['GET', 'POST'])
def oic_login(request: http.HttpRequest) -> http.HttpResponse:
"""Initiates the OpenId Connect login process by providing the list of
available providers, then redirecting to the one chosen.
"""
if request.method == 'POST':
next_url = _normalize_next_url(_get_next_url(request))
provider = request.POST['provider']
provider_info_dict = settings.OAUTH_LOGIN_PROVIDERS[provider]
assert is_provider_info(provider_info_dict), "provider_info_dict does not match ProviderInfo structure"
Expand All @@ -65,6 +96,7 @@ def oic_login(request: http.HttpRequest) -> http.HttpResponse:
'state': state,
'provider': provider,
'provider_conf': provider_conf,
'next': next_url,
}
request.session['oauth_login'] = oauth_login

Expand All @@ -82,7 +114,11 @@ def oic_login(request: http.HttpRequest) -> http.HttpResponse:
{'provider': p, 'title': d['title']}
for p, d in settings.OAUTH_LOGIN_PROVIDERS.items()
]
return render(request, "oic_login.html", context={'providers': providers})
return render(
request,
"oic_login.html",
context={'providers': providers, 'next': _get_next_url(request)},
)

@openapi(schema={
"get": {
Expand Down Expand Up @@ -121,12 +157,20 @@ def oic_providers(request: http.HttpRequest) -> http.HttpResponse:
@require_GET
def oic_callback(request: http.HttpRequest) -> http.HttpResponse:
"""Handles the return callback from the OIC identity provider. """
oauth_login: OAuthLogin = request.session['oauth_login']
del request.session['oauth_login'] # not really necessary, but might as well clean up.
assert crypto.constant_time_compare(request.GET['state'], oauth_login['state'])
oauth_login: OAuthLogin | None = request.session.pop('oauth_login', None)
next_url = _normalize_next_url(_get_next_url(request, oauth_login))

if oauth_login is None:
logger.warning('OAuth callback received without session state.')
return _redirect_with_next('/accounts/login/', next_url)

if not crypto.constant_time_compare(request.GET.get('state', ''), oauth_login['state']):
logger.warning('OAuth callback received mismatched state.')
return _redirect_with_next('/accounts/login/', next_url)

if 'error' in request.GET:
logger.error('OAuth error: %s', request.GET)
return http.HttpResponseRedirect('/accounts/login/')
return _redirect_with_next('/accounts/login/', next_url)

provider = oauth_login['provider']
provider_conf: ProviderConf = oauth_login['provider_conf']
Expand Down Expand Up @@ -168,7 +212,7 @@ def oic_callback(request: http.HttpRequest) -> http.HttpResponse:
)
del request.session['invite_token']
login(request, user, backend='django.contrib.auth.backends.ModelBackend')
return http.HttpResponseRedirect('/accounts/choose_collection/')
return _redirect_with_next('/accounts/choose_collection/', next_url)

try:
spuserexternalid = Spuserexternalid.objects.get(provider=provider, providerid=str(ext_user['sub']))
Expand All @@ -183,7 +227,7 @@ def oic_callback(request: http.HttpRequest) -> http.HttpResponse:
'idtoken': ext_user,
}
request.session['external_user'] = external_user
return http.HttpResponseRedirect('/accounts/legacy_login/')
return _legacy_login_redirect(next_url)

if not spuserexternalid.enabled:
return http.HttpResponse("Logins with this identity are disabled.", content_type="text/plain")
Expand All @@ -195,7 +239,7 @@ def oic_callback(request: http.HttpRequest) -> http.HttpResponse:
login(request,
cast(AbstractBaseUser, spuserexternalid.specifyuser),
backend='django.contrib.auth.backends.ModelBackend')
return http.HttpResponseRedirect('/accounts/choose_collection')
return _redirect_with_next('/accounts/choose_collection/', next_url)

@require_GET
@login_maybe_required
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ import React from 'react';

import { commonText } from '../../localization/common';
import { userText } from '../../localization/user';
import { redirectToLoginWithResume } from '../../utils/authResume';
import type { RA } from '../../utils/types';
import { Button } from '../Atoms/Button';
import type { AnyTree } from '../DataModel/helperTypes';
import { schema } from '../DataModel/schema';
import type { Tables } from '../DataModel/types';
import { userInformation } from '../InitialContext/userInformation';
import { Dialog } from '../Molecules/Dialog';
import { formatUrl } from '../Router/queryString';
import type { toolDefinitions } from '../Security/registry';
import {
partsToResourceName,
Expand Down Expand Up @@ -230,11 +230,7 @@ export function PermissionError({
buttons={userText.logIn()}
forceToTop
header={userText.sessionTimeOut()}
onClose={(): void =>
globalThis.location.assign(
formatUrl('/accounts/login/', { next: globalThis.location.href })
)
}
onClose={(): void => redirectToLoginWithResume()}
>
{userText.sessionTimeOutDescription()}
</Dialog>
Expand Down
60 changes: 54 additions & 6 deletions specifyweb/frontend/js_src/lib/components/QueryBuilder/Wrapped.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ import React from 'react';
import { useUnloadProtect } from '../../hooks/navigation';
import { useResource } from '../../hooks/resource';
import { useAsyncState } from '../../hooks/useAsyncState';
import { useAuthResume } from '../../hooks/useAuthResume';
import { useBooleanState } from '../../hooks/useBooleanState';
import { useCachedState } from '../../hooks/useCachedState';
import { useErrorContext } from '../../hooks/useErrorContext';
import { commonText } from '../../localization/common';
import { queryText } from '../../localization/query';
import {
consumeAuthResumePayload,
currentAuthResumeUrl,
} from '../../utils/authResume';
import { smoothScroll } from '../../utils/dom';
import { f } from '../../utils/functools';
import type { RA } from '../../utils/types';
Expand Down Expand Up @@ -42,6 +47,13 @@ import {
import { getMappingLineData } from '../WbPlanView/navigator';
import { navigatorSpecs } from '../WbPlanView/navigatorSpecs';
import { datasetVariants } from '../WbUtils/datasetVariants';
import {
type QueryBuilderResumePayload,
queryBuilderFlagsFromQuery,
queryBuilderFlagsRequireSave,
queryBuilderResumeKind,
restoreQueryBuilderState,
} from './authResume';
import { CheckReadAccess } from './CheckReadAccess';
import { MakeRecordSetButton } from './Components';
import { IsQueryBasicContext, useQueryViewPref } from './Context';
Expand Down Expand Up @@ -100,25 +112,36 @@ function Wrapped({
readonly isSeries: boolean | null;
}) => void;
}): JSX.Element {
const restoredSnapshot = React.useRef(
consumeAuthResumePayload<QueryBuilderResumePayload>(
queryBuilderResumeKind,
currentAuthResumeUrl()
)
).current;
const [query, setQuery] = useResource(queryResource);
useErrorContext('query', query);

const [treeRanksLoaded = false] = useAsyncState(fetchTreeRanks, false);

const table = getTableById(query.contextTableId);
const [selectedRows, setSelectedRows] = React.useState<ReadonlySet<number>>(
new Set()
() => new Set(restoredSnapshot?.selectedRows ?? [])
);

const buildInitialState = React.useCallback(
const baseInitialState = React.useMemo(
() =>
getInitialState({
query,
queryResource,
table,
autoRun,
}),
[queryResource, table, autoRun]
[query, queryResource, table, autoRun]
);

const initialState = React.useMemo(
() => restoreQueryBuilderState(baseInitialState, restoredSnapshot),
[baseInitialState, restoredSnapshot]
);

const [showMappingView = true, _] = useCachedState(
Expand All @@ -133,15 +156,24 @@ function Wrapped({
const [saveRequired, setSaveRequired] = React.useState(false);

React.useEffect(() => {
const initialState = buildInitialState();
dispatch({
type: 'ResetStateAction',
state: initialState,
});
initialFields.current = JSON.stringify(initialState.fields);
}, [buildInitialState]);
initialFields.current = JSON.stringify(baseInitialState.fields);
}, [baseInitialState, initialState]);
useErrorContext('state', state);

React.useEffect(() => {
if (restoredSnapshot === undefined) return;
setQuery({
...query,
...restoredSnapshot.query,
});
if (queryBuilderFlagsRequireSave(query, restoredSnapshot))
setSaveRequired(true);
}, []);

const checkForChanges = React.useMemo(
() =>
throttle(
Expand Down Expand Up @@ -183,6 +215,22 @@ function Wrapped({

const promptToSave = saveRequired && !isEmbedded;

useAuthResume(() =>
state === pendingState || (!promptToSave && state.queryRunCount === 0)
? undefined
: {
version: 1,
createdAt: Date.now(),
kind: queryBuilderResumeKind,
url: currentAuthResumeUrl(),
payload: {
query: queryBuilderFlagsFromQuery(query),
selectedRows: Array.from(selectedRows),
state,
},
}
);

const unsetUnloadProtect = useUnloadProtect(
promptToSave,
queryText.queryUnloadProtect()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import type { RA } from '../../utils/types';
import type { SerializedResource } from '../DataModel/helperTypes';
import type { SpQuery } from '../DataModel/types';
import type { MainState } from './reducer';

export const queryBuilderResumeKind = 'query-builder';

export type QueryBuilderResumePayload = {
readonly query: Pick<
SerializedResource<SpQuery>,
'countOnly' | 'searchSynonymy' | 'selectDistinct' | 'smushed'
>;
readonly selectedRows: RA<number>;
readonly state: MainState;
};

export function queryBuilderFlagsFromQuery(
query: SerializedResource<SpQuery>
): QueryBuilderResumePayload['query'] {
return {
countOnly: query.countOnly,
searchSynonymy: query.searchSynonymy,
selectDistinct: query.selectDistinct,
smushed: query.smushed,
};
}

export function restoreQueryBuilderState(
baseState: MainState,
snapshot: QueryBuilderResumePayload | undefined
): MainState {
return snapshot === undefined ||
snapshot.state.baseTableName !== baseState.baseTableName
? baseState
: {
...snapshot.state,
baseTableName: baseState.baseTableName,
};
}

export function queryBuilderFlagsRequireSave(
query: SerializedResource<SpQuery>,
snapshot: QueryBuilderResumePayload | undefined
): boolean {
if (snapshot === undefined) return false;
const flags = queryBuilderFlagsFromQuery(query);
return (
flags.countOnly !== snapshot.query.countOnly ||
flags.searchSynonymy !== snapshot.query.searchSynonymy ||
flags.selectDistinct !== snapshot.query.selectDistinct ||
flags.smushed !== snapshot.query.smushed
);
}
Loading
Loading