/*
 * JBoss, Home of Professional Open Source
 * Copyright 2006, JBoss Inc., and others contributors as indicated
 * by the @authors tag. All rights reserved.
 * See the copyright.txt in the distribution for a
 * full listing of individual contributors.
 * This copyrighted material is made available to anyone wishing to use,
 * modify, copy, or redistribute it subject to the terms and conditions
 * of the GNU Lesser General Public License, v. 2.1.
 * This program is distributed in the hope that it will be useful, but WITHOUT A
 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
 * PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more details.
 * You should have received a copy of the GNU Lesser General Public License,
 * v.2.1 along with this distribution; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 * MA  02110-1301, USA.
 *
 * (C) 2005-2006, JBoss Inc.
 */
package org.jboss.internal.soa.esb.webservice;

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.wsdl.Binding;
import javax.wsdl.BindingFault;
import javax.wsdl.BindingInput;
import javax.wsdl.BindingOperation;
import javax.wsdl.BindingOutput;
import javax.wsdl.Definition;
import javax.wsdl.Fault;
import javax.wsdl.Input;
import javax.wsdl.Message;
import javax.wsdl.Operation;
import javax.wsdl.Output;
import javax.wsdl.Part;
import javax.wsdl.PortType;
import javax.wsdl.Types;
import javax.wsdl.WSDLException;
import javax.wsdl.extensions.ExtensibilityElement;
import javax.wsdl.extensions.ExtensionRegistry;
import javax.wsdl.extensions.ExtensionSerializer;
import javax.wsdl.extensions.soap.SOAPOperation;
import javax.wsdl.factory.WSDLFactory;
import javax.xml.namespace.QName;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.transform.Result;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerException;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;

import org.jboss.internal.soa.esb.util.XMLHelper;
import org.jboss.soa.esb.ConfigurationException;
import org.jboss.soa.esb.Service;
import org.jboss.soa.esb.dom.YADOMUtil;
import org.jboss.soa.esb.listeners.config.WebserviceInfo;
import org.jboss.soa.esb.util.ClassUtil;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;

import com.ibm.wsdl.extensions.schema.SchemaImpl;
import com.ibm.wsdl.extensions.soap.SOAPAddressImpl;
import com.ibm.wsdl.extensions.soap.SOAPBindingImpl;
import com.ibm.wsdl.extensions.soap.SOAPBodyImpl;
import com.ibm.wsdl.extensions.soap.SOAPFaultImpl;
import com.ibm.wsdl.extensions.soap.SOAPOperationImpl;

public class ESBContractGenerator {
	private static final QName XSD_QN = new QName("http://www.w3.org/2001/XMLSchema", "schema");
	private static final String WSDL_NAMESPACE = "http://schemas.xmlsoap.org/wsdl/" ;
	private static final String WSDL_REQUIRED = "required" ;
	private static final String WSAW_NAMESPACE = "http://www.w3.org/2006/05/addressing/wsdl" ;
	private static final String WSAW_PREFIX = "wsaw" ;
	private static final QName WSAW_ACTION_QN = new QName(WSAW_NAMESPACE, "Action", WSAW_PREFIX) ;
	private static final QName WSAW_USING_ADDRESSING_QN = new QName(WSAW_NAMESPACE, "UsingAddressing", WSAW_PREFIX) ;
	private static WSDLFactory wsdlFactory ;
	
	public static String generateWSDL(final WebserviceInfo serviceConfig, final ESBServiceEndpointInfo serviceInfo) throws ConfigurationException {
	    return generateWSDL(serviceConfig, serviceInfo, null);
	}
	
