#include "reader.hpp"

#include <gflags/gflags.h>

#include <utility>

namespace rtransfer {

Reader::Reader(StoragesMap &storages, HandleCache &handleCache,
    std::shared_ptr<ServerSideLink> serverLink)
    : storages_{storages}
    , handleCache_{handleCache}
    , serverLink_{std::move(serverLink)}
{
}

folly::Future<std::size_t> Reader::read(std::uint64_t reqId,
    const folly::fbstring &storageId, const folly::fbstring &fileId,
    const folly::fbstring &fileGuid, std::uint64_t offset, std::size_t size,
    std::uint8_t priority)
{
    auto it = storages_.find(storageId);  // NOLINT
    if (it == storages_.cend())
        return folly::makeFuture<std::size_t>(
            folly::make_exception_wrapper<std::runtime_error>(
                "invalid storage"));

    VLOG(2) << "Reading " << reqId << " offset: " << offset
            << ", size: " << size;

    return handleCache_
        .read(*it->second, reqId, fileId, fileGuid, offset, size, priority)
        .thenTry([this, reqId, priority, offset, storageId, fileId](
                     folly::Try<folly::IOBufQueue> &&maybeBuf) {
            if (maybeBuf.hasException()) {
                LOG(WARNING) << "Read of " << fileId << " from storage "
                             << storageId << " failed due to system error: "
                             << maybeBuf.exception().get_exception()->what();

                ONE_METRIC_COUNTER_INC(
                    fmt::format("link.{}.read.errors", storageId));

                maybeBuf.exception().throw_exception();
            }

            auto &buf = maybeBuf.value();
            std::size_t size = buf.chainLength();

            ONE_METRIC_COUNTER_ADD(fmt::format("link.{}.pre_write_buffer.size",
                                       serverLink_->getPeerAddr()),
                size);

            VLOG(2) << "Sending " << reqId << ", offset: " << offset
                    << ", size: " << size;

            SendTask sendTask{{}, std::move(buf), reqId, offset, priority};
            auto sendFuture = sendTask.promise.getFuture();

            bool shouldISend = false;
            {
                folly::SpinLockGuard guard{spinLock_};
                queuedSendTasks_.emplace_back(std::move(sendTask));
                if (!sendInProgress_)
                    shouldISend = sendInProgress_ = true;
            }

            if (shouldISend)
                startSending();

            return std::move(sendFuture).thenValue([size](auto && /*unit*/) {
                return size;
            });
        });
}

void Reader::startSending()
{
    folly::getCPUExecutor()->add(std::bind(&Reader::doStartSending, this));
}

void Reader::doStartSending()
{
    while (true) {
        {
            folly::SpinLockGuard guard{spinLock_};
            usedSendTasks_.swap(queuedSendTasks_);
            if (usedSendTasks_.empty()) {
                sendInProgress_ = false;
                return;
            }
        }

        decltype(usedSendTasks_) usedSendTasks;
        usedSendTasks.reserve(usedSendTasks_.size());
        for (auto &task : usedSendTasks_)
            usedSendTasks.emplace_back(std::move(task));
        usedSendTasks_.clear();

        auto now = std::chrono::steady_clock::now();
        auto dataConns = serverLink_->dataConnections(now);
        if (dataConns.empty()) {
            for (auto &sendTask : usedSendTasks)
                sendTask.promise.setException(
                    folly::make_exception_wrapper<std::runtime_error>(
                        "no data connections"));
            continue;
        }

        folly::fbvector<
            std::pair<ServerSideLink::Conn, std::unique_ptr<folly::IOBuf>>>
            connsToData;

        connsToData.reserve(dataConns.size());

        for (auto &conn : dataConns)
            connsToData.emplace_back(
                std::move(conn), folly::IOBuf::createCombined(0));

        std::size_t size = 0;
        for (auto &sendTask : usedSendTasks) {
            ONE_METRIC_COUNTER_ADD(
                "link.sent_bytes", sendTask.buf.chainLength());
            size += sendTask.buf.chainLength();
            distributeParts(connsToData, sendTask);
        }

        folly::collect(sendParts(connsToData))
            .via(folly::getCPUExecutor().get())
            .thenTry(
                [=, usedSendTasks = std::move(usedSendTasks)](
                    const folly::Try<std::vector<folly::Unit>> &t) mutable {
                    auto metric = "link." + serverLink_->getPeerAddr() +
                                  ".pre_write_buffer.size";
                    ONE_METRIC_COUNTER_SUB(metric, size);
                    if (t.hasValue()) {
                        for (auto &sendTask : usedSendTasks)
                            sendTask.promise.setValue();
                    }
                    else {
                        for (auto &sendTask : usedSendTasks)
                            sendTask.promise.setException(t.exception());
                    }
                });
    }
}

folly::fbvector<std::pair<ServerSideLink::Conn, folly::IOBuf *>>
Reader::connsForPrio(folly::fbvector<std::pair<ServerSideLink::Conn,
                         std::unique_ptr<folly::IOBuf>>> &allConnsToData,
    std::uint8_t prio)
{
    const int kPriorityThreshold = 64;
    const std::size_t fourth = allConnsToData.size() / 4;
    std::size_t begin = 0;
    auto end = std::max<std::size_t>(fourth, 1);
    while (prio > kPriorityThreshold) {
        begin += fourth;
        end += fourth;
        prio -= kPriorityThreshold;
    }

    folly::fbvector<std::pair<ServerSideLink::Conn, folly::IOBuf *>> conns;
    for (std::size_t i = begin; i < end; ++i)
        conns.emplace_back(
            allConnsToData[i].first, allConnsToData[i].second.get());

    static thread_local std::mt19937 rng{folly::randomNumberSeed()};
    std::shuffle(conns.begin(), conns.end(), rng);
    return conns;
}

void Reader::distributeParts(
    folly::fbvector<std::pair<ServerSideLink::Conn,
        std::unique_ptr<folly::IOBuf>>> &allConnsToData,
    SendTask &sendTask)
{
    std::size_t size = sendTask.buf.chainLength();
    auto connsToData = connsForPrio(allConnsToData, sendTask.priority);

    static thread_local folly::fbvector<double> throughput;
    throughput.clear();
    double totalThroughput = 0;
    for (auto &el : connsToData) {
        auto &ti = el.first.second;
        double tput = static_cast<double>(ti.mss * ti.cwnd) /
                      (static_cast<double>(ti.rtt.count()) + 1.0);
        throughput.push_back(tput);
        totalThroughput += tput;
    }

    static thread_local folly::fbvector<std::size_t> partSize;
    partSize.clear();
    std::size_t assignedSize = 0;
    const int kMinChunkThreshold = 1024;
    for (auto tput : throughput) {
        std::size_t s = totalThroughput > 0
                            ? size * tput / totalThroughput
                            : static_cast<double>(size) / connsToData.size();
        // Even after throughput calculations, do not send chunks less than 1024
        // bytes if it can be avoided
        s = std::min(
            std::max<std::size_t>(kMinChunkThreshold, s), size - assignedSize);
        partSize.push_back(s);
        assignedSize += s;
    }

    // Send any remaining bytes through the first connection
    if (assignedSize < size) {
        partSize.front() += size - assignedSize;
    }

    for (auto i = 0U, part = 0U; size > 0 && i < connsToData.size(); ++i) {
        if (partSize[i] == 0)
            continue;

        // We use size to check if there's anything left to send in case the
        // floating point calculations somehow lead to assignedSize > size
        size -= std::min(partSize[i], size);

        auto msg = constructDataMsg(sendTask.reqId, sendTask.offset, part++,
            size == 0, sendTask.buf.splitAtMost(partSize[i]));

        connsToData[i].second->prependChain(std::move(msg));
    }
}

std::vector<folly::Future<folly::Unit>> Reader::sendParts(
    folly::fbvector<std::pair<ServerSideLink::Conn,
        std::unique_ptr<folly::IOBuf>>> &connsToData)
{
    std::vector<folly::Future<folly::Unit>> writeFutures;
    writeFutures.reserve(connsToData.size());

    for (auto &el : connsToData) {
        if (el.second->empty())
            continue;

        writeFutures.emplace_back(el.first.first->write(std::move(el.second)));
    }

    return writeFutures;
}

std::unique_ptr<folly::IOBuf> Reader::constructDataMsg(std::uint64_t reqId,
    std::uint64_t offset, std::uint8_t part, bool isLastPart,
    std::unique_ptr<folly::IOBuf> buf)
{
    DCHECK(part < (1 << 7));

    auto nativeSize = buf->computeChainDataLength();

    auto size = folly::Endian::big<std::uint32_t>(nativeSize);
    reqId = folly::Endian::big<std::uint64_t>(reqId);
    offset = folly::Endian::big<std::uint64_t>(offset);

    if (isLastPart)
        part |= proto::last_part_mask;

    auto msg = folly::IOBuf::createCombined(
        sizeof(reqId) + sizeof(part) + sizeof(offset) + sizeof(size));

    auto append = [&](auto val) {
        std::memcpy(msg->writableTail(), &val, sizeof(val));
        msg->append(sizeof(val));
    };

    append(reqId);
    append(part);
    append(offset);
    append(size);

    if (nativeSize > 0)
        msg->prependChain(std::move(buf));

    return msg;
}

}  // namespace rtransfer
