diff --git a/license/crypto.go b/license/crypto.go index aa31ada..9dcdda2 100644 --- a/license/crypto.go +++ b/license/crypto.go @@ -1,47 +1,60 @@ package license import ( - "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" - "crypto/rsa" "crypto/sha256" - "crypto/x509" + "encoding/asn1" "encoding/base64" - "encoding/pem" + "math/big" ) -var ( - privateKey *ecdsa.PrivateKey -) +var privateKey *ecdsa.PrivateKey func init() { - privateKey, _ = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + var err error + privateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } } -func SignPayload(payload []byte) (string, error) { - hash := sha256.Sum256(payload) +// ecdsaSignature 是 ASN.1 编码中使用的结构体 +type ecdsaSignature struct { + R, S *big.Int +} + +func SignPayload(message []byte) (string, error) { + hash := sha256.Sum256(message) r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash[:]) if err != nil { return "", err } - sig := append(r.Bytes(), s.Bytes()...) + sig, err := asn1.Marshal(ecdsaSignature{r, s}) + if err != nil { + return "", err + } return base64.StdEncoding.EncodeToString(sig), nil } -func VerifySignature(pub *rsa.PublicKey, message []byte, signatureBase64 string) bool { - signature, err := base64.StdEncoding.DecodeString(signatureBase64) +func VerifySignature(pub *ecdsa.PublicKey, message []byte, signatureBase64 string) bool { + sigBytes, err := base64.StdEncoding.DecodeString(signatureBase64) if err != nil { return false } - hashed := sha256.Sum256(message) - err = rsa.VerifyPKCS1v15(pub, crypto.SHA256, hashed[:], signature) - return err == nil + + var sig ecdsaSignature + _, err = asn1.Unmarshal(sigBytes, &sig) + if err != nil { + return false + } + + hash := sha256.Sum256(message) + return ecdsa.Verify(pub, hash[:], sig.R, sig.S) } -func ExportPublicKeyPEM() string { - pubKeyBytes, _ := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) - return string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubKeyBytes})) +func GetPublicKey() *ecdsa.PublicKey { + return &privateKey.PublicKey } diff --git a/license/service.go b/license/service.go index f1003ae..2c4eef2 100644 --- a/license/service.go +++ b/license/service.go @@ -13,15 +13,25 @@ func GenerateLicenseHandler(db storage.Database) fiber.Handler { return func(c *fiber.Ctx) error { var req LicenseRequest if err := c.BodyParser(&req); err != nil { - return fiber.ErrBadRequest + return fiber.NewError(fiber.StatusBadRequest, "Invalid request body") } - payloadBytes, _ := json.Marshal(req) - payloadB64 := base64.StdEncoding.EncodeToString(payloadBytes) - signature, _ := SignPayload(payloadBytes) + if req.MachineID == "" || req.Expiry == "" { + return fiber.NewError(fiber.StatusBadRequest, "Missing required fields") + } + + payloadBytes, err := json.Marshal(req) + if err != nil { + return fiber.NewError(fiber.StatusInternalServerError, "Failed to encode payload") + } + + signature, err := SignPayload(payloadBytes) + if err != nil { + return fiber.NewError(fiber.StatusInternalServerError, "Signing failed") + } licenseFile := LicenseFile{ - Payload: payloadB64, + Payload: base64.StdEncoding.EncodeToString(payloadBytes), Signature: signature, } @@ -33,25 +43,40 @@ func ActivateLicenseHandler(db storage.Database) fiber.Handler { return func(c *fiber.Ctx) error { var lf LicenseFile if err := c.BodyParser(&lf); err != nil { - return fiber.ErrBadRequest + return fiber.NewError(fiber.StatusBadRequest, "Invalid license format") + } + + payloadBytes, err := base64.StdEncoding.DecodeString(lf.Payload) + if err != nil { + return fiber.NewError(fiber.StatusBadRequest, "Invalid base64 payload") + } + + if !VerifySignature(GetPublicKey(), payloadBytes, lf.Signature) { + return fiber.NewError(fiber.StatusUnauthorized, "Invalid license signature") } - payloadBytes, _ := base64.StdEncoding.DecodeString(lf.Payload) var req LicenseRequest - json.Unmarshal(payloadBytes, &req) + if err := json.Unmarshal(payloadBytes, &req); err != nil { + return fiber.NewError(fiber.StatusBadRequest, "Malformed payload") + } + + if req.MachineID == "" { + return fiber.NewError(fiber.StatusBadRequest, "Missing machine ID") + } if db.HasActivated(req.MachineID) { - return fiber.NewError(403, "This machine is already activated.") + return fiber.NewError(fiber.StatusForbidden, "This machine is already activated") } - if !VerifySignature(&privateKey.PublicKey, payloadBytes, lf.Signature) { - return fiber.NewError(401, "Invalid license signature") + expiry, err := time.Parse("2006-01-02", req.Expiry) + if err != nil || time.Now().After(expiry) { + return fiber.NewError(fiber.StatusForbidden, "License is invalid or expired") } - licenseText := lf.Payload + "." + lf.Signature - db.SaveActivation(req.MachineID, licenseText) + // 绑定激活记录 + db.SaveActivation(req.MachineID, lf.Payload+"."+lf.Signature) - return c.SendString("License activated successfully.") + return c.JSON(fiber.Map{"status": "success", "message": "License activated successfully"}) } } @@ -59,33 +84,33 @@ func ValidateLicenseHandler(db storage.Database) fiber.Handler { return func(c *fiber.Ctx) error { var lf LicenseFile if err := c.BodyParser(&lf); err != nil { - return fiber.ErrBadRequest + return fiber.NewError(fiber.StatusBadRequest, "Invalid license format") } payloadBytes, err := base64.StdEncoding.DecodeString(lf.Payload) if err != nil { - return fiber.NewError(fiber.StatusBadRequest, "Invalid payload encoding") + return fiber.NewError(fiber.StatusBadRequest, "Invalid base64 payload") } - // 先验证签名是否真的是对 payloadBytes 签的 - if !VerifySignature(&privateKey.PublicKey, payloadBytes, lf.Signature) { + if !VerifySignature(GetPublicKey(), payloadBytes, lf.Signature) { return fiber.NewError(fiber.StatusUnauthorized, "Invalid license signature") } - // 验证通过后再解析 payload 内容 var req LicenseRequest if err := json.Unmarshal(payloadBytes, &req); err != nil { return fiber.NewError(fiber.StatusBadRequest, "Malformed payload") } expiry, err := time.Parse("2006-01-02", req.Expiry) - if err != nil { - return fiber.NewError(fiber.StatusBadRequest, "Invalid expiry date") - } - if time.Now().After(expiry) { + if err != nil || time.Now().After(expiry) { return fiber.NewError(fiber.StatusForbidden, "License expired") } - return c.JSON(fiber.Map{"valid": true, "features": req.Features}) + return c.JSON(fiber.Map{ + "valid": true, + "features": req.Features, + "machine": req.MachineID, + "expiry": req.Expiry, + }) } }