"""This module contains common functions and fixtures
   for rtransfer tests"""

__author__ = "Bartek Kryza"
__copyright__ = """(C) 2023 ACK CYFRONET AGH,
This software is released under the MIT license cited in 'LICENSE.txt'."""

import sys
import pytest
import random
import json
import string
import subprocess
import hashlib
import time
import base64
import tempfile
import os
import shutil
import boto3
from boto.s3.connection import S3Connection, OrdinaryCallingFormat

# Define variables for use in tests
_script_dir = os.path.dirname(os.path.realpath(__file__))
project_dir = os.path.realpath(os.path.join(_script_dir, '..', '..'))
docker_dir = os.path.join(project_dir, 'bamboos', 'docker')

# Append useful modules to the path
sys.path = [docker_dir] + sys.path

print(sys.path)

from environment import common, docker, ceph, s3

STARTUP_DELAY = 5  # [seconds]
CONNECTION_DELAY = 5  # [seconds]
DATA_CONNECTION_COUNT = 10
VERBOSE = False


def random_int(lower_bound=1, upper_bound=100):
    return random.randint(lower_bound, upper_bound)


def random_str(size=random_int(),
               characters=string.ascii_uppercase + string.digits):
    return ''.join(random.choice(characters) for _ in range(size))


@pytest.fixture()
def src_log_dir(request):
    d = tempfile.mkdtemp()

    def fin():
        try:
            shutil.rmtree(d)
        except OSError as ex:
            pass

    request.addfinalizer(fin)
    return d


@pytest.fixture()
def dest_log_dir(request):
    d = tempfile.mkdtemp()

    def fin():
        try:
            shutil.rmtree(d)
        except OSError as ex:
            pass

    request.addfinalizer(fin)
    return d


@pytest.fixture()
def dest_dir(request):
    d = tempfile.mkdtemp()

    def fin():
        try:
            shutil.rmtree(d)
        except OSError as ex:
            pass

    request.addfinalizer(fin)
    return d


@pytest.fixture()
def src_dir(request):
    d = tempfile.mkdtemp()

    def fin():
        try:
            shutil.rmtree(d)
        except OSError as ex:
            pass

    request.addfinalizer(fin)
    return d


class RTransfer(object):
    def __init__(self, server_port, log_dir, data_connection_number):
        self.server_port = server_port
        self.log_dir = log_dir
        self.data_connection_number = data_connection_number

    def run(self):
        self.process = subprocess.Popen(["stdbuf", "-oL", "build/link",
                                         "-v=1",
                                         "-number_of_data_conns=" + str(self.data_connection_number),
                                         "-alsologtostderr=true",
                                         "-auth_cache_duration=1000000",
                                         "-auth_cache_tick=100000",
                                         "-log_dir=" + self.log_dir,
                                         "-logbufsecs=2",
                                         "-server_port", str(self.server_port)],
                                        stdin=subprocess.PIPE,
                                        stdout=subprocess.PIPE,
                                        text=True,
                                        bufsize=1,
                                        universal_newlines=True,
                                        encoding='ascii')

    def request_async(self, req):
        req_str = json.dumps(req) + "\n"
        if VERBOSE:
            print(f"----------------->>> {req_str}")
        self.process.stdin.write(req_str)

    def get_response(self, skip_updates=True):
        resp = None
        if skip_updates:
            resp = {"isUpdate": True}
            while resp.get("isUpdate"):
                line = self.process.stdout.readline()
                resp = json.loads(line)
        else:
            line = self.process.stdout.readline()
            resp = json.loads(line)

        if VERBOSE:
            print("<<<----------------- "+str(resp))

        return resp

    def request(self, req):
        self.request_async(req)
        resp = self.get_response()
        return resp

    def stop(self):
        self.process.kill()


def start_rtransfer(server_port, log_dir, data_connections_number):
    rtransfer = RTransfer(server_port, log_dir, data_connections_number)
    rtransfer.run()
    return rtransfer


@pytest.fixture()
def rtransfer(request, src_log_dir, dest_log_dir):
    dest = start_rtransfer(2345, dest_log_dir, DATA_CONNECTION_COUNT)
    src = start_rtransfer(2346, src_log_dir, DATA_CONNECTION_COUNT)

    def fin():
        dest.stop()
        src.stop()

    request.addfinalizer(fin)
    time.sleep(STARTUP_DELAY)
    return dest, src


def stop_rtransfers(destination, src):
    destination.stop()
    for s in src:
        src.stop()


def connect_links(dest, src):
    return [connect(dest, src) for _ in range(0, dest.data_connection_number)]


def wait_for_connection():
    time.sleep(CONNECTION_DELAY)


def to_b64(string):
    if isinstance(string, str):
        return base64.b64encode(string.encode('ascii')).decode('ascii')
    else:
        return base64.b64encode(string).decode('ascii')


def create_null_helper(storage_id, prov):
    sid = to_b64(storage_id)
    params = []
    req = {"create_helper": {"storage_id": sid, "name": "nulldevice",
                             "params": params,
                             "io_buffered": False}}
    assert {"done": True} == prov.request(req)
    return sid


def create_slow_null_helper(storage_id, prov):
    sid = to_b64(storage_id)
    params = [{'key': 'latencyMin', 'value': to_b64("100")},
              {'key': 'latencyMax', 'value': to_b64("200")}]
    req = {"create_helper": {"storage_id": sid, "name": "nulldevice",
                             "params": params,
                             "io_buffered": False}}
    assert {"done": True} == prov.request(req)
    return sid


def create_busy_null_helper(storage_id, prov):
    sid = to_b64(storage_id)
    params = [{'key': 'timeoutProbability', 'value': to_b64("0.5")}]
    req = {"create_helper": {"storage_id": sid, "name": "nulldevice",
                             "params": params,
                             "io_buffered": False}}
    assert {"done": True} == prov.request(req)
    return sid


def create_verifying_null_helper(storage_id, prov):
    sid = to_b64(storage_id)
    params = [{'key': 'enableDataVerification', 'value': to_b64("true")}]
    req = {"create_helper": {"storage_id": sid, "name": "nulldevice",
                             "params": params,
                             "io_buffered": False}}
    assert {"done": True} == prov.request(req)
    return sid


def create_posix_helper(storage_id, prov, root):
    sid = to_b64(storage_id)
    params = [{'key': 'mountPoint', 'value': to_b64(root)}]
    req = {"create_helper": {"storage_id": sid, "name": "posix",
                             "params": params,
                             "io_buffered": False}}
    assert {"done": True} == prov.request(req)
    return sid


def create_ceph_helper(storage_id, prov, mon_host, username, key,
                       pool_name, block_size):
    sid = to_b64(storage_id)
    params = [{'key': 'monitorHostname', 'value': to_b64(mon_host)},
              {'key': 'username', 'value': to_b64(username)},
              {'key': 'key', 'value': to_b64(key)},
              {'key': 'poolName', 'value': to_b64(pool_name)},
              {'key': 'blockSize', 'value': to_b64(str(block_size))},
              {'key': 'storagePathType', 'value': to_b64('flat')},
              {'key': 'clusterName', 'value': to_b64('ceph')}]
    req = {"create_helper": {"storage_id": sid, "name": "cephrados",
                             "params": params,
                             "io_buffered": False}}
    assert {"done": True} == prov.request(req)
    return sid


def create_s3_helper(storage_id, prov, scheme, hostname, access_key, secret_key,
                     bucket, prefix, block_size):
    sid = to_b64(storage_id)
    params = [{'key': 'scheme', 'value': to_b64(scheme)},
              {'key': 'hostname', 'value': to_b64(hostname)},
              {'key': 'accessKey', 'value': to_b64(access_key)},
              {'key': 'secretKey', 'value': to_b64(secret_key)},
              {'key': 'bucketName', 'value': to_b64(bucket)},
              {'key': 'prefix', 'value': to_b64(prefix)},
              {'key': 'blockSize', 'value': to_b64(str(block_size))},
              {'key': 'storagePathType', 'value': to_b64('flat')}]
    req = {"create_helper": {"storage_id": sid, "name": "s3",
                             "params": params,
                             "io_buffered": False}}
    assert {"done": True} == prov.request(req)
    return sid


def connect(prov_from, prov_to):
    p1secret = to_b64("a" * 64)
    p2secret = to_b64("b" * 64)
    allow_req = {"allow_connection": {"my_secret": p2secret,
                                      "peer_secret": p1secret,
                                      "provider_id": to_b64("otherprovider"),
                                      "expiration": 60000}}
    assert {"done": True} == prov_to.request(allow_req)

    req = {"connect": {"my_secret": p1secret, "peer_secret": p2secret,
                       "peer_host": "127.0.0.1", "peer_port": prov_to.server_port}}
    resp = prov_from.request(req)
    assert "connectionId" in resp
    return resp["connectionId"]


def do_open(prov, conn_id, storage_id):
    filename = "file1.txt"
    file_id = base64.b64encode(b"/" + filename.encode('utf-8'))
    return file_id


def do_fetch(prov, prov_from, conn_id, src_storage_id, src_fd, dest_storage_id,
             dest_fd, offset, size, priority=1, authorization_cached=False,
             fetch_req_id=None):
    if fetch_req_id is None:
        fetch_req_id=int(random.uniform(0, 2 ** 63))

    req_id = to_b64(random_str())
    req = {"fetch": {"connection_id": conn_id, "src_storage_id": src_storage_id,
                     "src_file_id": src_fd, "dest_storage_id": dest_storage_id, "dest_file_id": dest_fd,
                     "offset": offset, "size": size, "req_id": fetch_req_id,
                     "priority": priority, "file_guid": to_b64(f'guid_{src_fd}')}, "req_id": req_id}

    prov.request_async(req)

    if not authorization_cached:
        question = prov_from.get_response()
        assert question["isQuestion"] == True
        resp = {"req_id": question["reqId"], "is_answer": True,
                "auth_response": {
                    "is_authorized": True,
                    "storage_id": src_storage_id,
                    "file_id": src_fd}}

        prov_from.request_async(resp)

    return req_id


