Commit cb825c03 by 郑冰晶

数据库加密组件

parent 4cf0013f
...@@ -5,6 +5,10 @@ import com.alibaba.druid.filter.FilterChain; ...@@ -5,6 +5,10 @@ import com.alibaba.druid.filter.FilterChain;
import com.alibaba.druid.proxy.jdbc.*; import com.alibaba.druid.proxy.jdbc.*;
import com.alibaba.druid.sql.SQLUtils; import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLStatement; import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.util.Utils; import com.alibaba.druid.util.Utils;
import com.secoo.mall.datasource.security.constant.PropertyProviderType; import com.secoo.mall.datasource.security.constant.PropertyProviderType;
import com.secoo.mall.datasource.security.exception.SecurityBizException; import com.secoo.mall.datasource.security.exception.SecurityBizException;
...@@ -278,22 +282,36 @@ public class SecurityFilter extends SecurityFilterEventAdapter { ...@@ -278,22 +282,36 @@ public class SecurityFilter extends SecurityFilterEventAdapter {
} }
protected void executeBefore(FilterChain chain,StatementProxy statement, String sql) { protected void executeBefore(FilterChain chain,StatementProxy statement, String sql) {
String dbName = securityFilterContext.getDbName(chain.getDataSource().getUrl());
Map<String,Map<String,ColumnRule>> tableRuleMap = securityFilterContext.getTableRuleMap(dbName);
if(tableRuleMap == null){
log.debug("过滤非加密db,dbName={},sql={}",dbName, sql);
return;
}
if (!(statement instanceof PreparedStatementProxy)) { if (!(statement instanceof PreparedStatementProxy)) {
log.debug("过滤statement:{}", sql); log.debug("过滤非PreparedStatement,dbName={},sql={}",dbName, sql);
return; return;
} }
PreparedStatementProxyImpl preparedStatement = (PreparedStatementProxyImpl) statement; PreparedStatementProxyImpl preparedStatement = (PreparedStatementProxyImpl) statement;
// 解析sql // 解析sql
String dbName = securityFilterContext.getDbName(chain.getDataSource().getUrl());
List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, chain.getDataSource().getDbType()); List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, chain.getDataSource().getDbType());
for (SQLStatement stmt : stmtList) { for (SQLStatement stmt : stmtList) {
List<Parameter> encryptColumnParameters = new ArrayList<>(); if (!(stmt instanceof SQLSelectStatement)
MySqlSecurityParameterVisitor visitor = new MySqlSecurityParameterVisitor(securityFilterContext.getTableRuleMap(dbName),encryptColumnParameters); && !(stmt instanceof MySqlUpdateStatement)
&& !(stmt instanceof MySqlInsertStatement)
&& !(stmt instanceof MySqlDeleteStatement)) {
log.warn("过滤非[insert|delete|update|select]statement,dbName={},sql={}",dbName, sql);
continue;
}
List<Parameter> encryptParameters = new ArrayList<>();
MySqlSecurityParameterVisitor visitor = new MySqlSecurityParameterVisitor(tableRuleMap,encryptParameters);
stmt.accept(visitor); stmt.accept(visitor);
// 加密 // 加密
for(Parameter parameter : encryptColumnParameters){ for(Parameter parameter : encryptParameters){
encrypt(parameter.getColumnRule(), preparedStatement, parameter.getJdbcIndex()); encrypt(parameter.getColumnRule(), preparedStatement, parameter.getJdbcIndex());
} }
} }
......
package com.secoo.mall.datasource.security.util;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class CollectionUtil {
public static <T> List<T> initAdd(List<T> list, T t){
if(list == null){
list = new ArrayList<>();
}
list.add(t);
return list;
}
public static <T> List<T> initAddAll(List<T> list, List<T> subList){
if(list == null){
list = new ArrayList<>();
}
list.addAll(subList);
return list;
}
public static <K,V> Map<K,V> initAdd(Map<K,V> map, K k, V v){
if(map == null){
map = new HashMap<>();
}
map.put(k,v);
return map;
}
public static <K,V> Map<K,V> initAddAll(Map<K,V> map,Map<K,V> subMap){
if(subMap == null){
subMap = new HashMap<>();
}
map.putAll(subMap);
return map;
}
}
package com.secoo.mall.datasource.security.visitor; package com.secoo.mall.datasource.security.visitor;
import com.secoo.mall.datasource.security.rule.ColumnRule;
public class Column { public class Column {
private String tableName; protected String tableName;
private String columnAlias; protected String ownerName;
private String columnName; protected String columnName;
protected String columnAlias;
protected Integer columnIndex;
protected ColumnRule columnRule;
public Column(String tableName, String columnAlias, String columnName) { public Column(String tableName, String ownerName, String columnName, String columnAlias,Integer columnIndex,ColumnRule columnRule) {
this.tableName = tableName; this.tableName = tableName;
this.columnAlias = columnAlias; this.ownerName = ownerName;
this.columnName = columnName; this.columnName = columnName;
this.columnAlias = columnAlias;
this.columnIndex = columnIndex;
this.columnRule = columnRule;
} }
public String getTableName() { public String getTableName() {
...@@ -19,12 +27,12 @@ public class Column { ...@@ -19,12 +27,12 @@ public class Column {
this.tableName = tableName; this.tableName = tableName;
} }
public String getColumnAlias() { public String getOwnerName() {
return columnAlias; return ownerName;
} }
public void setColumnAlias(String columnAlias) { public void setOwnerName(String ownerName) {
this.columnAlias = columnAlias; this.ownerName = ownerName;
} }
public String getColumnName() { public String getColumnName() {
...@@ -35,12 +43,58 @@ public class Column { ...@@ -35,12 +43,58 @@ public class Column {
this.columnName = columnName; this.columnName = columnName;
} }
public String getColumnAlias() {
return columnAlias;
}
public void setColumnAlias(String columnAlias) {
this.columnAlias = columnAlias;
}
public Integer getColumnIndex() {
return columnIndex;
}
public void setColumnIndex(Integer columnIndex) {
this.columnIndex = columnIndex;
}
public ColumnRule getColumnRule() {
return columnRule;
}
public void setColumnRule(ColumnRule columnRule) {
this.columnRule = columnRule;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof Column)) return false;
Column column = (Column) o;
if (getTableName() != null ? !getTableName().equals(column.getTableName()) : column.getTableName() != null)
return false;
if (getOwnerName() != null ? !getOwnerName().equals(column.getOwnerName()) : column.getOwnerName() != null)
return false;
if (getColumnName() != null ? !getColumnName().equals(column.getColumnName()) : column.getColumnName() != null)
return false;
if (getColumnAlias() != null ? !getColumnAlias().equals(column.getColumnAlias()) : column.getColumnAlias() != null)
return false;
if (getColumnIndex() != null ? !getColumnIndex().equals(column.getColumnIndex()) : column.getColumnIndex() != null)
return false;
return getColumnRule() != null ? getColumnRule().equals(column.getColumnRule()) : column.getColumnRule() == null;
}
@Override @Override
public String toString() { public int hashCode() {
return "Column{" + int result = getTableName() != null ? getTableName().hashCode() : 0;
"tableName='" + tableName + '\'' + result = 31 * result + (getOwnerName() != null ? getOwnerName().hashCode() : 0);
", columnAlias='" + columnAlias + '\'' + result = 31 * result + (getColumnName() != null ? getColumnName().hashCode() : 0);
", columnName='" + columnName + '\'' + result = 31 * result + (getColumnAlias() != null ? getColumnAlias().hashCode() : 0);
'}'; result = 31 * result + (getColumnIndex() != null ? getColumnIndex().hashCode() : 0);
result = 31 * result + (getColumnRule() != null ? getColumnRule().hashCode() : 0);
return result;
} }
} }
...@@ -2,45 +2,30 @@ package com.secoo.mall.datasource.security.visitor; ...@@ -2,45 +2,30 @@ package com.secoo.mall.datasource.security.visitor;
import com.secoo.mall.datasource.security.rule.ColumnRule; import com.secoo.mall.datasource.security.rule.ColumnRule;
public class Parameter { public class Parameter extends Column {
private String tableName;
private String columnAlias;
private String columnName;
private int jdbcIndex;
private Object value; private Object value;
private ColumnRule columnRule; private int jdbcIndex;
private Integer addJdbcIndex;
public Parameter(String tableName, String columnAlias, String columnName, int jdbcIndex, Object value,ColumnRule columnRule) { public Parameter(String tableName, String ownerName, String columnName, String columnAlias,Integer columnIndex, ColumnRule columnRule,
this.tableName = tableName; Object value, int jdbcIndex, Integer addJdbcIndex) {
this.columnAlias = columnAlias; super(tableName,ownerName,columnName,columnAlias,columnIndex,columnRule);
this.columnName = columnName;
this.jdbcIndex = jdbcIndex;
this.value = value; this.value = value;
this.columnRule = columnRule; this.jdbcIndex = jdbcIndex;
} this.addJdbcIndex = addJdbcIndex;
public String getTableName() {
return tableName;
}
public void setTableName(String tableName) {
this.tableName = tableName;
}
public String getColumnAlias() {
return columnAlias;
} }
public void setColumnAlias(String columnAlias) { public Parameter(Column column,
this.columnAlias = columnAlias; Object value, int jdbcIndex, Integer addJdbcIndex) {
this(column.getTableName(),column.getOwnerName(),column.getColumnName(),column.getColumnAlias(),column.getColumnIndex(),column.getColumnRule(),value,jdbcIndex,addJdbcIndex);
} }
public String getColumnName() { public Object getValue() {
return columnName; return value;
} }
public void setColumnName(String columnName) { public void setValue(Object value) {
this.columnName = columnName; this.value = value;
} }
public int getJdbcIndex() { public int getJdbcIndex() {
...@@ -51,31 +36,33 @@ public class Parameter { ...@@ -51,31 +36,33 @@ public class Parameter {
this.jdbcIndex = jdbcIndex; this.jdbcIndex = jdbcIndex;
} }
public Object getValue() { public Integer getAddJdbcIndex() {
return value; return addJdbcIndex;
} }
public void setValue(Object value) { public void setAddJdbcIndex(Integer addJdbcIndex) {
this.value = value; this.addJdbcIndex = addJdbcIndex;
} }
public ColumnRule getColumnRule() { @Override
return columnRule; public boolean equals(Object o) {
} if (this == o) return true;
if (!(o instanceof Parameter)) return false;
if (!super.equals(o)) return false;
Parameter parameter = (Parameter) o;
public void setColumnRule(ColumnRule columnRule) { if (getJdbcIndex() != parameter.getJdbcIndex()) return false;
this.columnRule = columnRule; if (getValue() != null ? !getValue().equals(parameter.getValue()) : parameter.getValue() != null) return false;
return getAddJdbcIndex() != null ? getAddJdbcIndex().equals(parameter.getAddJdbcIndex()) : parameter.getAddJdbcIndex() == null;
} }
@Override @Override
public String toString() { public int hashCode() {
return "Parameter{" + int result = super.hashCode();
"tableName='" + tableName + '\'' + result = 31 * result + (getValue() != null ? getValue().hashCode() : 0);
", columnAlias='" + columnAlias + '\'' + result = 31 * result + getJdbcIndex();
", columnName='" + columnName + '\'' + result = 31 * result + (getAddJdbcIndex() != null ? getAddJdbcIndex().hashCode() : 0);
", jdbcIndex=" + jdbcIndex + return result;
", value=" + value +
", columnRule=" + columnRule +
'}';
} }
} }
...@@ -10,8 +10,10 @@ import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem; ...@@ -10,8 +10,10 @@ import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.util.JdbcConstants;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.secoo.mall.datasource.security.visitor.MySqlSecurityParameterVisitor; import com.secoo.mall.datasource.security.visitor.MySqlSecurityParameterVisitor;
import com.secoo.mall.datasource.security.visitor.MySqlSecurityParameterVisitor2;
import com.secoo.mall.datasource.security.visitor.Parameter; import com.secoo.mall.datasource.security.visitor.Parameter;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
...@@ -20,10 +22,11 @@ import java.util.List; ...@@ -20,10 +22,11 @@ import java.util.List;
public class SQLParserTest { public class SQLParserTest {
public static void testVisitor(){ public static void testVisitor(){
// String sql = "select u.name,a.age from t_user u,t_account a where u.id = a.id and u.id > 12 and u.height between ? and ? and t_account.age >? and a.name in('tom', ? ,'john') order by id desc;"; // String sql = "SELECT Sname FROM Student WHERE Sno IN(SELECT Sno FROM SC WHERE Cno='2');";
String sql = "select u.name,a.age from t_user u,t_account a where u.id = a.id and u.id > 12 and u.height between ? and ? and t_account.age >? and a.name in('tom', ? ,'john') order by id desc;";
// String sql = "select secooStoreDB.t_sequence.`code` from secooStoreDB.t_sequence,secooStoreDB.t_store_fail_mq where secooStoreDB.t_sequence.`code`=secooStoreDB.t_store_fail_mq.topic"; // String sql = "select secooStoreDB.t_sequence.`code` from secooStoreDB.t_sequence,secooStoreDB.t_store_fail_mq where secooStoreDB.t_sequence.`code`=secooStoreDB.t_store_fail_mq.topic";
// String sql = "delete from t_sequence"; // String sql = "delete from t_sequence";
String sql = "INSERT INTO `secooStoreDB`.`t_sequence`(`secooStoreDB`.`t_sequence`.`name`, `t_sequence`.`current_value`, `increment`, `code`) VALUES (?, 4551, 1, 2),('t_store_category', ?, 1, ?);"; // String sql = "INSERT INTO `secooStoreDB`.`t_sequence`(`secooStoreDB`.`t_sequence`.`name`, `t_sequence`.`current_value`, `increment`, `code`) VALUES (?, 4551, 1, 2),('t_store_category', ?, 1, ?);";
// String sql = "INSERT INTO `secooStoreDB`.`t_sequence` VALUES ('t_store', 4551, 1, 2),('t_store_category', 65015, 1, 1)"; // String sql = "INSERT INTO `secooStoreDB`.`t_sequence` VALUES ('t_store', 4551, 1, 2),('t_store_category', 65015, 1, 1)";
// String sql = "update t_user set name=?,age=10 where id = id and id > 12"; // String sql = "update t_user set name=?,age=10 where id = id and id > 12";
// String sql = "update t_user u,t_account a set u.name=?,a.age=10 where u.id = a.id and u.id > 12 and a.age >?"; // String sql = "update t_user u,t_account a set u.name=?,a.age=10 where u.id = a.id and u.id > 12 and a.age >?";
...@@ -38,10 +41,12 @@ public class SQLParserTest { ...@@ -38,10 +41,12 @@ public class SQLParserTest {
// MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor(); // MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
List<Parameter> encryptColumnParameters = new ArrayList<>(); List<Parameter> encryptColumnParameters = new ArrayList<>();
MySqlSecurityParameterVisitor visitor = new MySqlSecurityParameterVisitor(null,encryptColumnParameters); // MySqlSecurityParameterVisitor visitor = new MySqlSecurityParameterVisitor(null,encryptColumnParameters);
MySqlSecurityParameterVisitor2 visitor = new MySqlSecurityParameterVisitor2(null,encryptColumnParameters);
stmt.accept(visitor); stmt.accept(visitor);
System.out.println("encryptColumnParameters=" + JSON.toJSONString(encryptColumnParameters)); System.out.println("encryptColumnParameters=" + JSON.toJSONString(encryptColumnParameters));
// System.out.println(visitor.getParameters()); System.out.println(stmt.toString());
// System.out.println(visitor.getConditions());
if(CollectionUtils.isNotEmpty(encryptColumnParameters)){ if(CollectionUtils.isNotEmpty(encryptColumnParameters)){
return; return;
} }
...@@ -56,7 +61,7 @@ public class SQLParserTest { ...@@ -56,7 +61,7 @@ public class SQLParserTest {
}else if(stmt instanceof MySqlDeleteStatement){ }else if(stmt instanceof MySqlDeleteStatement){
MySqlDeleteStatement deleteStmt = (MySqlDeleteStatement) stmt; MySqlDeleteStatement deleteStmt = (MySqlDeleteStatement) stmt;
System.out.println("---"); System.out.println("---");
}else if (stmt instanceof MySqlInsertStatement) { }else if (stmt instanceof MySqlInsertStatement){
MySqlInsertStatement insertStmt = (MySqlInsertStatement) stmt; MySqlInsertStatement insertStmt = (MySqlInsertStatement) stmt;
String tableName = insertStmt.getTableName().getSimpleName(); String tableName = insertStmt.getTableName().getSimpleName();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment