日常迭代管理后台时,有一个经常需要维护的需求,那就是查询数据列表,包含的一系列功能包括但不限于分页、筛选、排序、导出、批量操作等。这些功能虽然都是大同小异,但却需要在不同的场景下进行不同的配置,这就导致了代码的冗余和维护成本的增加。
最近在工作中发现 Go 语言的泛型功能可以很好地解决这个问题,通过泛型可以将一些通用的功能抽象出来,达到代码复用的目的。每次不同的数据场景只需要进行泛型推断,然后定义各自的查询细节,就可以实现不同场景下的查询功能,不需要重复编写相同的代码。
下面我们就来实现一个通用的列表查询构建器,不同的数据源通过策略模式进行注入,然后通过泛型变量来实现相同场景下的数据定义。话不多说,我们直接看代码。
假设我们现在参与了一个商城项目,在管理后台中需要查询商品列表,包括商品名称、价格、库存等信息。我们可以定义一个商品实体:
package model
import "gorm.io/gorm"
type Goods struct {
Name string `gorm:"column:name;type:varchar(100);not null;comment:商品名称"`
Price float32 `gorm:"column:price;type:decimal unsigned;not null;comment:商品价格"`
Stock uint32 `gorm:"column:stock;type:int unsigned;not null;comment:商品库存"`
Status uint32 `gorm:"column:status;type:tinyint unsigned;not null;default:1;comment:商品状态"`
gorm.Model
}
func (*Goods) TableName() string {
return "goods"
}
再将 goods.proto 文件定义如下:
syntax = "proto3";
package api.goods.v1;
import "google/api/annotations.proto";
option go_package = "kratos-demo/api/goods/v1;v1";
option java_multiple_files = true;
option java_package = "api.goods.v1";
service Goods {
rpc ListGoods (ListGoodsRequest) returns (ListGoodsReply) {
option (google.api.http) = {
get: "/v1/goods"
};
};
}
message ListGoodsRequest {
ListGoodsFilter filter = 1;
QueryGoodsListSort sort = 2;
uint32 start = 3;
uint32 limit = 4;
}
message ListGoodsReply {
repeated GoodsInfo goods = 1;
uint32 total = 2;
}
message ListGoodsFilter {
string name = 1;
float min_price = 2;
float max_price = 3;
}
enum QueryGoodsListSort {
CREATED_AT = 0;
PRICE = 1;
}
enum GoodsStatus {
NONE = 0;
NORMAL = 1;
OFFLINE = 2;
}
message GoodsInfo {
uint64 id = 1;
string name = 2;
float price = 3;
uint32 stock = 4;
GoodsStatus status = 5;
string created_at = 6;
}
开始进行准备工作,我们使用 errgroup 来定义一个并发查询列表及总条数的工具方法:
package utils
import (
"fmt"
"golang.org/x/sync/errgroup"
"runtime/debug"
)
// WaitAndGo 等待所有函数执行完毕
func WaitAndGo(fn ...func() error) error {
defer func() {
if err := recover(); err != nil {
fmt.Printf("Panic: %+v\n %s", err, string(debug.Stack()))
}
}()
var g errgroup.Group
for _, f := range fn {
g.Go(func() error {
return f()
})
}
return g.Wait()
}
首先先把数据底层查询逻辑实现,也就是构建器的定义,该构建器可实现 MySQL 及 MongoDB 的数据加载:
package builder
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"gorm.io/gorm"
"kratos-demo/internal/utils"
)
// DBProxy 数据实例结构
type DBProxy struct {
db *gorm.DB
mongodb *mongo.Collection // 需提前指定.Database("db_name").Collection("collection_name")
// redis、elasticsearch...
}
// NewDBProxy 创建数据实例
func NewDBProxy(db *gorm.DB, mongodb *mongo.Collection) *DBProxy {
return &DBProxy{
db: db,
mongodb: mongodb,
}
}
// QueryMiddleware 查询中间件类型定义
// 参数:
// ctx: 上下文
// builder: 查询构建器实例
// next: 下一个中间件或最终查询处理器
//
// 返回:
// []*R: 查询结果列表
// int64: 总数
// error: 错误信息
type QueryMiddleware[R any] func(
ctx context.Context,
builder *builder[R],
next func(context.Context) ([]*R, int64, error),
) ([]*R, int64, error)
// builder 查询构建器,使用泛型支持多种实体类型
// 泛型参数:
// R: 查询结果的实体类型
type builder[R any] struct {
data *DBProxy
start uint32
limit uint32
needTotal bool
strategy QueryListStrategy[R] // 查询策略
middlewares []QueryMiddleware[R] // 中间件链
filter func(context.Context) (any, error)
sort func() any
}
// SetFilter 设置过滤条件生成函数
// 返回支持链式调用的构建器实例
func (b *builder[R]) SetFilter(filter func(context.Context) (any, error)) *builder[R] {
b.filter = filter
return b
}
// SetSort 设置排序条件生成函数
// 返回支持链式调用的构建器实例
func (b *builder[R]) SetSort(sort func() any) *builder[R] {
b.sort = sort
return b
}
// SetStrategy 设置查询列表策略
// 返回支持链式调用的构建器实例
func (b *builder[R]) SetStrategy(strategy QueryListStrategy[R]) *builder[R] {
b.strategy = strategy
return b
}
// Use 添加中间件
// 返回支持链式调用的构建器实例
func (b *builder[R]) Use(middleware QueryMiddleware[R]) *builder[R] {
b.middlewares = append(b.middlewares, middleware)
return b
}
// getQueryStrategy 获取查询列表策略
// 如果没有设置策略,则根据数据源自动选择策略
func (b *builder[R]) getQueryStrategy() (QueryListStrategy[R], error) {
if b.strategy != nil {
return b.strategy, nil
}
if b.data == nil {
return nil, errors.New("no data source provided")
}
switch {
case b.data.db != nil:
return NewQueryGormListStrategy[R](), nil
case b.data.mongodb != nil:
return NewQueryMongoListStrategy[R](), nil
default:
return nil, errors.New("query strategy not set and no valid DB found")
}
}
// QueryList 执行查询列表操作
// 返回值与中间件类型相同,list []R 查询结果列表
func (b *builder[R]) QueryList(ctx context.Context) ([]*R, int64, error) {
// 尝试自动推断策略类型
strategy, err := b.getQueryStrategy()
if err != nil {
return nil, 0, err
}
// 构建中间件链
next := func(ctx context.Context) ([]*R, int64, error) {
return strategy.QueryList(ctx, b)
}
for i := len(b.middlewares) - 1; i >= 0; i-- {
next = func(mw QueryMiddleware[R], fn func(context.Context) ([]*R, int64, error)) func(context.Context) ([]*R, int64, error) {
return func(ctx context.Context) ([]*R, int64, error) {
return mw(ctx, b, fn)
}
}(b.middlewares[i], next)
}
return next(ctx)
}
// QueryListStrategy 查询列表策略
type QueryListStrategy[R any] interface {
QueryList(context.Context, *builder[R]) ([]*R, int64, error)
}
// QueryGormListStrategy GORM 查询策略实现
type QueryGormListStrategy[R any] struct{}
// NewQueryGormListStrategy 创建 GORM 查询策略实例
func NewQueryGormListStrategy[R any]() *QueryGormListStrategy[R] {
return &QueryGormListStrategy[R]{}
}
// QueryList 实现 GORM 查询逻辑
func (s *QueryGormListStrategy[R]) QueryList(
ctx context.Context,
builder *builder[R],
) (list []*R, total int64, err error) {
filterScope, err := builder.filter(ctx)
if err != nil {
return nil, 0, err
}
sortScope := builder.sort()
// 验证过滤条件和排序条件的类型有效性
for _, scope := range []any{filterScope, sortScope} {
if _, ok := scope.(func(*gorm.DB) *gorm.DB); !ok {
return nil, 0, errors.New("invalid scope")
}
}
// 使用 WaitAndGo 并行执行数据查询和总数统计操作
if err := utils.WaitAndGo(func() error {
limit := builder.limit
if builder.limit < 1 {
limit = defaultLimit
}
return builder.data.db.WithContext(ctx).
Model(&list).
Scopes(filterScope.(func(*gorm.DB) *gorm.DB), sortScope.(func(*gorm.DB) *gorm.DB)).
Offset(int(builder.start)).
Limit(int(limit)).
Find(&list).
Error
}, func() error {
if !builder.needTotal {
return nil
}
return builder.data.db.WithContext(ctx).
Model(&list).
Scopes(filterScope.(func(*gorm.DB) *gorm.DB)).
Count(&total).
Error
}); err != nil {
return nil, 0, err
}
return list, total, nil
}
// QueryMongoListStrategy MongoDB 查询策略实现
type QueryMongoListStrategy[R any] struct{}
// NewQueryMongoListStrategy 创建 MongoDB 查询策略实例
func NewQueryMongoListStrategy[R any]() *QueryMongoListStrategy[R] {
return &QueryMongoListStrategy[R]{}
}
// QueryList 实现 MongoDB 查询逻辑
func (s *QueryMongoListStrategy[R]) QueryList(
ctx context.Context,
builder *builder[R],
) (list []*R, total int64, err error) {
filterOpt, err := builder.filter(ctx)
if err != nil {
return nil, 0, err
}
sortOpt := builder.sort()
// 验证过滤条件和排序条件的类型有效性
for _, opt := range []any{filterOpt, sortOpt} {
_, mOk := opt.(bson.M)
_, dOk := opt.(bson.D)
if !mOk && !dOk {
return nil, 0, errors.New("invalid option")
}
}
// 使用 WaitAndGo 并行执行数据查询和总数统计操作
if err := utils.WaitAndGo(func() error {
limit := builder.limit
if builder.limit < 1 {
limit = defaultLimit
}
findOpt := options.Find().
SetLimit(int64(limit)).
SetSkip(int64(builder.start)).
SetSort(sortOpt)
cursor, err := builder.data.mongodb.
Find(ctx, filterOpt, findOpt)
if err != nil {
return err
}
defer func(cursor *mongo.Cursor, ctx context.Context) {
_ = cursor.Close(ctx)
}(cursor, ctx)
if err := cursor.All(ctx, &list); err != nil {
return err
}
return nil
}, func() error {
if !builder.needTotal {
return nil
}
total, err = builder.data.mongodb.
CountDocuments(ctx, filterOpt)
if err != nil {
return err
}
return nil
}); err != nil {
return nil, 0, err
}
return list, total, nil
}
接着我们可以将构建器的参数通过函数选项模式进行配置,定义一个通用选项结构:
package builder
import (
pb "kratos-demo/api/goods/v1"
)
const (
defaultStart = 0 // 默认从第0条开始
defaultLimit = 10 // 默认每页10条
defaultNeedTotal = true // 默认需要总数
)
// Filter 定义过滤条件的通用接口类型
type Filter any
// Sort 定义排序条件的通用接口类型
type Sort any
// QueryListOptions 定义了查询列表的通用选项接口
// 泛型参数:
// F - 过滤条件类型参数
// S - 排序条件类型参数
type QueryListOptions[F Filter, S Sort] interface {
GetData() *DBProxy
GetFilter() *F
GetSort() S
GetStart() uint32
GetLimit() uint32
GetNeedTotal() bool
}
// BaseQueryListOptions 实现了QueryListOptions接口的基础结构体
// 包含查询列表所需的所有基本选项
type BaseQueryListOptions[F Filter, S Sort] struct {
data *DBProxy // 数据实例
filter *F // 过滤条件生成函数
sort S // 排序条件生成函数
start uint32 // 分页起始位置
limit uint32 // 每页数据条数
needTotal bool // 是否需要查询总数
}
func (opts *BaseQueryListOptions[F, S]) GetData() *DBProxy {
return opts.data
}
func (opts *BaseQueryListOptions[F, S]) GetFilter() *F {
return opts.filter
}
func (opts *BaseQueryListOptions[F, S]) GetSort() S {
return opts.sort
}
func (opts *BaseQueryListOptions[F, S]) GetStart() uint32 {
return opts.start
}
func (opts *BaseQueryListOptions[F, S]) GetLimit() uint32 {
return opts.limit
}
func (opts *BaseQueryListOptions[F, S]) GetNeedTotal() bool {
return opts.needTotal
}
// QueryOption 定义用于配置查询选项的函数类型
type QueryOption[F Filter, S Sort] func(options *BaseQueryListOptions[F, S])
// LoadQueryOptions 加载并应用查询选项
// 参数:
// opts - 可变数量的查询选项函数
//
// 返回:
// 配置好的BaseQueryListOptions实例
func LoadQueryOptions[F Filter, S Sort](opts ...QueryOption[F, S]) BaseQueryListOptions[F, S] {
// 初始化默认选项
options := BaseQueryListOptions[F, S]{
start: defaultStart,
limit: defaultLimit,
needTotal: defaultNeedTotal,
}
// 应用所有选项函数
for _, opt := range opts {
opt(&options)
}
return options
}
func WithData[F Filter, S Sort](data *DBProxy) QueryOption[F, S] {
return func(o *BaseQueryListOptions[F, S]) {
o.data = data
}
}
func WithFilter[F Filter, S Sort](filter *F) QueryOption[F, S] {
return func(o *BaseQueryListOptions[F, S]) {
o.filter = filter
}
}
func WithSort[F Filter, S Sort](sort S) QueryOption[F, S] {
return func(o *BaseQueryListOptions[F, S]) {
o.sort = sort
}
}
func WithStart[F Filter, S Sort](start uint32) QueryOption[F, S] {
return func(o *BaseQueryListOptions[F, S]) {
o.start = start
}
}
func WithLimit[F Filter, S Sort](limit uint32) QueryOption[F, S] {
return func(o *BaseQueryListOptions[F, S]) {
o.limit = limit
}
}
func WithNeedTotal[F Filter, S Sort](needTotal bool) QueryOption[F, S] {
return func(o *BaseQueryListOptions[F, S]) {
o.needTotal = needTotal
}
}
然后完善业务层处理:
package builder
import (
"context"
"gorm.io/gorm"
pb "kratos-demo/api/goods/v1"
"kratos-demo/internal/model"
)
// QueryGoods 查询商品列表
type QueryGoods struct {
builder[model.Goods]
filter *pb.ListGoodsFilter
sort pb.QueryGoodsListSort
}
// NewQueryGoods 创建查询商品列表实例
func NewQueryGoods(opts ...QueryOption[pb.ListGoodsFilter, pb.QueryGoodsListSort]) *QueryGoods {
options := LoadQueryOptions(opts...)
return &QueryGoods{
builder: builder[model.Goods]{
data: options.GetData(),
start: options.GetStart(),
limit: options.GetLimit(),
needTotal: options.GetNeedTotal(),
},
filter: options.GetFilter(),
sort: options.GetSort(),
}
}
// getQueryListFilter 获取查询列表的过滤器
// 可在过滤器中查询其他服务的数据,ctx方便链路追踪
func (query *QueryGoods) getQueryListFilter(context.Context) (any, error) {
return func(db *gorm.DB) *gorm.DB {
if query.filter.GetName() != "" {
db.Where("name = ?", query.filter.GetName())
}
if query.filter.GetMinPrice() > 0 {
db.Where("price >=?", query.filter.GetMinPrice())
}
if query.filter.GetMaxPrice() > 0 {
db.Where("price <=?", query.filter.GetMaxPrice())
}
return db
}, nil
}
// getQueryListSort 获取查询列表的排序
func (query *QueryGoods) getQueryListSort() any {
return func(db *gorm.DB) *gorm.DB {
return map[pb.QueryGoodsListSort]*gorm.DB{
pb.QueryGoodsListSort_CREATED_AT: db.Order("created_at desc"),
pb.QueryGoodsListSort_PRICE: db.Order("price desc"),
}[query.sort]
}
}
// QueryList 查询列表
func (query *QueryGoods) QueryList(ctx context.Context) (list []*model.Goods, total int64, err error) {
return query.builder.
SetFilter(query.getQueryListFilter).
SetSort(query.getQueryListSort).
QueryList(ctx)
}
最后我们可以在数据层进行调用:
package data
import (
"context"
"github.com/go-kratos/kratos/v2/log"
pb "kratos-demo/api/goods/v1"
"kratos-demo/internal/biz"
"kratos-demo/internal/data/builder"
"kratos-demo/internal/model"
)
type goodsRepo struct {
data *Data
log *log.Helper
}
func NewGoodsRepo(data *Data, logger log.Logger) biz.GoodsRepo {
return &goodsRepo{data: data, log: log.NewHelper(logger)}
}
func (r *goodsRepo) ListGoods(ctx context.Context, req *pb.ListGoodsRequest) ([]*model.Goods, uint32, error) {
// 定义类型别名,简化泛型类型的推断
type filter = pb.ListGoodsFilter
type sort = pb.QueryGoodsListSort
list, total, err := builder.NewQueryGoods(
builder.WithData[filter, sort](builder.NewDBProxy(r.data.db, nil)),
builder.WithFilter[filter, sort](req.GetFilter()),
builder.WithSort[filter, sort](req.GetSort()),
builder.WithStart[filter, sort](req.GetStart()),
builder.WithLimit[filter, sort](req.GetLimit()),
).QueryList(ctx)
if err != nil {
r.log.WithContext(ctx).Errorf("list goods err: %+v", err)
return nil, 0, err
}
return list, uint32(total), nil
}
通过泛型的层层封装,一个商品列表查询的完整流程就完成了。后续如果需要增加其他查询列表的场景,只需要依葫芦画瓢定义好通用的逻辑,然后实现对应的 GetQueryListFilter 和 GetQueryListSort 方法即可。