275 lines
6.3 KiB
Go
275 lines
6.3 KiB
Go
// Package rsa implements crypto/rsa
|
|
package rsa
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"fmt"
|
|
"math/big"
|
|
|
|
"github.com/golang-module/dongle/openssl"
|
|
)
|
|
|
|
var (
|
|
// returns an invalid public key error.
|
|
// 返回无效的公钥错误
|
|
invalidPublicKeyError = func() error {
|
|
return fmt.Errorf("rsa: invalid public key, please make sure the public key is valid")
|
|
}
|
|
// returns an invalid private key error.
|
|
// 返回无效的私钥错误
|
|
invalidPrivateKeyError = func() error {
|
|
return fmt.Errorf("rsa: invalid private key, please make sure the private key is valid")
|
|
}
|
|
// returns an unsupported hash function error.
|
|
// 返回不支持的哈希函数错误
|
|
unsupportedHashError = func() error {
|
|
return fmt.Errorf("rsa: invalid hash function, the hash function is unsupported")
|
|
}
|
|
)
|
|
|
|
// KeyPair defines a KeyPair struct.
|
|
// 定义 KeyPair 结构体
|
|
type KeyPair struct {
|
|
publicKey []byte
|
|
privateKey []byte
|
|
hash crypto.Hash
|
|
}
|
|
|
|
// NewKeyPair returns a new KeyPair instance.
|
|
// 初始化 keyPair 结构体
|
|
func NewKeyPair() *KeyPair {
|
|
return &KeyPair{}
|
|
}
|
|
|
|
// SetPublicKey sets public key.
|
|
// 设置公钥
|
|
func (k *KeyPair) SetPublicKey(publicKey []byte) {
|
|
k.publicKey = publicKey
|
|
}
|
|
|
|
// SetPrivateKey sets private key.
|
|
// 设置私钥
|
|
func (k *KeyPair) SetPrivateKey(privateKey []byte) {
|
|
k.privateKey = privateKey
|
|
}
|
|
|
|
// SetHash sets hash algorithm.
|
|
// 设置哈希算法
|
|
func (k *KeyPair) SetHash(hash crypto.Hash) {
|
|
k.hash = hash
|
|
}
|
|
|
|
// EncryptByPublicKey encrypts by public key.
|
|
// 通过公钥加密
|
|
func (k *KeyPair) EncryptByPublicKey(src []byte) (dst []byte, err error) {
|
|
dst = []byte("")
|
|
if len(src) == 0 {
|
|
return
|
|
}
|
|
if !openssl.RSA.IsPublicKey(k.publicKey) {
|
|
err = invalidPublicKeyError()
|
|
return
|
|
}
|
|
pub, err := openssl.RSA.ParsePublicKey(k.publicKey)
|
|
if err != nil {
|
|
err = invalidPublicKeyError()
|
|
return
|
|
}
|
|
buffer := bytes.NewBufferString("")
|
|
for _, chunk := range bytesSplit(src, pub.Size()-11) {
|
|
dst, err = rsa.EncryptPKCS1v15(rand.Reader, pub, chunk)
|
|
buffer.Write(dst)
|
|
}
|
|
dst = buffer.Bytes()
|
|
return
|
|
}
|
|
|
|
// DecryptByPrivateKey encrypts by private key.
|
|
// 通过私钥解密
|
|
func (k *KeyPair) DecryptByPrivateKey(src []byte) (dst []byte, err error) {
|
|
dst = []byte("")
|
|
if len(src) == 0 {
|
|
return
|
|
}
|
|
if !openssl.RSA.IsPrivateKey(k.privateKey) {
|
|
err = invalidPrivateKeyError()
|
|
return
|
|
}
|
|
pri, err := openssl.RSA.ParsePrivateKey(k.privateKey)
|
|
if err != nil {
|
|
err = invalidPrivateKeyError()
|
|
return
|
|
}
|
|
buffer := bytes.NewBufferString("")
|
|
for _, chunk := range bytesSplit(src, pri.Size()) {
|
|
dst, err = rsa.DecryptPKCS1v15(rand.Reader, pri, chunk)
|
|
buffer.Write(dst)
|
|
}
|
|
dst = buffer.Bytes()
|
|
return
|
|
}
|
|
|
|
// EncryptByPrivateKey encrypts by private key.
|
|
// 通过私钥加密
|
|
func (k *KeyPair) EncryptByPrivateKey(src []byte) (dst []byte, err error) {
|
|
dst = []byte("")
|
|
if len(src) == 0 {
|
|
return
|
|
}
|
|
|
|
if !openssl.RSA.IsPrivateKey(k.privateKey) {
|
|
err = invalidPrivateKeyError()
|
|
return
|
|
}
|
|
pri, err := openssl.RSA.ParsePrivateKey(k.privateKey)
|
|
if err != nil {
|
|
err = invalidPrivateKeyError()
|
|
return
|
|
}
|
|
buffer := bytes.NewBufferString("")
|
|
for _, chunk := range bytesSplit(src, pri.Size()-11) {
|
|
dst, err = rsa.SignPKCS1v15(nil, pri, crypto.Hash(0), chunk)
|
|
buffer.Write(dst)
|
|
}
|
|
dst = buffer.Bytes()
|
|
return
|
|
}
|
|
|
|
// DecryptByPublicKey encrypts by public key.
|
|
// 通过公钥解密
|
|
func (k *KeyPair) DecryptByPublicKey(src []byte) (dst []byte, err error) {
|
|
dst = []byte("")
|
|
if len(src) == 0 {
|
|
return
|
|
}
|
|
|
|
if !openssl.RSA.IsPublicKey(k.publicKey) {
|
|
err = invalidPublicKeyError()
|
|
return
|
|
}
|
|
pub, err := openssl.RSA.ParsePublicKey(k.publicKey)
|
|
if err != nil {
|
|
err = invalidPublicKeyError()
|
|
return
|
|
}
|
|
buffer := bytes.NewBufferString("")
|
|
bigInt := new(big.Int)
|
|
for _, chunk := range bytesSplit(src, pub.Size()) {
|
|
bigInt.Exp(new(big.Int).SetBytes(chunk), big.NewInt(int64(pub.E)), pub.N)
|
|
dst = leftUnPad(leftPad(bigInt.Bytes(), pub.Size()))
|
|
buffer.Write(dst)
|
|
}
|
|
dst = buffer.Bytes()
|
|
return
|
|
}
|
|
|
|
// SignByPrivateKey signs by private key.
|
|
// 通过私钥签名
|
|
func (k *KeyPair) SignByPrivateKey(src []byte) (dst []byte, err error) {
|
|
dst = []byte("")
|
|
pri, err := openssl.RSA.ParsePrivateKey(k.privateKey)
|
|
if err != nil {
|
|
err = invalidPrivateKeyError()
|
|
return
|
|
}
|
|
if !k.IsSupportedHash() {
|
|
err = unsupportedHashError()
|
|
return
|
|
}
|
|
hasher := k.hash.New()
|
|
hasher.Write(src)
|
|
hashed := hasher.Sum(nil)
|
|
dst, err = rsa.SignPKCS1v15(rand.Reader, pri, k.hash, hashed)
|
|
return
|
|
}
|
|
|
|
// VerifyByPublicKey verify by public key.
|
|
// 通过公钥验签
|
|
func (k *KeyPair) VerifyByPublicKey(src, sign []byte) (err error) {
|
|
pub, err := openssl.RSA.ParsePublicKey(k.publicKey)
|
|
if err != nil {
|
|
err = invalidPublicKeyError()
|
|
return
|
|
}
|
|
if !k.IsSupportedHash() {
|
|
err = unsupportedHashError()
|
|
return
|
|
}
|
|
hasher := k.hash.New()
|
|
hasher.Write(src)
|
|
hashed := hasher.Sum(nil)
|
|
return rsa.VerifyPKCS1v15(pub, k.hash, hashed, sign)
|
|
}
|
|
|
|
// IsPublicKey whether is a public key.
|
|
// 是否是公钥
|
|
func (k *KeyPair) IsPublicKey() bool {
|
|
return openssl.RSA.IsPublicKey(k.publicKey)
|
|
}
|
|
|
|
// IsPrivateKey whether is a private key.
|
|
// 是否是私钥
|
|
func (k *KeyPair) IsPrivateKey() bool {
|
|
return openssl.RSA.IsPrivateKey(k.privateKey)
|
|
}
|
|
|
|
// IsSupportedHash whether is a supported hash algorithm.
|
|
// 判断是否是支持的哈希算法
|
|
func (k *KeyPair) IsSupportedHash() bool {
|
|
hashes := []crypto.Hash{
|
|
crypto.MD5, crypto.SHA1, crypto.SHA224, crypto.SHA256, crypto.SHA384, crypto.SHA512, crypto.RIPEMD160,
|
|
}
|
|
for _, hash := range hashes {
|
|
if hash == k.hash {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// left padding.
|
|
// 左补码
|
|
func leftPad(src []byte, size int) (dst []byte) {
|
|
dst = make([]byte, size)
|
|
copy(dst[len(dst)-len(src):], src)
|
|
return
|
|
}
|
|
|
|
// remove left padding.
|
|
// 左减码
|
|
func leftUnPad(src []byte) (dst []byte) {
|
|
n := len(src)
|
|
t := 2
|
|
for i := 2; i < n; i++ {
|
|
if src[i] == 0xff {
|
|
t = t + 1
|
|
} else {
|
|
if src[i] == src[0] {
|
|
t = t + int(src[1])
|
|
}
|
|
break
|
|
}
|
|
}
|
|
dst = make([]byte, n-t)
|
|
copy(dst, src[t:])
|
|
return
|
|
}
|
|
|
|
// split the byte slice by the specified size.
|
|
// 按照指定长度分割字节切片
|
|
func bytesSplit(buf []byte, size int) [][]byte {
|
|
var chunk []byte
|
|
chunks := make([][]byte, 0, len(buf)/size+1)
|
|
for len(buf) >= size {
|
|
chunk, buf = buf[:size], buf[size:]
|
|
chunks = append(chunks, chunk)
|
|
}
|
|
if len(buf) > 0 {
|
|
chunks = append(chunks, buf[:])
|
|
}
|
|
return chunks
|
|
}
|