#include <fstream>
#include <iomanip> 
#include <iostream>
#include <vector>

#include <assert.h>
#include <pthread.h>
#include <semaphore.h>
#include <stdint.h>
#include <stdlib.h>

#ifdef USE_OMP
# include <omp.h>
#endif

using namespace std;

// #define VERBOSE tells you what is happening (and disables statistics)
//
// #define VERBOSE

//-----------------------------------------------------------------------------
#ifdef VERBOSE

# define INFO(th, message)      \
   sem_wait(&print_sema);   \
   cout << "thread #" << th << " says: " << message << endl;  \
   sem_post(&print_sema);

#else

# define INFO(th, x)

#endif   // VERBOSE

/// an upper bound on the number of cores, so that we can
/// define some data structures statically
enum
{
   LOG_THREAD_LIMIT = 8,
   THREAD_LIMIT = 1 << LOG_THREAD_LIMIT
};

/// a semaphore to coordinate printouts from several threads
sem_t print_sema;

/// a semaphore to coordinate thread creration
sem_t pthread_create_sema;

//-----------------------------------------------------------------------------
inline uint64_t
cycle_counter()
{
unsigned int lo, hi;
   __asm__ __volatile__ ("rdtsc" : "=a" (lo), "=d" (hi));
   return ((uint64_t)hi << 32) | lo;
}
//-----------------------------------------------------------------------------
struct Thread_context
{
   /// constructor
   Thread_context()
   : thread_id(this - thread_contexts)
   {}

   /// start parallel execution of work
   void fork()
      {
        sem_wait(&fork_sema);
        INFO(thread_id, "forked")

        // fork our threads
        //
        for (int t = 0; t < forked_threads_count; ++t)
            sem_post(&(Thread_context::thread_contexts[forked_threads[t]]
                                                      .fork_sema));
        INFO(thread_id, forked_threads_count << " worker-threads forked")
      }

   /// end parallel execution of work
   inline void join();

   /// the semaphore that controls forking of threads
   sem_t fork_sema;

   /// the semaphore that controls joining of threads
   sem_t join_sema;

   // initialize all but thread and thread_id
   void init(int thread_count, const vector<int> & cores);

   void reset()
      { forked_threads_count = 0; join_thread = -1; }

   const int thread_id;
   int forked_threads[LOG_THREAD_LIMIT];
   int forked_threads_count;

   int join_thread;

   pthread_t thread;

   int CPU;   // the core to which thread binds

   static Thread_context thread_contexts[];

protected:
   void add_forked(int peer)
      { forked_threads[forked_threads_count++] = peer; }

   void add_joined(int peer)
      { assert(join_thread == -1);   join_thread = peer; }

} Thread_context::thread_contexts[THREAD_LIMIT];
// ----------------------------------------------------------------------------
void
Thread_context::join()
{
   // wait for the threads that we forked
   //
   for (int f = 0; f < forked_threads_count; ++f)
       {
         sem_wait(&join_sema);
         INFO(thread_id, "worker-thread " << (f + 1) << " (of "
                         << forked_threads_count << ") has joined")
       }

   // inform our forker (if any) that we are done
   //
   if (thread_id)   // not thread #0
      {
        Thread_context & forker = Thread_context::thread_contexts[join_thread];
        INFO(thread_id, "joining #" << join_thread)
             sem_post(&forker.join_sema);
      }
}
//-----------------------------------------------------------------------------
void do_work(int id);

void Thread_context::init(int thread_count, const vector<int> & cores)
{
   CPU = -1;
   if (thread_id >= thread_count)   return;
   CPU = cores[thread_id % cores.size()];

   for (int dist = THREAD_LIMIT >> 1; dist; dist >>= 1)
       {
         const int mask = dist - 1;
         if (thread_id & mask)   continue;

         const int peer = thread_id ^ dist;
         if (peer >= thread_count)   continue;
         if (thread_id & dist)          continue;

         // we fork peer and peer joins us.
         this->add_forked(peer);
         thread_contexts[peer].add_joined(thread_id);
       }

   sem_init(&fork_sema, 0, 0);
   sem_init(&join_sema, 0, 0);
   if (thread_id < thread_count)
      {
        cerr << "thread #" << thread_id << " will start ";
        if (forked_threads_count == 0)   cerr << "no threads";
        else if (forked_threads_count == 1)   cerr << "1 thread";
        else cerr << forked_threads_count << " threads";
        for (int c = 0; c < forked_threads_count; ++c)
            cerr << " #" << forked_threads[c];

        if (thread_id)
           {
             assert(join_thread != -1);
             cerr << " and will join thread #" << join_thread;
           }
        cerr << endl;
      }
}
//-----------------------------------------------------------------------------

uint64_t work_start;            ///< cycle counter before work
int64_t  work[THREAD_LIMIT];    ///< cycle counters during work
uint64_t work_end;              ///< cycle counter after work

void
do_work(int thread_id)
{
INFO(thread_id, "start work")
   work[thread_id] = cycle_counter();
INFO(thread_id, "done work")
}
//-----------------------------------------------------------------------------
void *
pthread_main(void * arg)
{
Thread_context & ctx = *(Thread_context *)arg;

   INFO(ctx.thread_id, "thread " << ctx.thread_id << " created");
   sem_post(&pthread_create_sema);

   for (;;)
      {
        ctx.fork();
        do_work(ctx.thread_id);
        ctx.join();
      }

   // not reached
   //
   return 0;
}
//-----------------------------------------------------------------------------
int
setup_threads(int thread_count, const vector<int> & cores)
{
   // limit thread_count by THREAD_LIMIT
   //
   if (thread_count > THREAD_LIMIT)   thread_count = THREAD_LIMIT;

   // clear and initizlize all Thread_contexts
   //
   cerr << "\nInitializing thread contexts ..." << endl;
   for (int c = 0; c < THREAD_LIMIT; ++c)
       Thread_context::thread_contexts[c].reset();

   for (int c = 0; c < THREAD_LIMIT; ++c)
       Thread_context::thread_contexts[c].init(thread_count, cores);

   // the main thread is #0 and we create worker-threads for #1 #2 ...
   //
   cerr << "\nCreating worker threads ..." << endl;
   Thread_context::thread_contexts[0].thread = pthread_self();
   for (int c = 1; c < thread_count; ++c)
       {
         Thread_context * ctx = Thread_context::thread_contexts + c;
         pthread_create(&(ctx->thread), /* attr */ 0, pthread_main, ctx);

         // wait until new thread has reached its loop
         sem_wait(&pthread_create_sema);
       }

   // bind threads to cores
   //
   cerr << "\nBinding threads to cores..." << endl;
   for (int c = 0; c < thread_count; ++c)
       {
         Thread_context * ctx = Thread_context::thread_contexts + c;

         cpu_set_t cpus;
         CPU_ZERO(&cpus);
         CPU_SET(ctx->CPU, &cpus);
         
         const int err = pthread_setaffinity_np(ctx->thread,
                                                sizeof(cpu_set_t), &cpus);
         if (err)
            {
              cerr << "pthread_setaffinity_np() failed with error "
                   << err << endl;
              exit(3);
            }

         cerr << "bound thread #" << c << " to core " << ctx->CPU << endl;
       }
   return thread_count;
}
//-----------------------------------------------------------------------------

/// compute the number of cores avaiable
static int
setup_cores(vector<int> & cores, cpu_set_t & CPUs)
{
const int err = pthread_getaffinity_np(pthread_self(),
                                       sizeof(cpu_set_t), &CPUs);
   if (err)
      {
        cerr << "pthread_getaffinity_np() failed with error "
             << err << endl;
        exit(2);
      }

   // get available CPUs (cores) but at most THREAD_LIMIT
   //
   for (int c = 0; c < THREAD_LIMIT; ++c)
       {
         if (CPU_ISSET(c, &CPUs))   cores.push_back(c);
       }

   cout << "\nThis machine has " << cores.size() << " cores:";
   for (int c = 0; c < cores.size(); ++c)   cout << " " << cores[c];
   cout << endl;

   return cores.size();
}
//-----------------------------------------------------------------------------
void
print_times(ostream & out, int thread_count, int pass)
{
   // print execution statistics, unless VERBOSE is on (in which case that
   // makes no sense.
   //
#ifdef VERBOSE
        out  << endl << "No statistics because VERBOSE was #defined" << endl;
        return;
#endif

   out << " " << thread_count << " cores/threads, "
       << (work_end - work_start) << " cycles total" << endl;

   // use file extension .omp for OMP and .man for hand-crafted parallization
   //
#ifdef USE_OMP
# define EXT "omp"
#else
# define EXT "manual"
#endif

char filename[1000];
   sprintf(filename, "cores_%d_pass_%d." EXT, thread_count, pass);

ofstream plot (filename);
   for (int c = 0; c < thread_count; ++c)
       plot << "    " << setw(3) << c << ", " << (work[c] - work_start) << endl;
       plot << "    " << setw(3) << (thread_count + 10)
        << ", " << (work_end - work_start) << endl;
}
//-----------------------------------------------------------------------------
int
main(int argc, const char * argv[])
{
int thread_count = 10;   if (argc >= 2)   thread_count = atoi(argv[1]);

   // determine the available cores and remember them
   //
vector<int> cores;
cpu_set_t CPUs;
const int core_count = setup_cores(cores, CPUs);

   // setup for parallel execution. This is normally done only once (unless
   // the number of cores/threads is changed.
   //
#ifdef USE_OMP

   omp_set_dynamic(false);
   omp_set_num_threads(thread_count);

#else

   sem_init(&print_sema, 0, 1);
   sem_init(&pthread_create_sema, 0, 0);

   thread_count = setup_threads(thread_count, cores);

   cerr << endl;

#endif

   // check that core_count == thread_count and abort if not.
   //  However, a mismatch is allowed for debugging purposes.
   //
#ifndef VERBOSE
   if (core_count != thread_count)
      {
        cerr << "Mismatch between core count " << core_count
             << " and thread count " << thread_count << ". Aborting." << endl;
        exit(1);
      }
#endif

   // we run the loop several times so that we can see cache effects
   //
   for (int pass = 0; pass < 5; ++pass)
       {
         cout << "Pass " << pass << ":";

         work_start = cycle_counter();

#ifdef USE_OMP
         {
#pragma omp parallel default(none) shared(thread_count, work)

#pragma omp for
         for (int c = 0; c < thread_count; c++)
             {
               work[c] = cycle_counter();
             }
         }

#else  // hand-crafted

         {
           // do the same as the workers, but without calling pthread_main().
           // this is to avoid unneccessary tests for worker vs. master in
           // pthread_main().
           //
           Thread_context & master = Thread_context::thread_contexts[0];
           sem_post(&master.fork_sema);
           master.fork();
           do_work(master.thread_id);
           master.join();
         } 

#endif // USE_OMP

         work_end = cycle_counter();

#ifndef VERBOSE
         print_times(cout, thread_count, pass);
#endif
       }
}
//-----------------------------------------------------------------------------