Commit 16693fb4 by 郑冰晶

数据库加密组件

parent 53c87e57
...@@ -107,6 +107,7 @@ public class ActivityController { ...@@ -107,6 +107,7 @@ public class ActivityController {
PriceRuleTask updatePriceRuleTask = new PriceRuleTask(); PriceRuleTask updatePriceRuleTask = new PriceRuleTask();
updatePriceRuleTask.setId(id); updatePriceRuleTask.setId(id);
updatePriceRuleTask.setCreator(priceRuleTaskCriteria.getCreator()); updatePriceRuleTask.setCreator(priceRuleTaskCriteria.getCreator());
updatePriceRuleTask.setModifier(priceRuleTaskCriteria.getModifier());
updatePriceRuleTasks.add(updatePriceRuleTask); updatePriceRuleTasks.add(updatePriceRuleTask);
} }
...@@ -114,4 +115,10 @@ public class ActivityController { ...@@ -114,4 +115,10 @@ public class ActivityController {
log.debug("rows={}", rows); log.debug("rows={}", rows);
return String.valueOf(rows); return String.valueOf(rows);
} }
@RequestMapping("testJoin")
public String testJoin(@RequestBody PriceRuleTaskCriteria priceRuleTaskCriteria) {
this.priceRuleTaskMapper.testJoin(priceRuleTaskCriteria);
return null;
}
} }
...@@ -29,5 +29,7 @@ public interface PriceRuleTaskMapper { ...@@ -29,5 +29,7 @@ public interface PriceRuleTaskMapper {
int queryPriceRuleTaskCount(PriceRuleTaskCriteria priceRuleTaskCriteria); int queryPriceRuleTaskCount(PriceRuleTaskCriteria priceRuleTaskCriteria);
List<PriceRuleTask> queryExecutePriceRuleTaskList(PriceRuleTaskCriteria priceRuleTaskCriteria); List<PriceRuleTask> queryExecutePriceRuleTaskList(PriceRuleTaskCriteria priceRuleTaskCriteria);
List<PriceRuleTask> testJoin(PriceRuleTaskCriteria priceRuleTaskCriteria);
} }
\ No newline at end of file
...@@ -482,4 +482,10 @@ ...@@ -482,4 +482,10 @@
</select> </select>
<select id="testJoin" resultMap="BaseResultMap" parameterType="com.secoo.mall.datasource.security.demo.bean.PriceRuleTaskCriteria">
select prt.id,br.brand_id,prt.remark from t_price_rule_task prt
inner join t_brand_rule br on prt.brand_id = br.brand_id
where prt.task_id=#{taskId}
</select>
</mapper> </mapper>
\ No newline at end of file
...@@ -57,7 +57,7 @@ public class ApolloPropertyProviderAlgorithm implements PropertyProviderAlgorith ...@@ -57,7 +57,7 @@ public class ApolloPropertyProviderAlgorithm implements PropertyProviderAlgorith
.collect(Collectors.toList()); .collect(Collectors.toList());
propertyNameSections = propertyNameSections.stream().filter(key -> key.length == 4).collect(Collectors.toList()); propertyNameSections = propertyNameSections.stream().filter(key -> key.length == 4).collect(Collectors.toList());
Map<String, Class> fieldMap = FieldUtil.getAllFieldsList(ColumnRule.class).stream().collect(Collectors.toMap(Field::getName, Field::getType, (key1, key2) -> key1)); Map<String, Class<?>> fieldMap = FieldUtil.getAllFieldsList(ColumnRule.class).stream().collect(Collectors.toMap(Field::getName, Field::getType, (key1, key2) -> key1));
Set<String> notSupportPropertySet = new HashSet<>(); Set<String> notSupportPropertySet = new HashSet<>();
// db // db
...@@ -124,10 +124,14 @@ public class ApolloPropertyProviderAlgorithm implements PropertyProviderAlgorith ...@@ -124,10 +124,14 @@ public class ApolloPropertyProviderAlgorithm implements PropertyProviderAlgorith
if(StringUtils.isBlank(columnRule.getEncryptType()) || StringUtils.isBlank(columnRule.getEncryptKey()) || StringUtils.isBlank(columnRule.getCipherColumn())){ if(StringUtils.isBlank(columnRule.getEncryptType()) || StringUtils.isBlank(columnRule.getEncryptKey()) || StringUtils.isBlank(columnRule.getCipherColumn())){
continue; continue;
} }
// 加密器
Properties properties = new Properties(); Properties properties = new Properties();
properties.setProperty(EncryptAlgorithm.ENCRYPT_KEY,columnRule.getEncryptKey()); properties.setProperty(EncryptAlgorithm.ENCRYPT_KEY,columnRule.getEncryptKey());
EncryptAlgorithm encryptAlgorithm = SecurityAlgorithmFactory.getObject(EncryptAlgorithm.class,columnRule.getEncryptType().toLowerCase(), properties); EncryptAlgorithm encryptAlgorithm = SecurityAlgorithmFactory.getObject(EncryptAlgorithm.class,columnRule.getEncryptType().toLowerCase(), properties);
encryptAlgorithm.init(); if(encryptAlgorithm == null){
continue;
}
columnRule.setEncryptAlgorithm(encryptAlgorithm); columnRule.setEncryptAlgorithm(encryptAlgorithm);
if(tableRule.getColumnRules() == null){ if(tableRule.getColumnRules() == null){
...@@ -135,11 +139,19 @@ public class ApolloPropertyProviderAlgorithm implements PropertyProviderAlgorith ...@@ -135,11 +139,19 @@ public class ApolloPropertyProviderAlgorithm implements PropertyProviderAlgorith
} }
tableRule.getColumnRules().add(columnRule); tableRule.getColumnRules().add(columnRule);
} }
if(tableRule.getColumnRules() == null || tableRule.getColumnRules().isEmpty()){
continue;
}
if(dbRule.getTableRules() == null){ if(dbRule.getTableRules() == null){
dbRule.setTableRules(new HashSet<>()); dbRule.setTableRules(new HashSet<>());
} }
dbRule.getTableRules().add(tableRule); dbRule.getTableRules().add(tableRule);
} }
if(dbRule.getTableRules() == null || dbRule.getTableRules().isEmpty()){
continue;
}
dbRules.add(dbRule); dbRules.add(dbRule);
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
throw new SecurityBizException("!!! Load security rule from apollo error !!!",e); throw new SecurityBizException("!!! Load security rule from apollo error !!!",e);
......
...@@ -13,9 +13,10 @@ import com.alibaba.druid.util.Utils; ...@@ -13,9 +13,10 @@ import com.alibaba.druid.util.Utils;
import com.secoo.mall.datasource.security.constant.PropertyProviderType; import com.secoo.mall.datasource.security.constant.PropertyProviderType;
import com.secoo.mall.datasource.security.exception.SecurityBizException; import com.secoo.mall.datasource.security.exception.SecurityBizException;
import com.secoo.mall.datasource.security.rule.ColumnRule; import com.secoo.mall.datasource.security.rule.ColumnRule;
import com.secoo.mall.datasource.security.util.CollectionUtil;
import com.secoo.mall.datasource.security.util.SecurityUtil; import com.secoo.mall.datasource.security.util.SecurityUtil;
import com.secoo.mall.datasource.security.visitor.MySqlSecurityParameterVisitor; import com.secoo.mall.datasource.security.visitor.MySqlEncryptParameterVisitor;
import com.secoo.mall.datasource.security.visitor.Parameter; import com.secoo.mall.datasource.security.visitor.model.Parameter;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
...@@ -27,18 +28,33 @@ import java.sql.ResultSetMetaData; ...@@ -27,18 +28,33 @@ import java.sql.ResultSetMetaData;
import java.sql.SQLException; import java.sql.SQLException;
import java.sql.Types; import java.sql.Types;
import java.util.*; import java.util.*;
import java.util.stream.Collectors;
@AutoLoad @AutoLoad
public class SecurityFilter extends SecurityFilterEventAdapter { public class SecurityFilter extends SecurityFilterEventAdapter {
private static final Logger log = LoggerFactory.getLogger(SecurityFilter.class); private static final Logger log = LoggerFactory.getLogger(SecurityFilter.class);
private SecurityFilterContext securityFilterContext; private static final ThreadLocal<List<Parameter>> ENCRYPT_PARAMETERS_TL = new ThreadLocal<>();
private final SecurityFilterContext securityFilterContext;
public SecurityFilter(){ public SecurityFilter(){
super(); super();
this.securityFilterContext = new SecurityFilterContext(PropertyProviderType.APOLLO); this.securityFilterContext = new SecurityFilterContext(PropertyProviderType.APOLLO);
} }
public List<Parameter> getEncryptParameters(){
return ENCRYPT_PARAMETERS_TL.get();
}
public void setEncryptParameters(List<Parameter> encryptParameters){
ENCRYPT_PARAMETERS_TL.set(encryptParameters);
}
public void clearEncryptParameters(){
ENCRYPT_PARAMETERS_TL.remove();
}
public void init(DataSourceProxy dataSource) { public void init(DataSourceProxy dataSource) {
String dbName = SecurityUtil.findDataBaseNameByUrl(dataSource.getUrl()); String dbName = SecurityUtil.findDataBaseNameByUrl(dataSource.getUrl());
if(StringUtils.isNotBlank(dbName)){ if(StringUtils.isNotBlank(dbName)){
...@@ -281,21 +297,16 @@ public class SecurityFilter extends SecurityFilterEventAdapter { ...@@ -281,21 +297,16 @@ public class SecurityFilter extends SecurityFilterEventAdapter {
return securityFilterContext.getColumnRule(dbName,tableName,columnName); return securityFilterContext.getColumnRule(dbName,tableName,columnName);
} }
protected void executeBefore(FilterChain chain,StatementProxy statement, String sql) { protected String rewritePrepareSql(FilterChain chain,String sql) {
String dbName = securityFilterContext.getDbName(chain.getDataSource().getUrl()); String dbName = securityFilterContext.getDbName(chain.getDataSource().getUrl());
Map<String,Map<String,ColumnRule>> tableRuleMap = securityFilterContext.getTableRuleMap(dbName); Map<String,Map<String,ColumnRule>> tableRuleMap = securityFilterContext.getTableRuleMap(dbName);
if(tableRuleMap == null){ if(tableRuleMap == null){
log.debug("过滤非加密db,dbName={},sql={}",dbName, sql); log.debug("过滤非加密db,dbName={},sql={}",dbName, sql);
return; return sql;
}
if (!(statement instanceof PreparedStatementProxy)) {
log.debug("过滤非PreparedStatement,dbName={},sql={}",dbName, sql);
return;
} }
PreparedStatementProxyImpl preparedStatement = (PreparedStatementProxyImpl) statement; // 解析
// 解析sql List<Parameter> allEncryptParameters = null;
List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, chain.getDataSource().getDbType()); List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, chain.getDataSource().getDbType());
for (SQLStatement stmt : stmtList) { for (SQLStatement stmt : stmtList) {
if (!(stmt instanceof SQLSelectStatement) if (!(stmt instanceof SQLSelectStatement)
...@@ -306,36 +317,116 @@ public class SecurityFilter extends SecurityFilterEventAdapter { ...@@ -306,36 +317,116 @@ public class SecurityFilter extends SecurityFilterEventAdapter {
continue; continue;
} }
MySqlSecurityParameterVisitor visitor = new MySqlSecurityParameterVisitor(tableRuleMap); MySqlEncryptParameterVisitor visitor = new MySqlEncryptParameterVisitor(tableRuleMap);
stmt.accept(visitor); stmt.accept(visitor);
List<Parameter> encryptParameters = visitor.getEncryptParameters(); List<Parameter> encryptParameters = visitor.getEncryptParameters();
// 加密
if(encryptParameters != null && !encryptParameters.isEmpty()){ if(encryptParameters != null && !encryptParameters.isEmpty()){
for(Parameter parameter : encryptParameters){ allEncryptParameters = CollectionUtil.addAll(allEncryptParameters,encryptParameters);
encrypt(parameter.getColumnRule(), preparedStatement, parameter.getJdbcIndex()); }
}
if(allEncryptParameters == null || allEncryptParameters.isEmpty()){
return sql;
}
// 重写
StringBuilder newSql = new StringBuilder();
stmtList.forEach(e -> newSql.append(e.toString()));
String newSqlStr = newSql.toString();
log.debug("重写sql={},allEncryptParameters={}",newSql,allEncryptParameters);
this.setEncryptParameters(allEncryptParameters);
return newSqlStr;
}
protected void rewritePrepareParameter(FilterChain chain,StatementProxy statement) {
try {
String dbName = securityFilterContext.getDbName(chain.getDataSource().getUrl());
Map<String,Map<String,ColumnRule>> tableRuleMap = securityFilterContext.getTableRuleMap(dbName);
if(tableRuleMap == null){
log.debug("过滤非加密db,dbName={}",dbName);
return;
}
if (!(statement instanceof PreparedStatementProxy)) {
log.debug("过滤非PreparedStatement,dbName={}",dbName);
return;
}
PreparedStatementProxyImpl preparedStatement = (PreparedStatementProxyImpl) statement;
Map<Integer,JdbcParameter> jdbcParameters = preparedStatement.getParameters();
if(jdbcParameters == null || jdbcParameters.isEmpty()){
return;
}
// 获取加密参数
List<Parameter> allEncryptParameters = this.getEncryptParameters();
if(allEncryptParameters == null || allEncryptParameters.isEmpty()){
return;
}
// 加密明文
allEncryptParameters.forEach(e -> {
JdbcParameter jdbcParameter = jdbcParameters.get(e.getParameterIndex());
if(jdbcParameter != null){
e.setPlainValue(jdbcParameter.getValue());
e.setCipherValue(this.encrypt(e.getColumnRule(), e.getPlainValue()));
}
});
Map<Integer,Parameter> allEncryptParameterMap = allEncryptParameters.stream().collect(Collectors.toMap(Parameter::getParameterIndex, e -> e));
// 重建jdbc参数列表
long addJdbcParameterCount = allEncryptParameters.stream().filter(e -> e.getAddParameterIndex() != null && e.getAddParameterIndex() >= 0).count();
List<JdbcParameter> newJdbcParameters = new ArrayList<>(jdbcParameters.size() + Long.valueOf(addJdbcParameterCount).intValue());
for(Map.Entry<Integer,JdbcParameter> jdbcParameterEntry:jdbcParameters.entrySet()){
Integer index = jdbcParameterEntry.getKey();
JdbcParameter jdbcParameter = jdbcParameterEntry.getValue();
Parameter encryptParameter = allEncryptParameterMap.get(index);
if(encryptParameter == null || encryptParameter.getCipherValue() == null){
newJdbcParameters.add(index,jdbcParameter);
}else{
JdbcParameter newJdbcParameter = null;
if (encryptParameter.getCipherValue() == null) {
newJdbcParameter = JdbcParameterNull.VARCHAR;
}else if (encryptParameter.getCipherValue().length() == 0) {
newJdbcParameter = JdbcParameterString.empty;
}else{
newJdbcParameter = new JdbcParameterString(encryptParameter.getCipherValue());
}
newJdbcParameters.add(index,newJdbcParameter);
}
}
// 新增明文参数
for(int i=0; i<allEncryptParameters.size(); i++){
Parameter parameter = allEncryptParameters.get(i);
if(parameter.getAddParameterIndex() != null && parameter.getAddParameterIndex() >= 0){
newJdbcParameters.add(parameter.getParameterIndex() + i, jdbcParameters.get(parameter.getParameterIndex()));
} }
String sql = ((PreparedStatementProxyImpl) statement).getSql();
log.debug("加密sql={}\n加密参数={}",sql,newJdbcParameters);
}
// 重写jdbc参数列表
for(int i=0; i<newJdbcParameters.size(); i++){
preparedStatement.setParameter(i + 1,newJdbcParameters.get(i));
} }
} finally {
this.clearEncryptParameters();
} }
} }
/** /**
* 加密 * 加密
* @param columnRule
* @param preparedStatement
* @param index
*/ */
private void encrypt(ColumnRule columnRule, PreparedStatementProxyImpl preparedStatement,int index) { private String encrypt(ColumnRule columnRule, Object plainText) {
JdbcParameter jdbcParameter = preparedStatement.getParameter(index);
final Object plainText = jdbcParameter.getValue();
if (plainText == null) { if (plainText == null) {
return; return null;
} }
try { try {
String cipherText = columnRule.getEncryptAlgorithm().encrypt(plainText); String cipherText = columnRule.getEncryptAlgorithm().encrypt(plainText);
preparedStatement.setObject(index + 1,cipherText);
log.debug("字段加密:columnRule={},plainText={},cipherText={}", columnRule, plainText, cipherText); log.debug("字段加密:columnRule={},plainText={},cipherText={}", columnRule, plainText, cipherText);
return cipherText;
} catch (Exception e) { } catch (Exception e) {
String errorMsg = "字段加密异常:columnRule="+columnRule+",plainText="+plainText; String errorMsg = "字段加密异常:columnRule="+columnRule+",plainText="+plainText;
log.error(errorMsg); log.error(errorMsg);
......
...@@ -7,6 +7,9 @@ import java.util.Map; ...@@ -7,6 +7,9 @@ import java.util.Map;
public class CollectionUtil { public class CollectionUtil {
public static <T> List<T> add(List<T> list, T t){ public static <T> List<T> add(List<T> list, T t){
if(t == null){
return null;
}
if(list == null){ if(list == null){
list = new ArrayList<>(); list = new ArrayList<>();
} }
...@@ -15,6 +18,9 @@ public class CollectionUtil { ...@@ -15,6 +18,9 @@ public class CollectionUtil {
} }
public static <T> List<T> addAll(List<T> list, List<T> subList){ public static <T> List<T> addAll(List<T> list, List<T> subList){
if(subList == null){
return null;
}
if(list == null){ if(list == null){
list = new ArrayList<>(); list = new ArrayList<>();
} }
...@@ -23,6 +29,9 @@ public class CollectionUtil { ...@@ -23,6 +29,9 @@ public class CollectionUtil {
} }
public static <K,V> Map<K,V> add(Map<K,V> map, K k, V v){ public static <K,V> Map<K,V> add(Map<K,V> map, K k, V v){
if(v == null){
return null;
}
if(map == null){ if(map == null){
map = new HashMap<>(); map = new HashMap<>();
} }
...@@ -32,7 +41,10 @@ public class CollectionUtil { ...@@ -32,7 +41,10 @@ public class CollectionUtil {
public static <K,V> Map<K,V> addAll(Map<K,V> map, Map<K,V> subMap){ public static <K,V> Map<K,V> addAll(Map<K,V> map, Map<K,V> subMap){
if(subMap == null){ if(subMap == null){
subMap = new HashMap<>(); return null;
}
if(map == null){
map = new HashMap<>();
} }
map.putAll(subMap); map.putAll(subMap);
return map; return map;
......
package com.secoo.mall.datasource.security.visitor;
import com.secoo.mall.datasource.security.rule.ColumnRule;
public class Parameter extends Column {
private Object value;
private int jdbcIndex;
private Integer addJdbcIndex;
public Parameter(String tableName, String ownerName, String columnName, String columnAlias,Integer columnIndex, ColumnRule columnRule,
Object value, int jdbcIndex, Integer addJdbcIndex) {
super(tableName,ownerName,columnName,columnAlias,columnIndex,columnRule);
this.value = value;
this.jdbcIndex = jdbcIndex;
this.addJdbcIndex = addJdbcIndex;
}
public Parameter(Column column,
Object value, int jdbcIndex, Integer addJdbcIndex) {
this(column.getTableName(),column.getOwnerName(),column.getColumnName(),column.getColumnAlias(),column.getColumnIndex(),column.getColumnRule(),value,jdbcIndex,addJdbcIndex);
}
public Object getValue() {
return value;
}
public void setValue(Object value) {
this.value = value;
}
public int getJdbcIndex() {
return jdbcIndex;
}
public void setJdbcIndex(int jdbcIndex) {
this.jdbcIndex = jdbcIndex;
}
public Integer getAddJdbcIndex() {
return addJdbcIndex;
}
public void setAddJdbcIndex(Integer addJdbcIndex) {
this.addJdbcIndex = addJdbcIndex;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof Parameter)) return false;
if (!super.equals(o)) return false;
Parameter parameter = (Parameter) o;
if (getJdbcIndex() != parameter.getJdbcIndex()) return false;
if (getValue() != null ? !getValue().equals(parameter.getValue()) : parameter.getValue() != null) return false;
return getAddJdbcIndex() != null ? getAddJdbcIndex().equals(parameter.getAddJdbcIndex()) : parameter.getAddJdbcIndex() == null;
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + (getValue() != null ? getValue().hashCode() : 0);
result = 31 * result + getJdbcIndex();
result = 31 * result + (getAddJdbcIndex() != null ? getAddJdbcIndex().hashCode() : 0);
return result;
}
}
package com.secoo.mall.datasource.security.visitor; package com.secoo.mall.datasource.security.visitor.model;
import com.secoo.mall.datasource.security.rule.ColumnRule; import com.secoo.mall.datasource.security.rule.ColumnRule;
......
package com.secoo.mall.datasource.security.visitor.model;
import com.secoo.mall.datasource.security.rule.ColumnRule;
public class Parameter extends Column {
private Object plainValue;
private String cipherValue;
private int parameterIndex;
private Integer addParameterIndex;
public Parameter(String tableName, String ownerName, String columnName, String columnAlias, Integer columnIndex, ColumnRule columnRule,
int parameterIndex, Integer addParameterIndex) {
super(tableName,ownerName,columnName,columnAlias,columnIndex,columnRule);
this.parameterIndex = parameterIndex;
this.addParameterIndex = addParameterIndex;
}
public Parameter(Column column,
int parameterIndex, Integer addParameterIndex) {
this(column.getTableName(),column.getOwnerName(),column.getColumnName(),column.getColumnAlias(),column.getColumnIndex(),column.getColumnRule(), parameterIndex, addParameterIndex);
}
public int getParameterIndex() {
return parameterIndex;
}
public void setParameterIndex(int parameterIndex) {
this.parameterIndex = parameterIndex;
}
public Integer getAddParameterIndex() {
return addParameterIndex;
}
public void setAddParameterIndex(Integer addParameterIndex) {
this.addParameterIndex = addParameterIndex;
}
public Object getPlainValue() {
return plainValue;
}
public void setPlainValue(Object plainValue) {
this.plainValue = plainValue;
}
public String getCipherValue() {
return cipherValue;
}
public void setCipherValue(String cipherValue) {
this.cipherValue = cipherValue;
}
@Override
public String toString() {
return "Parameter{" +
"tableName='" + tableName + '\'' +
", ownerName='" + ownerName + '\'' +
", columnName='" + columnName + '\'' +
", columnAlias='" + columnAlias + '\'' +
", columnIndex=" + columnIndex +
", columnRule=" + columnRule +
", plainValue=" + plainValue +
", cipherValue='" + cipherValue + '\'' +
", parameterIndex=" + parameterIndex +
", addParameterIndex=" + addParameterIndex +
'}';
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof Parameter)) return false;
if (!super.equals(o)) return false;
Parameter parameter = (Parameter) o;
if (getParameterIndex() != parameter.getParameterIndex()) return false;
if (getPlainValue() != null ? !getPlainValue().equals(parameter.getPlainValue()) : parameter.getPlainValue() != null)
return false;
if (getCipherValue() != null ? !getCipherValue().equals(parameter.getCipherValue()) : parameter.getCipherValue() != null)
return false;
return getAddParameterIndex() != null ? getAddParameterIndex().equals(parameter.getAddParameterIndex()) : parameter.getAddParameterIndex() == null;
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + (getPlainValue() != null ? getPlainValue().hashCode() : 0);
result = 31 * result + (getCipherValue() != null ? getCipherValue().hashCode() : 0);
result = 31 * result + getParameterIndex();
result = 31 * result + (getAddParameterIndex() != null ? getAddParameterIndex().hashCode() : 0);
return result;
}
}
...@@ -7,14 +7,13 @@ import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr; ...@@ -7,14 +7,13 @@ import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr;
import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr; import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement; import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem; import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem;
import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlCharExpr;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.util.JdbcConstants;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.secoo.mall.datasource.security.visitor.MySqlSecurityParameterVisitor; import com.secoo.mall.datasource.security.visitor.MySqlRewriteParameterVisitor;
import com.secoo.mall.datasource.security.visitor.MySqlSecurityParameterVisitor2; import com.secoo.mall.datasource.security.visitor.model.Parameter;
import com.secoo.mall.datasource.security.visitor.Parameter;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
...@@ -22,13 +21,15 @@ import java.util.List; ...@@ -22,13 +21,15 @@ import java.util.List;
public class SQLParserTest { public class SQLParserTest {
public static void testVisitor(){ public static void testVisitor(){
String sql = "select * from t_user where id=?";
// String sql = "SELECT Sname FROM Student WHERE Sno IN(SELECT Sno FROM SC WHERE Cno='2');"; // String sql = "SELECT Sname FROM Student WHERE Sno IN(SELECT Sno FROM SC WHERE Cno='2');";
String sql = "select u.name,a.age from t_user u,t_account a where u.id = a.id and u.id > 12 and u.height between ? and ? and t_account.age >? and a.name in('tom', ? ,'john') order by id desc;"; // String sql = "select u.name,a.age from t_user u,t_account a where u.id = a.id and u.id > 12 and u.height between ? and ? and t_account.age >? and a.name in('tom', ? ,'john') order by id desc;";
// String sql = "select secooStoreDB.t_sequence.`code` from secooStoreDB.t_sequence,secooStoreDB.t_store_fail_mq where secooStoreDB.t_sequence.`code`=secooStoreDB.t_store_fail_mq.topic"; // String sql = "select secooStoreDB.t_sequence.`code` from secooStoreDB.t_sequence,secooStoreDB.t_store_fail_mq where secooStoreDB.t_sequence.`code`=secooStoreDB.t_store_fail_mq.topic";
// String sql = "delete from t_sequence"; // String sql = "delete from t_sequence";
// String sql = "INSERT INTO `secooStoreDB`.`t_sequence`(`secooStoreDB`.`t_sequence`.`name`, `t_sequence`.`current_value`, `increment`, `code`) VALUES (?, 4551, 1, 2),('t_store_category', ?, 1, ?);"; // String sql = "INSERT INTO `secooStoreDB`.`t_sequence`(`secooStoreDB`.`t_sequence`.`name`, `t_sequence`.`current_value`, `increment`, `code`) VALUES (?, 4551, 1, 2),('t_store_category', ?, 1, ?);";
// String sql = "INSERT INTO `secooStoreDB`.`t_sequence` VALUES ('t_store', 4551, 1, 2),('t_store_category', 65015, 1, 1)"; // String sql = "INSERT INTO `secooStoreDB`.`t_sequence` VALUES ('t_store', 4551, 1, 2),('t_store_category', 65015, 1, 1)";
// String sql = "update t_user set name=?,age=10 where id = id and id > 12"; // String sql = "update t_user set name=?,age=10 where id = ? and id > 12";
// String sql = "update t_user set name=?,age=10 where id = ? and id > 12;update t_user set name=?,age=10 where id = ? and id < 12;";
// String sql = "update t_user u,t_account a set u.name=?,a.age=10 where u.id = a.id and u.id > 12 and a.age >?"; // String sql = "update t_user u,t_account a set u.name=?,a.age=10 where u.id = a.id and u.id > 12 and a.age >?";
// String sql = "update t_user u,t_account a set t_user.name=?,a.age=10 where u.id = a.id and u.id > 12 and a.age >?"; // String sql = "update t_user u,t_account a set t_user.name=?,a.age=10 where u.id = a.id and u.id > 12 and a.age >?";
// String sql = "update t_user u,t_account a set secooStoreDB.t_account.name=?,a.age=10 where secooStoreDB.t_account.id = a.id and u.id > 12 and a.age >?"; // String sql = "update t_user u,t_account a set secooStoreDB.t_account.name=?,a.age=10 where secooStoreDB.t_account.id = a.id and u.id > 12 and a.age >?";
...@@ -42,7 +43,7 @@ public class SQLParserTest { ...@@ -42,7 +43,7 @@ public class SQLParserTest {
List<Parameter> encryptColumnParameters = new ArrayList<>(); List<Parameter> encryptColumnParameters = new ArrayList<>();
// MySqlSecurityParameterVisitor visitor = new MySqlSecurityParameterVisitor(null,encryptColumnParameters); // MySqlSecurityParameterVisitor visitor = new MySqlSecurityParameterVisitor(null,encryptColumnParameters);
MySqlSecurityParameterVisitor2 visitor = new MySqlSecurityParameterVisitor2(null,encryptColumnParameters); MySqlRewriteParameterVisitor visitor = new MySqlRewriteParameterVisitor(null,encryptColumnParameters);
stmt.accept(visitor); stmt.accept(visitor);
System.out.println("encryptColumnParameters=" + JSON.toJSONString(encryptColumnParameters)); System.out.println("encryptColumnParameters=" + JSON.toJSONString(encryptColumnParameters));
System.out.println(stmt.toString()); System.out.println(stmt.toString());
...@@ -109,5 +110,17 @@ public class SQLParserTest { ...@@ -109,5 +110,17 @@ public class SQLParserTest {
public static void main(String[] args) { public static void main(String[] args) {
testVisitor(); testVisitor();
// SQLExpr sqlExpr = SQLUtils.toMySqlExpr("?");
// System.out.println(SQLEvalVisitorUtils.eval(JdbcConstants.MYSQL, sqlExpr, null, false));
/*System.out.println((SQLUtils.toMySqlExpr("u").toString()
+ "="
+ SQLUtils.toMySqlExpr("xxx").toString()).toString());
MySqlCharExpr value = new MySqlCharExpr();
value.setCharset("utf-8");
value.setText("xxx");
System.out.println(value.toString());*/
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment