361 lines
9.7 KiB
Go
361 lines
9.7 KiB
Go
package ctwatch
|
|
|
|
import (
|
|
"fmt"
|
|
"bytes"
|
|
"errors"
|
|
"encoding/asn1"
|
|
"math/big"
|
|
"time"
|
|
"net"
|
|
)
|
|
|
|
var (
|
|
oidExtensionSubjectAltName = asn1.ObjectIdentifier{2, 5, 29, 17}
|
|
oidExtensionBasicConstraints = asn1.ObjectIdentifier{2, 5, 29, 19}
|
|
oidCountry = asn1.ObjectIdentifier{2, 5, 4, 6}
|
|
oidOrganization = asn1.ObjectIdentifier{2, 5, 4, 10}
|
|
oidOrganizationalUnit = asn1.ObjectIdentifier{2, 5, 4, 11}
|
|
oidCommonName = asn1.ObjectIdentifier{2, 5, 4, 3}
|
|
oidSerialNumber = asn1.ObjectIdentifier{2, 5, 4, 5}
|
|
oidLocality = asn1.ObjectIdentifier{2, 5, 4, 7}
|
|
oidProvince = asn1.ObjectIdentifier{2, 5, 4, 8}
|
|
oidStreetAddress = asn1.ObjectIdentifier{2, 5, 4, 9}
|
|
oidPostalCode = asn1.ObjectIdentifier{2, 5, 4, 17}
|
|
)
|
|
|
|
type CertValidity struct {
|
|
NotBefore time.Time
|
|
NotAfter time.Time
|
|
}
|
|
|
|
type basicConstraints struct {
|
|
IsCA bool `asn1:"optional"`
|
|
MaxPathLen int `asn1:"optional,default:-1"`
|
|
}
|
|
|
|
type Extension struct {
|
|
Id asn1.ObjectIdentifier
|
|
Critical bool `asn1:"optional"`
|
|
Value []byte
|
|
}
|
|
|
|
const (
|
|
sanOtherName = 0
|
|
sanRfc822Name = 1
|
|
sanDNSName = 2
|
|
sanX400Address = 3
|
|
sanDirectoryName = 4
|
|
sanEdiPartyName = 5
|
|
sanURI = 6
|
|
sanIPAddress = 7
|
|
sanRegisteredID = 8
|
|
)
|
|
type SubjectAltName struct {
|
|
Type int
|
|
Value []byte
|
|
}
|
|
|
|
type RDNSequence []RelativeDistinguishedNameSET
|
|
type RelativeDistinguishedNameSET []AttributeTypeAndValue
|
|
type AttributeTypeAndValue struct {
|
|
Type asn1.ObjectIdentifier
|
|
Value asn1.RawValue
|
|
}
|
|
|
|
type TBSCertificate struct {
|
|
Raw asn1.RawContent
|
|
|
|
Version int `asn1:"optional,explicit,default:1,tag:0"`
|
|
SerialNumber asn1.RawValue
|
|
SignatureAlgorithm asn1.RawValue
|
|
Issuer asn1.RawValue
|
|
Validity asn1.RawValue
|
|
Subject asn1.RawValue
|
|
PublicKey asn1.RawValue
|
|
UniqueId asn1.BitString `asn1:"optional,tag:1"`
|
|
SubjectUniqueId asn1.BitString `asn1:"optional,tag:2"`
|
|
Extensions []Extension `asn1:"optional,explicit,tag:3"`
|
|
}
|
|
|
|
type Certificate struct {
|
|
Raw asn1.RawContent
|
|
|
|
TBSCertificate asn1.RawValue
|
|
SignatureAlgorithm asn1.RawValue
|
|
SignatureValue asn1.RawValue
|
|
}
|
|
|
|
|
|
func (rdns RDNSequence) ParseCNs () ([]string, error) {
|
|
var cns []string
|
|
|
|
for _, rdn := range rdns {
|
|
if len(rdn) == 0 {
|
|
continue
|
|
}
|
|
atv := rdn[0]
|
|
if atv.Type.Equal(oidCommonName) {
|
|
cnString, err := decodeASN1String(&atv.Value)
|
|
if err != nil {
|
|
return nil, errors.New("Error decoding CN: " + err.Error())
|
|
}
|
|
cns = append(cns, cnString)
|
|
}
|
|
}
|
|
|
|
return cns, nil
|
|
}
|
|
|
|
func rdnLabel (oid asn1.ObjectIdentifier) string {
|
|
switch {
|
|
case oid.Equal(oidCountry): return "C"
|
|
case oid.Equal(oidOrganization): return "O"
|
|
case oid.Equal(oidOrganizationalUnit): return "OU"
|
|
case oid.Equal(oidCommonName): return "CN"
|
|
case oid.Equal(oidSerialNumber): return "serialNumber"
|
|
case oid.Equal(oidLocality): return "L"
|
|
case oid.Equal(oidProvince): return "ST"
|
|
case oid.Equal(oidStreetAddress): return "street"
|
|
case oid.Equal(oidPostalCode): return "postalCode"
|
|
}
|
|
return oid.String()
|
|
}
|
|
|
|
func (rdns RDNSequence) String () string {
|
|
var buf bytes.Buffer
|
|
|
|
for _, rdn := range rdns {
|
|
if len(rdn) == 0 {
|
|
continue
|
|
}
|
|
atv := rdn[0]
|
|
|
|
if buf.Len() != 0 {
|
|
buf.WriteString(", ")
|
|
}
|
|
buf.WriteString(rdnLabel(atv.Type))
|
|
buf.WriteString("=")
|
|
valueString, err := decodeASN1String(&atv.Value)
|
|
if err == nil {
|
|
buf.WriteString(valueString) // TODO: escape non-printable characters, '\', and ','
|
|
} else {
|
|
fmt.Fprintf(&buf, "%v", atv.Value.FullBytes)
|
|
}
|
|
}
|
|
|
|
return buf.String()
|
|
}
|
|
|
|
func (san SubjectAltName) String () string {
|
|
switch san.Type {
|
|
case sanDNSName:
|
|
return "DNS:" + string(san.Value) // TODO: escape non-printable characters, '\', and ','
|
|
case sanIPAddress:
|
|
if len(san.Value) == 4 || len(san.Value) == 16 {
|
|
return "IP:" + net.IP(san.Value).String()
|
|
} else {
|
|
return fmt.Sprintf("IP:%v", san.Value)
|
|
}
|
|
default:
|
|
// TODO: support other types of SANs
|
|
return fmt.Sprintf("%d:%v", san.Type, san.Value)
|
|
}
|
|
}
|
|
|
|
func ParseTBSCertificate (tbsBytes []byte) (*TBSCertificate, error) {
|
|
var tbs TBSCertificate
|
|
if rest, err := asn1.Unmarshal(tbsBytes, &tbs); err != nil {
|
|
return nil, errors.New("failed to parse TBS: " + err.Error())
|
|
} else if len(rest) > 0 {
|
|
return nil, fmt.Errorf("trailing data after TBS: %v", rest)
|
|
}
|
|
return &tbs, nil
|
|
}
|
|
|
|
func (tbs *TBSCertificate) ParseValidity () (*CertValidity, error) {
|
|
var rawValidity struct {
|
|
NotBefore asn1.RawValue
|
|
NotAfter asn1.RawValue
|
|
}
|
|
if rest, err := asn1.Unmarshal(tbs.Validity.FullBytes, &rawValidity); err != nil {
|
|
return nil, errors.New("failed to parse validity: " + err.Error())
|
|
} else if len(rest) > 0 {
|
|
return nil, fmt.Errorf("trailing data after validity: %v", rest)
|
|
}
|
|
|
|
var validity CertValidity
|
|
var err error
|
|
if validity.NotBefore, err = decodeASN1Time(&rawValidity.NotBefore); err != nil {
|
|
return nil, errors.New("failed to decode notBefore time: " + err.Error())
|
|
}
|
|
if validity.NotAfter, err = decodeASN1Time(&rawValidity.NotAfter); err != nil {
|
|
return nil, errors.New("failed to decode notAfter time: " + err.Error())
|
|
}
|
|
|
|
return &validity, nil
|
|
}
|
|
|
|
func (tbs *TBSCertificate) ParseBasicConstraints () (*bool, error) {
|
|
isCA := false
|
|
isNotCA := false
|
|
|
|
// Some certs in the wild have multiple BasicConstraints extensions (is there anything
|
|
// that CAs haven't screwed up???), so we process all of them and only choke if they
|
|
// are contradictory (which has not been observed...yet).
|
|
for _, ext := range tbs.GetExtension(oidExtensionBasicConstraints) {
|
|
var constraints basicConstraints
|
|
if rest, err := asn1.Unmarshal(ext.Value, &constraints); err != nil {
|
|
return nil, errors.New("failed to parse Basic Constraints: " + err.Error())
|
|
} else if len(rest) > 0 {
|
|
return nil, fmt.Errorf("trailing data after Basic Constraints: %v", rest)
|
|
}
|
|
|
|
if constraints.IsCA {
|
|
isCA = true
|
|
} else {
|
|
isNotCA = true
|
|
}
|
|
}
|
|
|
|
if !isCA && !isNotCA {
|
|
return nil, nil
|
|
} else if isCA && !isNotCA {
|
|
trueValue := true
|
|
return &trueValue, nil
|
|
} else if !isCA && isNotCA {
|
|
falseValue := false
|
|
return &falseValue, nil
|
|
} else {
|
|
return nil, fmt.Errorf("Certificate has more than one Basic Constraints extension and they are contradictory")
|
|
}
|
|
}
|
|
|
|
func (tbs *TBSCertificate) ParseSerialNumber () (*big.Int, error) {
|
|
serialNumber := big.NewInt(0)
|
|
if rest, err := asn1.Unmarshal(tbs.SerialNumber.FullBytes, &serialNumber); err != nil {
|
|
return nil, errors.New("failed to parse serial number: " + err.Error())
|
|
} else if len(rest) > 0 {
|
|
return nil, fmt.Errorf("trailing data after serial number: %v", rest)
|
|
}
|
|
return serialNumber, nil
|
|
}
|
|
|
|
func (tbs *TBSCertificate) GetRawPublicKey () []byte {
|
|
return tbs.PublicKey.FullBytes
|
|
}
|
|
|
|
func (tbs *TBSCertificate) GetRawSubject () []byte {
|
|
return tbs.Subject.FullBytes
|
|
}
|
|
|
|
func (tbs *TBSCertificate) GetRawIssuer () []byte {
|
|
return tbs.Issuer.FullBytes
|
|
}
|
|
|
|
func (tbs *TBSCertificate) ParseSubject () (RDNSequence, error) {
|
|
var subject RDNSequence
|
|
if rest, err := asn1.Unmarshal(tbs.GetRawSubject(), &subject); err != nil {
|
|
return nil, errors.New("failed to parse certificate subject: " + err.Error())
|
|
} else if len(rest) != 0 {
|
|
return nil, fmt.Errorf("trailing data in certificate subject: %v", rest)
|
|
}
|
|
return subject, nil
|
|
}
|
|
|
|
func (tbs *TBSCertificate) ParseIssuer () (RDNSequence, error) {
|
|
var issuer RDNSequence
|
|
if rest, err := asn1.Unmarshal(tbs.GetRawIssuer(), &issuer); err != nil {
|
|
return nil, errors.New("failed to parse certificate issuer: " + err.Error())
|
|
} else if len(rest) != 0 {
|
|
return nil, fmt.Errorf("trailing data in certificate issuer: %v", rest)
|
|
}
|
|
return issuer, nil
|
|
}
|
|
|
|
func (tbs *TBSCertificate) ParseSubjectCommonNames () ([]string, error) {
|
|
subject, err := tbs.ParseSubject()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cns, err := subject.ParseCNs()
|
|
if err != nil {
|
|
return nil, errors.New("failed to process certificate subject: " + err.Error())
|
|
}
|
|
|
|
return cns, nil
|
|
}
|
|
|
|
func (tbs *TBSCertificate) ParseSubjectAltNames () ([]SubjectAltName, error) {
|
|
sans := []SubjectAltName{}
|
|
|
|
for _, sanExt := range tbs.GetExtension(oidExtensionSubjectAltName) {
|
|
var err error
|
|
sans, err = parseSANExtension(sans, sanExt.Value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return sans, nil
|
|
}
|
|
|
|
func (tbs *TBSCertificate) GetExtension (id asn1.ObjectIdentifier) []Extension {
|
|
var exts []Extension
|
|
for _, ext := range tbs.Extensions {
|
|
if ext.Id.Equal(id) {
|
|
exts = append(exts, ext)
|
|
}
|
|
}
|
|
return exts
|
|
}
|
|
|
|
|
|
func ParseCertificate (certBytes []byte) (*Certificate, error) {
|
|
var cert Certificate
|
|
if rest, err := asn1.Unmarshal(certBytes, &cert); err != nil {
|
|
return nil, errors.New("failed to parse certificate: " + err.Error())
|
|
} else if len(rest) > 0 {
|
|
return nil, fmt.Errorf("trailing data after certificate: %v", rest)
|
|
}
|
|
return &cert, nil
|
|
}
|
|
|
|
func (cert *Certificate) GetRawTBSCertificate () []byte {
|
|
return cert.TBSCertificate.FullBytes
|
|
}
|
|
|
|
func (cert *Certificate) ParseTBSCertificate () (*TBSCertificate, error) {
|
|
return ParseTBSCertificate(cert.GetRawTBSCertificate())
|
|
}
|
|
|
|
func parseSANExtension (sans []SubjectAltName, value []byte) ([]SubjectAltName, error) {
|
|
var seq asn1.RawValue
|
|
if rest, err := asn1.Unmarshal(value, &seq); err != nil {
|
|
return nil, errors.New("failed to parse subjectAltName extension: " + err.Error())
|
|
} else if len(rest) != 0 {
|
|
// Don't complain if the SAN is followed by exactly one zero byte,
|
|
// which is a common error.
|
|
if !(len(rest) == 1 && rest[0] == 0) {
|
|
return nil, fmt.Errorf("trailing data in subjectAltName extension: %v", rest)
|
|
}
|
|
}
|
|
if !seq.IsCompound || seq.Tag != 16 || seq.Class != 0 {
|
|
return nil, errors.New("failed to parse subjectAltName extension: bad SAN sequence")
|
|
}
|
|
|
|
rest := seq.Bytes
|
|
for len(rest) > 0 {
|
|
var val asn1.RawValue
|
|
var err error
|
|
rest, err = asn1.Unmarshal(rest, &val)
|
|
if err != nil {
|
|
return nil, errors.New("failed to parse subjectAltName extension item: " + err.Error())
|
|
}
|
|
sans = append(sans, SubjectAltName{Type: val.Tag, Value: val.Bytes})
|
|
}
|
|
|
|
return sans, nil
|
|
}
|
|
|