%%%-------------------------------------------------------------------
%%% @author Konrad Zemek
%%% @copyright (C) 2017 ACK CYFRONET AGH
%%% This software is released under the MIT license
%%% cited in 'LICENSE.txt'.
%%% @end
%%%-------------------------------------------------------------------
%%% @doc
%%% @end
%%%-------------------------------------------------------------------
-module(rtransfer_link_connection).
-author("Konrad Zemek").

-behaviour(gen_statem).

%%%===================================================================
%%% Type definitions
%%%===================================================================

-type request() :: rtransfer_link_request:t().

-type state() :: disconnected | connected.

-type data() :: #{
            provider_id := binary(),
            requests := #{binary() => request()},
            backoff := backoff:backoff(),
            retries_left := non_neg_integer(),
            hostname := binary(),
            port := inet:port_number(),
            conn_id => binary(),
            request_servers_map => #{non_neg_integer() => pid()}
           }.

%%%===================================================================
%%% Exports
%%%===================================================================

-export([start_link/3, fetch/2, cancel/2, disconnected/1]).
-export([init/1, callback_mode/0, connected/3, disconnected/3,
         code_change/4, terminate/3]).

%%%===================================================================
%%% API
%%%===================================================================

-spec start_link(ProviderId :: binary(), Hostname :: binary(),
                 Port :: inet:port_number()) -> {ok, pid()} | {error, any()}.
start_link(ProviderId, Hostname, Port) ->
    gen_statem:start_link(?MODULE, {ProviderId, Hostname, Port}, []).

-spec fetch(Conn :: pid(), Req :: request()) -> ok.
fetch(Conn, #{} = Req) ->
    gen_statem:cast(Conn, {fetch, Req}).

-spec cancel(Conn :: pid(), Req :: request()) -> ok.
cancel(Conn, #{} = Req) ->
    gen_statem:cast(Conn, {cancel, Req}).

-spec disconnected(ConnectionId :: binary()) -> ok.
disconnected(ConnId) ->
    Conn = gproc:lookup_local_name({?MODULE, ConnId}),
    gen_statem:cast(Conn, disconnected).

%%%===================================================================
%%% gen_statem callbacks
%%%===================================================================

-spec init({Hostname :: binary(), ProviderId :: binary(),
            Port :: inet:port_number()}) -> {ok, state(), data(), gen_statem:action()}.
init({ProviderId, Hostname, Port}) ->
    Data = #{
      provider_id => ProviderId,
      requests => #{},
      backoff => backoff:type(backoff:init(2000, 30000), jitter),
      retries_left => 5,
      hostname => Hostname,
      port => Port,
      request_servers_map => #{}
     },
    {ok, disconnected, Data, {next_event, internal, request_connection}}.

callback_mode() ->
    state_functions.

disconnected(internal, reconnect, #{retries_left := 0, hostname := Hostname, port := Port}) ->
    lager:warning("Connection to ~s:~B will no longer try to reconnect", [Hostname, Port]),
    {stop, {connection_error, no_retries_left}};

disconnected(internal, reconnect, #{retries_left := RetriesLeft, backoff := Backoff} = Data) ->
    {_, NewBackoff} = backoff:fail(Backoff),
    NewData = Data#{retries_left := RetriesLeft - 1, backoff := NewBackoff},
    NextBackoff = backoff:get(Backoff),
    lager:info("Retrying connection in ~B milliseconds", [NextBackoff]),
    {keep_state, NewData, {state_timeout, NextBackoff, request_connection}};

disconnected(state_timeout, connection_timeout, #{hostname := Hostname, port := Port}) ->
    lager:warning("Reconnecting to ~s:~B due to connection timeout", [Hostname, Port]),
    {keep_state_and_data, {next_event, internal, reconnect}};

disconnected(_InternalOrTimer, request_connection, Data) ->
    #{provider_id := ProviderId, hostname := Hostname, port := Port} = Data,
    try rtransfer_link_callback:get_connection_secret(ProviderId, {Hostname, Port}) of
        {MySecret, PeerSecret} ->
            Req = #{connect => #{my_secret => base64:encode(MySecret),
                                 peer_secret => base64:encode(PeerSecret),
                                 peer_host => Hostname, peer_port => Port}},
            rtransfer_link_port:request(Req),
            ConnTimeout = application:get_env(rtransfer_link, connection_timeout, 20000),
            {keep_state_and_data, {state_timeout, ConnTimeout, connection_timeout}};
        Other ->
            lager:error("Bad response from get_connection_secret(~p, ~p): ~p",
                        [ProviderId, {Hostname, Port}, Other]),
            {keep_state_and_data, {next_event, internal, reconnect}}
    catch
        Class:Error:Stack ->
            lager:error("~p:~p while executing get_connection_secret(~p, ~p)",
                        [Class, Error, ProviderId, {Hostname, Port}]),
            lager:warning("Stacktrace: ~p", [Stack]),
            {keep_state_and_data, {next_event, internal, reconnect}}
    end;

disconnected(cast, {fetch, #{req_id := ReqId} = Req}, #{requests := Requests} = Data) ->
    NewRequests = maps:put(ReqId, Req, Requests),
    {keep_state, Data#{requests := NewRequests}};

disconnected(cast, {cancel, #{ref := Ref, req_id := ReqId, on_complete := OnComplete}},
             #{requests := Requests} = Data) ->
    erlang:apply(OnComplete, [Ref, {error, <<"canceled">>}]),
    NewRequests = maps:remove(ReqId, Requests),
    {keep_state, Data#{requests := NewRequests}};

disconnected(info, {response, _ReqId, #{<<"connectionId">> := ConnId}},
             #{requests := Requests, request_servers_map := ServersMap} = Data) ->
    gproc:add_local_name({?MODULE, ConnId}),
    ServersMap2 = lists:foldl(fun(Req, TmpServersMap) ->
        do_fetch(Req, ConnId, TmpServersMap)
    end, ServersMap, maps:values(Requests)),
    {next_state, connected, Data#{requests := #{}, conn_id => ConnId,
        request_servers_map := ServersMap2}};

disconnected(info, {response, _ReqId, {error, Reason}}, #{hostname := Hostname, port := Port}) ->
    lager:warning("Failed to connect to ~s:~B due to ~p", [Hostname, Port, Reason]),
    {keep_state_and_data, {next_event, internal, reconnect}};

disconnected(EventType, Message, _Data) ->
    lager:warning("Unhandled event ~p: ~p while disconnected", [EventType, Message]),
    keep_state_and_data.

connected(cast, {fetch, Request}, #{conn_id := ConnId,
    request_servers_map := ServersMap} = Data) ->
    ServersMap2 = do_fetch(Request, ConnId, ServersMap),
    {keep_state, Data#{request_servers_map := ServersMap2}};

connected(cast, {cancel, Req}, #{conn_id := ConnId}) ->
    #{src_storage_id := SrcStorageId, dest_storage_id := DestStorageId, req_id := ReqId} = Req,
    CancelReq = #{cancel => #{connection_id => ConnId,
                              src_storage_id => base64:encode(SrcStorageId),
                              dest_storage_id => base64:encode(DestStorageId),
                              req_id => ReqId}},
    rtransfer_link_port:request(CancelReq),
    keep_state_and_data;

connected(cast, disconnected, #{hostname := Hostname, port := Port}) ->
    lager:warning("rtransfer connection to ~s:~B lost", [Hostname, Port]),
    {stop, {shutdown, disconnected}};

connected(info, {response, _, #{<<"done">> := true}}, _Data) ->
    keep_state_and_data;

connected(EventType, Message, _Data) ->
    lager:warning("Unhandled event ~p: ~p while connected", [EventType, Message]),
    keep_state_and_data.

terminate(_Reason, _State, _Data) ->
    ok.

code_change(_Vsn, State, Data, _Extra) ->
    {ok, State, Data}.

%%%===================================================================
%%% Helpers
%%%===================================================================

-spec do_fetch(request(), ConnId :: binary(), #{non_neg_integer() => pid()}) ->
    #{non_neg_integer() => pid()}.
do_fetch(Request, ConnId, ServersMap) ->
    MaxNum = application:get_env(rtransfer_link,
        request_servers_per_connection, 1000),
    Num = rand:uniform(MaxNum),
    {Pid, ServersMap2} = case maps:get(Num, ServersMap, undefined) of
        undefined ->
            {ok, NewPid} = rtransfer_link_request_server:start(),
            NewServersMap = maps:put(Num, NewPid, ServersMap),
            {NewPid, NewServersMap};
        ServerPid ->
            {ServerPid, ServersMap}
    end,
    try
        rtransfer_link_request_server:fetch(Pid, ConnId, Request),
        ServersMap2
    catch
        exit:{{shutdown, timeout}, _} ->
            ServersMap3 = maps:remove(Num, ServersMap2),
            do_fetch(Request, ConnId, ServersMap3);
        _:{noproc, _} ->
            ServersMap3 = maps:remove(Num, ServersMap2),
            do_fetch(Request, ConnId, ServersMap3);
        exit:{normal, _} ->
            ServersMap3 = maps:remove(Num, ServersMap2),
            do_fetch(Request, ConnId, ServersMap3)
    end.
