package org.infinispan.distribution.rehash;

import static org.testng.Assert.assertEquals;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

import jakarta.transaction.Transaction;
import jakarta.transaction.TransactionManager;
import javax.transaction.xa.XAResource;
import javax.transaction.xa.Xid;

import org.infinispan.Cache;
import org.infinispan.distribution.BaseDistFunctionalTest;
import org.infinispan.distribution.MagicKey;
import org.infinispan.test.TestingUtil;
import org.infinispan.commons.test.TestResourceTracker;
import org.testng.annotations.Test;

/**
 * A base test for all rehashing tests
 */
@Test(groups = "functional")
public abstract class RehashTestBase extends BaseDistFunctionalTest<Object, String> {

   protected RehashTestBase() {
      cleanup = CleanupPhase.AFTER_METHOD;
      transactional = true;
      performRehashing = true;
   }

   // this setup has 4 running caches: {c1, c2, c3, c4}

   /**
    * This is overridden by subclasses.  Could typically be a JOIN or LEAVE event.
    * @param offline
    */
   abstract void performRehashEvent(boolean offline) throws Throwable;

   /**
    * Blocks until a rehash completes.
    */
   abstract void waitForRehashCompletion();

   protected List<MagicKey> init() {

      List<MagicKey> keys = new ArrayList<>(Arrays.asList(
            new MagicKey("k1", c1), new MagicKey("k2", c2),
            new MagicKey("k3", c3), new MagicKey("k4", c4)
      ));
      assertEquals(caches.size(), keys.size(), "Received caches" + caches);

      int i = 0;
      for (Cache<Object, String> c : caches) c.put(keys.get(i++), "v0");

      for (MagicKey key : keys) assertOwnershipAndNonOwnership(key, false);

      log.infof("Initialized with keys %s", keys);
      return keys;
   }

   /**
    * Simple test.  Put some state, trigger event, test results
    */
   @Test
   public void testNonTransactional() throws Throwable {
      List<MagicKey> keys = init();

      log.info("Invoking rehash event");
      performRehashEvent(false);

      waitForRehashCompletion();
      log.info("Rehash complete");

      for (MagicKey key : keys) assertOnAllCachesAndOwnership(key, "v0");
   }


   /**
    * More complex - init some state.  Start a new transaction, and midway trigger a rehash.  Then complete transaction
    * and test results.
    */
   @Test
   public void testTransactional() throws Throwable {
      final List<MagicKey> keys = init();
      final CountDownLatch l = new CountDownLatch(1);
      final AtomicBoolean rollback = new AtomicBoolean(false);

      Future<Void> future = fork(() -> {
         try {
            // start a transaction on c1.
            TransactionManager t1 = TestingUtil.getTransactionManager(c1);
            t1.begin();
            c1.put(keys.get(0), "transactionally_replaced");
            Transaction tx = t1.getTransaction();
            tx.enlistResource(new XAResourceAdapter() {
               public int prepare(Xid id) {
                  // this would be called *after* the cache prepares.
                  try {
                     log.debug("Unblocking commit");
                     l.await();
                  } catch (InterruptedException e) {
                     Thread.currentThread().interrupt();
                  }
                  return XAResource.XA_OK;
               }
            });
            t1.commit();
         } catch (Exception e) {
            log.error("Error committing transaction", e);
            rollback.set(true);
            throw new RuntimeException(e);
         }
      });

      log.info("Invoking rehash event");
      performRehashEvent(true);
      l.countDown();
      future.get(30, TimeUnit.SECONDS);

      //ownership can only be verified after the rehashing has completed
      waitForRehashCompletion();
      log.info("Rehash complete");

      //only check for these values if tx was not rolled back
      if (!rollback.get()) {
         // the ownership of k1 might change during the tx and a cache might end up with it in L1
         assertOwnershipAndNonOwnership(keys.get(0), true);
         assertOwnershipAndNonOwnership(keys.get(1), false);
         assertOwnershipAndNonOwnership(keys.get(2), false);
         assertOwnershipAndNonOwnership(keys.get(3), false);

         // checking the values will bring the keys to L1, so we want to do it after checking ownership
         assertOnAllCaches(keys.get(0), "transactionally_replaced");
         assertOnAllCaches(keys.get(1), "v0");
         assertOnAllCaches(keys.get(2), "v0");
         assertOnAllCaches(keys.get(3), "v0");
      }
   }

   /**
    * A stress test.  One node is constantly modified while a rehash occurs.
    */
   @Test(groups = "stress", timeOut = 15*60*1000)
   public void testNonTransactionalStress() throws Throwable {
      TestResourceTracker.testThreadStarted(this.getTestName());
      stressTest(false);
   }

   /**
    * A stress test.  One node is constantly modified using transactions while a rehash occurs.
    */
   @Test(groups = "stress", timeOut = 15*60*1000)
   public void testTransactionalStress() throws Throwable {
      TestResourceTracker.testThreadStarted(this.getTestName());
      stressTest(true);
   }

   private void stressTest(boolean tx) throws Throwable {
      final List<MagicKey> keys = init();
      final CountDownLatch latch = new CountDownLatch(1);
      List<Updater> updaters = new ArrayList<>(keys.size());
      for (MagicKey k : keys) {
         Updater u = new Updater(c1, k, latch, tx);
         u.start();
         updaters.add(u);
      }

      latch.countDown();

      log.info("Invoking rehash event");
      performRehashEvent(false);

      for (Updater u : updaters) u.complete();
      for (Updater u : updaters) u.join();

      waitForRehashCompletion();

      log.info("Rehash complete");

      int i = 0;
      for (MagicKey key : keys) assertOnAllCachesAndOwnership(key, "v" + updaters.get(i++).currentValue);
   }
}

class Updater extends Thread {
   static final Random r = new Random();
   volatile int currentValue = 0;
   MagicKey key;
   Cache cache;
   CountDownLatch latch;
   volatile boolean running = true;
   TransactionManager tm;

   Updater(Cache cache, MagicKey key, CountDownLatch latch, boolean tx) {
      super("Updater-" + key);
      this.key = key;
      this.cache = cache;
      this.latch = latch;
      if (tx) tm = TestingUtil.getTransactionManager(cache);
   }

   public void complete() {
      running = false;
   }

   @Override
   public void run() {
      while (running) {
         try {
            currentValue++;
            if (tm != null) tm.begin();
            cache.put(key, "v" + currentValue);
            if (tm != null) tm.commit();
            TestingUtil.sleepThread(r.nextInt(10) * 10);
         } catch (Exception e) {
            // do nothing?
         }
      }
   }
}
