import pytest

from hypothesis import given, settings, reproduce_failure, unlimited, HealthCheck, seed
from hypothesis.extra.pandas import data_frames, columns, range_indexes, column, indexes
from hypothesis.extra.numpy import arrays
import hypothesis.strategies as st

from itertools import product
import tempfile
import subprocess
from io import StringIO

from pyrle import Rle
import pyranges as pr

import pandas as pd
import numpy as np


from tests.helpers import assert_df_equal
from tests.hypothesis_helper import dfs_min2, dfs_min



import numpy as np

from os import environ

if environ.get("TRAVIS"):
    max_examples = 100
    deadline = None
else:
    max_examples = 1000
    deadline = None


strandedness = [False, "same", "opposite"]
no_opposite = [False, "same"]


def run_bedtools(command, gr, gr2, strandedness, nearest_overlap=False, nearest_how=None):

    bedtools_strand = {False: "", "same": "-s", "opposite": "-S"}[strandedness]
    bedtools_overlap = {True: "", False: "-io"}[nearest_overlap]
    bedtools_how = {"upstream": "-id", "downstream": "-iu", None: ""}[nearest_how] + " -D a"
    # print("bedtools how:", bedtools_how)

    with tempfile.TemporaryDirectory() as temp_dir:
        f1 = "{}/f1.bed".format(temp_dir)
        f2 = "{}/f2.bed".format(temp_dir)
        gr.df.to_csv(f1, sep="\t", header=False, index=False)
        gr2.df.to_csv(f2, sep="\t", header=False, index=False)

        cmd = command.format(f1=f1, f2=f2, strand=bedtools_strand, overlap=bedtools_overlap,
                             bedtools_how=bedtools_how)
        print("cmd " * 5)
        print(cmd)
        result = subprocess.check_output(cmd, shell=True, executable="/bin/bash").decode()

    return result


def read_bedtools_result_set_op(bedtools_result, strandedness):

    if strandedness:
        usecols = [0, 1, 2, 5]
        names = "Chromosome Start End Strand".split()
    else:
        usecols = [0, 1, 2]
        names = "Chromosome Start End".split()

    return pd.read_csv(StringIO(bedtools_result), header=None, usecols=usecols, names=names, sep="\t")


def compare_results(bedtools_df, result):

    if not bedtools_df.empty:
        assert_df_equal(result.df, bedtools_df)
    else:
        assert bedtools_df.empty == result.df.empty


def compare_results_nearest(bedtools_df, result):

    if not bedtools_df.empty:
        bedtools_df = bedtools_df[bedtools_df.Distance != -1]

    result = result.df

    if not len(result) == 0:
        result_df = result["Chromosome Start End Strand Distance".split()]
        assert_df_equal(result_df, bedtools_df)
    else:
        assert bedtools_df.empty


@pytest.mark.bedtools
@pytest.mark.parametrize("strandedness", no_opposite)
@settings(max_examples=max_examples, deadline=deadline, timeout=unlimited, suppress_health_check=HealthCheck.all())
@given(gr=dfs_min(), gr2=dfs_min())
def test_set_intersect(gr, gr2, strandedness):

    set_intersect_command = "bedtools intersect {strand} -a <(sort -k1,1 -k2,2n {f1} | bedtools merge {strand} -c 4,5,6 -o first -i -) -b <(sort -k1,1 -k2,2n {f2} | bedtools merge {strand} -c 4,5,6 -o first -i -)"
    bedtools_result = run_bedtools(set_intersect_command, gr, gr2, strandedness)

    bedtools_df = read_bedtools_result_set_op(bedtools_result, strandedness)

    result = gr.set_intersect(gr2, strandedness=strandedness)

    compare_results(bedtools_df, result)


@pytest.mark.bedtools
@pytest.mark.parametrize("strandedness", no_opposite)
@settings(max_examples=max_examples, deadline=deadline, timeout=unlimited, suppress_health_check=HealthCheck.all())
@given(gr=dfs_min(), gr2=dfs_min())
def test_set_union(gr, gr2, strandedness):

    set_union_command = "cat {f1} {f2} | bedtools sort | bedtools merge {strand} -c 4,5,6 -o first -i -"    # set_union_command = "bedtools merge {strand} -c 4,5,6 -o first -i {f1}"
    bedtools_result = run_bedtools(set_union_command, gr, gr2, strandedness)

    bedtools_df = read_bedtools_result_set_op(bedtools_result, strandedness)

    result = gr.set_union(gr2, strandedness=strandedness)

    compare_results(bedtools_df, result)


@pytest.mark.bedtools
@pytest.mark.parametrize("strandedness", strandedness)
@settings(max_examples=max_examples, deadline=deadline, timeout=unlimited, suppress_health_check=HealthCheck.all())
@given(gr=dfs_min(), gr2=dfs_min())
def test_overlap(gr, gr2, strandedness):

    overlap_command = "bedtools intersect -wa {strand} -a {f1} -b {f2}"

    bedtools_result = run_bedtools(overlap_command, gr, gr2, strandedness)

    bedtools_df = pd.read_csv(StringIO(bedtools_result), header=None, names="Chromosome Start End Name Score Strand".split(), sep="\t")

    result = gr.overlap(gr2, strandedness=strandedness)

    compare_results(bedtools_df, result)


@pytest.mark.bedtools
@pytest.mark.parametrize("strandedness", strandedness)
@settings(max_examples=max_examples, deadline=deadline, timeout=unlimited, suppress_health_check=HealthCheck.all())
@given(gr=dfs_min(), gr2=dfs_min())
def test_intersect(gr, gr2, strandedness):

    intersect_command = "bedtools intersect {strand} -a {f1} -b {f2}"

    bedtools_result = run_bedtools(intersect_command, gr, gr2, strandedness)

    bedtools_df = pd.read_csv(StringIO(bedtools_result), header=None, names="Chromosome Start End Name Score Strand".split(), sep="\t")

    result = gr.intersect(gr2, strandedness=strandedness)

    compare_results(bedtools_df, result)



@pytest.mark.bedtools
@pytest.mark.parametrize("strandedness", ["same", "opposite", False]) #
@settings(max_examples=max_examples, deadline=deadline, timeout=unlimited, suppress_health_check=HealthCheck.all())
@given(gr=dfs_min(), gr2=dfs_min())
def test_subtraction(gr, gr2, strandedness):

    subtract_command = "bedtools subtract {strand} -a {f1} -b {f2}"

    bedtools_result = run_bedtools(subtract_command, gr, gr2, strandedness)

    bedtools_df = pd.read_csv(StringIO(bedtools_result), header=None, names="Chromosome Start End Name Score Strand".split(), sep="\t")

    result = gr.subtract(gr2, strandedness=strandedness)

    compare_results(bedtools_df, result)


nearest_hows = [None, "upstream", "downstream"]
overlaps = [True, False]

@pytest.mark.bedtools
@pytest.mark.parametrize("nearest_how,overlap,strandedness", product(nearest_hows, overlaps, strandedness))
@settings(max_examples=max_examples, deadline=deadline, timeout=unlimited, suppress_health_check=HealthCheck.all())
@given(gr=dfs_min(), gr2=dfs_min())
def test_nearest(gr, gr2, nearest_how, overlap, strandedness):

    nearest_command = "bedtools closest {bedtools_how} {strand} {overlap} -t first -d -a <(sort -k1,1 -k2,2n {f1}) -b <(sort -k1,1 -k2,2n {f2})"

    bedtools_result = run_bedtools(nearest_command, gr, gr2, strandedness, overlap, nearest_how)

    bedtools_df = pd.read_csv(StringIO(bedtools_result), header=None, names="Chromosome Start End Strand Chromosome2 Distance".split(), usecols=[0, 1, 2, 5, 6, 12], sep="\t")

    bedtools_df.Distance = bedtools_df.Distance.abs()

    bedtools_df = bedtools_df[bedtools_df.Chromosome2 != "."]
    bedtools_df = bedtools_df.drop("Chromosome2", 1)

    result = gr.nearest(gr2, strandedness=strandedness, overlap=overlap, how=nearest_how)

    print("bedtools " * 5)
    print(bedtools_df)
    print("result " * 5)
    print(result)

    compare_results_nearest(bedtools_df, result)



@pytest.mark.bedtools
@pytest.mark.parametrize("strandedness", no_opposite)
@settings(max_examples=max_examples, deadline=deadline, timeout=unlimited, suppress_health_check=HealthCheck.all())
@given(gr=dfs_min(), gr2=dfs_min())
def test_jaccard(gr, gr2, strandedness):

    # jaccard_command = "bedtools jaccard {strand}  -a <(sort -k1,1 -k2,2n {f1}) -b <(sort -k1,1 -k2,2n {f2})"

    # bedtools_result = run_bedtools(jaccard_command, gr, gr2, strandedness)
    # print(bedtools_result)

    # bedtools_jaccard = float(bedtools_result.split("\n")[1].split()[2])
    # print(bedtools_jaccard)

    # https://github.com/arq5x/bedtools2/issues/645
    # will make tests proper when bedtools is fixed
    result = gr.jaccard(gr2, strandedness=strandedness)

    # assert abs(result - bedtools_jaccard) < 0.001


    assert 0 <= result <= 1




@pytest.mark.bedtools
@pytest.mark.parametrize("strandedness", strandedness)
@settings(max_examples=max_examples, deadline=deadline, timeout=unlimited, suppress_health_check=HealthCheck.all())
@given(gr=dfs_min(), gr2=dfs_min())
def test_join(gr, gr2, strandedness):

    join_command = "bedtools intersect {strand} -wo -a {f1} -b {f2}"

    bedtools_result = run_bedtools(join_command, gr, gr2, strandedness)

    bedtools_df = pd.read_csv(StringIO(bedtools_result), header=None, sep="\t",
                                names="Chromosome Start End Name Score Strand Chromosome_b Start_b End_b Name_b Score_b Strand_b Overlap".split(),
                                dtype={"Chromosome": "category", "Strand": "category"}).drop("Chromosome_b Overlap".split(), axis=1)

    result = gr.join(gr2, strandedness=strandedness)

    if result.df.empty:
        assert bedtools_df.empty
    else:
        assert_df_equal(result.df, bedtools_df)



@pytest.mark.bedtools
@settings(max_examples=max_examples, deadline=deadline, timeout=unlimited, suppress_health_check=HealthCheck.all())
@given(gr=dfs_min2(), gr2=dfs_min2())
# @reproduce_failure('4.5.7', b'AXicY2RgYGAEIiBmArPAGAj+w3mMcDEoBQAcwgEN')
# @reproduce_failure('4.5.7', b'AXicY2RgYGAEIzhgxBQCAAC7AAc=')
# @reproduce_failure('4.5.7', b'AXicY2RgYGAEIwiAspGFgAAAANEACA==')
# @reproduce_failure('4.5.7', b'AXicY2RgYGAEIwiAshFCEIqZgQEAAQMADA==')
# @reproduce_failure('4.5.7', b'AXicY2RgYGAEIwiAspGFgAAAANEACA==')
# @reproduce_failure('4.5.7', b'AXicY2RgYGAEIzhgRBZiZISJAQAA7QAK')
# @reproduce_failure('4.5.7', b'AXicY2RgYGAEIzhghAtBWVAAAAD2AAk=')
# @reproduce_failure('4.5.7', b'AXicY2RgYGCEITBgApOMKOInGWAAAAk8ANQ=')
# @reproduce_failure('4.5.7', b'AXicY2RgYGCEITCGAEa4OBwAAAENAAk=')
@reproduce_failure('4.5.7', b'AXicY2RgYGAEIwiAskEUEDMxIAAAARgACw==')
# @reproduce_failure('4.5.7', b'AXicY2RgYGAEIyhghHFRRBkYAAD1AAk=')
def test_reldist(gr, gr2):


    reldist_command = "bedtools reldist -a <(sort -k1,1 -k2,2n {f1}) -b <(sort -k1,1 -k2,2n {f2})"

    bedtools_result = run_bedtools(reldist_command, gr, gr2, False)
    bedtools_result = pd.read_csv(StringIO(bedtools_result), sep="\t")
    print(bedtools_result)
    # print(bedtools_result.dtypes)
    # print(bedtools_result.split("\n")[1:])
    result = gr.relative_distance(gr2)
    print(result)

    # result = pd.concat([pd.DataFrame(v) for v in result.values()])
    assert list(bedtools_result["count"]) == result["count"].to_list()
    # print(result)
    # print(result.dtypes)

    # if result.empty and bedtools_result.empty:
    #     assert 1
    # else:
    #     assert result.equals(bedtools_result)

    # bedtools_jaccard = float(bedtools_result.split("\n")[1].split()[2])
    # print(bedtools_jaccard)

    # https://github.com/arq5x/bedtools2/issues/645
    # will make tests proper when bedtools is fixed
    # result = gr.jaccard(gr2, strandedness=strandedness)

    # assert abs(result - bedtools_jaccard) < 0.001


    # assert 0 <= result <= 1
