001    /**
002     *  Licensed to the Apache Software Foundation (ASF) under one or more
003     *  contributor license agreements.  See the NOTICE file distributed with
004     *  this work for additional information regarding copyright ownership.
005     *  The ASF licenses this file to You under the Apache License, Version 2.0
006     *  (the "License"); you may not use this file except in compliance with
007     *  the License.  You may obtain a copy of the License at
008     *
009     *     http://www.apache.org/licenses/LICENSE-2.0
010     *
011     *  Unless required by applicable law or agreed to in writing, software
012     *  distributed under the License is distributed on an "AS IS" BASIS,
013     *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     *  See the License for the specific language governing permissions and
015     *  limitations under the License.
016     */
017    
018    package org.apache.geronimo.security.realm.providers;
019    
020    import java.io.IOException;
021    import java.security.MessageDigest;
022    import java.security.NoSuchAlgorithmException;
023    import java.security.Principal;
024    import java.sql.Connection;
025    import java.sql.Driver;
026    import java.sql.PreparedStatement;
027    import java.sql.ResultSet;
028    import java.sql.SQLException;
029    import java.util.HashSet;
030    import java.util.Map;
031    import java.util.Properties;
032    import java.util.Set;
033    
034    import javax.security.auth.Subject;
035    import javax.security.auth.callback.Callback;
036    import javax.security.auth.callback.CallbackHandler;
037    import javax.security.auth.callback.NameCallback;
038    import javax.security.auth.callback.PasswordCallback;
039    import javax.security.auth.callback.UnsupportedCallbackException;
040    import javax.security.auth.login.FailedLoginException;
041    import javax.security.auth.login.LoginException;
042    import javax.security.auth.spi.LoginModule;
043    import javax.sql.DataSource;
044    
045    import org.apache.commons.logging.Log;
046    import org.apache.commons.logging.LogFactory;
047    import org.apache.geronimo.gbean.AbstractName;
048    import org.apache.geronimo.gbean.AbstractNameQuery;
049    import org.apache.geronimo.j2ee.j2eeobjectnames.NameFactory;
050    import org.apache.geronimo.kernel.GBeanNotFoundException;
051    import org.apache.geronimo.kernel.Kernel;
052    import org.apache.geronimo.kernel.KernelRegistry;
053    import org.apache.geronimo.management.geronimo.JCAManagedConnectionFactory;
054    import org.apache.geronimo.security.jaas.JaasLoginModuleUse;
055    import org.apache.geronimo.util.encoders.Base64;
056    import org.apache.geronimo.util.encoders.HexTranslator;
057    
058    
059    /**
060     * A login module that loads security information from a SQL database.  Expects
061     * to be run by a GenericSecurityRealm (doesn't work on its own).
062     * <p/>
063     * This requires database connectivity information (either 1: a dataSourceName and
064     * optional dataSourceApplication or 2: a JDBC driver, URL, username, and password)
065     * and 2 SQL queries.
066     * <p/>
067     * The userSelect query should return 2 values, the username and the password in
068     * that order.  It should include one PreparedStatement parameter (a ?) which
069     * will be filled in with the username.  In other words, the query should look
070     * like: <tt>SELECT user, password FROM credentials WHERE username=?</tt>
071     * <p/>
072     * The groupSelect query should return 2 values, the username and the group name in
073     * that order (but it may return multiple rows, one per group).  It should include
074     * one PreparedStatement parameter (a ?) which will be filled in with the username.
075     * In other words, the query should look like:
076     * <tt>SELECT user, role FROM user_roles WHERE username=?</tt>
077     * <p/>
078     * This login module checks security credentials so the lifecycle methods must return true to indicate success
079     * or throw LoginException to indicate failure.
080     *
081     * @version $Rev: 565912 $ $Date: 2007-08-14 17:03:11 -0400 (Tue, 14 Aug 2007) $
082     */
083    public class SQLLoginModule implements LoginModule {
084        private static Log log = LogFactory.getLog(SQLLoginModule.class);
085        public final static String USER_SELECT = "userSelect";
086        public final static String GROUP_SELECT = "groupSelect";
087        public final static String CONNECTION_URL = "jdbcURL";
088        public final static String USER = "jdbcUser";
089        public final static String PASSWORD = "jdbcPassword";
090        public final static String DRIVER = "jdbcDriver";
091        public final static String DATABASE_POOL_NAME = "dataSourceName";
092        public final static String DATABASE_POOL_APP_NAME = "dataSourceApplication";
093        public final static String DIGEST = "digest";
094        public final static String ENCODING = "encoding";
095        private String connectionURL;
096        private Properties properties;
097        private Driver driver;
098        private JCAManagedConnectionFactory factory;
099        private String userSelect;
100        private String groupSelect;
101        private String digest;
102        private String encoding;
103    
104        private Subject subject;
105        private CallbackHandler handler;
106        private String cbUsername;
107        private String cbPassword;
108        private final Set<Principal> groups = new HashSet<Principal>();
109    
110        public void initialize(Subject subject, CallbackHandler callbackHandler, Map sharedState, Map options) {
111            this.subject = subject;
112            this.handler = callbackHandler;
113            userSelect = (String) options.get(USER_SELECT);
114            groupSelect = (String) options.get(GROUP_SELECT);
115    
116            digest = (String) options.get(DIGEST);
117            encoding = (String) options.get(ENCODING);
118            if (digest != null && !digest.equals("")) {
119                // Check if the digest algorithm is available
120                try {
121                    MessageDigest.getInstance(digest);
122                } catch (NoSuchAlgorithmException e) {
123                    log.error("Initialization failed. Digest algorithm " + digest + " is not available.", e);
124                    throw new IllegalArgumentException("Unable to configure SQL login module: " + e.getMessage(), e);
125                }
126                if (encoding != null && !"hex".equalsIgnoreCase(encoding) && !"base64".equalsIgnoreCase(encoding)) {
127                    log.error("Initialization failed. Digest Encoding " + encoding + " is not supported.");
128                    throw new IllegalArgumentException(
129                            "Unable to configure SQL login module. Digest Encoding " + encoding + " not supported.");
130                }
131            }
132    
133            String dataSourceName = (String) options.get(DATABASE_POOL_NAME);
134            if (dataSourceName != null) {
135                dataSourceName = dataSourceName.trim();
136                String dataSourceAppName = (String) options.get(DATABASE_POOL_APP_NAME);
137                if (dataSourceAppName == null || dataSourceAppName.trim().equals("")) {
138                    dataSourceAppName = "null";
139                } else {
140                    dataSourceAppName = dataSourceAppName.trim();
141                }
142                String kernelName = (String) options.get(JaasLoginModuleUse.KERNEL_NAME_LM_OPTION);
143                Kernel kernel = KernelRegistry.getKernel(kernelName);
144                Set<AbstractName> set = kernel.listGBeans(new AbstractNameQuery(JCAManagedConnectionFactory.class.getName()));
145                JCAManagedConnectionFactory factory;
146                for (AbstractName name : set) {
147                    if (name.getName().get(NameFactory.J2EE_APPLICATION).equals(dataSourceAppName) &&
148                            name.getName().get(NameFactory.J2EE_NAME).equals(dataSourceName)) {
149                        try {
150                            factory = (JCAManagedConnectionFactory) kernel.getGBean(name);
151                            String type = factory.getConnectionFactoryInterface();
152                            if (type.equals(DataSource.class.getName())) {
153                                this.factory = factory;
154                                break;
155                            }
156                        } catch (GBeanNotFoundException e) {
157                            // ignore... GBean was unregistered
158                        }
159                    }
160                }
161            } else {
162                connectionURL = (String) options.get(CONNECTION_URL);
163                properties = new Properties();
164                if (options.get(USER) != null) {
165                    properties.put("user", options.get(USER));
166                }
167                if (options.get(PASSWORD) != null) {
168                    properties.put("password", options.get(PASSWORD));
169                }
170                ClassLoader cl = (ClassLoader) options.get(JaasLoginModuleUse.CLASSLOADER_LM_OPTION);
171                try {
172                    driver = (Driver) cl.loadClass((String) options.get(DRIVER)).newInstance();
173                } catch (ClassNotFoundException e) {
174                    throw new IllegalArgumentException("Driver class " + options.get(
175                            DRIVER) + " is not available.  Perhaps you need to add it as a dependency in your deployment plan?",
176                            e);
177                } catch (Exception e) {
178                    throw new IllegalArgumentException(
179                            "Unable to load, instantiate, register driver " + options.get(DRIVER) + ": " + e.getMessage(),
180                            e);
181                }
182            }
183        }
184    
185        public boolean login() throws LoginException {
186            Callback[] callbacks = new Callback[2];
187    
188            callbacks[0] = new NameCallback("User name");
189            callbacks[1] = new PasswordCallback("Password", false);
190            try {
191                handler.handle(callbacks);
192            } catch (IOException ioe) {
193                throw (LoginException) new LoginException().initCause(ioe);
194            } catch (UnsupportedCallbackException uce) {
195                throw (LoginException) new LoginException().initCause(uce);
196            }
197            assert callbacks.length == 2;
198            cbUsername = ((NameCallback) callbacks[0]).getName();
199            if (cbUsername == null || cbUsername.equals("")) {
200                throw new FailedLoginException();
201            }
202            char[] provided = ((PasswordCallback) callbacks[1]).getPassword();
203            cbPassword = provided == null ? null : new String(provided);
204    
205            try {
206                Connection conn;
207                if (factory != null) {
208                    DataSource ds = (DataSource) factory.getConnectionFactory();
209                    conn = ds.getConnection();
210                } else {
211                    conn = driver.connect(connectionURL, properties);
212                }
213    
214                try {
215                    PreparedStatement statement = conn.prepareStatement(userSelect);
216                    try {
217                        int count = countParameters(userSelect);
218                        for (int i = 0; i < count; i++) {
219                            statement.setObject(i + 1, cbUsername);
220                        }
221                        ResultSet result = statement.executeQuery();
222    
223                        try {
224                            while (result.next()) {
225                                String userName = result.getString(1);
226                                String userPassword = result.getString(2);
227    
228                                if (cbUsername.equals(userName)) {
229                                    if (!checkPassword(userPassword, cbPassword)) {
230                                        throw new FailedLoginException();
231                                    }
232                                    break;
233                                }
234                            }
235                        } finally {
236                            result.close();
237                        }
238                    } finally {
239                        statement.close();
240                    }
241    
242                    statement = conn.prepareStatement(groupSelect);
243                    try {
244                        int count = countParameters(groupSelect);
245                        for (int i = 0; i < count; i++) {
246                            statement.setObject(i + 1, cbUsername);
247                        }
248                        ResultSet result = statement.executeQuery();
249    
250                        try {
251                            while (result.next()) {
252                                String userName = result.getString(1);
253                                String groupName = result.getString(2);
254    
255                                if (cbUsername.equals(userName)) {
256                                    groups.add(new GeronimoGroupPrincipal(groupName));
257                                }
258                            }
259                        } finally {
260                            result.close();
261                        }
262                    } finally {
263                        statement.close();
264                    }
265                } finally {
266                    conn.close();
267                }
268            } catch (SQLException sqle) {
269                throw (LoginException) new LoginException("SQL error").initCause(sqle);
270            } catch (Exception e) {
271                throw (LoginException) new LoginException("Could not access datasource").initCause(e);
272            }
273    
274            return true;
275        }
276    
277        public boolean commit() throws LoginException {
278            Set<Principal> principals = subject.getPrincipals();
279            principals.add(new GeronimoUserPrincipal(cbUsername));
280            principals.addAll(groups);
281    
282            return true;
283        }
284    
285        public boolean abort() throws LoginException {
286            cbUsername = null;
287            cbPassword = null;
288    
289            return true;
290        }
291    
292        public boolean logout() throws LoginException {
293            cbUsername = null;
294            cbPassword = null;
295            //todo: should remove principals put in by commit
296            return true;
297        }
298    
299        private static int countParameters(String sql) {
300            int count = 0;
301            int pos = -1;
302            while ((pos = sql.indexOf('?', pos + 1)) != -1) {
303                ++count;
304            }
305            return count;
306        }
307    
308        /**
309         * This method checks if the provided password is correct.  The original password may have been digested.
310         *
311         * @param real     Original password in digested form if applicable
312         * @param provided User provided password in clear text
313         * @return true     If the password is correct
314         */
315        private boolean checkPassword(String real, String provided) {
316            if (real == null && provided == null) {
317                return true;
318            }
319            if (real == null || provided == null) {
320                return false;
321            }
322    
323            //both are non-null
324            if (digest == null || digest.equals("")) {
325                // No digest algorithm is used
326                return real.equals(provided);
327            }
328            try {
329                // Digest the user provided password
330                MessageDigest md = MessageDigest.getInstance(digest);
331                byte[] data = md.digest(provided.getBytes());
332                if (encoding == null || "hex".equalsIgnoreCase(encoding)) {
333                    // Convert bytes to hex digits
334                    byte[] hexData = new byte[data.length * 2];
335                    HexTranslator ht = new HexTranslator();
336                    ht.encode(data, 0, data.length, hexData, 0);
337                    // Compare the digested provided password with the actual one
338                    return real.equalsIgnoreCase(new String(hexData));
339                } else if ("base64".equalsIgnoreCase(encoding)) {
340                    return real.equals(new String(Base64.encode(data)));
341                }
342            } catch (NoSuchAlgorithmException e) {
343                // Should not occur.  Availability of algorithm has been checked at initialization
344                log.error("Should not occur.  Availability of algorithm has been checked at initialization.", e);
345            }
346            return false;
347        }
348    }