001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one
003     * or more contributor license agreements. See the NOTICE file
004     * distributed with this work for additional information
005     * regarding copyright ownership. The ASF licenses this file
006     * to you under the Apache License, Version 2.0 (the
007     * "License"); you may not use this file except in compliance
008     * with the License. You may obtain a copy of the License at
009     *
010     * http://www.apache.org/licenses/LICENSE-2.0
011     *
012     * Unless required by applicable law or agreed to in writing,
013     * software distributed under the License is distributed on an
014     * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015     * KIND, either express or implied. See the License for the
016     * specific language governing permissions and limitations
017     * under the License.
018     */
019    
020    package org.apache.geronimo.axis2;
021    
022    import java.io.FileNotFoundException;
023    import java.io.OutputStream;
024    import java.net.MalformedURLException;
025    import java.net.URI;
026    import java.net.URISyntaxException;
027    import java.net.URL;
028    import java.util.ArrayList;
029    import java.util.Collection;
030    import java.util.Iterator;
031    import java.util.List;
032    import java.util.Map;
033    import java.util.concurrent.ConcurrentHashMap;
034    
035    import javax.wsdl.Definition;
036    import javax.wsdl.Import;
037    import javax.wsdl.Port;
038    import javax.wsdl.Service;
039    import javax.wsdl.Types;
040    import javax.wsdl.extensions.ExtensibilityElement;
041    import javax.wsdl.extensions.schema.Schema;
042    import javax.wsdl.extensions.schema.SchemaImport;
043    import javax.wsdl.extensions.schema.SchemaReference;
044    import javax.wsdl.extensions.soap.SOAPAddress;
045    import javax.wsdl.extensions.soap12.SOAP12Address;
046    import javax.wsdl.factory.WSDLFactory;
047    import javax.wsdl.xml.WSDLReader;
048    import javax.wsdl.xml.WSDLWriter;
049    import javax.xml.namespace.QName;
050    import javax.xml.transform.OutputKeys;
051    import javax.xml.transform.Source;
052    import javax.xml.transform.Transformer;
053    import javax.xml.transform.TransformerException;
054    import javax.xml.transform.TransformerFactory;
055    import javax.xml.transform.dom.DOMSource;
056    import javax.xml.transform.stream.StreamResult;
057    
058    import org.apache.axis2.description.AxisService;
059    import org.apache.commons.logging.Log;
060    import org.apache.commons.logging.LogFactory;
061    import org.apache.geronimo.webservices.WebServiceContainer.Request;
062    import org.w3c.dom.Element;
063    import org.w3c.dom.Node;
064    import org.w3c.dom.NodeList;
065    
066    public class WSDLQueryHandler {
067    
068        private static final Log LOG = LogFactory.getLog(WSDLQueryHandler.class);
069        
070        private Map<String, Definition> mp = new ConcurrentHashMap<String, Definition>();
071        private Map<String, SchemaReference> smp = new ConcurrentHashMap<String, SchemaReference>();
072        private AxisService service;
073        
074        public WSDLQueryHandler(AxisService service) {
075            this.service = service;
076        }
077        
078        public void writeResponse(String baseUri, String wsdlUri, OutputStream os) throws Exception {
079    
080            String base = null;
081            String wsdl = "";
082            String xsd = null;
083            
084            int idx = baseUri.toLowerCase().indexOf("?wsdl");
085            if (idx != -1) {
086                base = baseUri.substring(0, idx);
087                wsdl = baseUri.substring(idx + 5);
088                if (wsdl.length() > 0) {
089                    wsdl = wsdl.substring(1);
090                }
091            } else {
092                idx = baseUri.toLowerCase().indexOf("?xsd");
093                if (idx != -1) {
094                    base = baseUri.substring(0, idx);
095                    xsd = baseUri.substring(idx + 4);
096                    if (xsd.length() > 0) {
097                        xsd = xsd.substring(1);
098                    }
099                } else {
100                    throw new Exception("Invalid request: " + baseUri);
101                }
102            }
103    
104            if (!mp.containsKey(wsdl)) {
105                WSDLFactory factory = WSDLFactory.newInstance();
106                WSDLReader reader = factory.newWSDLReader();
107                reader.setFeature("javax.wsdl.importDocuments", true);
108                reader.setFeature("javax.wsdl.verbose", false);
109                Definition def = reader.readWSDL(wsdlUri);
110                updateDefinition(def, mp, smp, base);
111                updateServices(this.service.getName(), this.service.getEndpointName(), def, base);
112                mp.put("", def);
113            }
114    
115            Element rootElement;
116    
117            if (xsd == null) {
118                Definition def = mp.get(wsdl);
119    
120                if (def == null) {
121                    throw new FileNotFoundException("WSDL not found: " + wsdl);
122                }
123                
124                WSDLFactory factory = WSDLFactory.newInstance();
125                WSDLWriter writer = factory.newWSDLWriter();
126    
127                rootElement = writer.getDocument(def).getDocumentElement();
128            } else {
129                SchemaReference si = smp.get(xsd);
130                
131                if (si == null) {
132                    throw new FileNotFoundException("Schema not found: " + xsd);
133                }
134                
135                rootElement = si.getReferencedSchema().getElement();
136            }
137    
138            NodeList nl = rootElement.getElementsByTagNameNS("http://www.w3.org/2001/XMLSchema",
139                    "import");
140            for (int x = 0; x < nl.getLength(); x++) {
141                Element el = (Element) nl.item(x);
142                String sl = el.getAttribute("schemaLocation");
143                if (smp.containsKey(sl)) {
144                    el.setAttribute("schemaLocation", base + "?xsd=" + sl);
145                }
146            }
147            nl = rootElement.getElementsByTagNameNS("http://www.w3.org/2001/XMLSchema", "include");
148            for (int x = 0; x < nl.getLength(); x++) {
149                Element el = (Element) nl.item(x);
150                String sl = el.getAttribute("schemaLocation");
151                if (smp.containsKey(sl)) {
152                    el.setAttribute("schemaLocation", base + "?xsd=" + sl);
153                }
154            }
155            nl = rootElement.getElementsByTagNameNS("http://schemas.xmlsoap.org/wsdl/", "import");
156            for (int x = 0; x < nl.getLength(); x++) {
157                Element el = (Element) nl.item(x);
158                String sl = el.getAttribute("location");
159                if (mp.containsKey(sl)) {
160                    el.setAttribute("location", base + "?wsdl=" + sl);
161                }
162            }
163    
164            writeTo(rootElement, os);
165        }
166           
167        protected void updateDefinition(Definition def,
168                                        Map<String, Definition> done,
169                                        Map<String, SchemaReference> doneSchemas,
170                                        String base) {
171            Collection<List> imports = def.getImports().values();
172            for (List lst : imports) {
173                List<Import> impLst = lst;
174                for (Import imp : impLst) {
175                    String start = imp.getLocationURI();
176                    try {
177                        //check to see if it's aleady in a URL format.  If so, leave it.
178                        new URL(start);
179                    } catch (MalformedURLException e) {
180                        done.put(start, imp.getDefinition());
181                        updateDefinition(imp.getDefinition(), done, doneSchemas, base);
182                    }
183                }
184            }      
185            
186            
187            /* This doesn't actually work.   Setting setSchemaLocationURI on the import
188            * for some reason doesn't actually result in the new URI being written
189            * */
190            Types types = def.getTypes();
191            if (types != null) {
192                for (ExtensibilityElement el : (List<ExtensibilityElement>)types.getExtensibilityElements()) {
193                    if (el instanceof Schema) {
194                        Schema see = (Schema)el;
195                        updateSchemaImports(see, doneSchemas, base);
196                    }
197                }
198            }
199        }
200        
201        protected void updateSchemaImports(Schema schema,
202                                           Map<String, SchemaReference> doneSchemas,
203                                           String base) {
204            Collection<List>  imports = schema.getImports().values();
205            for (List lst : imports) {
206                List<SchemaImport> impLst = lst;
207                for (SchemaImport imp : impLst) {
208                    String start = imp.getSchemaLocationURI();
209                    if (start != null) {
210                        try {
211                            //check to see if it's aleady in a URL format.  If so, leave it.
212                            new URL(start);
213                        } catch (MalformedURLException e) {
214                            if (!doneSchemas.containsKey(start)) {
215                                doneSchemas.put(start, imp);
216                                updateSchemaImports(imp.getReferencedSchema(), doneSchemas, base);
217                            }
218                        }
219                    }
220                }
221            }
222            List<SchemaReference> includes = schema.getIncludes();
223            for (SchemaReference included : includes) {
224                String start = included.getSchemaLocationURI();
225                if (start != null) {
226                    try {
227                        //check to see if it's aleady in a URL format.  If so, leave it.
228                        new URL(start);
229                    } catch (MalformedURLException e) {
230                        if (!doneSchemas.containsKey(start)) {
231                            doneSchemas.put(start, included);
232                            updateSchemaImports(included.getReferencedSchema(), doneSchemas, base);
233                        }
234                    }
235                }
236            }
237        }
238        
239        public static void writeTo(Node node, OutputStream os) {
240            writeTo(new DOMSource(node), os);
241        }
242        
243        public static void writeTo(Source src, OutputStream os) {
244            Transformer it;
245            try {
246                it = TransformerFactory.newInstance().newTransformer();
247                it.setOutputProperty(OutputKeys.METHOD, "xml");
248                it.setOutputProperty(OutputKeys.INDENT, "yes");
249                it.setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "4");
250                it.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "false");
251                it.setOutputProperty(OutputKeys.ENCODING, "utf-8");
252                it.transform(src, new StreamResult(os));
253            } catch (TransformerException e) {
254                // TODO Auto-generated catch block
255                e.printStackTrace();
256            }
257        }
258        
259        private void updateServices(String serviceName, String portName, Definition def, String baseUri)
260                throws Exception {
261            boolean updated = false;
262            Map services = def.getServices();
263            if (services != null) {
264                ArrayList<QName> servicesToRemove = new ArrayList<QName>();
265                
266                Iterator serviceIterator = services.entrySet().iterator();
267                while (serviceIterator.hasNext()) {
268                    Map.Entry serviceEntry = (Map.Entry) serviceIterator.next();
269                    QName currServiceName = (QName) serviceEntry.getKey();
270                    if (currServiceName.getLocalPart().equals(serviceName)) {
271                        Service service = (Service) serviceEntry.getValue();
272                        updatePorts(portName, service, baseUri);
273                        updated = true;
274                    } else {
275                        servicesToRemove.add(currServiceName);
276                    }
277                }
278                
279                for (QName serviceToRemove : servicesToRemove) {
280                    def.removeService(serviceToRemove);                
281                }
282            }
283            if (!updated) {
284                LOG.warn("WSDL '" + serviceName + "' service not found.");
285            }
286        }
287    
288        private void updatePorts(String portName, Service service, String baseUri) throws Exception {
289            boolean updated = false;
290            Map ports = service.getPorts();
291            if (ports != null) {
292                ArrayList<String> portsToRemove = new ArrayList<String>();
293                
294                Iterator portIterator = ports.entrySet().iterator();
295                while (portIterator.hasNext()) {
296                    Map.Entry portEntry = (Map.Entry) portIterator.next();
297                    String currPortName = (String) portEntry.getKey();
298                    if (currPortName.equals(portName)) {
299                        Port port = (Port) portEntry.getValue();
300                        updatePortLocation(port, baseUri);
301                        updated = true;
302                    } else {
303                        portsToRemove.add(currPortName);
304                    }
305                }
306                
307                for (String portToRemove : portsToRemove) {
308                    service.removePort(portToRemove);               
309                }
310            }
311            if (!updated) {
312                LOG.warn("WSDL '" + portName + "' port not found.");
313            }
314        }
315    
316        private void updatePortLocation(Port port, String baseUri) throws URISyntaxException {
317            List<?> exts = port.getExtensibilityElements();
318            if (exts != null && exts.size() > 0) {
319                ExtensibilityElement el = (ExtensibilityElement) exts.get(0);
320                if (el instanceof SOAP12Address) {
321                    SOAP12Address add = (SOAP12Address) el;
322                    add.setLocationURI(baseUri);
323                } else if (el instanceof SOAPAddress) {
324                    SOAPAddress add = (SOAPAddress) el;
325                    add.setLocationURI(baseUri);
326                }
327            }
328        }
329    }