%%%-------------------------------------------------------------------
%%% @author Lukasz Opiola
%%% @copyright (C) 2018 ACK CYFRONET AGH
%%% This software is released under the MIT license
%%% cited in 'LICENSE.txt'.
%%% @end
%%%-------------------------------------------------------------------
%%% @doc
%%% This module handles signing in to Onezone via different IdPs using
%%% Single Sign-On protocols (OpenId or SAML).
%%% @end
%%%-------------------------------------------------------------------
-module(idp_auth).

-include("auth/auth_common.hrl").
-include("auth/auth_errors.hrl").
-include("http/gui_paths.hrl").
-include("datastore/oz_datastore_models.hrl").
-include_lib("ctool/include/onedata.hrl").
-include_lib("ctool/include/logging.hrl").
-include_lib("ctool/include/errors.hrl").

% URL in IdP where the client should be redirected for authentication
-type login_endpoint() :: binary().
% URL where the client should be redirected after logging in to an IdP
-type redirect_uri() :: binary().
% Map of parsed parameters received in the query string on the consume endpoint (Key => Val)
-type query_params() :: #{}.
% OpenID access token, used to retrieve user info from an IdP
-type access_token() :: binary().
-type access_token_ttl() :: time:seconds().
% Refresh token - used to refresh access tokens
-type refresh_token() :: binary().
-export_type([
    login_endpoint/0,
    redirect_uri/0,
    query_params/0,
    access_token/0,
    access_token_ttl/0,
    refresh_token/0
]).

-type protocol_handler() :: saml_protocol | openid_protocol.

-define(NOW_SECONDS(), global_clock:timestamp_seconds()).
-define(REFRESH_THRESHOLD, oz_worker:get_env(idp_access_token_refresh_threshold, 300)).

%% API
-export([get_login_endpoint/4, validate_login/2]).
-export([acquire_idp_access_token/2, refresh_idp_access_token/2]).

%%%===================================================================
%%% API functions
%%%===================================================================

%%--------------------------------------------------------------------
%% @doc
%% Returns an URL in where clients should be redirected for authentication
%% (either via SAML or OIDC) based on IdP.
%% Returns a map that includes three keys:
%%      <<"method">>
%%      <<"url">>
%%      <<"formData">>
%% that defines what request should be performed to redirect to the login page.
%% @end
%%--------------------------------------------------------------------
-spec get_login_endpoint(auth_config:idp(), LinkAccount :: false | {true, od_user:id()},
    RedirectAfterLogin :: binary(), TestMode :: boolean()) ->
    {ok, map()} | {error, term()}.
get_login_endpoint(IdP, LinkAccount, RedirectAfterLogin, TestMode) ->
    TestMode andalso idp_auth_test_mode:process_enable_test_mode(),
    try
        {ok, StateToken} = state_token:create(
            IdP, LinkAccount, RedirectAfterLogin, TestMode
        ),
        Handler = get_protocol_handler(IdP),
        case Handler:get_login_endpoint(IdP, StateToken) of
            {ok, Result} ->
                ?debug("Redirecting for login to IdP '~tp' (state: ~ts):~n~tp", [
                    IdP, StateToken, Result
                ]),
                {ok, Result};
            {error, _} = Err ->
                Err
        end
    catch
        Type:Reason:Stacktrace ->
            ?error_stacktrace(
                "Cannot resolve redirect URL for IdP '~tp' - ~tp:~tp",
                [IdP, Type, Reason],
                Stacktrace
            ),
            ?ERR_INTERNAL_SERVER_ERROR(?err_ctx(), undefined)
    end.


%%--------------------------------------------------------------------
%% @doc
%% Validates an incoming login request based on received data payload.
%% @end
%%--------------------------------------------------------------------
-spec validate_login(gui:method(), cowboy_req:req()) ->
    {ok, od_user:id(), RedirectPage :: binary()} |
    {auth_error, {error, term()}, state_token:state_token(), RedirectPage :: binary()}.
