zzf.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. /*
  2. * Copyright (c) 2000-2018, 达梦数据库有限公司.
  3. * All rights reserved.
  4. */
  5. package security
  6. import (
  7. "crypto/md5"
  8. "errors"
  9. "fmt"
  10. "reflect"
  11. "unsafe"
  12. )
  13. type ThirdPartCipher struct {
  14. encryptType int // 外部加密算法id
  15. encryptName string // 外部加密算法名称
  16. hashType int
  17. key []byte
  18. cipherCount int // 外部加密算法个数
  19. //innerId int // 外部加密算法内部id
  20. blockSize int // 分组块大小
  21. khSize int // key/hash大小
  22. }
  23. func NewThirdPartCipher(encryptType int, key []byte, cipherPath string, hashType int) (ThirdPartCipher, error) {
  24. var tpc = ThirdPartCipher{
  25. encryptType: encryptType,
  26. key: key,
  27. hashType: hashType,
  28. cipherCount: -1,
  29. }
  30. var err error
  31. err = initThirdPartCipher(cipherPath)
  32. if err != nil {
  33. return tpc, err
  34. }
  35. tpc.getCount()
  36. if err = tpc.getInfo(); err != nil {
  37. return tpc, err
  38. }
  39. return tpc, nil
  40. }
  41. func (tpc *ThirdPartCipher) getCount() int {
  42. if tpc.cipherCount == -1 {
  43. tpc.cipherCount = cipherGetCount()
  44. }
  45. return tpc.cipherCount
  46. }
  47. func (tpc *ThirdPartCipher) getInfo() error {
  48. var cipher_id, ty, blk_size, kh_size int
  49. //var strptr, _ = syscall.UTF16PtrFromString(tpc.encryptName)
  50. var strptr *uint16 = new(uint16)
  51. for i := 1; i <= tpc.getCount(); i++ {
  52. cipherGetInfo(uintptr(i), uintptr(unsafe.Pointer(&cipher_id)), uintptr(unsafe.Pointer(&strptr)),
  53. uintptr(unsafe.Pointer(&ty)), uintptr(unsafe.Pointer(&blk_size)), uintptr(unsafe.Pointer(&kh_size)))
  54. if tpc.encryptType == cipher_id {
  55. tpc.blockSize = blk_size
  56. tpc.khSize = kh_size
  57. tpc.encryptName = string(uintptr2bytes(uintptr(unsafe.Pointer(strptr))))
  58. return nil
  59. }
  60. }
  61. return fmt.Errorf("ThirdPartyCipher: cipher id:%d not found", tpc.encryptType)
  62. }
  63. func (tpc ThirdPartCipher) Encrypt(plaintext []byte, genDigest bool) []byte {
  64. var tmp_para uintptr
  65. cipherEncryptInit(uintptr(tpc.encryptType), uintptr(unsafe.Pointer(&tpc.key[0])), uintptr(len(tpc.key)), tmp_para)
  66. ciphertextLen := cipherGetCipherTextSize(uintptr(tpc.encryptType), tmp_para, uintptr(len(plaintext)))
  67. ciphertext := make([]byte, ciphertextLen)
  68. ret := cipherEncrypt(uintptr(tpc.encryptType), tmp_para, uintptr(unsafe.Pointer(&plaintext[0])), uintptr(len(plaintext)),
  69. uintptr(unsafe.Pointer(&ciphertext[0])), uintptr(len(ciphertext)))
  70. ciphertext = ciphertext[:ret]
  71. cipherClean(uintptr(tpc.encryptType), tmp_para)
  72. // md5摘要
  73. if genDigest {
  74. digest := md5.Sum(plaintext)
  75. encrypt := ciphertext
  76. ciphertext = make([]byte, len(encrypt)+len(digest))
  77. copy(ciphertext[:len(encrypt)], encrypt)
  78. copy(ciphertext[len(encrypt):], digest[:])
  79. }
  80. return ciphertext
  81. }
  82. func (tpc ThirdPartCipher) Decrypt(ciphertext []byte, checkDigest bool) ([]byte, error) {
  83. var ret []byte
  84. if checkDigest {
  85. var digest = ciphertext[len(ciphertext)-MD5_DIGEST_SIZE:]
  86. ret = ciphertext[:len(ciphertext)-MD5_DIGEST_SIZE]
  87. ret = tpc.decrypt(ret)
  88. var msgDigest = md5.Sum(ret)
  89. if !reflect.DeepEqual(msgDigest[:], digest) {
  90. return nil, errors.New("Decrypt failed/Digest not match\n")
  91. }
  92. } else {
  93. ret = tpc.decrypt(ciphertext)
  94. }
  95. return ret, nil
  96. }
  97. func (tpc ThirdPartCipher) decrypt(ciphertext []byte) []byte {
  98. var tmp_para uintptr
  99. cipherDecryptInit(uintptr(tpc.encryptType), uintptr(unsafe.Pointer(&tpc.key[0])), uintptr(len(tpc.key)), tmp_para)
  100. plaintext := make([]byte, len(ciphertext))
  101. ret := cipherDecrypt(uintptr(tpc.encryptType), tmp_para, uintptr(unsafe.Pointer(&ciphertext[0])), uintptr(len(ciphertext)),
  102. uintptr(unsafe.Pointer(&plaintext[0])), uintptr(len(plaintext)))
  103. plaintext = plaintext[:ret]
  104. cipherClean(uintptr(tpc.encryptType), tmp_para)
  105. return plaintext
  106. }
  107. func addBufSize(buf []byte, newCap int) []byte {
  108. newBuf := make([]byte, newCap)
  109. copy(newBuf, buf)
  110. return newBuf
  111. }
  112. func uintptr2bytes(p uintptr) []byte {
  113. buf := make([]byte, 64)
  114. i := 0
  115. for b := (*byte)(unsafe.Pointer(p)); *b != 0; i++ {
  116. if i > cap(buf) {
  117. buf = addBufSize(buf, i*2)
  118. }
  119. buf[i] = *b
  120. // byte占1字节
  121. p++
  122. b = (*byte)(unsafe.Pointer(p))
  123. }
  124. return buf[:i]
  125. }