#include "shaper.hpp"

#include "shaperTimer.hpp"

#include <glog/logging.h>
#include <monitoring/monitoring.h>

#include <chrono>

DEFINE_uint64(single_fetch_max_size, 10 * 1024 * 1024,
    "maximum size of single read block");
DEFINE_uint64(
    shaper_initial_window, 10 * 1024 * 1024, "initial window size for shaper");
DEFINE_uint64(shaper_quantum_ms_size, 25,
    "the quantum size of shaper will be set to this value times bandwidth per "
    "second");

namespace sc = std::chrono;
using namespace std::chrono_literals;

namespace {

std::uint8_t clampPriority(std::uint32_t priority)
{
    return priority <= std::numeric_limits<std::uint8_t>::max()
               ? static_cast<std::uint8_t>(priority)
               : std::numeric_limits<std::uint8_t>::max();
}

// floor / round implementation from c++17

template <class T>
struct is_duration : std::false_type {
};
template <class Rep, class Period>
struct is_duration<sc::duration<Rep, Period>> : std::true_type {
};

template <class To, class Rep, class Period,
    class = std::enable_if_t<is_duration<To>{}>>
constexpr To floor(const sc::duration<Rep, Period> &d)
{
    To t = sc::duration_cast<To>(d);
    return t > d ? t - To{1} : t;
}

template <class To, class Rep, class Period,
    class = std::enable_if_t<is_duration<To>{} &&
                             !sc::treat_as_floating_point<typename To::rep>{}>>
constexpr To round(const sc::duration<Rep, Period> &d)
{
    To t0 = floor<To>(d);
    To t1 = t0 + To{1};
    auto diff0 = d - t0;
    auto diff1 = t1 - d;
    if (diff0 == diff1) {
        if (t0.count() & 1)
            return t1;
        return t0;
    }
    else if (diff0 < diff1) {
        return t0;
    }
    return t1;
}

}  // namespace

