1.抽象接口定义
1 public abstract class SearchQueryEngine<T> {
2
3 @Autowired
4 protected ElasticsearchTemplate elasticsearchTemplate;
5
6 public abstract int saveOrUpdate(List<T> list);
7
8 public abstract <R> List<R> aggregation(T query, Class<R> clazz);
9
10 public abstract <R> Page<R> scroll(T query, Class<R> clazz, Pageable pageable, ScrollId scrollId);
11
12 public abstract <R> List<R> find(T query, Class<R> clazz, int size);
13
14 public abstract <R> Page<R> find(T query, Class<R> clazz, Pageable pageable);
15
16 public abstract <R> R sum(T query, Class<R> clazz);
17
18 protected Document getDocument(T t) {
19 Document annotation = t.getClass().getAnnotation(Document.class);
20 if (annotation == null) {
21 throw new SearchQueryBuildException("Can't find annotation @Document on " + t.getClass().getName());
22 }
23 return annotation;
24 }
25
26 /**
27 * 获取字段名,若设置column则返回该值
28 *
29 * @param field
30 * @param column
31 * @return
32 */
33 protected String getFieldName(Field field, String column) {
34 return StringUtils.isNotBlank(column) ? column : field.getName();
35 }
36
37 /**
38 * 设置属性值
39 *
40 * @param field
41 * @param obj
42 * @param value
43 */
44 protected void setFieldValue(Field field, Object obj, Object value) {
45 boolean isAccessible = field.isAccessible();
46 field.setAccessible(true);
47 try {
48 switch (field.getType().getSimpleName()) {
49 case "BigDecimal":
50 field.set(obj, new BigDecimal(value.toString()).setScale(5, BigDecimal.ROUND_HALF_UP));
51 break;
52 case "Long":
53 field.set(obj, new Long(value.toString()));
54 break;
55 case "Integer":
56 field.set(obj, new Integer(value.toString()));
57 break;
58 case "Date":
59 field.set(obj, new Date(Long.valueOf(value.toString())));
60 break;
61 default:
62 field.set(obj, value);
63 }
64 } catch (IllegalAccessException e) {
65 throw new SearchQueryBuildException(e);
66 } finally {
67 field.setAccessible(isAccessible);
68 }
69 }
70
71 /**
72 * 获取字段值
73 *
74 * @param field
75 * @param obj
76 * @return
77 */
78 protected Object getFieldValue(Field field, Object obj) {
79 boolean isAccessible = field.isAccessible();
80 field.setAccessible(true);
81 try {
82 return field.get(obj);
83 } catch (IllegalAccessException e) {
84 throw new SearchQueryBuildException(e);
85 } finally {
86 field.setAccessible(isAccessible);
87 }
88 }
89
90 /**
91 * 转换为es识别的value值
92 *
93 * @param value
94 * @return
95 */
96 protected Object formatValue(Object value) {
97 if (value instanceof Date) {
98 return ((Date) value).getTime();
99 } else {
100 return value;
101 }
102 }
103
104 /**
105 * 获取索引分区数
106 *
107 * @param t
108 * @return
109 */
110 protected int getNumberOfShards(T t) {
111 return Integer.parseInt(elasticsearchTemplate.getSetting(getDocument(t).index()).get(IndexMetaData.SETTING_NUMBER_OF_SHARDS).toString());
112 }
113 }
2.接口实现
1 @Component
2 @ComponentScan
3 public class SimpleSearchQueryEngine<T> extends SearchQueryEngine<T> {
4
5 private int numberOfRowsPerScan = 10;
6
7 @Override
8 public int saveOrUpdate(List<T> list) {
9 if (CollectionUtils.isEmpty(list)) {
10 return 0;
11 }
12
13 T base = list.get(0);
14 Field id = null;
15 for (Field field : base.getClass().getDeclaredFields()) {
16 BusinessID businessID = field.getAnnotation(BusinessID.class);
17 if (businessID != null) {
18 id = field;
19 break;
20 }
21 }
22 if (id == null) {
23 throw new SearchQueryBuildException("Can't find @BusinessID on " + base.getClass().getName());
24 }
25
26 Document document = getDocument(base);
27 List<UpdateQuery> bulkIndex = new ArrayList<>();
28 for (T t : list) {
29 UpdateQuery updateQuery = new UpdateQuery();
30 updateQuery.setIndexName(document.index());
31 updateQuery.setType(document.type());
32 updateQuery.setId(getFieldValue(id, t).toString());
33 updateQuery.setUpdateRequest(new UpdateRequest(updateQuery.getIndexName(), updateQuery.getType(), updateQuery.getId()).doc(JSONObject.toJSONString(t, SerializerFeature.WriteMapNullValue)));
34 updateQuery.setDoUpsert(true);
35 updateQuery.setClazz(t.getClass());
36 bulkIndex.add(updateQuery);
37 }
38 elasticsearchTemplate.bulkUpdate(bulkIndex);
39 return list.size();
40 }
41
42 @Override
43 public <R> List<R> aggregation(T query, Class<R> clazz) {
44 NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query);
45 nativeSearchQueryBuilder.addAggregation(buildGroupBy(query));
46 Aggregations aggregations = elasticsearchTemplate.query(nativeSearchQueryBuilder.build(), new AggregationResultsExtractor());
47 try {
48 return transformList(null, aggregations, clazz.newInstance(), new ArrayList());
49 } catch (Exception e) {
50 throw new SearchResultBuildException(e);
51 }
52 }
53
54 /**
55 * 将Aggregations转为List
56 *
57 * @param terms
58 * @param aggregations
59 * @param baseObj
60 * @param resultList
61 * @param <R>
62 * @return
63 * @throws NoSuchFieldException
64 * @throws IllegalAccessException
65 * @throws InstantiationException
66 */
67 private <R> List<R> transformList(Aggregation terms, Aggregations aggregations, R baseObj, List<R> resultList) throws NoSuchFieldException, IllegalAccessException, InstantiationException {
68 for (String column : aggregations.asMap().keySet()) {
69 Aggregation childAggregation = aggregations.get(column);
70 if (childAggregation instanceof InternalSum) {
71 // 使用@Sum
72 if (!(terms instanceof InternalSum)) {
73 R targetObj = (R) baseObj.getClass().newInstance();
74 BeanUtils.copyProperties(baseObj, targetObj);
75 resultList.add(targetObj);
76 }
77 setFieldValue(baseObj.getClass().getDeclaredField(column), resultList.get(resultList.size() - 1), ((InternalSum) childAggregation).getValue());
78 terms = childAggregation;
79 } else {
80 Terms childTerms = (Terms) childAggregation;
81 for (Terms.Bucket bucket : childTerms.getBuckets()) {
82 if (CollectionUtils.isEmpty(bucket.getAggregations().asList())) {
83 // 未使用@Sum
84 R targetObj = (R) baseObj.getClass().newInstance();
85 BeanUtils.copyProperties(baseObj, targetObj);
86 setFieldValue(targetObj.getClass().getDeclaredField(column), targetObj, bucket.getKey());
87 resultList.add(targetObj);
88 } else {
89 setFieldValue(baseObj.getClass().getDeclaredField(column), baseObj, bucket.getKey());
90 transformList(childTerms, bucket.getAggregations(), baseObj, resultList);
91 }
92 }
93 }
94 }
95 return resultList;
96 }
97
98 @Override
99 public <R> Page<R> scroll(T query, Class<R> clazz, Pageable pageable, ScrollId scrollId) {
100 if (pageable.getPageSize() % numberOfRowsPerScan > 0) {
101 throw new SearchQueryBuildException("Page size must be an integral multiple of " + numberOfRowsPerScan);
102 }
103 SearchQuery searchQuery = buildNativeSearchQueryBuilder(query).withPageable(new PageRequest(pageable.getPageNumber(), numberOfRowsPerScan / getNumberOfShards(query), pageable.getSort())).build();
104 if (StringUtils.isEmpty(scrollId.getValue())) {
105 scrollId.setValue(elasticsearchTemplate.scan(searchQuery, 10000l, false));
106 }
107 Page<R> page = elasticsearchTemplate.scroll(scrollId.getValue(), 10000l, clazz);
108 if (page == null || page.getContent().size() == 0) {
109 elasticsearchTemplate.clearScroll(scrollId.getValue());
110 }
111 return page;
112 }
113
114 @Override
115 public <R> List<R> find(T query, Class<R> clazz, int size) {
116 // Caused by: QueryPhaseExecutionException[Result window is too large, from + size must be less than or equal to: [10000] but was [2147483647].
117 // See the scroll api for a more efficient way to request large data sets. This limit can be set by changing the [index.max_result_window] index level parameter.]
118 if (size % numberOfRowsPerScan > 0) {
119 throw new SearchQueryBuildException("Parameter 'size' must be an integral multiple of " + numberOfRowsPerScan);
120 }
121 int pageNum = 0;
122 List<R> result = new ArrayList<>();
123 ScrollId scrollId = new ScrollId();
124 while (true) {
125 Page<R> page = scroll(query, clazz, new PageRequest(pageNum, numberOfRowsPerScan), scrollId);
126 if (page != null && page.getContent().size() > 0) {
127 result.addAll(page.getContent());
128 } else {
129 break;
130 }
131 if (result.size() >= size) {
132 break;
133 } else {
134 pageNum++;
135 }
136 }
137 elasticsearchTemplate.clearScroll(scrollId.getValue());
138 return result;
139 }
140
141 @Override
142 public <R> Page<R> find(T query, Class<R> clazz, Pageable pageable) {
143 NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query).withPageable(pageable);
144 return elasticsearchTemplate.queryForPage(nativeSearchQueryBuilder.build(), clazz);
145 }
146
147 @Override
148 public <R> R sum(T query, Class<R> clazz) {
149 NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query);
150 for (SumBuilder sumBuilder : getSumBuilderList(query)) {
151 nativeSearchQueryBuilder.addAggregation(sumBuilder);
152 }
153 Aggregations aggregations = elasticsearchTemplate.query(nativeSearchQueryBuilder.build(), new AggregationResultsExtractor());
154 try {
155 return transformSumResult(aggregations, clazz);
156 } catch (Exception e) {
157 throw new SearchResultBuildException(e);
158 }
159 }
160
161 private <R> R transformSumResult(Aggregations aggregations, Class<R> clazz) throws IllegalAccessException, InstantiationException, NoSuchFieldException {
162 R targetObj = clazz.newInstance();
163 for (Aggregation sum : aggregations.asList()) {
164 if (sum instanceof InternalSum) {
165 setFieldValue(targetObj.getClass().getDeclaredField(sum.getName()), targetObj, ((InternalSum) sum).getValue());
166 }
167 }
168 return targetObj;
169 }
170
171 private NativeSearchQueryBuilder buildNativeSearchQueryBuilder(T query) {
172 Document document = getDocument(query);
173 NativeSearchQueryBuilder nativeSearchQueryBuilder = new NativeSearchQueryBuilder()
174 .withIndices(document.index())
175 .withTypes(document.type());
176
177 QueryBuilder whereBuilder = buildBoolQuery(query);
178 if (whereBuilder != null) {
179 nativeSearchQueryBuilder.withQuery(whereBuilder);
180 }
181
182 return nativeSearchQueryBuilder;
183 }
184
185 /**
186 * 布尔查询构建
187 *
188 * @param query
189 * @return
190 */
191 private BoolQueryBuilder buildBoolQuery(T query) {
192 BoolQueryBuilder boolQueryBuilder = boolQuery();
193 buildMatchQuery(boolQueryBuilder, query);
194 buildRangeQuery(boolQueryBuilder, query);
195 BoolQueryBuilder queryBuilder = boolQuery().must(boolQueryBuilder);
196 return queryBuilder;
197 }
198
199 /**
200 * and or 查询构建
201 *
202 * @param boolQueryBuilder
203 * @param query
204 */
205 private void buildMatchQuery(BoolQueryBuilder boolQueryBuilder, T query) {
206 Class clazz = query.getClass();
207 for (Field field : clazz.getDeclaredFields()) {
208 MatchQuery annotation = field.getAnnotation(MatchQuery.class);
209 Object value = getFieldValue(field, query);
210 if (annotation == null || value == null) {
211 continue;
212 }
213 if (Container.must.equals(annotation.container())) {
214 boolQueryBuilder.must(matchQuery(getFieldName(field, annotation.column()), formatValue(value)));
215 } else if (should.equals(annotation.container())) {
216 if (value instanceof Collection) {
217 BoolQueryBuilder shouldQueryBuilder = boolQuery();
218 Collection tmp = (Collection) value;
219 for (Object obj : tmp) {
220 shouldQueryBuilder.should(matchQuery(getFieldName(field, annotation.column()), formatValue(obj)));
221 }
222 boolQueryBuilder.must(shouldQueryBuilder);
223 } else {
224 boolQueryBuilder.must(boolQuery().should(matchQuery(getFieldName(field, annotation.column()), formatValue(value))));
225 }
226 }
227 }
228 }
229
230 /**
231 * 范围查询构建
232 *
233 * @param boolQueryBuilder
234 * @param query
235 */
236 private void buildRangeQuery(BoolQueryBuilder boolQueryBuilder, T query) {
237 Class clazz = query.getClass();
238 for (Field field : clazz.getDeclaredFields()) {
239 RangeQuery annotation = field.getAnnotation(RangeQuery.class);
240 Object value = getFieldValue(field, query);
241 if (annotation == null || value == null) {
242 continue;
243 }
244 if (Operator.gt.equals(annotation.operator())) {
245 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).gt(formatValue(value)));
246 } else if (Operator.gte.equals(annotation.operator())) {
247 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).gte(formatValue(value)));
248 } else if (Operator.lt.equals(annotation.operator())) {
249 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).lt(formatValue(value)));
250 } else if (Operator.lte.equals(annotation.operator())) {
251 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).lte(formatValue(value)));
252 }
253 }
254 }
255
256 /**
257 * Sum构建
258 *
259 * @param query
260 * @return
261 */
262 private List<SumBuilder> getSumBuilderList(T query) {
263 List<SumBuilder> list = new ArrayList<>();
264 Class clazz = query.getClass();
265 for (Field field : clazz.getDeclaredFields()) {
266 Sum annotation = field.getAnnotation(Sum.class);
267 if (annotation == null) {
268 continue;
269 }
270 list.add(AggregationBuilders.sum(field.getName()).field(field.getName()));
271 }
272 if (CollectionUtils.isEmpty(list)) {
273 throw new SearchQueryBuildException("Can't find @Sum on " + clazz.getName());
274 }
275 return list;
276 }
277
278
279 /**
280 * GroupBy构建
281 *
282 * @param query
283 * @return
284 */
285 private TermsBuilder buildGroupBy(T query) {
286 List<Field> sumList = new ArrayList<>();
287 Object groupByCollection = null;
288 Class clazz = query.getClass();
289 for (Field field : clazz.getDeclaredFields()) {
290 Sum sumAnnotation = field.getAnnotation(Sum.class);
291 if (sumAnnotation != null) {
292 sumList.add(field);
293 }
294 GroupBy groupByannotation = field.getAnnotation(GroupBy.class);
295 Object value = getFieldValue(field, query);
296 if (groupByannotation == null || value == null) {
297 continue;
298 } else if (!(value instanceof Collection)) {
299 throw new SearchQueryBuildException("GroupBy filed must be collection");
300 } else if (CollectionUtils.isEmpty((Collection<String>) value)) {
301 continue;
302 } else if (groupByCollection != null) {
303 throw new SearchQueryBuildException("Only one @GroupBy is allowed");
304 } else {
305 groupByCollection = value;
306 }
307 }
308 Iterator<String> iterator = ((Collection<String>) groupByCollection).iterator();
309 TermsBuilder termsBuilder = recursiveAddAggregation(iterator, sumList);
310 return termsBuilder;
311 }
312
313 /**
314 * 添加Aggregation
315 *
316 * @param iterator
317 * @return
318 */
319 private TermsBuilder recursiveAddAggregation(Iterator<String> iterator, List<Field> sumList) {
320 String groupBy = iterator.next();
321 TermsBuilder termsBuilder = AggregationBuilders.terms(groupBy).field(groupBy).size(0);
322 if (iterator.hasNext()) {
323 termsBuilder.subAggregation(recursiveAddAggregation(iterator, sumList));
324 } else {
325 for (Field field : sumList) {
326 termsBuilder.subAggregation(AggregationBuilders.sum(field.getName()).field(field.getName()));
327 }
328 sumList.clear();
329 }
330 return termsBuilder.order(Terms.Order.term(true));
331 }
3.存储scrollId值对象
import lombok.Data;
@Data
public class ScrollId {
private String value;
}
4.用于判断查询操作的枚举类
public enum Operator {
gt, gte, lt, lte
}
public enum Container {
must, should
}