Commit 444c3842 by 郑冰晶

数据库加密组件

parent 20f3fa99
package com.secoo.mall.datasource.security.constant;
public class DBType {
public static final String MYSQL = "MYSQL";
public static final String ORACLE = "ORACLE";
}
package com.secoo.mall.datasource.security.constant;
public class SecurityType {
public static final String DES = "DES";
public static final String AES = "AES";
}
package com.secoo.mall.datasource.security.exception;
public class TableColumnException extends RuntimeException {
private String message;
public TableColumnException(String message) {
this.message = message;
}
public TableColumnException(String message, String message1) {
super(message);
this.message = message1;
}
public TableColumnException(String message, Throwable cause, String message1) {
super(message, cause);
this.message = message1;
}
public TableColumnException(Throwable cause, String message) {
super(cause);
this.message = message;
}
public TableColumnException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace, String message1) {
super(message, cause, enableSuppression, writableStackTrace);
this.message = message1;
}
@Override
public String getMessage() {
return message;
}
public void setMessage(String message) {
this.message = message;
}
}
\ No newline at end of file
package com.secoo.mall.datasource.security.filter;
import com.alibaba.druid.filter.FilterEventAdapter;
import com.alibaba.druid.proxy.jdbc.ResultSetProxy;
import com.alibaba.druid.proxy.jdbc.StatementProxy;
import com.secoo.mall.datasource.security.rule.ColumnRule;
import com.secoo.mall.datasource.security.rule.TableRule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
public abstract class AbsSecurityFilter extends FilterEventAdapter {
private static final Logger log = LoggerFactory.getLogger(AbsSecurityFilter.class);
public static Charset charset = StandardCharsets.UTF_8;
private String dbType;
/**
* 加密是否开启
*/
private boolean enabled = true;
/**
* 是否启用并行处理
*/
private boolean parallelEnabled;
/**
* 建议core=max
*/
private int corePoolSize;
/**
* 建议core=max
*/
private int maxPoolSize;
private Map<String, Map<String, ColumnRule>> tableRuleMap;
private ExecutorService parallelExecutor;
public String getDbType() {
return dbType;
}
public void setDbType(String dbType) {
this.dbType = dbType;
}
public boolean isEnabled() {
return enabled;
}
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
public boolean isParallelEnabled() {
return parallelEnabled;
}
public void setParallelEnabled(boolean parallelEnabled) {
this.parallelEnabled = parallelEnabled;
}
public int getCorePoolSize() {
return corePoolSize;
}
public void setCorePoolSize(int corePoolSize) {
this.corePoolSize = corePoolSize;
}
public int getMaxPoolSize() {
return maxPoolSize;
}
public void setMaxPoolSize(int maxPoolSize) {
this.maxPoolSize = maxPoolSize;
}
public Map<String, Map<String, ColumnRule>> getTableRuleMap() {
return tableRuleMap;
}
public ExecutorService getParallelExecutor() {
return parallelExecutor;
}
protected void resultSetOpenAfter(ResultSetProxy resultSet) {
if (!enabled) {
return;
}
decryptResultSet(resultSet);
}
protected abstract void decryptResultSet(ResultSetProxy resultSet);
protected void statementExecuteBefore(StatementProxy statement, String sql) {
if (!enabled) {
return;
}
this.encryptStatement(statement, sql);
}
protected abstract void encryptStatement(StatementProxy statement,String sql);
private Map<String, Map<String, ColumnRule>> parseTableRules(Set<TableRule> tableRules) {
if (tableRules == null || tableRules.size() == 0) {
return null;
}
Map<String, Map<String, ColumnRule>> tableRuleMap = new HashMap<>();
for(TableRule tableRule: tableRules){
if(tableRule == null || tableRule.getColumnRules() == null || tableRule.getColumnRules().isEmpty()){
continue;
}
Map<String, ColumnRule> columnRuleMap = new HashMap<>();
for(ColumnRule columnRule:tableRule.getColumnRules()){
columnRuleMap.put(columnRule.getLogicColumn(),columnRule);
}
tableRuleMap.put(tableRule.getTableName(),columnRuleMap);
}
return tableRuleMap;
}
public void init(Set<TableRule> tableRules) {
if (enabled) {
return;
}
if (tableRules == null || tableRules.isEmpty()) {
throw new SecurityException("security tableRules is null");
}
this.tableRuleMap = this.parseTableRules(tableRules);
log.info("security tableRules is enable:{}", tableRuleMap);
if (parallelEnabled) {
if (corePoolSize <= 0 || maxPoolSize <= 0) {
throw new SecurityException("corePoolSize(" + corePoolSize + ") or maxPoolSize(" + maxPoolSize + ") is invalid");
}
log.debug("init security parallelExecutors : corePoolSize={}, maxPoolSize={}", this.corePoolSize, this.maxPoolSize);
this.parallelExecutor = new ThreadPoolExecutor(
this.corePoolSize,
this.maxPoolSize,
300,
TimeUnit.SECONDS,
new SynchronousQueue<Runnable>(),
new NamedThreadFactory("matrix-security-"),
new ThreadPoolExecutor.CallerRunsPolicy()
);
}
}
public void destroy() {
if(this.parallelExecutor != null){
try {
parallelExecutor.shutdown();
parallelExecutor.awaitTermination(1, TimeUnit.MINUTES);
} catch (InterruptedException e) {
log.warn("interrupted when shutdown the executor:", e);
}
}
}
static class NamedThreadFactory implements ThreadFactory {
private final AtomicInteger threadNumber = new AtomicInteger(1);
private final String namePrefix;
NamedThreadFactory(String namePrefix) {
this.namePrefix = namePrefix;
}
public Thread newThread(Runnable runnable) {
return new Thread(runnable, namePrefix + threadNumber.getAndIncrement());
}
}
}
package com.secoo.mall.datasource.security.filter;
import com.alibaba.druid.filter.AutoLoad;
import com.alibaba.druid.proxy.jdbc.*;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.*;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.stat.TableStat;
import com.secoo.mall.datasource.security.rule.ColumnRule;
import lombok.extern.slf4j.Slf4j;
import java.sql.SQLException;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
@Slf4j
@AutoLoad
public class MysqlSecurityFilter extends AbsSecurityFilter {
@Override
protected void decryptResultSet(ResultSetProxy resultSet) {
}
@Override
protected void encryptStatement(StatementProxy statement, String sql) {
if (!(statement instanceof PreparedStatementProxy)) {
log.debug("不需要处理的statement:{}", sql);
return;
}
PreparedStatementProxyImpl preparedStatement = (PreparedStatementProxyImpl) statement;
// 解析sql
List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, this.getDbType());
for (SQLStatement stmt : stmtList) {
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
stmt.accept(visitor);
int index = 0;
List<Future<Boolean>> futureList = new LinkedList<>();
// 查询 | 删除
if (stmt instanceof SQLSelectStatement || stmt instanceof MySqlDeleteStatement) {
// 查询条件
for (TableStat.Condition condition : visitor.getConditions()) {
// 查询条件值,in/between语句等可能有多个
for (Object conditionValue : condition.getValues()) {
// 解析出条件值为空为查询条件
if (conditionValue != null) {
continue;
}
TableStat.Column column = condition.getColumn();
Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(column.getTable());
if (columnRuleMap != null && !columnRuleMap.isEmpty()) {
// 需要加密的字段
ColumnRule columnRule = columnRuleMap.get(column.getName());
if (columnRule != null) {
encrypt(futureList, columnRule, preparedStatement, index);
}
}
index ++;
}
}
}
// 插入
else if (stmt instanceof MySqlInsertStatement) {
MySqlInsertStatement insertStmt = (MySqlInsertStatement) stmt;
String tableName = insertStmt.getTableName().getSimpleName();
Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(tableName);
if (columnRuleMap != null && !columnRuleMap.isEmpty()) {
int valuesSize = insertStmt.getValuesList().size();
Collection<TableStat.Column> columns = visitor.getColumns();
int columnSize = columns.size();
for (TableStat.Column column : columns) {
// 需要加密的字段
ColumnRule columnRule = columnRuleMap.get(column.getName());
if (columnRule != null) {
for (int valueIndex = 0; valueIndex < valuesSize; valueIndex++) {
encrypt(columnRule, preparedStatement, index + valueIndex * columnSize);
}
}
}
}
index ++;
}
// 更新
else if (stmt instanceof MySqlUpdateStatement) {
MySqlUpdateStatement updateStmt = (MySqlUpdateStatement) stmt;
// 更新语句应该只支持单表
String tableName = updateStmt.getTableName().getSimpleName();
Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(tableName);
if (columnRuleMap == null || columnRuleMap.isEmpty()) {
continue;
}
// 处理set
for (SQLUpdateSetItem item : updateStmt.getItems()) {
SQLExpr column = item.getColumn();
if (item.getValue() instanceof SQLVariantRefExpr && column instanceof SQLIdentifierExpr) {
// 需要加密的字段
String columnName = ((SQLIdentifierExpr) column).getName();
ColumnRule columnRule = columnRuleMap.get(columnName);
if (columnRule != null) {
encrypt(futureList, columnRule, preparedStatement,index);
}
}
index++;
}
// 处理where
for (TableStat.Condition condition : visitor.getConditions()) {
// 查询条件值,in/between语句等可能有多个
for (Object conditionValue : condition.getValues()) {
// 解析出条件值为空为查询条件
if (conditionValue == null) {
continue;
}
TableStat.Column column = condition.getColumn();
ColumnRule columnRule = columnRuleMap.get(column.getName());
if (columnRule != null) {
encrypt(futureList, columnRule, preparedStatement,index);
}
index++;
}
}
}
// 其他
else {
for (TableStat.Column column : visitor.getColumns()) {
Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(column.getTable());
if (columnRuleMap != null && !columnRuleMap.isEmpty()) {
// 需要加密的字段
ColumnRule columnRule = columnRuleMap.get(column.getName());
if (columnRule != null) {
encrypt(futureList, columnRule, preparedStatement,index);
}
}
index++;
}
}
for (Future<Boolean> future : futureList) {
try {
future.get();
} catch (Exception e) {
log.error("加密出现异常,异常部分未加密", e);
throw new SecurityException("加密出现异常,异常部分未加密");
}
}
}
}
private void encrypt(List<Future<Boolean>> futureList,final ColumnRule columnRule, final PreparedStatementProxyImpl preparedStatement,int index) {
JdbcParameter jdbcParameter = preparedStatement.getParameter(index);
final Object origValue = jdbcParameter.getValue();
if (origValue == null) {
return;
}
if (this.isParallelEnabled()) {
Future<Boolean> future = this.getParallelExecutor().submit(new Callable<Boolean>() {
@Override
public Boolean call() throws Exception {
encrypt(columnRule,preparedStatement, index);
return true;
}
});
futureList.add(future);
} else {
encrypt(columnRule, preparedStatement, index);
}
}
private void encrypt(ColumnRule columnRule, PreparedStatementProxyImpl preparedStatement,int index) {
JdbcParameter jdbcParameter = preparedStatement.getParameter(index);
final Object origValue = jdbcParameter.getValue();
if (origValue == null) {
return;
}
String encryptValue = columnRule.getSecurityAlgorithm().encrypt(origValue);
try {
preparedStatement.setObject(index + 1,encryptValue);
} catch (SQLException throwables) {
log.error("字段加密异常:columnRule={},origValue={},encryptValue={}", columnRule, origValue, encryptValue);
throw new SecurityException("参数加密异常!");
}
log.debug("字段加密:columnRule={},origValue={},encryptValue={}", columnRule, origValue, encryptValue);
}
public static void main(String[] args) {
// String sql = "select id,name,age from t_user where id = ? and name = 'tom' and age > 10";
String sql = "update t_user u,t_account a set u.name = 'test', a.age=3 where u.id = 1 and a.age > ?";
List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, "mysql");
for (SQLStatement stmt : stmtList) {
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
stmt.accept(visitor);
MySqlUpdateStatement updateStmt = (MySqlUpdateStatement) stmt;
System.out.println(visitor.getParameters());
System.out.println(visitor.getColumns());
System.out.println(visitor.getGroupByColumns());
System.out.println(visitor.getOrderByColumns());
System.out.println(visitor.getConditions());
System.out.println(visitor.getTables());
System.out.println(visitor.getRelationships());
// 查询
if (stmt instanceof SQLSelectStatement) {
SQLSelectStatement selectStmt = (SQLSelectStatement) stmt;
// 遍历查询条件
for (TableStat.Condition condition : visitor.getConditions()) {
// 遍历查询条件值,一般只有一个,但in/between语句等可能有多个
for (Object conditionValue : condition.getValues()) {
// 解析出条件值为空才是查询条件
System.out.println("condition" + condition.getColumn() + ",values=" + condition.getValues());
}
}
SQLExpr sqlExpr = selectStmt.getSelect().getQueryBlock().getWhere();
if(sqlExpr instanceof SQLInListExpr){
// SQLInListExpr 指 run_id in ('1', '2') 这一情况
SQLInListExpr inListExpr = (SQLInListExpr)sqlExpr;
List<SQLExpr> valueExprs = inListExpr.getTargetList();
for(SQLExpr expr : valueExprs){
System.out.print(expr + "\t");
}
} else {
// SQLBinaryOpExpr 指 run_id = '1' 这一情况
SQLBinaryOpExpr binaryOpExpr = (SQLBinaryOpExpr) sqlExpr;
System.out.println(binaryOpExpr.getLeft() + " --> " + binaryOpExpr.getRight());
}
}
}
}
}
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>matrix-datasource-security</artifactId>
<groupId>com.secoo.mall</groupId>
<version>2.0.17.RELEASE</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>matrix-datasource-security-demo</artifactId>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<java.version>1.8</java.version>
</properties>
<dependencies>
<!--lombok-->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<!-- log -->
<dependency>
<groupId>com.secoo.mall</groupId>
<artifactId>logger-starter</artifactId>
</dependency>
<!-- 配置中心 -->
<dependency>
<groupId>com.secoo.mall</groupId>
<artifactId>config-starter</artifactId>
</dependency>
<!--mybatis-->
<dependency>
<groupId>org.mybatis.spring.boot</groupId>
<artifactId>mybatis-spring-boot-starter</artifactId>
<version>2.1.1</version>
</dependency>
<!--data source-->
<dependency>
<groupId>org.apache.shardingsphere</groupId>
<artifactId>sharding-jdbc-core</artifactId>
<version>4.1.1</version>
</dependency>
<dependency>
<groupId>org.apache.shardingsphere</groupId>
<artifactId>sharding-jdbc-spring-namespace</artifactId>
<version>4.1.1</version>
</dependency>
<dependency>
<groupId>com.secoo.mall</groupId>
<artifactId>matrix-datasource-druid</artifactId>
</dependency>
<dependency>
<groupId>com.secoo.mall</groupId>
<artifactId>matrix-datasource-security-druid</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-autoconfigure</artifactId>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid-spring-boot-starter</artifactId>
<version>1.1.10</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<finalName>matrix-datasource-security-demo</finalName>
<resources>
<resource>
<directory>src/main/java</directory>
<includes>
<include>**/*.xml</include>
</includes>
<filtering>false</filtering>
</resource>
<resource>
<directory>${project.basedir}/src/main/resources</directory>
<filtering>true</filtering>
</resource>
</resources>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<fork>true</fork>
</configuration>
<executions>
<execution>
<goals>
<goal>repackage</goal>
</goals>
</execution>
</executions>
</plugin>
<!-- 解决资源文件的编码问题 -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-resources-plugin</artifactId>
<version>2.4</version>
<configuration>
<encoding>UTF-8</encoding>
<!-- xlsx 不转码-->
<nonFilteredFileExtensions>
<nonFilteredFileExtension>xlsx</nonFilteredFileExtension>
</nonFilteredFileExtensions>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</build>
</project>
package com.secoo.mall.datasource.security.demo;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@Slf4j
@SpringBootApplication(scanBasePackages = "com.secoo.mall.datasource.security")
public class SecurityDemoApplication {
public static void main(String[] args) {
System.setProperty("env","local");
SpringApplication.run(SecurityDemoApplication.class, args);
log.info("SecurityDemoApplication SpringBoot Start Success");
}
}
package com.secoo.mall.datasource.security.demo.bean;
import com.secoo.mall.datasource.security.demo.dao.entity.PriceRuleTask;
import lombok.Data;
/**
* 定价规则任务记录表
*
* @Author like
* @Date 2020/4/1
*/
@Data
public class PriceRuleTaskBean extends PriceRuleTask {
private static final long serialVersionUID = 3992045013900981597L;
private String cnName;
private String enName;
private String taskTypeDesc;
private String taskStatusDesc;
private String triggerTypeDesc;
private String checkStatusDesc;
private String pushStatusDesc;
private String europePushStatusDesc;
private int tableShardNo;
}
\ No newline at end of file
package com.secoo.mall.datasource.security.demo.bean;
import com.secoo.mall.datasource.security.demo.bean.common.PageCriteria;
import lombok.Data;
import java.util.Date;
import java.util.List;
/**
* 定价规则任务记录表
*
* @Author like
* @Date 2020/4/1
*/
@Data
public class PriceRuleTaskCriteria extends PageCriteria {
private Long id;
private Long brandId;
// spark任务ID
private String taskId;
/**
* spark任务合并ID
*/
private String mergeTaskId;
// 任务类型(1干预价定时 2干预价手动 3品牌规则定时 4品牌规则手动)
private Integer taskType;
// 状态(0=未开始,1=进行中,2=已结束)
private Integer taskStatus;
// 触发类型(1手动 2定时)
private Integer triggerType;
// 校验状态(0=通过,1=失败)
private Integer checkStatus;
// 定价配置时间点
private Date checkPointDate;
// 推送状态0=未开始,1=进行中,2=已结束
private Integer pushStatus;
// 推送状态0=未开始,1=进行中,2=已结束
private Integer europePushStatus;
// 计算时间
private Date caclTime;
// 分片
private Integer shardNo;
// 创建日期
private Date createDate;
// 创建人
private String creator;
// 修改时间
private Date modifyDate;
// 修改人
private String modifier;
// 版本号
private Long version;
// --------- 扩展-------------
private String orderByClause = "id desc";
private String groupByClause;
private String fields;
private List<Long> ids;
private List<Long> brandIds;
private Long excludeBrandId;
private List<Long> excludeBrandIds;
private String taskIdLike;
private List<String> taskIds;
private List<String> mergeTaskIds;
private List<Integer> triggerTypes;
private List<Integer> taskTypes;
private List<Integer> taskStatuss;
private List<Integer> pushStatuss;
private List<Integer> europePushStatuss;
private List<Integer> shardNos;
private Date minCreateDate;
private Date maxCreateDate;
private Date minCaclTime;
private Date maxCaclTime;
}
\ No newline at end of file
package com.secoo.mall.datasource.security.demo.bean.common;
import lombok.Data;
/**
* @program: price-activity
* @description: 分页查询公共条件
* @author: jiangshuaiguang
* @create: 2020-03-17 18:23
**/
@Data
public class PageCriteria {
private int pageNumber = 1;
private int pageSize = 10;
/**
* 计算开始索引号
*
* @return
*/
public int getStartIndex() {
return getPageSize() * (pageNumber - 1);
}
public int getPageNumber() {
return pageNumber;
}
public void setPageNumber(int pageNumber) {
this.pageNumber = pageNumber;
}
public int getPageSize() {
return pageSize;
}
public void setPageSize(int pageSize) {
this.pageSize = pageSize;
}
}
package com.secoo.mall.datasource.security.config;
package com.secoo.mall.datasource.security.demo.config;
import com.alibaba.druid.filter.Filter;
import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.druid.util.JdbcConstants;
import com.alibaba.fastjson.JSON;
import com.secoo.mall.datasource.bean.MatrixDataSource;
import com.secoo.mall.datasource.security.constant.DBType;
import com.secoo.mall.datasource.security.filter.MysqlSecurityFilter;
import com.secoo.mall.datasource.security.filter.SecurityFilter;
import lombok.extern.slf4j.Slf4j;
import org.apache.shardingsphere.shardingjdbc.jdbc.core.datasource.ShardingDataSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.SmartInitializingSingleton;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
......@@ -15,11 +20,15 @@ import org.springframework.context.annotation.Configuration;
import javax.sql.DataSource;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
@Slf4j
@Configuration
@EnableConfigurationProperties(DataSourceSecurityProperties.class)
public class DataSourceSecurityAutoConfiguration implements ApplicationContextAware, SmartInitializingSingleton {
private static Logger logger = LoggerFactory.getLogger(DataSourceSecurityAutoConfiguration.class);
private ConfigurableApplicationContext applicationContext;
......@@ -31,41 +40,51 @@ public class DataSourceSecurityAutoConfiguration implements ApplicationContextAw
@Override
public void afterSingletonsInstantiated() {
if(properties == null || properties.getDatasourceRules() == null || properties.getDatasourceRules().size() == 0){
log.debug(">>>>>>>>>>>>>"+JSON.toJSONString(properties));
if(properties == null || properties.getSecurityRules() == null || properties.getSecurityRules().size() == 0){
throw new RuntimeException("DataSourceSecurityProperties is null!");
}
Map<String,Map<String,Map<String,String>>> dbRules = properties.getSecurityRules().get("securityRules");
Map<String, DataSource> dataSourceBeans = this.applicationContext.getBeansOfType(DataSource.class);
properties.getDatasourceRules().forEach(datasourceRule -> {
DataSource dataSource = dataSourceBeans.get(datasourceRule.getDatasourceName());
if(dataSource == null){
return;
}
DruidDataSource druidDataSource = null;
dataSourceBeans.forEach(new BiConsumer<String, DataSource>() {
@Override
public void accept(String dataSourceName, DataSource dataSource) {
dbRules.forEach(new BiConsumer<String, Map<String, Map<String, String>>>() {
@Override
public void accept(String db, Map<String, Map<String, String>> dbRule) {
List<DruidDataSource> druidDataSources = new ArrayList<>();
if(dataSource instanceof DruidDataSource){
druidDataSource = (DruidDataSource) dataSource;
druidDataSources.add((DruidDataSource) dataSource);
}else if(dataSource instanceof MatrixDataSource){
druidDataSource = (DruidDataSource) ((MatrixDataSource) dataSource).getTargetDataSource(((MatrixDataSource) dataSource).getDsName());
druidDataSources.add((DruidDataSource) ((MatrixDataSource) dataSource).getTargetDataSource(((MatrixDataSource) dataSource).getDsName()));
}else if(dataSource instanceof ShardingDataSource){
((ShardingDataSource) dataSource).getDataSourceMap().forEach(new BiConsumer<String, DataSource>() {
@Override
public void accept(String innerDataSourceName, DataSource innerDataSource) {
druidDataSources.add((DruidDataSource) innerDataSource);
}
if(druidDataSource == null){
return;
});
}
if(druidDataSource.getDbType().equals(DBType.MYSQL)){
MysqlSecurityFilter securityFilter = new MysqlSecurityFilter();
securityFilter.setEnabled(properties.getEnabled() != null && properties.getEnabled());
securityFilter.setParallelEnabled(properties.getParallelEnabled() != null && properties.getParallelEnabled());
if(securityFilter.isParallelEnabled()){
securityFilter.setCorePoolSize(properties.getCorePoolSize() == null?0:properties.getCorePoolSize());
securityFilter.setMaxPoolSize(properties.getMaxPoolSize() == null?Integer.MAX_VALUE:properties.getMaxPoolSize());
if(druidDataSources.isEmpty()){
logger.warn("dataSource type is not support!");
return;
}
securityFilter.init(datasourceRule.getTableRules());
druidDataSources.forEach(druidDataSource -> {
if(druidDataSource.getDbType().equalsIgnoreCase(JdbcConstants.MYSQL)){
SecurityFilter securityFilter = new SecurityFilter();
// TODO init rule config
druidDataSource.setProxyFilters(new ArrayList<Filter>(){{add(securityFilter);}});
}else{
logger.warn("database type({}) is not support!",druidDataSource.getDbType());
}
});
}
});
}
});
}
@Override
......
package com.secoo.mall.datasource.security.demo.config;
import org.springframework.boot.context.properties.ConfigurationProperties;
import java.util.Map;
@ConfigurationProperties(prefix = DataSourceSecurityProperties.PREFIX)
public class DataSourceSecurityProperties {
public static final String PREFIX = "spring.datasource.security";
/**
* 加解密规则
*/
private Map<String, Map<String,Map<String,Map<String,String>>>> securityRules;
public Map<String, Map<String, Map<String, Map<String,String>>>> getSecurityRules() {
return securityRules;
}
public void setSecurityRules(Map<String, Map<String, Map<String, Map<String,String>>>> securityRules) {
this.securityRules = securityRules;
}
}
package com.secoo.mall.datasource.security.demo.config;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.boot.autoconfigure.AutoConfigureBefore;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Configuration;
import javax.annotation.PostConstruct;
import javax.sql.DataSource;
import java.util.Map;
import java.util.function.BiConsumer;
@Slf4j
@Configuration
@AutoConfigureBefore(DataSource.class)
public class TestConfig implements ApplicationContextAware {
private ConfigurableApplicationContext applicationContext;
@PostConstruct
public void afterSingletonsInstantiated() {
Map<String, DataSource> beans = this.applicationContext.getBeansOfType(DataSource.class);
beans.forEach(new BiConsumer<String, DataSource>() {
@Override
public void accept(String s, DataSource dataSource) {
log.info("========>" + s);
}
});
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = (ConfigurableApplicationContext) applicationContext;
}
}
package com.secoo.mall.datasource.security.demo.config.datasource;
import com.google.common.collect.Lists;
import com.secoo.mall.datasource.bean.MatrixDataSource;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.logging.stdout.StdOutImpl;
import org.apache.ibatis.session.SqlSessionFactory;
import org.apache.shardingsphere.api.config.masterslave.MasterSlaveRuleConfiguration;
import org.apache.shardingsphere.api.config.sharding.ShardingRuleConfiguration;
import org.apache.shardingsphere.shardingjdbc.api.ShardingDataSourceFactory;
import org.mybatis.spring.SqlSessionFactoryBean;
import org.mybatis.spring.SqlSessionTemplate;
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.jdbc.metadata.DataSourcePoolMetadataProvider;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
import javax.sql.DataSource;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
/**
* @description: 数据源配置
* @author: jiangshuaiguan
* @create: 2018-08-24 15:45
**/
@Slf4j
@Configuration
@MapperScan(basePackages = "com.secoo.mall.datasource.security.demo.dao.mapper", sqlSessionFactoryRef = "activitySqlSessionFactory")
public class ActivityDBConfig {
@Value("${sql.show}")
private Boolean sqlShow;
@Bean("masterDB")
public MatrixDataSource masterDB(){
return new MatrixDataSource("activity-master");
}
@Bean("slaveDB")
public MatrixDataSource slaveDB(){
return new MatrixDataSource("activity-slave");
}
@Bean("activityDataSource")
public DataSource masterSlaveDataSource(@Qualifier("masterDB") MatrixDataSource masterDB,@Qualifier("slaveDB")MatrixDataSource slaveDB) {
DataSource source = null;
try {
MasterSlaveRuleConfiguration masterSlaveRuleConfig = new MasterSlaveRuleConfiguration("activity", "ds_master", Arrays.asList("ds_slave0"));
ShardingRuleConfiguration shardingRuleConfig = new ShardingRuleConfiguration();
shardingRuleConfig.setMasterSlaveRuleConfigs(Lists.newArrayList(masterSlaveRuleConfig));
shardingRuleConfig.setTableRuleConfigs(TableShardingConfig.gettableRules());
Properties prop = new Properties();
Map<String, DataSource> dataSourceMap = new HashMap<>();
dataSourceMap.put("ds_master", masterDB);
dataSourceMap.put("ds_slave0", slaveDB);
source = ShardingDataSourceFactory.createDataSource(dataSourceMap, shardingRuleConfig, prop);
} catch (Exception e) {
log.error("配置shard-jdbc失败", e);
}
return source;
}
@Bean
public DataSourcePoolMetadataProvider dataSourcePoolMetadataProvider(@Qualifier("activityDataSource") DataSource activityDataSource) {
DataSourcePoolMetadataProvider poolMetadataProvider = dataSource -> new DataSourcePoolMetadata(activityDataSource, "select 1");
return poolMetadataProvider;
}
@Bean("activityTransactionManager")
public DataSourceTransactionManager activityTransactionManager(@Qualifier(value = "activityDataSource") DataSource dataSource) {
return new DataSourceTransactionManager(dataSource);
}
@Bean("activitySqlSessionFactory")
public SqlSessionFactory sqlSessionFactory(@Qualifier("activityDataSource") DataSource dataSource) throws Exception {
final SqlSessionFactoryBean sessionFactoryBean = new SqlSessionFactoryBean();
sessionFactoryBean.setDataSource(dataSource);
org.apache.ibatis.session.Configuration config = new org.apache.ibatis.session.Configuration();
config.setMapUnderscoreToCamelCase(true);
if(sqlShow){
config.setLogImpl(StdOutImpl.class);
}
sessionFactoryBean.setConfiguration(config);
return sessionFactoryBean.getObject();
}
@Bean
public SqlSessionTemplate sqlSessionTemplate(
@Qualifier("activitySqlSessionFactory") SqlSessionFactory factory) {
return new SqlSessionTemplate(factory);
}
}
package com.secoo.mall.datasource.security.demo.config.datasource;
import org.springframework.boot.jdbc.metadata.AbstractDataSourcePoolMetadata;
import javax.sql.DataSource;
/**
* @ClassName DruidDataSourcePoolMetadata
* @Author QIANGLU
* @Date 2020/12/29 14:28
* @Version 1.0
*/
public class DataSourcePoolMetadata extends AbstractDataSourcePoolMetadata {
private String query;
/**
* Create an instance with the data source to use.
*
* @param dataSource the data source
*/
protected DataSourcePoolMetadata(DataSource dataSource, String query) {
super(dataSource);
this.query = query;
}
@Override
public Integer getActive() {
return null;
}
@Override
public Integer getMax() {
return null;
}
@Override
public Integer getMin() {
return null;
}
@Override
public String getValidationQuery() {
return this.query;
}
@Override
public Boolean getDefaultAutoCommit() {
return null;
}
}
package com.secoo.mall.datasource.security.demo.config.datasource;
import lombok.extern.slf4j.Slf4j;
import java.util.ArrayList;
import java.util.List;
import org.apache.shardingsphere.api.config.sharding.TableRuleConfiguration;
import org.apache.shardingsphere.api.config.sharding.strategy.StandardShardingStrategyConfiguration;
/**
* @program: sharingjdbc-demo
* @description:分库分表配置文件
* @author: jiangshuaiguang
* @create: 2019-03-06 09:49
**/
@Slf4j
public class TableShardingConfig {
public static List<TableRuleConfiguration> gettableRules() {
List<TableRuleConfiguration> rules = new ArrayList<>();
rules.add(getBrandRuleSkuRule());
return rules;
}
/**
* t_brand_rule_sku
*
* @return
*/
private static TableRuleConfiguration getBrandRuleSkuRule() {
TableRuleConfiguration demoConfig = new TableRuleConfiguration("t_brand_rule_sku", "activity.t_brand_rule_sku_${0..15}");
//demoConfig.setDatabaseShardingStrategyConfig(new InlineShardingStrategyConfiguration("store_id", "demo${store_id % 2}DB"));
//设置分片策略,这里简单起见直接取模,也可以使用自定义算法来实现分片规则
StandardShardingStrategyConfiguration config = new StandardShardingStrategyConfiguration("task_id", (availableTargetNames, shardingValue) -> {
int index = Math.abs(shardingValue.getValue().hashCode()) % 16;
for (Object e : availableTargetNames) {
String tableName = (String) e;
if(tableName.equals("t_brand_rule_sku_"+index)){
return tableName;
}
}
return null;
});
demoConfig.setTableShardingStrategyConfig(config);
return demoConfig;
}
}
package com.secoo.mall.datasource.security.demo.controller;
import com.alibaba.fastjson.JSON;
import com.secoo.mall.datasource.security.demo.bean.PriceRuleTaskCriteria;
import com.secoo.mall.datasource.security.demo.dao.entity.PriceRuleTask;
import com.secoo.mall.datasource.security.demo.dao.mapper.PriceRuleTaskMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
@Slf4j
@RestController
@RequestMapping("test")
public class TestController {
@Autowired
private PriceRuleTaskMapper priceRuleTaskMapper;
@RequestMapping("select")
public String select(Long id,String creator) {
PriceRuleTaskCriteria priceRuleTaskCriteria = new PriceRuleTaskCriteria();
priceRuleTaskCriteria.setId(id);
priceRuleTaskCriteria.setCreator(creator);
priceRuleTaskCriteria.setPageSize(1);
List<PriceRuleTask> priceRuleTasks = this.priceRuleTaskMapper.queryPriceRuleTaskList(priceRuleTaskCriteria);
String data = JSON.toJSONString(priceRuleTasks);
log.debug(">>>>>>>>>>>>{}", data);
return new String(data);
}
@RequestMapping("update")
public String update(Long id,String creator) {
PriceRuleTask priceRuleTask = new PriceRuleTask();
priceRuleTask.setId(id);
priceRuleTask.setCreator(creator);
int rows = this.priceRuleTaskMapper.updateByPrimaryKeySelective(priceRuleTask);
log.debug(">>>>>>>>>>>>{}", rows);
return String.valueOf(rows);
}
}
package com.secoo.mall.datasource.security.demo.dao.entity;
import lombok.Data;
import java.io.Serializable;
import java.util.Date;
/**
* 定价规则任务记录表
*
* @Author like
* @Date 2020/4/1
*/
@Data
public class PriceRuleTask implements Serializable{
private static final long serialVersionUID = 2560720814472462192L;
/**
* 主键ID
**/
private Long id;
private Long brandId;
/**
* spark任务ID
**/
private String taskId;
/**
* spark任务合并ID
*/
private String mergeTaskId;
/**
* 任务类型(1干预价定时 2干预价手动 3品牌规则定时 4品牌规则手动)
**/
private Integer taskType;
/**
* 状态(0=未开始,1=进行中,2=已结束)
**/
private Integer taskStatus;
/**
* 触发类型(1手动 2定时)
*/
private Integer triggerType;
/**
* 校验状态(0=通过,1=失败)
**/
private Integer checkStatus;
/**
* 定价配置时间点
**/
private Date checkPointDate;
/**
* 推送状态0=未开始,1=进行中,2=已结束
*/
private Integer pushStatus;
/**
* 推送状态0=未开始,1=进行中,2=已结束
*/
private Integer europePushStatus;
private String remark;
/**
* 计算时间
*/
private Date caclTime;
private Integer shardNo;
/**
* 创建日期
**/
private Date createDate;
private Long creatorId;
/**
* 创建人
**/
private String creator;
/**
* 修改时间
**/
private Date modifyDate;
private Long modifierId;
/**
* 修改人
**/
private String modifier;
/**
* 版本号
**/
private Long version;
}
\ No newline at end of file
package com.secoo.mall.datasource.security.demo.dao.mapper;
import com.secoo.mall.datasource.security.demo.bean.PriceRuleTaskCriteria;
import com.secoo.mall.datasource.security.demo.dao.entity.PriceRuleTask;
import org.apache.ibatis.annotations.Param;
import java.util.List;
public interface PriceRuleTaskMapper {
int insert(PriceRuleTask record);
int insertSelective(PriceRuleTask record);
void batchInsert(List<PriceRuleTask> records);
int updateByPrimaryKeySelective(PriceRuleTask record);
int batchUpdate(List<PriceRuleTask> records);
int updateByIdsAndStatus(@Param("entity") PriceRuleTask record,@Param("ids") List<Long> ids,@Param("taskStatus") Integer taskStatus);
int deleteByPrimaryKey(Long id);
PriceRuleTask selectByPrimaryKey(Long id);
List<PriceRuleTask> queryPriceRuleTaskList(PriceRuleTaskCriteria priceRuleTaskCriteria);
int queryPriceRuleTaskCount(PriceRuleTaskCriteria priceRuleTaskCriteria);
List<PriceRuleTask> queryExecutePriceRuleTaskList(PriceRuleTaskCriteria priceRuleTaskCriteria);
}
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd" >
<mapper namespace="com.secoo.mall.datasource.security.demo.dao.mapper.PriceRuleTaskMapper">
<resultMap id="BaseResultMap" type="com.secoo.mall.datasource.security.demo.dao.entity.PriceRuleTask">
<id column="id" property="id" jdbcType="BIGINT"/>
<result column="brand_id" property="brandId" jdbcType="BIGINT"/>
<result column="task_id" property="taskId" jdbcType="VARCHAR"/>
<result column="merge_task_id" property="mergeTaskId" jdbcType="VARCHAR"/>
<result column="task_type" property="taskType" jdbcType="TINYINT"/>
<result column="task_status" property="taskStatus" jdbcType="INTEGER"/>
<result column="trigger_type" property="triggerType" jdbcType="TINYINT"/>
<result column="check_status" property="checkStatus" jdbcType="INTEGER"/>
<result column="check_point_date" property="checkPointDate" jdbcType="TIMESTAMP"/>
<result column="push_status" property="pushStatus" jdbcType="INTEGER"/>
<result column="europe_push_status" property="europePushStatus" jdbcType="INTEGER"/>
<result column="shard_no" jdbcType="TINYINT" property="shardNo"/>
<result column="remark" property="remark" jdbcType="VARCHAR"/>
<result column="cacl_time" property="caclTime" jdbcType="TIMESTAMP"/>
<result column="create_date" property="createDate" jdbcType="TIMESTAMP"/>
<result column="creator_id" property="creatorId" jdbcType="BIGINT"/>
<result column="creator" property="creator" jdbcType="VARCHAR"/>
<result column="modify_date" property="modifyDate" jdbcType="TIMESTAMP"/>
<result column="modify_id" property="modifierId" jdbcType="BIGINT"/>
<result column="modifier" property="modifier" jdbcType="VARCHAR"/>
<result column="version" property="version" jdbcType="BIGINT"/>
</resultMap>
<sql id="Base_Column_List">
id, brand_id, task_id, merge_task_id, task_type, task_status,trigger_type,check_status, check_point_date, push_status,europe_push_status,remark, cacl_time, shard_no, create_date,creator_id, creator, modify_date, modifier_id,modifier,version
</sql>
<sql id="Example_Where_Clause">
<where>
<if test="id != null ">
and id = #{id,jdbcType=BIGINT}
</if>
<if test="ids != null and ids.size()>0">
and id in
<foreach collection="ids" item="item" open="(" close=")" separator=",">#{item}</foreach>
</if>
<if test="brandId != null ">
and brand_id = #{brandId,jdbcType=BIGINT}
</if>
<if test="brandIds != null and brandIds.size()>0">
and brand_id in <foreach collection="brandIds" item="item" open="(" close=")" separator=",">#{item}</foreach>
</if>
<if test="excludeBrandId != null ">
and brand_id != #{excludeBrandId,jdbcType=BIGINT}
</if>
<if test="excludeBrandIds != null and excludeBrandIds.size()>0">
and brand_id not in <foreach collection="excludeBrandIds" item="item" open="(" close=")" separator=",">#{item}</foreach>
</if>
<if test="taskId != null and taskId != ''">
and task_id = #{taskId,jdbcType=VARCHAR}
</if>
<if test="taskIds != null and taskIds.size()>0">
and task_id in <foreach collection="taskIds" item="item" open="(" close=")" separator=",">#{item}</foreach>
</if>
<if test="taskIdLike != null and taskIdLike != ''">
and task_id like concat(#{taskIdLike,jdbcType=VARCHAR},'%')
</if>
<if test="mergeTaskId != null and mergeTaskId != ''">
and merge_task_id = #{mergeTaskId,jdbcType=VARCHAR}
</if>
<if test="mergeTaskIds != null and mergeTaskIds.size()>0">
and merge_task_id in <foreach collection="mergeTaskIds" item="item" open="(" close=")" separator=",">#{item}</foreach>
</if>
<if test="triggerType != null ">
and trigger_type = #{triggerType,jdbcType=INTEGER}
</if>
<if test="triggerTypes != null and triggerTypes.size > 0">
and trigger_type in (
<foreach collection="triggerTypes" item="item" separator=",">#{item}</foreach>
)
</if>
<if test="taskType != null ">
and task_type = #{taskType,jdbcType=INTEGER}
</if>
<if test="taskTypes != null and taskTypes.size > 0">
and task_type in (
<foreach collection="taskTypes" item="item" separator=",">#{item}</foreach>
)
</if>
<if test="taskStatus != null ">
and task_status = #{taskStatus,jdbcType=INTEGER}
</if>
<if test="taskStatuss != null and taskStatuss.size() > 0">
and task_status in (
<foreach collection="taskStatuss" item="item" separator=",">#{item}</foreach>
)
</if>
<if test="checkStatus != null ">
and check_status = #{checkStatus,jdbcType=INTEGER}
</if>
<if test="pushStatus != null ">
and push_status = #{pushStatus,jdbcType=INTEGER}
</if>
<if test="pushStatuss != null and pushStatuss.size()>0">
and push_status in (
<foreach collection="pushStatuss" item="item" separator=",">#{item}</foreach>
)
</if>
<if test="europePushStatus != null ">
and europe_push_status = #{europePushStatus,jdbcType=INTEGER}
</if>
<if test="europePushStatuss != null and europePushStatuss.size()>0">
and europe_push_status in (
<foreach collection="europePushStatuss" item="item" separator=",">#{item}</foreach>
)
</if>
<if test="shardNo != null">
and shard_no=#{shardNo}
</if>
<if test="shardNos != null and shardNos.size() != 0">
and shard_no in
<foreach collection="shardNos" item="item" open="(" close=")" separator=",">#{item}
</foreach>
</if>
<if test="minCreateDate != null">
<![CDATA[ AND create_date >= #{minCreateDate,jdbcType=TIMESTAMP} ]]>
</if>
<if test="maxCreateDate != null ">
<![CDATA[ AND create_date <= #{maxCreateDate,jdbcType=TIMESTAMP} ]]>
</if>
<if test="creator != null and creator != ''">
and creator = #{creator}
</if>
<if test="minCaclTime != null">
<![CDATA[ AND cacl_time >= #{minCaclTime,jdbcType=TIMESTAMP} ]]>
</if>
<if test="maxCaclTime != null ">
<![CDATA[ AND cacl_time <= #{maxCaclTime,jdbcType=TIMESTAMP} ]]>
</if>
</where>
</sql>
<insert id="insert" parameterType="com.secoo.mall.datasource.security.demo.dao.entity.PriceRuleTask"
useGeneratedKeys="true" keyProperty="id" keyColumn="id">
insert into t_price_rule_task (brand_id,task_id, merge_task_id, task_type, task_status, trigger_type,
check_status, check_point_date,push_status,europe_push_status,remark,cacl_time,shard_no, create_date,creator_id, creator, version)
values (#{brandId,jdbcType=BIGINT},#{taskId,jdbcType=VARCHAR}, #{merge_task_id,jdbcType=VARCHAR}, #{taskType,jdbcType=TINYINT}, #{taskStatus,jdbcType=INTEGER},#{triggerType,jdbcType=TINYINT},
#{checkStatus,jdbcType=INTEGER}, #{checkPointDate,jdbcType=TIMESTAMP},#{pushStatus,jdbcType=INTEGER},#{europePushStatus,jdbcType=INTEGER},#{remark,jdbcType=VARCHAR},#{caclTime},
#{shardNo,jdbcType=TINYINT},now(),#{creatorId,jdbcType=BIGINT},#{creator,jdbcType=VARCHAR}, 1)
</insert>
<insert id="insertSelective" parameterType="com.secoo.mall.datasource.security.demo.dao.entity.PriceRuleTask"
useGeneratedKeys="true" keyProperty="id" keyColumn="id">
insert into t_price_rule_task
<trim prefix="(" suffix=")" suffixOverrides=",">
<if test="brandId != null">
brand_id,
</if>
<if test="taskId != null">
task_id,
</if>
<if test="mergeTaskId != null">
merge_task_id,
</if>
<if test="taskType != null">
task_type,
</if>
<if test="taskStatus != null">
task_status,
</if>
<if test="triggerType != null">
trigger_type,
</if>
<if test="checkStatus != null">
check_status,
</if>
<if test="checkPointDate != null">
check_point_date,
</if>
<if test="pushStatus != null">
push_status,
</if>
<if test="europePushStatus != null">
europe_push_status,
</if>
<if test="remark != null">
remark,
</if>
<if test="caclTime != null">
cacl_time,
</if>
<if test="shardNo != null">
shard_no,
</if>
create_date,
<if test="creatorId != null">
creator_id,
</if>
<if test="creator != null">
creator,
</if>
version
</trim>
<trim prefix="values (" suffix=")" suffixOverrides=",">
<if test="brandId != null">
#{brandId,jdbcType=BIGINT},
</if>
<if test="taskId != null">
#{taskId,jdbcType=VARCHAR},
</if>
<if test="mergeTaskId != null">
#{mergeTaskId,jdbcType=VARCHAR},
</if>
<if test="taskType != null">
#{taskType,jdbcType=TINYINT},
</if>
<if test="taskStatus != null">
#{taskStatus,jdbcType=INTEGER},
</if>
<if test="triggerType != null">
#{triggerType,jdbcType=TINYINT},
</if>
<if test="checkStatus != null">
#{checkStatus,jdbcType=INTEGER},
</if>
<if test="checkPointDate != null">
#{checkPointDate,jdbcType=TIMESTAMP},
</if>
<if test="pushStatus != null">
#{pushStatus,jdbcType=INTEGER},
</if>
<if test="europePushStatus != null">
#{europePushStatus,jdbcType=INTEGER},
</if>
<if test="remark != null">
#{remark,jdbcType=VARCHAR},
</if>
<if test="caclTime != null">
#{cacl_time},
</if>
<if test="shardNo != null">
#{shardNo,jdbcType=TINYINT},
</if>
now(),
<if test="creatorId != null">
#{creatorId,jdbcType=BIGINT},
</if>
<if test="creator != null">
#{creator,jdbcType=VARCHAR},
</if>
1
</trim>
</insert>
<insert id="batchInsert" keyProperty="id">
insert into t_price_rule_task (brand_id,task_id, merge_task_id, task_type, task_status, trigger_type,
check_status, check_point_date,push_status,europe_push_status,remark,cacl_time,shard_no, create_date,creator_id, creator, version)
values
<foreach item="item" collection="list" separator=",">
(#{item.brandId,jdbcType=BIGINT},#{item.taskId,jdbcType=VARCHAR}, #{item.mergeTaskId,jdbcType=VARCHAR}, #{item.taskType,jdbcType=TINYINT}, #{item.taskStatus,jdbcType=INTEGER},#{item.triggerType,jdbcType=TINYINT},
#{item.checkStatus,jdbcType=INTEGER}, #{item.checkPointDate,jdbcType=TIMESTAMP},#{item.pushStatus,jdbcType=INTEGER},#{item.europePushStatus,jdbcType=INTEGER},#{item.remark,jdbcType=VARCHAR},#{item.caclTime},
#{item.shardNo,jdbcType=TINYINT}, now(),#{item.creatorId,jdbcType=BIGINT},#{item.creator,jdbcType=VARCHAR}, 1)
</foreach>
</insert>
<update id="updateByIdsAndStatus">
update t_price_rule_task
<set>
<if test="entity.brandId != null">
brand_id = #{entity.brandId,jdbcType=BIGINT},
</if>
<if test="entity.taskId != null">
task_id = #{entity.taskId,jdbcType=VARCHAR},
</if>
<if test="entity.mergeTaskId != null">
merge_task_id = #{entity.mergeTaskId,jdbcType=VARCHAR},
</if>
<if test="entity.taskType != null">
task_type = #{entity.taskType,jdbcType=TINYINT},
</if>
<if test="entity.taskStatus != null">
task_status = #{entity.taskStatus,jdbcType=INTEGER},
</if>
<if test="entity.triggerType != null">
trigger_type = #{entity.triggerType,jdbcType=TINYINT},
</if>
<if test="entity.checkStatus != null">
check_status = #{entity.checkStatus,jdbcType=INTEGER},
</if>
<if test="entity.checkPointDate != null">
check_point_date = #{entity.checkPointDate,jdbcType=TIMESTAMP},
</if>
<if test="entity.pushStatus != null">
push_status = #{entity.pushStatus,jdbcType=INTEGER},
</if>
<if test="entity.europePushStatus != null">
europe_push_status = #{entity.europePushStatus,jdbcType=INTEGER},
</if>
<if test="entity.remark != null">
remark = #{entity.remark,jdbcType=VARCHAR},
</if>
<if test="entity.caclTime != null">
cacl_time = #{entity.caclTime},
</if>
<if test="entity.shardNo != null">
shard_no = #{entity.shardNo,jdbcType=TINYINT},
</if>
<if test="entity.modifierId != null">
modifier_id = #{entity.modifierId,jdbcType=BIGINT},
</if>
<if test="entity.modifier != null">
modifier = #{entity.modifier,jdbcType=VARCHAR},
</if>
modify_date = now(),
version = version + 1
</set>
where id in (<foreach collection="ids" item="item" separator=",">#{item}</foreach>)
<if test="taskStatus != null">
and task_status = #{taskStatus,jdbcType=TINYINT}
</if>
</update>
<update id="updateByPrimaryKeySelective"
parameterType="com.secoo.mall.datasource.security.demo.dao.entity.PriceRuleTask">
update t_price_rule_task
<set>
<if test="brandId != null">
brand_id = #{brandId,jdbcType=BIGINT},
</if>
<if test="taskId != null">
task_id = #{taskId,jdbcType=VARCHAR},
</if>
<if test="mergeTaskId != null">
merge_task_id = #{mergeTaskId,jdbcType=VARCHAR},
</if>
<if test="taskType != null">
task_type = #{taskType,jdbcType=TINYINT},
</if>
<if test="taskStatus != null">
task_status = #{taskStatus,jdbcType=INTEGER},
</if>
<if test="triggerType != null">
trigger_type = #{triggerType,jdbcType=TINYINT},
</if>
<if test="checkStatus != null">
check_status = #{checkStatus,jdbcType=INTEGER},
</if>
<if test="checkPointDate != null">
check_point_date = #{checkPointDate,jdbcType=TIMESTAMP},
</if>
<if test="pushStatus != null">
push_status = #{pushStatus,jdbcType=INTEGER},
</if>
<if test="europePushStatus != null">
europe_push_status = #{europePushStatus,jdbcType=INTEGER},
</if>
<if test="remark != null">
remark = #{remark,jdbcType=VARCHAR},
</if>
<if test="caclTime != null">
cacl_time = #{caclTime},
</if>
<if test="shardNo != null">
shard_no = #{shardNo,jdbcType=TINYINT},
</if>
<if test="modifierId != null">
modifier_id = #{modifierId,jdbcType=BIGINT},
</if>
<if test="creator != null">
creator = #{creator,jdbcType=VARCHAR},
</if>
<if test="modifier != null">
modifier = #{modifier,jdbcType=VARCHAR},
</if>
modify_date = now(),
version = version + 1
</set>
where id = #{id,jdbcType=BIGINT}
<if test="version != null">
and version = #{version,jdbcType=BIGINT}
</if>
</update>
<update id="batchUpdate">
<foreach collection="list" item="item" separator=";">
update t_price_rule_task
<set>
<if test="item.brandId != null">
brand_id = #{item.brandId,jdbcType=BIGINT},
</if>
<if test="item.taskId != null">
task_id = #{item.taskId,jdbcType=VARCHAR},
</if>
<if test="item.mergeTaskId != null">
merge_task_id = #{item.mergeTaskId,jdbcType=VARCHAR},
</if>
<if test="item.taskType != null">
task_type = #{item.taskType,jdbcType=TINYINT},
</if>
<if test="item.taskStatus != null">
task_status = #{item.taskStatus,jdbcType=INTEGER},
</if>
<if test="item.triggerType != null">
trigger_type = #{item.triggerType,jdbcType=TINYINT},
</if>
<if test="item.checkStatus != null">
check_status = #{item.checkStatus,jdbcType=INTEGER},
</if>
<if test="item.checkPointDate != null">
check_point_date = #{item.checkPointDate,jdbcType=TIMESTAMP},
</if>
<if test="item.pushStatus != null">
push_status = #{item.pushStatus,jdbcType=INTEGER},
</if>
<if test="item.europePushStatus != null">
europe_push_status = #{item.europePushStatus,jdbcType=INTEGER},
</if>
<if test="item.remark != null">
remark = #{item.remark,jdbcType=VARCHAR},
</if>
<if test="item.caclTime != null">
cacl_time = #{item.caclTime},
</if>
<if test="item.shardNo != null">
shard_no = #{item.shardNo,jdbcType=TINYINT},
</if>
<if test="item.modifierId != null">
modifier_id = #{item.modifierId,jdbcType=BIGINT},
</if>
<if test="item.modifier != null">
modifier = #{item.modifier,jdbcType=VARCHAR},
</if>
modify_date = now(),
version = version + 1
</set>
where id = #{item.id,jdbcType=BIGINT}
</foreach>
</update>
<delete id="deleteByPrimaryKey" parameterType="java.lang.Long">
delete
from t_price_rule_task
where id = #{id,jdbcType=BIGINT}
</delete>
<select id="selectByPrimaryKey" resultMap="BaseResultMap" parameterType="java.lang.Long">
select
<include refid="Base_Column_List"/>
from t_price_rule_task
where id = #{id,jdbcType=BIGINT}
</select>
<select id="queryPriceRuleTaskCount" resultType="int" parameterType="com.secoo.mall.datasource.security.demo.bean.PriceRuleTaskCriteria">
select count(*)
from t_price_rule_task
<include refid="Example_Where_Clause"></include>
</select>
<select id="queryPriceRuleTaskList" resultMap="BaseResultMap" parameterType="com.secoo.mall.datasource.security.demo.bean.PriceRuleTaskCriteria">
select
<if test="fields != null and fields != ''">
${fields}
</if>
<if test="fields == null or fields == ''">
<include refid="Base_Column_List"/>
</if>
from t_price_rule_task
<include refid="Example_Where_Clause"></include>
<if test="orderByClause != null">
order by ${orderByClause}
</if>
<if test="startIndex!=null and pageSize!=null and pageSize>0">
limit #{startIndex},#{pageSize}
</if>
</select>
<select id="queryExecutePriceRuleTaskList" resultMap="BaseResultMap" parameterType="com.secoo.mall.datasource.security.demo.bean.PriceRuleTaskCriteria">
select
<if test="fields != null and fields != ''">
${fields}
</if>
<if test="fields == null or fields == ''">
<include refid="Base_Column_List"/>
</if>
from t_price_rule_task
<include refid="Example_Where_Clause"></include>
<if test="groupByClause != null">
group by ${groupByClause}
</if>
<if test="orderByClause != null">
order by ${orderByClause}
</if>
</select>
</mapper>
\ No newline at end of file
spring:
application:
name: price-activity
datasource:
security:
rules:
secooActivityDB:
t_price_rule_task:
creator:
cipherColumn: creator
encryptKey: 123
encryptType: AES
logicColumn: creator
plainColumn: creator
modifior:
cipherColumn: modifior
encryptKey: 123
encryptType: AES
logicColumn: modifior
plainColumn: modifior
t_smart_batch:
creator:
cipherColumn: creator
encryptKey: 456
encryptType: AES
logicColumn: creator
plainColumn: creator
modifior:
cipherColumn: modifior
encryptKey: 456
encryptType: AES
logicColumn: modifior
plainColumn: modifior
server:
port: 6080
servlet:
context-path : /price-activity
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
......@@ -9,7 +8,7 @@
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>matrix-datasource-security-core</artifactId>
<artifactId>matrix-datasource-security-druid</artifactId>
<packaging>jar</packaging>
<dependencies>
......@@ -18,14 +17,13 @@
<artifactId>druid</artifactId>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.13</version>
<groupId>com.ctrip.framework.apollo</groupId>
<artifactId>apollo-client</artifactId>
</dependency>
</dependencies>
<build>
<finalName>matrix-datasource-security-starter</finalName>
<finalName>matrix-datasource-security-druid</finalName>
<plugins>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
......
package com.secoo.mall.datasource.security.algorithm;
package com.secoo.mall.datasource.security.algorithm.encrypt;
import com.secoo.mall.datasource.security.constant.SecurityType;
import lombok.Getter;
import lombok.Setter;
import lombok.SneakyThrows;
import com.secoo.mall.datasource.security.constant.EncryptType;
import org.apache.commons.codec.digest.DigestUtils;
import javax.crypto.Cipher;
......@@ -15,33 +12,19 @@ import java.security.GeneralSecurityException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.Properties;
/**
* AES encrypt algorithm.
*/
@Getter
@Setter
public final class AESSecurityAlgorithm implements SecurityAlgorithm {
public final class AESEncryptAlgorithm implements EncryptAlgorithm {
private String aesKey;
private Properties props = new Properties();
private byte[] secretKey;
public AESSecurityAlgorithm(String aesKey){
if(aesKey == null || aesKey.length() == 0){
throw new SecurityException("aesKey can not be null!");
}
this.aesKey = aesKey;
this.secretKey = createSecretKey();
}
private byte[] createSecretKey() {
return Arrays.copyOf(DigestUtils.sha1(aesKey), 16);
}
@SneakyThrows(GeneralSecurityException.class)
@Override
public String encrypt(final Object plaintext) {
public String encrypt(final Object plaintext) throws GeneralSecurityException {
if (null == plaintext) {
return null;
}
......@@ -49,9 +32,8 @@ public final class AESSecurityAlgorithm implements SecurityAlgorithm {
return DatatypeConverter.printBase64Binary(result);
}
@SneakyThrows(GeneralSecurityException.class)
@Override
public String decrypt(final String ciphertext) {
public String decrypt(final String ciphertext) throws GeneralSecurityException {
if (null == ciphertext) {
return null;
}
......@@ -59,13 +41,36 @@ public final class AESSecurityAlgorithm implements SecurityAlgorithm {
return new String(result, StandardCharsets.UTF_8);
}
@Override
public String getType() {
return EncryptType.AES;
}
@Override
public void init(){
this.secretKey = createSecretKey();
}
public Properties getProps() {
return this.props;
}
/**
* Set properties.
*
* @param props properties
*/
public void setProps(final Properties props) {
this.props = props;
}
private byte[] createSecretKey() {
return Arrays.copyOf(DigestUtils.sha1(props.getProperty(ENCRYPT_KEY)), 16);
}
private Cipher getCipher(final int decryptMode) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeyException {
Cipher result = Cipher.getInstance(getType());
result.init(decryptMode, new SecretKeySpec(secretKey, getType()));
return result;
}
public String getType() {
return SecurityType.AES;
}
}
package com.secoo.mall.datasource.security.algorithm;
package com.secoo.mall.datasource.security.algorithm.encrypt;
import com.secoo.mall.datasource.security.spi.TypedSPI;
import java.security.GeneralSecurityException;
/**
* Encrypt|decrypt algorithm for SPI.
*/
public interface SecurityAlgorithm {
public interface EncryptAlgorithm extends TypedSPI {
String ENCRYPT_KEY = "encrypt-key-value";
void init();
/**
* Encode.
......@@ -11,7 +19,7 @@ public interface SecurityAlgorithm {
* @param plainText plainText
* @return cipherText
*/
String encrypt(Object plainText);
String encrypt(Object plainText) throws GeneralSecurityException;
/**
* Decode.
......@@ -19,6 +27,6 @@ public interface SecurityAlgorithm {
* @param cipherText cipherText
* @return plainText
*/
String decrypt(String cipherText);
String decrypt(String cipherText) throws GeneralSecurityException;
}
package com.secoo.mall.datasource.security.algorithm.property;
import com.ctrip.framework.apollo.Config;
import com.ctrip.framework.apollo.ConfigService;
import com.secoo.mall.common.util.colletion.CollectionUtil;
import com.secoo.mall.common.util.reflect.FieldUtil;
import com.secoo.mall.common.util.string.StringUtil;
import com.secoo.mall.datasource.security.algorithm.encrypt.EncryptAlgorithm;
import com.secoo.mall.datasource.security.config.DataSourceSecurityProperties;
import com.secoo.mall.datasource.security.constant.PropertyProviderType;
import com.secoo.mall.datasource.security.constant.SymbolConstants;
import com.secoo.mall.datasource.security.exception.SecurityException;
import com.secoo.mall.datasource.security.factory.EncryptAlgorithmFactory;
import com.secoo.mall.datasource.security.rule.ColumnRule;
import com.secoo.mall.datasource.security.rule.DbRule;
import com.secoo.mall.datasource.security.rule.TableRule;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.reflect.MethodUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.util.*;
import java.util.stream.Collectors;
/**
* security property provider
*/
public class ApolloPropertyProviderAlgorithm implements PropertyProviderAlgorithm {
private static Logger log = LoggerFactory.getLogger(ApolloPropertyProviderAlgorithm.class);
private static String DATASOURCE_SECURITY_APOLLO_NAMESPACE = "arch.db_config";
private Properties props = new Properties();
@Override
public DataSourceSecurityProperties load() {
log.info("load security config from apollo...");
Config appConfig = ConfigService.getConfig(DATASOURCE_SECURITY_APOLLO_NAMESPACE);
Set<String> propertyNames = appConfig.getPropertyNames();
if (CollectionUtil.isEmpty(propertyNames)) {
throw new SecurityException("!!! Can not find apollo security rules !!!");
}
//用数据源名字为key,进行属性分组
List<String[]> propertyNameSections = propertyNames.stream()
.filter(key -> key.contains(DataSourceSecurityProperties.PREFIX + SymbolConstants.PROPERTY_SPLIT_CHAR))
.map(key -> key.replace(DataSourceSecurityProperties.PREFIX + SymbolConstants.PROPERTY_SPLIT_CHAR, "")
.split("\\"+SymbolConstants.PROPERTY_SPLIT_CHAR))
.collect(Collectors.toList());
propertyNameSections = propertyNameSections.stream().filter(key -> key.length == 4).collect(Collectors.toList());
Map<String, Class> fieldMap = FieldUtil.getAllFieldsList(ColumnRule.class).stream().collect(Collectors.toMap(Field::getName, Field::getType, (key1, key2) -> key1));
Set<String> notSupportPropertySet = new HashSet<>();
// db
Set<DbRule> dbRules = new HashSet<>();
Map<String, List<String[]>> dbSectionMap = propertyNameSections.stream().collect(Collectors.groupingBy(dbSection -> dbSection[0]));
for (Map.Entry<String, List<String[]>> dbSectionEntry : dbSectionMap.entrySet()) {
String dbName = dbSectionEntry.getKey();
if(StringUtils.isBlank(dbName)){
continue;
}
DbRule dbRule = new DbRule();
dbRule.setDbName(dbName);
// table
List<String[]> tableSections = dbSectionEntry.getValue();
if(tableSections == null || tableSections.isEmpty()){
continue;
}
Map<String, List<String[]>> tableSectionMap = tableSections.stream().collect(Collectors.groupingBy(tableSection -> tableSection[1]));
try {
for(Map.Entry<String,List<String[]>> tableSectionEntry:tableSectionMap.entrySet()){
String tableName = tableSectionEntry.getKey();
if(StringUtils.isBlank(tableName)){
continue;
}
TableRule tableRule = new TableRule();
tableRule.setTableName(tableName);
// column
List<String[]> columnSections = tableSectionEntry.getValue();
if(columnSections == null || columnSections.isEmpty()){
continue;
}
Map<String, List<String[]>> columnSectionMap = columnSections.stream().collect(Collectors.groupingBy(columnSection -> columnSection[2]));
for(Map.Entry<String,List<String[]>> columnSectionEntry:columnSectionMap.entrySet()){
String columnName = columnSectionEntry.getKey();
if(StringUtils.isBlank(columnName)){
continue;
}
ColumnRule columnRule = new ColumnRule();
// field
List<String[]> fieldSections = columnSectionEntry.getValue();
if(fieldSections == null || fieldSections.isEmpty()){
continue;
}
for(String[] fieldSection:fieldSections){
String configPropertyFullPath = DataSourceSecurityProperties.PREFIX + SymbolConstants.PROPERTY_SPLIT_CHAR + Arrays.stream(fieldSection).collect(Collectors.joining(SymbolConstants.PROPERTY_SPLIT_CHAR,"",""));
Object value = appConfig.getProperty(configPropertyFullPath, "");
//得到实际类的字段
String beanFieldName = fieldSection[3];
if (!fieldMap.containsKey(beanFieldName)) {
notSupportPropertySet.add(beanFieldName);
continue;
}
//非String类型的需要进行真实类型数据转换
if (!fieldMap.get(beanFieldName).isAssignableFrom(String.class)) {
value = MethodUtils.invokeExactStaticMethod(fieldMap.get(beanFieldName), "valueOf", value);
}
MethodUtils.invokeExactMethod(columnRule, "set" + StringUtil.capitalize(beanFieldName), value);
}
if(StringUtils.isBlank(columnRule.getEncryptType()) || StringUtils.isBlank(columnRule.getEncryptKey()) || StringUtils.isBlank(columnRule.getCipherColumn())){
continue;
}
Properties properties = new Properties();
properties.setProperty(EncryptAlgorithm.ENCRYPT_KEY,columnRule.getEncryptKey());
EncryptAlgorithm encryptAlgorithm = EncryptAlgorithmFactory.getObject(columnRule.getEncryptType().toLowerCase(), properties);
encryptAlgorithm.init();
columnRule.setEncryptAlgorithm(encryptAlgorithm);
if(tableRule.getColumnRules() == null){
tableRule.setColumnRules(new HashSet<>());
}
tableRule.getColumnRules().add(columnRule);
}
if(dbRule.getTableRules() == null){
dbRule.setTableRules(new HashSet<>());
}
dbRule.getTableRules().add(tableRule);
}
dbRules.add(dbRule);
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
throw new SecurityException("!!! Load security rule from apollo error !!!",e);
}
}
if (CollectionUtil.isNotEmpty(notSupportPropertySet)) {
log.warn("security rule not support the property:{} !!!", notSupportPropertySet);
}
DataSourceSecurityProperties dataSourceSecurityProperties = new DataSourceSecurityProperties();
dataSourceSecurityProperties.setRules(dbRules);
return dataSourceSecurityProperties;
}
@Override
public String getType() {
return PropertyProviderType.APOLLO;
}
public static void main(String[] args) {
// String x = "a-b-c";
// System.out.println(JSON.toJSONString(x.split("-",2)));
String x = "spring.datasource.security.secooActivityDB.t_price_rule_task.creator.plainColumn";
System.out.println(x.contains("spring.datasource.security\\."));
}
}
package com.secoo.mall.datasource.security.algorithm.property;
import com.secoo.mall.datasource.security.config.DataSourceSecurityProperties;
import com.secoo.mall.datasource.security.spi.TypedSPI;
/**
* Encrypt|decrypt algorithm for SPI.
*/
public interface PropertyProviderAlgorithm extends TypedSPI {
DataSourceSecurityProperties load();
}
package com.secoo.mall.datasource.security.config;
import com.secoo.mall.datasource.security.rule.DbRule;
import java.util.Set;
public class DataSourceSecurityProperties {
public static final String PREFIX = "spring.datasource.security";
/**
* 加解密规则
*/
private Set<DbRule> rules;
public Set<DbRule> getRules() {
return rules;
}
public void setRules(Set<DbRule> rules) {
this.rules = rules;
}
}
package com.secoo.mall.datasource.security.constant;
public class EncryptType {
public static final String DES = "des";
public static final String AES = "aes";
}
package com.secoo.mall.datasource.security.constant;
public class PropertyProviderType {
public static final String APOLLO = "apollo";
}
......@@ -4,4 +4,8 @@ public class SymbolConstants {
public static final String SEPARATOR = ",";
public static final String EQUAL = "=";
public static final String SEPARATOR_FEN = ";";
/**
* 属性配置分割符
*/
public static final String PROPERTY_SPLIT_CHAR = ".";
}
package com.secoo.mall.datasource.security.factory;
import com.secoo.mall.datasource.security.algorithm.encrypt.EncryptAlgorithm;
import com.secoo.mall.datasource.security.spi.SecurityServiceLoader;
import com.secoo.mall.datasource.security.spi.TypedSPIRegistry;
import java.util.Properties;
public class EncryptAlgorithmFactory {
static {
SecurityServiceLoader.register(EncryptAlgorithm.class);
}
public static EncryptAlgorithm getObject(String type, Properties props) {
return TypedSPIRegistry.getRegisteredService(EncryptAlgorithm.class, type, props);
}
}
package com.secoo.mall.datasource.security.factory;
import com.secoo.mall.datasource.security.algorithm.property.PropertyProviderAlgorithm;
import com.secoo.mall.datasource.security.spi.SecurityServiceLoader;
import com.secoo.mall.datasource.security.spi.TypedSPIRegistry;
import java.util.Properties;
public class PropertyProviderAlgorithmFactory {
static {
SecurityServiceLoader.register(PropertyProviderAlgorithm.class);
}
public static PropertyProviderAlgorithm getObject(String type, Properties props) {
return TypedSPIRegistry.getRegisteredService(PropertyProviderAlgorithm.class, type, props);
}
}
package com.secoo.mall.datasource.security.filter;
import com.alibaba.druid.filter.AutoLoad;
import com.alibaba.druid.filter.FilterChain;
import com.alibaba.druid.filter.FilterEventAdapter;
import com.alibaba.druid.proxy.jdbc.*;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem;
import com.alibaba.druid.sql.ast.statement.SQLUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.stat.TableStat;
import com.alibaba.druid.util.Utils;
import com.secoo.mall.datasource.security.rule.ColumnRule;
import com.secoo.mall.datasource.security.rule.DbRule;
import com.secoo.mall.datasource.security.rule.TableRule;
import com.secoo.mall.datasource.security.util.SecurityUtil;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.Reader;
import java.io.StringReader;
import java.security.GeneralSecurityException;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Types;
import java.util.*;
@AutoLoad
public class SecurityFilter extends FilterEventAdapter {
private static final Logger log = LoggerFactory.getLogger(SecurityFilter.class);
private String dbType;
private String dbName;
private Map<String, Map<String, ColumnRule>> tableRuleMap;
public String getDbType() {
return dbType;
}
public String getDbName() {
return dbName;
}
public Map<String, Map<String, ColumnRule>> getTableRuleMap() {
return tableRuleMap;
}
public void init(DataSourceProxy dataSource) {
this.dbType = dataSource.getDbType();
this.dbName = SecurityUtil.findDataBaseNameByUrl(dataSource.getUrl());
if(StringUtils.isBlank(this.dbName)){
return;
}
Set<DbRule> dbRules = SecurityFilterContext.getInstance().getDbRules();
this.setSecurityRules(this.dbName, dbRules);
}
@Override
public String resultSet_getString(FilterChain chain, ResultSetProxy result, int columnIndex) throws SQLException {
ResultSet rawResultSet = result.getResultSetRaw();
ResultSetMetaData metadata = rawResultSet.getMetaData();
ColumnRule columnRule = this.getColumnRule(metadata,columnIndex);
if(columnRule == null){
return super.resultSet_getString(chain, result, columnIndex);
}
String value = super.resultSet_getString(chain, result, columnIndex);
return decrypt(columnRule, value);
}
@Override
public String resultSet_getString(FilterChain chain, ResultSetProxy result, String columnLabel) throws SQLException {
ResultSet rawResultSet = result.getResultSetRaw();
int columnIndex = rawResultSet.findColumn(columnLabel);
ResultSetMetaData metadata = rawResultSet.getMetaData();
ColumnRule columnRule = this.getColumnRule(metadata,columnIndex);
if(columnRule == null){
return super.resultSet_getString(chain, result, columnLabel);
}
String value = super.resultSet_getString(chain, result, columnLabel);
return decrypt(columnRule, value);
}
@Override
public Object resultSet_getObject(FilterChain chain, ResultSetProxy result, int columnIndex) throws SQLException {
ResultSet rawResultSet = result.getResultSetRaw();
ResultSetMetaData metadata = rawResultSet.getMetaData();
int columnType = metadata.getColumnType(columnIndex);
ColumnRule columnRule = this.getColumnRule(metadata,columnIndex);
if(columnRule == null){
return super.resultSet_getObject(chain, result, columnIndex);
}
Object value = null;
switch (columnType) {
case Types.CHAR:
case Types.CLOB:
case Types.LONGVARCHAR:
case Types.VARCHAR:
value = super.resultSet_getString(chain, result, columnIndex);
break;
default:
value = super.resultSet_getObject(chain, result, columnIndex);
}
return decryptObject(columnRule, value);
}
@Override
public <T> T resultSet_getObject(FilterChain chain, ResultSetProxy result, int columnIndex, Class<T> type) throws SQLException {
ResultSet rawResultSet = result.getResultSetRaw();
ResultSetMetaData metadata = rawResultSet.getMetaData();
int columnType = metadata.getColumnType(columnIndex);
ColumnRule columnRule = this.getColumnRule(metadata,columnIndex);
if(columnRule == null){
return super.resultSet_getObject(chain, result, columnIndex,type);
}
Object value = null;
switch (columnType) {
case Types.CHAR:
case Types.CLOB:
case Types.LONGVARCHAR:
case Types.VARCHAR:
value = super.resultSet_getString(chain, result, columnIndex);
break;
default:
value = super.resultSet_getObject(chain, result, columnIndex, type);
}
return (T) decryptObject(columnRule, value);
}
@Override
public Object resultSet_getObject(FilterChain chain, ResultSetProxy result, int columnIndex,
java.util.Map<String, Class<?>> map) throws SQLException {
ResultSet rawResultSet = result.getResultSetRaw();
ResultSetMetaData metadata = rawResultSet.getMetaData();
int columnType = metadata.getColumnType(columnIndex);
ColumnRule columnRule = this.getColumnRule(metadata,columnIndex);
if(columnRule == null){
return super.resultSet_getObject(chain, result, columnIndex,map);
}
Object value = null;
switch (columnType) {
case Types.CHAR:
case Types.CLOB:
case Types.LONGVARCHAR:
case Types.VARCHAR:
value = super.resultSet_getString(chain, result, columnIndex);
break;
default:
value = super.resultSet_getObject(chain, result, columnIndex, map);
}
return decryptObject(columnRule, value);
}
@Override
public Object resultSet_getObject(FilterChain chain, ResultSetProxy result, String columnLabel) throws SQLException {
ResultSet rawResultSet = result.getResultSetRaw();
int columnIndex = rawResultSet.findColumn(columnLabel);
ResultSetMetaData metadata = rawResultSet.getMetaData();
int columnType = metadata.getColumnType(columnIndex);
ColumnRule columnRule = this.getColumnRule(metadata,columnIndex);
if(columnRule == null){
return super.resultSet_getObject(chain, result, columnLabel);
}
Object value = null;
switch (columnType) {
case Types.CHAR:
case Types.CLOB:
case Types.LONGVARCHAR:
case Types.VARCHAR:
value = super.resultSet_getString(chain, result, columnLabel);
break;
default:
value = super.resultSet_getObject(chain, result, columnLabel);
}
return decryptObject(columnRule, value);
}
@Override
public <T> T resultSet_getObject(FilterChain chain, ResultSetProxy result, String columnLabel, Class<T> type) throws SQLException {
ResultSet rawResultSet = result.getResultSetRaw();
int columnIndex = rawResultSet.findColumn(columnLabel);
ResultSetMetaData metadata = rawResultSet.getMetaData();
int columnType = metadata.getColumnType(columnIndex);
ColumnRule columnRule = this.getColumnRule(metadata,columnIndex);
if(columnRule == null){
return super.resultSet_getObject(chain, result, columnLabel,type);
}
Object value = null;
switch (columnType) {
case Types.CHAR:
case Types.CLOB:
case Types.LONGVARCHAR:
case Types.VARCHAR:
value = super.resultSet_getString(chain, result, columnLabel);
break;
default:
value = super.resultSet_getObject(chain, result, columnLabel, type);
}
return (T) decryptObject(columnRule, value);
}
@Override
public Object resultSet_getObject(FilterChain chain, ResultSetProxy result, String columnLabel,
java.util.Map<String, Class<?>> map) throws SQLException {
ResultSet rawResultSet = result.getResultSetRaw();
int columnIndex = rawResultSet.findColumn(columnLabel);
ResultSetMetaData metadata = rawResultSet.getMetaData();
int columnType = metadata.getColumnType(columnIndex);
ColumnRule columnRule = this.getColumnRule(metadata,columnIndex);
if(columnRule == null){
return super.resultSet_getObject(chain, result, columnLabel, map);
}
Object value = null;
switch (columnType) {
case Types.CHAR:
value = super.resultSet_getString(chain, result, columnLabel);
break;
case Types.CLOB:
value = super.resultSet_getString(chain, result, columnLabel);
break;
case Types.LONGVARCHAR:
value = super.resultSet_getString(chain, result, columnLabel);
break;
case Types.VARCHAR:
value = super.resultSet_getString(chain, result, columnLabel);
break;
default:
value = super.resultSet_getObject(chain, result, columnLabel, map);
}
return decryptObject(columnRule, value);
}
public Object decryptObject(ColumnRule columnRule, Object object) {
if (object instanceof String) {
return decrypt(columnRule, (String) object);
}
if (object instanceof Reader) {
Reader reader = (Reader) object;
String text = Utils.read(reader);
return new StringReader(decrypt(columnRule, text));
}
return object;
}
/**
* 解密
* @param columnRule
* @param cipherText
* @return
*/
public String decrypt(ColumnRule columnRule, String cipherText) {
try {
String plainText = columnRule.getEncryptAlgorithm().decrypt(cipherText);
log.debug("字段解密:columnRule={},cipherText={},plainText={}", columnRule, cipherText, plainText);
return plainText;
} catch (GeneralSecurityException e) {
String errorMsg = "字段解密异常:columnRule="+columnRule.toString()+",cipherText="+cipherText;
throw new SecurityException(errorMsg,e);
}
}
/**
* 获取解密列
* @param metadata
* @param columnIndex
* @return
* @throws SQLException
*/
private ColumnRule getColumnRule(ResultSetMetaData metadata,int columnIndex) throws SQLException {
if(this.getTableRuleMap() == null || this.getTableRuleMap().isEmpty()){
return null;
}
String columnName = metadata.getColumnName(columnIndex);
String tableName = metadata.getTableName(columnIndex);
Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(tableName);
if(columnRuleMap == null || columnRuleMap.isEmpty()){
return null;
}
ColumnRule columnRule = columnRuleMap.get(columnName);
if(columnRule == null){
return null;
}
return columnRule;
}
/**
* 加密
* @param statement
* @param sql
*/
protected void statementExecuteBefore(StatementProxy statement, String sql) {
if (getTableRuleMap() == null || getTableRuleMap().isEmpty()) {
return;
}
this.encryptStatement(statement, sql);
}
protected void encryptStatement(StatementProxy statement, String sql) {
if (!(statement instanceof PreparedStatementProxy)) {
log.debug("不需要处理的statement:{}", sql);
return;
}
PreparedStatementProxyImpl preparedStatement = (PreparedStatementProxyImpl) statement;
// 解析sql
List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, this.getDbType());
for (SQLStatement stmt : stmtList) {
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
stmt.accept(visitor);
int index = 0;
// 查询 | 删除
if (stmt instanceof SQLSelectStatement || stmt instanceof MySqlDeleteStatement) {
// 查询条件
for (TableStat.Condition condition : visitor.getConditions()) {
// 查询条件值,in/between语句等可能有多个
for (Object conditionValue : condition.getValues()) {
// 解析出条件值为空为查询条件
if (conditionValue != null) {
continue;
}
TableStat.Column column = condition.getColumn();
Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(column.getTable());
if (columnRuleMap != null && !columnRuleMap.isEmpty()) {
// 需要加密的字段
ColumnRule columnRule = columnRuleMap.get(column.getName());
if (columnRule != null) {
encrypt(columnRule, preparedStatement, index);
}
}
index ++;
}
}
}
// 插入
else if (stmt instanceof MySqlInsertStatement) {
MySqlInsertStatement insertStmt = (MySqlInsertStatement) stmt;
String tableName = insertStmt.getTableName().getSimpleName();
Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(tableName);
if (columnRuleMap != null && !columnRuleMap.isEmpty()) {
int valuesSize = insertStmt.getValuesList().size();
Collection<TableStat.Column> columns = visitor.getColumns();
int columnSize = columns.size();
for (TableStat.Column column : columns) {
// 需要加密的字段
ColumnRule columnRule = columnRuleMap.get(column.getName());
if (columnRule != null) {
for (int valueIndex = 0; valueIndex < valuesSize; valueIndex++) {
encrypt(columnRule, preparedStatement, index + valueIndex * columnSize);
}
}
}
}
index ++;
}
// 更新
else if (stmt instanceof MySqlUpdateStatement) {
MySqlUpdateStatement updateStmt = (MySqlUpdateStatement) stmt;
// 更新语句应该只支持单表
String tableName = updateStmt.getTableName().getSimpleName();
Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(tableName);
if (columnRuleMap == null || columnRuleMap.isEmpty()) {
continue;
}
// 处理set
for (SQLUpdateSetItem item : updateStmt.getItems()) {
SQLExpr column = item.getColumn();
if (item.getValue() instanceof SQLVariantRefExpr && column instanceof SQLIdentifierExpr) {
// 需要加密的字段
String columnName = ((SQLIdentifierExpr) column).getName();
ColumnRule columnRule = columnRuleMap.get(columnName);
if (columnRule != null) {
encrypt(columnRule, preparedStatement,index);
}
}
index++;
}
// 处理where
for (TableStat.Condition condition : visitor.getConditions()) {
// 查询条件值,in/between语句等可能有多个
for (Object conditionValue : condition.getValues()) {
// 解析出条件值为空为查询条件
if (conditionValue == null) {
continue;
}
TableStat.Column column = condition.getColumn();
ColumnRule columnRule = columnRuleMap.get(column.getName());
if (columnRule != null) {
encrypt(columnRule, preparedStatement,index);
}
index++;
}
}
}
// 其他
else {
for (TableStat.Column column : visitor.getColumns()) {
Map<String, ColumnRule> columnRuleMap = this.getTableRuleMap().get(column.getTable());
if (columnRuleMap != null && !columnRuleMap.isEmpty()) {
// 需要加密的字段
ColumnRule columnRule = columnRuleMap.get(column.getName());
if (columnRule != null) {
encrypt(columnRule, preparedStatement,index);
}
}
index++;
}
}
}
}
/**
* 加密
* @param columnRule
* @param preparedStatement
* @param index
*/
private void encrypt(ColumnRule columnRule, PreparedStatementProxyImpl preparedStatement,int index) {
JdbcParameter jdbcParameter = preparedStatement.getParameter(index);
final Object plainText = jdbcParameter.getValue();
if (plainText == null) {
return;
}
try {
String cipherText = columnRule.getEncryptAlgorithm().encrypt(plainText);
preparedStatement.setObject(index + 1,cipherText);
log.debug("字段加密:columnRule={},plainText={},cipherText={}", columnRule, plainText, cipherText);
} catch (SQLException | GeneralSecurityException e) {
String errorMsg = "字段加密异常:columnRule="+columnRule+",plainText="+plainText;
throw new SecurityException(errorMsg,e);
}
}
/**
* 初始化规则
* @param dbName
* @param dbRules
* @return
*/
private void setSecurityRules(String dbName,Set<DbRule> dbRules) {
if(StringUtils.isBlank(dbName) || dbRules == null || dbRules.isEmpty()){
return;
}
DbRule dbRule = null;
for(DbRule e:dbRules){
if(dbName.equals(e.getDbName())){
dbRule = e;
}
}
if(dbRule == null || dbRule.getTableRules() == null || dbRule.getTableRules().isEmpty()){
return;
}
Set<TableRule> tableRules = dbRule.getTableRules();
if (tableRules == null || tableRules.isEmpty()) {
return;
}
Map<String, Map<String, ColumnRule>> tableRuleMap = new HashMap<>();
for(TableRule tableRule: tableRules){
if(tableRule == null || tableRule.getColumnRules() == null || tableRule.getColumnRules().isEmpty()){
continue;
}
Map<String, ColumnRule> columnRuleMap = new HashMap<>();
for(ColumnRule columnRule:tableRule.getColumnRules()){
columnRuleMap.put(columnRule.getLogicColumn(),columnRule);
}
tableRuleMap.put(tableRule.getTableName(),columnRuleMap);
}
this.tableRuleMap = tableRuleMap;
}
public static void main(String[] args) {
String sql = "update t_user u,t_account a set u.name=?,a.age=10 where u.id = a.id and u.id > 12 and a.age >?";
// 解析sql
List<SQLStatement> stmtList = SQLUtils.parseStatements(sql,"mysql");
for (SQLStatement stmt : stmtList) {
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
stmt.accept(visitor);
SQLUpdateStatement _stmt = (SQLUpdateStatement) stmt;
System.out.println(visitor.getColumns());
System.out.println(_stmt.getItems());
SQLExpr sqlExpr = SQLUtils.toSQLExpr(sql,"sql");
System.out.println(sqlExpr);
}
}
}
package com.secoo.mall.datasource.security.filter;
import com.alibaba.druid.filter.AutoLoad;
import com.secoo.mall.datasource.security.algorithm.property.PropertyProviderAlgorithm;
import com.secoo.mall.datasource.security.config.DataSourceSecurityProperties;
import com.secoo.mall.datasource.security.constant.PropertyProviderType;
import com.secoo.mall.datasource.security.exception.SecurityException;
import com.secoo.mall.datasource.security.factory.PropertyProviderAlgorithmFactory;
import com.secoo.mall.datasource.security.rule.DbRule;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Properties;
import java.util.Set;
@AutoLoad
public class SecurityFilterContext {
private static final Logger log = LoggerFactory.getLogger(SecurityFilterContext.class);
private static final String DATASOURCE_SECURITY_PROPERTY_PROVIDER = "datasource.security.propertyProvider";
private static final String DEFAULT_PROPERTY_PROVIDER = PropertyProviderType.APOLLO;
private static class LazyHolder{
private static final SecurityFilterContext INSTANCE = new SecurityFilterContext();
}
public static SecurityFilterContext getInstance() {
return LazyHolder.INSTANCE;
}
public SecurityFilterContext(){
this.init();
}
private void init(){
String propertyProvider = System.getProperty(DATASOURCE_SECURITY_PROPERTY_PROVIDER);
if(StringUtils.isBlank(propertyProvider)){
propertyProvider = System.getenv(DATASOURCE_SECURITY_PROPERTY_PROVIDER);
}
if(StringUtils.isBlank(propertyProvider)){
propertyProvider =DEFAULT_PROPERTY_PROVIDER;
}
PropertyProviderAlgorithm propertyProviderAlgorithm = PropertyProviderAlgorithmFactory.getObject(propertyProvider,new Properties());
DataSourceSecurityProperties dataSourceSecurityProperties = propertyProviderAlgorithm.load();
if(dataSourceSecurityProperties == null){
// log.warn("!!! Can not find security rules !!!");
throw new SecurityException("!!! Can not find security rules !!!");
}
}
private Set<DbRule> dbRules;
public Set<DbRule> getDbRules() {
return dbRules;
}
public void setDbRules(Set<DbRule> dbRules) {
this.dbRules = dbRules;
}
}
package com.secoo.mall.datasource.security.rule;
import com.secoo.mall.datasource.security.algorithm.AESSecurityAlgorithm;
import com.secoo.mall.datasource.security.algorithm.SecurityAlgorithm;
import com.secoo.mall.datasource.security.algorithm.encrypt.EncryptAlgorithm;
public class ColumnRule {
/**
* 加密类型
*/
private String encryptType;
/**
* 加密器配置
......@@ -25,7 +27,15 @@ public class ColumnRule {
/**
* 加密器
*/
private SecurityAlgorithm securityAlgorithm;
private EncryptAlgorithm encryptAlgorithm;
public String getEncryptType() {
return encryptType;
}
public void setEncryptType(String encryptType) {
this.encryptType = encryptType;
}
public String getEncryptKey() {
return encryptKey;
......@@ -33,7 +43,6 @@ public class ColumnRule {
public void setEncryptKey(String encryptKey) {
this.encryptKey = encryptKey;
this.securityAlgorithm = new AESSecurityAlgorithm(encryptKey);
}
public String getLogicColumn() {
......@@ -60,11 +69,23 @@ public class ColumnRule {
this.cipherColumn = cipherColumn;
}
public SecurityAlgorithm getSecurityAlgorithm() {
return securityAlgorithm;
public EncryptAlgorithm getEncryptAlgorithm() {
return encryptAlgorithm;
}
public void setEncryptAlgorithm(EncryptAlgorithm encryptAlgorithm) {
this.encryptAlgorithm = encryptAlgorithm;
}
public void setSecurityAlgorithm(SecurityAlgorithm securityAlgorithm) {
this.securityAlgorithm = securityAlgorithm;
@Override
public String toString() {
return "ColumnRule{" +
"encryptType='" + encryptType + '\'' +
", encryptKey='" + encryptKey + '\'' +
", logicColumn='" + logicColumn + '\'' +
", plainColumn='" + plainColumn + '\'' +
", cipherColumn='" + cipherColumn + '\'' +
", encryptAlgorithm=" + encryptAlgorithm +
'}';
}
}
......@@ -2,22 +2,22 @@ package com.secoo.mall.datasource.security.rule;
import java.util.Set;
public class DataSourceRule {
public class DbRule {
/**
* 数据库名
*/
private String datasourceName;
private String dbName;
/**
* 表规则
*/
private Set<TableRule> tableRules;
public String getDatasourceName() {
return datasourceName;
public String getDbName() {
return dbName;
}
public void setDatasourceName(String datasourceName) {
this.datasourceName = datasourceName;
public void setDbName(String dbName) {
this.dbName = dbName;
}
public Set<TableRule> getTableRules() {
......@@ -27,4 +27,12 @@ public class DataSourceRule {
public void setTableRules(Set<TableRule> tableRules) {
this.tableRules = tableRules;
}
@Override
public String toString() {
return "DbRule{" +
"dbName='" + dbName + '\'' +
", tableRules=" + tableRules +
'}';
}
}
......@@ -27,4 +27,12 @@ public class TableRule {
public void setColumnRules(Set<ColumnRule> columnRules) {
this.columnRules = columnRules;
}
@Override
public String toString() {
return "TableRule{" +
"tableName='" + tableName + '\'' +
", columnRules=" + columnRules +
'}';
}
}
package com.secoo.mall.datasource.security.spi;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
public class SecurityServiceLoader {
private static final Map<Class<?>, Collection<Object>> SERVICES = new ConcurrentHashMap<>();
/**
* Register service.
*
* @param serviceInterface service interface
*/
public static void register(final Class<?> serviceInterface) {
if (!SERVICES.containsKey(serviceInterface)) {
SERVICES.put(serviceInterface, load(serviceInterface));
}
}
private static <T> Collection<Object> load(final Class<T> serviceInterface) {
Collection<Object> result = new LinkedList<>();
for (T each : ServiceLoader.load(serviceInterface)) {
result.add(each);
}
return result;
}
/**
* Get singleton service instances.
*
* @param service service class
* @param <T> type of service
* @return service instances
*/
@SuppressWarnings("unchecked")
public static <T> Collection<T> getSingletonServiceInstances(final Class<T> service) {
return (Collection<T>) SERVICES.getOrDefault(service, Collections.emptyList());
}
/**
* New service instances.
*
* @param service service class
* @param <T> type of service
* @return service instances
*/
@SuppressWarnings("unchecked")
public static <T> Collection<T> newServiceInstances(final Class<T> service) {
return SERVICES.containsKey(service) ? SERVICES.get(service).stream().map(each -> (T) newServiceInstance(each.getClass())).collect(Collectors.toList()) : Collections.emptyList();
}
private static Object newServiceInstance(final Class<?> clazz) {
try {
return clazz.newInstance();
} catch (final InstantiationException | IllegalAccessException ex) {
throw new SecurityException(String.format("Can not find public default constructor for SPI class `%s`", clazz.getName()), ex);
}
}
}
package com.secoo.mall.datasource.security.spi;
import java.util.Properties;
public interface TypedSPI {
/**
* Get type.
*
* @return type
*/
String getType();
/**
* Get properties.
*
* @return properties
*/
default Properties getProps() {
return new Properties();
}
/**
* Set properties.
*
* @param props properties
*/
default void setProps(final Properties props) {}
}
package com.secoo.mall.datasource.security.spi;
import java.util.Optional;
import java.util.Properties;
public final class TypedSPIRegistry {
/**
* Find registered service.
*
* @param typedSPIClass typed SPI class
* @param type type
* @param props properties
* @param <T> type
* @return registered service
*/
public static <T extends TypedSPI> Optional<T> findRegisteredService(final Class<T> typedSPIClass, final String type, final Properties props) {
Optional<T> serviceInstance = SecurityServiceLoader.newServiceInstances(typedSPIClass).stream().filter(each -> each.getType().equalsIgnoreCase(type)).findFirst();
if (serviceInstance.isPresent()) {
T result = serviceInstance.get();
convertPropertiesValueType(props, result);
return Optional.of(result);
}
return Optional.empty();
}
/**
* Find registered service.
*
* @param typedSPIClass typed SPI class
* @param <T> type
* @return registered service
*/
public static <T extends TypedSPI> Optional<T> findRegisteredService(final Class<T> typedSPIClass) {
return SecurityServiceLoader.newServiceInstances(typedSPIClass).stream().findFirst();
}
/**
* Get registered service.
*
* @param typedSPIClass typed SPI class
* @param type type
* @param props properties
* @param <T> type
* @return registered service
*/
public static <T extends TypedSPI> T getRegisteredService(final Class<T> typedSPIClass, final String type, final Properties props) {
Optional<T> result = findRegisteredService(typedSPIClass, type, props);
if (result.isPresent()) {
return result.get();
}
throw new SecurityException(String.format("No implementation class load from SPI `%s` with type `%s`.", typedSPIClass.getName(), type));
}
/**
* Get registered service.
*
* @param typedSPIClass typed SPI class
* @param <T> type
* @return registered service
*/
public static <T extends TypedSPI> T getRegisteredService(final Class<T> typedSPIClass) {
Optional<T> serviceInstance = SecurityServiceLoader.newServiceInstances(typedSPIClass).stream().findFirst();
if (serviceInstance.isPresent()) {
return serviceInstance.get();
}
throw new SecurityException(String.format("No implementation class load from SPI `%s`.", typedSPIClass.getName()));
}
private static <T extends TypedSPI> void convertPropertiesValueType(final Properties props, final T service) {
if (null != props) {
Properties newProps = new Properties();
props.forEach((key, value) -> newProps.setProperty(key.toString(), null == value ? null : value.toString()));
service.setProps(newProps);
}
}
}
package com.secoo.mall.datasource.security.util;
import org.apache.commons.lang3.StringUtils;
public class SecurityUtil {
public static String findDataBaseNameByUrl(String jdbcUrl) {
String database = null;
int pos;
String tmpJdbcUrl;
if (StringUtils.isBlank(jdbcUrl)) {
throw new IllegalArgumentException("Invalid JDBC url !");
}
if (jdbcUrl.startsWith("jdbc:impala")) {
jdbcUrl = jdbcUrl.replace(":impala", "");
}
if (!jdbcUrl.startsWith("jdbc:") || (pos = jdbcUrl.indexOf(':', 5)) == -1) {
throw new IllegalArgumentException("Invalid JDBC url !");
}
tmpJdbcUrl = jdbcUrl.substring(pos + 1);
if (!tmpJdbcUrl.startsWith("//")) {
pos = tmpJdbcUrl.indexOf("//");
if(pos == -1){
throw new IllegalArgumentException("Invalid JDBC url !");
}
tmpJdbcUrl = tmpJdbcUrl.substring(pos);
}
if ((pos = tmpJdbcUrl.indexOf('/', 2)) != -1) {
database = tmpJdbcUrl.substring(pos + 1);
}
if (database.contains("?")) {
database = database.substring(0, database.indexOf("?"));
}
if (database.contains(";")) {
database = database.substring(0, database.indexOf(";"));
}
if (StringUtils.isBlank(database)) {
throw new IllegalArgumentException("Invalid JDBC url !");
}
return database;
}
public static void main(String[] args) {
System.out.println(findDataBaseNameByUrl("jdbc:mysql://abcDB.slave.com/abcDB;xxx?useUnicode=true&allowMultiQueries=true&characterEncoding=utf8&noAccessToProcedureBodies=true&zeroDateTimeBehavior=convertToNull&useSSL=false&autoReconnect=true&failOverReadOnly=false&serverTimezone=GMT%2B8"));
}
}
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>matrix-datasource-security</artifactId>
<groupId>com.secoo.mall</groupId>
<version>2.0.17.RELEASE</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>matrix-datasource-security-starter</artifactId>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>com.secoo.mall</groupId>
<artifactId>matrix-datasource-security-core</artifactId>
</dependency>
<dependency>
<groupId>com.secoo.mall</groupId>
<artifactId>matrix-datasource-core</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
</dependencies>
<build>
<finalName>matrix-datasource-security-starter</finalName>
<plugins>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
package com.secoo.mall.datasource.security.config;
import com.secoo.mall.datasource.security.rule.DataSourceRule;
import org.springframework.boot.context.properties.ConfigurationProperties;
import java.util.Set;
@ConfigurationProperties(prefix = DataSourceSecurityProperties.PREFIX)
public class DataSourceSecurityProperties {
public static final String PREFIX = "spring.matrix.security";
private Boolean enabled = true;
/**
* 是否启用并行处理,默认不启用
*/
private Boolean parallelEnabled = false;
/**
* 建议core=max
*/
private Integer corePoolSize;
/**
* 建议core=max
*/
private Integer maxPoolSize;
/**
* 加解密规则
*/
public Set<DataSourceRule> datasourceRules;
public Boolean getEnabled() {
return enabled;
}
public void setEnabled(Boolean enabled) {
this.enabled = enabled;
}
public Boolean getParallelEnabled() {
return parallelEnabled;
}
public void setParallelEnabled(Boolean parallelEnabled) {
this.parallelEnabled = parallelEnabled;
}
public Integer getCorePoolSize() {
return corePoolSize;
}
public void setCorePoolSize(Integer corePoolSize) {
this.corePoolSize = corePoolSize;
}
public Integer getMaxPoolSize() {
return maxPoolSize;
}
public void setMaxPoolSize(Integer maxPoolSize) {
this.maxPoolSize = maxPoolSize;
}
public Set<DataSourceRule> getDatasourceRules() {
return datasourceRules;
}
public void setDatasourceRules(Set<DataSourceRule> datasourceRules) {
this.datasourceRules = datasourceRules;
}
}
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
......@@ -12,15 +11,15 @@
<artifactId>matrix-datasource-security</artifactId>
<packaging>pom</packaging>
<modules>
<module>matrix-datasource-security-core</module>
<module>matrix-datasource-security-starter</module>
<module>matrix-datasource-security-druid</module>
<module>matrix-datasource-security-demo</module>
</modules>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>com.secoo.mall</groupId>
<artifactId>matrix-datasource-security-core</artifactId>
<artifactId>matrix-datasource-security-druid</artifactId>
<version>2.0.17.RELEASE</version>
</dependency>
</dependencies>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment