日常迭代管理后台时,有一个经常需要维护的需求,那就是查询数据列表,包含的一系列功能包括但不限于分页、筛选、排序、导出、批量操作等。这些功能虽然都是大同小异,但却需要在不同的场景下进行不同的配置,这就导致了代码的冗余和维护成本的增加。

最近在工作中发现 Go 语言的泛型功能可以很好地解决这个问题,通过泛型可以将一些通用的功能抽象出来,达到代码复用的目的。每次不同的数据场景只需要进行泛型推断,然后定义各自的查询细节,就可以实现不同场景下的查询功能,不需要重复编写相同的代码。

下面我们就来实现一个通用的列表查询构建器,不同的数据源通过策略模式进行注入,然后通过泛型变量来实现相同场景下的数据定义。话不多说,我们直接看代码。

假设我们现在参与了一个商城项目,在管理后台中需要查询商品列表,包括商品名称、价格、库存等信息。我们可以定义一个商品实体:

package model

import "gorm.io/gorm"

type Goods struct {
    Name   string  `gorm:"column:name;type:varchar(100);not null;comment:商品名称"`
    Price  float32 `gorm:"column:price;type:decimal unsigned;not null;comment:商品价格"`
    Stock  uint32  `gorm:"column:stock;type:int unsigned;not null;comment:商品库存"`
    Status uint32  `gorm:"column:status;type:tinyint unsigned;not null;default:1;comment:商品状态"`
    gorm.Model
}

func (*Goods) TableName() string {
    return "goods"
}

再将 goods.proto 文件定义如下:

syntax = "proto3";

package api.goods.v1;

import "google/api/annotations.proto";

option go_package = "kratos-demo/api/goods/v1;v1";
option java_multiple_files = true;
option java_package = "api.goods.v1";

service Goods {
    rpc ListGoods (ListGoodsRequest) returns (ListGoodsReply) {
        option (google.api.http) = {
            get: "/v1/goods"
        };
    };
}

message ListGoodsRequest {
    ListGoodsFilter filter = 1;
    QueryGoodsListSort sort = 2;
    uint32 start = 3;
    uint32 limit = 4;
}
message ListGoodsReply {
    repeated GoodsInfo goods = 1;
    uint32 total = 2;
}

message ListGoodsFilter {
    string name = 1;
    float min_price = 2;
    float max_price = 3;
}

enum QueryGoodsListSort {
    CREATED_AT = 0;
    PRICE = 1;
}

enum GoodsStatus {
    NONE = 0;
    NORMAL = 1;
    OFFLINE = 2;
}

message GoodsInfo {
    uint64 id = 1;
    string name = 2;
    float price = 3;
    uint32 stock = 4;
    GoodsStatus status = 5;
    string created_at = 6;
}

开始进行准备工作,我们使用 errgroup 来定义一个并发查询列表及总条数的工具方法:

package utils

import (
    "fmt"
    "golang.org/x/sync/errgroup"
    "runtime/debug"
)

// WaitAndGo 等待所有函数执行完毕
func WaitAndGo(fn ...func() error) error {
    defer func() {
        if err := recover(); err != nil {
            fmt.Printf("Panic: %+v\n %s", err, string(debug.Stack()))
        }
    }()
    
    var g errgroup.Group
    for _, f := range fn {
        g.Go(func() error {
            return f()
        })
    }
    return g.Wait()
}

首先先把数据底层查询逻辑实现,也就是构建器的定义,该构建器可实现 MySQL 及 MongoDB 的数据加载:

package builder

import (
    "context"
    "errors"
    "go.mongodb.org/mongo-driver/bson"
    "go.mongodb.org/mongo-driver/mongo"
    "go.mongodb.org/mongo-driver/mongo/options"
    "gorm.io/gorm"
    "kratos-demo/internal/utils"
)

// DBProxy 数据实例结构
type DBProxy struct {
    db      *gorm.DB
    mongodb *mongo.Collection // 需提前指定.Database("db_name").Collection("collection_name")
    // redis、elasticsearch...
}

// NewDBProxy 创建数据实例
func NewDBProxy(db *gorm.DB, mongodb *mongo.Collection) *DBProxy {
    return &DBProxy{
        db:      db,
        mongodb: mongodb,
    }
}

// QueryMiddleware 查询中间件类型定义
// 参数:
//    ctx: 上下文
//    builder: 查询构建器实例
//    next: 下一个中间件或最终查询处理器
//
// 返回:
//    []*R: 查询结果列表
//    int64: 总数
//    error: 错误信息
type QueryMiddleware[R any] func(
    ctx context.Context,
    builder *builder[R],
    next func(context.Context) ([]*R, int64, error),
) ([]*R, int64, error)

