Commit 16693fb4 by 郑冰晶

数据库加密组件

parent 53c87e57
...@@ -107,6 +107,7 @@ public class ActivityController { ...@@ -107,6 +107,7 @@ public class ActivityController {
PriceRuleTask updatePriceRuleTask = new PriceRuleTask(); PriceRuleTask updatePriceRuleTask = new PriceRuleTask();
updatePriceRuleTask.setId(id); updatePriceRuleTask.setId(id);
updatePriceRuleTask.setCreator(priceRuleTaskCriteria.getCreator()); updatePriceRuleTask.setCreator(priceRuleTaskCriteria.getCreator());
updatePriceRuleTask.setModifier(priceRuleTaskCriteria.getModifier());
updatePriceRuleTasks.add(updatePriceRuleTask); updatePriceRuleTasks.add(updatePriceRuleTask);
} }
...@@ -114,4 +115,10 @@ public class ActivityController { ...@@ -114,4 +115,10 @@ public class ActivityController {
log.debug("rows={}", rows); log.debug("rows={}", rows);
return String.valueOf(rows); return String.valueOf(rows);
} }
@RequestMapping("testJoin")
public String testJoin(@RequestBody PriceRuleTaskCriteria priceRuleTaskCriteria) {
this.priceRuleTaskMapper.testJoin(priceRuleTaskCriteria);
return null;
}
} }
...@@ -30,4 +30,6 @@ public interface PriceRuleTaskMapper { ...@@ -30,4 +30,6 @@ public interface PriceRuleTaskMapper {
List<PriceRuleTask> queryExecutePriceRuleTaskList(PriceRuleTaskCriteria priceRuleTaskCriteria); List<PriceRuleTask> queryExecutePriceRuleTaskList(PriceRuleTaskCriteria priceRuleTaskCriteria);
List<PriceRuleTask> testJoin(PriceRuleTaskCriteria priceRuleTaskCriteria);
} }
\ No newline at end of file
...@@ -482,4 +482,10 @@ ...@@ -482,4 +482,10 @@
</select> </select>
<select id="testJoin" resultMap="BaseResultMap" parameterType="com.secoo.mall.datasource.security.demo.bean.PriceRuleTaskCriteria">
select prt.id,br.brand_id,prt.remark from t_price_rule_task prt
inner join t_brand_rule br on prt.brand_id = br.brand_id
where prt.task_id=#{taskId}
</select>
</mapper> </mapper>
\ No newline at end of file
...@@ -57,7 +57,7 @@ public class ApolloPropertyProviderAlgorithm implements PropertyProviderAlgorith ...@@ -57,7 +57,7 @@ public class ApolloPropertyProviderAlgorithm implements PropertyProviderAlgorith
.collect(Collectors.toList()); .collect(Collectors.toList());
propertyNameSections = propertyNameSections.stream().filter(key -> key.length == 4).collect(Collectors.toList()); propertyNameSections = propertyNameSections.stream().filter(key -> key.length == 4).collect(Collectors.toList());
Map<String, Class> fieldMap = FieldUtil.getAllFieldsList(ColumnRule.class).stream().collect(Collectors.toMap(Field::getName, Field::getType, (key1, key2) -> key1)); Map<String, Class<?>> fieldMap = FieldUtil.getAllFieldsList(ColumnRule.class).stream().collect(Collectors.toMap(Field::getName, Field::getType, (key1, key2) -> key1));
Set<String> notSupportPropertySet = new HashSet<>(); Set<String> notSupportPropertySet = new HashSet<>();
// db // db
...@@ -124,10 +124,14 @@ public class ApolloPropertyProviderAlgorithm implements PropertyProviderAlgorith ...@@ -124,10 +124,14 @@ public class ApolloPropertyProviderAlgorithm implements PropertyProviderAlgorith
if(StringUtils.isBlank(columnRule.getEncryptType()) || StringUtils.isBlank(columnRule.getEncryptKey()) || StringUtils.isBlank(columnRule.getCipherColumn())){ if(StringUtils.isBlank(columnRule.getEncryptType()) || StringUtils.isBlank(columnRule.getEncryptKey()) || StringUtils.isBlank(columnRule.getCipherColumn())){
continue; continue;
} }
// 加密器
Properties properties = new Properties(); Properties properties = new Properties();
properties.setProperty(EncryptAlgorithm.ENCRYPT_KEY,columnRule.getEncryptKey()); properties.setProperty(EncryptAlgorithm.ENCRYPT_KEY,columnRule.getEncryptKey());
EncryptAlgorithm encryptAlgorithm = SecurityAlgorithmFactory.getObject(EncryptAlgorithm.class,columnRule.getEncryptType().toLowerCase(), properties); EncryptAlgorithm encryptAlgorithm = SecurityAlgorithmFactory.getObject(EncryptAlgorithm.class,columnRule.getEncryptType().toLowerCase(), properties);
encryptAlgorithm.init(); if(encryptAlgorithm == null){
continue;
}
columnRule.setEncryptAlgorithm(encryptAlgorithm); columnRule.setEncryptAlgorithm(encryptAlgorithm);
if(tableRule.getColumnRules() == null){ if(tableRule.getColumnRules() == null){
...@@ -135,11 +139,19 @@ public class ApolloPropertyProviderAlgorithm implements PropertyProviderAlgorith ...@@ -135,11 +139,19 @@ public class ApolloPropertyProviderAlgorithm implements PropertyProviderAlgorith
} }
tableRule.getColumnRules().add(columnRule); tableRule.getColumnRules().add(columnRule);
} }
if(tableRule.getColumnRules() == null || tableRule.getColumnRules().isEmpty()){
continue;
}
if(dbRule.getTableRules() == null){ if(dbRule.getTableRules() == null){
dbRule.setTableRules(new HashSet<>()); dbRule.setTableRules(new HashSet<>());
} }
dbRule.getTableRules().add(tableRule); dbRule.getTableRules().add(tableRule);
} }
if(dbRule.getTableRules() == null || dbRule.getTableRules().isEmpty()){
continue;
}
dbRules.add(dbRule); dbRules.add(dbRule);
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
throw new SecurityBizException("!!! Load security rule from apollo error !!!",e); throw new SecurityBizException("!!! Load security rule from apollo error !!!",e);
......
...@@ -13,9 +13,10 @@ import com.alibaba.druid.util.Utils; ...@@ -13,9 +13,10 @@ import com.alibaba.druid.util.Utils;
import com.secoo.mall.datasource.security.constant.PropertyProviderType; import com.secoo.mall.datasource.security.constant.PropertyProviderType;
import com.secoo.mall.datasource.security.exception.SecurityBizException; import com.secoo.mall.datasource.security.exception.SecurityBizException;
import com.secoo.mall.datasource.security.rule.ColumnRule; import com.secoo.mall.datasource.security.rule.ColumnRule;
import com.secoo.mall.datasource.security.util.CollectionUtil;
import com.secoo.mall.datasource.security.util.SecurityUtil; import com.secoo.mall.datasource.security.util.SecurityUtil;
import com.secoo.mall.datasource.security.visitor.MySqlSecurityParameterVisitor; import com.secoo.mall.datasource.security.visitor.MySqlEncryptParameterVisitor;
import com.secoo.mall.datasource.security.visitor.Parameter; import com.secoo.mall.datasource.security.visitor.model.Parameter;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
...@@ -27,18 +28,33 @@ import java.sql.ResultSetMetaData; ...@@ -27,18 +28,33 @@ import java.sql.ResultSetMetaData;
import java.sql.SQLException; import java.sql.SQLException;
import java.sql.Types; import java.sql.Types;
import java.util.*; import java.util.*;
import java.util.stream.Collectors;
@AutoLoad @AutoLoad
public class SecurityFilter extends SecurityFilterEventAdapter { public class SecurityFilter extends SecurityFilterEventAdapter {
private static final Logger log = LoggerFactory.getLogger(SecurityFilter.class); private static final Logger log = LoggerFactory.getLogger(SecurityFilter.class);
private SecurityFilterContext securityFilterContext; private static final ThreadLocal<List<Parameter>> ENCRYPT_PARAMETERS_TL = new ThreadLocal<>();
private final SecurityFilterContext securityFilterContext;
public SecurityFilter(){ public SecurityFilter(){
super(); super();
this.securityFilterContext = new SecurityFilterContext(PropertyProviderType.APOLLO); this.securityFilterContext = new SecurityFilterContext(PropertyProviderType.APOLLO);
} }
public List<Parameter> getEncryptParameters(){
return ENCRYPT_PARAMETERS_TL.get();
}
public void setEncryptParameters(List<Parameter> encryptParameters){
ENCRYPT_PARAMETERS_TL.set(encryptParameters);
}
public void clearEncryptParameters(){
ENCRYPT_PARAMETERS_TL.remove();
}
public void init(DataSourceProxy dataSource) { public void init(DataSourceProxy dataSource) {
String dbName = SecurityUtil.findDataBaseNameByUrl(dataSource.getUrl()); String dbName = SecurityUtil.findDataBaseNameByUrl(dataSource.getUrl());
if(StringUtils.isNotBlank(dbName)){ if(StringUtils.isNotBlank(dbName)){
...@@ -281,21 +297,16 @@ public class SecurityFilter extends SecurityFilterEventAdapter { ...@@ -281,21 +297,16 @@ public class SecurityFilter extends SecurityFilterEventAdapter {
return securityFilterContext.getColumnRule(dbName,tableName,columnName); return securityFilterContext.getColumnRule(dbName,tableName,columnName);
} }
protected void executeBefore(FilterChain chain,StatementProxy statement, String sql) { protected String rewritePrepareSql(FilterChain chain,String sql) {
String dbName = securityFilterContext.getDbName(chain.getDataSource().getUrl()); String dbName = securityFilterContext.getDbName(chain.getDataSource().getUrl());
Map<String,Map<String,ColumnRule>> tableRuleMap = securityFilterContext.getTableRuleMap(dbName); Map<String,Map<String,ColumnRule>> tableRuleMap = securityFilterContext.getTableRuleMap(dbName);
if(tableRuleMap == null){ if(tableRuleMap == null){
log.debug("过滤非加密db,dbName={},sql={}",dbName, sql); log.debug("过滤非加密db,dbName={},sql={}",dbName, sql);
return; return sql;
}
if (!(statement instanceof PreparedStatementProxy)) {
log.debug("过滤非PreparedStatement,dbName={},sql={}",dbName, sql);
return;
} }
PreparedStatementProxyImpl preparedStatement = (PreparedStatementProxyImpl) statement; // 解析
// 解析sql List<Parameter> allEncryptParameters = null;
List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, chain.getDataSource().getDbType()); List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, chain.getDataSource().getDbType());
for (SQLStatement stmt : stmtList) { for (SQLStatement stmt : stmtList) {
if (!(stmt instanceof SQLSelectStatement) if (!(stmt instanceof SQLSelectStatement)
...@@ -306,36 +317,116 @@ public class SecurityFilter extends SecurityFilterEventAdapter { ...@@ -306,36 +317,116 @@ public class SecurityFilter extends SecurityFilterEventAdapter {
continue; continue;
} }
MySqlSecurityParameterVisitor visitor = new MySqlSecurityParameterVisitor(tableRuleMap); MySqlEncryptParameterVisitor visitor = new MySqlEncryptParameterVisitor(tableRuleMap);
stmt.accept(visitor); stmt.accept(visitor);
List<Parameter> encryptParameters = visitor.getEncryptParameters(); List<Parameter> encryptParameters = visitor.getEncryptParameters();
// 加密
if(encryptParameters != null && !encryptParameters.isEmpty()){ if(encryptParameters != null && !encryptParameters.isEmpty()){
for(Parameter parameter : encryptParameters){ allEncryptParameters = CollectionUtil.addAll(allEncryptParameters,encryptParameters);
encrypt(parameter.getColumnRule(), preparedStatement, parameter.getJdbcIndex()); }
} }
if(allEncryptParameters == null || allEncryptParameters.isEmpty()){
return sql;
} }
// 重写
StringBuilder newSql = new StringBuilder();
stmtList.forEach(e -> newSql.append(e.toString()));
String newSqlStr = newSql.toString();
log.debug("重写sql={},allEncryptParameters={}",newSql,allEncryptParameters);
this.setEncryptParameters(allEncryptParameters);
return newSqlStr;
}
protected void rewritePrepareParameter(FilterChain chain,StatementProxy statement) {
try {
String dbName = securityFilterContext.getDbName(chain.getDataSource().getUrl());
Map<String,Map<String,ColumnRule>> tableRuleMap = securityFilterContext.getTableRuleMap(dbName);
if(tableRuleMap == null){
log.debug("过滤非加密db,dbName={}",dbName);
return;
}
if (!(statement instanceof PreparedStatementProxy)) {
log.debug("过滤非PreparedStatement,dbName={}",dbName);
return;
}
PreparedStatementProxyImpl preparedStatement = (PreparedStatementProxyImpl) statement;
Map<Integer,JdbcParameter> jdbcParameters = preparedStatement.getParameters();
if(jdbcParameters == null || jdbcParameters.isEmpty()){
return;
}
// 获取加密参数
List<Parameter> allEncryptParameters = this.getEncryptParameters();
if(allEncryptParameters == null || allEncryptParameters.isEmpty()){
return;
}
// 加密明文
allEncryptParameters.forEach(e -> {
JdbcParameter jdbcParameter = jdbcParameters.get(e.getParameterIndex());
if(jdbcParameter != null){
e.setPlainValue(jdbcParameter.getValue());
e.setCipherValue(this.encrypt(e.getColumnRule(), e.getPlainValue()));
}
});
Map<Integer,Parameter> allEncryptParameterMap = allEncryptParameters.stream().collect(Collectors.toMap(Parameter::getParameterIndex, e -> e));
// 重建jdbc参数列表
long addJdbcParameterCount = allEncryptParameters.stream().filter(e -> e.getAddParameterIndex() != null && e.getAddParameterIndex() >= 0).count();
List<JdbcParameter> newJdbcParameters = new ArrayList<>(jdbcParameters.size() + Long.valueOf(addJdbcParameterCount).intValue());
for(Map.Entry<Integer,JdbcParameter> jdbcParameterEntry:jdbcParameters.entrySet()){
Integer index = jdbcParameterEntry.getKey();
JdbcParameter jdbcParameter = jdbcParameterEntry.getValue();
Parameter encryptParameter = allEncryptParameterMap.get(index);
if(encryptParameter == null || encryptParameter.getCipherValue() == null){
newJdbcParameters.add(index,jdbcParameter);
}else{
JdbcParameter newJdbcParameter = null;
if (encryptParameter.getCipherValue() == null) {
newJdbcParameter = JdbcParameterNull.VARCHAR;
}else if (encryptParameter.getCipherValue().length() == 0) {
newJdbcParameter = JdbcParameterString.empty;
}else{
newJdbcParameter = new JdbcParameterString(encryptParameter.getCipherValue());
}
newJdbcParameters.add(index,newJdbcParameter);
}
}
// 新增明文参数
for(int i=0; i<allEncryptParameters.size(); i++){
Parameter parameter = allEncryptParameters.get(i);
if(parameter.getAddParameterIndex() != null && parameter.getAddParameterIndex() >= 0){
newJdbcParameters.add(parameter.getParameterIndex() + i, jdbcParameters.get(parameter.getParameterIndex()));
}
String sql = ((PreparedStatementProxyImpl) statement).getSql();
log.debug("加密sql={}\n加密参数={}",sql,newJdbcParameters);
}
// 重写jdbc参数列表
for(int i=0; i<newJdbcParameters.size(); i++){
preparedStatement.setParameter(i + 1,newJdbcParameters.get(i));
}
} finally {
this.clearEncryptParameters();
} }
} }
/** /**
* 加密 * 加密
* @param columnRule
* @param preparedStatement
* @param index
*/ */
private void encrypt(ColumnRule columnRule, PreparedStatementProxyImpl preparedStatement,int index) { private String encrypt(ColumnRule columnRule, Object plainText) {
JdbcParameter jdbcParameter = preparedStatement.getParameter(index);
final Object plainText = jdbcParameter.getValue();
if (plainText == null) { if (plainText == null) {
return; return null;
} }
try { try {
String cipherText = columnRule.getEncryptAlgorithm().encrypt(plainText); String cipherText = columnRule.getEncryptAlgorithm().encrypt(plainText);
preparedStatement.setObject(index + 1,cipherText);
log.debug("字段加密:columnRule={},plainText={},cipherText={}", columnRule, plainText, cipherText); log.debug("字段加密:columnRule={},plainText={},cipherText={}", columnRule, plainText, cipherText);
return cipherText;
} catch (Exception e) { } catch (Exception e) {
String errorMsg = "字段加密异常:columnRule="+columnRule+",plainText="+plainText; String errorMsg = "字段加密异常:columnRule="+columnRule+",plainText="+plainText;
log.error(errorMsg); log.error(errorMsg);
......
package com.secoo.mall.datasource.security.filter; package com.secoo.mall.datasource.security.filter;
import com.alibaba.druid.filter.FilterAdapter;
import com.alibaba.druid.filter.FilterChain; import com.alibaba.druid.filter.FilterChain;
import com.alibaba.druid.filter.FilterEventAdapter;
import com.alibaba.druid.proxy.jdbc.*; import com.alibaba.druid.proxy.jdbc.*;
import java.sql.SQLException; import java.sql.SQLException;
public abstract class SecurityFilterEventAdapter extends FilterEventAdapter { public abstract class SecurityFilterEventAdapter extends FilterAdapter {
@Override @Override
public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql) throws SQLException { public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql)
statementExecuteBefore(chain, statement, sql); throws SQLException {
String rewriteSql = this.rewritePrepareSql(chain,sql);
return chain.connection_prepareStatement(connection, rewriteSql);
}
try { @Override
boolean firstResult = super.statement_execute(chain, statement, sql); public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection,
String sql, int autoGeneratedKeys) throws SQLException {
String rewriteSql = this.rewritePrepareSql(chain,sql);
return chain.connection_prepareStatement(connection, rewriteSql, autoGeneratedKeys);
}
@Override
public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection,
String sql, int resultSetType, int resultSetConcurrency)
throws SQLException {
statementExecuteAfter(statement, sql, firstResult); String rewriteSql = this.rewritePrepareSql(chain,sql);
return chain.connection_prepareStatement(connection, rewriteSql, resultSetType, resultSetConcurrency);
}
@Override
public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection,
String sql, int resultSetType, int resultSetConcurrency,
int resultSetHoldability) throws SQLException {
String rewriteSql = this.rewritePrepareSql(chain,sql);
return chain.connection_prepareStatement(connection, rewriteSql, resultSetType, resultSetConcurrency,
resultSetHoldability);
}
@Override
public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection,
String sql, int[] columnIndexes) throws SQLException {
String rewriteSql = this.rewritePrepareSql(chain,sql);
return chain.connection_prepareStatement(connection, rewriteSql, columnIndexes);
}
@Override
public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection,
String sql, String[] columnNames) throws SQLException {
String rewriteSql = this.rewritePrepareSql(chain,sql);
return chain.connection_prepareStatement(connection, rewriteSql, columnNames);
}
@Override
public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql) throws SQLException {
statementExecuteBefore(chain, statement);
try {
boolean firstResult = super.statement_execute(chain, statement, sql);
return firstResult; return firstResult;
} catch (SQLException error) { } catch (SQLException | RuntimeException | Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (RuntimeException error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error; throw error;
} }
} }
...@@ -33,22 +68,12 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter { ...@@ -33,22 +68,12 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter {
@Override @Override
public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql, int autoGeneratedKeys) public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql, int autoGeneratedKeys)
throws SQLException { throws SQLException {
statementExecuteBefore(chain, statement, sql); statementExecuteBefore(chain, statement);
try { try {
boolean firstResult = super.statement_execute(chain, statement, sql, autoGeneratedKeys); boolean firstResult = super.statement_execute(chain, statement, sql, autoGeneratedKeys);
this.statementExecuteAfter(statement, sql, firstResult);
return firstResult; return firstResult;
} catch (SQLException error) { } catch (SQLException | RuntimeException | Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (RuntimeException error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error; throw error;
} }
} }
...@@ -56,22 +81,12 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter { ...@@ -56,22 +81,12 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter {
@Override @Override
public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql, int columnIndexes[]) public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql, int columnIndexes[])
throws SQLException { throws SQLException {
statementExecuteBefore(chain, statement, sql); statementExecuteBefore(chain, statement);
try { try {
boolean firstResult = super.statement_execute(chain, statement, sql, columnIndexes); boolean firstResult = super.statement_execute(chain, statement, sql, columnIndexes);
this.statementExecuteAfter(statement, sql, firstResult);
return firstResult; return firstResult;
} catch (SQLException error) { } catch (SQLException | RuntimeException | Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (RuntimeException error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error; throw error;
} }
} }
...@@ -79,22 +94,12 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter { ...@@ -79,22 +94,12 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter {
@Override @Override
public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql, String columnNames[]) public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql, String columnNames[])
throws SQLException { throws SQLException {
statementExecuteBefore(chain, statement, sql); statementExecuteBefore(chain, statement);
try { try {
boolean firstResult = super.statement_execute(chain, statement, sql, columnNames); boolean firstResult = super.statement_execute(chain, statement, sql, columnNames);
this.statementExecuteAfter(statement, sql, firstResult);
return firstResult; return firstResult;
} catch (SQLException error) { } catch (SQLException | Error | RuntimeException error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (RuntimeException error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error; throw error;
} }
} }
...@@ -105,18 +110,8 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter { ...@@ -105,18 +110,8 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter {
try { try {
int[] result = super.statement_executeBatch(chain, statement); int[] result = super.statement_executeBatch(chain, statement);
statementExecuteBatchAfter(statement, result);
return result; return result;
} catch (SQLException error) { } catch (SQLException | RuntimeException | Error error) {
statement_executeErrorAfter(statement, statement.getBatchSql(), error);
throw error;
} catch (RuntimeException error) {
statement_executeErrorAfter(statement, statement.getBatchSql(), error);
throw error;
} catch (Error error) {
statement_executeErrorAfter(statement, statement.getBatchSql(), error);
throw error; throw error;
} }
} }
...@@ -124,7 +119,7 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter { ...@@ -124,7 +119,7 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter {
@Override @Override
public ResultSetProxy statement_executeQuery(FilterChain chain, StatementProxy statement, String sql) public ResultSetProxy statement_executeQuery(FilterChain chain, StatementProxy statement, String sql)
throws SQLException { throws SQLException {
statementExecuteQueryBefore(chain, statement, sql); statementExecuteQueryBefore(chain, statement);
try { try {
ResultSetProxy resultSet = super.statement_executeQuery(chain, statement, sql); ResultSetProxy resultSet = super.statement_executeQuery(chain, statement, sql);
...@@ -135,36 +130,19 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter { ...@@ -135,36 +130,19 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter {
} }
return resultSet; return resultSet;
} catch (SQLException error) { } catch (SQLException | RuntimeException | Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (RuntimeException error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error; throw error;
} }
} }
@Override @Override
public int statement_executeUpdate(FilterChain chain, StatementProxy statement, String sql) throws SQLException { public int statement_executeUpdate(FilterChain chain, StatementProxy statement, String sql) throws SQLException {
statementExecuteUpdateBefore(chain, statement, sql); statementExecuteUpdateBefore(chain, statement);
try { try {
int updateCount = super.statement_executeUpdate(chain, statement, sql); int updateCount = super.statement_executeUpdate(chain, statement, sql);
statementExecuteUpdateAfter(statement, sql, updateCount);
return updateCount; return updateCount;
} catch (SQLException error) { } catch (SQLException | RuntimeException | Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (RuntimeException error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error; throw error;
} }
} }
...@@ -172,22 +150,12 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter { ...@@ -172,22 +150,12 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter {
@Override @Override
public int statement_executeUpdate(FilterChain chain, StatementProxy statement, String sql, int autoGeneratedKeys) public int statement_executeUpdate(FilterChain chain, StatementProxy statement, String sql, int autoGeneratedKeys)
throws SQLException { throws SQLException {
statementExecuteUpdateBefore(chain, statement, sql); statementExecuteUpdateBefore(chain, statement);
try { try {
int updateCount = super.statement_executeUpdate(chain, statement, sql, autoGeneratedKeys); int updateCount = super.statement_executeUpdate(chain, statement, sql, autoGeneratedKeys);
statementExecuteUpdateAfter(statement, sql, updateCount);
return updateCount; return updateCount;
} catch (SQLException error) { } catch (SQLException | RuntimeException | Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (RuntimeException error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error; throw error;
} }
} }
...@@ -195,22 +163,12 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter { ...@@ -195,22 +163,12 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter {
@Override @Override
public int statement_executeUpdate(FilterChain chain, StatementProxy statement, String sql, int columnIndexes[]) public int statement_executeUpdate(FilterChain chain, StatementProxy statement, String sql, int columnIndexes[])
throws SQLException { throws SQLException {
statementExecuteUpdateBefore(chain, statement, sql); statementExecuteUpdateBefore(chain, statement);
try { try {
int updateCount = super.statement_executeUpdate(chain, statement, sql, columnIndexes); int updateCount = super.statement_executeUpdate(chain, statement, sql, columnIndexes);
statementExecuteUpdateAfter(statement, sql, updateCount);
return updateCount; return updateCount;
} catch (SQLException error) { } catch (SQLException | RuntimeException | Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (RuntimeException error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error; throw error;
} }
} }
...@@ -218,22 +176,12 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter { ...@@ -218,22 +176,12 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter {
@Override @Override
public int statement_executeUpdate(FilterChain chain, StatementProxy statement, String sql, String columnNames[]) public int statement_executeUpdate(FilterChain chain, StatementProxy statement, String sql, String columnNames[])
throws SQLException { throws SQLException {
statementExecuteUpdateBefore(chain, statement, sql); statementExecuteUpdateBefore(chain, statement);
try { try {
int updateCount = super.statement_executeUpdate(chain, statement, sql, columnNames); int updateCount = super.statement_executeUpdate(chain, statement, sql, columnNames);
statementExecuteUpdateAfter(statement, sql, updateCount);
return updateCount; return updateCount;
} catch (SQLException error) { } catch (SQLException | RuntimeException | Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (RuntimeException error) {
statement_executeErrorAfter(statement, sql, error);
throw error;
} catch (Error error) {
statement_executeErrorAfter(statement, sql, error);
throw error; throw error;
} }
} }
...@@ -241,22 +189,12 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter { ...@@ -241,22 +189,12 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter {
@Override @Override
public boolean preparedStatement_execute(FilterChain chain, PreparedStatementProxy statement) throws SQLException { public boolean preparedStatement_execute(FilterChain chain, PreparedStatementProxy statement) throws SQLException {
try { try {
statementExecuteBefore(chain, statement, statement.getSql()); statementExecuteBefore(chain, statement);
boolean firstResult = chain.preparedStatement_execute(statement); boolean firstResult = chain.preparedStatement_execute(statement);
this.statementExecuteAfter(statement, statement.getSql(), firstResult);
return firstResult; return firstResult;
} catch (SQLException error) { } catch (SQLException | RuntimeException | Error error) {
statement_executeErrorAfter(statement, statement.getSql(), error);
throw error;
} catch (RuntimeException error) {
statement_executeErrorAfter(statement, statement.getSql(), error);
throw error;
} catch (Error error) {
statement_executeErrorAfter(statement, statement.getSql(), error);
throw error; throw error;
} }
...@@ -266,7 +204,7 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter { ...@@ -266,7 +204,7 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter {
public ResultSetProxy preparedStatement_executeQuery(FilterChain chain, PreparedStatementProxy statement) public ResultSetProxy preparedStatement_executeQuery(FilterChain chain, PreparedStatementProxy statement)
throws SQLException { throws SQLException {
try { try {
statementExecuteQueryBefore(chain, statement, statement.getSql()); statementExecuteQueryBefore(chain, statement);
ResultSetProxy resultSet = chain.preparedStatement_executeQuery(statement); ResultSetProxy resultSet = chain.preparedStatement_executeQuery(statement);
...@@ -277,14 +215,7 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter { ...@@ -277,14 +215,7 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter {
} }
return resultSet; return resultSet;
} catch (SQLException error) { } catch (SQLException | RuntimeException | Error error) {
statement_executeErrorAfter(statement, statement.getSql(), error);
throw error;
} catch (RuntimeException error) {
statement_executeErrorAfter(statement, statement.getSql(), error);
throw error;
} catch (Error error) {
statement_executeErrorAfter(statement, statement.getSql(), error);
throw error; throw error;
} }
} }
...@@ -292,43 +223,54 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter { ...@@ -292,43 +223,54 @@ public abstract class SecurityFilterEventAdapter extends FilterEventAdapter {
@Override @Override
public int preparedStatement_executeUpdate(FilterChain chain, PreparedStatementProxy statement) throws SQLException { public int preparedStatement_executeUpdate(FilterChain chain, PreparedStatementProxy statement) throws SQLException {
try { try {
statementExecuteUpdateBefore(chain, statement, statement.getSql()); statementExecuteUpdateBefore(chain, statement);
int updateCount = super.preparedStatement_executeUpdate(chain, statement); int updateCount = super.preparedStatement_executeUpdate(chain, statement);
statementExecuteUpdateAfter(statement, statement.getSql(), updateCount);
return updateCount; return updateCount;
} catch (SQLException error) { } catch (SQLException | RuntimeException | Error error) {
statement_executeErrorAfter(statement, statement.getSql(), error);
throw error;
} catch (RuntimeException error) {
statement_executeErrorAfter(statement, statement.getSql(), error);
throw error;
} catch (Error error) {
statement_executeErrorAfter(statement, statement.getSql(), error);
throw error; throw error;
} }
} }
protected void statementExecuteUpdateBefore(FilterChain chain,StatementProxy statement, String sql) { protected void statementExecuteQueryAfter(StatementProxy statement, String sql, ResultSetProxy resultSet) {
this.executeBefore(chain,statement,sql);
}
protected void resultSetOpenAfter(ResultSetProxy resultSet) {
} }
protected void statementExecuteQueryBefore(FilterChain chain,StatementProxy statement, String sql) { protected void statementExecuteUpdateBefore(FilterChain chain,StatementProxy statement) {
this.executeBefore(chain,statement,sql);
} }
protected void statementExecuteBefore(FilterChain chain,StatementProxy statement, String sql) { protected void statementExecuteQueryBefore(FilterChain chain,StatementProxy statement) {
this.executeBefore(chain,statement,sql); this.rewritePrepareParameter(chain,statement);
}
protected void statementExecuteBefore(FilterChain chain,StatementProxy statement) {
this.rewritePrepareParameter(chain,statement);
} }
protected void statementExecuteBatchBefore(FilterChain chain,StatementProxy statement) { protected void statementExecuteBatchBefore(FilterChain chain,StatementProxy statement) {
this.executeBefore(chain,statement,statement.getBatchSql()); this.rewritePrepareParameter(chain,statement);
}
/**
* 重写sql
* @param chain
* @param sql
* @return
*/
protected String rewritePrepareSql(FilterChain chain,String sql){
return sql;
} }
protected void executeBefore(FilterChain chain,StatementProxy statement,String sql){ /**
this.executeBefore(chain,statement,sql); * 重写参数
* @param chain
* @param statement
*/
protected void rewritePrepareParameter(FilterChain chain,StatementProxy statement){
} }
} }
...@@ -7,6 +7,9 @@ import java.util.Map; ...@@ -7,6 +7,9 @@ import java.util.Map;
public class CollectionUtil { public class CollectionUtil {
public static <T> List<T> add(List<T> list, T t){ public static <T> List<T> add(List<T> list, T t){
if(t == null){
return null;
}
if(list == null){ if(list == null){
list = new ArrayList<>(); list = new ArrayList<>();
} }
...@@ -15,6 +18,9 @@ public class CollectionUtil { ...@@ -15,6 +18,9 @@ public class CollectionUtil {
} }
public static <T> List<T> addAll(List<T> list, List<T> subList){ public static <T> List<T> addAll(List<T> list, List<T> subList){
if(subList == null){
return null;
}
if(list == null){ if(list == null){
list = new ArrayList<>(); list = new ArrayList<>();
} }
...@@ -23,6 +29,9 @@ public class CollectionUtil { ...@@ -23,6 +29,9 @@ public class CollectionUtil {
} }
public static <K,V> Map<K,V> add(Map<K,V> map, K k, V v){ public static <K,V> Map<K,V> add(Map<K,V> map, K k, V v){
if(v == null){
return null;
}
if(map == null){ if(map == null){
map = new HashMap<>(); map = new HashMap<>();
} }
...@@ -32,7 +41,10 @@ public class CollectionUtil { ...@@ -32,7 +41,10 @@ public class CollectionUtil {
public static <K,V> Map<K,V> addAll(Map<K,V> map, Map<K,V> subMap){ public static <K,V> Map<K,V> addAll(Map<K,V> map, Map<K,V> subMap){
if(subMap == null){ if(subMap == null){
subMap = new HashMap<>(); return null;
}
if(map == null){
map = new HashMap<>();
} }
map.putAll(subMap); map.putAll(subMap);
return map; return map;
......
...@@ -10,11 +10,13 @@ import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor; ...@@ -10,11 +10,13 @@ import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.sql.parser.Token; import com.alibaba.druid.sql.parser.Token;
import com.secoo.mall.datasource.security.rule.ColumnRule; import com.secoo.mall.datasource.security.rule.ColumnRule;
import com.secoo.mall.datasource.security.util.CollectionUtil; import com.secoo.mall.datasource.security.util.CollectionUtil;
import com.secoo.mall.datasource.security.visitor.model.Column;
import com.secoo.mall.datasource.security.visitor.model.Parameter;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public class MySqlSecurityParameterVisitor extends MySqlSchemaStatVisitor { public class MySqlEncryptParameterVisitor extends MySqlSchemaStatVisitor {
/** /**
* 加密规则 * 加密规则
*/ */
...@@ -24,13 +26,13 @@ public class MySqlSecurityParameterVisitor extends MySqlSchemaStatVisitor { ...@@ -24,13 +26,13 @@ public class MySqlSecurityParameterVisitor extends MySqlSchemaStatVisitor {
*/ */
protected List<Parameter> encryptParameters; protected List<Parameter> encryptParameters;
public MySqlSecurityParameterVisitor(Map<String,Map<String,ColumnRule>> tableRuleMap, List<Parameter> encryptParameters) { public MySqlEncryptParameterVisitor(Map<String,Map<String,ColumnRule>> tableRuleMap, List<Parameter> encryptParameters) {
super(); super();
this.tableRuleMap = tableRuleMap; this.tableRuleMap = tableRuleMap;
this.encryptParameters = encryptParameters; this.encryptParameters = encryptParameters;
} }
public MySqlSecurityParameterVisitor(Map<String,Map<String,ColumnRule>> tableRuleMap) { public MySqlEncryptParameterVisitor(Map<String,Map<String,ColumnRule>> tableRuleMap) {
this(tableRuleMap,new ArrayList<>()); this(tableRuleMap,new ArrayList<>());
} }
...@@ -56,88 +58,13 @@ public class MySqlSecurityParameterVisitor extends MySqlSchemaStatVisitor { ...@@ -56,88 +58,13 @@ public class MySqlSecurityParameterVisitor extends MySqlSchemaStatVisitor {
repository.resolve(x); repository.resolve(x);
} }
accept(x.getTableSource());
accept(x.getColumns()); accept(x.getColumns());
accept(x.getValuesList()); accept(x.getValuesList());
accept(x.getQuery()); accept(x.getQuery());
accept(x.getDuplicateKeyUpdate()); accept(x.getDuplicateKeyUpdate());
// 插入sql不能省略列名 this.encryptInsert(x);
List<SQLExpr> columns = x.getColumns();
if(columns.isEmpty()){
return false;
}
String tableName = x.getTableName().getSimpleName();
// columns
List<Column> encryptColumns = null;
for(int i=0;i<columns.size();i++){
SQLExpr columnSQLExpr = columns.get(i);
Column column = this.getEncryptColumn(columnSQLExpr);
if(column == null || column.getColumnRule() == null){
continue;
}
column.setColumnIndex(i);
encryptColumns = CollectionUtil.add(encryptColumns,column);
}
// values
if(encryptColumns != null && !encryptColumns.isEmpty()){
Map<Integer,Column> encryptColumnMap = encryptColumns.stream().collect(Collectors.toMap(Column::getColumnIndex, e -> e));
List<MySqlInsertStatement.ValuesClause> valuesClauses = x.getValuesList();
for(SQLInsertStatement.ValuesClause valuesClause:valuesClauses){
List<SQLExpr> values = valuesClause.getValues();
for(int columnIndex=0; columnIndex < values.size();columnIndex++){
SQLExpr valueSQLExpr = values.get(columnIndex);
if (!(valueSQLExpr instanceof SQLVariantRefExpr)) {// ?
continue;
}
SQLVariantRefExpr variantRefExpr = (SQLVariantRefExpr) valueSQLExpr;
if(!Token.QUES.name.equals(variantRefExpr.getName())){
continue;
}
Column column = encryptColumnMap.get(columnIndex);
if(column == null || column.getColumnRule() == null){
continue;
}
// 采集加密参数
Parameter parameter = new Parameter(column,variantRefExpr.getName(),variantRefExpr.getIndex(),null);
this.encryptParameters.add(parameter);
}
}
}
// duplicateKeyUpdate
List<SQLExpr> duplicateKeyUpdate = x.getDuplicateKeyUpdate();
for(int i=0;i<duplicateKeyUpdate.size();i++){
SQLExpr sqlExpr = duplicateKeyUpdate.get(i);
if(!(sqlExpr instanceof SQLBinaryOpExpr)){
sqlExpr.accept(this);
continue;
}
SQLBinaryOpExpr sqlBinaryOpExpr = (SQLBinaryOpExpr) sqlExpr;
SQLExpr left = sqlBinaryOpExpr.getLeft();
SQLExpr right = sqlBinaryOpExpr.getRight();
if (!(right instanceof SQLVariantRefExpr)) {// ?
continue;
}
SQLVariantRefExpr variantRefExpr = (SQLVariantRefExpr) right;
if(!Token.QUES.name.equals(variantRefExpr.getName())){
continue;
}
Column column = this.getEncryptColumn(left);
if(column == null || column.getColumnRule() == null){
continue;
}
column.setColumnIndex(i);
this.encryptParameters.add(new Parameter(column,variantRefExpr.getName(), variantRefExpr.getIndex(), null));
}
return false; return false;
} }
...@@ -191,106 +118,16 @@ public class MySqlSecurityParameterVisitor extends MySqlSchemaStatVisitor { ...@@ -191,106 +118,16 @@ public class MySqlSecurityParameterVisitor extends MySqlSchemaStatVisitor {
return false; return false;
} }
Column column = this.getEncryptColumn(columnExpr); Column encryptColumn = this.handleEncryptColumn(columnExpr);
if(column == null || column.getColumnRule() == null){ if(encryptColumn == null || encryptColumn.getColumnRule() == null){
return false; return false;
} }
this.encryptParameters.add(new Parameter(column,x.getName(), x.getIndex(), null)); this.encryptParameters.add(new Parameter(encryptColumn, x.getIndex(), null));
return false; return false;
} }
/**
* 加密规则
* @param tableName
* @param columnName
* @return
*/
private ColumnRule getColumnRule(String tableName,String columnName){
if(tableName == null || columnName == null || tableRuleMap == null || tableRuleMap.isEmpty()){
return null;
}
Map<String,ColumnRule> columnRuleMap = tableRuleMap.get(tableName);
if(columnRuleMap == null || columnRuleMap.isEmpty()){
return null;
}
return columnRuleMap.get(columnName);
}
/**
* 解析列信息
* @param columnExpr
* @return
*/
protected Column getEncryptColumn(SQLExpr columnExpr){
// unwrap
columnExpr = unwrapExpr(columnExpr);
String tableName = null;
String columnName = null;
String ownerName = null;
if(columnExpr instanceof SQLIdentifierExpr){// name = ?
SQLIdentifierExpr sqlIdentifierExpr = (SQLIdentifierExpr) columnExpr;
columnName = sqlIdentifierExpr.getName();
SQLTableSource tableSource = sqlIdentifierExpr.getResolvedTableSource();
if (tableSource instanceof SQLExprTableSource) {
SQLExpr tableSourceExpr = ((SQLExprTableSource) tableSource).getExpr();
if (tableSourceExpr != null && !(tableSourceExpr instanceof SQLName)) {
tableSourceExpr = this.unwrapExpr(tableSourceExpr);
}
if (tableSourceExpr instanceof SQLName) {
tableName = ((SQLName) tableSourceExpr).toString();
}
}
}else if(columnExpr instanceof SQLPropertyExpr){
SQLPropertyExpr sqlPropertyExpr = (SQLPropertyExpr) columnExpr;
columnName = sqlPropertyExpr.getName();
ownerName = sqlPropertyExpr.getOwnernName();
SQLExpr owner = sqlPropertyExpr.getOwner();
SQLObject resolvedOwnerObject = sqlPropertyExpr.getResolvedOwnerObject();
// a.name = ? 或 t_user.name = ?
if(owner instanceof SQLIdentifierExpr){
tableName = ((SQLName) owner).toString();
if (resolvedOwnerObject instanceof SQLExprTableSource) {
SQLExpr tableSourceExpr = ((SQLExprTableSource) resolvedOwnerObject).getExpr();
if (tableSourceExpr instanceof SQLName) {
tableName = ((SQLName) tableSourceExpr).toString();
}
}
}
// secooAbcDB.t_user.name = ?
else if(owner instanceof SQLPropertyExpr){
SQLPropertyExpr ownerSQLPropertyExpr = (SQLPropertyExpr) owner;
if(ownerSQLPropertyExpr.getName() != null){
tableName = ownerSQLPropertyExpr.getName();
}else{
tableName = ownerSQLPropertyExpr.getSimpleName();
}
}
}
if(tableName == null || columnName == null){
return null;
}
ColumnRule columnRule = this.getColumnRule(tableName,columnName);
if(columnRule == null){
return null;
}
return new Column(tableName,ownerName,columnName,"",null,columnRule);
}
private SQLExpr unwrapExpr(SQLExpr expr) { private SQLExpr unwrapExpr(SQLExpr expr) {
SQLExpr original = expr; SQLExpr original = expr;
...@@ -403,4 +240,241 @@ public class MySqlSecurityParameterVisitor extends MySqlSchemaStatVisitor { ...@@ -403,4 +240,241 @@ public class MySqlSecurityParameterVisitor extends MySqlSchemaStatVisitor {
return expr; return expr;
} }
/**
* 解析条件表达式信息(where condition | updateItem)
* @param column
* @param operator
* @param valueExprs
* @return
*/
private List<Parameter> handleEncryptCondition(Column column, String operator, SQLExpr... valueExprs) {
if(column == null){
return null;
}
// 采集加密参数
List<Parameter> encryptParameters = null;
for (SQLExpr item : valueExprs) {
SQLVariantRefExpr variantRefExpr = handlePreparedValue(item);
if(variantRefExpr == null){
continue;
}
if(encryptParameters == null){
encryptParameters = new ArrayList<>();
}
encryptParameters.add(new Parameter(column, variantRefExpr.getIndex(),null));
}
if(encryptParameters != null && !encryptParameters.isEmpty()){
this.encryptParameters.addAll(encryptParameters);
}
return encryptParameters;
}
/**
* 解析占位符参数(?,其占位符不支持)
* @param valueExpr
* @return
*/
private SQLVariantRefExpr handlePreparedValue(SQLExpr valueExpr){
valueExpr = unwrapExpr(valueExpr);
if (!(valueExpr instanceof SQLVariantRefExpr)) {// ?
return null;
}
SQLVariantRefExpr variantRefExpr = (SQLVariantRefExpr) valueExpr;
if(!Token.QUES.name.equals(variantRefExpr.getName())){
return null;
}
return variantRefExpr;
}
/**
* 解析encrypt column表达式,替换逻辑列名
* @param columnExpr
* @return
*/
private Column handleEncryptColumn(SQLExpr columnExpr){
Column column = this.handleColumn(columnExpr);
if(column == null){
return null;
}
ColumnRule columnRule = this.getColumnRule(column.getTableName(),column.getColumnName());
if(columnRule == null){
return null;
}
column.setColumnRule(columnRule);
// 替换逻辑列为加密列
if(columnExpr instanceof SQLIdentifierExpr){
((SQLIdentifierExpr) columnExpr).setName(columnRule.getCipherColumn());
}else if(columnExpr instanceof SQLPropertyExpr){
((SQLPropertyExpr) columnExpr).setName(columnRule.getCipherColumn());
}
return column;
}
/**
* 解析列信息
* @param columnExpr
* @return
*/
protected Column handleColumn(SQLExpr columnExpr){
if (columnExpr instanceof SQLCastExpr) {
columnExpr = ((SQLCastExpr) columnExpr).getExpr();
}
// unwrap
columnExpr = unwrapExpr(columnExpr);
String tableName = null;
String columnName = null;
String ownerName = null;
if(columnExpr instanceof SQLIdentifierExpr){// name = ?
SQLIdentifierExpr sqlIdentifierExpr = (SQLIdentifierExpr) columnExpr;
columnName = sqlIdentifierExpr.getName();
SQLTableSource tableSource = sqlIdentifierExpr.getResolvedTableSource();
if (tableSource instanceof SQLExprTableSource) {
SQLExpr tableSourceExpr = ((SQLExprTableSource) tableSource).getExpr();
if (tableSourceExpr != null && !(tableSourceExpr instanceof SQLName)) {
tableSourceExpr = this.unwrapExpr(tableSourceExpr);
}
if (tableSourceExpr instanceof SQLName) {
tableName = ((SQLName) tableSourceExpr).toString();
}
}
}else if(columnExpr instanceof SQLPropertyExpr){
SQLPropertyExpr sqlPropertyExpr = (SQLPropertyExpr) columnExpr;
columnName = sqlPropertyExpr.getName();
ownerName = sqlPropertyExpr.getOwnernName();
SQLExpr owner = sqlPropertyExpr.getOwner();
if (owner instanceof SQLName) {
SQLObject resolvedOwnerObject = sqlPropertyExpr.getResolvedOwnerObject();
// a.name = ? 或 t_user.name = ?
if(owner instanceof SQLIdentifierExpr){
tableName = ((SQLName) owner).toString();
if (resolvedOwnerObject instanceof SQLExprTableSource) {
SQLExpr tableSourceExpr = ((SQLExprTableSource) resolvedOwnerObject).getExpr();
if (tableSourceExpr instanceof SQLName) {
tableName = ((SQLName) tableSourceExpr).toString();
}
}
}
// secooAbcDB.t_user.name = ?
else if(owner instanceof SQLPropertyExpr){
SQLPropertyExpr ownerSQLPropertyExpr = (SQLPropertyExpr) owner;
if(ownerSQLPropertyExpr.getName() != null){
tableName = ownerSQLPropertyExpr.getName();
}else{
tableName = ownerSQLPropertyExpr.getSimpleName();
}
}
}
}
if(tableName == null || columnName == null){
return null;
}
return new Column(tableName,ownerName,columnName,"",null,null);
}
/**
* 加密规则
* @param tableName
* @param columnName
* @return
*/
private ColumnRule getColumnRule(String tableName,String columnName){
if(tableName == null || columnName == null || tableRuleMap == null || tableRuleMap.isEmpty()){
return null;
}
Map<String,ColumnRule> columnRuleMap = tableRuleMap.get(tableName);
if(columnRuleMap == null || columnRuleMap.isEmpty()){
return null;
}
return columnRuleMap.get(columnName);
}
/**
* 重写insert values + duplicateKeyUpdate
* @param x
*/
private void encryptInsert(MySqlInsertStatement x){
// 插入sql不能省略列名
List<SQLExpr> columns = x.getColumns();
if(columns.isEmpty()){
return;
}
// columns
List<Column> insertEncryptColumns = null;
for(int i=0; i<columns.size(); i++){
SQLExpr columnSQLExpr = columns.get(i);
Column encryptColumn = this.handleEncryptColumn(columnSQLExpr);
if(encryptColumn == null){
continue;
}
encryptColumn.setColumnIndex(i);
insertEncryptColumns = CollectionUtil.add(insertEncryptColumns,encryptColumn);
}
// values
if(insertEncryptColumns != null && !insertEncryptColumns.isEmpty()){
Map<Integer,Column> insertEncryptColumnMap = insertEncryptColumns.stream().collect(Collectors.toMap(Column::getColumnIndex, e -> e));
List<MySqlInsertStatement.ValuesClause> valuesClauses = x.getValuesList();
for(SQLInsertStatement.ValuesClause valuesClause:valuesClauses){
List<SQLExpr> values = valuesClause.getValues();
for(int columnIndex=0; columnIndex < values.size(); columnIndex++){
Column encryptColumn = insertEncryptColumnMap.get(columnIndex);
if(encryptColumn == null){
continue;
}
SQLExpr valueSQLExpr = values.get(columnIndex);
SQLVariantRefExpr variantRefExpr = this.handlePreparedValue(valueSQLExpr);
if(variantRefExpr == null){
continue;
}
// 采集加密参数
Parameter encryptParameter = new Parameter(encryptColumn,variantRefExpr.getIndex(),null);
this.encryptParameters.add(encryptParameter);
}
}
}
// duplicateKeyUpdate
List<SQLExpr> duplicateKeyUpdate = x.getDuplicateKeyUpdate();
for(int i=0; i<duplicateKeyUpdate.size(); i++){
SQLExpr sqlExpr = duplicateKeyUpdate.get(i);
if(!(sqlExpr instanceof SQLBinaryOpExpr)){
sqlExpr.accept(this);
continue;
}
SQLBinaryOpExpr sqlBinaryOpExpr = (SQLBinaryOpExpr) sqlExpr;
SQLExpr left = sqlBinaryOpExpr.getLeft();
SQLExpr right = sqlBinaryOpExpr.getRight();
Column encryptColumn = this.handleEncryptColumn(left);
if(encryptColumn == null){
continue;
}
encryptColumn.setColumnIndex(i);
List<Parameter> encryptParameters = this.handleEncryptCondition(encryptColumn,SQLBinaryOperator.Equality.name,right);
}
}
} }
package com.secoo.mall.datasource.security.visitor;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLName;
import com.alibaba.druid.sql.ast.SQLObject;
import com.alibaba.druid.sql.ast.expr.*;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter;
import com.alibaba.druid.sql.parser.Token;
import com.alibaba.druid.sql.repository.SchemaRepository;
import com.alibaba.druid.util.FnvHash;
import com.secoo.mall.datasource.security.rule.ColumnRule;
import com.secoo.mall.datasource.security.util.CollectionUtil;
import com.secoo.mall.datasource.security.visitor.model.Column;
import com.secoo.mall.datasource.security.visitor.model.Parameter;
import java.util.*;
import java.util.stream.Collectors;
public class MySqlRewriteParameterVisitor extends MySqlASTVisitorAdapter {
protected SchemaRepository repository;
protected String dbType;
/**
* 加密规则
*/
private Map<String,Map<String,ColumnRule>> tableRuleMap;
private List<Parameter> encryptParameters;
public String getDbType() {
return dbType;
}
public void setDbType(String dbType) {
this.dbType = dbType;
}
public List<Parameter> getEncryptParameters() {
return encryptParameters;
}
public void setEncryptColumnParameters(Map<String,Map<String,ColumnRule>> tableRuleMap, List<Parameter> encryptColumnParameters) {
this.tableRuleMap = tableRuleMap;
this.encryptParameters = encryptColumnParameters;
}
public MySqlRewriteParameterVisitor(){
this((String) null);
}
public MySqlRewriteParameterVisitor(String dbType){
this(new SchemaRepository(dbType),new HashMap<>(), new ArrayList<Parameter>());
this.dbType = dbType;
}
public MySqlRewriteParameterVisitor(Map<String,Map<String,ColumnRule>> tableRuleMap, List<Parameter> encryptParameters){
this((String) null,tableRuleMap, encryptParameters);
}
public MySqlRewriteParameterVisitor(String dbType, Map<String,Map<String,ColumnRule>> tableRuleMap, List<Parameter> encryptParameters){
this(new SchemaRepository(dbType),tableRuleMap, encryptParameters);
this.encryptParameters = encryptParameters;
}
public MySqlRewriteParameterVisitor(SchemaRepository repository, Map<String,Map<String,ColumnRule>> tableRuleMap, List<Parameter> encryptParameters){
this.repository = repository;
this.tableRuleMap = tableRuleMap;
this.encryptParameters = encryptParameters;
if (repository != null) {
String dbType = repository.getDbType();
if (dbType != null && this.dbType == null) {
this.dbType = dbType;
}
}
}
@Override
public boolean visit(MySqlInsertStatement x) {
if (repository != null
&& x.getParent() == null) {
repository.resolve(x);
}
accept(x.getTableSource());
accept(x.getColumns());
accept(x.getValuesList());
accept(x.getQuery());
accept(x.getDuplicateKeyUpdate());
this.rewriteInsert(x);
return false;
}
@Override
public boolean visit(MySqlDeleteStatement x) {
if (repository != null
&& x.getParent() == null) {
repository.resolve(x);
}
SQLTableSource from = x.getFrom();
if (from != null) {
from.accept(this);
}
SQLTableSource using = x.getUsing();
if (using != null) {
using.accept(this);
}
SQLTableSource tableSource = x.getTableSource();
tableSource.accept(this);
accept(x.getWhere());
accept(x.getOrderBy());
accept(x.getLimit());
return false;
}
@Override
public boolean visit(MySqlUpdateStatement x) {
if (repository != null
&& x.getParent() == null) {
repository.resolve(x);
}
SQLTableSource tableSource = x.getTableSource();
if (!(tableSource instanceof SQLExprTableSource)) {
tableSource.accept(this);
}
accept(x.getFrom());
accept(x.getItems());
accept(x.getWhere());
for (SQLExpr item : x.getReturning()) {
item.accept(this);
}
this.rewriteUpdate(x);
return false;
}
@Override
public boolean visit(SQLSelectStatement x) {
if (repository != null
&& x.getParent() == null) {
repository.resolve(x);
}
visit(x.getSelect());
return true;
}
public boolean visit(SQLSelectItem x) {
statExpr(x.getExpr());
// 逻辑列处理
handleEncryptColumn(x.getExpr());
return false;
}
/**
* where - 普通
*/
@Override
public boolean visit(SQLBinaryOpExpr x) {
SQLObject parent = x.getParent();
if (parent instanceof SQLIfStatement) {
return true;
}
final SQLBinaryOperator op = x.getOperator();
final SQLExpr left = x.getLeft();
final SQLExpr right = x.getRight();
switch (op) {
case Equality:
case NotEqual:
case GreaterThan:
case GreaterThanOrEqual:
case LessThan:
case LessThanOrGreater:
case LessThanOrEqual:
case LessThanOrEqualOrGreaterThan:
case Like:
case NotLike:
case Is:
case IsNot:
handleEncryptCondition(left, x.getOperator().name, right);
handleEncryptCondition(right, x.getOperator().name, left);
break;
case BooleanOr: {
List<SQLExpr> list = SQLBinaryOpExpr.split(x, op);
for (SQLExpr item : list) {
if (item instanceof SQLBinaryOpExpr) {
visit((SQLBinaryOpExpr) item);
} else {
item.accept(this);
}
}
return false;
}
case Modulus:
if (right instanceof SQLIdentifierExpr) {
long hashCode64 = ((SQLIdentifierExpr) right).hashCode64();
if (hashCode64 == FnvHash.Constants.ISOPEN) {
left.accept(this);
return false;
}
}
break;
default:
break;
}
statExpr(left);
statExpr(right);
return false;
}
/**
* where - BETWEEN
*/
@Override
public boolean visit(SQLBetweenExpr x) {
SQLObject parent = x.getParent();
SQLExpr test = x.getTestExpr();
SQLExpr begin = x.getBeginExpr();
SQLExpr end = x.getEndExpr();
handleEncryptCondition(test, "BETWEEN", begin, end);
return false;
}
/**
* where - IN
*/
@Override
public boolean visit(SQLInListExpr x) {
if (x.isNot()) {
handleEncryptCondition(x.getExpr(), "NOT IN", x.getTargetList());
} else {
handleEncryptCondition(x.getExpr(), "IN", x.getTargetList());
}
return true;
}
/**
* where - SubQuery
*/
@Override
public boolean visit(SQLInSubQueryExpr x) {
if (x.isNot()) {
handleEncryptCondition(x.getExpr(), "NOT IN");
} else {
handleEncryptCondition(x.getExpr(), "IN");
}
return true;
}
protected void accept(SQLObject x) {
if (x != null) {
x.accept(this);
}
}
protected void accept(List<? extends SQLObject> nodes) {
for (int i = 0, size = nodes.size(); i < size; ++i) {
accept(nodes.get(i));
}
}
protected final void statExpr(SQLExpr x) {
Class<?> clazz = x.getClass();
if (clazz == SQLIdentifierExpr.class) {
visit((SQLIdentifierExpr) x);
} else if (clazz == SQLPropertyExpr.class) {
visit((SQLPropertyExpr) x);
// } else if (clazz == SQLAggregateExpr.class) {
// visit((SQLAggregateExpr) x);
} else if (clazz == SQLBinaryOpExpr.class) {
visit((SQLBinaryOpExpr) x);
// } else if (clazz == SQLCharExpr.class) {
// visit((SQLCharExpr) x);
// } else if (clazz == SQLNullExpr.class) {
// visit((SQLNullExpr) x);
// } else if (clazz == SQLIntegerExpr.class) {
// visit((SQLIntegerExpr) x);
// } else if (clazz == SQLNumberExpr.class) {
// visit((SQLNumberExpr) x);
// } else if (clazz == SQLMethodInvokeExpr.class) {
// visit((SQLMethodInvokeExpr) x);
// } else if (clazz == SQLVariantRefExpr.class) {
// visit((SQLVariantRefExpr) x);
// } else if (clazz == SQLBinaryOpExprGroup.class) {
// visit((SQLBinaryOpExprGroup) x);
} else if (x instanceof SQLLiteralExpr) {
// skip
} else {
x.accept(this);
}
}
private SQLExpr unwrapExpr(SQLExpr expr) {
SQLExpr original = expr;
for (;;) {
if (expr instanceof SQLMethodInvokeExpr) {
SQLMethodInvokeExpr methodInvokeExp = (SQLMethodInvokeExpr) expr;
if (methodInvokeExp.getArguments().size() == 1) {
SQLExpr firstExpr = methodInvokeExp.getArguments().get(0);
expr = firstExpr;
continue;
}
}
if (expr instanceof SQLCastExpr) {
expr = ((SQLCastExpr) expr).getExpr();
continue;
}
if (expr instanceof SQLPropertyExpr) {
SQLPropertyExpr propertyExpr = (SQLPropertyExpr) expr;
SQLTableSource resolvedTableSource = propertyExpr.getResolvedTableSource();
if (resolvedTableSource instanceof SQLSubqueryTableSource) {
SQLSelect select = ((SQLSubqueryTableSource) resolvedTableSource).getSelect();
SQLSelectQueryBlock queryBlock = select.getFirstQueryBlock();
if (queryBlock != null) {
if (queryBlock.getGroupBy() != null) {
if (original.getParent() instanceof SQLBinaryOpExpr) {
SQLExpr other = ((SQLBinaryOpExpr) original.getParent()).other(original);
if (!SQLExprUtils.isLiteralExpr(other)) {
break;
}
}
}
SQLSelectItem selectItem = queryBlock.findSelectItem(propertyExpr
.nameHashCode64());
if (selectItem != null) {
SQLExpr selectItemExpr = selectItem.getExpr();
if (selectItemExpr != expr) {
expr = selectItemExpr;
continue;
}
} else if (queryBlock.selectItemHasAllColumn()) {
SQLTableSource allColumnTableSource = null;
SQLTableSource from = queryBlock.getFrom();
if (from instanceof SQLJoinTableSource) {
SQLSelectItem allColumnSelectItem = queryBlock.findAllColumnSelectItem();
if (allColumnSelectItem != null && allColumnSelectItem.getExpr() instanceof SQLPropertyExpr) {
SQLExpr owner = ((SQLPropertyExpr) allColumnSelectItem.getExpr()).getOwner();
if (owner instanceof SQLName) {
allColumnTableSource = from.findTableSource(((SQLName) owner).nameHashCode64());
}
}
} else {
allColumnTableSource = from;
}
if (allColumnTableSource == null) {
break;
}
propertyExpr = propertyExpr.clone();
propertyExpr.setResolvedTableSource(allColumnTableSource);
if (allColumnTableSource instanceof SQLExprTableSource) {
propertyExpr.setOwner(((SQLExprTableSource) allColumnTableSource).getExpr().clone());
}
expr = propertyExpr;
continue;
}
}
} else if (resolvedTableSource instanceof SQLExprTableSource) {
SQLExprTableSource exprTableSource = (SQLExprTableSource) resolvedTableSource;
if (exprTableSource.getSchemaObject() != null) {
break;
}
SQLTableSource redirectTableSource = null;
SQLExpr tableSourceExpr = exprTableSource.getExpr();
if (tableSourceExpr instanceof SQLIdentifierExpr) {
redirectTableSource = ((SQLIdentifierExpr) tableSourceExpr).getResolvedTableSource();
} else if (tableSourceExpr instanceof SQLPropertyExpr) {
redirectTableSource = ((SQLPropertyExpr) tableSourceExpr).getResolvedTableSource();
}
if (redirectTableSource == resolvedTableSource) {
redirectTableSource = null;
}
if (redirectTableSource != null) {
propertyExpr = propertyExpr.clone();
if (redirectTableSource instanceof SQLExprTableSource) {
propertyExpr.setOwner(((SQLExprTableSource) redirectTableSource).getExpr().clone());
}
propertyExpr.setResolvedTableSource(redirectTableSource);
expr = propertyExpr;
continue;
}
propertyExpr = propertyExpr.clone();
propertyExpr.setOwner(tableSourceExpr);
expr = propertyExpr;
break;
}
}
break;
}
return expr;
}
private List<Parameter> handleEncryptCondition(SQLExpr expr, String operator, List<SQLExpr> values) {
Column encryptColumn = this.handleEncryptColumn(expr);
return handleEncryptCondition(encryptColumn, operator, values == null?(new SQLExpr[0]):values.toArray(new SQLExpr[values.size()]));
}
private List<Parameter> handleEncryptCondition(SQLExpr expr, String operator, SQLExpr... valueExprs) {
Column encryptColumn = this.handleEncryptColumn(expr);
return handleEncryptCondition(encryptColumn, operator, valueExprs);
}
/**
* 解析条件表达式信息(where condition | updateItem)
* @param encryptColumn
* @param operator
* @param valueExprs
* @return
*/
private List<Parameter> handleEncryptCondition(Column encryptColumn, String operator, SQLExpr... valueExprs) {
if(encryptColumn == null){
return null;
}
// 采集加密参数
List<Parameter> encryptParameters = null;
for (SQLExpr item : valueExprs) {
SQLVariantRefExpr variantRefExpr = handlePreparedValue(item);
if(variantRefExpr == null){
continue;
}
if(encryptParameters == null){
encryptParameters = new ArrayList<>();
}
encryptParameters.add(new Parameter(encryptColumn, variantRefExpr.getIndex(),null));
}
if(encryptParameters != null && !encryptParameters.isEmpty()){
this.encryptParameters.addAll(encryptParameters);
}
return encryptParameters;
}
/**
* 解析占位符参数(?,其占位符不支持)
* @param valueExpr
* @return
*/
private SQLVariantRefExpr handlePreparedValue(SQLExpr valueExpr){
valueExpr = unwrapExpr(valueExpr);
if (!(valueExpr instanceof SQLVariantRefExpr)) {// ?
return null;
}
SQLVariantRefExpr variantRefExpr = (SQLVariantRefExpr) valueExpr;
if(!Token.QUES.name.equals(variantRefExpr.getName())){
return null;
}
return variantRefExpr;
}
/**
* 解析encrypt column表达式,替换逻辑列名
* @param columnExpr
* @return
*/
private Column handleEncryptColumn(SQLExpr columnExpr){
Column column = this.handleColumn(columnExpr);
if(column == null){
return null;
}
ColumnRule columnRule = this.getColumnRule(column.getTableName(),column.getColumnName());
if(columnRule == null){
return null;
}
column.setColumnRule(columnRule);
// 替换逻辑列为加密列
if(columnExpr instanceof SQLIdentifierExpr){
((SQLIdentifierExpr) columnExpr).setName(columnRule.getCipherColumn());
}else if(columnExpr instanceof SQLPropertyExpr){
((SQLPropertyExpr) columnExpr).setName(columnRule.getCipherColumn());
}
return column;
}
/**
* 解析column表达式
* @return
*/
private Column handleColumn(SQLExpr columnExpr){
if (columnExpr instanceof SQLCastExpr) {
columnExpr = ((SQLCastExpr) columnExpr).getExpr();
}
// unwrap
columnExpr = unwrapExpr(columnExpr);
String tableName = null;
String columnName = null;
String ownerName = null;
if(columnExpr instanceof SQLIdentifierExpr){// name = ?
SQLIdentifierExpr sqlIdentifierExpr = (SQLIdentifierExpr) columnExpr;
columnName = sqlIdentifierExpr.getName();
SQLTableSource tableSource = sqlIdentifierExpr.getResolvedTableSource();
if (tableSource instanceof SQLExprTableSource) {
SQLExpr tableSourceExpr = ((SQLExprTableSource) tableSource).getExpr();
if (tableSourceExpr != null && !(tableSourceExpr instanceof SQLName)) {
tableSourceExpr = this.unwrapExpr(tableSourceExpr);
}
if (tableSourceExpr instanceof SQLName) {
tableName = ((SQLName) tableSourceExpr).toString();
}
}
}else if(columnExpr instanceof SQLPropertyExpr){
SQLPropertyExpr sqlPropertyExpr = (SQLPropertyExpr) columnExpr;
columnName = sqlPropertyExpr.getName();
ownerName = sqlPropertyExpr.getOwnernName();
SQLExpr owner = sqlPropertyExpr.getOwner();
if (owner instanceof SQLName) {
SQLObject resolvedOwnerObject = sqlPropertyExpr.getResolvedOwnerObject();
// a.name = ? 或 t_user.name = ?
if(owner instanceof SQLIdentifierExpr){
tableName = ((SQLName) owner).toString();
if (resolvedOwnerObject instanceof SQLExprTableSource) {
SQLExpr tableSourceExpr = ((SQLExprTableSource) resolvedOwnerObject).getExpr();
if (tableSourceExpr instanceof SQLName) {
tableName = ((SQLName) tableSourceExpr).toString();
}
}
}
// secooAbcDB.t_user.name = ?
else if(owner instanceof SQLPropertyExpr){
SQLPropertyExpr ownerSQLPropertyExpr = (SQLPropertyExpr) owner;
if(ownerSQLPropertyExpr.getName() != null){
tableName = ownerSQLPropertyExpr.getName();
}else{
tableName = ownerSQLPropertyExpr.getSimpleName();
}
}
}
}
if(tableName == null || columnName == null){
return null;
}
return new Column(tableName,ownerName,columnName,"",null,null);
}
/**
* 加密规则
* @param tableName
* @param columnName
* @return
*/
private ColumnRule getColumnRule(String tableName, String columnName){
if(tableName == null || columnName == null || tableRuleMap == null || tableRuleMap.isEmpty()){
return null;
}
Map<String,ColumnRule> columnRuleMap = tableRuleMap.get(tableName);
if(columnRuleMap == null || columnRuleMap.isEmpty()){
return null;
}
return columnRuleMap.get(columnName);
}
/**
* 重写insert values + duplicateKeyUpdate
* @param x
*/
private void rewriteInsert(MySqlInsertStatement x){
// 插入sql不能省略列名
List<SQLExpr> columns = x.getColumns();
if(columns.isEmpty()){
return;
}
// columns
List<Column> insertEncryptColumns = null;
for(int i=0; i<columns.size(); i++){
SQLExpr columnSQLExpr = columns.get(i);
Column encryptColumn = this.handleEncryptColumn(columnSQLExpr);
if(encryptColumn == null){
continue;
}
encryptColumn.setColumnIndex(i);
insertEncryptColumns = CollectionUtil.add(insertEncryptColumns,encryptColumn);
}
// values
if(insertEncryptColumns != null && !insertEncryptColumns.isEmpty()){
Set<Integer> addInsertPlainColumn = new HashSet<>();
Map<Integer,Column> insertEncryptColumnMap = insertEncryptColumns.stream().collect(Collectors.toMap(Column::getColumnIndex, e -> e));
List<MySqlInsertStatement.ValuesClause> valuesClauses = x.getValuesList();
for(SQLInsertStatement.ValuesClause valuesClause:valuesClauses){
// 加密参数处理
Map<Integer,Parameter> insertEncryptParameterMap = null;
List<SQLExpr> values = valuesClause.getValues();
for(int columnIndex=0; columnIndex < values.size(); columnIndex++){
Column encryptColumn = insertEncryptColumnMap.get(columnIndex);
if(encryptColumn == null){
continue;
}
SQLExpr valueSQLExpr = values.get(columnIndex);
SQLVariantRefExpr variantRefExpr = this.handlePreparedValue(valueSQLExpr);
if(variantRefExpr == null){
continue;
}
// 采集加密参数
Parameter encryptParameter = new Parameter(encryptColumn,variantRefExpr.getIndex(),null);
this.encryptParameters.add(encryptParameter);
insertEncryptParameterMap = CollectionUtil.add(insertEncryptParameterMap,columnIndex,encryptParameter);
}
// 增加明文列值(存在加密列但不使用预占符的参数,所以直接根据加密列value复制到明文列)
if(insertEncryptParameterMap != null && !insertEncryptParameterMap.isEmpty()){
for(int i=0;i<insertEncryptColumns.size();i++){
Column encryptColumn = insertEncryptColumns.get(i);
ColumnRule columnRule = encryptColumn.getColumnRule();
Parameter insertEncryptParameter = insertEncryptParameterMap.get(encryptColumn.getColumnIndex());
if(columnRule.getPlainColumn() != null && insertEncryptParameter != null){
addInsertPlainColumn.add(encryptColumn.getColumnIndex());
insertEncryptParameter.setAddParameterIndex(insertEncryptParameter.getParameterIndex());
SQLExpr valueSQLExpr = values.get(encryptColumn.getColumnIndex());
values.add(encryptColumn.getColumnIndex() + i,valueSQLExpr);
}
}
}
}
// 增加明文列名
for(int i=0;i<insertEncryptColumns.size();i++){
Column encryptColumn = insertEncryptColumns.get(i);
ColumnRule columnRule = encryptColumn.getColumnRule();
if(columnRule.getPlainColumn() != null && addInsertPlainColumn.contains(encryptColumn.getColumnIndex())){
SQLExpr plainSQLExpr = SQLUtils.toMySqlExpr((encryptColumn.getOwnerName() == null?"":(encryptColumn.getOwnerName() + ".")) + columnRule.getPlainColumn());
columns.add(encryptColumn.getColumnIndex() + i,plainSQLExpr);
}
}
}
// duplicateKeyUpdate
List<Column> dkpEncryptColumns = null;
Map<Integer,Parameter> dkpEncryptParameterMap = null;
List<SQLExpr> duplicateKeyUpdate = x.getDuplicateKeyUpdate();
for(int i=0; i<duplicateKeyUpdate.size(); i++){
SQLExpr sqlExpr = duplicateKeyUpdate.get(i);
if(!(sqlExpr instanceof SQLBinaryOpExpr)){
sqlExpr.accept(this);
continue;
}
SQLBinaryOpExpr sqlBinaryOpExpr = (SQLBinaryOpExpr) sqlExpr;
SQLExpr left = sqlBinaryOpExpr.getLeft();
SQLExpr right = sqlBinaryOpExpr.getRight();
Column encryptColumn = this.handleEncryptColumn(left);
if(encryptColumn == null){
continue;
}
encryptColumn.setColumnIndex(i);
dkpEncryptColumns = CollectionUtil.add(dkpEncryptColumns,encryptColumn);
List<Parameter> encryptParameters = this.handleEncryptCondition(encryptColumn,SQLBinaryOperator.Equality.name,right);
if(encryptParameters != null && !encryptParameters.isEmpty()){
dkpEncryptParameterMap = CollectionUtil.add(dkpEncryptParameterMap,encryptParameters.get(0).getColumnIndex(),encryptParameters.get(0));
}
}
// 新增明文列
if(dkpEncryptColumns != null && !dkpEncryptColumns.isEmpty()
&& dkpEncryptParameterMap != null && dkpEncryptParameterMap.isEmpty()){
for(int i=0; i<dkpEncryptColumns.size(); i++){
Column encryptColumn = dkpEncryptColumns.get(i);
ColumnRule columnRule = encryptColumn.getColumnRule();
Parameter encryptParameter = dkpEncryptParameterMap.get(encryptColumn.getColumnIndex());
if(columnRule.getPlainColumn() != null && encryptParameter != null){
encryptParameter.setAddParameterIndex(encryptParameter.getParameterIndex());
SQLExpr sqlExpr = duplicateKeyUpdate.get(i);
SQLBinaryOpExpr encryptSqlBinaryOpExpr = (SQLBinaryOpExpr) sqlExpr;
SQLBinaryOpExpr plainSqlBinaryOpExpr = new SQLBinaryOpExpr();
plainSqlBinaryOpExpr.setLeft(SQLUtils.toMySqlExpr((encryptColumn.getOwnerName() == null?"":(encryptColumn.getOwnerName()+".")) + encryptColumn.getColumnName()));
plainSqlBinaryOpExpr.setRight(encryptSqlBinaryOpExpr.getRight());
plainSqlBinaryOpExpr.setOperator(encryptSqlBinaryOpExpr.getOperator());
plainSqlBinaryOpExpr.setDbType(encryptSqlBinaryOpExpr.getDbType());
duplicateKeyUpdate.add(encryptColumn.getColumnIndex() + i, plainSqlBinaryOpExpr);
}
}
}
}
/**
* 重写update items
* @param x
*/
private void rewriteUpdate(MySqlUpdateStatement x){
List<Column> updateEncryptColumns = null;
Map<Integer,Parameter> updateEncryptParameterMap = null;
List<SQLUpdateSetItem> items = x.getItems();
for(int i=0; i<items.size(); i++){
SQLUpdateSetItem item = items.get(i);
SQLExpr column = item.getColumn();
SQLExpr value = item.getValue();
Column encryptColumn = this.handleEncryptColumn(column);
if(encryptColumn == null){
continue;
}
encryptColumn.setColumnIndex(i);
updateEncryptColumns = CollectionUtil.add(updateEncryptColumns,encryptColumn);
List<Parameter> encryptParameters = this.handleEncryptCondition(encryptColumn,SQLBinaryOperator.Equality.name,value);
if(encryptParameters != null && !encryptParameters.isEmpty()){
updateEncryptParameterMap = CollectionUtil.add(updateEncryptParameterMap,encryptParameters.get(0).getColumnIndex(),encryptParameters.get(0));
}
}
// 新增明文列
if(updateEncryptColumns != null && !updateEncryptColumns.isEmpty()
&& updateEncryptParameterMap != null && updateEncryptParameterMap.isEmpty()){
for(int i=0; i<updateEncryptColumns.size(); i++){
Column encryptColumn = updateEncryptColumns.get(i);
ColumnRule columnRule = encryptColumn.getColumnRule();
Parameter encryptParameter = updateEncryptParameterMap.get(encryptColumn.getColumnIndex());
if(columnRule.getPlainColumn() != null && encryptParameter != null){
encryptParameter.setAddParameterIndex(encryptParameter.getParameterIndex());
SQLUpdateSetItem item = items.get(i);
SQLUpdateSetItem plainSQLUpdateSetItem = new SQLUpdateSetItem();
SQLExpr plainSQLExpr = SQLUtils.toMySqlExpr((encryptColumn.getOwnerName() == null?"":(encryptColumn.getOwnerName() + ".")) + columnRule.getPlainColumn());
plainSQLUpdateSetItem.setColumn(plainSQLExpr);
plainSQLUpdateSetItem.setValue(item.getValue());
items.add(encryptColumn.getColumnIndex() + i, plainSQLUpdateSetItem);
}
}
}
}
}
package com.secoo.mall.datasource.security.visitor;
import com.secoo.mall.datasource.security.rule.ColumnRule;
public class Parameter extends Column {
private Object value;
private int jdbcIndex;
private Integer addJdbcIndex;
public Parameter(String tableName, String ownerName, String columnName, String columnAlias,Integer columnIndex, ColumnRule columnRule,
Object value, int jdbcIndex, Integer addJdbcIndex) {
super(tableName,ownerName,columnName,columnAlias,columnIndex,columnRule);
this.value = value;
this.jdbcIndex = jdbcIndex;
this.addJdbcIndex = addJdbcIndex;
}
public Parameter(Column column,
Object value, int jdbcIndex, Integer addJdbcIndex) {
this(column.getTableName(),column.getOwnerName(),column.getColumnName(),column.getColumnAlias(),column.getColumnIndex(),column.getColumnRule(),value,jdbcIndex,addJdbcIndex);
}
public Object getValue() {
return value;
}
public void setValue(Object value) {
this.value = value;
}
public int getJdbcIndex() {
return jdbcIndex;
}
public void setJdbcIndex(int jdbcIndex) {
this.jdbcIndex = jdbcIndex;
}
public Integer getAddJdbcIndex() {
return addJdbcIndex;
}
public void setAddJdbcIndex(Integer addJdbcIndex) {
this.addJdbcIndex = addJdbcIndex;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof Parameter)) return false;
if (!super.equals(o)) return false;
Parameter parameter = (Parameter) o;
if (getJdbcIndex() != parameter.getJdbcIndex()) return false;
if (getValue() != null ? !getValue().equals(parameter.getValue()) : parameter.getValue() != null) return false;
return getAddJdbcIndex() != null ? getAddJdbcIndex().equals(parameter.getAddJdbcIndex()) : parameter.getAddJdbcIndex() == null;
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + (getValue() != null ? getValue().hashCode() : 0);
result = 31 * result + getJdbcIndex();
result = 31 * result + (getAddJdbcIndex() != null ? getAddJdbcIndex().hashCode() : 0);
return result;
}
}
package com.secoo.mall.datasource.security.visitor; package com.secoo.mall.datasource.security.visitor.model;
import com.secoo.mall.datasource.security.rule.ColumnRule; import com.secoo.mall.datasource.security.rule.ColumnRule;
......
package com.secoo.mall.datasource.security.visitor.model;
import com.secoo.mall.datasource.security.rule.ColumnRule;
public class Parameter extends Column {
private Object plainValue;
private String cipherValue;
private int parameterIndex;
private Integer addParameterIndex;
public Parameter(String tableName, String ownerName, String columnName, String columnAlias, Integer columnIndex, ColumnRule columnRule,
int parameterIndex, Integer addParameterIndex) {
super(tableName,ownerName,columnName,columnAlias,columnIndex,columnRule);
this.parameterIndex = parameterIndex;
this.addParameterIndex = addParameterIndex;
}
public Parameter(Column column,
int parameterIndex, Integer addParameterIndex) {
this(column.getTableName(),column.getOwnerName(),column.getColumnName(),column.getColumnAlias(),column.getColumnIndex(),column.getColumnRule(), parameterIndex, addParameterIndex);
}
public int getParameterIndex() {
return parameterIndex;
}
public void setParameterIndex(int parameterIndex) {
this.parameterIndex = parameterIndex;
}
public Integer getAddParameterIndex() {
return addParameterIndex;
}
public void setAddParameterIndex(Integer addParameterIndex) {
this.addParameterIndex = addParameterIndex;
}
public Object getPlainValue() {
return plainValue;
}
public void setPlainValue(Object plainValue) {
this.plainValue = plainValue;
}
public String getCipherValue() {
return cipherValue;
}
public void setCipherValue(String cipherValue) {
this.cipherValue = cipherValue;
}
@Override
public String toString() {
return "Parameter{" +
"tableName='" + tableName + '\'' +
", ownerName='" + ownerName + '\'' +
", columnName='" + columnName + '\'' +
", columnAlias='" + columnAlias + '\'' +
", columnIndex=" + columnIndex +
", columnRule=" + columnRule +
", plainValue=" + plainValue +
", cipherValue='" + cipherValue + '\'' +
", parameterIndex=" + parameterIndex +
", addParameterIndex=" + addParameterIndex +
'}';
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof Parameter)) return false;
if (!super.equals(o)) return false;
Parameter parameter = (Parameter) o;
if (getParameterIndex() != parameter.getParameterIndex()) return false;
if (getPlainValue() != null ? !getPlainValue().equals(parameter.getPlainValue()) : parameter.getPlainValue() != null)
return false;
if (getCipherValue() != null ? !getCipherValue().equals(parameter.getCipherValue()) : parameter.getCipherValue() != null)
return false;
return getAddParameterIndex() != null ? getAddParameterIndex().equals(parameter.getAddParameterIndex()) : parameter.getAddParameterIndex() == null;
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + (getPlainValue() != null ? getPlainValue().hashCode() : 0);
result = 31 * result + (getCipherValue() != null ? getCipherValue().hashCode() : 0);
result = 31 * result + getParameterIndex();
result = 31 * result + (getAddParameterIndex() != null ? getAddParameterIndex().hashCode() : 0);
return result;
}
}
...@@ -7,14 +7,13 @@ import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr; ...@@ -7,14 +7,13 @@ import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr;
import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr; import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement; import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem; import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem;
import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlCharExpr;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.util.JdbcConstants;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.secoo.mall.datasource.security.visitor.MySqlSecurityParameterVisitor; import com.secoo.mall.datasource.security.visitor.MySqlRewriteParameterVisitor;
import com.secoo.mall.datasource.security.visitor.MySqlSecurityParameterVisitor2; import com.secoo.mall.datasource.security.visitor.model.Parameter;
import com.secoo.mall.datasource.security.visitor.Parameter;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
...@@ -22,13 +21,15 @@ import java.util.List; ...@@ -22,13 +21,15 @@ import java.util.List;
public class SQLParserTest { public class SQLParserTest {
public static void testVisitor(){ public static void testVisitor(){
String sql = "select * from t_user where id=?";
// String sql = "SELECT Sname FROM Student WHERE Sno IN(SELECT Sno FROM SC WHERE Cno='2');"; // String sql = "SELECT Sname FROM Student WHERE Sno IN(SELECT Sno FROM SC WHERE Cno='2');";
String sql = "select u.name,a.age from t_user u,t_account a where u.id = a.id and u.id > 12 and u.height between ? and ? and t_account.age >? and a.name in('tom', ? ,'john') order by id desc;"; // String sql = "select u.name,a.age from t_user u,t_account a where u.id = a.id and u.id > 12 and u.height between ? and ? and t_account.age >? and a.name in('tom', ? ,'john') order by id desc;";
// String sql = "select secooStoreDB.t_sequence.`code` from secooStoreDB.t_sequence,secooStoreDB.t_store_fail_mq where secooStoreDB.t_sequence.`code`=secooStoreDB.t_store_fail_mq.topic"; // String sql = "select secooStoreDB.t_sequence.`code` from secooStoreDB.t_sequence,secooStoreDB.t_store_fail_mq where secooStoreDB.t_sequence.`code`=secooStoreDB.t_store_fail_mq.topic";
// String sql = "delete from t_sequence"; // String sql = "delete from t_sequence";
// String sql = "INSERT INTO `secooStoreDB`.`t_sequence`(`secooStoreDB`.`t_sequence`.`name`, `t_sequence`.`current_value`, `increment`, `code`) VALUES (?, 4551, 1, 2),('t_store_category', ?, 1, ?);"; // String sql = "INSERT INTO `secooStoreDB`.`t_sequence`(`secooStoreDB`.`t_sequence`.`name`, `t_sequence`.`current_value`, `increment`, `code`) VALUES (?, 4551, 1, 2),('t_store_category', ?, 1, ?);";
// String sql = "INSERT INTO `secooStoreDB`.`t_sequence` VALUES ('t_store', 4551, 1, 2),('t_store_category', 65015, 1, 1)"; // String sql = "INSERT INTO `secooStoreDB`.`t_sequence` VALUES ('t_store', 4551, 1, 2),('t_store_category', 65015, 1, 1)";
// String sql = "update t_user set name=?,age=10 where id = id and id > 12"; // String sql = "update t_user set name=?,age=10 where id = ? and id > 12";
// String sql = "update t_user set name=?,age=10 where id = ? and id > 12;update t_user set name=?,age=10 where id = ? and id < 12;";
// String sql = "update t_user u,t_account a set u.name=?,a.age=10 where u.id = a.id and u.id > 12 and a.age >?"; // String sql = "update t_user u,t_account a set u.name=?,a.age=10 where u.id = a.id and u.id > 12 and a.age >?";
// String sql = "update t_user u,t_account a set t_user.name=?,a.age=10 where u.id = a.id and u.id > 12 and a.age >?"; // String sql = "update t_user u,t_account a set t_user.name=?,a.age=10 where u.id = a.id and u.id > 12 and a.age >?";
// String sql = "update t_user u,t_account a set secooStoreDB.t_account.name=?,a.age=10 where secooStoreDB.t_account.id = a.id and u.id > 12 and a.age >?"; // String sql = "update t_user u,t_account a set secooStoreDB.t_account.name=?,a.age=10 where secooStoreDB.t_account.id = a.id and u.id > 12 and a.age >?";
...@@ -42,7 +43,7 @@ public class SQLParserTest { ...@@ -42,7 +43,7 @@ public class SQLParserTest {
List<Parameter> encryptColumnParameters = new ArrayList<>(); List<Parameter> encryptColumnParameters = new ArrayList<>();
// MySqlSecurityParameterVisitor visitor = new MySqlSecurityParameterVisitor(null,encryptColumnParameters); // MySqlSecurityParameterVisitor visitor = new MySqlSecurityParameterVisitor(null,encryptColumnParameters);
MySqlSecurityParameterVisitor2 visitor = new MySqlSecurityParameterVisitor2(null,encryptColumnParameters); MySqlRewriteParameterVisitor visitor = new MySqlRewriteParameterVisitor(null,encryptColumnParameters);
stmt.accept(visitor); stmt.accept(visitor);
System.out.println("encryptColumnParameters=" + JSON.toJSONString(encryptColumnParameters)); System.out.println("encryptColumnParameters=" + JSON.toJSONString(encryptColumnParameters));
System.out.println(stmt.toString()); System.out.println(stmt.toString());
...@@ -109,5 +110,17 @@ public class SQLParserTest { ...@@ -109,5 +110,17 @@ public class SQLParserTest {
public static void main(String[] args) { public static void main(String[] args) {
testVisitor(); testVisitor();
// SQLExpr sqlExpr = SQLUtils.toMySqlExpr("?");
// System.out.println(SQLEvalVisitorUtils.eval(JdbcConstants.MYSQL, sqlExpr, null, false));
/*System.out.println((SQLUtils.toMySqlExpr("u").toString()
+ "="
+ SQLUtils.toMySqlExpr("xxx").toString()).toString());
MySqlCharExpr value = new MySqlCharExpr();
value.setCharset("utf-8");
value.setText("xxx");
System.out.println(value.toString());*/
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment