From bb986d67a7b154bbcfe15e8b6fc36c4002eaaa35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=85=89=E6=98=A5?= Date: Mon, 15 Aug 2022 09:31:46 +0800 Subject: [PATCH] - update ctx --- const.go | 4 +++- context.go | 12 ++++++------ context_test.go | 2 +- gin_use.go | 4 ++-- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/const.go b/const.go index 8aaa36f..762b214 100644 --- a/const.go +++ b/const.go @@ -1,3 +1,5 @@ package gotrace_id -const Version = "1.0.4" +const Version = "1.0.5" + +const Nil = "%!s()" diff --git a/context.go b/context.go index 060d429..70fc385 100644 --- a/context.go +++ b/context.go @@ -8,22 +8,22 @@ import ( ) // CustomTraceIdContext 自定义设置跟踪编号上下文 -func CustomTraceIdContext() context.Context { - return context.WithValue(context.Background(), "trace_id", gostring.GetUuId()) +func CustomTraceIdContext(ctx context.Context) context.Context { + return context.WithValue(ctx, "trace_id", gostring.GetUuId()) } // SetGinTraceIdContext 设置跟踪编号上下文 -func SetGinTraceIdContext(c *gin.Context) context.Context { - return context.WithValue(context.Background(), "trace_id", GetGinTraceId(c)) +func SetGinTraceIdContext(ctx context.Context, c *gin.Context) context.Context { + return context.WithValue(ctx, "trace_id", GetGinTraceId(c)) } // GetTraceIdContext 通过上下文获取跟踪编号 func GetTraceIdContext(ctx context.Context) string { traceId := fmt.Sprintf("%s", ctx.Value("trace_id")) - if len(traceId) <= 0 { + if traceId == Nil { return "" } - if traceId == "%!s()" { + if len(traceId) <= 0 { return "" } return traceId diff --git a/context_test.go b/context_test.go index a952e97..630b147 100644 --- a/context_test.go +++ b/context_test.go @@ -6,7 +6,7 @@ import ( ) func TestContext(t *testing.T) { - ctx1 := CustomTraceIdContext() + ctx1 := CustomTraceIdContext(context.Background()) t.Log(ctx1) t.Log(GetTraceIdContext(ctx1)) ctx2 := context.Background() diff --git a/gin_use.go b/gin_use.go index 9bcea31..0d57327 100644 --- a/gin_use.go +++ b/gin_use.go @@ -22,10 +22,10 @@ func SetGinTraceId() gin.HandlerFunc { // GetGinTraceId 通过gin中间件获取跟踪编号 func GetGinTraceId(c *gin.Context) string { traceId := fmt.Sprintf("%s", c.MustGet("trace_id")) - if len(traceId) <= 0 { + if traceId == Nil { return "" } - if traceId == "%!s()" { + if len(traceId) <= 0 { return "" } return traceId