// builder 查询构建器,使用泛型支持多种实体类型
// 泛型参数:
//    R: 查询结果的实体类型
type builder[R any] struct {
    data        *DBProxy
    start       uint32
    limit       uint32
    needTotal   bool
    strategy    QueryListStrategy[R] // 查询策略
    middlewares []QueryMiddleware[R] // 中间件链

    filter func(context.Context) (any, error)
    sort   func() any
}

// SetFilter 设置过滤条件生成函数
// 返回支持链式调用的构建器实例
func (b *builder[R]) SetFilter(filter func(context.Context) (any, error)) *builder[R] {
    b.filter = filter
    return b
}

// SetSort 设置排序条件生成函数
// 返回支持链式调用的构建器实例
func (b *builder[R]) SetSort(sort func() any) *builder[R] {
    b.sort = sort
    return b
}

// SetStrategy 设置查询列表策略
// 返回支持链式调用的构建器实例
func (b *builder[R]) SetStrategy(strategy QueryListStrategy[R]) *builder[R] {
    b.strategy = strategy
    return b
}

// Use 添加中间件
// 返回支持链式调用的构建器实例
func (b *builder[R]) Use(middleware QueryMiddleware[R]) *builder[R] {
    b.middlewares = append(b.middlewares, middleware)
    return b
}

// getQueryStrategy 获取查询列表策略
// 如果没有设置策略,则根据数据源自动选择策略
func (b *builder[R]) getQueryStrategy() (QueryListStrategy[R], error) {
    if b.strategy != nil {
        return b.strategy, nil
    }
    if b.data == nil {
        return nil, errors.New("no data source provided")
    }

    switch {
    case b.data.db != nil:
        return NewQueryGormListStrategy[R](), nil
    case b.data.mongodb != nil:
        return NewQueryMongoListStrategy[R](), nil
    default:
        return nil, errors.New("query strategy not set and no valid DB found")
    }
}

// QueryList 执行查询列表操作
// 返回值与中间件类型相同,list []R 查询结果列表
func (b *builder[R]) QueryList(ctx context.Context) ([]*R, int64, error) {
    // 尝试自动推断策略类型
    strategy, err := b.getQueryStrategy()
    if err != nil {
        return nil, 0, err
    }

    // 构建中间件链
    next := func(ctx context.Context) ([]*R, int64, error) {
        return strategy.QueryList(ctx, b)
    }

    for i := len(b.middlewares) - 1; i >= 0; i-- {
        next = func(mw QueryMiddleware[R], fn func(context.Context) ([]*R, int64, error)) func(context.Context) ([]*R, int64, error) {
            return func(ctx context.Context) ([]*R, int64, error) {
                return mw(ctx, b, fn)
            }
        }(b.middlewares[i], next)
    }

    return next(ctx)
}

// QueryListStrategy 查询列表策略
type QueryListStrategy[R any] interface {
    QueryList(context.Context, *builder[R]) ([]*R, int64, error)
}

// QueryGormListStrategy GORM 查询策略实现
type QueryGormListStrategy[R any] struct{}

// NewQueryGormListStrategy 创建 GORM 查询策略实例
func NewQueryGormListStrategy[R any]() *QueryGormListStrategy[R] {
    return &QueryGormListStrategy[R]{}
}

// QueryList 实现 GORM 查询逻辑
func (s *QueryGormListStrategy[R]) QueryList(
    ctx context.Context,
    builder *builder[R],
) (list []*R, total int64, err error) {
    filterScope, err := builder.filter(ctx)
    if err != nil {
        return nil, 0, err
    }

    sortScope := builder.sort()
    // 验证过滤条件和排序条件的类型有效性
    for _, scope := range []any{filterScope, sortScope} {
        if _, ok := scope.(func(*gorm.DB) *gorm.DB); !ok {
            return nil, 0, errors.New("invalid scope")
        }
    }

    // 使用 WaitAndGo 并行执行数据查询和总数统计操作
    if err := utils.WaitAndGo(func() error {
        limit := builder.limit
        if builder.limit < 1 {
            limit = defaultLimit
        }

        return builder.data.db.WithContext(ctx).
            Model(&list).
            Scopes(filterScope.(func(*gorm.DB) *gorm.DB), sortScope.(func(*gorm.DB) *gorm.DB)).
            Offset(int(builder.start)).
            Limit(int(limit)).
            Find(&list).
            Error
    }, func() error {
        if !builder.needTotal {
            return nil
        }

        return builder.data.db.WithContext(ctx).
            Model(&list).
            Scopes(filterScope.(func(*gorm.DB) *gorm.DB)).
            Count(&total).
            Error
    }); err != nil {
        return nil, 0, err
    }

    return list, total, nil
}

// QueryMongoListStrategy MongoDB 查询策略实现
type QueryMongoListStrategy[R any] struct{}

