zze.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. /*
  2. * Copyright (c) 2000-2018, 达梦数据库有限公司.
  3. * All rights reserved.
  4. */
  5. package security
  6. import (
  7. "bytes"
  8. "crypto/aes"
  9. "crypto/cipher"
  10. "crypto/des"
  11. "crypto/md5"
  12. "crypto/rc4"
  13. "errors"
  14. "reflect"
  15. )
  16. type SymmCipher struct {
  17. encryptCipher interface{} //cipher.BlockMode | cipher.Stream
  18. decryptCipher interface{} //cipher.BlockMode | cipher.Stream
  19. key []byte
  20. block cipher.Block // 分组加密算法
  21. algorithmType int
  22. workMode int
  23. needPadding bool
  24. }
  25. func NewSymmCipher(algorithmID int, key []byte) (SymmCipher, error) {
  26. var sc SymmCipher
  27. var err error
  28. sc.key = key
  29. sc.algorithmType = algorithmID & ALGO_MASK
  30. sc.workMode = algorithmID & WORK_MODE_MASK
  31. switch sc.algorithmType {
  32. case AES128:
  33. if sc.block, err = aes.NewCipher(key[:16]); err != nil {
  34. return sc, err
  35. }
  36. case AES192:
  37. if sc.block, err = aes.NewCipher(key[:24]); err != nil {
  38. return sc, err
  39. }
  40. case AES256:
  41. if sc.block, err = aes.NewCipher(key[:32]); err != nil {
  42. return sc, err
  43. }
  44. case DES:
  45. if sc.block, err = des.NewCipher(key[:8]); err != nil {
  46. return sc, err
  47. }
  48. case DES3:
  49. var tripleDESKey []byte
  50. tripleDESKey = append(tripleDESKey, key[:16]...)
  51. tripleDESKey = append(tripleDESKey, key[:8]...)
  52. if sc.block, err = des.NewTripleDESCipher(tripleDESKey); err != nil {
  53. return sc, err
  54. }
  55. case RC4:
  56. if sc.encryptCipher, err = rc4.NewCipher(key[:16]); err != nil {
  57. return sc, err
  58. }
  59. if sc.decryptCipher, err = rc4.NewCipher(key[:16]); err != nil {
  60. return sc, err
  61. }
  62. return sc, nil
  63. default:
  64. return sc, errors.New("invalidCipher")
  65. }
  66. blockSize := sc.block.BlockSize()
  67. if sc.encryptCipher, err = sc.getEncrypter(sc.workMode, sc.block, defaultIV[:blockSize]); err != nil {
  68. return sc, err
  69. }
  70. if sc.decryptCipher, err = sc.getDecrypter(sc.workMode, sc.block, defaultIV[:blockSize]); err != nil {
  71. return sc, err
  72. }
  73. return sc, nil
  74. }
  75. func (sc SymmCipher) Encrypt(plaintext []byte, genDigest bool) []byte {
  76. // 执行过加密后,IV值变了,需要重新初始化encryptCipher对象(因为没有类似resetIV的方法)
  77. if sc.algorithmType != RC4 {
  78. sc.encryptCipher, _ = sc.getEncrypter(sc.workMode, sc.block, defaultIV[:sc.block.BlockSize()])
  79. } else {
  80. sc.encryptCipher, _ = rc4.NewCipher(sc.key[:16])
  81. }
  82. // 填充
  83. var paddingtext = make([]byte, len(plaintext))
  84. copy(paddingtext, plaintext)
  85. if sc.needPadding {
  86. paddingtext = pkcs5Padding(paddingtext)
  87. }
  88. ret := make([]byte, len(paddingtext))
  89. if v, ok := sc.encryptCipher.(cipher.Stream); ok {
  90. v.XORKeyStream(ret, paddingtext)
  91. } else if v, ok := sc.encryptCipher.(cipher.BlockMode); ok {
  92. v.CryptBlocks(ret, paddingtext)
  93. }
  94. // md5摘要
  95. if genDigest {
  96. digest := md5.Sum(plaintext)
  97. encrypt := ret
  98. ret = make([]byte, len(encrypt)+len(digest))
  99. copy(ret[:len(encrypt)], encrypt)
  100. copy(ret[len(encrypt):], digest[:])
  101. }
  102. return ret
  103. }
  104. func (sc SymmCipher) Decrypt(ciphertext []byte, checkDigest bool) ([]byte, error) {
  105. // 执行过解密后,IV值变了,需要重新初始化decryptCipher对象(因为没有类似resetIV的方法)
  106. if sc.algorithmType != RC4 {
  107. sc.decryptCipher, _ = sc.getDecrypter(sc.workMode, sc.block, defaultIV[:sc.block.BlockSize()])
  108. } else {
  109. sc.decryptCipher, _ = rc4.NewCipher(sc.key[:16])
  110. }
  111. var ret []byte
  112. if checkDigest {
  113. var digest = ciphertext[len(ciphertext)-MD5_DIGEST_SIZE:]
  114. ret = ciphertext[:len(ciphertext)-MD5_DIGEST_SIZE]
  115. ret = sc.decrypt(ret)
  116. var msgDigest = md5.Sum(ret)
  117. if !reflect.DeepEqual(msgDigest[:], digest) {
  118. return nil, errors.New("Decrypt failed/Digest not match\n")
  119. }
  120. } else {
  121. ret = sc.decrypt(ciphertext)
  122. }
  123. return ret, nil
  124. }
  125. func (sc SymmCipher) decrypt(ciphertext []byte) []byte {
  126. ret := make([]byte, len(ciphertext))
  127. if v, ok := sc.decryptCipher.(cipher.Stream); ok {
  128. v.XORKeyStream(ret, ciphertext)
  129. } else if v, ok := sc.decryptCipher.(cipher.BlockMode); ok {
  130. v.CryptBlocks(ret, ciphertext)
  131. }
  132. // 去除填充
  133. if sc.needPadding {
  134. ret = pkcs5UnPadding(ret)
  135. }
  136. return ret
  137. }
  138. func (sc *SymmCipher) getEncrypter(workMode int, block cipher.Block, iv []byte) (ret interface{}, err error) {
  139. switch workMode {
  140. case ECB_MODE:
  141. ret = NewECBEncrypter(block)
  142. sc.needPadding = true
  143. case CBC_MODE:
  144. ret = cipher.NewCBCEncrypter(block, iv)
  145. sc.needPadding = true
  146. case CFB_MODE:
  147. ret = cipher.NewCFBEncrypter(block, iv)
  148. sc.needPadding = false
  149. case OFB_MODE:
  150. ret = cipher.NewOFB(block, iv)
  151. sc.needPadding = false
  152. default:
  153. err = errors.New("invalidCipherMode")
  154. }
  155. return
  156. }
  157. func (sc *SymmCipher) getDecrypter(workMode int, block cipher.Block, iv []byte) (ret interface{}, err error) {
  158. switch workMode {
  159. case ECB_MODE:
  160. ret = NewECBDecrypter(block)
  161. sc.needPadding = true
  162. case CBC_MODE:
  163. ret = cipher.NewCBCDecrypter(block, iv)
  164. sc.needPadding = true
  165. case CFB_MODE:
  166. ret = cipher.NewCFBDecrypter(block, iv)
  167. sc.needPadding = false
  168. case OFB_MODE:
  169. ret = cipher.NewOFB(block, iv)
  170. sc.needPadding = false
  171. default:
  172. err = errors.New("invalidCipherMode")
  173. }
  174. return
  175. }
  176. // 补码
  177. func pkcs77Padding(ciphertext []byte, blocksize int) []byte {
  178. padding := blocksize - len(ciphertext)%blocksize
  179. padtext := bytes.Repeat([]byte{byte(padding)}, padding)
  180. return append(ciphertext, padtext...)
  181. }
  182. // 去码
  183. func pkcs7UnPadding(origData []byte) []byte {
  184. length := len(origData)
  185. unpadding := int(origData[length-1])
  186. return origData[:length-unpadding]
  187. }
  188. // 补码
  189. func pkcs5Padding(ciphertext []byte) []byte {
  190. return pkcs77Padding(ciphertext, 8)
  191. }
  192. // 去码
  193. func pkcs5UnPadding(ciphertext []byte) []byte {
  194. return pkcs7UnPadding(ciphertext)
  195. }