#include "shaper.hpp"

#include <glog/logging.h>
#include <monitoring/monitoring.h>

#include <chrono>

#include "shaperTimer.hpp"

DEFINE_uint64(single_fetch_max_size, 10 * 1024 * 1024,
    "maximum size of single read block");
DEFINE_uint64(
    shaper_initial_window, 10 * 1024 * 1024, "initial window size for shaper");
DEFINE_uint64(shaper_quantum_ms_size, 25,
    "the quantum size of shaper will be set to this value times bandwidth per "
    "second");
DEFINE_uint64(
    request_read_retry_count, 10, "Per request storage read retry count");

namespace sc = std::chrono;
using namespace std::chrono_literals;

namespace {

/**
 * Set of errors which should be retried during read from source storage.
 */
const std::set<int> retryErrors = {EINTR, EIO, EAGAIN, EACCES, EBUSY, EMFILE,
    ETXTBSY, ESPIPE, EMLINK, EPIPE, EDEADLK, EWOULDBLOCK, ENOLINK, EADDRINUSE,
    EADDRNOTAVAIL, ENETDOWN, ENETUNREACH, ECONNABORTED, ECONNRESET, ENOTCONN,
    EHOSTUNREACH, ECANCELED, ESTALE};

std::uint8_t clampPriority(std::uint32_t priority)
{
    return priority <= std::numeric_limits<std::uint8_t>::max()
               ? static_cast<std::uint8_t>(priority)
               : std::numeric_limits<std::uint8_t>::max();
}

constexpr int kToKBytes = 1000;
}  // namespace

