package org.infinispan.interceptors.impl;

import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;

import org.infinispan.commands.tx.CommitCommand;
import org.infinispan.commands.tx.PrepareCommand;
import org.infinispan.commands.tx.RollbackCommand;
import org.infinispan.commands.write.ClearCommand;
import org.infinispan.commands.write.InvalidateCommand;
import org.infinispan.commands.write.WriteCommand;
import org.infinispan.configuration.cache.Configuration;
import org.infinispan.configuration.cache.MemoryConfiguration;
import org.infinispan.configuration.cache.StorageType;
import org.infinispan.container.entries.CacheEntry;
import org.infinispan.container.entries.InternalCacheEntry;
import org.infinispan.container.impl.InternalDataContainer;
import org.infinispan.container.impl.KeyValueMetadataSizeCalculator;
import org.infinispan.container.offheap.UnpooledOffHeapMemoryAllocator;
import org.infinispan.context.InvocationContext;
import org.infinispan.context.impl.TxInvocationContext;
import org.infinispan.distribution.DistributionManager;
import org.infinispan.factories.annotations.Inject;
import org.infinispan.factories.annotations.Start;
import org.infinispan.factories.annotations.Stop;
import org.infinispan.interceptors.DDAsyncInterceptor;
import org.infinispan.notifications.Listener;
import org.infinispan.notifications.cachelistener.CacheNotifier;
import org.infinispan.notifications.cachelistener.annotation.CacheEntryExpired;
import org.infinispan.notifications.cachelistener.event.CacheEntryExpiredEvent;
import org.infinispan.transaction.xa.GlobalTransaction;
import org.infinispan.util.logging.Log;
import org.infinispan.util.logging.LogFactory;

/**
 * Interceptor that prevents the cache from inserting too many entries over a configured maximum amount.
 * This interceptor assumes that there is a transactional cache without one phase commit semantics.
 * @author wburns
 * @since 9.0
 */
@Listener(observation = Listener.Observation.POST)
public class TransactionalExceptionEvictionInterceptor extends DDAsyncInterceptor {
   private final static Log log = LogFactory.getLog(TransactionalExceptionEvictionInterceptor.class);
   private final static boolean isTrace = log.isTraceEnabled();

   private final AtomicLong currentSize = new AtomicLong();
   private final ConcurrentMap<GlobalTransaction, Long> pendingSize = new ConcurrentHashMap<>();
   private MemoryConfiguration memoryConfiguration;
   private CacheNotifier cacheNotifier;
   private InternalDataContainer container;
   private DistributionManager dm;
   private long maxSize;
   private long minSize;
   private KeyValueMetadataSizeCalculator calculator;

   private Consumer<Iterable<InternalCacheEntry>> listener;

   public long getCurrentSize() {
      return currentSize.get();
   }

   public long getMaxSize() {
      return maxSize;
   }

   public long getMinSize() {
      return minSize;
   }

   public long pendingTransactionCount() {
      return pendingSize.size();
   }

   @Inject
   public void inject(Configuration config, CacheNotifier cacheNotifier,
                      InternalDataContainer dataContainer, KeyValueMetadataSizeCalculator calculator, DistributionManager dm) {
      this.memoryConfiguration = config.memory();
      this.cacheNotifier = cacheNotifier;
      this.container = dataContainer;
      this.maxSize = config.memory().size();
      this.calculator = calculator;
      this.dm = dm;
   }

   @Start
   public void start() {
      if (memoryConfiguration.storageType() == StorageType.OFF_HEAP) {
         minSize = UnpooledOffHeapMemoryAllocator.estimateSizeOverhead(memoryConfiguration.addressCount() << 3);
         currentSize.set(minSize);
      }

      listener = this::entriesRemoved;
      container.addRemovalListener(listener);

      // Local caches just remove the entry, so we have to listen for those events
      if (!cacheConfiguration.clustering().cacheMode().isClustered()) {
         // We want the raw values and no transformations for our listener
         cacheNotifier.addListener(this);
      }
   }

   @Stop
   public void stop() {
      container.removeRemovalListener(listener);
   }

   private void entriesRemoved(Iterable<InternalCacheEntry> entries) {
      long changeAmount = 0;
      for (InternalCacheEntry entry : entries) {
         changeAmount -= calculator.calculateSize(entry.getKey(), entry.getValue(), entry.getMetadata());
      }
      if (changeAmount != 0) {
         increaseSize(changeAmount);
      }
   }

