#include "client.hpp"

#include <gflags/gflags.h>
#include <wangle/channel/OutputBufferingHandler.h>

#include "common.hpp"
#include "eventBaseHandler.hpp"
#include "httpHeader.hpp"
#include "pingHandler.hpp"
#include "shaper.hpp"
#include "ssl.hpp"

DEFINE_uint64(number_of_data_conns, 10, "number of data connections per link");
DEFINE_uint64(
    recv_buf_size, 10 * 1024 * 1024, "size of data conns receive buffer");

using namespace std::literals;

namespace rtransfer {
namespace client {

CtrlIdHandler::CtrlIdHandler(folly::fbstring addr)
    : peerAddr_{std::move(addr)}
{
}

void CtrlIdHandler::read(Context *ctx, folly::IOBufQueue &buf)
{
    if (ctrlIdPromise_.isFulfilled()) {
        ctx->fireRead(buf);
        return;
    }

    const auto kCtrlIdVersionLength = 2U;
    const auto kCtrlIdHeaderLength = 18U;
    if (buf.chainLength() < kCtrlIdHeaderLength) {
        VLOG(2) << "More data needed to parse ctrlId response from "
                << peerAddr_;
        return;
    }

    VLOG(2) << "Successfuly parsed ctrlId response from " << peerAddr_;

    buf.trimStart(kCtrlIdVersionLength);  // trim version, we only support 00
    auto ctrlId =
        buf.split(kCtrlIdHeaderLength - kCtrlIdVersionLength)->moveToFbString();
    ctrlIdPromise_.setValue(std::move(ctrlId));
    ctx->fireRead(buf);
}

folly::Future<folly::fbstring> CtrlIdHandler::ctrlId()
{
    return ctrlIdPromise_.getFuture();
}

HttpResponseHandler::HttpResponseHandler(folly::fbstring addr)
    : peerAddr_{std::move(addr)}
{
}

void HttpResponseHandler::read(Context *ctx, folly::IOBufQueue &buf)
{
    if (promise_.isFulfilled()) {
        ctx->fireRead(buf);
        return;
    }

    buf.gather(buf.chainLength());
    if (buf.chainLength() < http::responseLen()) {
        VLOG(2) << "More data needed to parse HTTP response from " << peerAddr_;
        return;
    }

    auto header = buf.split(http::responseLen())->moveToFbString();
    auto res = http::parseResponse(header);
    switch (res.first) {
        case http::ParseStatus::more_data:
            VLOG(2) << "More data needed to parse HTTP response from "
                    << peerAddr_;
            return;
        case http::ParseStatus::bad_header:
            LOG(WARNING) << "Server responded with bad HTTP response "
                         << peerAddr_;
            promise_.setException(
                folly::make_exception_wrapper<std::runtime_error>(
                    "bad header response from server"));
            return;
        case http::ParseStatus::ok:
            VLOG(2) << "Successfuly read HTTP response from " << peerAddr_;
            promise_.setValue();
            ctx->fireRead(buf);
    }
}

folly::Future<folly::Unit> HttpResponseHandler::done()
{
    return promise_.getFuture();
}

void PeerSecretHandler::read(Context *ctx, folly::IOBufQueue &buf)
{
    if (promise_.isFulfilled()) {
        ctx->fireRead(buf);
        return;
    }

    if (buf.chainLength() < proto::secret_size) {
        VLOG(2) << "More data needed to parse peer secret from " << peerAddr_;
        return;
    }

    auto receivedSecret = buf.split(proto::secret_size)->moveToFbString();
    if (peerSecret_ != receivedSecret) {
        LOG(WARNING) << "Bad peer secret received from " << peerAddr_;
        promise_.setException(folly::make_exception_wrapper<std::runtime_error>(
            "bad peer secret"));
        return;
    }

    VLOG(2) << "Successfuly parsed peer secret from " << peerAddr_;

    promise_.setValue();
    ctx->fireRead(buf);
}

folly::Future<folly::Unit> PeerSecretHandler::done()
{
    return promise_.getFuture();
}

PipelineFactory::PipelineFactory(
    std::shared_ptr<ConnectionCloseHandler> connectionCloseHandler,
    folly::fbstring peerSecret, const folly::SocketAddress &addr)
    : connectionCloseHandler_{std::move(connectionCloseHandler)}
    , peerSecret_{std::move(peerSecret)}
    , peerAddr_{addr.describe()}
{
}

Pipeline::Ptr PipelineFactory::newPipeline(
    std::shared_ptr<folly::AsyncTransportWrapper> sock)
{
    const auto kLengthFieldLength = 2;
    const auto kMaxFrameLength = UINT_MAX;
    const auto kLengthFieldOffset = 0;

    auto pipeline = Pipeline::create();
    pipeline->addBack(wangle::AsyncSocketHandler{sock})
        .addBack(wangle::OutputBufferingHandler{})
        .addBack(EventBaseHandler{})
        .addBack(HttpResponseHandler{peerAddr_})
        .addBack(PeerSecretHandler{peerAddr_, peerSecret_})
        .addBack(CtrlIdHandler{peerAddr_})
        .addBack(connectionCloseHandler_)
        .addBack(wangle::LengthFieldBasedFrameDecoder{
            kLengthFieldLength, kMaxFrameLength, kLengthFieldOffset, 0, 2})
        .addBack(wangle::LengthFieldPrepender{2})
        .addBack(ProtoHandler<proto::LinkMessage>{})
        .finalize();
    return pipeline;
}

DataPipelineFactory::DataPipelineFactory(
    std::shared_ptr<FetchManager> fetchManager,
    std::shared_ptr<ConnectionCloseHandler> connectionCloseHandler,
    folly::fbstring peerSecret, const folly::SocketAddress &addr)
    : fetchManager_{std::move(fetchManager)}
    , connectionCloseHandler_{std::move(connectionCloseHandler)}
    , peerSecret_{std::move(peerSecret)}
    , peerAddr_{addr.describe()}
{
}

DataPipeline::Ptr DataPipelineFactory::newPipeline(
    std::shared_ptr<folly::AsyncTransportWrapper> sock)
{
    const auto kLengthFieldLength = 4;
    const auto kMaxFrameLength = UINT_MAX;
    const auto kLengthFieldOffset = 17;

    auto pipeline = DataPipeline::create();
    pipeline->addBack(wangle::AsyncSocketHandler{sock})
        .addBack(EventBaseHandler{})
        .addBack(HttpResponseHandler{peerAddr_})
        .addBack(PeerSecretHandler{peerAddr_, peerSecret_})
        .addBack(PingSender{})
        .addBack(connectionCloseHandler_)
        .addBack(wangle::LengthFieldBasedFrameDecoder{
            kLengthFieldLength, kMaxFrameLength, kLengthFieldOffset, 0, 0})
        .addBack(fetchManager_)
        .finalize();
    return pipeline;
}

Client::Client(folly::SocketAddress addr,
    std::shared_ptr<folly::IOThreadPoolExecutor> clientExecutor,
    StoragesMap &storages, folly::fbstring mySecret, folly::fbstring peerSecret)
    : addr_{std::move(addr)}
    , storages_{storages}
    , mySecret_{std::move(mySecret)}
    , peerSecret_{std::move(peerSecret)}
    , clientExecutor_{clientExecutor}
    , executor_{folly::SerialExecutor::create(
          folly::getKeepAliveToken(clientExecutor_.get()))}
{
    ctrl_->pipelineFactory(std::make_shared<PipelineFactory>(
        connectionCloseHandler_, peerSecret_, addr_));
    ctrl_->group(clientExecutor);

    LOG_IF(ERROR, mySecret_.size() != proto::secret_size)
        << "Size of secret " << mySecret_.size()
        << " is different from expected " << proto::secret_size
        << ". The handshake will not finish correctly.";
}

Client::~Client() { closeNow(); }

std::unique_ptr<folly::IOBuf> Client::makeHandshake(folly::StringPiece ctrlId)
{
    auto header = folly::IOBuf::copyBuffer(http::makeHeader("example.com"));
    header->prependChain(folly::IOBuf::wrapBuffer(
        "rtransfer00", proto::header_size + proto::version_size));
    header->prependChain(
        folly::IOBuf::copyBuffer(mySecret_.data(), mySecret_.size()));
    header->prependChain(
        folly::IOBuf::copyBuffer(ctrlId.data(), ctrlId.size()));
    return header;
}

folly::Future<folly::Unit> Client::connect()
{
    if (rtransfer::useSSL())
        ctrl_->sslContext(rtransfer::sslContext());

    return ctrl_->connect(addr_, std::chrono::milliseconds{0})
        .thenTry([this](auto &&maybe) {
            if (maybe.hasException()) {
                LOG(ERROR) << "Failed to connect to " << addr_;
                closeNow();
                return folly::makeFuture<folly::Unit>(maybe.exception());
            }

            auto transport = ctrl_->getPipeline()->getTransport();
            auto *base = transport->getEventBase();

            return via(base)
                .thenValue([this, transport](auto && /*unit*/) {
                    common::setNoDelay("client ctrl", transport);
                    DLOG(INFO)
                        << "Sending control header to " << addr_.describe();
                    auto handshake =
                        makeHandshake(folly::fbstring(proto::ctrlid_size, '0'));
                    return ctrl_->getPipeline()
                        ->getContext<wangle::LengthFieldPrepender>()
                        ->fireWrite(std::move(handshake));
                })
                .via(base)
                .thenValue([this](auto && /*unit*/) {
                    DLOG(INFO) << "HttpResponseHandler";
                    return ctrl_->getPipeline()
                        ->getHandler<HttpResponseHandler>()
                        ->done();
                })
                .via(base)
                .thenValue([this](auto && /*unit*/) {
                    DLOG(INFO) << "PeerSecretHandler";
                    return ctrl_->getPipeline()
                        ->getHandler<PeerSecretHandler>()
                        ->done();
                })
                .via(base)
                .thenValue([this](auto && /*unit*/) {
                    DLOG(INFO) << "CtrlIdHandler";
                    return ctrl_->getPipeline()
                        ->getHandler<CtrlIdHandler>()
                        ->ctrlId();
                })
                .via(base)
                .thenValue([this](folly::fbstring &&ctrlId) {
                    DLOG(INFO) << "Removing finished handlers";
                    ctrl_->getPipeline()
                        ->remove<HttpResponseHandler>()
                        .remove<PeerSecretHandler>()
                        .remove<CtrlIdHandler>()
                        .finalize();
                    return ctrlId;
                })
                .via(executor_.get())
                .thenValue([this](folly::fbstring &&ctrlId) {
                    DLOG(INFO) << "Starting data conns";
                    ctrlId_ = std::move(ctrlId);
                    dispatcher_.setPipeline(ctrl_->getPipeline());

                    folly::fbvector<folly::Future<folly::Unit>> newDataConns;
                    for (std::size_t i = 0; i < FLAGS_number_of_data_conns; ++i)
                        newDataConns.emplace_back(newDataConn());

                    return folly::collectAll(newDataConns);
                })
                .via(executor_.get())
                .thenValue([this](auto /*futures*/) {
                    pinger_ = std::make_unique<PeriodicHandler>(10s,
                        [s = std::weak_ptr<Client>{this->shared_from_this()}] {
                            if (auto self = s.lock())
                                self->ping();
                        });
                })
                .thenError(folly::tag_t<folly::exception_wrapper>{},
                    [this](auto &&ew) {
                        closeNow();
                        return folly::makeFuture<folly::Unit>(
                            std::forward<decltype(ew)>(ew));
                    });
        });
}

folly::Future<folly::Unit> Client::ping()
{
    auto msg = std::make_unique<proto::LinkMessage>();
    msg->mutable_ping();

    return via(executor_.get())
        .thenValue([this, msg = std::move(msg)](auto && /*unit*/) mutable {
            return service_(std::move(msg));
        })
        .thenTry([](folly::Try<MsgPtr> &&maybeResponse) {
            if (maybeResponse.hasException()) {
                LOG(ERROR) << "Pong not received for ping";
                return folly::makeFuture<MsgPtr>(maybeResponse.exception());
            }

            return ensureResponse(
                std::move(maybeResponse.value()), proto::LinkMessage::kPong);
        })
        .thenValue([](MsgPtr && /*unused*/) {});
}

void Client::closeNow()
{
    if (closeNowFlag_.test_and_set())
        return;

    auto *pipeline = ctrl_->getPipeline();
    if (pipeline == nullptr)
        return;

    auto transport = pipeline->getTransport();
    if (transport == nullptr)
        return;

    auto *base = transport->getEventBase();
    if (base == nullptr)
        return;

    base->runInEventBaseThread(
        [pinger = std::move(pinger_), ctrl = std::move(ctrl_),
            data = std::move(data_)]() mutable {
            pinger = nullptr;
            for (auto &d : data) {
                auto *base = d->getPipeline()->getTransport()->getEventBase();
                base->runInEventBaseThread(
                    [d = std::move(d)]() mutable { d = nullptr; });
            }
            ctrl = nullptr;
        });
}

folly::Future<std::size_t> Client::fetch(folly::StringPiece srcStorageId,
    folly::StringPiece srcFileId, folly::StringPiece destStorageId,
    folly::StringPiece destFileId, folly::StringPiece fileGuid,
    std::uint64_t offset, std::size_t size, std::uint8_t priority,
    std::uint64_t reqId, folly::StringPiece transferData,
    folly::Function<void(std::uint64_t, std::size_t)> notifyCb)
{
    auto msg = std::make_unique<proto::LinkMessage>();
    auto *req = msg->mutable_fetch();
    req->set_file_id(srcFileId.data(), srcFileId.size());
    req->set_src(srcStorageId.data(), srcStorageId.size());
    req->set_dest(destStorageId.data(), destStorageId.size());
    req->set_offset(offset);
    req->set_size(size);
    req->set_req_id(reqId);
    req->set_priority(priority);
    req->set_transfer_data(transferData.data(), transferData.size());

    auto destStorage = storages_.find(destStorageId.str());  // NOLINT
    if (destStorage == storages_.cend())
        throw std::runtime_error("Cannot find storage " + destStorageId.str());

    const auto destHelper = (*destStorage).second->helper();
    auto destBlockSize =
        destHelper->isObjectStorage()
            ? destHelper->blockSizeForPath(destFileId.str()).get()
            : 0;

    req->set_dest_block_size(destBlockSize);

    VLOG(2) << "Destination block size for storage " << destStorageId << " is "
            << destBlockSize;

    auto fetchFuture = fetchManager_->newFetch(reqId, destStorageId.str(),
        destFileId.str(), fileGuid.str(), offset, size, std::move(notifyCb));

    return via(executor_.get())
        .thenValue([this, reqId, msg = std::move(msg),
                       fetchFuture = std::move(fetchFuture)](
                       auto && /*unit*/) mutable {
            auto replyFuture =
                service_(std::move(msg))
                    .via(executor_.get())
                    .thenTry([this, reqId](folly::Try<MsgPtr> &&maybeResponse) {
                        if (maybeResponse
                                .hasException<folly::FutureCancellation>()) {
                            // We're on a happy path - just return anything
                            return folly::makeFuture<MsgPtr>(
                                maybeResponse.exception());
                        }

                        if (maybeResponse.hasException()) {
                            LOG(WARNING) << "Fetch request " << reqId
                                         << " failed due to exception: "
                                         << maybeResponse.exception().what();
                            fetchManager_->cancelFetch(
                                reqId, maybeResponse.exception());
                            return folly::makeFuture<MsgPtr>(
                                maybeResponse.exception());
                        }

                        auto &response = maybeResponse.value();

                        if (response->payload_case() ==
                            proto::LinkMessage::kError) {
                            LOG(WARNING) << "Fetch request " << reqId
                                         << " failed due to error: "
                                         << response->error();
                            fetchManager_->cancelFetch(reqId,
                                folly::make_exception_wrapper<
                                    std::runtime_error>(response->error()));
                        }
                        else if (response->payload_case() !=
                                 proto::LinkMessage::kTotalSize) {
                            LOG(WARNING) << "Fetch request " << reqId
                                         << " canceled due to invalid response "
                                            "- expected kTotalSize";
                            fetchManager_->cancelFetch(reqId,
                                folly::make_exception_wrapper<
                                    std::runtime_error>("Invalid response"));
                        }

                        return folly::makeFuture<MsgPtr>(std::move(response));
                    });

            return std::move(fetchFuture)
                .via(executor_.get())
                .ensure([replyFuture = std::move(replyFuture)]() mutable {
                    // We have our result, no need to wait for a reply future
                    // now - if we're on the happy path, the reply's not coming
                    // at all
                    replyFuture.cancel();
                });
        });
}

folly::Future<folly::Unit> Client::cancel(std::uint64_t reqId,
    folly::StringPiece srcStorageId, folly::StringPiece destStorageId)
{
    VLOG(2) << "Cancel fetch called for reqId " << reqId;

    auto msg = std::make_unique<proto::LinkMessage>();
    auto *req = msg->mutable_cancel();
    req->set_req_id(reqId);
    req->set_src(srcStorageId.data(), srcStorageId.size());
    req->set_dest(destStorageId.data(), destStorageId.size());

    return via(executor_.get())
        .thenValue(
            [this, reqId, msg = std::move(msg)](auto && /*unit*/) mutable {
                fetchManager_->cancelFetch(reqId);
                return service_(std::move(msg));
            })
        .thenTry([](folly::Try<MsgPtr> &&maybeResponse) {
            if (maybeResponse.hasException()) {
                LOG(ERROR) << "Cancel failed due to exception: "
                           << maybeResponse.exception().what();
                return folly::makeFuture<folly::Unit>(
                    maybeResponse.exception());
            }

            return ensureResponse(
                std::move(maybeResponse.value()), proto::LinkMessage::kDone)
                .thenValue([](auto && /*ignore*/) {});
        });
}

void Client::ack(std::uint64_t reqId, std::uint64_t offset)
{
    bool shouldIHandleAcks = false;
    {
        folly::SpinLockGuard guard{acksLock_};
        pendingAcks_.emplace_back(reqId, offset);
        if (!someoneHandlingAcks_)
            shouldIHandleAcks = someoneHandlingAcks_ = true;
    }

    if (shouldIHandleAcks)
        folly::getUnsafeMutableGlobalCPUExecutor()->addWithPriority(
            std::bind(&Client::ackLoop, this), SHAPER_OPS_PRIO);
}

void Client::ackLoop()
{
    while (true) {
        usedAcks_.clear();
        {
            folly::SpinLockGuard guard{acksLock_};
            if (pendingAcks_.empty()) {
                someoneHandlingAcks_ = false;
                return;
            }
            pendingAcks_.swap(usedAcks_);
        }

        std::unordered_map<std::uint64_t, folly::fbvector<std::uint64_t>> byReq;
        for (auto &ack : usedAcks_)
            byReq[ack.first].emplace_back(ack.second);

        auto msg = std::make_unique<proto::LinkMessage>();
        bool isEmpty{true};
        for (auto &ack : byReq) {
            auto *a = msg->mutable_acks()->add_acks();
            a->set_req_id(ack.first);
            isEmpty = false;
            for (auto offset : ack.second)
                a->add_offsets(offset);
        }

        if (!isEmpty) {
            executor_->add([this, msg = std::move(msg)]() mutable {
                service_(std::move(msg)).cancel();
            });
        }
    }
}

folly::Future<MsgPtr> Client::ensureResponse(
    MsgPtr msg, proto::LinkMessage::PayloadCase expected)
{
    if (msg->payload_case() == proto::LinkMessage::kError) {
        return folly::makeFuture<MsgPtr>(
            folly::make_exception_wrapper<std::runtime_error>(msg->error()));
    }
    if (msg->payload_case() != expected)
        return folly::makeFuture<MsgPtr>(
            folly::make_exception_wrapper<std::runtime_error>("bad response"));

    return folly::makeFuture<MsgPtr>(std::move(msg));
}

folly::Future<folly::Unit> Client::newDataConn()
{
    VLOG(2) << "Establishing new data connection to " << addr_.describe();
    auto client = std::make_shared<rtransfer::ClientBootstrap<DataPipeline>>();

    if (rtransfer::useSSL())
        client->sslContext(rtransfer::sslContext());

    client->pipelineFactory(std::make_shared<DataPipelineFactory>(
        fetchManager_, connectionCloseHandler_, peerSecret_, addr_));

    folly::SocketOptionKey recvBufSize{SOL_SOCKET, SO_RCVBUF};
    client->group(clientExecutor_)
        ->setSocketOptions({{recvBufSize, FLAGS_recv_buf_size}});

    return client->connect(addr_, std::chrono::milliseconds{0})
        .thenTry([client, this](auto &&maybe) {
            if (maybe.hasException()) {
                LOG(ERROR) << "Connection to " << addr_
                           << " failed due to: " << maybe.exception().what();
                closeNow();
                return folly::makeFuture<folly::Unit>(maybe.exception());
            }
            auto transport = client->getPipeline()->getTransport();
            auto *base = transport->getEventBase();

            return via(base)
                .thenValue([this, transport, client](auto && /*unit*/) {
                    common::setNoDelay("client data sock", transport);
                    VLOG(2)
                        << "Sending data conn header to " << addr_.describe();
                    auto handshake = makeHandshake(ctrlId_);
                    return client->getPipeline()->write(std::move(handshake));
                })
                .thenValue([client](auto && /*unit*/) {
                    return client->getPipeline()
                        ->getHandler<HttpResponseHandler>()
                        ->done();
                })
                .thenValue([client](auto && /*unit*/) {
                    return client->getPipeline()
                        ->getHandler<PeerSecretHandler>()
                        ->done();
                })
                .thenValue([client](auto && /*unit*/) {
                    client->getPipeline()
                        ->remove<HttpResponseHandler>()
                        .remove<PeerSecretHandler>()
                        .finalize();
                })
                .via(executor_.get())
                .thenValue([this, client](auto && /*unit*/) {
                    data_.emplace_back(client);
                })
                .thenError(folly::tag_t<folly::exception_wrapper>{},
                    [this](auto &&ew) {
                        closeNow();
                        return folly::makeFuture<folly::Unit>(
                            std::forward<decltype(ew)>(ew));
                    });
        });
}

Client::ClientBootstrapPtr Client::ctrl() { return ctrl_; }

folly::fbvector<DataBootstrapPtr> &Client::data() { return data_; };

}  // namespace client
}  // namespace rtransfer