namespace rtransfer {

Shaper::RequestQueue::RequestQueue(const std::size_t simultaneous)
    : simultaneous_{simultaneous}
{
}

void Shaper::RequestQueue::push(std::shared_ptr<Request> req)
{
    queue_.emplace_back(std::move(req));
}

bool Shaper::RequestQueue::empty() const { return queue_.empty(); }

std::shared_ptr<Shaper::Request> Shaper::RequestQueue::front()
{
    if (idx_ >= std::min(simultaneous_, queue_.size()))
        idx_ = 0;

    return queue_[idx_++];
}

void Shaper::RequestQueue::pop()
{
    std::swap(queue_[idx_ - 1], queue_.front());
    --idx_;
    queue_.pop_front();
}

Shaper::Shaper(Reader &reader, ShaperTimer<Clock> &timer)
    : reader_{reader}
    , rtPropStamp_{Clock::now()}
    , pacingRate_{FLAGS_single_fetch_max_size * BBRHighGain / 1000}
    , timer_{timer}
{
    enterStartup();
}

folly::Future<std::size_t> Shaper::read(MsgPtr msg)
{
    return via(&executor_).then([this, msg = std::move(msg)]() mutable {
        return doRead(std::move(msg));
    });
}

folly::Future<std::size_t> Shaper::doRead(MsgPtr msg)
{
    auto req = std::make_shared<Request>(std::move(msg));
    auto &fetch = req->fetch->fetch();
    VLOG(1) << "Storing req_id: " << fetch.req_id() << " for shaping";

    auto priority = clampPriority(fetch.priority());
    auto f = req->promise.getFuture();
    reqs_.emplace(fetch.req_id(), req);

    const std::string m{"link.shaper." + std::to_string(priority) + ".queue"};
    ONE_METRIC_COUNTER_ADD(m + "_bytes", fetch.size());
    ONE_METRIC_COUNTER_ADD(m, 1);

    delayedReqsSize_ += fetch.size();
    delayedRequests_.push(std::move(req));

    sendPacket(Clock::now());

    return f;
}

folly::Future<folly::Unit> Shaper::cancel(MsgPtr msg)
{
    return via(&executor_).then([this, msg = std::move(msg)]() mutable {
        doCancel(std::move(msg));
    });
}

void Shaper::doCancel(MsgPtr msg)
{
    auto reqId = msg->cancel().req_id();
    auto it = reqs_.find(reqId);
    if (it != reqs_.end()) {
        it->second->promise.setException(
            folly::make_exception_wrapper<std::runtime_error>(
                "Fetch cancelled by user"));

        // If !it->second->fetch, then we fulfilled the whole request already
        // and cleared it->second->fetch by std::move in fulfilWholeRequest.
        if (it->second->fetch) {
            delayedReqsSize_ -= it->second->fetch->fetch().size();
            it->second->fetch.reset();
        }

        reqs_.erase(it);
    }

    auto range = unackedPerReq_.equal_range(reqId);
    for (auto offsetIt = range.first; offsetIt != range.second; ++offsetIt) {
        auto sentPacketIt = sentPackets_.find({reqId, offsetIt->second});
        if (sentPacketIt != sentPackets_.end()) {
            inFlight_ -= sentPacketIt->second.size;
            sentPackets_.erase(sentPacketIt);
        }
    }
    if (range.first != range.second)
        unackedPerReq_.erase(range.first, range.second);

    // inFlight_ may have changed
    sendPacket(Clock::now());
}

folly::Future<bool> Shaper::ack(
    std::uint64_t reqId, folly::fbvector<std::uint64_t> offsets)
{
    return via(&executor_, SHAPER_OPS_PRIO)
        .then([=, offsets = std::move(offsets)] {
            const auto now = Clock::now();
            for (std::uint64_t offset : offsets)
                doAck(now, reqId, offset);

            sendPacket(now);

            // true if nothing to send and nothing to ack
            return !unackedPerReq_.count(reqId) && !reqs_.count(reqId);
        });
}

void Shaper::doAck(
    Clock::time_point now, std::uint64_t reqId, std::uint64_t offset)
{
    auto it = sentPackets_.find(std::make_pair(reqId, offset));
    if (it == sentPackets_.end()) {
        LOG(WARNING) << "No packet to ack " << reqId << " " << offset;
        return;
    }

    const auto &packet = it->second;
    auto rtt = round<sc::microseconds>(now - packet.sentTime);

    auto range = unackedPerReq_.equal_range(reqId);
    auto unackedIt = std::find_if(range.first, range.second,
        [=](const auto &i) { return i.second == offset; });
    CHECK(unackedIt != range.second);
    unackedPerReq_.erase(unackedIt);

    VLOG(2) << "Acked " << packet.size << " bytes for reqId: " << reqId;

    inFlight_ -= packet.size;
    generateRateSample(packet, now);
    updateModelAndState(packet, now, rtt);
    updateControlParameters(packet);

    LOG_EVERY_N(INFO, 1000)
        << "rtt: " << (rtProp_ ? rtProp_->count() / 1000 : -1)
        << " ms, bw: " << btlBw_ << ", cwnd: " << cwnd_
        << ", targetCwnd: " << targetCwnd_ << ", inFlight: " << inFlight_
        << ", delayedReqsSize: " << delayedReqsSize_;

    sentPackets_.erase(it);
}

void Shaper::enterStartup()
{
    state_ = State::startup;
    pacingGain_ = BBRHighGain;
    cwndGain_ = BBRHighGain;
}

void Shaper::updateModelAndState(
    const Packet &packet, Clock::time_point now, sc::microseconds rtt)
{
    updateBtlBw(packet);
    checkCyclePhase(now);
    checkFullPipe();
    checkDrain(now);
    updateRtProp(now, rtt);
    checkProbeRTT(now);
}

void Shaper::updateBtlBw(const Packet &packet)
{
    updateRound(packet);
    if (rs_.deliveryRate >= btlBw_ || !rs_.isAppLimited)
        btlBw_ = btlBwFilter_.add(rs_.deliveryRate, roundCount_);
}

void Shaper::updateRound(const Packet &packet)
{
    if (packet.delivered >= nextRoundDelivered_) {
        nextRoundDelivered_ = delivered_;
        ++roundCount_;
        roundStart_ = true;
    }
    else
        roundStart_ = false;
}

void Shaper::updateRtProp(Clock::time_point now, sc::microseconds rtt)
{
    rtPropExpired_ = now > rtPropStamp_ + RTpropFilterLen;
    if (rtt >= 0us && (!rtProp_ || rtt <= *rtProp_ || rtPropExpired_)) {
        rtProp_ = rtt;
        rtPropStamp_ = now;
    }
}

void Shaper::updateControlParameters(const Packet &packet)
{
    setPacingRate();
    setSendQuantum();
    setCwnd(packet);
}

void Shaper::setPacingRateWithGain(double pacingGain)
{
    auto rate = pacingGain_ * btlBw_;
    if (filledPipe_ || rate > pacingRate_)
        pacingRate_ = rate;
}

void Shaper::setPacingRate() { setPacingRateWithGain(pacingGain_); }

void Shaper::setSendQuantum()
{
    sendQuantum_ =
        std::max<std::size_t>(FLAGS_shaper_quantum_ms_size * btlBw_, 150 * 25);
}

std::size_t Shaper::inFlight(double gain)
{
    if (!rtProp_)
        return FLAGS_shaper_initial_window;  // no valid RTT samples yet

    auto quanta = 3 * sendQuantum_;
    auto estimatedBdp = btlBw_ * rtProp_->count() / 1000;
    return gain * estimatedBdp + quanta;
}

void Shaper::updateTargetCwnd() { targetCwnd_ = inFlight(cwndGain_); }

void Shaper::saveCwnd()
{
    priorCwnd_ =
        state_ == State::probeRTT ? std::max(priorCwnd_, cwnd_) : cwnd_;
}

void Shaper::restoreCwnd() { cwnd_ = std::max(cwnd_, priorCwnd_); }

void Shaper::modulateCwndForProbeRTT()
{
    if (state_ == State::probeRTT)
        cwnd_ = std::min(cwnd_, BBRMinPipeCwnd);
}

void Shaper::setCwnd(const Packet &packet)
{
    updateTargetCwnd();
    if (filledPipe_)
        cwnd_ = std::min(cwnd_ + packet.size, targetCwnd_);
    else if (cwnd_ < targetCwnd_ || delivered_ < FLAGS_shaper_initial_window)
        cwnd_ += packet.size;

    cwnd_ = std::max(cwnd_, BBRMinPipeCwnd);

    modulateCwndForProbeRTT();
}

void Shaper::checkFullPipe()
{
    if (filledPipe_ || !roundStart_ || rs_.isAppLimited)
        return;  // no need to check for a full pipe now

    if (btlBw_ >= fullBw_ * 1.25) {  // btlBw_ still growing?
        fullBw_ = btlBw_;            // record new baseline level
        fullBwCount_ = 0;
        return;
    }

    fullBwCount_++;  // another round w/o much growth
    if (fullBwCount_ >= 3)
        filledPipe_ = true;
}

void Shaper::enterDrain()
{
    state_ = State::drain;
    pacingGain_ = 1 / BBRHighGain;  // pace slowly
    cwndGain_ = BBRHighGain;        // maintain cwnd
}

void Shaper::checkDrain(Clock::time_point now)
{
    if (state_ == State::startup && filledPipe_)
        enterDrain();
    if (state_ == State::drain && inFlight_ <= inFlight(1.0))
        enterProbeBw(now);  // we estimate queue is drained
}

void Shaper::enterProbeBw(Clock::time_point now)
{
    state_ = State::probeBW;
    pacingGain_ = 1.0;
    cwndGain_ = 2.0;
    cycleIndex_ = BBRGainCycleLen - 1 - folly::Random::rand32(7);
    advanceCyclePhase(now);
}

void Shaper::checkCyclePhase(Clock::time_point now)
{
    if (state_ == State::probeBW && isNextCyclePhase(now))
        advanceCyclePhase(now);
}

void Shaper::advanceCyclePhase(Clock::time_point now)
{
    cycleStamp_ = now;
    cycleIndex_ = (cycleIndex_ + 1) % BBRGainCycleLen;
    static constexpr double pacingGainCycle[] = {
        5. / 4, 3. / 4, 1, 1, 1, 1, 1, 1};
    pacingGain_ = pacingGainCycle[cycleIndex_];
}

bool Shaper::isNextCyclePhase(Clock::time_point now)
{
    bool isFullLength = rtProp_ && (now - cycleStamp_) > *rtProp_;
    if (pacingGain_ == 1)
        return isFullLength;
    if (pacingGain_ > 1)
        return isFullLength && priorInFlight_ >= inFlight(pacingGain_);
    return isFullLength || priorInFlight_ <= inFlight(1);
}

void Shaper::handleRestartFromIdle()
{
    if (inFlight_ == 0 && appLimited_) {
        idleRestart_ = true;
        if (state_ == State::probeBW)
            setPacingRateWithGain(1);
    }
}

void Shaper::checkProbeRTT(Clock::time_point now)
{
    if (state_ != State::probeRTT && rtPropExpired_ && !idleRestart_) {
        enterProbeRTT();
        saveCwnd();
        probeRttDoneStamp_.clear();
    }

    if (state_ == State::probeRTT)
        handleProbeRTT(now);

    idleRestart_ = false;
}

void Shaper::enterProbeRTT()
{
    state_ = State::probeRTT;
    pacingGain_ = 1;
    cwndGain_ = 1;
}

void Shaper::handleProbeRTT(Clock::time_point now)
{
    // Ignore low rate samples during ProbeRTT
    appLimited_ = std::max<decltype(appLimited_)>(delivered_ + inFlight_, 1);
    if (!probeRttDoneStamp_ && inFlight_ <= BBRMinPipeCwnd) {
        probeRttDoneStamp_ = now + ProbeRTTDuration;
        probeRttRoundDone_ = false;
        nextRoundDelivered_ = delivered_;
    }
    else if (probeRttDoneStamp_) {
        if (roundStart_)
            probeRttRoundDone_ = true;
        if (probeRttRoundDone_ && now > *probeRttDoneStamp_) {
            rtPropStamp_ = now;
            restoreCwnd();
            exitProbeRTT(now);
        }
    }
}

void Shaper::exitProbeRTT(Clock::time_point now)
{
    if (filledPipe_)
        enterProbeBw(now);
    else
        enterStartup();
}

void Shaper::onSendPacket(Clock::time_point now, std::uint64_t reqId,
    std::uint64_t offset, std::size_t size)
{
    Packet packet;
    if (inFlight_ == 0)
        firstSentTime_ = deliveredTime_ = now;

    packet.firstSentTime = firstSentTime_;
    packet.deliveredTime = deliveredTime_;
    packet.delivered = delivered_;
    packet.isAppLimited = (appLimited_ != 0);
    packet.size = size;
    packet.sentTime = now;

    sentPackets_.emplace(std::make_pair(reqId, offset), packet);
    unackedPerReq_.emplace(reqId, offset);

    handleRestartFromIdle();

    priorInFlight_ = inFlight_;
    inFlight_ += packet.size;
}

bool Shaper::generateRateSample(const Packet &p, Clock::time_point now)
{
    updateRateSample(p, now);

    // Clear app-limited field if bubble is ACKed and gone.
    if (appLimited_ and delivered_ > appLimited_)
        appLimited_ = 0;

    if (rs_.priorTime.time_since_epoch().count() == 0)
        return false;  // nothing delivered on this ACK

    // Use the longer of the send_elapsed and ack_elapsed
    rs_.interval = std::max(rs_.sendElapsed, rs_.ackElapsed);
    rs_.delivered = delivered_ - rs_.priorDelivered;

    /* Normally we expect interval >= MinRTT.
     * Note that rate may still be over-estimated when a
     * spuriously retransmitted skb was first (s)acked because
     * "interval" is under-estimated (up to an RTT). However,
     * continuously measuring the delivery rate during loss * recovery is
     * crucial for connections suffer heavy or prolonged losses.
     */
    if (!rtProp_ || *rs_.interval < *rtProp_) {
        rs_.interval.clear();
        return false;  // no reliable sample
    }

    if (*rs_.interval != 0us)
        rs_.deliveryRate = rs_.delivered / rs_.interval->count();

    return true;  // we filled in rs with a rate sample
}

/* Update rs when packet is SACKed or ACKed. */
void Shaper::updateRateSample(const Packet &p, Clock::time_point now)
{
    delivered_ += p.size;
    deliveredTime_ = now;

    /* Update info using the newest packet: */
    if (p.delivered > rs_.priorDelivered) {
        rs_.priorDelivered = p.delivered;
        rs_.priorTime = p.deliveredTime;
        rs_.isAppLimited = p.isAppLimited;
        rs_.sendElapsed = round<sc::milliseconds>(p.sentTime - p.firstSentTime);
        rs_.ackElapsed =
            round<sc::milliseconds>(deliveredTime_ - p.deliveredTime);
        firstSentTime_ = p.sentTime;
    }
}

void Shaper::scheduledSendPacket()
{
    executor_.addWithPriority(
        [this] {
            isUpdateScheduled_ = false;
            sendPacket(Clock::now());
        },
        SHAPER_OPS_PRIO);
}

void Shaper::sendPacket(Clock::time_point now)
{
    if (inFlight_ >= cwnd_) {
        VLOG(2) << "inFlight_ = " << inFlight_ << " >= " << cwnd_ << " = cwnd_";
        return;  // wait for ack or retransmission timeout
    }

    if (isUpdateScheduled_) {
        VLOG(2) << "isUpdateScheduled_ = true";
        return;
    }

    if (delayedReqsSize_ < sendQuantum_)
        appLimited_ =
            std::max<decltype(appLimited_)>(delivered_ + inFlight_, 1);

    const std::size_t toSend = std::min(delayedReqsSize_, sendQuantum_);
    if (toSend == 0) {
        VLOG(2) << "toSend = 0";
        return;
    }

    std::size_t sent = 0;
    while (sent < toSend)
        sent += sendPacket(now, toSend - sent);

    sc::microseconds sendNextIn{std::llround(sent * 1000 / pacingRate_)};
    if (sendNextIn > 0us) {
        VLOG(2) << "Scheduling wakeup in " << sendNextIn.count()
                << " us (pacingRate_: " << pacingRate_ << ")";
        nextSendTime_ = now + sendNextIn;
        scheduleUpdate(now);
    }
    else if (delayedReqsSize_ > 0) {
        LOG(WARNING) << "Not scheduling even though delayedReqsSize_ = "
                     << delayedReqsSize_ << " (sendQuantum_ = " << sendQuantum_
                     << ", toSend: " << toSend
                     << ", pacingRate: " << pacingRate_ << ")";
    }
}

void Shaper::scheduleUpdate(Clock::time_point now)
{
    DCHECK(nextSendTime_ > now);
    isUpdateScheduled_ = true;
    timer_.scheduleSendPacket(nextSendTime_, this);
}

std::size_t Shaper::sendPacket(Clock::time_point now, const std::size_t toSend)
{
    while (!delayedRequests_.empty()) {
        auto res = delayedRequests_.front();

        // Ignore already fulfilled entries - cancels
        if (res->promise.isFulfilled()) {
            DLOG(INFO) << "Already fulfilled promise present in queue.";
            delayedRequests_.pop();
            continue;
        }

        auto reqId = res->fetch->fetch().req_id();
        auto offset = res->fetch->fetch().offset();
        const std::size_t frontSize = res->fetch->fetch().size();
        const std::size_t toRequest =
            std::min(std::min(toSend, FLAGS_single_fetch_max_size), frontSize);

        if (frontSize == toRequest) {
            fulfillWholeRequest(res);
            delayedRequests_.pop();
        }
        else {
            fulfillSubRequest(res, toRequest);
        }

        delayedReqsSize_ -= toRequest;
        onSendPacket(now, reqId, offset, toRequest);

        return toRequest;
    }

    LOG(FATAL) << "sendPacket called while no packets wait to be sent";
    return -1;
}

void Shaper::fulfillWholeRequest(std::shared_ptr<Request> req)
{
    auto reqId = req->fetch->fetch().req_id();
    VLOG(2) << "fulfillWholeRequest req_id: " << reqId
            << ", offset: " << req->fetch->fetch().offset()
            << ", size: " << req->fetch->fetch().size();

    const std::string m{"link.shaper." +
                        std::to_string(req->fetch->fetch().priority()) +
                        ".queue"};
    ONE_METRIC_COUNTER_SUB(m + "_bytes", req->fetch->fetch().size());
    ONE_METRIC_COUNTER_SUB(m, 1);

    req->readFutures.emplace_back(readFromStorage(std::move(req->fetch)));
    collect(req->readFutures)
        .then([req](const std::vector<std::size_t> &vals) {
            std::size_t sum = 0;
            for (auto val : vals)
                sum += val;
            req->promise.setValue(sum);
        })
        .onError([req](folly::exception_wrapper ew) {
            req->promise.setException(std::move(ew));
        })
        .via(&executor_)
        .ensure([=] { reqs_.erase(reqId); });
}

void Shaper::fulfillSubRequest(std::shared_ptr<Request> req, std::size_t size)
{
    VLOG(2) << "fulfillSubRequest req_id: " << req->fetch->fetch().req_id()
            << ", offset: " << req->fetch->fetch().offset()
            << ", size: " << size;

    ONE_METRIC_COUNTER_SUB("link.shaper." +
                               std::to_string(req->fetch->fetch().priority()) +
                               ".queue_bytes",
        size);

    auto subMsg = std::make_unique<proto::LinkMessage>(*req->fetch);
    subMsg->mutable_fetch()->set_size(size);
    auto origFetch = req->fetch->mutable_fetch();
    origFetch->set_size(origFetch->size() - size);
    origFetch->set_offset(origFetch->offset() + size);
    req->readFutures.emplace_back(readFromStorage(std::move(subMsg)));
}

folly::Future<std::size_t> Shaper::readFromStorage(MsgPtr req)
{
    auto &msg = req->fetch();
    auto priority = clampPriority(msg.priority());
    return reader_
        .read(msg.req_id(), msg.src(), msg.file_id(), msg.file_guid(),
            msg.offset(), msg.size(), priority)
        .then([reqId = msg.req_id(), offset = msg.offset(), size = msg.size()](
                  folly::Try<std::size_t> res) {
            if (res.hasException()) {
                VLOG(1) << "Error fulfilling req_id: " << reqId
                        << ", offset: " << offset << ", size: " << size;
                return folly::makeFuture<std::size_t>(
                    std::move(res).exception());
            }
            VLOG(2) << "Successfuly read req_id: " << reqId
                    << ", offset: " << offset << ", size: " << size;
            return folly::makeFuture(res.value());
        });
}

}  // namespace rtransfer
