package org.infinispan.server.hotrod.iteration

import java.util.{BitSet => JavaBitSet, Set => JavaSet, UUID}
import java.util.concurrent.atomic.AtomicReference

import org.infinispan.commons.marshall.Marshaller
import org.infinispan.commons.util.{BitSetUtils, CloseableIterator, CollectionFactory, InfinispanCollections}
import org.infinispan.configuration.cache.CompatibilityModeConfiguration
import org.infinispan.container.entries.CacheEntry
import org.infinispan.filter.{KeyValueFilterConverter, KeyValueFilterConverterFactory, ParamKeyValueFilterConverterFactory}
import org.infinispan.iteration.EntryRetriever
import org.infinispan.iteration.EntryRetriever.SegmentListener
import org.infinispan.manager.EmbeddedCacheManager
import org.infinispan.server.hotrod.OperationStatus.OperationStatus
import org.infinispan.server.hotrod._
import org.infinispan.server.hotrod.iteration.MarshallerBuilder._
import org.infinispan.server.hotrod.logging.Log
import org.infinispan.util.concurrent.ConcurrentHashSet

import scala.collection.JavaConversions._

/**
 * @author gustavonalle
 * @since 8.0
 */
trait IterationManager {
   type IterationId = String
   def start(cacheName: String, segments: Option[JavaBitSet], filterConverterFactory: NamedFactory, batch: Integer, metadata: Boolean): IterationId
   def next(cacheName: String, iterationId: IterationId): IterableIterationResult
   def close(cacheName: String, iterationId: IterationId): Boolean
   def addKeyValueFilterConverterFactory[K, V, C](name: String, factory: KeyValueFilterConverterFactory[K, V, C]): Unit
   def removeKeyValueFilterConverterFactory(name: String): Unit
   def setMarshaller(maybeMarshaller: Option[Marshaller]): Unit
   def activeIterations: Int
}

class IterationSegmentsListener extends SegmentListener {
   private val finished = new AtomicReference(new ConcurrentHashSet[Integer]())


   def getFinished = finished.getAndSet(new ConcurrentHashSet[Integer]())

   override def segmentTransferred(segment: Int, sentLastEntry: Boolean): Unit = finished.get.add(segment)
}

class IterationState(val listener: IterationSegmentsListener, val iterator: CloseableIterator[CacheEntry], val batch: Integer, val compatInfo: CompatInfo, val metadata: Boolean)

class IterableIterationResult(finishedSegments: JavaSet[Integer], val statusCode: OperationStatus, val entries: List[CacheEntry], compatInfo: CompatInfo, val metadata: Boolean) {

   lazy val compatEnabled = compatInfo.enabled && compatInfo.hotRodTypeConverter.isDefined

   def segmentsToBytes = {
      val bs = new JavaBitSet()
      finishedSegments.foreach(i => bs.set(i))
      BitSetUtils.toByteArray(bs)
   }

   def unbox(value: AnyRef) = compatInfo.hotRodTypeConverter.get.unboxValue(value)

}

class CompatInfo(val enabled: Boolean, val hotRodTypeConverter: Option[HotRodTypeConverter])

object CompatInfo {
   def apply(config: CompatibilityModeConfiguration) =
      new CompatInfo(config.enabled(), Option(config.marshaller()).map(HotRodTypeConverter(_)))
}

class DefaultIterationManager(val cacheManager: EmbeddedCacheManager) extends IterationManager with Log {
   @volatile var marshaller: Option[_ <: Marshaller] = None

   private val iterationStateMap = CollectionFactory.makeConcurrentMap[String, IterationState]()
   private val filterConverterFactoryMap = CollectionFactory.makeConcurrentMap[String, KeyValueFilterConverterFactory[_, _, _]]()

   override def start(cacheName: String, segments: Option[JavaBitSet], namedFactory: NamedFactory, batch: Integer, metadata: Boolean): IterationId = {
      val iterationId = UUID.randomUUID().toString
      val entryRetriever = cacheManager.getCache(cacheName).getAdvancedCache.getComponentRegistry.getComponent(classOf[EntryRetriever[_, _]])
      val segmentListener = new IterationSegmentsListener
      val compatInfo = CompatInfo(cacheManager.getCacheConfiguration(cacheName).compatibility())

      val filter = {
         val customFilter = buildCustomFilter(namedFactory)
         if (customFilter.isDefined || segments.isDefined) {
            new IterationFilter(compatInfo.enabled, customFilter.asInstanceOf[Option[KeyValueFilterConverter[Any,Any,Any]]], segments, marshaller)
         } else null
      }

      val iterator = entryRetriever.retrieveEntries(filter, null, null, segmentListener)
      val iterationState = new IterationState(segmentListener, iterator, batch, compatInfo, metadata)

      iterationStateMap.put(iterationId, iterationState)
      iterationId
   }

   private def buildCustomFilter[K, V, Any](namedFactory: NamedFactory)  = {
      namedFactory match {
         case None => None
         case Some((name, params)) =>
            Option(filterConverterFactoryMap.get(name)).map {
               case factory: ParamKeyValueFilterConverterFactory[K, V, Any] =>
                  val unmarshalledParams = unmarshallParams(params.toArray, factory)
                  factory.getFilterConverter(unmarshalledParams)
               case factory: KeyValueFilterConverterFactory[K, V, Any] => factory.getFilterConverter
            }.orElse(throw log.missingKeyValueFilterConverterFactory(name))
      }
   }

   private def unmarshallParams(params: Array[Bytes], factory: AnyRef): Array[AnyRef] = {
      val m = marshaller.getOrElse(genericFromInstance(Some(factory)))
      params.map(m.objectFromByteBuffer)
   }

   override def next(cacheName: String, iterationId: IterationId): IterableIterationResult = {
      val iterationState = Option(iterationStateMap.get(iterationId))
      iterationState.map { state =>
         val iterator = state.iterator
         val listener = state.listener
         val batch = state.batch
         val entries = for (i <- 0 to batch - 1; if iterator.hasNext) yield iterator.next()
         new IterableIterationResult(listener.getFinished, OperationStatus.Success, entries.toList, state.compatInfo, state.metadata)
      }.getOrElse(new IterableIterationResult(InfinispanCollections.emptySet(), OperationStatus.InvalidIteration, List.empty, null, false))
   }

   override def close(cacheName: String, iterationId: IterationId): Boolean = {
      val iterationState = Option(iterationStateMap.get(iterationId))
      val removed = iterationState.map { state =>
         state.iterator.close()
         iterationStateMap.remove(iterationId)
      }
      Option(removed).isDefined
   }

   override def addKeyValueFilterConverterFactory[K, V, C](name: String, factory: KeyValueFilterConverterFactory[K, V, C]): Unit = filterConverterFactoryMap.put(name, factory)

   override def removeKeyValueFilterConverterFactory(name: String) = filterConverterFactoryMap.remove(name)

   override def activeIterations: Int = iterationStateMap.size()

   override def setMarshaller(maybeMarshaller: Option[Marshaller]): Unit = this.marshaller = maybeMarshaller
}