// NewQueryMongoListStrategy 创建 MongoDB 查询策略实例
func NewQueryMongoListStrategy[R any]() *QueryMongoListStrategy[R] {
    return &QueryMongoListStrategy[R]{}
}

// QueryList 实现 MongoDB 查询逻辑
func (s *QueryMongoListStrategy[R]) QueryList(
    ctx context.Context,
    builder *builder[R],
) (list []*R, total int64, err error) {
    filterOpt, err := builder.filter(ctx)
    if err != nil {
        return nil, 0, err
    }

    sortOpt := builder.sort()
    // 验证过滤条件和排序条件的类型有效性
    for _, opt := range []any{filterOpt, sortOpt} {
        _, mOk := opt.(bson.M)
        _, dOk := opt.(bson.D)
        if !mOk && !dOk {
            return nil, 0, errors.New("invalid option")
        }
    }

    // 使用 WaitAndGo 并行执行数据查询和总数统计操作
    if err := utils.WaitAndGo(func() error {
        limit := builder.limit
        if builder.limit < 1 {
            limit = defaultLimit
        }

        findOpt := options.Find().
            SetLimit(int64(limit)).
            SetSkip(int64(builder.start)).
            SetSort(sortOpt)
        cursor, err := builder.data.mongodb.
            Find(ctx, filterOpt, findOpt)
        if err != nil {
            return err
        }
        defer func(cursor *mongo.Cursor, ctx context.Context) {
            _ = cursor.Close(ctx)
        }(cursor, ctx)

        if err := cursor.All(ctx, &list); err != nil {
            return err
        }

        return nil
    }, func() error {
        if !builder.needTotal {
            return nil
        }

        total, err = builder.data.mongodb.
            CountDocuments(ctx, filterOpt)
        if err != nil {
            return err
        }

        return nil
    }); err != nil {
        return nil, 0, err
    }

    return list, total, nil
}

接着我们可以将构建器的参数通过函数选项模式进行配置,定义一个通用选项结构:

package builder

import (
    pb "kratos-demo/api/goods/v1"
)

const (
    defaultStart     = 0    // 默认从第0条开始
    defaultLimit     = 10   // 默认每页10条
    defaultNeedTotal = true // 默认需要总数
)

// Filter 定义过滤条件的通用接口类型
type Filter any

// Sort 定义排序条件的通用接口类型
type Sort any

// QueryListOptions 定义了查询列表的通用选项接口
// 泛型参数:
//    F - 过滤条件类型参数
//    S - 排序条件类型参数
type QueryListOptions[F Filter, S Sort] interface {
    GetData() *DBProxy
    GetFilter() *F
    GetSort() S
    GetStart() uint32
    GetLimit() uint32
    GetNeedTotal() bool
}

// BaseQueryListOptions 实现了QueryListOptions接口的基础结构体
// 包含查询列表所需的所有基本选项
type BaseQueryListOptions[F Filter, S Sort] struct {
    data      *DBProxy // 数据实例
    filter    *F       // 过滤条件生成函数
    sort      S        // 排序条件生成函数
    start     uint32   // 分页起始位置
    limit     uint32   // 每页数据条数
    needTotal bool     // 是否需要查询总数
}

func (opts *BaseQueryListOptions[F, S]) GetData() *DBProxy {
    return opts.data
}

func (opts *BaseQueryListOptions[F, S]) GetFilter() *F {
    return opts.filter
}

func (opts *BaseQueryListOptions[F, S]) GetSort() S {
    return opts.sort
}

func (opts *BaseQueryListOptions[F, S]) GetStart() uint32 {
    return opts.start
}

func (opts *BaseQueryListOptions[F, S]) GetLimit() uint32 {
    return opts.limit
}

func (opts *BaseQueryListOptions[F, S]) GetNeedTotal() bool {
    return opts.needTotal
}

// QueryOption 定义用于配置查询选项的函数类型
type QueryOption[F Filter, S Sort] func(options *BaseQueryListOptions[F, S])

// LoadQueryOptions 加载并应用查询选项
// 参数:
//    opts - 可变数量的查询选项函数
//
// 返回:
//    配置好的BaseQueryListOptions实例
func LoadQueryOptions[F Filter, S Sort](opts ...QueryOption[F, S]) BaseQueryListOptions[F, S] {
    // 初始化默认选项
    options := BaseQueryListOptions[F, S]{
        start:     defaultStart,
        limit:     defaultLimit,
        needTotal: defaultNeedTotal,
    }

    // 应用所有选项函数
    for _, opt := range opts {
        opt(&options)
    }

    return options
}

func WithData[F Filter, S Sort](data *DBProxy) QueryOption[F, S] {
    return func(o *BaseQueryListOptions[F, S]) {
        o.data = data
    }
}

