"""This module contains test cases for parallel transfers"""

__author__ = "Bartek Kryza"
__copyright__ = """(C) 2023 ACK CYFRONET AGH,
This software is released under the MIT license cited in 'LICENSE.txt'."""

import base64

from test_common import *

@pytest.mark.parametrize("file_size,file_count,random_priority",
                         [(1024, 100, True), (1024*1024, 100, True), (10*1024*1024, 100, True),
                          (1024, 100, False), (1024*1024, 100, False), (10*1024*1024, 100, False)])
def test_ideal_null_storage_simultaneous_transfers_different_priorities(rtransfer, file_size, file_count, random_priority):
    (dest, src) = rtransfer
    dest_storage_id = create_null_helper("dest_null", dest)
    src_storage_id = create_null_helper("src_null", src)
    conns = connect_links(dest, src)

    wait_for_connection()

    fetch_offset = 0
    fetch_size = file_size

    files_to_fetch = ["/"+random_str() for _ in range(0, file_count)]

    connection_index = 0

    authorization_cached = False
    req_id_it = 1
    for f in files_to_fetch:
        srcfd = to_b64(f)
        destfd = to_b64(f)
        if random_priority:
            priority = random_int(1, 255)
        else:
            priority = 1
        do_fetch(dest, src, conns[connection_index], src_storage_id, srcfd,
                 dest_storage_id, destfd, fetch_offset,
                 fetch_size, priority, authorization_cached, req_id_it)

        if req_id_it == 1:
            # If this is a first request - wait for auth cache to be filled
            time.sleep(2)
            authorization_cached = True

        print("======= " + str(req_id_it))

        connection_index += 1
        if connection_index >= len(conns):
            connection_index = 0

        req_id_it += 1


    total_size = fetch_size * len(files_to_fetch)

    while True:
        resp = dest.get_response(skip_updates=True)
        written_bytes = int(resp["wrote"])
        if VERBOSE:
            print("------------------------------------------")
            print("     Transferred: " + str(written_bytes))
            print("     Remaining:   " + str(total_size))
            print("------------------------------------------")
        assert fetch_size == written_bytes
        total_size -= fetch_size
        if total_size == 0:
            break


@pytest.mark.parametrize("file_size,file_count",
                         [(10*1024*1024, 100)])
def test_ideal_null_storage_on_the_fly_before_scheduled_transfers(rtransfer, file_size, file_count):
    (dest, src) = rtransfer
    dest_storage_id = create_null_helper("dest_null", dest)
    src_storage_id = create_null_helper("src_null", src)
    conns = connect_links(dest, src)

    wait_for_connection()

    fetch_offset = 0
    fetch_size = file_size

    files_to_fetch = ["/"+random_str() for _ in range(0, file_count)]

    connection_index = 0

    authorization_cached = False
    req_id_it = 2

    # First schedule normal transfers
    for f in files_to_fetch:
        srcfd = to_b64(f)
        destfd = to_b64(f)
        priority = 125
        do_fetch(dest, src, conns[connection_index], src_storage_id, srcfd,
                 dest_storage_id, destfd, fetch_offset,
                 fetch_size, priority, authorization_cached, req_id_it)

        connection_index += 1
        if connection_index >= len(conns):
            connection_index = 0

        authorization_cached = True
        req_id_it += 1

    # Now schedule on the fly transfer with req_id = 1
    on_the_fly_priority = 25
    on_the_fly_req_id = 1
    on_the_fly_fetch_size = 1024
    on_the_fly_fetch_offset = 1024*1024
    f = "/"+random_str()
    srcfd = to_b64(f)
    destfd = to_b64(f)
    req_id = do_fetch(dest, src, conns[0], src_storage_id, srcfd,
                 dest_storage_id, destfd, on_the_fly_fetch_offset,
                 on_the_fly_fetch_size, on_the_fly_priority, authorization_cached, on_the_fly_req_id)

    total_size = fetch_size * len(files_to_fetch)

    complete_responses = 0
    while True:
        resp = dest.get_response(skip_updates=True)

        if resp["reqId"] == req_id:
            assert complete_responses < (file_count / 4)
            break

        written_bytes = int(resp["wrote"])

        if written_bytes >= file_count * file_size:
            break

        complete_responses += 1