namespace rtransfer {

Shaper::Request::Request(MsgPtr msg)
    : fetch{std::move(msg)}
    , retries{FLAGS_request_read_retry_count}
{
}

Shaper::RequestQueue::RequestQueue(const std::size_t simultaneous)
    : simultaneous_{simultaneous}
{
}

void Shaper::RequestQueue::push(std::shared_ptr<Request> req)
{
    queue_.emplace_back(std::move(req));
}

bool Shaper::RequestQueue::empty() const { return queue_.empty(); }

std::shared_ptr<Shaper::Request> Shaper::RequestQueue::front()
{
    if (idx_ >= std::min(simultaneous_, queue_.size()))
        idx_ = 0;

    return queue_[idx_++];
}

void Shaper::RequestQueue::pop()
{
    std::swap(queue_[idx_ - 1], queue_.front());
    --idx_;
    queue_.pop_front();
}

Shaper::Shaper(Shaper::ID id, Reader &reader, ShaperTimer<Clock> &timer)
    : id_{std::move(id)}
    , reader_{reader}
    , executor_{folly::SerialExecutor::create(folly::getKeepAliveToken(
          folly::getUnsafeMutableGlobalCPUExecutor().get()))}
    , rtPropStamp_{Clock::now()}
    , pacingRate_{FLAGS_single_fetch_max_size * BBRHighGain / kToKBytes}
    , timer_{timer}
{
    enterStartup();
}

folly::Future<std::size_t> Shaper::read(MsgPtr msg)
{
    VLOG(2) << "Shaper::read " << msg->req_id();

    return via(executor_.get())
        .thenValue([this, msg = std::move(msg)](auto && /*unit*/) mutable {
            return doRead(std::move(msg));
        });
}

folly::Future<std::size_t> Shaper::doRead(MsgPtr msg)
{
    auto req = std::make_shared<Request>(std::move(msg));
    const auto &fetch = req->fetch->fetch();

    VLOG(2) << "Storing req_id: " << fetch.req_id() << " for shaping";

    auto priority = clampPriority(fetch.priority());
    auto f = req->promise.getFuture();
    reqs_.emplace(fetch.req_id(), req);

    VLOG(2) << "Scheduled request " << fetch.req_id() << " with priority "
            << priority;

    const std::string m{"link.shaper." + std::to_string(priority) + ".queue"};
    ONE_METRIC_COUNTER_ADD(m + "_bytes", fetch.size());
    ONE_METRIC_COUNTER_ADD(m, 1);

    pendingReqsSize_ += fetch.size();
    pendingRequests_.push(std::move(req));

    readRequestsCount_++;

    sendPacket(Clock::now());

    return f;
}

folly::Future<folly::Unit> Shaper::cancel(MsgPtr msg)
{
    VLOG(2) << "Cancelling request " << msg->req_id();

    return via(executor_.get())
        .thenValue([this, msg = std::move(msg)](
                       auto && /*unit*/) mutable { doCancel(std::move(msg)); });
}

void Shaper::doCancel(MsgPtr msg)
{
    VLOG(1) << "Cancelling request " << msg->req_id();

    auto reqId = msg->cancel().req_id();
    auto it = reqs_.find(reqId);
    if (it != reqs_.end()) {
        it->second->promise.setException(
            folly::make_exception_wrapper<std::runtime_error>(
                "Fetch canceled by user"));

        // If !it->second->fetch, then we fulfilled the whole request already
        // and cleared it->second->fetch by std::move in fulfilWholeRequest.
        if (it->second->fetch) {
            pendingReqsSize_ -= it->second->fetch->fetch().size();
            it->second->fetch.reset();
        }

        reqs_.erase(it);
    }

    auto range = unackedPerReq_.equal_range(reqId);
    for (auto offsetIt = range.first; offsetIt != range.second; ++offsetIt) {
        auto sentPacketIt = sentPackets_.find({reqId, offsetIt->second});
        if (sentPacketIt != sentPackets_.end()) {
            decreaseInFlightSize(sentPacketIt->second.size);
            sentPackets_.erase(sentPacketIt);
        }
    }
    if (range.first != range.second)
        unackedPerReq_.erase(range.first, range.second);

    // inFlight_ may have changed
    sendPacket(Clock::now());
}

folly::Future<bool> Shaper::ack(
    std::uint64_t reqId, folly::fbvector<std::uint64_t> offsets)
{
    VLOG(2) << "Got ack for " << reqId;

    return via(executor_.get(), SHAPER_OPS_PRIO)
        .thenValue([=, offsets = std::move(offsets)](auto && /*unit*/) {
            const auto now = Clock::now();

            for (std::uint64_t offset : offsets)
                doAck(now, reqId, offset);

            VLOG(2) << "Remaining reqs to ack " << unackedPerReq_.count(reqId);

            VLOG(2) << "Remaining data inFlight " << inFlight_;

            sendPacket(Clock::now());

            // true if nothing to send and nothing to ack
            return (unackedPerReq_.count(reqId) == 0u) &&
                   (reqs_.count(reqId) == 0u);
        });
}

void Shaper::doAck(
    Clock::time_point now, std::uint64_t reqId, std::uint64_t offset)
{
    auto it = sentPackets_.find(std::make_pair(reqId, offset));
    if (it == sentPackets_.end()) {
        LOG(WARNING) << "No packet to ack " << reqId << " " << offset;
        return;
    }

    const auto &packet = it->second;
    auto rtt = round<sc::microseconds>(now - packet.sentTime);

    auto range = unackedPerReq_.equal_range(reqId);
    auto unackedIt = std::find_if(range.first, range.second,
        [offset](const auto &i) { return i.second == offset; });
    CHECK(unackedIt != range.second);
    unackedPerReq_.erase(unackedIt);

    VLOG(2) << "Acked " << packet.size << " bytes for reqId: " << reqId;

    decreaseInFlightSize(packet.size);

    generateRateSample(packet, now);
    updateModelAndState(packet, now, rtt);
    updateControlParameters(packet);

    const int kRTPopDivider = 1000;

    acknowledgedPacketsCount_++;

    // NOLINTNEXTLINE
    LOG_EVERY_N(INFO, 1000)
        << "Shaper(" << std::get<ID_SRC_>(id_) << "," << std::get<ID_DEST_>(id_)
        << "," << std::to_string(std::get<ID_PRIO_>(id_)) << ") stats - rtt: "
        << (rtProp_ ? rtProp_->count() / kRTPopDivider : -1)
        << " ms, bw: " << btlBw_ << ", cwnd: " << cwnd_
        << ", targetCwnd: " << targetCwnd_ << ", inFlight: " << inFlight_
        << ", pendingReqsSize: " << pendingReqsSize_
        << ", readRequests: " << readRequestsCount_
        << ", scheduledPackets: " << scheduledPacketsCount_
        << ", sentPackets: " << sentPacketsCount_
        << ", acknowledgedPackets_: " << acknowledgedPacketsCount_;

    sentPackets_.erase(it);
}

void Shaper::enterStartup()
{
    state_ = State::startup;
    pacingGain_ = BBRHighGain;
    cwndGain_ = BBRHighGain;
}

void Shaper::updateModelAndState(
    const Packet &packet, Clock::time_point now, sc::microseconds rtt)
{
    updateBtlBw(packet);
    checkCyclePhase(now);
    checkFullPipe();
    checkDrain(now);
    updateRtProp(now, rtt);
    checkProbeRTT(now);
}

void Shaper::updateBtlBw(const Packet &packet)
{
    updateRound(packet);
    if (rs_.deliveryRate >= btlBw_ || (rs_.isAppLimited == 0U))
        btlBw_ = btlBwFilter_.add(rs_.deliveryRate, roundCount_);
}

void Shaper::updateRound(const Packet &packet)
{
    if (packet.delivered >= nextRoundDelivered_) {
        nextRoundDelivered_ = delivered_;
        ++roundCount_;
        roundStart_ = true;
    }
    else
        roundStart_ = false;
}

void Shaper::updateRtProp(Clock::time_point now, sc::microseconds rtt)
{
    rtPropExpired_ = now > rtPropStamp_ + RTpropFilterLen;
    if (rtt >= 0us && (!rtProp_ || rtt <= *rtProp_ || rtPropExpired_)) {
        rtProp_ = rtt;
        rtPropStamp_ = now;
    }
}

void Shaper::updateControlParameters(const Packet &packet)
{
    setPacingRate();
    setSendQuantum();
    setCwnd(packet);
}

void Shaper::setPacingRateWithGain(double /*pacingGain*/)
{
    auto rate = pacingGain_ * btlBw_;
    if (filledPipe_ || rate > pacingRate_)
        pacingRate_ = rate;
}

void Shaper::setPacingRate() { setPacingRateWithGain(pacingGain_); }

void Shaper::setSendQuantum()
{
    const int kQuantumMS = 25;
    const int kQuantumCount = 150;
    sendQuantum_ = std::max<std::size_t>(
        FLAGS_shaper_quantum_ms_size * btlBw_, kQuantumCount * kQuantumMS);
}

std::size_t Shaper::inFlight(double gain)
{
    if (!rtProp_)
        return FLAGS_shaper_initial_window;  // no valid RTT samples yet

    const int kToMilliSeconds = 1000;
    auto quanta = 3 * sendQuantum_;
    auto estimatedBdp = btlBw_ * rtProp_->count() / kToMilliSeconds;
    return gain * estimatedBdp + quanta;
}

void Shaper::updateTargetCwnd() { targetCwnd_ = inFlight(cwndGain_); }

void Shaper::saveCwnd()
{
    priorCwnd_ =
        state_ == State::probeRTT ? std::max(priorCwnd_, cwnd_) : cwnd_;
}

void Shaper::restoreCwnd() { cwnd_ = std::max(cwnd_, priorCwnd_); }

void Shaper::modulateCwndForProbeRTT()
{
    if (state_ == State::probeRTT)
        cwnd_ = std::min(cwnd_, BBRMinPipeCwnd);
}

void Shaper::setCwnd(const Packet &packet)
{
    updateTargetCwnd();
    if (filledPipe_)
        cwnd_ = std::min(cwnd_ + packet.size, targetCwnd_);
    else if (cwnd_ < targetCwnd_ || delivered_ < FLAGS_shaper_initial_window)
        cwnd_ += packet.size;

    cwnd_ = std::max(cwnd_, BBRMinPipeCwnd);

    modulateCwndForProbeRTT();
}

void Shaper::checkFullPipe()
{
    if (filledPipe_ || !roundStart_ || (rs_.isAppLimited != 0U))
        return;  // no need to check for a full pipe now

    const double kBwFactor = 1.25;
    if (btlBw_ >= fullBw_ * kBwFactor) {  // btlBw_ still growing?
        fullBw_ = btlBw_;                 // record new baseline level
        fullBwCount_ = 0;
        return;
    }

    fullBwCount_++;  // another round w/o much growth
    if (fullBwCount_ >= 3)
        filledPipe_ = true;
}

void Shaper::enterDrain()
{
    state_ = State::drain;
    pacingGain_ = 1 / BBRHighGain;  // pace slowly
    cwndGain_ = BBRHighGain;        // maintain cwnd
}

void Shaper::checkDrain(Clock::time_point now)
{
    const auto currentInFlight = increaseInFlightSize(0);

    if (state_ == State::startup && filledPipe_)
        enterDrain();
    if (state_ == State::drain && currentInFlight <= inFlight(1.0))
        enterProbeBw(now);  // we estimate queue is drained
}

void Shaper::enterProbeBw(Clock::time_point now)
{
    const double kInitialPacingGain = 1.0;
    const double kInitialCwndGain = 2.0;
    const int kCycleLenRand = 7;

    state_ = State::probeBW;
    pacingGain_ = kInitialPacingGain;
    cwndGain_ = kInitialCwndGain;
    cycleIndex_ = BBRGainCycleLen - 1 - folly::Random::rand32(kCycleLenRand);
    advanceCyclePhase(now);
}

void Shaper::checkCyclePhase(Clock::time_point now)
{
    if (state_ == State::probeBW && isNextCyclePhase(now))
        advanceCyclePhase(now);
}

void Shaper::advanceCyclePhase(Clock::time_point now)
{
    cycleStamp_ = now;
    cycleIndex_ = (cycleIndex_ + 1) % BBRGainCycleLen;
    static constexpr double pacingGainCycle[] = {
        5. / 4, 3. / 4, 1, 1, 1, 1, 1, 1};
    pacingGain_ = pacingGainCycle[cycleIndex_];  // NOLINT
}

bool Shaper::isNextCyclePhase(Clock::time_point now)
{
    bool isFullLength = rtProp_ && (now - cycleStamp_) > *rtProp_;
    if (pacingGain_ == 1)
        return isFullLength;
    if (pacingGain_ > 1)
        return isFullLength && priorInFlight_ >= inFlight(pacingGain_);
    return isFullLength || priorInFlight_ <= inFlight(1);
}

void Shaper::handleRestartFromIdle()
{
    const auto currentInFlight = increaseInFlightSize(0);

    if (currentInFlight == 0 && (appLimited_ != 0)) {
        idleRestart_ = true;
        if (state_ == State::probeBW)
            setPacingRateWithGain(1);
    }
}

void Shaper::checkProbeRTT(Clock::time_point now)
{
    if (state_ != State::probeRTT && rtPropExpired_ && !idleRestart_) {
        enterProbeRTT();
        saveCwnd();
        probeRttDoneStamp_.clear();
    }

    if (state_ == State::probeRTT)
        handleProbeRTT(now);

    idleRestart_ = false;
}

void Shaper::enterProbeRTT()
{
    state_ = State::probeRTT;
    pacingGain_ = 1;
    cwndGain_ = 1;
}

void Shaper::handleProbeRTT(Clock::time_point now)
{
    const auto currentInFlight = increaseInFlightSize(0);

    // Ignore low rate samples during ProbeRTT
    appLimited_ =
        std::max<decltype(appLimited_)>(delivered_ + currentInFlight, 1);
    if (!probeRttDoneStamp_ && currentInFlight <= BBRMinPipeCwnd) {
        probeRttDoneStamp_ = now + ProbeRTTDuration;
        probeRttRoundDone_ = false;
        nextRoundDelivered_ = delivered_;
    }
    else if (probeRttDoneStamp_) {
        if (roundStart_)
            probeRttRoundDone_ = true;
        if (probeRttRoundDone_ && now > *probeRttDoneStamp_) {
            rtPropStamp_ = now;
            restoreCwnd();
            exitProbeRTT(now);
        }
    }
}

void Shaper::exitProbeRTT(Clock::time_point now)
{
    if (filledPipe_)
        enterProbeBw(now);
    else
        enterStartup();
}

void Shaper::onSendPacket(Clock::time_point now, std::uint64_t reqId,
    std::uint64_t offset, std::size_t size)
{
    Packet packet;
    const auto currentInFlight = increaseInFlightSize(0);

    if (currentInFlight == 0)
        firstSentTime_ = deliveredTime_ = now;

    packet.firstSentTime = firstSentTime_;
    packet.deliveredTime = deliveredTime_;
    packet.delivered = delivered_;
    packet.isAppLimited = static_cast<std::size_t>(appLimited_ != 0);
    packet.size = size;
    packet.sentTime = now;

    sentPackets_.emplace(std::make_pair(reqId, offset), packet);
    unackedPerReq_.emplace(reqId, offset);

    handleRestartFromIdle();

    priorInFlight_ = increaseInFlightSize(size);

    sentPacketsCount_++;
}

std::size_t Shaper::increaseInFlightSize(std::size_t size)
{
    std::lock_guard<std::mutex> g{inFlightMutex_};

    if (size == 0UL)
        return inFlight_;

    VLOG(2) << "Increasing inFlight_ counter by " << size;

    auto previousInFlight = inFlight_;

    inFlight_ += size;

    return previousInFlight;
}

std::size_t Shaper::decreaseInFlightSize(std::size_t size)
{
    std::lock_guard<std::mutex> g{inFlightMutex_};

    if (size == 0UL)
        return inFlight_;

    VLOG(4) << "Decreasing inFlight_ counter by " << size;

    auto previousInFlight = inFlight_;

    inFlight_ -= std::min<std::size_t>(inFlight_, size);

    return previousInFlight;
}

bool Shaper::generateRateSample(const Packet &p, Clock::time_point now)
{
    updateRateSample(p, now);

    // Clear app-limited field if bubble is ACKed and gone.
    if ((appLimited_ != 0) and delivered_ > appLimited_)
        appLimited_ = 0;

    if (rs_.priorTime.time_since_epoch().count() == 0)
        return false;  // nothing delivered on this ACK

    // Use longest of the send_elapsed and ack_elapsed
    rs_.interval = std::max(rs_.sendElapsed, rs_.ackElapsed);
    rs_.delivered = delivered_ - rs_.priorDelivered;

    /* Normally we expect interval >= MinRTT.
     * Note that rate may still be over-estimated when a
     * spuriously retransmitted skb was first (s)acked because
     * "interval" is under-estimated (up to an RTT). However,
     * continuously measuring the delivery rate during loss * recovery is
     * crucial for connections suffer heavy or prolonged losses.
     */
    if (!rtProp_ || *rs_.interval < *rtProp_) {
        rs_.interval.clear();
        return false;  // no reliable sample
    }

    if (*rs_.interval != 0us)
        rs_.deliveryRate =
            static_cast<double>(rs_.delivered) / rs_.interval->count();

    return true;  // we filled in rs with a rate sample
}

/* Update rs when packet is SACKed or ACKed. */
void Shaper::updateRateSample(const Packet &p, Clock::time_point now)
{
    delivered_ += p.size;
    deliveredTime_ = now;

    /* Update info using the newest packet: */
    if (p.delivered > rs_.priorDelivered) {
        rs_.priorDelivered = p.delivered;
        rs_.priorTime = p.deliveredTime;
        rs_.isAppLimited = p.isAppLimited;
        rs_.sendElapsed = round<sc::milliseconds>(p.sentTime - p.firstSentTime);
        rs_.ackElapsed =
            round<sc::milliseconds>(deliveredTime_ - p.deliveredTime);
        firstSentTime_ = p.sentTime;
    }
}

void Shaper::scheduledSendPacket()
{
    executor_->addWithPriority(
        [this] {
            isUpdateScheduled_ = false;
            sendPacket(Clock::now());
        },
        SHAPER_OPS_PRIO);
}

void Shaper::sendPacket(Clock::time_point now)
{
    const static auto kMinSendNextIn = sc::microseconds{1000};

    using namespace std::chrono_literals;

    const auto inFlightCurrent = increaseInFlightSize(0);

    if (inFlightCurrent >= cwnd_) {
        VLOG(2) << "Waiting for ack because inFlight_ == " << inFlightCurrent
                << " >= " << cwnd_ << " == cwnd_";

        return;  // wait for ack or retransmission timeout
    }

    if (isUpdateScheduled_) {
        VLOG(2) << "Waiting for already scheduled update: isUpdateScheduled_ "
                   "== true, at: "
                << nextSendTime_.time_since_epoch().count()
                << " now is: " << now.time_since_epoch().count();
        return;
    }

    if (pendingReqsSize_ < sendQuantum_)
        appLimited_ =
            std::max<decltype(appLimited_)>(delivered_ + inFlightCurrent, 1);

    const std::size_t toSend = std::min(pendingReqsSize_, sendQuantum_);

    if (toSend == 0) {
        VLOG(2) << "Nothing to send: toSend == 0, sendQuantum_ == "
                << sendQuantum_;
        return;
    }

    std::size_t sent = 0;
    while (sent < toSend)
        sent += sendPacket(now, toSend - sent);

    const int kToMilliseconds = 1000;
    sc::microseconds sendNextIn{std::llround(
        static_cast<double>(sent) * kToMilliseconds / pacingRate_)};

    if (sendNextIn == 0us)
        sendNextIn = kMinSendNextIn;

    VLOG(2) << "Scheduling wakeup in " << sendNextIn.count()
            << " us (pacingRate_: " << pacingRate_ << ")";

    nextSendTime_ = now + sendNextIn;

    scheduleUpdate(now);
}

void Shaper::scheduleUpdate(Clock::time_point now)
{
    DCHECK(nextSendTime_ > now);

    isUpdateScheduled_ = true;

    auto nextSendTime = nextSendTime_;

    // Make sure we're not scheduling into the past
    if (nextSendTime <= now) {
        nextSendTime = now + std::chrono::microseconds(1000);
    }

    timer_.scheduleSendPacket(nextSendTime, this);
}

std::size_t Shaper::sendPacket(Clock::time_point now, const std::size_t toSend)
{
    while (!pendingRequests_.empty()) {
        auto frontRequest = pendingRequests_.front();

        // Ignore already fulfilled entries - cancels
        if (frontRequest->promise.isFulfilled()) {
            DLOG(INFO) << "Already fulfilled promise present in queue.";
            pendingRequests_.pop();
            continue;
        }

        const auto reqId = frontRequest->fetch->fetch().req_id();
        const auto offset = frontRequest->fetch->fetch().offset();
        const auto destStorageBlockSize =
            frontRequest->fetch->fetch().dest_block_size();

        const std::size_t remainingFetchSize =
            frontRequest->fetch->fetch().size();

        // Calculate default subrequest size based on maximum fetch size,
        // requested fetch size and estimated optimal to send size
        std::size_t subRequestSize = std::min(
            std::min(toSend, FLAGS_single_fetch_max_size), remainingFetchSize);

        if (destStorageBlockSize > 0) {
            // If the destination storage requires blocks of certain size,
            // ensure that the reads are aligned to that block size, in favor of
            // optimizing for network performance
            subRequestSize =
                std::min((subRequestSize / destStorageBlockSize + 1) *
                             destStorageBlockSize,
                    remainingFetchSize);

            if (offset % destStorageBlockSize != 0) {
                // If offset is not aligned on block size boundary, adjust the
                // size to match the nearest block boundary
                auto nextBlockBoundary = ((offset / destStorageBlockSize) + 1) *
                                         destStorageBlockSize;
                subRequestSize =
                    std::min(nextBlockBoundary - offset, remainingFetchSize);
            }
        }

        scheduledPacketsCount_++;

        if (remainingFetchSize == subRequestSize) {
            fulfillWholeRequest(frontRequest);
            pendingRequests_.pop();
        }
        else {
            fulfillSubRequest(frontRequest, subRequestSize);
        }

        pendingReqsSize_ -= subRequestSize;
        onSendPacket(now, reqId, offset, subRequestSize);

        return subRequestSize;
    }

    LOG(WARNING) << "sendPacket called while no packets wait to be sent";

    return 0;
}

void Shaper::fulfillWholeRequest(std::shared_ptr<Request> req)
{
    auto reqId = req->fetch->fetch().req_id();

    VLOG(2) << "fulfillWholeRequest req_id: " << reqId
            << ", offset: " << req->fetch->fetch().offset()
            << ", size: " << req->fetch->fetch().size();

    const std::string m{"link.shaper." +
                        std::to_string(req->fetch->fetch().priority()) +
                        ".queue"};
    ONE_METRIC_COUNTER_SUB(m + "_bytes", req->fetch->fetch().size());
    ONE_METRIC_COUNTER_SUB(m, 1);

    auto req_total_size = req->fetch->total_size();
    auto sub_request_size = req->fetch->fetch().size();

    auto reqPtrCopy = req;
    enqueueFetchSubRequestWithRetry(
        reqPtrCopy, std::move(req->fetch), sub_request_size);

    collect(reqPtrCopy->readFutures)
        .via(executor_.get())
        .thenTry(
            [req = std::move(reqPtrCopy), req_total_size, reqId](auto &&t) {
                VLOG(2) << "Waiting finished for futures for " << reqId;
                if (t.hasException()) {
                    LOG(ERROR) << "Request " << reqId
                               << " failed due to: " << t.exception().what();
                    req->promise.setException(t.exception());
                }
                else {
                    std::size_t sum = 0;
                    for (auto val : t.value())
                        sum += val;

                    VLOG(2) << "Entire request " << reqId
                            << " fulfilled successfully (" << sum << ") of ("
                            << req_total_size << ")";

                    req->promise.setValue(sum);
                }
            })
        .ensure([this, reqId] {
            VLOG(2) << "Removing request " << reqId << " from queue";
            reqs_.erase(reqId);
        });
}

void Shaper::fulfillSubRequest(std::shared_ptr<Request> req, std::size_t size)
{
    auto reqId = req->fetch->fetch().req_id();

    VLOG(2) << "fulfillSubRequest req_id: " << reqId
            << ", offset: " << req->fetch->fetch().offset()
            << ", size: " << size;

    ONE_METRIC_COUNTER_SUB("link.shaper." +
                               std::to_string(req->fetch->fetch().priority()) +
                               ".queue_bytes",
        size);

    // Create new fetch submessage
    auto subRequestMessage = std::make_unique<proto::LinkMessage>(*req->fetch);
    subRequestMessage->mutable_fetch()->set_size(size);

    // Remove the submessage block from the original fetch message
    auto *origFetch = req->fetch->mutable_fetch();
    origFetch->set_size(origFetch->size() - size);
    origFetch->set_offset(origFetch->offset() + size);

    enqueueFetchSubRequestWithRetry(
        std::move(req), std::move(subRequestMessage), size);
}

void Shaper::enqueueFetchSubRequestWithRetry(std::shared_ptr<Request> req,
    std::unique_ptr<proto::LinkMessage> &&subRequestFetchMessage,
    std::size_t size)
{
    auto subRequestFetchMessageCopy =
        std::make_unique<proto::LinkMessage>(*subRequestFetchMessage);

    req->readFutures.emplace_back(
        readFromStorage(std::move(subRequestFetchMessage))
            .thenTry([this, req, subRequestSize = size,
                         subRequestFetchMessageCopy =
                             std::move(subRequestFetchMessageCopy)](
                         auto &&maybeSize) mutable {
                if (maybeSize.hasException()) {
                    auto retries = req->retries;
                    bool retryableException{false};
                    maybeSize.exception().handle(
                        [&retryableException](std::system_error &e) {
                            if (retryErrors.find(e.code().value()) !=
                                retryErrors.end())
                                retryableException = true;
                        });

                    if (retryableException && retries > 0) {
                        VLOG(1)
                            << "Retrying fetch request " << req->fetch->req_id()
                            << " at offset "
                            << subRequestFetchMessageCopy->fetch().offset()
                            << " - retries left for this request: " << retries;

                        req->retries = retries - 1;

                        enqueueFetchSubRequestWithRetry(req,
                            std::move(subRequestFetchMessageCopy),
                            subRequestSize);

                        return folly::makeFuture<size_t>(0);
                    }

                    VLOG(2)
                        << "Decreasing inFlight_ counter by " << subRequestSize
                        << " due to exception " << maybeSize.exception().what();

                    // We have to give up for this sub request due to too many
                    // retries
                    decreaseInFlightSize(subRequestSize);

                    sendPacket(Clock::now());

                    return folly::makeFuture<size_t>(maybeSize.exception());
                }

                return folly::makeFuture<size_t>(std::move(maybeSize.value()));
            }));
}

folly::Future<std::size_t> Shaper::readFromStorage(MsgPtr req)
{
    const auto &msg = req->fetch();
    auto priority = clampPriority(msg.priority());
    return reader_
        .read(msg.req_id(), msg.src(), msg.file_id(), msg.file_guid(),
            msg.offset(), msg.size(), priority)
        .thenTry([reqId = msg.req_id(), offset = msg.offset(),
                     size = msg.size()](folly::Try<std::size_t> &&res) {
            if (res.hasException()) {
                VLOG(1) << "Error fulfilling req_id: " << reqId
                        << ", offset: " << offset << ", size: " << size;

                return folly::makeFuture<std::size_t>(
                    std::move(res).exception());
            }

            VLOG(2) << "Successfuly read req_id: " << reqId
                    << ", offset: " << offset << ", size: " << size;

            return folly::makeFuture(res.value());
        });
}

}  // namespace rtransfer
