certspotter/x509.go

408 lines
11 KiB
Go

// Copyright (C) 2016 Opsmate, Inc.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License, v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
//
// This software is distributed WITHOUT A WARRANTY OF ANY KIND.
// See the Mozilla Public License for details.
package certspotter
import (
"bytes"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"fmt"
"math/big"
"net"
"time"
)
var (
oidExtensionSubjectAltName = asn1.ObjectIdentifier{2, 5, 29, 17}
oidExtensionBasicConstraints = asn1.ObjectIdentifier{2, 5, 29, 19}
oidCountry = asn1.ObjectIdentifier{2, 5, 4, 6}
oidOrganization = asn1.ObjectIdentifier{2, 5, 4, 10}
oidOrganizationalUnit = asn1.ObjectIdentifier{2, 5, 4, 11}
oidCommonName = asn1.ObjectIdentifier{2, 5, 4, 3}
oidSerialNumber = asn1.ObjectIdentifier{2, 5, 4, 5}
oidLocality = asn1.ObjectIdentifier{2, 5, 4, 7}
oidProvince = asn1.ObjectIdentifier{2, 5, 4, 8}
oidStreetAddress = asn1.ObjectIdentifier{2, 5, 4, 9}
oidPostalCode = asn1.ObjectIdentifier{2, 5, 4, 17}
)
type CertValidity struct {
NotBefore time.Time
NotAfter time.Time
}
type basicConstraints struct {
IsCA bool `asn1:"optional"`
MaxPathLen int `asn1:"optional,default:-1"`
}
type Extension struct {
Id asn1.ObjectIdentifier
Critical bool `asn1:"optional"`
Value []byte
}
const (
sanOtherName = 0
sanRfc822Name = 1
sanDNSName = 2
sanX400Address = 3
sanDirectoryName = 4
sanEdiPartyName = 5
sanURI = 6
sanIPAddress = 7
sanRegisteredID = 8
)
type SubjectAltName struct {
Type int
Value []byte
}
type RDNSequence []RelativeDistinguishedNameSET
type RelativeDistinguishedNameSET []AttributeTypeAndValue
type AttributeTypeAndValue struct {
Type asn1.ObjectIdentifier
Value asn1.RawValue
}
func ParseRDNSequence(rdnsBytes []byte) (RDNSequence, error) {
var rdns RDNSequence
if rest, err := asn1.Unmarshal(rdnsBytes, &rdns); err != nil {
return nil, errors.New("failed to parse RDNSequence: " + err.Error())
} else if len(rest) != 0 {
return nil, fmt.Errorf("trailing data after RDNSequence: %v", rest) // XXX: too strict?
}
return rdns, nil
}
func MarshalRDNSequence(rdns RDNSequence) ([]byte, error) {
return asn1.Marshal(rdns)
}
type TBSCertificate struct {
Raw asn1.RawContent
Version int `asn1:"optional,explicit,default:1,tag:0"`
SerialNumber asn1.RawValue
SignatureAlgorithm asn1.RawValue
Issuer asn1.RawValue
Validity asn1.RawValue
Subject asn1.RawValue
PublicKey asn1.RawValue
UniqueId asn1.BitString `asn1:"optional,tag:1"`
SubjectUniqueId asn1.BitString `asn1:"optional,tag:2"`
Extensions []Extension `asn1:"optional,explicit,tag:3"`
}
type Certificate struct {
Raw asn1.RawContent
TBSCertificate asn1.RawValue
SignatureAlgorithm asn1.RawValue
SignatureValue asn1.RawValue
}
func (rdns RDNSequence) ParseCNs() ([]string, error) {
var cns []string
for _, rdn := range rdns {
if len(rdn) == 0 {
continue
}
atv := rdn[0]
if atv.Type.Equal(oidCommonName) {
cnString, err := decodeASN1String(&atv.Value)
if err != nil {
return nil, errors.New("Error decoding CN: " + err.Error())
}
cns = append(cns, cnString)
}
}
return cns, nil
}
func rdnLabel(oid asn1.ObjectIdentifier) string {
switch {
case oid.Equal(oidCountry):
return "C"
case oid.Equal(oidOrganization):
return "O"
case oid.Equal(oidOrganizationalUnit):
return "OU"
case oid.Equal(oidCommonName):
return "CN"
case oid.Equal(oidSerialNumber):
return "serialNumber"
case oid.Equal(oidLocality):
return "L"
case oid.Equal(oidProvince):
return "ST"
case oid.Equal(oidStreetAddress):
return "street"
case oid.Equal(oidPostalCode):
return "postalCode"
}
return oid.String()
}
func (rdns RDNSequence) String() string {
var buf bytes.Buffer
for _, rdn := range rdns {
if len(rdn) == 0 {
continue
}
atv := rdn[0]
if buf.Len() != 0 {
buf.WriteString(", ")
}
buf.WriteString(rdnLabel(atv.Type))
buf.WriteString("=")
valueString, err := decodeASN1String(&atv.Value)
if err == nil {
buf.WriteString(valueString) // TODO: escape non-printable characters, '\', and ','
} else {
fmt.Fprintf(&buf, "%v", atv.Value.FullBytes)
}
}
return buf.String()
}
func (san SubjectAltName) String() string {
switch san.Type {
case sanDNSName:
return "DNS:" + string(san.Value) // TODO: escape non-printable characters, '\', and ','
case sanIPAddress:
if len(san.Value) == 4 || len(san.Value) == 16 {
return "IP:" + net.IP(san.Value).String()
} else {
return fmt.Sprintf("IP:%v", san.Value)
}
default:
// TODO: support other types of SANs
return fmt.Sprintf("%d:%v", san.Type, san.Value)
}
}
func ParseTBSCertificate(tbsBytes []byte) (*TBSCertificate, error) {
var tbs TBSCertificate
if rest, err := asn1.Unmarshal(tbsBytes, &tbs); err != nil {
return nil, errors.New("failed to parse TBS: " + err.Error())
} else if len(rest) > 0 {
return nil, fmt.Errorf("trailing data after TBS: %v", rest) // XXX: too strict?
}
return &tbs, nil
}
func (tbs *TBSCertificate) ParseValidity() (*CertValidity, error) {
var rawValidity struct {
NotBefore asn1.RawValue
NotAfter asn1.RawValue
}
if rest, err := asn1.Unmarshal(tbs.Validity.FullBytes, &rawValidity); err != nil {
return nil, errors.New("failed to parse validity: " + err.Error())
} else if len(rest) > 0 {
return nil, fmt.Errorf("trailing data after validity: %v", rest)
}
var validity CertValidity
var err error
if validity.NotBefore, err = decodeASN1Time(&rawValidity.NotBefore); err != nil {
return nil, errors.New("failed to decode notBefore time: " + err.Error())
}
if validity.NotAfter, err = decodeASN1Time(&rawValidity.NotAfter); err != nil {
return nil, errors.New("failed to decode notAfter time: " + err.Error())
}
return &validity, nil
}
func (tbs *TBSCertificate) ParseBasicConstraints() (*bool, error) {
isCA := false
isNotCA := false
// Some certs in the wild have multiple BasicConstraints extensions (is there anything
// that CAs haven't screwed up???), so we process all of them and only choke if they
// are contradictory (which has not been observed...yet).
for _, ext := range tbs.GetExtension(oidExtensionBasicConstraints) {
var constraints basicConstraints
if rest, err := asn1.Unmarshal(ext.Value, &constraints); err != nil {
return nil, errors.New("failed to parse Basic Constraints: " + err.Error())
} else if len(rest) > 0 {
return nil, fmt.Errorf("trailing data after Basic Constraints: %v", rest)
}
if constraints.IsCA {
isCA = true
} else {
isNotCA = true
}
}
if !isCA && !isNotCA {
return nil, nil
} else if isCA && !isNotCA {
trueValue := true
return &trueValue, nil
} else if !isCA && isNotCA {
falseValue := false
return &falseValue, nil
} else {
return nil, fmt.Errorf("Certificate has more than one Basic Constraints extension and they are contradictory")
}
}
func (tbs *TBSCertificate) ParseSerialNumber() (*big.Int, error) {
serialNumber := big.NewInt(0)
if rest, err := asn1.Unmarshal(tbs.SerialNumber.FullBytes, &serialNumber); err != nil {
return nil, errors.New("failed to parse serial number: " + err.Error())
} else if len(rest) > 0 {
return nil, fmt.Errorf("trailing data after serial number: %v", rest)
}
return serialNumber, nil
}
func (tbs *TBSCertificate) GetRawPublicKey() []byte {
return tbs.PublicKey.FullBytes
}
func (tbs *TBSCertificate) GetRawSubject() []byte {
return tbs.Subject.FullBytes
}
func (tbs *TBSCertificate) GetRawIssuer() []byte {
return tbs.Issuer.FullBytes
}
func (tbs *TBSCertificate) ParseSubject() (RDNSequence, error) {
subject, err := ParseRDNSequence(tbs.GetRawSubject())
if err != nil {
return nil, errors.New("failed to parse certificate subject: " + err.Error())
}
return subject, nil
}
func (tbs *TBSCertificate) ParseIssuer() (RDNSequence, error) {
issuer, err := ParseRDNSequence(tbs.GetRawIssuer())
if err != nil {
return nil, errors.New("failed to parse certificate issuer: " + err.Error())
}
return issuer, nil
}
func (tbs *TBSCertificate) ParseSubjectCommonNames() ([]string, error) {
subject, err := tbs.ParseSubject()
if err != nil {
return nil, err
}
cns, err := subject.ParseCNs()
if err != nil {
return nil, errors.New("failed to process certificate subject: " + err.Error())
}
return cns, nil
}
func (tbs *TBSCertificate) ParseSubjectAltNames() ([]SubjectAltName, error) {
sans := []SubjectAltName{}
for _, sanExt := range tbs.GetExtension(oidExtensionSubjectAltName) {
var err error
sans, err = ParseSANExtension(sans, sanExt.Value)
if err != nil {
return nil, err
}
}
return sans, nil
}
func (tbs *TBSCertificate) GetExtension(id asn1.ObjectIdentifier) []Extension {
var exts []Extension
for _, ext := range tbs.Extensions {
if ext.Id.Equal(id) {
exts = append(exts, ext)
}
}
return exts
}
func ParseCertificate(certBytes []byte) (*Certificate, error) {
var cert Certificate
if rest, err := asn1.Unmarshal(certBytes, &cert); err != nil {
return nil, errors.New("failed to parse certificate: " + err.Error())
} else if len(rest) > 0 {
return nil, fmt.Errorf("trailing data after certificate: %v", rest) // XXX: too strict?
}
return &cert, nil
}
func (cert *Certificate) GetRawTBSCertificate() []byte {
return cert.TBSCertificate.FullBytes
}
func (cert *Certificate) ParseTBSCertificate() (*TBSCertificate, error) {
return ParseTBSCertificate(cert.GetRawTBSCertificate())
}
func (cert *Certificate) ParseSignatureAlgorithm() (*pkix.AlgorithmIdentifier, error) {
signatureAlgorithm := new(pkix.AlgorithmIdentifier)
if rest, err := asn1.Unmarshal(cert.SignatureAlgorithm.FullBytes, signatureAlgorithm); err != nil {
return nil, errors.New("failed to parse signature algorithm: " + err.Error())
} else if len(rest) > 0 {
return nil, fmt.Errorf("trailing data after signature algorithm: %v", rest)
}
return signatureAlgorithm, nil
}
func (cert *Certificate) ParseSignatureValue() ([]byte, error) {
var signatureValue asn1.BitString
if rest, err := asn1.Unmarshal(cert.SignatureValue.FullBytes, &signatureValue); err != nil {
return nil, errors.New("failed to parse signature value: " + err.Error())
} else if len(rest) > 0 {
return nil, fmt.Errorf("trailing data after signature value: %v", rest)
}
return signatureValue.RightAlign(), nil
}
func ParseSANExtension(sans []SubjectAltName, value []byte) ([]SubjectAltName, error) {
var seq asn1.RawValue
if rest, err := asn1.Unmarshal(value, &seq); err != nil {
return nil, errors.New("failed to parse subjectAltName extension: " + err.Error())
} else if len(rest) != 0 {
// Don't complain if the SAN is followed by exactly one zero byte,
// which is a common error.
if !(len(rest) == 1 && rest[0] == 0) {
return nil, fmt.Errorf("trailing data in subjectAltName extension: %v", rest) // XXX: too strict?
}
}
if !seq.IsCompound || seq.Tag != 16 || seq.Class != 0 {
return nil, errors.New("failed to parse subjectAltName extension: bad SAN sequence") // XXX: too strict?
}
rest := seq.Bytes
for len(rest) > 0 {
var val asn1.RawValue
var err error
rest, err = asn1.Unmarshal(rest, &val)
if err != nil {
return nil, errors.New("failed to parse subjectAltName extension item: " + err.Error())
}
sans = append(sans, SubjectAltName{Type: val.Tag, Value: val.Bytes})
}
return sans, nil
}