Home | History | Annotate | Download | only in x509
      1 // Copyright 2017 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // +build dragonfly freebsd linux netbsd openbsd solaris
      6 
      7 package x509
      8 
      9 import (
     10 	"fmt"
     11 	"os"
     12 	"testing"
     13 )
     14 
     15 const (
     16 	testDir     = "testdata"
     17 	testDirCN   = "test-dir"
     18 	testFile    = "test-file.crt"
     19 	testFileCN  = "test-file"
     20 	testMissing = "missing"
     21 )
     22 
     23 func TestEnvVars(t *testing.T) {
     24 	testCases := []struct {
     25 		name    string
     26 		fileEnv string
     27 		dirEnv  string
     28 		files   []string
     29 		dirs    []string
     30 		cns     []string
     31 	}{
     32 		{
     33 			// Environment variables override the default locations preventing fall through.
     34 			name:    "override-defaults",
     35 			fileEnv: testMissing,
     36 			dirEnv:  testMissing,
     37 			files:   []string{testFile},
     38 			dirs:    []string{testDir},
     39 			cns:     nil,
     40 		},
     41 		{
     42 			// File environment overrides default file locations.
     43 			name:    "file",
     44 			fileEnv: testFile,
     45 			dirEnv:  "",
     46 			files:   nil,
     47 			dirs:    nil,
     48 			cns:     []string{testFileCN},
     49 		},
     50 		{
     51 			// Directory environment overrides default directory locations.
     52 			name:    "dir",
     53 			fileEnv: "",
     54 			dirEnv:  testDir,
     55 			files:   nil,
     56 			dirs:    nil,
     57 			cns:     []string{testDirCN},
     58 		},
     59 		{
     60 			// File & directory environment overrides both default locations.
     61 			name:    "file+dir",
     62 			fileEnv: testFile,
     63 			dirEnv:  testDir,
     64 			files:   nil,
     65 			dirs:    nil,
     66 			cns:     []string{testFileCN, testDirCN},
     67 		},
     68 		{
     69 			// Environment variable empty / unset uses default locations.
     70 			name:    "empty-fall-through",
     71 			fileEnv: "",
     72 			dirEnv:  "",
     73 			files:   []string{testFile},
     74 			dirs:    []string{testDir},
     75 			cns:     []string{testFileCN, testDirCN},
     76 		},
     77 	}
     78 
     79 	// Save old settings so we can restore before the test ends.
     80 	origCertFiles, origCertDirectories := certFiles, certDirectories
     81 	origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv)
     82 	defer func() {
     83 		certFiles = origCertFiles
     84 		certDirectories = origCertDirectories
     85 		os.Setenv(certFileEnv, origFile)
     86 		os.Setenv(certDirEnv, origDir)
     87 	}()
     88 
     89 	for _, tc := range testCases {
     90 		t.Run(tc.name, func(t *testing.T) {
     91 			if err := os.Setenv(certFileEnv, tc.fileEnv); err != nil {
     92 				t.Fatalf("setenv %q failed: %v", certFileEnv, err)
     93 			}
     94 			if err := os.Setenv(certDirEnv, tc.dirEnv); err != nil {
     95 				t.Fatalf("setenv %q failed: %v", certDirEnv, err)
     96 			}
     97 
     98 			certFiles, certDirectories = tc.files, tc.dirs
     99 
    100 			r, err := loadSystemRoots()
    101 			if err != nil {
    102 				t.Fatal("unexpected failure:", err)
    103 			}
    104 
    105 			if r == nil {
    106 				if tc.cns == nil {
    107 					// Expected nil
    108 					return
    109 				}
    110 				t.Fatal("nil roots")
    111 			}
    112 
    113 			// Verify that the returned certs match, otherwise report where the mismatch is.
    114 			for i, cn := range tc.cns {
    115 				if i >= len(r.certs) {
    116 					t.Errorf("missing cert %v @ %v", cn, i)
    117 				} else if r.certs[i].Subject.CommonName != cn {
    118 					fmt.Printf("%#v\n", r.certs[0].Subject)
    119 					t.Errorf("unexpected cert common name %q, want %q", r.certs[i].Subject.CommonName, cn)
    120 				}
    121 			}
    122 			if len(r.certs) > len(tc.cns) {
    123 				t.Errorf("got %v certs, which is more than %v wanted", len(r.certs), len(tc.cns))
    124 			}
    125 		})
    126 	}
    127 }
    128