| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | package main | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 	"errors" | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | 	"fmt" | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 	"go/format" | 
					
						
							|  |  |  | 	"io" | 
					
						
							|  |  |  | 	"os" | 
					
						
							|  |  |  | 	"path/filepath" | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | 	"regexp" | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | 	"strings" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-10 17:25:23 +08:00
										 |  |  | 	"git.apinb.com/bsm-tools/protoc-gen-slc/tpl" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | 	"git.apinb.com/bsm-sdk/core/utils" | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 	"golang.org/x/mod/modfile" | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | 	"google.golang.org/protobuf/compiler/protogen" | 
					
						
							|  |  |  | 	"google.golang.org/protobuf/types/pluginpb" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | var ServicesName []string | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | func main() { | 
					
						
							|  |  |  | 	protogen.Options{}.Run(func(gen *protogen.Plugin) error { | 
					
						
							|  |  |  | 		gen.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 		if !utils.PathExists("./internal") { | 
					
						
							|  |  |  | 			os.MkdirAll("./internal", 777) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if !utils.PathExists("./internal/server") { | 
					
						
							|  |  |  | 			os.MkdirAll("./internal/server", 777) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if !utils.PathExists("./internal/logic") { | 
					
						
							|  |  |  | 			os.MkdirAll("./internal/logic", 777) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | 		for _, f := range gen.Files { | 
					
						
							|  |  |  | 			if len(f.Services) == 0 { | 
					
						
							|  |  |  | 				continue | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			if err := generateFiles(gen, f); err != nil { | 
					
						
							|  |  |  | 				return err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | 		err := generateNewServerFile(ServicesName) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | 		return nil | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func generateFiles(gen *protogen.Plugin, file *protogen.File) error { | 
					
						
							|  |  |  | 	for _, service := range file.Services { | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 		ServicesName = append(ServicesName, service.GoName) | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | 		// Generate server file
 | 
					
						
							|  |  |  | 		if err := generateServerFile(gen, file, service); err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | 		// Generate logic file
 | 
					
						
							|  |  |  | 		if err := generateLogicFile(gen, file, service); err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func generateNewServerFile(services []string) error { | 
					
						
							|  |  |  | 	moduleName := getModuleName() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	//create new.go
 | 
					
						
							|  |  |  | 	code := tpl.NewFile | 
					
						
							|  |  |  | 	newImports := []string{ | 
					
						
							|  |  |  | 		"pb \"" + moduleName + "/pb\"", | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	code = strings.ReplaceAll(code, "{import}", strings.Join(newImports, "\n")) | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// register grpc
 | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 	var register []string | 
					
						
							|  |  |  | 	for _, service := range services { | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | 		register = append(register, "pb.Register"+service+"Server(srv.Grpc, New"+service+"Server())") | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	code = strings.ReplaceAll(code, "{register}", strings.Join(register, "\n")) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | 	// register grpc gw
 | 
					
						
							|  |  |  | 	var gw []string | 
					
						
							|  |  |  | 	for _, service := range services { | 
					
						
							|  |  |  | 		gw = append(gw, "pb.Register"+service+"HandlerFromEndpoint(srv.Ctx, srv.Mux, addr, opts)") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	code = strings.ReplaceAll(code, "{gw}", strings.Join(gw, "\n")) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 	// 格式化代码
 | 
					
						
							|  |  |  | 	formattedCode, err := format.Source([]byte(code)) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("failed to format generated code: %w", err) | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | 	StringToFile("./internal/server/new.go", string(formattedCode)) | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func generateServerFile(gen *protogen.Plugin, file *protogen.File, service *protogen.Service) error { | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | 	filename := fmt.Sprintf("./internal/server/%s_server.go", strings.ToLower(service.GoName)) | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 	moduleName := getModuleName() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	//create servers.
 | 
					
						
							|  |  |  | 	code := tpl.Server | 
					
						
							|  |  |  | 	imports := []string{ | 
					
						
							|  |  |  | 		"\"" + moduleName + "/internal/logic/" + strings.ToLower(service.GoName) + "\"", | 
					
						
							|  |  |  | 		"pb \"" + moduleName + "/pb\"", | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	code = strings.ReplaceAll(code, "{import}", strings.Join(imports, "\n")) | 
					
						
							|  |  |  | 	code = strings.ReplaceAll(code, "{service}", service.GoName) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var codeMethods []string | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | 	for _, method := range service.Methods { | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 		commit := strings.TrimSpace(method.Comments.Leading.String()) | 
					
						
							|  |  |  | 		methodCode := tpl.Method | 
					
						
							|  |  |  | 		methodCode = strings.ReplaceAll(methodCode, "{service}", service.GoName) | 
					
						
							|  |  |  | 		methodCode = strings.ReplaceAll(methodCode, "{serviceLower}", strings.ToLower(service.GoName)) | 
					
						
							|  |  |  | 		methodCode = strings.ReplaceAll(methodCode, "{func}", method.GoName) | 
					
						
							|  |  |  | 		methodCode = strings.ReplaceAll(methodCode, "{comment}", commit) | 
					
						
							|  |  |  | 		methodCode = strings.ReplaceAll(methodCode, "{input}", method.Input.GoIdent.GoName) | 
					
						
							|  |  |  | 		methodCode = strings.ReplaceAll(methodCode, "{output}", method.Output.GoIdent.GoName) | 
					
						
							|  |  |  | 		codeMethods = append(codeMethods, methodCode) | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 	code = strings.ReplaceAll(code, "{method}", strings.Join(codeMethods, "\n")) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// 格式化代码
 | 
					
						
							|  |  |  | 	formattedCode, err := format.Source([]byte(code)) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return fmt.Errorf("failed to format generated code: %w", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	StringToFile(filename, string(formattedCode)) | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func generateLogicFile(gen *protogen.Plugin, file *protogen.File, service *protogen.Service) error { | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | 	logicPath := "./internal/logic/" + strings.ToLower(service.GoName) | 
					
						
							|  |  |  | 	if !utils.PathExists(logicPath) { | 
					
						
							|  |  |  | 		os.MkdirAll(logicPath, os.ModePerm) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	moduleName := getModuleName() | 
					
						
							|  |  |  | 	for _, method := range service.Methods { | 
					
						
							|  |  |  | 		filename := fmt.Sprintf("%s/%s.go", logicPath, toSnakeCase(method.GoName)) | 
					
						
							|  |  |  | 		if utils.PathExists(filename) { | 
					
						
							|  |  |  | 			continue | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		code := tpl.LogicFile | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		code = strings.ReplaceAll(code, "{methodName}", strings.ToLower(service.GoName)) | 
					
						
							|  |  |  | 		imports := []string{ | 
					
						
							|  |  |  | 			"pb \"" + moduleName + "/pb\"", | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-10 17:25:23 +08:00
										 |  |  | 		if strings.ToLower(method.Input.GoIdent.GoName) == "identrequest" || strings.ToLower(method.Input.GoIdent.GoName) == "fetchrequest" { | 
					
						
							|  |  |  | 			if strings.ToLower(method.Input.GoIdent.GoName) == "identrequest" { | 
					
						
							|  |  |  | 				imports = append(imports, "\"git.apinb.com/bsm-sdk/core/errcode\"") | 
					
						
							|  |  |  | 				code = strings.ReplaceAll(code, "{valid}", tpl.ValidCode) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			if strings.ToLower(method.Input.GoIdent.GoName) == "fetchrequest" { | 
					
						
							|  |  |  | 				code = strings.ReplaceAll(code, "{valid}", tpl.FetchValidCode) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		} else { | 
					
						
							|  |  |  | 			code = strings.ReplaceAll(code, "{valid}", "// TODO: valid code") | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if strings.ToLower(method.Output.GoIdent.GoName) == "statusreply" { | 
					
						
							|  |  |  | 			imports = append(imports, "\"time\"") | 
					
						
							|  |  |  | 			code = strings.ReplaceAll(code, "{return}", tpl.StatusReplyCode) | 
					
						
							|  |  |  | 		} else { | 
					
						
							|  |  |  | 			code = strings.ReplaceAll(code, "{return}", "return ") | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | 		code = strings.ReplaceAll(code, "{import}", strings.Join(imports, "\n")) | 
					
						
							|  |  |  | 		commit := strings.TrimSpace(method.Comments.Leading.String()) | 
					
						
							|  |  |  | 		code = strings.ReplaceAll(code, "{func}", method.GoName) | 
					
						
							|  |  |  | 		code = strings.ReplaceAll(code, "{comment}", commit) | 
					
						
							|  |  |  | 		code = strings.ReplaceAll(code, "{input}", method.Input.GoIdent.GoName) | 
					
						
							|  |  |  | 		code = strings.ReplaceAll(code, "{output}", method.Output.GoIdent.GoName) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-10 17:25:23 +08:00
										 |  |  | 		formattedCode, err := format.Source([]byte(code)) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("failed to format generated code: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-10 17:25:23 +08:00
										 |  |  | 		StringToFile(filename, string(formattedCode)) | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-01 17:37:16 +08:00
										 |  |  | func toSnakeCase(str string) string { | 
					
						
							|  |  |  | 	// Use a regular expression to find uppercase letters and insert an underscore before them
 | 
					
						
							|  |  |  | 	re := regexp.MustCompile("([a-z0-9])([A-Z])") | 
					
						
							|  |  |  | 	snake := re.ReplaceAllString(str, "${1}_${2}") | 
					
						
							|  |  |  | 	// Convert the entire string to lowercase
 | 
					
						
							|  |  |  | 	return strings.ToLower(snake) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | func methodSignature(g *protogen.GeneratedFile, method *protogen.Method) string { | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 	return fmt.Sprintf("%s(ctx context.Context, req pb%s) (*%s, error)", | 
					
						
							| 
									
										
										
										
											2025-03-28 20:54:27 +08:00
										 |  |  | 		method.GoName, | 
					
						
							|  |  |  | 		method.Input.GoIdent, | 
					
						
							|  |  |  | 		method.Output.GoIdent) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func fullMethodName(file *protogen.File, service *protogen.Service, method *protogen.Method) string { | 
					
						
							|  |  |  | 	return fmt.Sprintf("/%s.%s/%s", | 
					
						
							|  |  |  | 		file.Proto.GetPackage(), | 
					
						
							|  |  |  | 		service.GoName, | 
					
						
							|  |  |  | 		method.GoName) | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2025-03-29 17:38:05 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | func getModuleName() (modulePath string) { | 
					
						
							|  |  |  | 	// 获取当前工作目录
 | 
					
						
							|  |  |  | 	cwd, err := os.Getwd() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		fmt.Errorf("failed to get current working directory: %w", err) | 
					
						
							|  |  |  | 		return | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// 读取 go.mod 文件
 | 
					
						
							|  |  |  | 	modFilePath := filepath.Join(cwd, "go.mod") | 
					
						
							|  |  |  | 	modFileBytes, err := os.ReadFile(modFilePath) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		fmt.Errorf("failed to read go.mod file: %w", err) | 
					
						
							|  |  |  | 		return | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// 解析 go.mod 文件
 | 
					
						
							|  |  |  | 	modFile, err := modfile.Parse(modFilePath, modFileBytes, nil) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		fmt.Errorf("failed to parse go.mod file: %w", err) | 
					
						
							|  |  |  | 		return | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// 获取模块路径
 | 
					
						
							|  |  |  | 	return modFile.Module.Mod.Path | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // 将字符串写入文件
 | 
					
						
							|  |  |  | func StringToFile(path, content string) error { | 
					
						
							|  |  |  | 	startF, err := os.Create(path) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return errors.New("os.Create create file " + path + " error:" + err.Error()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	defer startF.Close() | 
					
						
							|  |  |  | 	_, err = io.WriteString(startF, content) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return errors.New("io.WriteString to " + path + " error:" + err.Error()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2025-04-10 17:25:23 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | // parseOptions 解析注释中的选项字符串并返回一个 map
 | 
					
						
							|  |  |  | func parseOptions(comment string) map[string]string { | 
					
						
							|  |  |  | 	// 去掉注释符号和分号
 | 
					
						
							|  |  |  | 	comment = strings.Trim(comment, " //;") | 
					
						
							|  |  |  | 	// 按逗号分割选项
 | 
					
						
							|  |  |  | 	options := strings.Split(comment, ",") | 
					
						
							|  |  |  | 	result := make(map[string]string) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, opt := range options { | 
					
						
							|  |  |  | 		// 按等号分割键和值
 | 
					
						
							|  |  |  | 		parts := strings.SplitN(opt, "=", 2) | 
					
						
							|  |  |  | 		if len(parts) == 2 { | 
					
						
							|  |  |  | 			key := strings.TrimSpace(parts[0]) | 
					
						
							|  |  |  | 			value := strings.TrimSpace(parts[1]) | 
					
						
							|  |  |  | 			result[key] = value | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return result | 
					
						
							|  |  |  | } |