/*
 * Copyright 2010 Red Hat, Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */
package org.jboss.soa.dsp.ws;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import javax.wsdl.Definition;
import javax.wsdl.Operation;
import javax.wsdl.WSDLException;
import javax.wsdl.factory.WSDLFactory;
import javax.wsdl.xml.WSDLReader;
import javax.xml.namespace.QName;
import javax.xml.soap.*;
import javax.xml.ws.Provider;

/**
 * Base class for BPEL endpoints that are created through javassist.
 * Represents a JAX-WS {@link javax.xml.ws.Provider} implementation.
 *
 * @see org.jboss.soa.dsp.ws.WebServiceProviderGenerator
 *
 * @author Heiko.Braun <heiko.braun@jboss.com>
 */
public class BaseWebServiceEndpoint implements Provider<SOAPMessage>
{
  protected final Log log = LogFactory.getLog(getClass());

  private SOAPMessageAdapter soapAdapter;
  private QName serviceQName;
  private Definition wsdlDefinition;
  private WebServiceProviderFactory serviceFactory;

  private boolean isInitialized;

  private void init()
  {
    if(!isInitialized)
    {
      try
      {
        WSDLReader wsdlReader = WSDLFactory.newInstance().newWSDLReader();
        this.wsdlDefinition = wsdlReader.readWSDL(getWsdlLocation());
        this.serviceQName = QName.valueOf(getServiceName());
        this.soapAdapter = new SOAPMessageAdapter(wsdlDefinition, serviceQName, getPortName());
        
        WebServiceDetails details=getClass().getAnnotation(WebServiceDetails.class);
        
        if (details != null) {
        	try {
	        	Class<?> cls=Class.forName(details.factory());
	        	
	        	serviceFactory = (WebServiceProviderFactory)cls.newInstance();
	        	
            } catch(ClassNotFoundException cnfe) {
          	  throw new RuntimeException("Unable to find Web Service Factory class '"+
          			  					details.factory()+"'", cnfe);
            } catch(Exception ex) {
          	  throw new RuntimeException("Failed to instantiate Web Service Factory class '"+
          			  					details.factory()+"'", ex);
        	}
        } else {
        	throw new RuntimeException("Web service details not defined on Web Service endpoint");
        }
      }
      catch (WSDLException e)
      {
        throw new RuntimeException("Failed to parse WSDL", e);
      }
      isInitialized = true;
    }
  }

  public SOAPMessage invoke(SOAPMessage soapMessage)
  {
    log.debug("Invoking endpoint "+getEndpointId());
    init();

    try
    {
      SOAPPart soapPart = soapMessage.getSOAPPart();
      SOAPEnvelope soapEnvelope = soapPart.getEnvelope();
      Element messageElement = getMessagePayload(soapEnvelope);

      if(log.isDebugEnabled())
        log.debug( "ODE inbound message: \n" +DOMWriter.printNode(soapEnvelope, true) );

      // Create invocation context
      final String operationName = resolveOperationName(messageElement);
      
      WSInvocationAdapter invocationContext = serviceFactory.getInvocationAdapter(operationName,
    		  				serviceQName, getPortName(), soapAdapter);
      
      invocationContext.setSOAPMessage(soapMessage);

      // Invoke ODE
      serviceFactory.getServiceProvider().invoke(invocationContext);
      
      // Handle response
      SOAPMessage responseMessage = null;
      
      if (isResponseExpected(messageElement)) {
    	  responseMessage = invocationContext.getInvocationResult();

    	  if(log.isDebugEnabled())
              log.debug( "ODE outbound message: \n" +
                  DOMWriter.printNode(responseMessage.getSOAPPart().getEnvelope(), true)
              );
      } else if (log.isDebugEnabled()) {
    	  log.debug( "ODE no outbound message");
      }

      if (responseMessage == null) {
    	  log.debug("No response, probably due to oneway request");
    	  
    	  // Need to create an empty response to avoid npe in jbossws (RIFTSAW-154)
    	  responseMessage = MessageFactory.newInstance().createMessage();
      }
      
      return responseMessage;
    }
    catch (Exception e)
    {
      throw new RuntimeException("Failed to invoke BPEL process: "+e.getMessage(), e);
    }
  }

  public String resolveOperationName(Element payload)
  {
    if(soapAdapter.isRPC())
    {
      return payload.getLocalName();
    }
    else
    {
      QName elementName = new QName(payload.getNamespaceURI(), payload.getLocalName());
      Operation op = new WSDLParser(wsdlDefinition).getDocLitOperation(
          this.serviceQName, getPortName(), elementName
      );

      return op.getName();
    }
  }

  public boolean isResponseExpected(Element payload)
  {
	Operation op=null;
	
    if(soapAdapter.isRPC())
    {
        QName elementName = new QName(payload.getNamespaceURI(), payload.getLocalName());
        op = new WSDLParser(wsdlDefinition).getRPCOperation(
            this.serviceQName, getPortName(), elementName
        );
    }
    else
    {
      QName elementName = new QName(payload.getNamespaceURI(), payload.getLocalName());
      op = new WSDLParser(wsdlDefinition).getDocLitOperation(
          this.serviceQName, getPortName(), elementName
      );
    }
    
    if (op == null) {
    	throw new RuntimeException("Failed to locate operation definition for: "+payload);
    }

    return op.getOutput() != null;
  }

  public static Element getMessagePayload(SOAPEnvelope soapEnvelope)
      throws SOAPException
  {
    SOAPBody body = soapEnvelope.getBody();
    Element messageElement = null;     // first child of type Element
    NodeList children = body.getChildNodes();
    for(int i=0; i<children.getLength(); i++)
    {
      Node tmp = children.item(i);
      if(Node.ELEMENT_NODE == tmp.getNodeType())
      {
        messageElement = (Element)tmp;
        break;
      }
    }
    return messageElement;
  }

  public String getEndpointId() {
	  return(null);
  }

  public String getServiceName() {
	  return(null);
  }

  public String getWsdlLocation() {
	  return(null);
  }

  public String getPortName() {
	  return(null);
  }
}