	public static String generateWSDL(final WebserviceInfo serviceConfig, final ESBServiceEndpointInfo serviceInfo, final ClassLoader classLoader) throws ConfigurationException {
		final Definition def = getWSDLFactory().newDefinition() ;
		final String namespace = serviceInfo.getNamespace() ;
		def.setTargetNamespace(namespace);
		def.addNamespace("tns", namespace);
		def.addNamespace("soap", "http://schemas.xmlsoap.org/wsdl/soap/");
		
		if (serviceInfo.isAddressing()) {
			def.getExtensionRegistry().registerSerializer(Binding.class, WSAW_USING_ADDRESSING_QN, new UsingAddressingSerializer()) ;
			def.addNamespace(WSAW_PREFIX, WSAW_NAMESPACE) ;
		}
		// add types
		Types types = def.createTypes();
		def.setTypes(types);
		
		// Keeps track of schema types added to avoid duplicates.
		Set<String> schemasAdded = new HashSet<String>(); 

		Message reqMessage = null;
		Message resMessage = null;
		List<Message> faultMessages = null;

		String inXsd = serviceConfig.getInXsd();
		String outXsd = serviceConfig.getOutXsd();
		String faultXsd = serviceConfig.getFaultXsd();

		int nsSuffixCounter = 0 ;
		if (inXsd != null) {
			try {
				Document doc = YADOMUtil.parseStream(getResourceAsStream(inXsd, classLoader), false, false);
				if (doc != null) {
					reqMessage = addMessage(def, doc.getDocumentElement(),
							serviceInfo.getRequestName(), "in", ++nsSuffixCounter, schemasAdded);
				}
			} catch (Exception e) {
				throw new ConfigurationException("File defined in inXsd attribute '" + inXsd + "' not found in classpath.", e);
			} 
		}

		if (outXsd != null) {
			try {
				Document doc = YADOMUtil.parseStream(getResourceAsStream(outXsd, classLoader), false, false);
				if (doc != null) {
					resMessage = addMessage(def, doc.getDocumentElement(),
						serviceInfo.getResponseName(), "out", ++nsSuffixCounter, schemasAdded);
				}
			} catch (Exception e) {
				throw new ConfigurationException("File defined in outXsd attribute '" + outXsd + "' not found in classpath.", e);
			} 

		}

		if ((faultXsd != null) && !serviceInfo.isOneWay()) {
			try {
				final String[] xsds = faultXsd.split(",") ;
				faultMessages = new ArrayList<Message>();
				for(String xsd: xsds) {
					Document doc = YADOMUtil.parseStream(getResourceAsStream(xsd, classLoader), false, false);
					if (doc != null) {
						addFaultMessage(faultMessages, def, doc.getDocumentElement(),
							serviceInfo.getFaultName(), "fault", ++nsSuffixCounter, schemasAdded);
					}
				}
			} catch (Exception e) {
				throw new ConfigurationException("File defined in faultXsd attribute '" + faultXsd + "' not found in classpath.", e);
			} 

		}

		PortType portType = addPortType(def, serviceInfo, reqMessage,
			resMessage, faultMessages);
		Binding binding = addBinding(def, serviceInfo, portType);
		addService(def, serviceInfo, binding);
		StringWriter sw = new java.io.StringWriter();
		try {
			getWSDLFactory().newWSDLWriter().writeWSDL(def, sw);
		} catch (WSDLException e) {
			final Service service = serviceConfig.getService() ;
			throw new ConfigurationException("Failed to generate wsdl for service:" + service.getCategory() + "/" + service.getName() , e);
		}
		return sw.toString();
	}
	
	private static void addSchema(Types types, Element xsdElement, Set<String> schemasAdded) throws SAXException, IOException, TransformerException, ParserConfigurationException {
	    if (add(xsdElement, schemasAdded))
	    {
    		SchemaImpl schemaImpl = new SchemaImpl();
    		schemaImpl.setElement(xsdElement);
    		schemaImpl.setElementType(XSD_QN);
    		types.addExtensibilityElement(schemaImpl);
	    }
	}
	
	private static InputStream getResourceAsStream(final String resource, final ClassLoader classLoader)
	{
	    if (classLoader != null)
	    {
	        final InputStream in = classLoader.getResourceAsStream(resource);
	        if (in !=null )
	        {
	            return in;
	        }
	    }
	    // Fallback to using the class's clasloader.
		return ClassUtil.getResourceAsStream(resource, ESBContractGenerator.class);
	}

	private static boolean add(final Element schemaElement, final Set<String> schemasAdded) throws SAXException, IOException, TransformerException, ParserConfigurationException
    {
        TransformerFactory factory = TransformerFactory.newInstance();
        Transformer transformer = factory.newTransformer();
        StringWriter writer = new StringWriter();
        Result result = new StreamResult(writer);
        transformer.transform(new DOMSource(schemaElement), result);
        String newType = writer.toString();
        if (schemasAdded.size() == 0)
        {
            return schemasAdded.add(newType);
        }
        else
        {
            boolean addSchema = true ;
            for (String existingType : schemasAdded)
            {
                if (XMLHelper.compareXMLContent(existingType, newType))
                {
                    addSchema = false ;
                    break ;
                }
            }
            if (addSchema)
            {
                return schemasAdded.add(newType);
            }
        } 
        
        return false;
    }

