#include <stdio.h>
#include <stdlib.h>
#include <math.h> // For fabs

// --- CUDA Best Practice: Error Checking Macro ---
#define checkCudaErrors(call)                                                  \
  do {                                                                         \
    cudaError_t err = call;                                                    \
    if (err != cudaSuccess) {                                                  \
      printf("CUDA Error at %s:%d: %s\n", __FILE__, __LINE__,                   \
             cudaGetErrorString(err));                                         \
      exit(EXIT_FAILURE);                                                      \
    }                                                                          \
  } while (0)

// --- 1. KERNEL: BASELINE (from Example 1) ---
// This is our control. It has coalesced global memory access and no divergence.
__global__ void axpb_baseline(float *y, const float *x, float a, float b, int n) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  for (int i = idx; i < n; i += gridDim.x * blockDim.x) {
    y[i] = a * x[i] + b;
  }
}

// --- 2. KERNEL: BRANCHING (WARP DIVERGENCE) ---
// This kernel introduces a data-dependent 'if' statement.
// Since h_x is random, threads within the same warp (32 threads)
// will likely take different paths, causing divergence.
__global__ void axpb_branching(float *y, const float *x, float a, float b, int n) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  for (int i = idx; i < n; i += gridDim.x * blockDim.x) {
    if (x[i] > 0.5f) {
      // Path 1
      y[i] = a * x[i] + b; 
    } else {
      // Path 2
      y[i] = a * x[i] - b; // Different operation
    }
  }
}

// --- 3. KERNEL: SHARED MEMORY (GOOD ACCESS) ---
// This is our control for the bank conflict test.
// It uses shared memory but accesses it in an ideal, conflict-free pattern.
// Note: 'extern' keyword denotes dynamic shared memory.
__global__ void axpb_shared_good(float *y, const float *x, float a, float b, int n) {
  // 's_data' is a pointer to the dynamic shared memory allocated at launch
  extern __shared__ float s_data[]; 
  
  int tid = threadIdx.x; // Thread's ID within the block
  
  // Use a grid-stride loop (note the loop starts from the block's global index)
  for (int i = blockIdx.x * blockDim.x; i < n; i += gridDim.x * blockDim.x) {
    int current_idx = i + tid; // This thread's global data index
    
    // 1. Load from global to shared (coalesced)
    if (current_idx < n) {
      s_data[tid] = x[current_idx];
    }
    __syncthreads(); // Wait for ALL threads in block to finish loading

    // 2. Read from shared (coalesced) and compute
    float val = 0.0f;
    if (current_idx < n) {
      val = s_data[tid]; // GOOD: tid=0 -> bank 0, tid=1 -> bank 1...
      y[current_idx] = a * val + b;
    }
    __syncthreads(); // Wait for all threads to finish before next loop
  }
}

// --- 4. KERNEL: SHARED MEMORY (BANK CONFLICT) ---
// This kernel correctly computes y=ax+b, but *also* performs
// a dummy read from shared memory in a pattern that causes conflicts.
__global__ void axpb_shared_conflict(float *y, const float *x, float a, float b, int n) {
  extern __shared__ float s_data[];
  
  int tid = threadIdx.x;
  
  for (int i = blockIdx.x * blockDim.x; i < n; i += gridDim.x * blockDim.x) {
    int current_idx = i + tid;
    
    // 1. Load from global to shared (coalesced)
    if (current_idx < n) {
      s_data[tid] = x[current_idx];
    }
    __syncthreads(); 

    // 2. Read from shared and compute
    float val = 0.0f;
    if (current_idx < n) {
      // --- This is the bottleneck ---
      // This access pattern causes a 2-way bank conflict.
      // e.g., thread 0 accesses s_data[0] (Bank 0)
      //       thread 2 accesses s_data[32] (Bank 0) -> CONFLICT
      //       thread 1 accesses s_data[16] (Bank 16)
      //       thread 3 accesses s_data[48] (Bank 16) -> CONFLICT
      int conflict_idx = (tid * 16) % blockDim.x;
      float dummy = s_data[conflict_idx]; // Conflicting read
      
      // This is the *correct* read (no conflict)
      val = s_data[tid]; 
      
      // We do the correct calculation but include the dummy value
      // so the compiler doesn't optimize away the conflicting read.
      y[current_idx] = a * val + b + (0.0f * dummy);
    }
    __syncthreads();
  }
}


// --- Utility function for verification ---
void verify(const char *label, const float *h_y_result, const float *h_x, float a, float b, int n, bool use_branching_logic) {
  bool success = true;
  for (int i = 0; i < 10; i++) { // Check first 10 elements
    float expected;
    if (use_branching_logic) {
      if (h_x[i] > 0.5f) expected = a * h_x[i] + b;
      else expected = a * h_x[i] - b;
    } else {
      expected = a * h_x[i] + b;
    }
    
    if (fabs(h_y_result[i] - expected) > 1e-5) {
      printf("[%s] Verification FAILED at index %d!\n", label, i);
      printf("  Expected: %f, Got: %f\n", expected, h_y_result[i]);
      success = false;
      break;
    }
  }
  if (success) {
    printf("[%s] Verification SUCCESSFUL!\n", label);
  }
}

