#!/usr/bin/env python
""" Find enrichments of GO terms for a ranked list of genes. """

from __future__ import print_function, division

import sys
import re
import csv
import textwrap
import argparse
from itertools import product, islice
from collections import defaultdict
from functools import lru_cache
import numpy as np
import math
from scipy.misc import comb as spcomb

def argument_parser():
	parser = argparse.ArgumentParser(description=__doc__)
	parser.add_argument("obo", help="The gene ontology in .obo format.")
	parser.add_argument("genelist", nargs="?", default=sys.stdin, help="A ranked list of genes. Each row should contain a comma separated list of GO terms the gene is contained in.")
	parser.add_argument("--draw-graph", "-g", help="Draw the graph of enriched GO terms.")
	parser.add_argument("--max-partition", "-m", type=int, metavar="M", help="Consider only the first M partitions when calculating an enrichment of a GO term against the ranked genelist.")
	parser.add_argument("--fixed-partition", "--forground", type=int, metavar="M", help="Consider the first M M genes as foreground and the rest as background when calculating an enrichment of a GO term against the ranked genelist.")
	parser.add_argument("--max-pvalue", "-p", type=float, metavar="P", default=0.001, help="Maximum uncorrected p-value to report an enrichment.")
	parser.add_argument("--usefdr", action="store_true", help="Use FDR for p-value cutoff (see above).")
	parser.add_argument("--fast", "-f", action="store_true", help="Calculate an upper bound rather than the exact p-value.")
	return parser

def comb(n, k, spcomb=spcomb):
	return spcomb(n, k, exact=True)

############## Enrichment ################
# based on Eden et al. "GOrilla: a tool for discovery and visualization of enriched GO terms in ranked gene lists", BMC Bioinformatics 2009


def hypergeometric_tail(N, B, n, b, comb=comb):
	"""
	Computes HGT(b;N,B,n)

	N -- number of overall genes
	B -- size of the GO term
	n -- number of tested genes
	b -- number of tested genes associated with the GO term
	"""
	NB = comb(N,B)
	hgt = sum(comb(n, i) * comb(N - n, B - i) for i in range(b, min(n, B) + 1)) / NB
	return hgt

def tails(N, B, interm, max_partition, hypergeometric_tail=hypergeometric_tail):
	b = 0
	yield hypergeometric_tail(N, B, 1, interm[0]), 1, interm[0] # yield the first element
	for n in range(2, min(len(interm), max_partition) + 1):
		isin = interm[n-1]
		if isin: # only need to compute the tail if b increases, else hgt will become only bigger and we need the minimum
			b += 1
			yield hypergeometric_tail(N, B, n, b), n, b

def min_hypergeometric_tail(N, B, interm, max_partition, tails=tails):
	"""
	Computes mHGT(lambda)

	N -- number of overall genes
	B -- size of the GO term
	interm -- the vector lambda stating (with a boolean value True) for gene i in the list whether it is contained in the considered GO term
	"""
	return min(tails(N, B, interm, max_partition), key=lambda item: item[0])

def not_visiting_paths(N, B, R):
	"""
	Computes the number of paths not visiting R (PI_R(N,B)) by dynamic programming.
	"""
	# we shift b by 1 to allow for b=-1 at the index of 0, i.e. the real b equals b - 1
	#P = np.zeros((N+1, B+2), dtype=np.uint)
	P = [[0 for b in range(B+2)] for n in range(N+1)]
	P[0][1] = 1
	for n in range(1, N + 1):
		for b in range(max(1, B + 1 - N + n), min(B + 1, n + 1) + 1):
			if (n,b-1) in R:
				P[n][b] = 0
			else:
				P[n][b] = P[n-1][b] + P[n-1][b-1]
	ret = P[N][B+1] # i.e. PI[N,B] in the paper
	return ret

def R(N, B, mhgt, comb=comb):
	"""
	Points n,b in the Grid N,B where HGT(b;N,B,n) <= mhgt
	"""
	validpoints = set()
	NB = comb(N,B)
	for n in range(1, N+1):
		hgt = 0
		for b in reversed(range(max(0, B - N + n), min(B, n) + 1)):
			hgt += comb(n, b) * comb(N - n, B - b) / NB # calc the next element of the hypergeometric tail sum
			if hgt <= mhgt:
				validpoints.add((n,b))
			else:
				break
	return validpoints