    private static Message addMessage(Definition def, Element element, String msgName, String partName, int nsSuffixCounter, Set<String> schemasAdded) throws SAXException, IOException, TransformerException, ParserConfigurationException {
		String schemaNs = YADOMUtil
				.getAttribute(element, "targetNamespace", "");
		addSchema(def.getTypes(), element, schemasAdded);
		if (def.getNamespace(schemaNs) == null) {
			def.addNamespace("ns" + nsSuffixCounter, schemaNs);
		}
		// add request message
		Node node = YADOMUtil.getNode(element, "/schema/element");
		Message msg = def.createMessage();
		msg.setQName(new QName(def.getTargetNamespace(), msgName));
		msg.setUndefined(false);
		Part part = def.createPart();
		part.setName(partName);
		part.setElementName(new QName(schemaNs, YADOMUtil.getAttribute(
				(Element) node, "name", "")));
		msg.addPart(part);
		def.addMessage(msg);
		return msg;
	}

	private static void addFaultMessage(final List<Message> faultMessages,
			Definition def, Element element, String msgName, String partName,
			int nsSuffixCounter, Set<String> schemasAdded) throws SAXException, IOException, TransformerException, ParserConfigurationException {
		String schemaNs = YADOMUtil
				.getAttribute(element, "targetNamespace", "");
		addSchema(def.getTypes(), element, schemasAdded);
		if (def.getNamespace(schemaNs) == null) {
			def.addNamespace("ns" + nsSuffixCounter, schemaNs);
		}
		// add request message
		NodeList nodes = YADOMUtil.getNodeList(element, "/schema/element");
		for (int i = 0; i < nodes.getLength(); i++) {
			final int nameIndex = i + 1;
			Node node = nodes.item(0);
			Message msg = def.createMessage();
			msg.setQName(new QName(def.getTargetNamespace(), msgName
					+ nameIndex));
			msg.setUndefined(false);
			Part part = def.createPart();
			part.setName(partName + nameIndex);
			part.setElementName(new QName(schemaNs, YADOMUtil.getAttribute(
					(Element) node, "name", "")));
			msg.addPart(part);
			def.addMessage(msg);
			faultMessages.add(msg);
		}
	}

	private static PortType addPortType(Definition def, final ESBServiceEndpointInfo serviceInfo,
			Message inMessage, Message outMessage, List<Message> faultMessages) {
		// add port type
		PortType portType = def.createPortType();
		portType.setQName(new QName(def.getTargetNamespace(), serviceInfo.getPortName())) ;
		Operation op = def.createOperation();
		op.setUndefined(false);
		op.setName(serviceInfo.getOperationName());
		if (inMessage != null) {
			Input in = def.createInput();
			in.setMessage(inMessage);
			in.setName(inMessage.getQName().getLocalPart());
			if (serviceInfo.isAddressing()) {
				in.setExtensionAttribute(WSAW_ACTION_QN, serviceInfo.getRequestAction()) ;
			}
			op.setInput(in);
		}
		if (outMessage != null) {
			Output out = def.createOutput();
			out.setMessage(outMessage);
			out.setName(outMessage.getQName().getLocalPart());
			if (serviceInfo.isAddressing()) {
				out.setExtensionAttribute(WSAW_ACTION_QN, serviceInfo.getResponseAction()) ;
			}
			op.setOutput(out);
		}

		int count = 1 ;
		if (faultMessages != null) {
			for (Message message : faultMessages) {
				Fault fault = def.createFault();
				fault.setMessage(message);
				fault.setName("fault" + (count++));
				op.addFault(fault);
			}
		}
		portType.addOperation(op);
		portType.setUndefined(false);
		def.addPortType(portType);
		return portType;
	}

