001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one
003     * or more contributor license agreements. See the NOTICE file
004     * distributed with this work for additional information
005     * regarding copyright ownership. The ASF licenses this file
006     * to you under the Apache License, Version 2.0 (the
007     * "License"); you may not use this file except in compliance
008     * with the License. You may obtain a copy of the License at
009     *
010     * http://www.apache.org/licenses/LICENSE-2.0
011     *
012     * Unless required by applicable law or agreed to in writing,
013     * software distributed under the License is distributed on an
014     * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015     * KIND, either express or implied. See the License for the
016     * specific language governing permissions and limitations
017     * under the License.
018     */
019    
020    package org.apache.geronimo.axis2;
021    
022    import java.io.OutputStream;
023    import java.net.MalformedURLException;
024    import java.net.URI;
025    import java.net.URISyntaxException;
026    import java.net.URL;
027    import java.util.ArrayList;
028    import java.util.Collection;
029    import java.util.Iterator;
030    import java.util.List;
031    import java.util.Map;
032    import java.util.concurrent.ConcurrentHashMap;
033    
034    import javax.wsdl.Definition;
035    import javax.wsdl.Import;
036    import javax.wsdl.Port;
037    import javax.wsdl.Service;
038    import javax.wsdl.Types;
039    import javax.wsdl.extensions.ExtensibilityElement;
040    import javax.wsdl.extensions.schema.Schema;
041    import javax.wsdl.extensions.schema.SchemaImport;
042    import javax.wsdl.extensions.schema.SchemaReference;
043    import javax.wsdl.extensions.soap.SOAPAddress;
044    import javax.wsdl.extensions.soap12.SOAP12Address;
045    import javax.wsdl.factory.WSDLFactory;
046    import javax.wsdl.xml.WSDLReader;
047    import javax.wsdl.xml.WSDLWriter;
048    import javax.xml.namespace.QName;
049    import javax.xml.transform.OutputKeys;
050    import javax.xml.transform.Source;
051    import javax.xml.transform.Transformer;
052    import javax.xml.transform.TransformerException;
053    import javax.xml.transform.TransformerFactory;
054    import javax.xml.transform.dom.DOMSource;
055    import javax.xml.transform.stream.StreamResult;
056    
057    import org.apache.axis2.description.AxisService;
058    import org.apache.commons.logging.Log;
059    import org.apache.commons.logging.LogFactory;
060    import org.apache.geronimo.webservices.WebServiceContainer.Request;
061    import org.w3c.dom.Element;
062    import org.w3c.dom.Node;
063    import org.w3c.dom.NodeList;
064    
065    public class WSDLQueryHandler {
066    
067        private static final Log LOG = LogFactory.getLog(WSDLQueryHandler.class);
068        
069        private Map<String, Definition> mp = new ConcurrentHashMap<String, Definition>();
070        private Map<String, SchemaReference> smp = new ConcurrentHashMap<String, SchemaReference>();
071        private AxisService service;
072        
073        public WSDLQueryHandler(AxisService service) {
074            this.service = service;
075        }
076        
077        public void writeResponse(String baseUri, String wsdlUri, OutputStream os) throws Exception {
078    
079            int idx = baseUri.toLowerCase().indexOf("?wsdl");
080            String base = null;
081            String wsdl = "";
082            String xsd = null;
083            if (idx != -1) {
084                base = baseUri.substring(0, baseUri.toLowerCase().indexOf("?wsdl"));
085                wsdl = baseUri.substring(baseUri.toLowerCase().indexOf("?wsdl") + 5);
086                if (wsdl.length() > 0) {
087                    wsdl = wsdl.substring(1);
088                }
089            } else {
090                base = baseUri.substring(0, baseUri.toLowerCase().indexOf("?xsd="));
091                xsd = baseUri.substring(baseUri.toLowerCase().indexOf("?xsd=") + 5);
092            }
093    
094            if (!mp.containsKey(wsdl)) {
095                WSDLFactory factory = WSDLFactory.newInstance();
096                WSDLReader reader = factory.newWSDLReader();
097                reader.setFeature("javax.wsdl.importDocuments", true);
098                reader.setFeature("javax.wsdl.verbose", false);
099                Definition def = reader.readWSDL(wsdlUri);
100                updateDefinition(def, mp, smp, base);
101                updateServices(this.service.getName(), this.service.getEndpointName(), def, base);
102                mp.put("", def);
103            }
104    
105            Element rootElement;
106    
107            if (xsd == null) {
108                Definition def = mp.get(wsdl);
109    
110                WSDLFactory factory = WSDLFactory.newInstance();
111                WSDLWriter writer = factory.newWSDLWriter();
112    
113                rootElement = writer.getDocument(def).getDocumentElement();
114            } else {
115                SchemaReference si = smp.get(xsd);
116                rootElement = si.getReferencedSchema().getElement();
117            }
118    
119            NodeList nl = rootElement.getElementsByTagNameNS("http://www.w3.org/2001/XMLSchema",
120                    "import");
121            for (int x = 0; x < nl.getLength(); x++) {
122                Element el = (Element) nl.item(x);
123                String sl = el.getAttribute("schemaLocation");
124                if (smp.containsKey(sl)) {
125                    el.setAttribute("schemaLocation", base + "?xsd=" + sl);
126                }
127            }
128            nl = rootElement.getElementsByTagNameNS("http://www.w3.org/2001/XMLSchema", "include");
129            for (int x = 0; x < nl.getLength(); x++) {
130                Element el = (Element) nl.item(x);
131                String sl = el.getAttribute("schemaLocation");
132                if (smp.containsKey(sl)) {
133                    el.setAttribute("schemaLocation", base + "?xsd=" + sl);
134                }
135            }
136            nl = rootElement.getElementsByTagNameNS("http://schemas.xmlsoap.org/wsdl/", "import");
137            for (int x = 0; x < nl.getLength(); x++) {
138                Element el = (Element) nl.item(x);
139                String sl = el.getAttribute("location");
140                if (mp.containsKey(sl)) {
141                    el.setAttribute("location", base + "?wsdl=" + sl);
142                }
143            }
144    
145            writeTo(rootElement, os);
146        }
147           
148        protected void updateDefinition(Definition def,
149                                        Map<String, Definition> done,
150                                        Map<String, SchemaReference> doneSchemas,
151                                        String base) {
152            Collection<List> imports = def.getImports().values();
153            for (List lst : imports) {
154                List<Import> impLst = lst;
155                for (Import imp : impLst) {
156                    String start = imp.getLocationURI();
157                    try {
158                        //check to see if it's aleady in a URL format.  If so, leave it.
159                        new URL(start);
160                    } catch (MalformedURLException e) {
161                        done.put(start, imp.getDefinition());
162                        updateDefinition(imp.getDefinition(), done, doneSchemas, base);
163                    }
164                }
165            }      
166            
167            
168            /* This doesn't actually work.   Setting setSchemaLocationURI on the import
169            * for some reason doesn't actually result in the new URI being written
170            * */
171            Types types = def.getTypes();
172            if (types != null) {
173                for (ExtensibilityElement el : (List<ExtensibilityElement>)types.getExtensibilityElements()) {
174                    if (el instanceof Schema) {
175                        Schema see = (Schema)el;
176                        updateSchemaImports(see, doneSchemas, base);
177                    }
178                }
179            }
180        }
181        
182        protected void updateSchemaImports(Schema schema,
183                                           Map<String, SchemaReference> doneSchemas,
184                                           String base) {
185            Collection<List>  imports = schema.getImports().values();
186            for (List lst : imports) {
187                List<SchemaImport> impLst = lst;
188                for (SchemaImport imp : impLst) {
189                    String start = imp.getSchemaLocationURI();
190                    if (start != null) {
191                        try {
192                            //check to see if it's aleady in a URL format.  If so, leave it.
193                            new URL(start);
194                        } catch (MalformedURLException e) {
195                            if (!doneSchemas.containsKey(start)) {
196                                doneSchemas.put(start, imp);
197                                updateSchemaImports(imp.getReferencedSchema(), doneSchemas, base);
198                            }
199                        }
200                    }
201                }
202            }
203            List<SchemaReference> includes = schema.getIncludes();
204            for (SchemaReference included : includes) {
205                String start = included.getSchemaLocationURI();
206                if (start != null) {
207                    try {
208                        //check to see if it's aleady in a URL format.  If so, leave it.
209                        new URL(start);
210                    } catch (MalformedURLException e) {
211                        if (!doneSchemas.containsKey(start)) {
212                            doneSchemas.put(start, included);
213                            updateSchemaImports(included.getReferencedSchema(), doneSchemas, base);
214                        }
215                    }
216                }
217            }
218        }
219        
220        public static void writeTo(Node node, OutputStream os) {
221            writeTo(new DOMSource(node), os);
222        }
223        
224        public static void writeTo(Source src, OutputStream os) {
225            Transformer it;
226            try {
227                it = TransformerFactory.newInstance().newTransformer();
228                it.setOutputProperty(OutputKeys.METHOD, "xml");
229                it.setOutputProperty(OutputKeys.INDENT, "yes");
230                it.setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "4");
231                it.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "false");
232                it.setOutputProperty(OutputKeys.ENCODING, "utf-8");
233                it.transform(src, new StreamResult(os));
234            } catch (TransformerException e) {
235                // TODO Auto-generated catch block
236                e.printStackTrace();
237            }
238        }
239        
240        private void updateServices(String serviceName, String portName, Definition def, String baseUri)
241                throws Exception {
242            boolean updated = false;
243            Map services = def.getServices();
244            if (services != null) {
245                ArrayList<QName> servicesToRemove = new ArrayList<QName>();
246                
247                Iterator serviceIterator = services.entrySet().iterator();
248                while (serviceIterator.hasNext()) {
249                    Map.Entry serviceEntry = (Map.Entry) serviceIterator.next();
250                    QName currServiceName = (QName) serviceEntry.getKey();
251                    if (currServiceName.getLocalPart().equals(serviceName)) {
252                        Service service = (Service) serviceEntry.getValue();
253                        updatePorts(portName, service, baseUri);
254                        updated = true;
255                    } else {
256                        servicesToRemove.add(currServiceName);
257                    }
258                }
259                
260                for (QName serviceToRemove : servicesToRemove) {
261                    def.removeService(serviceToRemove);                
262                }
263            }
264            if (!updated) {
265                LOG.warn("WSDL '" + serviceName + "' service not found.");
266            }
267        }
268    
269        private void updatePorts(String portName, Service service, String baseUri) throws Exception {
270            boolean updated = false;
271            Map ports = service.getPorts();
272            if (ports != null) {
273                ArrayList<String> portsToRemove = new ArrayList<String>();
274                
275                Iterator portIterator = ports.entrySet().iterator();
276                while (portIterator.hasNext()) {
277                    Map.Entry portEntry = (Map.Entry) portIterator.next();
278                    String currPortName = (String) portEntry.getKey();
279                    if (currPortName.equals(portName)) {
280                        Port port = (Port) portEntry.getValue();
281                        updatePortLocation(port, baseUri);
282                        updated = true;
283                    } else {
284                        portsToRemove.add(currPortName);
285                    }
286                }
287                
288                for (String portToRemove : portsToRemove) {
289                    service.removePort(portToRemove);               
290                }
291            }
292            if (!updated) {
293                LOG.warn("WSDL '" + portName + "' port not found.");
294            }
295        }
296    
297        private void updatePortLocation(Port port, String baseUri) throws URISyntaxException {
298            List<?> exts = port.getExtensibilityElements();
299            if (exts != null && exts.size() > 0) {
300                ExtensibilityElement el = (ExtensibilityElement) exts.get(0);
301                if (el instanceof SOAP12Address) {
302                    SOAP12Address add = (SOAP12Address) el;
303                    add.setLocationURI(baseUri);
304                } else if (el instanceof SOAPAddress) {
305                    SOAPAddress add = (SOAPAddress) el;
306                    add.setLocationURI(baseUri);
307                }
308            }
309        }
310    }