/*
 * Copyright (C) CNRS, INRIA, Université Bordeaux 1, Télécom SudParis
 * See COPYING in top-level directory.
 */

#include <eztrace-lib/eztrace.h>
#include <eztrace-lib/eztrace_module.h>
#include <starpu.h>
#include <starpu_opencl.h>

#ifdef USE_MPI
#include <starpu_mpi.h>
#endif

#define CURRENT_MODULE starpuv2
DECLARE_CURRENT_MODULE;

PPTRACE_START_INTERCEPT_FUNCTIONS(CURRENT_MODULE)
PPTRACE_END_INTERCEPT_FUNCTIONS(CURRENT_MODULE)

/* set to 1 when all the hooks are set.
 * This is usefull in order to avoid recursive calls
 */
static int _starpu_initialized = 0;

#define MAX_STRING_LENGTH 128
#define NB_FUNCTION_MAX 1024
struct starpu_function{
  void* fun_ptr;
  int event_id;
  char name[MAX_STRING_LENGTH];
};

static struct starpu_function functions[NB_FUNCTION_MAX];
static _Atomic int nb_functions = 0;
static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;

static inline struct starpu_function* get_function(void* fun_ptr) {
  for(int i=0; i<nb_functions; i++) {
    if(functions[i].fun_ptr == fun_ptr)
      return &functions[i];
  }

  pthread_mutex_lock(&lock);

  /* we need to check again in case of a race condition occured */
  for(int i=0; i<nb_functions; i++) {
    if(functions[i].fun_ptr == fun_ptr) {
      pthread_mutex_unlock(&lock);
      return &functions[i];
    }
  }

  return NULL;
}

static inline struct starpu_function* get_existing_function(void* fun_ptr, const char* task_name) {
  assert(task_name != NULL);

  struct starpu_function* function = get_function(fun_ptr);

  if(function == NULL) {
    eztrace_error("[EZTrace::"STRINGIFY(CURRENT_MODULE)"] Function %p (%s) not found\n", fun_ptr, task_name);
  }

  return function;
}

static inline struct starpu_function* get_or_create_function(void* fun_ptr, const char* task_name) {
  assert(task_name != NULL);

  struct starpu_function* function = get_function(fun_ptr);

  if(function != NULL) {
    return function;
  }

  /* The function is not registered yet, let's register it */
  int id = nb_functions++;
  if(id > NB_FUNCTION_MAX) {
    eztrace_error("[EZTrace::"STRINGIFY(CURRENT_MODULE)"] Too many functions registered!\n");
  }
  functions[id].fun_ptr = fun_ptr;
  strncpy(functions[id].name, task_name, MAX_STRING_LENGTH);
  functions[id].event_id = ezt_otf2_register_function(functions[id].name);
  pthread_mutex_unlock(&lock);
  return &functions[id];
}


void myfunction_cb(struct starpu_prof_tool_info *prof_info,
		   union starpu_prof_tool_event_info *event_info,
		   struct starpu_prof_tool_api_info *api_info) {
	if (NULL == prof_info)
	  return;

	if (NULL == prof_info->task_name) {
	  prof_info->task_name = "<no name>";
	}

	switch (prof_info->event_type)
	{
	case starpu_prof_tool_event_start_cpu_exec:
	case starpu_prof_tool_event_start_gpu_exec:
	  {
	    struct starpu_function* f = get_or_create_function(prof_info->fun_ptr, prof_info->task_name);
	    EZT_OTF2_EvtWriter_Enter(evt_writer, NULL, ezt_get_timestamp(), f->event_id);
	    break;
	  }
	case starpu_prof_tool_event_end_cpu_exec:
	case starpu_prof_tool_event_end_gpu_exec:
	  {
	    struct starpu_function* f = get_existing_function(prof_info->fun_ptr, prof_info->task_name);
	    EZT_OTF2_EvtWriter_Leave(evt_writer, NULL, ezt_get_timestamp(), f->event_id);
	    break;
	  }
//	case starpu_prof_tool_event_start_transfer:
//		printf("Start transfer on memnode %ud\n", prof_info->memnode);
//		break;
//	case starpu_prof_tool_event_end_transfer:
//		printf("End transfer on memnode %ud\n", prof_info->memnode);
//		break;
	default:
		eztrace_error("[EZTrace::"STRINGIFY(CURRENT_MODULE)"] Unknown callback %d\n",  prof_info->event_type);
		break;
	}
}

/* StarPU will call this function, so we can tell StarPU which functions from
 * this module should be called when profiling a StarPU application. */
void starpu_prof_tool_library_register(starpu_prof_tool_entry_register_func reg, starpu_prof_tool_entry_register_func unreg) {
  eztrace_log(dbg_lvl_normal, "[EZTrace::"STRINGIFY(CURRENT_MODULE)"] Registering profiling hooks!\n");

  
  enum  starpu_prof_tool_command info = 0;
  //  reg(starpu_prof_tool_event_driver_init, &myfunction_cb, info);
  //  reg(starpu_prof_tool_event_driver_init_start, &myfunction_cb, info);
  //  reg(starpu_prof_tool_event_driver_init_end, &myfunction_cb, info);
  reg(starpu_prof_tool_event_start_cpu_exec, &myfunction_cb, info);
  reg(starpu_prof_tool_event_end_cpu_exec, &myfunction_cb, info);
  reg(starpu_prof_tool_event_start_gpu_exec, &myfunction_cb, info);
  reg(starpu_prof_tool_event_end_gpu_exec, &myfunction_cb, info);
  //  reg(starpu_prof_tool_event_start_transfer, &myfunction_cb, info);
  //  reg(starpu_prof_tool_event_end_transfer, &myfunction_cb, info);
}

static void init_starpu() {
  INSTRUMENT_FUNCTIONS(CURRENT_MODULE);

  if (eztrace_autostart_enabled())
    eztrace_start();

  _starpu_initialized = 1;
}

static void finalize_starpu() {
  if(_starpu_initialized) {
    _starpu_initialized = 0;
    eztrace_stop();
  }
}

static void _starpu_finalize (void) __attribute__ ((destructor));
static void _starpu_finalize (void) {
  finalize_starpu();
}

static void _starpu_init (void) __attribute__ ((constructor));
/* Initialize the current library */
static void _starpu_init (void) {
  EZT_REGISTER_MODULE(CURRENT_MODULE, "Module for StarPU functions using StarPU's profiling API",
		      init_starpu, finalize_starpu);
}