// --- HOST CODE (Main Function) ---
int main() {
  // --- 1. Define Problem Size ---
  int n = 1 << 24; // 16.7 million elements
  size_t bytes = n * sizeof(float);
  float a = 2.0f;
  float b = 1.0f;

  // --- 2. Allocate Host (CPU) Memory ---
  float *h_x = (float *)malloc(bytes);
  float *h_y_ref = (float *)malloc(bytes); // For storing results
  float *h_y_zeros = (float *)malloc(bytes); // To reset GPU memory
  
  // Initialize host data
  for (int i = 0; i < n; i++) {
    h_x[i] = (float)rand() / (float)RAND_MAX;
    h_y_zeros[i] = 0.0f;
  }

  // --- 3. Allocate Device (GPU) Memory ---
  float *d_x, *d_y;
  checkCudaErrors(cudaMalloc(&d_x, bytes));
  checkCudaErrors(cudaMalloc(&d_y, bytes));

  // --- 4. Copy Data from Host to Device ---
  printf("Copying %lu MB from Host to Device...\n", bytes / (1024 * 1024));
  checkCudaErrors(cudaMemcpy(d_x, h_x, bytes, cudaMemcpyHostToDevice));

  // --- 5. Create CUDA Events for Timing ---
  cudaEvent_t start, stop;
  checkCudaErrors(cudaEventCreate(&start));
  checkCudaErrors(cudaEventCreate(&stop));

  // --- 6. Set up launch parameters ---
  // For the shared memory examples, we'll fix the block size.
  // 256 or 512 are common, good choices.
  int blockSize = 256;
  int gridSize = (n + blockSize - 1) / blockSize;
  
  // Calculate dynamic shared memory size
  size_t sharedMemBytes = blockSize * sizeof(float);
  
  printf("Problem Size: %d elements\n", n);
  printf("Launch Config: Grid=%d, Block=%d, SharedMem=%lu bytes\n\n", 
         gridSize, blockSize, sharedMemBytes);
  
  float milliseconds = 0;

  // --- RUN 1: BASELINE ---
  checkCudaErrors(cudaMemcpy(d_y, h_y_zeros, bytes, cudaMemcpyHostToDevice));
  checkCudaErrors(cudaEventRecord(start));
  axpb_baseline<<<gridSize, blockSize>>>(d_y, d_x, a, b, n);
  checkCudaErrors(cudaEventRecord(stop));
  checkCudaErrors(cudaEventSynchronize(stop));
  checkCudaErrors(cudaEventElapsedTime(&milliseconds, start, stop));
  printf("1. Baseline:          \t%f ms\n", milliseconds);
  checkCudaErrors(cudaMemcpy(h_y_ref, d_y, bytes, cudaMemcpyDeviceToHost));
  verify("Baseline", h_y_ref, h_x, a, b, n, false);

  // --- RUN 2: BRANCHING (WARP DIVERGENCE) ---
  checkCudaErrors(cudaMemcpy(d_y, h_y_zeros, bytes, cudaMemcpyHostToDevice));
  checkCudaErrors(cudaEventRecord(start));
  axpb_branching<<<gridSize, blockSize>>>(d_y, d_x, a, b, n);
  checkCudaErrors(cudaEventRecord(stop));
  checkCudaErrors(cudaEventSynchronize(stop));
  checkCudaErrors(cudaEventElapsedTime(&milliseconds, start, stop));
  printf("\n2. Branching:         \t%f ms\n", milliseconds);
  checkCudaErrors(cudaMemcpy(h_y_ref, d_y, bytes, cudaMemcpyDeviceToHost));
  verify("Branching", h_y_ref, h_x, a, b, n, true); // Use branching verify
  
  // --- RUN 3: SHARED MEMORY (GOOD) ---
  checkCudaErrors(cudaMemcpy(d_y, h_y_zeros, bytes, cudaMemcpyHostToDevice));
  checkCudaErrors(cudaEventRecord(start));
  axpb_shared_good<<<gridSize, blockSize, sharedMemBytes>>>(d_y, d_x, a, b, n);
  checkCudaErrors(cudaEventRecord(stop));
  checkCudaErrors(cudaEventSynchronize(stop));
  checkCudaErrors(cudaEventElapsedTime(&milliseconds, start, stop));
  printf("\n3. Shared (Good):     \t%f ms\n", milliseconds);
  checkCudaErrors(cudaMemcpy(h_y_ref, d_y, bytes, cudaMemcpyDeviceToHost));
  verify("Shared (Good)", h_y_ref, h_x, a, b, n, false);
  
  // --- RUN 4: SHARED MEMORY (BANK CONFLICT) ---
  checkCudaErrors(cudaMemcpy(d_y, h_y_zeros, bytes, cudaMemcpyHostToDevice));
  checkCudaErrors(cudaEventRecord(start));
  axpb_shared_conflict<<<gridSize, blockSize, sharedMemBytes>>>(d_y, d_x, a, b, n);
  checkCudaErrors(cudaEventRecord(stop));
  checkCudaErrors(cudaEventSynchronize(stop));
  checkCudaErrors(cudaEventElapsedTime(&milliseconds, start, stop));
  printf("\n4. Shared (Conflict): \t%f ms\n", milliseconds);
  checkCudaErrors(cudaMemcpy(h_y_ref, d_y, bytes, cudaMemcpyDeviceToHost));
  verify("Shared (Conflict)", h_y_ref, h_x, a, b, n, false);

  // --- 7. Clean Up Memory ---
  printf("\nCleaning up memory...\n");
  free(h_x);
  free(h_y_ref);
  free(h_y_zeros);
  checkCudaErrors(cudaFree(d_x));
  checkCudaErrors(cudaFree(d_y));
  checkCudaErrors(cudaEventDestroy(start));
  checkCudaErrors(cudaEventDestroy(stop));

  return 0;
}