View Javadoc

1   /**
2    *
3    * Copyright 2003-2004 The Apache Software Foundation
4    *
5    *  Licensed under the Apache License, Version 2.0 (the "License");
6    *  you may not use this file except in compliance with the License.
7    *  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   *  Unless required by applicable law or agreed to in writing, software
12   *  distributed under the License is distributed on an "AS IS" BASIS,
13   *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   *  See the License for the specific language governing permissions and
15   *  limitations under the License.
16   */
17  
18  package org.apache.geronimo.security.realm.providers;
19  
20  import java.io.IOException;
21  import java.sql.Connection;
22  import java.sql.Driver;
23  import java.sql.PreparedStatement;
24  import java.sql.ResultSet;
25  import java.sql.SQLException;
26  import java.util.HashSet;
27  import java.util.Iterator;
28  import java.util.Map;
29  import java.util.Properties;
30  import java.util.Set;
31  import javax.security.auth.Subject;
32  import javax.security.auth.callback.Callback;
33  import javax.security.auth.callback.CallbackHandler;
34  import javax.security.auth.callback.NameCallback;
35  import javax.security.auth.callback.PasswordCallback;
36  import javax.security.auth.callback.UnsupportedCallbackException;
37  import javax.security.auth.login.FailedLoginException;
38  import javax.security.auth.login.LoginException;
39  import javax.security.auth.spi.LoginModule;
40  import javax.sql.DataSource;
41  
42  import org.apache.geronimo.gbean.AbstractName;
43  import org.apache.geronimo.gbean.AbstractNameQuery;
44  import org.apache.geronimo.j2ee.j2eeobjectnames.NameFactory;
45  import org.apache.geronimo.kernel.GBeanNotFoundException;
46  import org.apache.geronimo.kernel.Kernel;
47  import org.apache.geronimo.kernel.KernelRegistry;
48  import org.apache.geronimo.management.geronimo.JCAManagedConnectionFactory;
49  import org.apache.geronimo.security.jaas.JaasLoginModuleUse;
50  
51  
52  /**
53   * A login module that loads security information from a SQL database.  Expects
54   * to be run by a GenericSecurityRealm (doesn't work on its own).
55   * <p>
56   * This requires database connectivity information (either 1: a dataSourceName and
57   * optional dataSourceApplication or 2: a JDBC driver, URL, username, and password)
58   * and 2 SQL queries.
59   * <p>
60   * The userSelect query should return 2 values, the username and the password in
61   * that order.  It should include one PreparedStatement parameter (a ?) which
62   * will be filled in with the username.  In other words, the query should look
63   * like: <tt>SELECT user, password FROM users WHERE username=?</tt>
64   * <p>
65   * The groupSelect query should return 2 values, the username and the group name in
66   * that order (but it may return multiple rows, one per group).  It should include
67   * one PreparedStatement parameter (a ?) which will be filled in with the username.
68   * In other words, the query should look like:
69   * <tt>SELECT user, role FROM user_roles WHERE username=?</tt>
70   *
71   * @version $Rev: 407961 $ $Date: 2006-05-20 00:26:08 -0700 (Sat, 20 May 2006) $
72   */
73  public class SQLLoginModule implements LoginModule {
74      public final static String USER_SELECT = "userSelect";
75      public final static String GROUP_SELECT = "groupSelect";
76      public final static String CONNECTION_URL = "jdbcURL";
77      public final static String USER = "jdbcUser";
78      public final static String PASSWORD = "jdbcPassword";
79      public final static String DRIVER = "jdbcDriver";
80      public final static String DATABASE_POOL_NAME = "dataSourceName";
81      public final static String DATABASE_POOL_APP_NAME = "dataSourceApplication";
82      private String connectionURL;
83      private Properties properties;
84      private Driver driver;
85      private JCAManagedConnectionFactory factory;
86      private String userSelect;
87      private String groupSelect;
88  
89      private Subject subject;
90      private CallbackHandler handler;
91      private String cbUsername;
92      private String cbPassword;
93      private final Set groups = new HashSet();
94  
95      public void initialize(Subject subject, CallbackHandler callbackHandler, Map sharedState, Map options) {
96          this.subject = subject;
97          this.handler = callbackHandler;
98          userSelect = (String) options.get(USER_SELECT);
99          groupSelect = (String) options.get(GROUP_SELECT);
100 
101         String dataSourceName = (String) options.get(DATABASE_POOL_NAME);
102         if(dataSourceName != null) {
103             dataSourceName = dataSourceName.trim();
104             String dataSourceAppName = (String) options.get(DATABASE_POOL_APP_NAME);
105             if(dataSourceAppName == null || dataSourceAppName.trim().equals("")) {
106                 dataSourceAppName = "null";
107             } else {
108                 dataSourceAppName = dataSourceAppName.trim();
109             }
110             String kernelName = (String) options.get(JaasLoginModuleUse.KERNEL_NAME_LM_OPTION);
111             Kernel kernel = KernelRegistry.getKernel(kernelName);
112             Set set = kernel.listGBeans(new AbstractNameQuery(JCAManagedConnectionFactory.class.getName()));
113             JCAManagedConnectionFactory factory;
114             for (Iterator it = set.iterator(); it.hasNext();) {
115                 AbstractName name = (AbstractName) it.next();
116                 if(name.getName().get(NameFactory.J2EE_APPLICATION).equals(dataSourceAppName) &&
117                     name.getName().get(NameFactory.J2EE_NAME).equals(dataSourceName)) {
118                     try {
119                         factory = (JCAManagedConnectionFactory) kernel.getGBean(name);
120                         String type = factory.getConnectionFactoryInterface();
121                         if(type.equals(DataSource.class.getName())) {
122                             this.factory = factory;
123                             break;
124                         }
125                     } catch (GBeanNotFoundException e) {
126                         // ignore... GBean was unregistered
127                     }
128                 }
129             }
130         } else {
131             connectionURL = (String) options.get(CONNECTION_URL);
132             properties = new Properties();
133             if(options.get(USER) != null) {
134                 properties.put("user", options.get(USER));
135             }
136             if(options.get(PASSWORD) != null) {
137                 properties.put("password", options.get(PASSWORD));
138             }
139             ClassLoader cl = (ClassLoader) options.get(JaasLoginModuleUse.CLASSLOADER_LM_OPTION);
140             try {
141                 driver = (Driver) cl.loadClass((String) options.get(DRIVER)).newInstance();
142             } catch (ClassNotFoundException e) {
143                 throw new IllegalArgumentException("Driver class " + options.get(DRIVER) + " is not available.  Perhaps you need to add it as a dependency in your deployment plan?");
144             } catch (Exception e) {
145                 throw new IllegalArgumentException("Unable to load, instantiate, register driver " + options.get(DRIVER) + ": " + e.getMessage());
146             }
147         }
148     }
149 
150     public boolean login() throws LoginException {
151         Callback[] callbacks = new Callback[2];
152 
153         callbacks[0] = new NameCallback("User name");
154         callbacks[1] = new PasswordCallback("Password", false);
155         try {
156             handler.handle(callbacks);
157         } catch (IOException ioe) {
158             throw (LoginException) new LoginException().initCause(ioe);
159         } catch (UnsupportedCallbackException uce) {
160             throw (LoginException) new LoginException().initCause(uce);
161         }
162         assert callbacks.length == 2;
163         cbUsername = ((NameCallback) callbacks[0]).getName();
164         if (cbUsername == null || cbUsername.equals("")) {
165             return false;
166         }
167         char[] provided = ((PasswordCallback) callbacks[1]).getPassword();
168         cbPassword = provided == null ? null : new String(provided);
169 
170         boolean found = false;
171         try {
172             Connection conn;
173             if(factory != null) {
174                 DataSource ds = (DataSource) factory.getConnectionFactory();
175                 conn = ds.getConnection();
176             } else {
177                 conn = driver.connect(connectionURL, properties);
178             }
179 
180             try {
181                 PreparedStatement statement = conn.prepareStatement(userSelect);
182                 try {
183                     int count = countParameters(userSelect);
184                     for(int i=0; i<count; i++) {
185                         statement.setObject(i+1, cbUsername);
186                     }
187                     ResultSet result = statement.executeQuery();
188 
189                     try {
190                         while (result.next()) {
191                             String userName = result.getString(1);
192                             String userPassword = result.getString(2);
193 
194                             if (cbUsername.equals(userName)) {
195                                 found = (cbPassword == null && userPassword == null) ||
196                                         (cbPassword != null && userPassword != null && cbPassword.equals(userPassword));
197                                 break;
198                             }
199                         }
200                     } finally {
201                         result.close();
202                     }
203                 } finally {
204                     statement.close();
205                 }
206 
207                 if (!found) {
208                     throw new FailedLoginException();
209                 }
210 
211                 statement = conn.prepareStatement(groupSelect);
212                 try {
213                     int count = countParameters(groupSelect);
214                     for(int i=0; i<count; i++) {
215                         statement.setObject(i+1, cbUsername);
216                     }
217                     ResultSet result = statement.executeQuery();
218 
219                     try {
220                         while (result.next()) {
221                             String userName = result.getString(1);
222                             String groupName = result.getString(2);
223 
224                             if (cbUsername.equals(userName)) {
225                                 groups.add(new GeronimoGroupPrincipal(groupName));
226                             }
227                         }
228                     } finally {
229                         result.close();
230                     }
231                 } finally {
232                     statement.close();
233                 }
234             } finally {
235                 conn.close();
236             }
237         } catch (SQLException sqle) {
238             throw (LoginException) new LoginException("SQL error").initCause(sqle);
239         }
240 
241         return true;
242     }
243 
244     public boolean commit() throws LoginException {
245         Set principals = subject.getPrincipals();
246         principals.add(new GeronimoUserPrincipal(cbUsername));
247         Iterator iter = groups.iterator();
248         while (iter.hasNext()) {
249             principals.add(iter.next());
250         }
251 
252         return true;
253     }
254 
255     public boolean abort() throws LoginException {
256         cbUsername = null;
257         cbPassword = null;
258 
259         return true;
260     }
261 
262     public boolean logout() throws LoginException {
263         cbUsername = null;
264         cbPassword = null;
265         //todo: should remove principals put in by commit
266         return true;
267     }
268 
269     private static int countParameters(String sql) {
270         int count = 0;
271         int pos = -1;
272         while((pos = sql.indexOf('?', pos+1)) != -1) {
273             ++count;
274         }
275         return count;
276     }
277 }