def do_cancel(prov, conn_id, src_storage_id, dest_storage_id, req_id):
    cancel_req_id = to_b64(random_str())

    req = {"cancel": {
        "connection_id": conn_id,
        "src_storage_id": src_storage_id,
        "dest_storage_id": dest_storage_id,
        "req_id": req_id},
        "req_id": cancel_req_id}

    prov.request_async(req)

    return cancel_req_id


def md5(fname):
    hash_md5 = hashlib.md5()
    with open(fname, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_md5.update(chunk)
            return hash_md5.hexdigest()


def do_fetch_on_all_connections(p1, p2, dest_storage_id, src_storage_id, conn_id):
    num = 100
    size = 1024
    destfd = do_open(p1, None, dest_storage_id)
    srcfd = do_open(p1, conn_id, src_storage_id)
    priority = 1

    for i in range(0, num):
        destfd = do_open(p1, None, dest_storage_id)
        srcfd = do_open(p1, conn_id, src_storage_id)
        for j in range(0, num):
            priority = 1 if i < num - 1 else 0
            do_fetch(p1, p2, conn_id, src_storage_id, srcfd, dest_storage_id,
                     destfd, j * size, size, priority)

    for i in range(0, num * num):
        assert {"wrote": str(size)} == p1.get_response()


@pytest.fixture(scope='module')
def ceph_server(request):
    class Server(object):
        def __init__(self, mon_host, username, key, pool_name):
            self.mon_host = mon_host
            self.username = username
            self.key = key
            self.pool_name = pool_name
            self.container = None
            self.block_size = 1024 * 1024

        def list(self, file_id):
            # The only way to list objects in Ceph which start with prefix 'file_id'
            # is to list all objects and then grep through the results.
            output = docker.exec_(self.container,
                                  ['bash', '-c', "rados -p {} ls | grep {} || true".format(
                                      self.pool_name, file_id)], output=True, stdout=sys.stdout)
            return output.splitlines()

        def read_file_contents(self, storage_file_id, size):
            import rados
            cluster = rados.Rados()
            cluster.conf_set('mon_host', self.mon_host)
            cluster.conf_set('username', self.username)
            cluster.conf_set('key', self.key)
            # cluster.conf_set('pool_name', self.pool_name)

            data = []

            ioctx = cluster.open_ioctx(self.pool_name)
            first_block = 999999
            size_so_far = size
            block_it = 0
            while (size_so_far > 0):
                block += ioctx.read(storage_file_id + str(first_block - block_it))
                block_it += 1
                size_so_far -= len(block)
                data += block

            ioctx.close()

            return data

    pool_name = 'data'
    result = ceph.up('onedata/ceph', [(pool_name, '8')], 'storage',
                     common.generate_uid())

    [container] = result['docker_ids']
    username = result['username'].encode('ascii')
    key = result['key'].encode('ascii')
    mon_host = result['host_name'].encode('ascii')

    def fin():
        docker.remove([container], force=True, volumes=True)

    request.addfinalizer(fin)

    server = Server(mon_host, username, key, pool_name)
    server.container = container

    return server


@pytest.fixture(scope='module')
def s3_server(request):
    class Server(object):
        def __init__(self, scheme, hostname, bucket, access_key, secret_key, prefix=""):
            [ip, port] = hostname.split(':')
            self.scheme = scheme
            self.hostname = hostname
            self.access_key = access_key
            self.secret_key = secret_key
            self.bucket = bucket
            self.prefix = prefix
            self.block_size = 1024 * 1024
            self.s3 = boto3.resource('s3', endpoint_url=self.scheme + "://" + self.hostname,
                                     aws_access_key_id=self.access_key,
                                     aws_secret_access_key=self.secret_key)

        def list(self, file_id):
            test_bucket = self.s3.Bucket(self.bucket)
            return [o.key for o in
                    test_bucket.objects.filter(Prefix=os.path.join(self.prefix, file_id) + '/', Delimiter='/')]

        def create_file(self, storage_file_id, count):
            data = b'x'*self.block_size

            first_block = 999999
            block_it = 0

            for i in range(0, count):
                object_path = str(os.path.join(storage_file_id, str(first_block - block_it)))
                obj = self.s3.Object(self.bucket, object_path)
                obj.put(Body=data)
                block_it += 1

        def read_file_contents(self, storage_file_id, size):
            data = b''

            first_block = 999999
            size_so_far = size
            block_it = 0
            while size_so_far > 0:
                object_path = str(os.path.join(storage_file_id, str(first_block - block_it)))

                obj = self.s3.Object(self.bucket, object_path)
                block = obj.get()['Body'].read()

                block_it += 1
                size_so_far -= len(block)

                data += block

            return data

    bucket = 'data'
    result = s3.up('onedata/minio:v1', [bucket], 'storage',
                   common.generate_uid())
    [container] = result['docker_ids']

    def fin():
        docker.remove([container], force=True, volumes=True)

    request.addfinalizer(fin)

    return Server('http', result['host_name'], bucket, result['access_key'],
                  result['secret_key'])
