Skip to content
Snippets Groups Projects
Commit 39034656 authored by brinn's avatar brinn
Browse files

[BIS-260] Fix SQL injections for custom queries.

SVN: 27601
parent 9e995a13
No related merge requests found
......@@ -16,10 +16,17 @@
package ch.systemsx.cisd.openbis.plugin.query.server;
import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.ParameterMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLDataException;
import java.sql.SQLException;
import java.sql.Time;
import java.sql.Timestamp;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
......@@ -32,6 +39,7 @@ import javax.sql.DataSource;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCallback;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.simple.SimpleJdbcDaoSupport;
import org.springframework.jdbc.support.JdbcUtils;
......@@ -193,35 +201,100 @@ class DAO extends SimpleJdbcDaoSupport implements IDAO
template.setFetchSize(FETCH_SIZE);
template.setMaxRows(MAX_ROWS + 1); // fetch one more row than allowed to detect excess
template.setQueryTimeout(QUERY_TIMEOUT_SECS);
final String resolvedQuery = createSQLQueryWithBindingsResolved(sqlQuery, bindingsOrNull);
final PreparedStatementCreator resolvedQuery =
createSQLPreparedStatement(sqlQuery, bindingsOrNull);
return (TableModel) template.execute(resolvedQuery, callback);
}
// FIXME this solution is not safe.
// We should use a prepared statement and set the parameters according to the information
// that PreparedStatement.getParameterMetaData() provides or check whether setObject() does the
// trick for us here.
private static String createSQLQueryWithBindingsResolved(String sqlQuery,
QueryParameterBindings bindingsOrNull)
private static PreparedStatementCreator createSQLPreparedStatement(final String sqlQuery,
final QueryParameterBindings bindingsOrNull)
{
Template template = new Template(sqlQuery);
if (bindingsOrNull != null)
{
for (Entry<String, String> entry : bindingsOrNull.getBindings().entrySet())
return new PreparedStatementCreator()
{
validateParameterValue(entry.getValue());
template.bind(entry.getKey(), entry.getValue());
}
}
return template.createText();
}
private static void validateParameterValue(String value) throws UserFailureException
{
if (value.contains("'"))
{
throw new UserFailureException("Parameter value \"" + value
+ "\" contains invalid character.");
}
@Override
public PreparedStatement createPreparedStatement(Connection con)
throws SQLException
{
final Map<Integer, Entry<String, String>> indexMap =
new HashMap<Integer, Entry<String, String>>();
final Template template = new Template(sqlQuery);
if (bindingsOrNull != null)
{
for (Entry<String, String> entry : bindingsOrNull.getBindings().entrySet())
{
template.bind(entry.getKey(), "?");
indexMap.put(template.tryGetIndex(entry.getKey()), entry);
}
}
final PreparedStatement psm = con.prepareStatement(template.createText());
final ParameterMetaData pmd = psm.getParameterMetaData();
for (int i = 1; i <= pmd.getParameterCount(); ++i)
{
final Entry<String, String> entry = indexMap.get(i - 1);
final String strValue = entry.getValue();
try
{
switch (pmd.getParameterType(i))
{
case Types.BIT:
case Types.BOOLEAN:
psm.setBoolean(i, Boolean.parseBoolean(strValue));
break;
case Types.TINYINT:
psm.setByte(i, Byte.parseByte(strValue));
break;
case Types.SMALLINT:
psm.setShort(i, Short.parseShort(strValue));
break;
case Types.INTEGER:
psm.setInt(i, Integer.parseInt(strValue));
break;
case Types.BIGINT:
psm.setLong(i, Long.parseLong(strValue));
break;
case Types.FLOAT:
case Types.REAL:
psm.setFloat(i, Float.parseFloat(strValue));
break;
case Types.DOUBLE:
psm.setDouble(i, Double.parseDouble(strValue));
break;
case Types.NUMERIC:
case Types.DECIMAL:
psm.setBigDecimal(i, new BigDecimal(strValue));
break;
case Types.CHAR:
case Types.VARCHAR:
case Types.LONGVARCHAR:
case Types.NCHAR:
case Types.NVARCHAR:
case Types.LONGNVARCHAR:
case Types.ARRAY:
psm.setString(i, strValue);
break;
case Types.TIME:
psm.setTime(i, Time.valueOf(strValue));
break;
case Types.DATE:
psm.setDate(i, java.sql.Date.valueOf(strValue));
break;
case Types.TIMESTAMP:
psm.setTimestamp(i, Timestamp.valueOf(strValue));
break;
default:
throw new SQLDataException("Unsupported SQL type "
+ pmd.getParameterTypeName(i) + "("
+ pmd.getParameterType(i) + ") for variable "
+ entry.getKey());
}
} catch (RuntimeException ex)
{
throw new SQLDataException("Invalid value '" + entry.getValue()
+ "' for variable " + entry.getKey(), ex);
}
}
return psm;
}
};
}
}
......@@ -22,6 +22,7 @@ import java.util.List;
import javax.sql.DataSource;
import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.testng.AbstractTransactionalTestNGSpringContextTests;
import org.springframework.test.context.transaction.TransactionConfiguration;
......@@ -50,6 +51,7 @@ public class DAOTest extends AbstractTransactionalTestNGSpringContextTests
TestInitializer.init();
}
@Test
public void testQuery()
{
String query =
......@@ -57,6 +59,7 @@ public class DAOTest extends AbstractTransactionalTestNGSpringContextTests
testQueryWithBindings(query, null);
}
@Test
public void testQueryWithBindings()
{
String query =
......@@ -90,6 +93,27 @@ public class DAOTest extends AbstractTransactionalTestNGSpringContextTests
assertEquals(2, rows.size());
}
@Test
public void testQueryWithArrayBinding()
{
String query =
"select id, code as DATA_SET_KEY, registration_timestamp, is_valid from data where code = any(${codes}::text[]) order by id";
QueryParameterBindings bindings = new QueryParameterBindings();
bindings.addBinding("codes", "{20081105092159188-3, 20081105092159111-1}");
testQueryWithBindings(query, bindings);
}
@Test(expectedExceptions = DataIntegrityViolationException.class)
public void testQueryWithBindingsSQLInjection()
{
String query =
"select id, code as DATA_SET_KEY, registration_timestamp, is_valid from data where id < ${id} order by id";
QueryParameterBindings bindings = new QueryParameterBindings();
bindings.addBinding("id",
"6 union select id, user_id, registration_timestamp, is_active from persons");
testQueryWithBindings(query, bindings);
}
@Test
public void testQueryWithSpecialCharacters()
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment