Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
M
matrix
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
CI / CD
CI / CD
Pipelines
Schedules
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Commits
Issue Boards
Open sidebar
mall
arch
matrix
Commits
20f3fa99
Commit
20f3fa99
authored
Aug 18, 2021
by
郑冰晶
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
数据库加密组件
parent
293e7d19
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
116 additions
and
153 deletions
+116
-153
MysqlSecurityFilter.java
matrix-datasource/matrix-datasource-security/matrix-datasource-security-core/src/main/java/com/secoo/mall/datasource/security/filter/MysqlSecurityFilter.java
+116
-153
No files found.
matrix-datasource/matrix-datasource-security/matrix-datasource-security-core/src/main/java/com/secoo/mall/datasource/security/filter/MysqlSecurityFilter.java
View file @
20f3fa99
package
com
.
secoo
.
mall
.
datasource
.
security
.
filter
;
import
com.alibaba.druid.
proxy.jdbc.ResultSetProxy
;
import
com.alibaba.druid.proxy.jdbc.
StatementProxy
;
import
com.alibaba.druid.
filter.AutoLoad
;
import
com.alibaba.druid.proxy.jdbc.
*
;
import
com.alibaba.druid.sql.SQLUtils
;
import
com.alibaba.druid.sql.ast.SQLExpr
;
import
com.alibaba.druid.sql.ast.SQLStatement
;
...
...
@@ -13,156 +13,115 @@ import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import
com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement
;
import
com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor
;
import
com.alibaba.druid.stat.TableStat
;
import
com.alibaba.druid.util.JdbcConstants
;
import
com.mysql.cj.BindValue
;
import
com.mysql.cj.jdbc.ClientPreparedStatement
;
import
com.mysql.cj.jdbc.result.ResultSetImpl
;
import
com.mysql.cj.protocol.ResultsetRows
;
import
com.mysql.cj.result.Field
;
import
com.mysql.cj.result.Row
;
import
com.mysql.cj.util.StringUtils
;
import
com.secoo.mall.datasource.security.rule.ColumnRule
;
import
com.secoo.mall.datasource.security.rule.TableRule
;
import
lombok.extern.slf4j.Slf4j
;
import
java.sql.S
tatement
;
import
java.sql.S
QLException
;
import
java.util.*
;
import
java.util.concurrent.Callable
;
import
java.util.concurrent.Future
;
@Slf4j
@AutoLoad
public
class
MysqlSecurityFilter
extends
AbsSecurityFilter
{
@Override
protected
void
decryptResultSet
(
ResultSetProxy
resultSet
)
{
// 结果集
ResultsetRows
rows
=
((
ResultSetImpl
)
resultSet
.
getRawObject
()).
getRows
();
// 结果集字段描述
Field
[]
fields
=
rows
.
getMetadata
().
getFields
();
List
<
Future
<
Boolean
>>
futureList
=
new
LinkedList
<>();
for
(
final
Field
field:
fields
)
{
Map
<
String
,
ColumnRule
>
columnRuleMap
=
this
.
getTableRuleMap
().
get
(
field
.
getOriginalTableName
());
if
(
columnRuleMap
==
null
||
columnRuleMap
.
isEmpty
())
{
continue
;
}
ColumnRule
columnRule
=
columnRuleMap
.
get
(
field
.
getOriginalName
());
if
(
columnRule
!=
null
)
{
for
(
int
rowIndex
=
0
;
rowIndex
<
rows
.
size
();
rowIndex
++)
{
final
Row
row
=
rows
.
get
(
rowIndex
);
decrypt
(
futureList
,
columnRule
,
field
,
row
);
}
}
}
for
(
Future
<
Boolean
>
future
:
futureList
)
{
try
{
future
.
get
();
}
catch
(
Exception
e
)
{
log
.
error
(
"解密出现异常,异常部分未解密"
,
e
);
}
}
}
@Override
protected
void
encryptStatement
(
StatementProxy
statement
,
String
sql
)
{
// 解析sql
List
<
SQLStatement
>
stmtList
=
SQLUtils
.
parseStatements
(
sql
,
this
.
getDbType
());
Statement
rawObject
=
statement
.
getRawObject
();
if
(!(
rawObject
instanceof
ClientPreparedStatement
))
{
log
.
debug
(
"不需要处理的statement:{}"
,
rawObject
);
if
(!(
statement
instanceof
PreparedStatementProxy
))
{
log
.
debug
(
"不需要处理的statement:{}"
,
sql
);
return
;
}
BindValue
[]
bindValues
=
((
ClientPreparedStatement
)
rawObject
).
getQueryBindings
().
getBindValues
();
// 解析出语句,通常只有一条,不支持超过一条语句的SQL
PreparedStatementProxyImpl
preparedStatement
=
(
PreparedStatementProxyImpl
)
statement
;
// 解析sql
List
<
SQLStatement
>
stmtList
=
SQLUtils
.
parseStatements
(
sql
,
this
.
getDbType
());
for
(
SQLStatement
stmt
:
stmtList
)
{
MySqlSchemaStatVisitor
visitor
=
new
MySqlSchemaStatVisitor
();
stmt
.
accept
(
visitor
);
List
<
Future
<
Boolean
>>
futureList
=
new
LinkedList
<>();
int
index
=
0
;
List
<
Future
<
Boolean
>>
futureList
=
new
LinkedList
<>();
// 查询
语句或删除语句,只有查询条件需要加密
// 查询
| 删除
if
(
stmt
instanceof
SQLSelectStatement
||
stmt
instanceof
MySqlDeleteStatement
)
{
//
遍历
查询条件
// 查询条件
for
(
TableStat
.
Condition
condition
:
visitor
.
getConditions
())
{
//
遍历查询条件值,一般只有一个,但
in/between语句等可能有多个
//
查询条件值,
in/between语句等可能有多个
for
(
Object
conditionValue
:
condition
.
getValues
())
{
// 解析出条件值为空
才是
查询条件
// 解析出条件值为空
为
查询条件
if
(
conditionValue
!=
null
)
{
continue
;
}
TableStat
.
Column
column
=
condition
.
getColumn
();
Map
<
String
,
ColumnRule
>
columnRuleMap
=
this
.
getTableRuleMap
().
get
(
column
.
getTable
());
if
(
columnRuleMap
==
null
||
columnRuleMap
.
isEmpty
())
{
continue
;
if
(
columnRuleMap
!=
null
&&
!
columnRuleMap
.
isEmpty
())
{
// 需要加密的字段
ColumnRule
columnRule
=
columnRuleMap
.
get
(
column
.
getName
());
if
(
columnRule
!=
null
)
{
encrypt
(
futureList
,
columnRule
,
preparedStatement
,
index
);
}
}
// 需要加密的字段
ColumnRule
columnRule
=
columnRuleMap
.
get
(
column
.
getName
());
if
(
columnRule
!=
null
)
{
encrypt
(
futureList
,
columnRule
,
bindValues
[
index
]);
}
index
++;
index
++;
}
}
}
// 插入
语句
// 插入
else
if
(
stmt
instanceof
MySqlInsertStatement
)
{
MySqlInsertStatement
insertStmt
=
(
MySqlInsertStatement
)
stmt
;
// 插入语句应该只有一个表
String
tableName
=
insertStmt
.
getTableName
().
getSimpleName
();
Map
<
String
,
ColumnRule
>
columnRuleMap
=
this
.
getTableRuleMap
().
get
(
tableName
);
if
(
columnRuleMap
==
null
||
columnRuleMap
.
isEmpty
())
{
continue
;
}
// valuesSize>1为batch insert语句
int
valuesSize
=
insertStmt
.
getValuesList
().
size
();
Collection
<
TableStat
.
Column
>
columns
=
visitor
.
getColumns
();
// 字段数量
int
columnSize
=
columns
.
size
();
for
(
TableStat
.
Column
column
:
columns
)
{
// 需要加密的字段
ColumnRule
columnRule
=
columnRuleMap
.
get
(
column
.
getName
());
if
(
columnRule
!=
null
)
{
for
(
int
valueIndex
=
0
;
valueIndex
<
valuesSize
;
valueIndex
++)
{
BindValue
bindValue
=
bindValues
[
index
+
valueIndex
*
columnSize
];
encrypt
(
futureList
,
columnRule
,
bindValue
);
if
(
columnRuleMap
!=
null
&&
!
columnRuleMap
.
isEmpty
())
{
int
valuesSize
=
insertStmt
.
getValuesList
().
size
();
Collection
<
TableStat
.
Column
>
columns
=
visitor
.
getColumns
();
int
columnSize
=
columns
.
size
();
for
(
TableStat
.
Column
column
:
columns
)
{
// 需要加密的字段
ColumnRule
columnRule
=
columnRuleMap
.
get
(
column
.
getName
());
if
(
columnRule
!=
null
)
{
for
(
int
valueIndex
=
0
;
valueIndex
<
valuesSize
;
valueIndex
++)
{
encrypt
(
columnRule
,
preparedStatement
,
index
+
valueIndex
*
columnSize
);
}
}
}
index
++;
}
}
else
if
(
stmt
instanceof
MySqlUpdateStatement
)
{
MySqlUpdateStatement
updateStat
=
(
MySqlUpdateStatement
)
stmt
;
// 更新语句应该只有一个表
String
tableName
=
updateStat
.
getTableName
().
getSimpleName
();
index
++;
}
// 更新
else
if
(
stmt
instanceof
MySqlUpdateStatement
)
{
MySqlUpdateStatement
updateStmt
=
(
MySqlUpdateStatement
)
stmt
;
// 更新语句应该只支持单表
String
tableName
=
updateStmt
.
getTableName
().
getSimpleName
();
Map
<
String
,
ColumnRule
>
columnRuleMap
=
this
.
getTableRuleMap
().
get
(
tableName
);
if
(
columnRuleMap
==
null
||
columnRuleMap
.
isEmpty
())
{
continue
;
}
//
先处理set语句
for
(
SQLUpdateSetItem
item
:
updateSt
a
t
.
getItems
())
{
//
处理set
for
(
SQLUpdateSetItem
item
:
updateSt
m
t
.
getItems
())
{
SQLExpr
column
=
item
.
getColumn
();
if
(
item
.
getValue
()
instanceof
SQLVariantRefExpr
&&
column
instanceof
SQLIdentifierExpr
)
{
// 需要加密的字段
String
columnName
=
((
SQLIdentifierExpr
)
column
).
getName
();
ColumnRule
columnRule
=
columnRuleMap
.
get
(
columnName
);
if
(
columnRule
!=
null
)
{
encrypt
(
futureList
,
columnRule
,
bindValues
[
index
]
);
encrypt
(
futureList
,
columnRule
,
preparedStatement
,
index
);
}
index
++;
}
index
++;
}
//
再处理where语句
//
处理where
for
(
TableStat
.
Condition
condition
:
visitor
.
getConditions
())
{
//
遍历查询条件值,一般只有一个,但
in/between语句等可能有多个
//
查询条件值,
in/between语句等可能有多个
for
(
Object
conditionValue
:
condition
.
getValues
())
{
// 解析出条件值为空
才是
查询条件
// 解析出条件值为空
为
查询条件
if
(
conditionValue
==
null
)
{
continue
;
}
...
...
@@ -170,25 +129,22 @@ public class MysqlSecurityFilter extends AbsSecurityFilter {
TableStat
.
Column
column
=
condition
.
getColumn
();
ColumnRule
columnRule
=
columnRuleMap
.
get
(
column
.
getName
());
if
(
columnRule
!=
null
)
{
encrypt
(
futureList
,
columnRule
,
bindValues
[
index
]
);
encrypt
(
futureList
,
columnRule
,
preparedStatement
,
index
);
}
index
++;
}
}
}
// 其他
,一般没有了
// 其他
else
{
for
(
TableStat
.
Column
column
:
visitor
.
getColumns
())
{
Map
<
String
,
ColumnRule
>
columnRuleMap
=
this
.
getTableRuleMap
().
get
(
column
.
getTable
());
if
(
columnRuleMap
==
null
||
columnRuleMap
.
isEmpty
())
{
continue
;
}
// 需要加密的字段
ColumnRule
columnRule
=
columnRuleMap
.
get
(
column
.
getName
());
if
(
columnRule
!=
null
)
{
encrypt
(
futureList
,
columnRule
,
bindValues
[
index
]);
if
(
columnRuleMap
!=
null
&&
!
columnRuleMap
.
isEmpty
())
{
// 需要加密的字段
ColumnRule
columnRule
=
columnRuleMap
.
get
(
column
.
getName
());
if
(
columnRule
!=
null
)
{
encrypt
(
futureList
,
columnRule
,
preparedStatement
,
index
);
}
}
index
++;
}
...
...
@@ -199,13 +155,15 @@ public class MysqlSecurityFilter extends AbsSecurityFilter {
future
.
get
();
}
catch
(
Exception
e
)
{
log
.
error
(
"加密出现异常,异常部分未加密"
,
e
);
throw
new
SecurityException
(
"加密出现异常,异常部分未加密"
);
}
}
}
}
private
void
encrypt
(
List
<
Future
<
Boolean
>>
futureList
,
final
ColumnRule
columnRule
,
final
BindValue
bindValue
)
{
final
String
origValue
=
getBindValue
(
bindValue
);
private
void
encrypt
(
List
<
Future
<
Boolean
>>
futureList
,
final
ColumnRule
columnRule
,
final
PreparedStatementProxyImpl
preparedStatement
,
int
index
)
{
JdbcParameter
jdbcParameter
=
preparedStatement
.
getParameter
(
index
);
final
Object
origValue
=
jdbcParameter
.
getValue
();
if
(
origValue
==
null
)
{
return
;
}
...
...
@@ -214,72 +172,77 @@ public class MysqlSecurityFilter extends AbsSecurityFilter {
Future
<
Boolean
>
future
=
this
.
getParallelExecutor
().
submit
(
new
Callable
<
Boolean
>()
{
@Override
public
Boolean
call
()
throws
Exception
{
encrypt
(
columnRule
,
origValue
,
bindValue
);
encrypt
(
columnRule
,
preparedStatement
,
index
);
return
true
;
}
});
futureList
.
add
(
future
);
}
else
{
encrypt
(
columnRule
,
origValue
,
bindValue
);
encrypt
(
columnRule
,
preparedStatement
,
index
);
}
}
private
void
encrypt
(
ColumnRule
columnRule
,
String
origValue
,
BindValue
bindValue
)
{
String
encryptValue
=
columnRule
.
getSecurityAlgorithm
().
encrypt
(
origValue
);
encryptValue
=
"'"
+
encryptValue
+
"'"
;
bindValue
.
setByteValue
(
encryptValue
.
getBytes
(
charset
));
log
.
debug
(
"字段加密:columnRule={},origValue={},encryptValue={}"
,
columnRule
,
origValue
,
encryptValue
)
;
}
private
void
encrypt
(
ColumnRule
columnRule
,
PreparedStatementProxyImpl
preparedStatement
,
int
index
)
{
JdbcParameter
jdbcParameter
=
preparedStatement
.
getParameter
(
index
);
final
Object
origValue
=
jdbcParameter
.
getValue
()
;
if
(
origValue
==
null
)
{
return
;
}
private
void
decrypt
(
List
<
Future
<
Boolean
>>
futureList
,
final
ColumnRule
columnRule
,
final
Field
field
,
final
Row
row
)
{
int
index
=
field
.
getCollationIndex
();
byte
[]
bytes
=
row
.
getBytes
(
index
);
if
(
bytes
!=
null
&&
bytes
.
length
>
0
)
{
final
String
origValue
=
StringUtils
.
toString
(
bytes
,
charset
.
name
());
if
(
this
.
isParallelEnabled
())
{
Future
<
Boolean
>
future
=
this
.
getParallelExecutor
().
submit
(
new
Callable
<
Boolean
>()
{
@Override
public
Boolean
call
()
{
decrypt
(
columnRule
,
row
,
origValue
,
index
);
return
true
;
}
});
futureList
.
add
(
future
);
}
else
{
decrypt
(
columnRule
,
row
,
origValue
,
index
);
}
String
encryptValue
=
columnRule
.
getSecurityAlgorithm
().
encrypt
(
origValue
);
try
{
preparedStatement
.
setObject
(
index
+
1
,
encryptValue
);
}
catch
(
SQLException
throwables
)
{
log
.
error
(
"字段加密异常:columnRule={},origValue={},encryptValue={}"
,
columnRule
,
origValue
,
encryptValue
);
throw
new
SecurityException
(
"参数加密异常!"
);
}
log
.
debug
(
"字段加密:columnRule={},origValue={},encryptValue={}"
,
columnRule
,
origValue
,
encryptValue
);
}
private
void
decrypt
(
ColumnRule
columnRule
,
Row
row
,
String
origValue
,
int
index
)
{
String
decryptValue
=
columnRule
.
getSecurityAlgorithm
().
decrypt
(
origValue
);
row
.
setBytes
(
index
,
decryptValue
.
getBytes
(
charset
));
log
.
debug
(
"字段解密:columnRule={},origValue={},decryptValue={}"
,
columnRule
,
origValue
,
decryptValue
);
}
public
static
void
main
(
String
[]
args
)
{
// String sql = "select id,name,age from t_user where id = ? and name = 'tom' and age > 10";
String
sql
=
"update t_user u,t_account a set u.name = 'test', a.age=3 where u.id = 1 and a.age > ?"
;
List
<
SQLStatement
>
stmtList
=
SQLUtils
.
parseStatements
(
sql
,
"mysql"
);
for
(
SQLStatement
stmt
:
stmtList
)
{
MySqlSchemaStatVisitor
visitor
=
new
MySqlSchemaStatVisitor
();
stmt
.
accept
(
visitor
);
private
String
getBindValue
(
BindValue
bindValue
)
{
if
(
bindValue
.
isNull
())
{
return
null
;
}
byte
[]
byteValue
=
bindValue
.
getByteValue
();
if
(
byteValue
==
null
||
byteValue
.
length
==
0
)
{
return
null
;
}
String
origValue
=
StringUtils
.
toString
(
byteValue
,
charset
.
name
());
if
(
"''"
.
equals
(
origValue
)
||
""
.
equals
(
origValue
))
{
return
null
;
}
// 参数可能自带''单引号,需要去掉''单引号
if
(
origValue
.
startsWith
(
"'"
)
&&
origValue
.
endsWith
(
"'"
))
{
origValue
=
origValue
.
substring
(
1
,
origValue
.
length
()
-
1
);
}
MySqlUpdateStatement
updateStmt
=
(
MySqlUpdateStatement
)
stmt
;
return
origValue
;
}
System
.
out
.
println
(
visitor
.
getParameters
());
System
.
out
.
println
(
visitor
.
getColumns
());
System
.
out
.
println
(
visitor
.
getGroupByColumns
());
System
.
out
.
println
(
visitor
.
getOrderByColumns
());
System
.
out
.
println
(
visitor
.
getConditions
());
System
.
out
.
println
(
visitor
.
getTables
());
System
.
out
.
println
(
visitor
.
getRelationships
());
public
void
init
(
Set
<
TableRule
>
tableRules
)
{
super
.
init
(
tableRules
);
this
.
setDbType
(
JdbcConstants
.
MYSQL
);
// 查询
if
(
stmt
instanceof
SQLSelectStatement
)
{
SQLSelectStatement
selectStmt
=
(
SQLSelectStatement
)
stmt
;
// 遍历查询条件
for
(
TableStat
.
Condition
condition
:
visitor
.
getConditions
())
{
// 遍历查询条件值,一般只有一个,但in/between语句等可能有多个
for
(
Object
conditionValue
:
condition
.
getValues
())
{
// 解析出条件值为空才是查询条件
System
.
out
.
println
(
"condition"
+
condition
.
getColumn
()
+
",values="
+
condition
.
getValues
());
}
}
SQLExpr
sqlExpr
=
selectStmt
.
getSelect
().
getQueryBlock
().
getWhere
();
if
(
sqlExpr
instanceof
SQLInListExpr
){
// SQLInListExpr 指 run_id in ('1', '2') 这一情况
SQLInListExpr
inListExpr
=
(
SQLInListExpr
)
sqlExpr
;
List
<
SQLExpr
>
valueExprs
=
inListExpr
.
getTargetList
();
for
(
SQLExpr
expr
:
valueExprs
){
System
.
out
.
print
(
expr
+
"\t"
);
}
}
else
{
// SQLBinaryOpExpr 指 run_id = '1' 这一情况
SQLBinaryOpExpr
binaryOpExpr
=
(
SQLBinaryOpExpr
)
sqlExpr
;
System
.
out
.
println
(
binaryOpExpr
.
getLeft
()
+
" --> "
+
binaryOpExpr
.
getRight
());
}
}
}
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment