Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,27 @@ type ConfigNetwork struct {
}

type ConfigHost struct {
ID string
Name string
IPAddress string
ID string
Name string
IPAddresses []string
}

type ConfigEndpointOIDC struct {
Email string
ExpiresAt *time.Time
}

// mergeIPAddresses returns the plural field if populated, otherwise wraps the singular value.
func mergeIPAddresses(plural []string, singular string) []string {
if len(plural) > 0 {
return plural
}
if singular != "" {
return []string{singular}
}
return nil
}

// Enroll issues an enrollment request against the REST API using the given enrollment code, passing along a locally
// generated DH X25519 public key to be signed by the CA, and an Ed 25519 public key for future API call authentication.
// On success it returns the Nebula config generated by the server, a Nebula private key PEM to be inserted into the
Expand Down Expand Up @@ -178,9 +189,9 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
Name: r.Network.Name,
},
Host: ConfigHost{
ID: r.HostID,
Name: r.Host.Name,
IPAddress: r.Host.IPAddress,
ID: r.HostID,
Name: r.Host.Name,
IPAddresses: mergeIPAddresses(r.Host.IPAddresses, r.Host.IPAddress),
},
}

Expand Down Expand Up @@ -352,9 +363,9 @@ func (c *Client) DoUpdate(ctx context.Context, creds keys.Credentials) ([]byte,
Name: result.Network.Name,
},
Host: ConfigHost{
ID: result.Host.ID,
Name: result.Host.Name,
IPAddress: result.Host.IPAddress,
ID: result.Host.ID,
Name: result.Host.Name,
IPAddresses: mergeIPAddresses(result.Host.IPAddresses, result.Host.IPAddress),
},
}

Expand Down Expand Up @@ -460,9 +471,9 @@ func (c *Client) DoConfigUpdate(ctx context.Context, creds keys.Credentials) ([]
Name: result.Network.Name,
},
Host: ConfigHost{
ID: result.Host.ID,
Name: result.Host.Name,
IPAddress: result.Host.IPAddress,
ID: result.Host.ID,
Name: result.Host.Name,
IPAddresses: mergeIPAddresses(result.Host.IPAddresses, result.Host.IPAddress),
},
}

Expand Down
197 changes: 195 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func TestEnroll(t *testing.T) {
assert.Equal(t, netName, meta.Network.Name)
assert.Equal(t, hostID, meta.Host.ID)
assert.Equal(t, hostName, meta.Host.Name)
assert.Equal(t, hostIP, meta.Host.IPAddress)
assert.Equal(t, []string{hostIP}, meta.Host.IPAddresses)
assert.Equal(t, oidcEmail, meta.EndpointOIDC.Email)
assert.WithinDuration(t, oidcExpiresAt, *meta.EndpointOIDC.ExpiresAt, 1*time.Second)

Expand Down Expand Up @@ -438,7 +438,7 @@ func TestDoUpdate(t *testing.T) {
assert.Equal(t, netName, meta.Network.Name)
assert.Equal(t, hostID, meta.Host.ID)
assert.Equal(t, hostName, meta.Host.Name)
assert.Equal(t, hostIP, meta.Host.IPAddress)
assert.Equal(t, []string{hostIP}, meta.Host.IPAddresses)
assert.Equal(t, oidcEmail, meta.EndpointOIDC.Email)
assert.Nil(t, meta.EndpointOIDC.ExpiresAt)

Expand Down Expand Up @@ -1206,6 +1206,199 @@ func TestDownloads(t *testing.T) {
assert.Equal(t, "0.5.1", resp.VersionInfo.Latest.Mobile)
}

func TestEnroll_PluralMeta(t *testing.T) {
t.Parallel()

useragent := "testClient"
ts := dnapitest.NewServer(useragent)
client := NewClient(useragent, ts.URL)
t.Cleanup(func() { ts.Close() })

code := "abcdef"
orgID := "foobaz"
orgName := "foobar's foo org"
netID := "qux"
netName := "the best network"
netCIDRs := []string{"192.168.100.0/24", "10.0.0.0/16"}
hostID := "foobar"
hostName := "foo host"
hostIPs := []string{"192.168.100.1", "10.0.0.1"}
counter := uint(5)
ca, _ := dnapitest.NebulaCACert()
caPEM, err := ca.MarshalPEM()
require.NoError(t, err)

ts.ExpectEnrollment(code, message.NetworkCurve25519, func(req message.EnrollRequest) []byte {
cfg, err := yaml.Marshal(m{
"pki": m{"ca": string(caPEM)},
})
if err != nil {
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIResponseErrors{{
Code: "ERR_FAILED_TO_MARSHAL_YAML",
Message: "failed to marshal test response config",
}},
})
}

return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Data: message.EnrollResponseData{
HostID: hostID,
Counter: counter,
Config: cfg,
TrustedKeys: ca.MarshalPublicKeyPEM(),
Organization: message.HostOrgMetadata{
ID: orgID,
Name: orgName,
},
Network: message.HostNetworkMetadata{
ID: netID,
Name: netName,
Curve: message.NetworkCurve25519,
CIDRs: netCIDRs,
},
Host: message.HostHostMetadata{
ID: hostID,
Name: hostName,
IPAddresses: hostIPs,
},
},
})
})

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, _, _, meta, err := client.Enroll(ctx, testutil.NewTestLogger(), code)
require.NoError(t, err)
assert.Empty(t, ts.Errors())
assert.Equal(t, 0, ts.RequestsRemaining())

// test meta
assert.Equal(t, orgID, meta.Org.ID)
assert.Equal(t, orgName, meta.Org.Name)
assert.Equal(t, netID, meta.Network.ID)
assert.Equal(t, netName, meta.Network.Name)
assert.Equal(t, hostID, meta.Host.ID)
assert.Equal(t, hostName, meta.Host.Name)
assert.Equal(t, hostIPs, meta.Host.IPAddresses)
}

func TestDoUpdate_PluralMeta(t *testing.T) {
t.Parallel()

useragent := "testClient"
ts := dnapitest.NewServer(useragent)
t.Cleanup(func() { ts.Close() })

ca, caPrivkey := dnapitest.NebulaCACert()
caPEM, err := ca.MarshalPEM()
require.NoError(t, err)

c := NewClient(useragent, ts.URL)

code := "foobar"
ts.ExpectEnrollment(code, message.NetworkCurve25519, func(req message.EnrollRequest) []byte {
cfg, err := yaml.Marshal(m{
"pki": m{"ca": string(caPEM)},
})
if err != nil {
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIResponseErrors{{
Code: "ERR_FAILED_TO_MARSHAL_YAML",
Message: "failed to marshal test response config",
}},
})
}

return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Data: message.EnrollResponseData{
HostID: "foobar",
Counter: 1,
Config: cfg,
TrustedKeys: ca.MarshalPublicKeyPEM(),
Organization: message.HostOrgMetadata{
ID: "foobaz",
Name: "foobar's foo org",
},
Network: message.HostNetworkMetadata{
ID: "qux",
Name: "the best network",
Curve: message.NetworkCurve25519,
CIDR: "192.168.100.0/24",
},
Host: message.HostHostMetadata{
ID: "foobar",
Name: "foo host",
IPAddress: "192.168.100.2",
},
},
})
})

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, _, creds, _, err := c.Enroll(ctx, testutil.NewTestLogger(), code)
require.NoError(t, err)

orgID := "foobaz"
orgName := "foobar's foo org"
netID := "qux"
netName := "the best network"
netCIDRs := []string{"192.168.100.0/24", "10.0.0.0/16"}
hostID := "foobar"
hostName := "foo host"
hostIPs := []string{"192.168.100.1", "10.0.0.1"}

ts.ExpectDNClientRequest(message.DoUpdate, http.StatusOK, func(r message.RequestWrapper) []byte {
newConfigResponse := message.DoUpdateResponse{
Config: dnapitest.NebulaCfg(caPEM),
Counter: 3,
Nonce: dnapitest.GetNonce(r),
TrustedKeys: ca.MarshalPublicKeyPEM(),
Organization: message.HostOrgMetadata{
ID: orgID,
Name: orgName,
},
Network: message.HostNetworkMetadata{
ID: netID,
Name: netName,
Curve: message.NetworkCurve25519,
CIDRs: netCIDRs,
},
Host: message.HostHostMetadata{
ID: hostID,
Name: hostName,
IPAddresses: hostIPs,
},
}
rawRes := jsonMarshal(newConfigResponse)

return jsonMarshal(message.SignedResponseWrapper{
Data: message.SignedResponse{
Version: 1,
Message: rawRes,
Signature: ed25519.Sign(caPrivkey, rawRes),
},
})
})

ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, _, _, meta, err := c.DoUpdate(ctx, *creds)
require.NoError(t, err)
assert.Empty(t, ts.Errors())
assert.Equal(t, 0, ts.RequestsRemaining())

// test meta
assert.Equal(t, orgID, meta.Org.ID)
assert.Equal(t, orgName, meta.Org.Name)
assert.Equal(t, netID, meta.Network.ID)
assert.Equal(t, netName, meta.Network.Name)
assert.Equal(t, hostID, meta.Host.ID)
assert.Equal(t, hostName, meta.Host.Name)
assert.Equal(t, hostIPs, meta.Host.IPAddresses)
}

func TestNebulaPemBanners(t *testing.T) {
const NebulaECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY"
const NebulaEd25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
Expand Down
8 changes: 5 additions & 3 deletions message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,15 @@ type HostNetworkMetadata struct {
Name string `json:"name"`
Curve NetworkCurve `json:"curve"`
CIDR string `json:"cidr"`
CIDRs []string `json:"cidrs"`
}

// HostHostMetadata is included in EnrollResponseData.
type HostHostMetadata struct {
ID string `json:"id"`
Name string `json:"name"`
IPAddress string `json:"ipAddress"`
ID string `json:"id"`
Name string `json:"name"`
IPAddress string `json:"ipAddress"`
IPAddresses []string `json:"ipAddresses"`
}

// HostEndpointOIDCMetadata is included in EnrollResponseData.
Expand Down
Loading