001    /**
002     *  Licensed to the Apache Software Foundation (ASF) under one or more
003     *  contributor license agreements.  See the NOTICE file distributed with
004     *  this work for additional information regarding copyright ownership.
005     *  The ASF licenses this file to You under the Apache License, Version 2.0
006     *  (the "License"); you may not use this file except in compliance with
007     *  the License.  You may obtain a copy of the License at
008     *
009     *     http://www.apache.org/licenses/LICENSE-2.0
010     *
011     *  Unless required by applicable law or agreed to in writing, software
012     *  distributed under the License is distributed on an "AS IS" BASIS,
013     *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     *  See the License for the specific language governing permissions and
015     *  limitations under the License.
016     */
017    package org.apache.geronimo.cxf;
018    
019    import java.io.IOException;
020    import java.io.OutputStream;
021    import java.io.Serializable;
022    import java.io.InputStream;
023    import java.net.HttpURLConnection;
024    import java.security.Principal;
025    import java.util.Iterator;
026    import java.util.ArrayList;
027    import java.util.List;
028    import java.util.Map;
029    import java.util.Enumeration;
030    import java.util.logging.Logger;
031    import javax.servlet.ServletContext;
032    import javax.servlet.http.HttpServletRequest;
033    import javax.servlet.http.HttpServletResponse;
034    import javax.xml.ws.handler.MessageContext;
035    
036    import org.apache.cxf.Bus;
037    import org.apache.cxf.message.Message;
038    import org.apache.cxf.message.MessageImpl;
039    import org.apache.cxf.security.SecurityContext;
040    import org.apache.cxf.service.model.EndpointInfo;
041    import org.apache.cxf.transport.Conduit;
042    import org.apache.cxf.transport.ConduitInitiator;
043    import org.apache.cxf.transport.Destination;
044    import org.apache.cxf.transport.MessageObserver;
045    import org.apache.cxf.transport.http.AbstractHTTPDestination;
046    import org.apache.cxf.ws.addressing.EndpointReferenceType;
047    import org.apache.geronimo.webservices.WebServiceContainer;
048    import org.apache.geronimo.webservices.WebServiceContainer.Request;
049    import org.apache.geronimo.webservices.WebServiceContainer.Response;
050    
051    public class GeronimoDestination extends AbstractHTTPDestination
052            implements Serializable {
053    
054        private MessageObserver messageObserver;
055        private boolean passSecurityContext = false;
056    
057        public GeronimoDestination(Bus bus, 
058                                   ConduitInitiator conduitInitiator, 
059                                   EndpointInfo endpointInfo) throws IOException {
060            super(bus, conduitInitiator, endpointInfo, true);
061        }
062    
063        public void setPassSecurityContext(boolean passSecurityContext) {
064            this.passSecurityContext = passSecurityContext;
065        }
066        
067        public boolean getPassSecurityContext() {
068            return this.passSecurityContext;
069        }
070        
071        public EndpointInfo getEndpointInfo() {
072            return this.endpointInfo;
073        }
074    
075        public void invoke(Request request, Response response) throws Exception {
076            MessageImpl message = new MessageImpl();
077            message.setContent(InputStream.class, request.getInputStream());
078            message.setDestination(this);
079    
080            message.put(Request.class, request);
081            message.put(Response.class, response);
082    
083            final HttpServletRequest servletRequest = 
084                (HttpServletRequest)request.getAttribute(WebServiceContainer.SERVLET_REQUEST);
085            message.put(MessageContext.SERVLET_REQUEST, servletRequest);
086            
087            HttpServletResponse servletResponse =
088                (HttpServletResponse)request.getAttribute(WebServiceContainer.SERVLET_RESPONSE);
089            message.put(MessageContext.SERVLET_RESPONSE, servletResponse);
090            
091            ServletContext servletContext = 
092                (ServletContext)request.getAttribute(WebServiceContainer.SERVLET_CONTEXT);
093            message.put(MessageContext.SERVLET_CONTEXT, servletContext);
094            
095            if (this.passSecurityContext) {
096                message.put(SecurityContext.class, new SecurityContext() {
097                    public Principal getUserPrincipal() {
098                        return servletRequest.getUserPrincipal();
099                    }
100                    public boolean isUserInRole(String role) {
101                        return servletRequest.isUserInRole(role);
102                    }
103                });
104            }
105            
106            // this calls copyRequestHeaders()
107            setHeaders(message);
108            
109            message.put(Message.HTTP_REQUEST_METHOD, servletRequest.getMethod());
110            message.put(Message.PATH_INFO, servletRequest.getPathInfo());
111            message.put(Message.QUERY_STRING, servletRequest.getQueryString());
112            message.put(Message.CONTENT_TYPE, servletRequest.getContentType());
113            message.put(Message.ENCODING, getCharacterEncoding(servletRequest.getCharacterEncoding()));
114            
115            messageObserver.onMessage(message);
116        }
117    
118        private static String getCharacterEncoding(String encoding) {
119            if (encoding != null) {
120                encoding = encoding.trim();
121                // work around a bug with Jetty which results in the character
122                // encoding not being trimmed correctly:
123                // http://jira.codehaus.org/browse/JETTY-302
124                if (encoding.endsWith("\"")) {
125                    encoding = encoding.substring(0, encoding.length() - 1);
126                }
127            }
128            return encoding;
129        }
130        
131        protected void copyRequestHeaders(Message message, Map<String, List<String>> headers) {
132            HttpServletRequest servletRequest = (HttpServletRequest)message.get(MessageContext.SERVLET_REQUEST);
133            if (servletRequest != null) {
134                Enumeration names = servletRequest.getHeaderNames();
135                while(names.hasMoreElements()) {
136                    String name = (String)names.nextElement();
137                    
138                    List<String> headerValues = headers.get(name);
139                    if (headerValues == null) {
140                        headerValues = new ArrayList<String>();
141                        headers.put(name, headerValues);
142                    }
143                    
144                    Enumeration values = servletRequest.getHeaders(name);
145                    while(values.hasMoreElements()) {
146                        String value = (String)values.nextElement();
147                        headerValues.add(value);
148                    }
149                }
150            }
151        }
152    
153        public Logger getLogger() {
154            return Logger.getLogger(GeronimoDestination.class.getName());
155        }
156    
157        public Conduit getInbuiltBackChannel(Message inMessage) {
158            return new BackChannelConduit(null, inMessage);
159        }
160    
161        public Conduit getBackChannel(Message inMessage,
162                                      Message partialResponse,
163                                      EndpointReferenceType address) throws IOException {
164            Conduit backChannel = null;
165            if (address == null) {
166                backChannel = new BackChannelConduit(address, inMessage);
167            } else {
168                if (partialResponse != null) {
169                    // setup the outbound message to for 202 Accepted
170                    partialResponse.put(Message.RESPONSE_CODE,
171                                        HttpURLConnection.HTTP_ACCEPTED);
172                    backChannel = new BackChannelConduit(address, inMessage);
173                } else {
174                    backChannel = conduitInitiator.getConduit(endpointInfo, address);
175                    // ensure decoupled back channel input stream is closed
176                    backChannel.setMessageObserver(new MessageObserver() {
177                        public void onMessage(Message m) {
178                            if (m.getContentFormats().contains(InputStream.class)) {
179                                InputStream is = m.getContent(InputStream.class);
180                                try {
181                                    is.close();
182                                } catch (Exception e) {
183                                    // ignore
184                                }
185                            }
186                        }
187                    });
188                }
189            }
190            return backChannel;
191        }
192    
193        public void shutdown() {
194        }
195    
196        public void setMessageObserver(MessageObserver messageObserver) {
197            this.messageObserver = messageObserver;
198        }
199    
200        protected class BackChannelConduit implements Conduit {
201    
202            protected Message request;
203            protected EndpointReferenceType target;
204    
205            BackChannelConduit(EndpointReferenceType target, Message request) {
206                this.target = target;
207                this.request = request;
208            }
209    
210            public void close(Message msg) throws IOException {
211                msg.getContent(OutputStream.class).close();
212            }
213    
214            /**
215             * Register a message observer for incoming messages.
216             *
217             * @param observer the observer to notify on receipt of incoming
218             */
219            public void setMessageObserver(MessageObserver observer) {
220                // shouldn't be called for a back channel conduit
221            }
222            
223            public void prepare(Message message) throws IOException {
224                send(message);
225            }
226    
227            /**
228             * Send an outbound message, assumed to contain all the name-value
229             * mappings of the corresponding input message (if any).
230             *
231             * @param message the message to be sent.
232             */
233            public void send(Message message) throws IOException {
234                Response response = (Response)request.get(Response.class);
235                
236                // handle response headers
237                updateResponseHeaders(message);
238    
239                Map<String, List<String>> protocolHeaders = 
240                    (Map<String, List<String>>) message.get(Message.PROTOCOL_HEADERS);
241    
242                // set headers of the HTTP response object
243                Iterator headers = protocolHeaders.entrySet().iterator();
244                while (headers.hasNext()) {
245                    Map.Entry entry = (Map.Entry) headers.next();
246                    String headerName = (String) entry.getKey();
247                    String headerValue = getHeaderValue((List) entry.getValue());
248                    response.setHeader(headerName, headerValue);
249                }
250                
251                message.setContent(OutputStream.class, new WrappedOutputStream(message, response));
252            }
253            
254            /**
255             * @return the reference associated with the target Destination
256             */
257            public EndpointReferenceType getTarget() {
258                return target;
259            }
260    
261            /**
262             * Retreive the back-channel Destination.
263             *
264             * @return the backchannel Destination (or null if the backchannel is
265             *         built-in)
266             */
267            public Destination getBackChannel() {
268                return null;
269            }
270    
271            /**
272             * Close the conduit
273             */
274            public void close() {
275            }
276        }
277            
278        private String getHeaderValue(List<String> values) {
279            Iterator iter = values.iterator();
280            StringBuffer buf = new StringBuffer();
281            while(iter.hasNext()) {
282                buf.append(iter.next());
283                if (iter.hasNext()) {
284                    buf.append(", ");
285                }
286            }
287            return buf.toString();
288        }
289        
290        protected void setContentType(Message message, Response response) {                
291            Map<String, List<String>> protocolHeaders =
292                (Map<String, List<String>>)message.get(Message.PROTOCOL_HEADERS);
293            
294            if (protocolHeaders == null || !protocolHeaders.containsKey(Message.CONTENT_TYPE)) {
295                String ct = (String) message.get(Message.CONTENT_TYPE);
296                String enc = (String) message.get(Message.ENCODING);
297                
298                if (null != ct) {
299                    if (enc != null && ct.indexOf("charset=") == -1) {
300                        ct = ct + "; charset=" + enc;
301                    }
302                    response.setContentType(ct);
303                } else if (enc != null) {
304                    response.setContentType("text/xml; charset=" + enc);
305                }
306            }     
307        }
308                   
309        private class WrappedOutputStream extends OutputStream {
310    
311            private Message message;
312            private Response response;
313            private OutputStream rawOutputStream;
314    
315            WrappedOutputStream(Message message, Response response) {
316                this.message = message;
317                this.response = response;
318            }
319    
320            public void write(int b) throws IOException {
321                flushHeaders();
322                this.rawOutputStream.write(b);
323            }
324    
325            public void write(byte b[]) throws IOException {
326                flushHeaders();
327                this.rawOutputStream.write(b);
328            }
329    
330            public void write(byte b[], int off, int len) throws IOException {
331                flushHeaders();
332                this.rawOutputStream.write(b, off, len);
333            }
334    
335            public void flush() throws IOException {
336                flushHeaders();
337                this.rawOutputStream.flush();
338            }
339    
340            public void close() throws IOException {
341                flushHeaders();
342                this.rawOutputStream.close();
343            }
344            
345             protected void flushHeaders() throws IOException {
346                if (this.rawOutputStream != null) {
347                    return;
348                }
349    
350                // set response code
351                Integer i = (Integer) this.message.get(Message.RESPONSE_CODE);
352                if (i != null) {
353                    this.response.setStatusCode(i.intValue());
354                }
355                
356                // set content-type
357                setContentType(this.message, this.response);
358                
359                this.rawOutputStream = this.response.getOutputStream();
360            }
361    
362        }
363           
364    }