func WithFilter[F Filter, S Sort](filter *F) QueryOption[F, S] {
    return func(o *BaseQueryListOptions[F, S]) {
        o.filter = filter
    }
}

func WithSort[F Filter, S Sort](sort S) QueryOption[F, S] {
    return func(o *BaseQueryListOptions[F, S]) {
        o.sort = sort
    }
}

func WithStart[F Filter, S Sort](start uint32) QueryOption[F, S] {
    return func(o *BaseQueryListOptions[F, S]) {
        o.start = start
    }
}

func WithLimit[F Filter, S Sort](limit uint32) QueryOption[F, S] {
    return func(o *BaseQueryListOptions[F, S]) {
        o.limit = limit
    }
}

func WithNeedTotal[F Filter, S Sort](needTotal bool) QueryOption[F, S] {
    return func(o *BaseQueryListOptions[F, S]) {
        o.needTotal = needTotal
    }
}

然后完善业务层处理:

package builder

import (
    "context"
    "gorm.io/gorm"
    pb "kratos-demo/api/goods/v1"
    "kratos-demo/internal/model"
)

// QueryGoods 查询商品列表
type QueryGoods struct {
    builder[model.Goods]
    filter *pb.ListGoodsFilter
    sort   pb.QueryGoodsListSort
}

// NewQueryGoods 创建查询商品列表实例
func NewQueryGoods(opts ...QueryOption[pb.ListGoodsFilter, pb.QueryGoodsListSort]) *QueryGoods {
    options := LoadQueryOptions(opts...)
    return &QueryGoods{
        builder: builder[model.Goods]{
            data:      options.GetData(),
            start:     options.GetStart(),
            limit:     options.GetLimit(),
            needTotal: options.GetNeedTotal(),
        },
        filter: options.GetFilter(),
        sort:   options.GetSort(),
    }
}

// getQueryListFilter 获取查询列表的过滤器
// 可在过滤器中查询其他服务的数据,ctx方便链路追踪
func (query *QueryGoods) getQueryListFilter(context.Context) (any, error) {
    return func(db *gorm.DB) *gorm.DB {
        if query.filter.GetName() != "" {
            db.Where("name = ?", query.filter.GetName())
        }

        if query.filter.GetMinPrice() > 0 {
            db.Where("price >=?", query.filter.GetMinPrice())
        }

        if query.filter.GetMaxPrice() > 0 {
            db.Where("price <=?", query.filter.GetMaxPrice())
        }

        return db
    }, nil
}

// getQueryListSort 获取查询列表的排序
func (query *QueryGoods) getQueryListSort() any {
    return func(db *gorm.DB) *gorm.DB {
        return map[pb.QueryGoodsListSort]*gorm.DB{
            pb.QueryGoodsListSort_CREATED_AT: db.Order("created_at desc"),
            pb.QueryGoodsListSort_PRICE:      db.Order("price desc"),
        }[query.sort]
    }
}

// QueryList 查询列表
func (query *QueryGoods) QueryList(ctx context.Context) (list []*model.Goods, total int64, err error) {
    return query.builder.
        SetFilter(query.getQueryListFilter).
        SetSort(query.getQueryListSort).
        QueryList(ctx)
}

最后我们可以在数据层进行调用:

package data

import (
    "context"
    "github.com/go-kratos/kratos/v2/log"
    pb "kratos-demo/api/goods/v1"
    "kratos-demo/internal/biz"
    "kratos-demo/internal/data/builder"
    "kratos-demo/internal/model"
)

type goodsRepo struct {
    data *Data
    log  *log.Helper
}

func NewGoodsRepo(data *Data, logger log.Logger) biz.GoodsRepo {
    return &goodsRepo{data: data, log: log.NewHelper(logger)}
}

func (r *goodsRepo) ListGoods(ctx context.Context, req *pb.ListGoodsRequest) ([]*model.Goods, uint32, error) {
    // 定义类型别名,简化泛型类型的推断
    type filter = pb.ListGoodsFilter
    type sort = pb.QueryGoodsListSort

    list, total, err := builder.NewQueryGoods(
        builder.WithData[filter, sort](builder.NewDBProxy(r.data.db, nil)),
        builder.WithFilter[filter, sort](req.GetFilter()),
        builder.WithSort[filter, sort](req.GetSort()),
        builder.WithStart[filter, sort](req.GetStart()),
        builder.WithLimit[filter, sort](req.GetLimit()),
    ).QueryList(ctx)
    if err != nil {
        r.log.WithContext(ctx).Errorf("list goods err: %+v", err)
        return nil, 0, err
    }

    return list, uint32(total), nil
}

通过泛型的层层封装,一个商品列表查询的完整流程就完成了。后续如果需要增加其他查询列表的场景,只需要依葫芦画瓢定义好通用的逻辑,然后实现对应的 GetQueryListFilter 和 GetQueryListSort 方法即可。