   @CacheEntryExpired
   public void entryExpired(CacheEntryExpiredEvent event) {
      // If this is null it means it was from the store, so we don't care about that
      if (event.getValue() != null) {
         Object key = event.getKey();
         if (isTrace) {
            log.tracef("Key %s found to have expired", key);
         }
         increaseSize(- calculator.calculateSize(key, event.getValue(), event.getMetadata()));
      }
   }

   private boolean increaseSize(long increaseAmount) {
      while (true) {
         long size = currentSize.get();
         long targetSize = size + increaseAmount;
         if (targetSize <= maxSize) {
            if (currentSize.compareAndSet(size, size + increaseAmount)) {
               if (isTrace) {
                  log.tracef("Increased exception based size by %d to %d", increaseAmount, size + increaseAmount);
               }
               return true;
            }
         } else {
            return false;
         }
      }
   }

   @Override
   public Object visitInvalidateCommand(InvocationContext ctx, InvalidateCommand command) throws Throwable {
      /**
       * State transfer uses invalidate command to remove entries that are outside of the tx context
       */
      Object[] keys = command.getKeys();
      long changeAmount = 0;
      for (Object key : keys) {
         InternalCacheEntry entry = container.peek(key);
         if (entry != null) {
            changeAmount -= calculator.calculateSize(key, entry.getValue(), entry.getMetadata());
         }
      }
      if (changeAmount != 0) {
         increaseSize(changeAmount);
      }
      return super.visitInvalidateCommand(ctx, command);
   }

   @Override
   public Object visitClearCommand(InvocationContext ctx, ClearCommand command) throws Throwable {
      if (isTrace) {
         log.tracef("Clear command encountered, resetting size to %d", minSize);
      }
      // Clear is never invoked in the middle of a transaction with others so just set the size
      currentSize.set(minSize);
      return super.visitClearCommand(ctx, command);
   }

   @Override
   public Object visitPrepareCommand(TxInvocationContext ctx, PrepareCommand command) throws Throwable {
      // If we just invoke ctx.getModifications it won't return the modifications for REPL state transfer
      List<WriteCommand> modifications = ctx.getCacheTransaction().getAllModifications();
      Set<Object> modifiedKeys = new HashSet<>();
      for (WriteCommand modification : modifications) {
         modifiedKeys.addAll(modification.getAffectedKeys());
      }

      long changeAmount = 0;
      for (Object key : modifiedKeys) {
         if (dm == null || dm.getCacheTopology().isWriteOwner(key)) {
            CacheEntry entry = ctx.lookupEntry(key);
            if (entry.isRemoved()) {
               // Need to subtract old value here
               InternalCacheEntry containerEntry = container.peek(key);
               Object value = containerEntry != null ? containerEntry.getValue() : null;
               if (value != null) {
                  if (isTrace) {
                     log.tracef("Key %s was removed", key);
                  }
                  changeAmount -= calculator.calculateSize(key, value, entry.getMetadata());
               }
            } else {
               // We check the container directly - this is to handle entries that are expired as the command
               // won't think it replaced a value
               InternalCacheEntry containerEntry = container.peek(key);
               if (isTrace) {
                  log.tracef("Key %s was put into cache, replacing existing %s", key, containerEntry != null);
               }
               // Create and replace both add for the new value
               changeAmount += calculator.calculateSize(key, entry.getValue(), entry.getMetadata());
               // Need to subtract old value here
               if (containerEntry != null) {
                  changeAmount -= calculator.calculateSize(key, containerEntry.getValue(), containerEntry.getMetadata());
               }
            }
         }
      }

      if (changeAmount != 0 && !increaseSize(changeAmount)) {
         throw log.containerFull(maxSize);
      }

      if (!command.isOnePhaseCommit()) {
         pendingSize.put(ctx.getGlobalTransaction(), changeAmount);
      }

      return super.visitPrepareCommand(ctx, command);
   }

   @Override
   public Object visitRollbackCommand(TxInvocationContext ctx, RollbackCommand command) throws Throwable {
      Long size = pendingSize.remove(ctx.getGlobalTransaction());
      if (size != null) {
         if (isTrace) {
            log.tracef("Rollback encountered subtracting exception size by %d", size);
         }
         currentSize.addAndGet(-size);
      }
      return super.visitRollbackCommand(ctx, command);
   }

   @Override
   public Object visitCommitCommand(TxInvocationContext ctx, CommitCommand command) throws Throwable {
      pendingSize.remove(ctx.getGlobalTransaction());
      return super.visitCommitCommand(ctx, command);
   }
}