	private static Binding addBinding(Definition def, final ESBServiceEndpointInfo serviceInfo, PortType portType) {
		// add binding
		Binding binding = def.createBinding();
		binding.setUndefined(false);
		binding.setPortType(portType);
		binding.setQName(new QName(def.getTargetNamespace(), serviceInfo.getBindingName())) ;
		SOAPBindingImpl soapBinding = new SOAPBindingImpl();
		soapBinding.setStyle("document");
		soapBinding.setTransportURI("http://schemas.xmlsoap.org/soap/http");
		binding.addExtensibilityElement(soapBinding);
		if (serviceInfo.isAddressing()) {
			binding.addExtensibilityElement(new UsingAddressingExtension()) ;
		}

		BindingOperation bop = def.createBindingOperation();

		bop.setName(serviceInfo.getOperationName());
		
		Operation op = (Operation) portType.getOperations().get(0);
		bop.setOperation(op);
		SOAPOperation soapOperation = new SOAPOperationImpl() ;
		soapOperation.setSoapActionURI(serviceInfo.getResponseAction()) ;
		bop.addExtensibilityElement(soapOperation) ;
		
		if (op.getInput() != null) {
			BindingInput binput = def.createBindingInput();
			bop.setBindingInput(binput);
			SOAPBodyImpl soapBody = new SOAPBodyImpl();
			soapBody.setUse("literal");
			binput.setName(serviceInfo.getRequestName()) ;
			binput.addExtensibilityElement(soapBody);
		}
		if (op.getOutput() != null) {
			BindingOutput boutput = def.createBindingOutput();
			bop.setBindingOutput(boutput);
			SOAPBodyImpl soapBody = new SOAPBodyImpl();
			soapBody.setUse("literal");
			boutput.setName(serviceInfo.getResponseName()) ;
			boutput.addExtensibilityElement(soapBody);
		}
		final Map faults = op.getFaults() ;
		if (faults != null) {
			Iterator iterator = op.getFaults().values().iterator();
			while (iterator.hasNext()) {
				Fault fault = (Fault) iterator.next();
				BindingFault bfault = def.createBindingFault();
				bfault.setName(fault.getName());
				bop.addBindingFault(bfault);
				SOAPFaultImpl soapFault = new SOAPFaultImpl();
				soapFault.setName(fault.getName());
				soapFault.setUse("literal");
				bfault.addExtensibilityElement(soapFault);
			}
		}
		binding.addBindingOperation(bop);
		def.addBinding(binding);
		return binding;

	}

	private static void addService(Definition def, final ESBServiceEndpointInfo serviceInfo, Binding binding) {
		// create service
		javax.wsdl.Service service = def.createService();
		service.setQName(new QName(def.getTargetNamespace(), serviceInfo.getServiceName()));
		javax.wsdl.Port port = def.createPort();
		port.setBinding(binding);
		port.setName(serviceInfo.getPortName());
		SOAPAddressImpl soapAddress = new SOAPAddressImpl();
		soapAddress.setLocationURI("http://change_this_URI/"+serviceInfo.getServletPath());
		port.addExtensibilityElement(soapAddress);
		service.addPort(port);
		def.addService(service);
	}
	
	private synchronized static WSDLFactory getWSDLFactory()
	    throws ConfigurationException
	{
	    if (wsdlFactory == null)
	    {
	        try
	        {
	            wsdlFactory = AccessController.doPrivileged(new PrivilegedExceptionAction<WSDLFactory>() {
	                public WSDLFactory run() throws WSDLException
	                {
	                    return WSDLFactory.newInstance();
	                }
	            }) ;
	        }
	        catch (final PrivilegedActionException pae)
	        {
	            throw new ConfigurationException("Failed to instantiate the WSDL factory", pae.getCause()) ;
	        }
	    }
	    return wsdlFactory ;
	}
	
	private static class UsingAddressingExtension implements ExtensibilityElement {
		public QName getElementType() {
			return WSAW_USING_ADDRESSING_QN ;
		}

		public Boolean getRequired() {
			return Boolean.TRUE ;
		}

		public void setElementType(final QName qname) {
		}

		public void setRequired(final Boolean required) {
		}
	}
	
	private static class UsingAddressingSerializer implements ExtensionSerializer
	{
		public void marshall(final Class parentType, final QName elementType,
			final ExtensibilityElement extension, final PrintWriter pw,
			final Definition definition, final ExtensionRegistry registry)
			throws WSDLException {
			if (extension != null) {
				final String prefix = definition.getPrefix(elementType.getNamespaceURI()) ;
				pw.print("    <"+prefix+":"+elementType.getLocalPart()) ;
				if (extension.getRequired().booleanValue()) {
					final String wsdlPrefix = definition.getPrefix(WSDL_NAMESPACE) ;
					pw.print(" " + wsdlPrefix + ":" + WSDL_REQUIRED + "=\"true\"") ;
				}
				pw.println("/>") ;
			}
		}
	}
}
