Spring Boot MVC 自定义参数解析器
Closed this issue · 1 comments
Xuguozong commented
背景
遗留系统
: 遗留接口系统为 Golang + Gin + MongoDB 实现, MongoDB collection 的 field 字段为中文名而且不固定(可以通过 excel 文件自定义表头的形式添加), 查询参数是直接传的中文形式, 而且传参不固定(query参数的 key 和 value 都不固定),如:/api/v1/customers?姓名=张三&录入时间=2022-09-01 00:00:00&录入时间=2022-09-02 23:59:59
问题
: go 工程师离职, 后端只有 java 工程师的情况下,将此接口改用 java 重写一遍,但传参方式不变(query key 还是为中文)
实现
可选方案
- 方案一:不用 java 改写,java 工程师维护 golang 项目
- 优点:代码改动最小,出现 bug 的概率也最小
- 缺点:需要花时间学习 go 语言和相关项目
- 方案二:获取 HttpServletRequset 对象, 在相应的接口入口层进行参数处理
- 优点:能够获取到所有参数,快速实现功能
- 缺点:针对特定接口实现,不利于代码扩展
- 方案三:自定义 MVC 层参数解析器解析请求参数并绑定
- 优点:能够抽取此类需求的公共参数,结合反射和注解机制利于相似需求的扩展
- 缺点:相对于方案二性能可能会差一点。
- 涉及到反射处理
方案三实现
- MVC 方法参数处理器接口 HandlerMethodArgumentResolver
public interface HandlerMethodArgumentResolver {
/**
* 此解析器是否支持该方法参数的解析
*/
boolean supportsParameter(MethodParameter parameter);
/**
* 将拿到的原始数据解析成想要的参数对象
* A {@link ModelAndViewContainer} provides access to the model for the
* request. A {@link WebDataBinderFactory} provides a way to create
* a {@link WebDataBinder} instance when needed for data binding and
* type conversion purposes.
* @param parameter the method parameter to resolve. This parameter must
* have previously been passed to {@link #supportsParameter} which must
* have returned {@code true}.
* @param mavContainer the ModelAndViewContainer for the current request
* @param webRequest the current request
* @param binderFactory a factory for creating {@link WebDataBinder} instances
* @return the resolved argument value, or {@code null} if not resolvable
* @throws Exception in case of errors with the preparation of argument values
*/
@Nullable
Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer,
NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception;
}
- 实现自定义方法参数处理器
@Slf4j
public class ZhFieldRequestResolver implements HandlerMethodArgumentResolver {
@Override
public boolean supportsParameter(MethodParameter parameter) {
// 有 ZhBindConvertor 注解的参数才进行解析转换
return parameter.hasParameterAnnotation(ZhBindConvertor.class);
}
@Override
public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer mavContainer,
NativeWebRequest webRequest, WebDataBinderFactory binderFactory) throws Exception {
HttpServletRequest request = webRequest.getNativeRequest(HttpServletRequest.class);
assert request != null;
// 获取参数类型,根据参数类型反射创建类型实例
Class<?> resultType = parameter.getParameterType();
return buildResultObject(resultType, request);
}
private Object buildResultObject(Class<?> resultType, HttpServletRequest request) throws InvocationTargetException,
NoSuchMethodException, InstantiationException, IllegalAccessException, IOException {
String method = request.getMethod();
// 根据不同的 http 方法,使用不同的参数构建模式
return switch (method) {
case "POST" -> buildResultObjectForPost(resultType, request);
case "GET" -> buildResultObjectForGet(resultType, request);
default -> throw new IllegalStateException("不支持的 http 方法类型: " + method);
};
}
/**
* GET 方法直接通过 {@link ServletRequest#getParameterMap()} 方法获取请求参数
*/
private Object buildResultObjectForGet(Class<?> resultType, HttpServletRequest request) throws NoSuchMethodException,
InvocationTargetException, InstantiationException, IllegalAccessException {
Map<String, String[]> parameterMap = request.getParameterMap();
Map<String, List<String>> pMap = new HashMap<>();
parameterMap.forEach((k, v) -> pMap.put(k, List.of(v)));
Field[] fields = resultType.getDeclaredFields();
Class<?> superclass = resultType.getSuperclass();
Field[] superFields = null;
if (!superclass.equals(Object.class)) {
superFields = superclass.getDeclaredFields();
}
// 反射实例化参数对象
Object instance = resultType.getDeclaredConstructor(null).newInstance(null);
// 填充对象字段信息
setFieldsForGet(fields, instance, parameterMap, pMap);
if (superFields != null) {
setFieldsForGet(superFields, instance, parameterMap, pMap);
}
return instance;
}
/**
* 通过 field type 来设置对应的字段值
* https://docs.oracle.com/javase/tutorial/reflect/member/fieldTypes.html
*/
private void setFieldsForGet(Field[] fields, Object instance, Map<String, String[]> parameterMap, Map<String,
List<String>> pMap) throws IllegalAccessException {
for (Field f : fields) {
f.setAccessible(true);
String typeName = f.getGenericType().getTypeName();
ZhBindAlias bindAlias = f.getAnnotation(ZhBindAlias.class);
if (Objects.nonNull(bindAlias)) {
String name = bindAlias.value();
int index = bindAlias.index();
String[] values = parameterMap.get(name);
if (values != null && values.length > 0) {
Object convertValue = FieldTypeConvertor.FieldType.of(typeName).convert(values, index);
f.set(instance, convertValue);
}
pMap.remove(name);
if ("extras".equals(f.getName()) && !pMap.isEmpty()) {
f.set(instance, pMap);
}
} else {
String[] values = parameterMap.get(f.getName());
if (values != null && values.length > 0) {
Object convertValue = FieldTypeConvertor.FieldType.of(typeName).convertFirst(values);
f.set(instance, convertValue);
}
pMap.remove(f.getName());
}
}
}
/**
* POST 方法通过获取请求体 json 字符串转请求对象的方式获取参数信息
*/
private Object buildResultObjectForPost(Class<?> resultType, HttpServletRequest request) throws NoSuchMethodException,
InvocationTargetException, InstantiationException, IllegalAccessException, IOException {
// 验证 header 的 Content-Type 为 application/json 才能进行后续操作
String contentTypeHeader = request.getHeader(HttpHeaders.CONTENT_TYPE);
if (StringUtils.isBlank(contentTypeHeader) || !contentTypeHeader.equals(MediaType.APPLICATION_JSON_VALUE)) {
throw new IllegalStateException("请设置 Content-Type 值为 application/json");
}
Field[] fields = resultType.getDeclaredFields();
Class<?> superclass = resultType.getSuperclass();
Field[] superFields = null;
if (!superclass.equals(Object.class)) {
superFields = superclass.getDeclaredFields();
}
Object instance = resultType.getDeclaredConstructor(null).newInstance(null);
StringBuilder sb = new StringBuilder();
try (BufferedReader reader = request.getReader()) {
String line = reader.readLine();
while (StringUtils.isNotBlank(line)) {
sb.append(line);
line = reader.readLine();
}
log.info("uri:{},请求体参数:{}", request.getRequestURI(), sb);
JSONObject body = JSON.parseObject(sb.toString());
setFields(fields, body, instance);
setFields(superFields, body, instance);
}
return instance;
}
private void setFields(Field[] fields, JSONObject source, Object target) throws IllegalAccessException {
if (Objects.isNull(fields) || fields.length == 0) return;
for (Field f : fields) {
f.setAccessible(true);
String fName = f.getName();
String typeName = f.getGenericType().getTypeName();
ZhBindAlias bindAlias = f.getAnnotation(ZhBindAlias.class);
if (Objects.nonNull(bindAlias)) {
String name = bindAlias.value();
int index = bindAlias.index();
Object o = source.get(name);
if ("extras".equals(fName)) {
f.set(target, jsonObjectToMap(source));
}
if (Objects.isNull(o)) continue;
Object convertValue = FieldTypeConvertor.FieldType.of(typeName).convertJsonObject(o, index, name);
f.set(target, convertValue);
source.remove(name);
} else {
Object o = source.get(fName);
if (Objects.isNull(o)) continue;
Object convertValue = FieldTypeConvertor.FieldType.of(typeName).convertJsonObject(o, 0, fName);
f.set(target, convertValue);
source.remove(fName);
}
}
}
private Map<String, List<String>> jsonObjectToMap(JSONObject object) {
TypeReference<List<String>> type = new TypeReference<>(){};
Map<String, List<String>> value = new HashMap<>();
object.forEach((k, v) -> {
List<String> list = JSON.parseObject(v.toString(), type);
value.put(k, list);
});
return value;
}
}
- 配置 WebMvcConfigurer 使自定义参数解析器生效
@Configuration
public class WebConfig implements WebMvcConfigurer {
@Override
public void addArgumentResolvers(List<HandlerMethodArgumentResolver> resolvers) {
resolvers.add(new ZhFieldRequestResolver());
}
}
- 自定义注解用于标注解析后的参数接收类
/**
* 中文请求参数转换,请求实体字段配合 {@link ZhBindAlias} 使用
* 实现原理:自定义实现 mvc 参数转换器 {@link org.springframework.web.method.support.HandlerMethodArgumentResolver}
*/
@Documented
@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
public @interface ZhBindConvertor {
boolean required() default true;
}
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface ZhBindAlias {
/** 字段的中文别名 */
String value();
/** 同名的情况下绑定到第几个参数 */
int index() default 0;
/** 是否默认添加到 query 条件构建中 */
boolean includeQuery() default true;
}
- 抽取公用查询参数类
/**
* 公共查询参数
*/
@Data
public abstract class ZhSearchReq {
@ZhBindAlias("录入时间")
private String startTime;
@ZhBindAlias(value = "录入时间", index = 1)
private String endTime;
/** 自定义查询字段(对应变化的部分) */
@ZhBindAlias(value = "extra", includeQuery = false)
private Map<String, List<String>> extras;
/**
* 排序字段, json 字符串,示例:
* {“录入时间”: "ascend", "分配时间": "descend"}
* 按录入时间升序,分配时间降序排列,默认升序
*/
private String sorter;
@Min(1)
@ZhBindAlias(value = "current", includeQuery = false)
private Integer current = 1;
@Min(1)
@ZhBindAlias(value = "pageSize", includeQuery = false)
private Integer pageSize = 50;
/** 默认按录入时间降序排列 */
private Sort defaultSort() {
return Sort.by("录入时间").descending();
}
/**
* 获取子类字段的值,子类需要有 getter 方法
* @deprecated
*/
private void filedCriteriaForSubClass(Criteria criteria, Field f) {
Method[] methods = this.getClass().getDeclaredMethods();
String name = f.getName();
Stream.of(methods)
.filter(m -> m.getName().startsWith("get") && m.getName().toLowerCase().contains(name.toLowerCase()))
.findAny()
.ifPresent(m -> {
try {
Object value = m.invoke(this, null);
// filedCriteria(criteria, f, value);
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (InvocationTargetException e) {
e.printStackTrace();
}
});
}
public Sort sort() {
String sorter = getSorter();
if (StringUtils.isEmpty(sorter)) return defaultSort();
HashMap sorterMap = JSONObject.parseObject(sorter, HashMap.class);
if (sorterMap.isEmpty()) return defaultSort();
List<Sort.Order> orders = new ArrayList<>(sorterMap.size());
sorterMap.forEach((k, v) -> {
Sort.Order order;
if (v.equals("descend")) {
order = Sort.Order.desc(k.toString());
} else {
order = Sort.Order.asc(k.toString());
}
orders.add(order);
});
if (orders.isEmpty()) return defaultSort();
return Sort.by(orders);
}
/**
* 由于数据库选型原因,与 MongoDB 查询条件强绑定
*/
public Criteria getQueryCriteria() {
Criteria criteria = new Criteria();
Class<? extends ZhSearchReq> reqClass = this.getClass();
Class<?> superclass = reqClass.getSuperclass();
Field[] superFields = new Field[0];
if (!superclass.equals(Object.class)) {
superFields = superclass.getDeclaredFields();
}
Field[] fields = reqClass.getDeclaredFields();
Arrays.stream(fields).forEach(f -> filedCriteria(criteria, f));
Arrays.stream(superFields).forEach(f -> filedCriteria(criteria, f));
if (StringUtils.isNotBlank(getStartTime()) && StringUtils.isNotBlank(getEndTime())) {
criteria.and("录入时间").gte(getStartTime()).lte(getEndTime());
}
// 对于动态参数的处理
if (MapUtil.isNotEmpty(extras)) {
extras.forEach((k, v) -> {
if (CollUtil.isNotEmpty(v)) {
// warning 与具体业务逻辑相关,可以前端确定公用的处理模型
if (!v.contains("全选")) {
if (v.size() == 1) {
// 根据业务逻辑确定单值的查询定义
criteria.and(k).is(v.get(0));
} else {
// 根据业务逻辑确定多值的查询定义
criteria.and(k).in(v);
}
}
}
});
}
return criteria;
}
/**
* 对于单个查询参数的处理
*/
protected void filedCriteria(Criteria criteria, Field f) {
String fName = f.getName();
if ("sorter".equals(fName)) return;
String typeName = f.getGenericType().getTypeName();
ZhBindAlias alias = f.getAnnotation(ZhBindAlias.class);
Object value = null;
try {
f.setAccessible(true);
value = f.get(this);
} catch (IllegalAccessException e) {
// 内部调用,不会有问题
e.printStackTrace();
}
if (Objects.isNull(value)) return;
if (Objects.nonNull(alias)) {
if (!alias.includeQuery()) return;
String where = alias.value();
if ("java.lang.String".equals(typeName) && !where.contains("时间")) {
// 单值采用前缀匹配查询
criteria.and(where).regex("^" + value);
} else if ("java.util.List<java.lang.String>".equals(typeName) && !where.contains("时间")) {
List<String> values = (List<String>) value;
if (!values.contains("全选")) {
// 多值采用 $in 查询
criteria.and(where).in(values);
}
}
} else {
// 等值查询
criteria.and(fName).is(value);
}
}
}
- 扩展查询参数类
@Data
public class CustomerSearchReq extends ZhSearchReq {
@ZhBindAlias("跟进状态")
private List<String> followState;
@ZhBindAlias("手机号")
private String mobile;
@ZhBindAlias("姓名")
private String name;
}
- 使用
@GetMapping()
@Operation(summary = "线索列表")
public R<PageResult<LinkedHashMap>> searchPage(@ZhBindConvertor CustomerSearchReq searchReq) {
return R.ok(customerSearchService.search(searchReq));
}
- 重构质量保证-测试用例
- 只写了集成测试用例,保证基本的全流程准确性
总结
涉及知识点
- Spring
- SpringMVC 参数绑定流程及自定义参数解析器实现
- GET 类型请求参数解析
- POST 类型请求参数解析
- 自定义参数解析器如何配置生效
- SpringMVC 参数绑定流程及自定义参数解析器实现
- 反射:
- 根据类型实例化对象
- 字段信息获取与对象字段设置
- 字段类型信息获取和区分各种不同类型
- 父类字段信息获取以及子类字段信息如何获取
- 对象示例方法信息获取和方法执行
- 自定义注解的使用
- 测试用例
- 使用了 testcontainers + docker mongodb + mockmvc 编写集成测试用例
- json 文件 + MongoTemplate#insert 完成测试前数据准备
- MongoTemplate#dropCollection 完成测试后数据清理
- MockMvc 对于响应数据的各种准确性断言
遇到的问题
-
父类方法反射获取子类字段的取值(子类实例调用时)
- 解决:忘记了 f.setAccessible(true);
- 没有 f.setAccessible(true) 的时候也可以使用调用反射方法 getter + field name 的方式获取值,但很不 clean 也会有更大性能开销?benchmark
-
spring doc openapi 参数转换器问题
- 通过 WebMvcConfigurer#addArgumentResolvers 解决
存在的问题及可以继续改进的地方
- 公用查询参数类只支持一层继承体系类的处理,可以根据业务实际需要做支持处理或做规范
- 公用查询参数类获取 query 条件的方法与数据库类型以及业务强绑定,可以在这一层根据实际需求再做一层抽象
- 由于业务逻辑简单,没有做单元测试,只做了集成测试
- 反射获取字段信息时没有做缓存,可以 benchmark 下看看对于应用的性能提升
参考及示例代码
代码:
custom_mvc_method_argument_resolver
参考:
Spring From the Trenches: Creating a Custom HandlerMethodArgumentResolver
Xuguozong commented
benchmark result should mark the running machine's hardware