validate_login(Method, Req) ->
    {StateToken, Payload} = parse_payload(Method, Req),
    case state_token:lookup(StateToken) of
        error ->
            % This state token was not generated by us or has expired
            log_error(?ERROR_INVALID_STATE, undefined, StateToken, []),
            {auth_error, ?ERROR_INVALID_STATE, StateToken, <<?LOGIN_PAGE_PATH>>};
        {ok, #{idp := IdP, redirect_after_login := RedirectAfterLogin} = StateInfo} ->
            try validate_login_by_state(Payload, StateToken, StateInfo) of
                {ok, UserId} ->
                    {ok, UserId, RedirectAfterLogin};
                {error, Reason} ->
                    log_error({error, Reason}, IdP, StateToken, []),
                    {auth_error, {error, Reason}, StateToken, RedirectAfterLogin}
            catch
                throw:{error, _} = Error:Stacktrace ->
                    log_error(Error, IdP, StateToken, Stacktrace),
                    {auth_error, Error, StateToken, RedirectAfterLogin};
                Type:Reason:Stacktrace ->
                    log_error({Type, Reason}, IdP, StateToken, Stacktrace),
                    {auth_error, ?ERR_INTERNAL_SERVER_ERROR(?err_ctx(), undefined), StateToken, RedirectAfterLogin}
            end
    end.


%%--------------------------------------------------------------------
%% @doc
%% Acquires an access token for given user, issued by given IdP.
%% Returns ?ERROR_NOT_FOUND when:
%%  * the user does not have an account in such IdP
%%  * there is no access token stored
%%  * the stored access token has expired and there is no viable refresh token
%% Can return ?ERR_INTERNAL_SERVER_ERROR in case token refresh goes wrong.
%% @end
%%--------------------------------------------------------------------
-spec acquire_idp_access_token(od_user:record(), auth_config:idp()) ->
    {ok, {access_token(), access_token_ttl()}} | {error, term()}.
acquire_idp_access_token(#od_user{blocked = true}, _) ->
    ?ERR_USER_BLOCKED(?err_ctx());
acquire_idp_access_token(#od_user{linked_accounts = LinkedAccounts}, IdP) ->
    lists:foldl(fun
        (_LinkedAccount, {ok, Result}) ->
            {ok, Result};
        (#linked_account{idp = CurrentIdP} = LinkedAcc, _) when CurrentIdP == IdP ->
            acquire_idp_access_token(LinkedAcc);
        (_, Acc) ->
            Acc
    end, ?ERROR_NOT_FOUND, LinkedAccounts).

%% @private
-spec acquire_idp_access_token(od_user:linked_account()) ->
    {ok, {access_token(), access_token_ttl()}} | {error, term()}.
acquire_idp_access_token(#linked_account{access_token = {undefined, 0}, refresh_token = _}) ->
    ?ERROR_NOT_FOUND;
acquire_idp_access_token(#linked_account{access_token = {AccessToken, Expires}, refresh_token = undefined}) ->
    % No refresh token - no point in trying to refresh the access token
    Now = ?NOW_SECONDS(),
    case Expires > Now of
        true -> {ok, {AccessToken, Expires - Now}};
        false -> ?ERROR_NOT_FOUND
    end;
acquire_idp_access_token(#linked_account{idp = IdP, access_token = {AccessToken, Expires}, refresh_token = RefreshToken}) ->
    Now = ?NOW_SECONDS(),
    case Expires - ?REFRESH_THRESHOLD > Now of
        true -> {ok, {AccessToken, Expires - Now}};
        false -> refresh_idp_access_token(IdP, RefreshToken)
    end.


%%--------------------------------------------------------------------
%% @doc
%% Acquires a new access token using given refresh token.
%% @end
%%--------------------------------------------------------------------
-spec refresh_idp_access_token(auth_config:idp(), refresh_token()) ->
    {ok, {idp_auth:access_token(), idp_auth:access_token_ttl()}} | errors:error().
refresh_idp_access_token(IdP, RefreshToken) ->
    try
        {ok, Attributes} = openid_protocol:refresh_idp_access_token(IdP, RefreshToken),
        LinkedAccount = attribute_mapping:map_attributes(IdP, Attributes),
        case linked_accounts:acquire_user(LinkedAccount) of
            {error, _} = Error ->
                log_error(Error, IdP, <<"refresh_token_flow">>, []),
                Error;
            {ok, _} ->
                #linked_account{access_token = {AccessToken, Expires}} = LinkedAccount,
                {ok, {AccessToken, Expires - ?NOW_SECONDS()}}
        end
    catch
        throw:ThrownError:Stacktrace ->
            log_error(ThrownError, IdP, <<"refresh_token_flow">>, Stacktrace),
            ?ERR_INTERNAL_SERVER_ERROR(?err_ctx(), undefined);
        Type:Reason:Stacktrace ->
            log_error({Type, Reason}, IdP, <<"refresh_token_flow">>, Stacktrace),
            ?ERR_INTERNAL_SERVER_ERROR(?err_ctx(), undefined)
    end.


%%%===================================================================
%%% Internal functions
%%%===================================================================

%% @private
-spec validate_login_by_state(Payload :: #{}, state_token:state_token(), state_token:state_info()) ->
    {ok, od_user:id()} | {error, term()}.
validate_login_by_state(Payload, StateToken, #{idp := IdP, test_mode := TestMode, link_account := LinkAccount}) ->
    TestMode andalso idp_auth_test_mode:process_enable_test_mode(),
    TestMode andalso idp_auth_test_mode:store_state_token(StateToken),
    Handler = get_protocol_handler(IdP),
    ?auth_debug("Login attempt from IdP '~tp' (state: ~ts), payload:~n~tp", [
        IdP, StateToken, Payload
    ]),

    {ok, Attributes} = Handler:validate_login(IdP, Payload),
    % Do not print sensitive information
    PrintableAttributes = maps:without([<<"access_token">>, <<"refresh_token">>], Attributes),
    ?auth_debug("Login from IdP '~tp' (state: ~ts) validated, attributes:~n~ts", [
        IdP, StateToken, json_utils:encode(PrintableAttributes, [pretty])
    ]),

    LinkedAccount = attribute_mapping:map_attributes(IdP, Attributes),
    LinkedAccountMap = maps:without(
        [<<"name">>, <<"login">>, <<"alias">>, <<"emailList">>, <<"groups">>], % do not include deprecated fields
        linked_accounts:to_map(LinkedAccount, all_fields)
    ),
    ?auth_debug("Attributes from IdP '~tp' (state: ~ts) sucessfully mapped:~n~ts", [
        IdP, StateToken, json_utils:encode(LinkedAccountMap, [pretty])
    ]),

    case {idp_auth_test_mode:process_is_test_mode_enabled(), LinkAccount} of
        {true, _} ->
            validate_test_login_by_linked_account(LinkedAccount);
        {false, false} ->
            validate_login_by_linked_account(LinkedAccount);
        {false, {true, UserId}} ->
            validate_link_account_request(LinkedAccount, UserId)
    end.


%% @private
-spec validate_login_by_linked_account(od_user:linked_account()) ->
    {ok, od_user:id()} | errors:error().
validate_login_by_linked_account(LinkedAccount) ->
    case linked_accounts:acquire_user(LinkedAccount) of
        {error, _} = Error ->
            Error;
        {ok, #document{key = UserId, value = #od_user{full_name = FullName}}} ->
            ?info("User '~ts' has logged in (~ts)", [FullName, UserId]),
            {ok, UserId}
    end.


%% @private
-spec validate_link_account_request(od_user:linked_account(), od_user:id()) ->
    {ok, od_user:id()} | {error, term()}.
validate_link_account_request(LinkedAccount, TargetUserId) ->
    % Check if this account isn't connected to other profile
    case linked_accounts:find_user(LinkedAccount) of
        {ok, #document{key = FoundUserId}} ->
            % Synchronize the information regardless of account linking success
            linked_accounts:merge(FoundUserId, LinkedAccount),
            case FoundUserId of
                TargetUserId ->
                    % The account is already linked to this user, report error
                    ?ERROR_ACCOUNT_ALREADY_LINKED_TO_CURRENT_USER(TargetUserId);
                OtherUserId ->
                    % The account is used on some other profile, cannot proceed
                    ?ERROR_ACCOUNT_ALREADY_LINKED_TO_ANOTHER_USER(TargetUserId, OtherUserId)
            end;
        {error, not_found} ->
            % ok, add new linked account to the user
            {ok, #document{value = #od_user{
                full_name = FullName
            }}} = linked_accounts:merge(TargetUserId, LinkedAccount),
            ?info("User ~ts (~ts) has linked his account from '~tp'", [
                FullName, TargetUserId, LinkedAccount#linked_account.idp
            ]),
            {ok, TargetUserId}
    end.


%%--------------------------------------------------------------------
%% @private
%% @doc
%% Rather than creating a user and his groups, stores info that was gathered
%% in the test login process for later use by the page_consume_login module.
%% @end
%%--------------------------------------------------------------------
-spec validate_test_login_by_linked_account(od_user:linked_account()) ->
    {ok, od_user:id()}.
validate_test_login_by_linked_account(LinkedAccount) ->
    {UserId, UserData} = linked_accounts:build_test_user_info(LinkedAccount),
    idp_auth_test_mode:store_user_data(UserData),
    {ok, UserId}.


%% @private
-spec get_protocol_handler(auth_config:idp()) -> protocol_handler().
get_protocol_handler(IdP) ->
    case auth_config:get_protocol(IdP) of
        saml -> saml_protocol;
        openid -> openid_protocol
    end.


%%--------------------------------------------------------------------
%% @private
%% @doc
%% Parses OIDC / SAML payload and returns it as a map, along with resolved state token.
%% @end
%%--------------------------------------------------------------------
-spec parse_payload(gui:method(), cowboy_req:req()) ->
    {state_token:state_token(), Payload :: #{}}.
parse_payload(<<"POST">>, Req) ->
    {ok, PostBody, _} = cowboy_req:read_urlencoded_body(Req, #{length => 128000}),
    StateToken = proplists:get_value(<<"RelayState">>, PostBody, <<>>),
    {StateToken, maps:from_list(PostBody)};
parse_payload(<<"GET">>, Req) ->
    QueryParams = cowboy_req:parse_qs(Req),
    StateToken = proplists:get_value(<<"state">>, QueryParams, <<>>),
    {StateToken, maps:from_list(QueryParams)}.


%%--------------------------------------------------------------------
%% @private
%% @doc
%% Logs authentication errors with proper severity and message, depending on the
%% error type.
%% @end
%%--------------------------------------------------------------------
-spec log_error({Type :: term(), Reason :: term()}, auth_config:idp(),
    state_token:state_token(), Stacktrace :: term()) -> ok.
log_error(?ERROR_BAD_AUTH_CONFIG, _, _, Stacktrace) ->
    ?auth_debug(
        "Login request failed due to bad auth config: ~ts", [
            iolist_to_binary(lager:pr_stacktrace(Stacktrace))
        ]
    );
log_error(?ERROR_INVALID_STATE, _, StateToken, _) ->
    ?auth_debug(
        "Cannot validate login request - invalid state ~ts (not found)",
        [StateToken]
    );
log_error(?ERROR_INVALID_AUTH_REQUEST, IdP, StateToken, Stacktrace) ->
    ?auth_debug(
        "Cannot validate login request for IdP '~tp' (state: ~ts) - invalid auth request~n"
        "Stacktrace: ~ts", [IdP, StateToken, iolist_to_binary(lager:pr_stacktrace(Stacktrace))]
    );
log_error(?ERR_USER_BLOCKED, IdP, StateToken, _) ->
    ?auth_debug(
        "Declining login request for IdP '~tp' (state: ~ts) - the user is blocked",
        [IdP, StateToken]
    );
log_error(?ERROR_IDP_UNREACHABLE(Reason), IdP, StateToken, _) ->
    ?auth_warning(
        "Cannot validate login request for IdP '~tp' (state: ~ts) - IdP not reachable: ~tp",
        [IdP, StateToken, Reason]
    );
log_error(?ERROR_BAD_IDP_RESPONSE(Endpoint, Code, Headers, Body), IdP, StateToken, _) ->
    ?auth_warning(
        "Cannot validate login request for IdP '~tp' (state: ~ts) - unexpected response from IdP:~n"
        "Endpoint: ~ts~n"
        "Code: ~tp~n"
        "Headers: ~tp~n"
        "Body: ~ts",
        [IdP, StateToken, Endpoint, Code, Headers, Body]
    );
log_error(?ERROR_CANNOT_RESOLVE_REQUIRED_ATTRIBUTE(Attr), IdP, StateToken, _) ->
    ?auth_debug(
        "Cannot map attributes for IdP '~tp' (state: ~ts) - atrribute '~tp' not found",
        [IdP, StateToken, Attr]
    );
log_error(?ERROR_BAD_ATTRIBUTE_TYPE(Attribute, Type), IdP, StateToken, _) ->
    ?auth_debug(
        "Cannot map attributes for IdP '~tp' (state: ~ts) - atrribute '~tp' "
        "does not have the required type '~tp'",
        [IdP, StateToken, Attribute, Type]
    );
log_error(?ERROR_ATTRIBUTE_MAPPING_ERROR(Attribute, IdPAttributes, EType, EReason, Stacktrace), IdP, StateToken, _) ->
    ?auth_debug(
        "Cannot map attributes for IdP '~tp' (state: ~ts) - atrribute '~tp' "
        "could not be mapped due to an error - ~tp:~tp~n"
        "IdP attributes: ~tp~n"
        "Stacktrace: ~ts",
        [IdP, StateToken, Attribute, EType, EReason, IdPAttributes, iolist_to_binary(lager:pr_stacktrace(Stacktrace))]
    );
log_error(?ERROR_ACCOUNT_ALREADY_LINKED_TO_CURRENT_USER(UserId), IdP, StateToken, _) ->
    ?auth_debug(
        "Cannot link account from IdP '~tp' for user '~ts' (state: ~ts) - account already linked to the user",
        [IdP, UserId, StateToken]
    );
log_error(?ERROR_ACCOUNT_ALREADY_LINKED_TO_ANOTHER_USER(UserId, OtherUserId), IdP, StateToken, _) ->
    ?auth_debug(
        "Cannot link account from IdP '~tp' for user '~ts' (state: ~ts) - account already linked to user '~ts'",
        [IdP, UserId, StateToken, OtherUserId]
    );
log_error(?ERR_INTERNAL_SERVER_ERROR(Ref), IdP, StateToken, _) ->
    % The logging is already done when throwing this error
    ?auth_debug(
        "Cannot validate login request for IdP '~tp' (state: ~ts) - internal server error (ref: ~ts)",
        [IdP, StateToken, Ref]
    );
log_error(Error, IdP, StateToken, Stacktrace) ->
    ?auth_error(
        "Cannot validate login request for IdP '~tp' (state: ~ts) - ~tp~n"
        "Stacktrace: ~ts", [IdP, StateToken, Error, iolist_to_binary(lager:pr_stacktrace(Stacktrace))]
    ).