def pvalue(N, B, mhgt, not_visiting_paths=not_visiting_paths, comb=comb):
	"""
	Calculate the p-value for the given minimum Hypergeometric Tail score mhgt.
	Then the pvalue is the probability to see a score <= mhgt given N genes in total
	and B genes in the GO term.
	"""
	nvp = not_visiting_paths(N, B, R(N, B, mhgt))
	NB = comb(N,B)
	return (NB - nvp) / NB

############### GO Parser ###################

def parse_obo(obofile):
	with open(obofile) as obofile:
		for l in obofile:
			if l.startswith("[Term]"):
				goterm = GOTerm(obofile)
				#if goterm.id == 1709:
				yield goterm

class GOTerm:
	_byid = dict()

	@classmethod
	def byid(cls, id):
		if isinstance(id, str):
			return cls._byid[goid(id)]
		return cls._byid[id]

	def __init__(self, obofile):
		values = self.parse_goterm(obofile)
		self.id = goid(values["id"][0])
		self.name = values["name"][0]
		self.namespace = values["namespace"][0]
		self.definition = values["def"][0]
		self.subset = values["subset"][0] if values["subset"] else None
		self.is_a = list(map(goid, values["is_a"]))
		self._byid[self.id] = self

	@staticmethod
	def parse_goterm(obofile):
		values = defaultdict(list)
		regex = re.compile("(?P<key>\w+): (?P<value>[^!]+)")
		for l in obofile:
			if l == "\n":
				break
			match = re.match(regex, l)
			values[match.group("key")].append(match.group("value").strip())
		return values

	def __repr__(self):
		return "GO:{:07} {}".format(self.id, self.name)

def goid(idstring):
	return int(idstring[3:])

################ Genelist parser ##############

def parse_genelist(genelist):
	for l in csv.reader(genelist, delimiter="\t"):
		if len(l) > 1 and l[1]:
			yield Gene(l[0], l[1:])

class Gene:
	def __init__(self, id, goterms):
		self.id = id
		self.goterms = set(map(goid, goterms))

	def interm(self, goterm):
		return goterm.id in self.goterms

################ process data #################

def calc_fdrs(pvalues, sortedindex, n):
	"""
	Calculate FDR with the algorithm of Benjamini-Hochberg as implemented in the R package multtest.
	From Benjamini-Hochberg, 1995:
	let k be the largest i for which
	P_i <= i / n * (q*)
	then reject all H_i for i = 1,...,k. Thereby, above procedure controls the false discovery rate at q*.
	In other words, the false discovery rate FDR_i for P_i is
	P_i * n / i <= FDR_i .
	"""
	fdr = np.empty_like(pvalues)
	if n:
		fdr[sortedindex[n-1]] = pvalues[sortedindex[n-1]]
		for i in reversed(range(n-1)):
			fdr[sortedindex[i]] = min(fdr[sortedindex[i+1]], pvalues[sortedindex[i]] * (n / (i+1)), 1)
			assert fdr[sortedindex[i]] >= pvalues[sortedindex[i]]
	return fdr

def calc_interm(goterm, genelist):
	return [goterm.id in gene.goterms for gene in genelist]

def is_hit(interm, max_partition = None):
	if max_partition is None:
		max_partition = len(interm)
	return sum(interm[:max_partition]) >= 1

def calc_enrichments(goterms, genelist, max_partition = None, fixed_partition = None, max_pvalue = 0.001, fast = False):
	N = len(genelist)
	max_partition = min(len(genelist), N) if max_partition is None else max_partition
	pvalues = np.ones(len(goterms))
	interms = []
	params = []
	print("test", file=sys.stderr)
	for i, goterm in enumerate(goterms):
		interm = calc_interm(goterm, genelist)
		interms.append(interm)
		B = sum(interm)
		if is_hit(interm, max_partition=max_partition):
			if not fixed_partition is None:
				# TODO pvalue in this case has to be computed differently
				n = fixed_partition
				b = sum(interm[:fixed_partition])
				mhgt = hypergeometric_tail(N, B, n, b)
			else:
				mhgt, n, b = min_hypergeometric_tail(N, B, interm, max_partition)

			if mhgt < max_pvalue: # use lower bound of p-value as in Eden et al. Plos Comp. Biol. 2007 to omit unnecessary computations
				if fast:
					pvalues[i] = B * mhgt
				else:
					pvalues[i] = pvalue(N, B, mhgt)
				assert mhgt - pvalues[i] <= 0.0001
				assert pvalues[i] - B * mhgt <= 0.0001
		else:
			n, b = 0, 0 # no hit in possible partitions
		params.append((N, B, n, b))
		print(i, "of", len(goterms), "done", file=sys.stderr)

	return pvalues, interms, params

