Commit 20f3fa99 by 郑冰晶

数据库加密组件

parent 293e7d19
package com.secoo.mall.datasource.security.filter; package com.secoo.mall.datasource.security.filter;
import com.alibaba.druid.proxy.jdbc.ResultSetProxy; import com.alibaba.druid.filter.AutoLoad;
import com.alibaba.druid.proxy.jdbc.StatementProxy; import com.alibaba.druid.proxy.jdbc.*;
import com.alibaba.druid.sql.SQLUtils; import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement; import com.alibaba.druid.sql.ast.SQLStatement;
...@@ -13,156 +13,115 @@ import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; ...@@ -13,156 +13,115 @@ import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor; import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.stat.TableStat; import com.alibaba.druid.stat.TableStat;
import com.alibaba.druid.util.JdbcConstants;
import com.mysql.cj.BindValue;
import com.mysql.cj.jdbc.ClientPreparedStatement;
import com.mysql.cj.jdbc.result.ResultSetImpl;
import com.mysql.cj.protocol.ResultsetRows;
import com.mysql.cj.result.Field;
import com.mysql.cj.result.Row;
import com.mysql.cj.util.StringUtils;
import com.secoo.mall.datasource.security.rule.ColumnRule; import com.secoo.mall.datasource.security.rule.ColumnRule;
import com.secoo.mall.datasource.security.rule.TableRule;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import java.sql.Statement; import java.sql.SQLException;
import java.util.*; import java.util.*;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.Future; import java.util.concurrent.Future;
@Slf4j @Slf4j
@AutoLoad
public class MysqlSecurityFilter extends AbsSecurityFilter { public class MysqlSecurityFilter extends AbsSecurityFilter {
@Override @Override
protected void decryptResultSet(ResultSetProxy resultSet) { protected void decryptResultSet(ResultSetProxy resultSet) {
// 结果集
ResultsetRows rows = ((ResultSetImpl) resultSet.getRawObject()).getRows();
// 结果集字段描述
Field[] fields = rows.getMetadata().getFields();
List<Future<Boolean>> futureList = new LinkedList<>();
for (final Field field:fields) {
Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(field.getOriginalTableName());
if (columnRuleMap == null || columnRuleMap.isEmpty()) {
continue;
}
ColumnRule columnRule = columnRuleMap.get(field.getOriginalName());
if (columnRule != null) {
for (int rowIndex = 0; rowIndex < rows.size(); rowIndex++) {
final Row row = rows.get(rowIndex);
decrypt(futureList,columnRule, field, row);
}
}
}
for (Future<Boolean> future : futureList) {
try {
future.get();
} catch (Exception e) {
log.error("解密出现异常,异常部分未解密", e);
}
}
} }
@Override @Override
protected void encryptStatement(StatementProxy statement, String sql) { protected void encryptStatement(StatementProxy statement, String sql) {
// 解析sql if (!(statement instanceof PreparedStatementProxy)) {
List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, this.getDbType()); log.debug("不需要处理的statement:{}", sql);
Statement rawObject = statement.getRawObject();
if (!(rawObject instanceof ClientPreparedStatement)) {
log.debug("不需要处理的statement:{}", rawObject);
return; return;
} }
BindValue[] bindValues = ((ClientPreparedStatement) rawObject).getQueryBindings().getBindValues();
// 解析出语句,通常只有一条,不支持超过一条语句的SQL PreparedStatementProxyImpl preparedStatement = (PreparedStatementProxyImpl) statement;
// 解析sql
List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, this.getDbType());
for (SQLStatement stmt : stmtList) { for (SQLStatement stmt : stmtList) {
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor(); MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
stmt.accept(visitor); stmt.accept(visitor);
List<Future<Boolean>> futureList = new LinkedList<>();
int index = 0; int index = 0;
List<Future<Boolean>> futureList = new LinkedList<>();
// 查询语句或删除语句,只有查询条件需要加密 // 查询 | 删除
if (stmt instanceof SQLSelectStatement || stmt instanceof MySqlDeleteStatement) { if (stmt instanceof SQLSelectStatement || stmt instanceof MySqlDeleteStatement) {
// 遍历查询条件 // 查询条件
for (TableStat.Condition condition : visitor.getConditions()) { for (TableStat.Condition condition : visitor.getConditions()) {
// 遍历查询条件值,一般只有一个,但in/between语句等可能有多个 // 查询条件值,in/between语句等可能有多个
for (Object conditionValue : condition.getValues()) { for (Object conditionValue : condition.getValues()) {
// 解析出条件值为空才是查询条件 // 解析出条件值为空查询条件
if (conditionValue != null) { if (conditionValue != null) {
continue; continue;
} }
TableStat.Column column = condition.getColumn(); TableStat.Column column = condition.getColumn();
Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(column.getTable()); Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(column.getTable());
if (columnRuleMap == null || columnRuleMap.isEmpty()) { if (columnRuleMap != null && !columnRuleMap.isEmpty()) {
continue; // 需要加密的字段
ColumnRule columnRule = columnRuleMap.get(column.getName());
if (columnRule != null) {
encrypt(futureList, columnRule, preparedStatement, index);
}
} }
index ++;
// 需要加密的字段
ColumnRule columnRule = columnRuleMap.get(column.getName());
if (columnRule != null) {
encrypt(futureList, columnRule, bindValues[index]);
}
index++;
} }
} }
} }
// 插入语句 // 插入
else if (stmt instanceof MySqlInsertStatement) { else if (stmt instanceof MySqlInsertStatement) {
MySqlInsertStatement insertStmt = (MySqlInsertStatement) stmt; MySqlInsertStatement insertStmt = (MySqlInsertStatement) stmt;
// 插入语句应该只有一个表
String tableName = insertStmt.getTableName().getSimpleName(); String tableName = insertStmt.getTableName().getSimpleName();
Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(tableName); Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(tableName);
if (columnRuleMap == null || columnRuleMap.isEmpty()) { if (columnRuleMap != null && !columnRuleMap.isEmpty()) {
continue; int valuesSize = insertStmt.getValuesList().size();
} Collection<TableStat.Column> columns = visitor.getColumns();
int columnSize = columns.size();
// valuesSize>1为batch insert语句 for (TableStat.Column column : columns) {
int valuesSize = insertStmt.getValuesList().size(); // 需要加密的字段
Collection<TableStat.Column> columns = visitor.getColumns(); ColumnRule columnRule = columnRuleMap.get(column.getName());
// 字段数量 if (columnRule != null) {
int columnSize = columns.size(); for (int valueIndex = 0; valueIndex < valuesSize; valueIndex++) {
for (TableStat.Column column : columns) { encrypt(columnRule, preparedStatement, index + valueIndex * columnSize);
// 需要加密的字段 }
ColumnRule columnRule = columnRuleMap.get(column.getName());
if (columnRule != null) {
for (int valueIndex = 0; valueIndex < valuesSize; valueIndex++) {
BindValue bindValue = bindValues[index + valueIndex * columnSize];
encrypt(futureList, columnRule, bindValue);
} }
} }
index++;
} }
} else if (stmt instanceof MySqlUpdateStatement) { index ++;
MySqlUpdateStatement updateStat = (MySqlUpdateStatement) stmt; }
// 更新语句应该只有一个表 // 更新
String tableName = updateStat.getTableName().getSimpleName(); else if (stmt instanceof MySqlUpdateStatement) {
MySqlUpdateStatement updateStmt = (MySqlUpdateStatement) stmt;
// 更新语句应该只支持单表
String tableName = updateStmt.getTableName().getSimpleName();
Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(tableName); Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(tableName);
if (columnRuleMap == null || columnRuleMap.isEmpty()) { if (columnRuleMap == null || columnRuleMap.isEmpty()) {
continue; continue;
} }
// 先处理set语句 // 处理set
for (SQLUpdateSetItem item : updateStat.getItems()) { for (SQLUpdateSetItem item : updateStmt.getItems()) {
SQLExpr column = item.getColumn(); SQLExpr column = item.getColumn();
if (item.getValue() instanceof SQLVariantRefExpr && column instanceof SQLIdentifierExpr) { if (item.getValue() instanceof SQLVariantRefExpr && column instanceof SQLIdentifierExpr) {
// 需要加密的字段 // 需要加密的字段
String columnName = ((SQLIdentifierExpr) column).getName(); String columnName = ((SQLIdentifierExpr) column).getName();
ColumnRule columnRule = columnRuleMap.get(columnName); ColumnRule columnRule = columnRuleMap.get(columnName);
if (columnRule != null) { if (columnRule != null) {
encrypt(futureList, columnRule, bindValues[index]); encrypt(futureList, columnRule, preparedStatement,index);
} }
index++;
} }
index++;
} }
// 再处理where语句 // 处理where
for (TableStat.Condition condition : visitor.getConditions()) { for (TableStat.Condition condition : visitor.getConditions()) {
// 遍历查询条件值,一般只有一个,但in/between语句等可能有多个 // 查询条件值,in/between语句等可能有多个
for (Object conditionValue : condition.getValues()) { for (Object conditionValue : condition.getValues()) {
// 解析出条件值为空才是查询条件 // 解析出条件值为空查询条件
if (conditionValue == null) { if (conditionValue == null) {
continue; continue;
} }
...@@ -170,25 +129,22 @@ public class MysqlSecurityFilter extends AbsSecurityFilter { ...@@ -170,25 +129,22 @@ public class MysqlSecurityFilter extends AbsSecurityFilter {
TableStat.Column column = condition.getColumn(); TableStat.Column column = condition.getColumn();
ColumnRule columnRule = columnRuleMap.get(column.getName()); ColumnRule columnRule = columnRuleMap.get(column.getName());
if (columnRule != null) { if (columnRule != null) {
encrypt(futureList, columnRule, bindValues[index]); encrypt(futureList, columnRule, preparedStatement,index);
} }
index++; index++;
} }
} }
} }
// 其他,一般没有了 // 其他
else { else {
for (TableStat.Column column : visitor.getColumns()) { for (TableStat.Column column : visitor.getColumns()) {
Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(column.getTable()); Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(column.getTable());
if (columnRuleMap == null || columnRuleMap.isEmpty()) { if (columnRuleMap != null && !columnRuleMap.isEmpty()) {
continue; // 需要加密的字段
} ColumnRule columnRule = columnRuleMap.get(column.getName());
if (columnRule != null) {
// 需要加密的字段 encrypt(futureList, columnRule, preparedStatement,index);
ColumnRule columnRule = columnRuleMap.get(column.getName()); }
if (columnRule != null) {
encrypt(futureList, columnRule, bindValues[index]);
} }
index++; index++;
} }
...@@ -199,13 +155,15 @@ public class MysqlSecurityFilter extends AbsSecurityFilter { ...@@ -199,13 +155,15 @@ public class MysqlSecurityFilter extends AbsSecurityFilter {
future.get(); future.get();
} catch (Exception e) { } catch (Exception e) {
log.error("加密出现异常,异常部分未加密", e); log.error("加密出现异常,异常部分未加密", e);
throw new SecurityException("加密出现异常,异常部分未加密");
} }
} }
} }
} }
private void encrypt(List<Future<Boolean>> futureList,final ColumnRule columnRule, final BindValue bindValue) { private void encrypt(List<Future<Boolean>> futureList,final ColumnRule columnRule, final PreparedStatementProxyImpl preparedStatement,int index) {
final String origValue = getBindValue(bindValue); JdbcParameter jdbcParameter = preparedStatement.getParameter(index);
final Object origValue = jdbcParameter.getValue();
if (origValue == null) { if (origValue == null) {
return; return;
} }
...@@ -214,72 +172,77 @@ public class MysqlSecurityFilter extends AbsSecurityFilter { ...@@ -214,72 +172,77 @@ public class MysqlSecurityFilter extends AbsSecurityFilter {
Future<Boolean> future = this.getParallelExecutor().submit(new Callable<Boolean>() { Future<Boolean> future = this.getParallelExecutor().submit(new Callable<Boolean>() {
@Override @Override
public Boolean call() throws Exception { public Boolean call() throws Exception {
encrypt(columnRule, origValue, bindValue); encrypt(columnRule,preparedStatement, index);
return true; return true;
} }
}); });
futureList.add(future); futureList.add(future);
} else { } else {
encrypt(columnRule, origValue, bindValue); encrypt(columnRule, preparedStatement, index);
} }
} }
private void encrypt(ColumnRule columnRule, String origValue, BindValue bindValue) { private void encrypt(ColumnRule columnRule, PreparedStatementProxyImpl preparedStatement,int index) {
String encryptValue = columnRule.getSecurityAlgorithm().encrypt(origValue); JdbcParameter jdbcParameter = preparedStatement.getParameter(index);
encryptValue = "'" + encryptValue + "'"; final Object origValue = jdbcParameter.getValue();
bindValue.setByteValue(encryptValue.getBytes(charset)); if (origValue == null) {
log.debug("字段加密:columnRule={},origValue={},encryptValue={}", columnRule, origValue, encryptValue); return;
} }
private void decrypt(List<Future<Boolean>> futureList,final ColumnRule columnRule,final Field field, final Row row) { String encryptValue = columnRule.getSecurityAlgorithm().encrypt(origValue);
int index = field.getCollationIndex(); try {
byte[] bytes = row.getBytes(index); preparedStatement.setObject(index + 1,encryptValue);
if (bytes != null && bytes.length > 0) { } catch (SQLException throwables) {
final String origValue = StringUtils.toString(bytes, charset.name()); log.error("字段加密异常:columnRule={},origValue={},encryptValue={}", columnRule, origValue, encryptValue);
if (this.isParallelEnabled()) { throw new SecurityException("参数加密异常!");
Future<Boolean> future = this.getParallelExecutor().submit(new Callable<Boolean>() {
@Override
public Boolean call() {
decrypt(columnRule, row, origValue, index);
return true;
}
});
futureList.add(future);
} else {
decrypt(columnRule, row, origValue, index);
}
} }
log.debug("字段加密:columnRule={},origValue={},encryptValue={}", columnRule, origValue, encryptValue);
} }
private void decrypt(ColumnRule columnRule, Row row, String origValue, int index) { public static void main(String[] args) {
String decryptValue = columnRule.getSecurityAlgorithm().decrypt(origValue); // String sql = "select id,name,age from t_user where id = ? and name = 'tom' and age > 10";
row.setBytes(index, decryptValue.getBytes(charset)); String sql = "update t_user u,t_account a set u.name = 'test', a.age=3 where u.id = 1 and a.age > ?";
log.debug("字段解密:columnRule={},origValue={},decryptValue={}", columnRule, origValue, decryptValue); List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, "mysql");
} for (SQLStatement stmt : stmtList) {
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
stmt.accept(visitor);
private String getBindValue(BindValue bindValue) { MySqlUpdateStatement updateStmt = (MySqlUpdateStatement) stmt;
if (bindValue.isNull()) {
return null;
}
byte[] byteValue = bindValue.getByteValue();
if (byteValue == null || byteValue.length == 0) {
return null;
}
String origValue = StringUtils.toString(byteValue, charset.name());
if ("''".equals(origValue) || "".equals(origValue)) {
return null;
}
// 参数可能自带''单引号,需要去掉''单引号
if (origValue.startsWith("'") && origValue.endsWith("'")) {
origValue = origValue.substring(1, origValue.length() - 1);
}
return origValue; System.out.println(visitor.getParameters());
} System.out.println(visitor.getColumns());
System.out.println(visitor.getGroupByColumns());
System.out.println(visitor.getOrderByColumns());
System.out.println(visitor.getConditions());
System.out.println(visitor.getTables());
System.out.println(visitor.getRelationships());
public void init(Set<TableRule> tableRules) { // 查询
super.init(tableRules); if (stmt instanceof SQLSelectStatement) {
this.setDbType(JdbcConstants.MYSQL); SQLSelectStatement selectStmt = (SQLSelectStatement) stmt;
// 遍历查询条件
for (TableStat.Condition condition : visitor.getConditions()) {
// 遍历查询条件值,一般只有一个,但in/between语句等可能有多个
for (Object conditionValue : condition.getValues()) {
// 解析出条件值为空才是查询条件
System.out.println("condition" + condition.getColumn() + ",values=" + condition.getValues());
}
}
SQLExpr sqlExpr = selectStmt.getSelect().getQueryBlock().getWhere();
if(sqlExpr instanceof SQLInListExpr){
// SQLInListExpr 指 run_id in ('1', '2') 这一情况
SQLInListExpr inListExpr = (SQLInListExpr)sqlExpr;
List<SQLExpr> valueExprs = inListExpr.getTargetList();
for(SQLExpr expr : valueExprs){
System.out.print(expr + "\t");
}
} else {
// SQLBinaryOpExpr 指 run_id = '1' 这一情况
SQLBinaryOpExpr binaryOpExpr = (SQLBinaryOpExpr) sqlExpr;
System.out.println(binaryOpExpr.getLeft() + " --> " + binaryOpExpr.getRight());
}
}
}
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment