#! /usr/bin/env escript
%% This line tells emacs to use -*- erlang -*- mode -*- coding: iso-8859-1 -*-
%%! -pa _build/default/lib/enif_protobuf/ebin/ -pa tmp/ -sname protoexerciser

%%% Copyright (C) 2011  Tomas Abrahamsson
%%%
%%% Author: Tomas Abrahamsson <tab@lysator.liu.se>
%%%
%%% This library is free software; you can redistribute it and/or
%%% modify it under the terms of the GNU Lesser General Public
%%% License as published by the Free Software Foundation; either
%%% version 2.1 of the License, or (at your option) any later version.
%%%
%%% This library is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
%%% Lesser General Public License for more details.
%%%
%%% You should have received a copy of the GNU Lesser General Public
%%% License along with this library; if not, write to the Free Software
%%% Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
%%% MA  02110-1301  USA

-mode(compile).
%-compile(export_all).
-export([main/1]).

main(Args) ->
    run_tests(Args).

run_tests(["--echo", Str | Rest]) ->
    io:format("~s~n", [Str]),
    run_tests(Rest);
run_tests([MsgModuleStr, MsgNameStr | MsgFileAndRest]) ->
    MsgModule = list_to_atom(MsgModuleStr),
    MsgName = list_to_atom(MsgNameStr),
    {MsgFiles, Rest} = pickup_msg_file_or_files(MsgFileAndRest),
    MsgBins = [begin {{ok, B},_} = {file:read_file(MsgFile),MsgFile},
                     B
               end || MsgFile <- MsgFiles],
    DataSize = iolist_size(MsgBins),
    EMsgName = case encode_needs_msgname(MsgModule) of
                   true -> MsgName;
        false -> undefined
    end,
    ok = enif_protobuf:load_cache(MsgModule:get_msg_defs()),
    Decoder = fun() -> decode_bins(MsgBins, MsgModule, MsgName) end,
    Msgs = [MsgModule:decode_msg(Bin, MsgName) || Bin <- MsgBins],
    Encoder = fun() -> encode_msgs(Msgs, MsgModule, EMsgName) end,
    maybe_warn_skewed_results(),
    io:format("Benchmarking ~s ~s with file ~s~n",
              [MsgModule, MsgName, string:join(MsgFiles, ",")]),
    run_test("Serialize to binary", DataSize, Encoder),
    run_test("Deserialize from binary", DataSize, Decoder),
    io:format("~n"),
    run_tests(Rest);
run_tests([]) ->
    ok.

encode_needs_msgname(MsgModule) ->
    Exports = MsgModule:module_info(exports),
    case lists:keyfind(encode_msg, 1, Exports) of
        {encode_msg, 1} -> false; %% records
        {encode_msg, 2} -> true   %% maps
    end.

pickup_msg_file_or_files(["--multi" | MsgFilesAndRest]) ->
    EndMarker = "--end-multi",
    {MsgFiles, [EndMarker | Rest]} =
        lists:splitwith(fun(S) -> S /= EndMarker end, MsgFilesAndRest),
    {MsgFiles, Rest};
pickup_msg_file_or_files([MsgFile | Rest]) ->
    {[MsgFile], Rest}.


decode_bins([MsgBin | Rest], MsgModule, MsgName) ->
    enif_protobuf:decode(MsgBin, MsgName),
    decode_bins(Rest, MsgModule, MsgName);
decode_bins([], _, _) ->
    ok.

encode_msgs([Msg | Rest], MsgModule, undefined) ->
    enif_protobuf:encode(Msg), %% MsgName not needed for records
    encode_msgs(Rest, MsgModule, undefined);
encode_msgs([Msg | Rest], MsgModule, MsgName) ->
    enif_protobuf:encode(Msg, MsgName), %% MsgName needed for maps
    encode_msgs(Rest, MsgModule, MsgName);
encode_msgs([], _, _) ->
    ok.


maybe_warn_skewed_results() ->
    try list_to_integer(erlang:system_info(otp_release)) of
        N when N >= 22 ->
            io:format(
              "\n"
              "NB: The results are currently skewed on Erlang 22 and later.\n"
              "    Optimizations in Erlang 22 and the way the benchmarks\n"
              "    are implemented cause more GC than normal.\n"
              "    The benchmarking code is not yet adapted. It will show\n"
              "    pessimistic values.\n"
              "    See https://github.com/erlang/otp/commit/7d941c529d#commitcomment-31091771\n"
              "    for more info.\n"
              "\n");
        _ ->
            ok
    catch _:_ ->
            ok
    end.



run_test(Description, DataSize, Action) ->
    {P,M} = spawn_monitor(
              fun() ->
                      run_test_aux(Description, DataSize, Action)
        end),
    receive
        {'DOWN', M, process, P, normal} ->
            ok;
        {'DOWN', M, process, P, Other} ->
            error({aux_died, Other})
    end.

run_test_aux(Description, DataSize, Action) ->
    MinSampleTime = 2, %% seconds
    TargetTime = 30,   %% seconds
    {Elapsed, NumIterations} = iterate_until_elapsed(MinSampleTime, Action),
    TargetNumIterations = round((TargetTime / Elapsed) * NumIterations),
    Elapsed2 = time_action(TargetNumIterations, Action),
    io:format("~s: ~w iterations in ~.3fs; ~.2fMB/s~n",
              [Description, TargetNumIterations, Elapsed2,
               (TargetNumIterations * DataSize) / (Elapsed2 * 1024 * 1024)]),
    ok.

iterate_until_elapsed(MaxDuration, Action) ->
    iterate_until_elapsed_2(1, MaxDuration, Action).

iterate_until_elapsed_2(NumIterations, MaxDuration, Action) ->
    case time_action(NumIterations, Action) of
        Elapsed when Elapsed < MaxDuration ->
            iterate_until_elapsed_2(NumIterations * 2, MaxDuration, Action);
        Elapsed when Elapsed >= MaxDuration ->
            {Elapsed, NumIterations}
    end.

time_action(NumIterations, Action) ->
    garbage_collect(),
    T0 = os:timestamp(),
    iterate_action(NumIterations, Action),
    T1 = os:timestamp(),
    timer:now_diff(T1, T0) / 1000000.

iterate_action(N, Action) when N > 0 ->
    Action(),
    iterate_action(N - 1, Action);
iterate_action(0, _Action) ->
    ok.