def significant_indices(sortedindex, hits, pvalues, max_pvalue):
	return set(i for i in islice(sortedindex, hits) if pvalues[i] <= max_pvalue)

################## drawing ###################

def collect_terms(goterms, significant):
	visited = set(significant)
	queue = list(visited)
	parents = dict()
	while queue:
		goterm = queue.pop(0)
		parents[goterm] = list(map(GOTerm.byid, goterm.is_a))
		for parent in parents[goterm]:
			if parent not in visited:
				visited.add(parent)
				queue.append(parent)
	return visited, parents

def draw_terms(outfile, goterms, pvalues, fdrs, params, significant, usefdr, maxpvalue):
	stat = pvalues if not usefdr else fdrs
	with open(outfile, "w") as dot:
		dot.write("digraph enrichment {\n")
		dot.write("node [shape=Mrecord,style=filled];")
		significant = set(goterms[i] for i in significant)
		visited, parents = collect_terms(goterms, significant)
		for i, goterm in enumerate(goterms):
			if goterm not in visited:
				continue

			if goterm in significant:
				pval = "\\npvalue: " if not usefdr else "\\nfdr: "
				pval += "{:.2e}".format(stat[i])
				saturation = 1 - stat[i] / maxpvalue
			else:
				saturation = 0
				pval = ""

			dot.write("{}[label=\"{}{}\",fillcolor=\"0.0 {} 1.0\"];\n".format(goterm.id, "\\n".join(textwrap.wrap(str(goterm), 30)), pval, saturation))
			for parent in parents[goterm]:
				dot.write("{} -> {};\n".format(parent.id, goterm.id))
		dot.write("}")




def main():
	parser = argument_parser()
	args = parser.parse_args()
	#import yappi
	#yappi.start()

	if args.genelist == sys.stdin:
		genelist = list(parse_genelist(args.genelist))
	else:
		with open(args.genelist) as f:
			genelist = list(parse_genelist(f))
	goterms = list(parse_obo(args.obo))

	pvalues, interms, params = calc_enrichments(goterms, genelist, max_partition=args.max_partition, fixed_partition=args.fixed_partition, max_pvalue = args.max_pvalue, fast = args.fast)
	hits = sum(is_hit(interm) for interm in interms)
	sortedindex = sorted(range(len(pvalues)), key=pvalues.__getitem__)
	fdrs = calc_fdrs(pvalues, sortedindex, hits)

	if args.usefdr:
		significant = significant_indices(sortedindex, hits, fdrs, args.max_pvalue)
	else:
		significant = significant_indices(sortedindex, hits, pvalues, args.max_pvalue)

	if args.draw_graph:
		draw_terms(args.draw_graph, goterms, pvalues, fdrs, params, significant, args.usefdr, args.max_pvalue)

	print("goterm\tp-value\tfdr\ttotal num of genes\tgenes in GO term\tnum of genes in partition\tnum of genes in partition and GO term\tgenes")
	for i in islice(sortedindex, hits):
		N, B, n, b = params[i]
		if i in significant:
			genes = (genelist[j].id for j, isin in islice(enumerate(interms[i]), n) if isin)
			print("{goterm}\t{pvalue}\t{fdr}\t{params}\t{genes}".format(goterm=goterms[i], pvalue=pvalues[i], fdr=fdrs[i], params="\t".join(map(str, (N,B,n,b))), genes="\t".join(genes)))
	#with open("profile.txt", "w") as out:
	#	yappi.print_stats(out=out, sort_type=2)


def test():
	N = 330
	B = 30
	mhgt = 0.0000001
	while mhgt < 1:
		assert mhgt <= pvalue(N, B, mhgt) <= B*mhgt
		print(mhgt, pvalue(N, B, mhgt), B*mhgt)
		mhgt *= 10

if __name__ == "__main__":
	main()
