From 616b928e68256f3c2d392a3520b866bffc8bdd67 Mon Sep 17 00:00:00 2001 From: Edward Viaene Date: Tue, 10 Feb 2026 11:28:16 -0600 Subject: [PATCH 1/2] azure byol --- licensing/azure.go | 30 +++++ licensing/azure_test.go | 256 ++++++++++++++++++++++++++++++++++++++ licensing/gcp_test.go | 4 +- licensing/license.go | 26 ++-- licensing/license_test.go | 89 +++++++++++-- 5 files changed, 382 insertions(+), 23 deletions(-) create mode 100644 licensing/azure_test.go diff --git a/licensing/azure.go b/licensing/azure.go index 339ee2e..f88d7db 100644 --- a/licensing/azure.go +++ b/licensing/azure.go @@ -79,3 +79,33 @@ func getAzureInstanceType(client http.Client) string { } return instanceMetadata.Compute.VMSize } + +func getAzureInstancePlan(client http.Client) Plan { + metadataEndpoint := "http://" + MetadataIP + "/metadata/instance/compute?api-version=2021-02-01" + req, err := http.NewRequest("GET", metadataEndpoint, nil) + if err != nil { + return Plan{} + } + + req.Header.Add("Metadata", "true") + + resp, err := client.Do(req) + if err != nil { + return Plan{} + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return Plan{} + } + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return Plan{} + } + var instanceComputeMetadata Compute + err = json.Unmarshal(bodyBytes, &instanceComputeMetadata) + if err != nil { + return Plan{} + } + return instanceComputeMetadata.Plan +} diff --git a/licensing/azure_test.go b/licensing/azure_test.go new file mode 100644 index 0000000..83edaed --- /dev/null +++ b/licensing/azure_test.go @@ -0,0 +1,256 @@ +package licensing + +import ( + "bytes" + "errors" + "io" + "net/http" + "strings" + "testing" +) + +// roundTripperFunc lets us stub http.Client.Do() without spinning up a server. +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func httpResp(status int, body string) *http.Response { + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + } +} + +func Test_isOnAzure(t *testing.T) { + origIP := MetadataIP + MetadataIP = "169.254.169.254" + t.Cleanup(func() { MetadataIP = origIP }) + + t.Run("returns true on 200 from /metadata/versions and sends Metadata header", func(t *testing.T) { + client := http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Method != http.MethodGet { + t.Fatalf("expected GET, got %s", req.Method) + } + if req.URL.String() != "http://169.254.169.254/metadata/versions" { + t.Fatalf("unexpected url: %s", req.URL.String()) + } + if got := req.Header.Get("Metadata"); got != "true" { + t.Fatalf("expected Metadata:true header, got %q", got) + } + return httpResp(200, `["2021-02-01","2025-04-07"]`), nil + }), + } + + if got := isOnAzure(client); got != true { + t.Fatalf("expected true, got %v", got) + } + }) + + t.Run("returns false on non-200", func(t *testing.T) { + client := http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return httpResp(404, `{"error":"not found"}`), nil + }), + } + + if got := isOnAzure(client); got != false { + t.Fatalf("expected false, got %v", got) + } + }) + + t.Run("returns false on transport error", func(t *testing.T) { + client := http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return nil, errors.New("network down") + }), + } + + if got := isOnAzure(client); got != false { + t.Fatalf("expected false, got %v", got) + } + }) +} + +func Test_GetMaxUsersAzure(t *testing.T) { + tests := []struct { + name string + instanceType string + want int + }{ + {"empty defaults to 3", "", 3}, + // typical Azure VM size string contains CPU count as the first number: D2s_v3 -> 2 -> 50 users + {"parses cpu count from Standard_D2s_v3", "Standard_D2s_v3", 50}, + // special-case: if first extracted number is 0 => 15 + {"cpu count 0 special-cases to 15", "Standard_D0s_v3", 15}, + // your code strips leading version prefix matching ^.*v[0-9]+# + {"strips version prefix", "foo-v12#Standard_D4s_v3", 100}, + {"no digits falls back to 3", "Standard_Whatever", 3}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetMaxUsersAzure(tt.instanceType); got != tt.want { + t.Fatalf("GetMaxUsersAzure(%q)=%d, want %d", tt.instanceType, got, tt.want) + } + }) + } +} + +func Test_getAzureInstanceType(t *testing.T) { + origIP := MetadataIP + MetadataIP = "169.254.169.254" + t.Cleanup(func() { MetadataIP = origIP }) + + // Realistic (trimmed) sample based on Microsoft Learn IMDS "instance" response: + // it returns an object with top-level "compute", and compute includes "vmSize" and "plan". + const instanceJSON = `{ + "compute": { + "azEnvironment": "AZUREPUBLICCLOUD", + "location": "westus", + "name": "examplevmname", + "plan": { "name": "planName", "product": "planProduct", "publisher": "planPublisher" }, + "vmSize": "Standard_D2s_v3" + } + }` + + t.Run("returns vmSize on 200 and valid JSON", func(t *testing.T) { + client := http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.String() != "http://169.254.169.254/metadata/instance?api-version=2021-02-01" { + t.Fatalf("unexpected url: %s", req.URL.String()) + } + if got := req.Header.Get("Metadata"); got != "true" { + t.Fatalf("expected Metadata:true header, got %q", got) + } + return httpResp(200, instanceJSON), nil + }), + } + + if got := getAzureInstanceType(client); got != "Standard_D2s_v3" { + t.Fatalf("expected Standard_D2s_v3, got %q", got) + } + }) + + t.Run("returns empty string on non-200", func(t *testing.T) { + client := http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return httpResp(500, `{"error":"boom"}`), nil + }), + } + if got := getAzureInstanceType(client); got != "" { + t.Fatalf("expected empty string, got %q", got) + } + }) + + t.Run("returns empty string on invalid JSON", func(t *testing.T) { + client := http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return httpResp(200, `{not-json`), nil + }), + } + if got := getAzureInstanceType(client); got != "" { + t.Fatalf("expected empty string, got %q", got) + } + }) +} + +func Test_getAzureInstancePlan(t *testing.T) { + origIP := MetadataIP + MetadataIP = "169.254.169.254" + t.Cleanup(func() { MetadataIP = origIP }) + + // Realistic (trimmed) sample based on Microsoft Learn IMDS "compute" endpoint: + // /metadata/instance/compute returns the compute object (not wrapped). + const computeJSON = `{ + "azEnvironment": "AZUREPUBLICCLOUD", + "location": "westus", + "name": "examplevmname", + "plan": { "name": "planName", "product": "planProduct", "publisher": "planPublisher" }, + "vmSize": "Standard_D2s_v3" + }` + + t.Run("returns plan on 200 and valid JSON", func(t *testing.T) { + client := http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.String() != "http://169.254.169.254/metadata/instance/compute?api-version=2021-02-01" { + t.Fatalf("unexpected url: %s", req.URL.String()) + } + if got := req.Header.Get("Metadata"); got != "true" { + t.Fatalf("expected Metadata:true header, got %q", got) + } + return httpResp(200, computeJSON), nil + }), + } + + got := getAzureInstancePlan(client) + if got.Name != "planName" || got.Product != "planProduct" || got.Publisher != "planPublisher" { + t.Fatalf("unexpected plan: %#v", got) + } + }) + + t.Run("returns empty Plan on invalid JSON", func(t *testing.T) { + client := http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return httpResp(200, `{not-json`), nil + }), + } + got := getAzureInstancePlan(client) + if got != (Plan{}) { + t.Fatalf("expected empty plan, got %#v", got) + } + }) + + t.Run("returns empty Plan on read error", func(t *testing.T) { + // Force ReadAll to fail by returning a Body that errors. + errBody := io.NopCloser(&errorReader{}) + client := http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: errBody, + Header: make(http.Header), + }, nil + }), + } + got := getAzureInstancePlan(client) + if got != (Plan{}) { + t.Fatalf("expected empty plan, got %#v", got) + } + }) +} + +type errorReader struct{} + +func (r *errorReader) Read(_ []byte) (int, error) { return 0, errors.New("read failed") } +func (r *errorReader) Close() error { return nil } + +// Optional: a sanity test that the fake client can verify headers across endpoints. +func Test_fakeClientRejectsMissingMetadataHeader(t *testing.T) { + origIP := MetadataIP + MetadataIP = "169.254.169.254" + t.Cleanup(func() { MetadataIP = origIP }) + + // Here we simulate a transport that enforces the header; the production code SHOULD set it. + client := http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Header.Get("Metadata") != "true" { + return httpResp(400, `{"error":"missing Metadata header"}`), nil + } + return httpResp(200, `[]`), nil + }), + } + + if got := isOnAzure(client); got != true { + // If this fails, your code isn't setting Metadata:true for /metadata/versions. + t.Fatalf("expected true, got %v", got) + } +} + +// Small helper if you want to create responses with bytes.Reader bodies in other tests. +func bodyFromBytes(b []byte) io.ReadCloser { + return io.NopCloser(bytes.NewReader(b)) +} diff --git a/licensing/gcp_test.go b/licensing/gcp_test.go index 37ca71e..9db4c81 100644 --- a/licensing/gcp_test.go +++ b/licensing/gcp_test.go @@ -22,9 +22,9 @@ func TestGuessInfrastructureGCP(t *testing.T) { })) defer ts.Close() - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.ReplaceAll(ts.URL, "http://", "") - infra := guessInfrastructure() + infra := guessInfrastructure(*ts.Client()) if infra != "gcp" { t.Fatalf("wrong infra returned: %s", infra) diff --git a/licensing/license.go b/licensing/license.go index 2dac6dd..db0ff11 100644 --- a/licensing/license.go +++ b/licensing/license.go @@ -14,11 +14,8 @@ import ( var MetadataIP = "169.254.169.254" var licenseURL = "https://in4it-vpn-server.s3.amazonaws.com/licenses" -func guessInfrastructure() string { +func guessInfrastructure(client http.Client) string { // check whether we are on AWS, Azure, DigitalOcean or something undefined - client := http.Client{ - Timeout: 5 * time.Second, - } if isOnAWSMarketPlace(client) { return "aws-marketplace" @@ -29,7 +26,12 @@ func guessInfrastructure() string { } if isOnAzure(client) { - return "azure" + if plan := getAzureInstancePlan(client); plan.Name != "" { + if plan.Publisher == "in4it" && plan.Name == "vpn-server-plan" && plan.Product == "vpn-server" { + return "azure" + } + } + return "azure-byol" } if isOnDigitalOcean(client) { @@ -47,9 +49,11 @@ func GetInstanceType() (string, string) { client := http.Client{ Timeout: 5 * time.Second, } - switch guessInfrastructure() { + switch guessInfrastructure(client) { case "azure": return "azure", getAzureInstanceType(client) + case "azure-byol": + return "azure-byol", getAzureInstanceType(client) case "aws-marketplace": return "aws-marketplace", getAWSInstanceType(client) case "aws": @@ -82,11 +86,11 @@ func getMaxUsers(storage storage.ReadWriter, cloudType, instanceType string) int Timeout: 5 * time.Second, } return GetMaxUsersDigitalOceanBYOL(client, storage) - case "gcp": - client := http.Client{ - Timeout: 5 * time.Second, - } - return GetMaxUsersGCPBYOL(client, storage) + case "gcp": + client := http.Client{ + Timeout: 5 * time.Second, + } + return GetMaxUsersGCPBYOL(client, storage) case "": client := http.Client{ Timeout: 5 * time.Second, diff --git a/licensing/license_test.go b/licensing/license_test.go index 8bd3b74..f31522e 100644 --- a/licensing/license_test.go +++ b/licensing/license_test.go @@ -119,23 +119,92 @@ func TestGetMaxUsersAWS(t *testing.T) { func TestGuessInfrastructureAzure(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.RequestURI == "/metadata/versions" { + switch r.RequestURI { + case "/metadata/versions": w.WriteHeader(http.StatusOK) return + case "/metadata/instance/compute?api-version=2021-02-01": + w.Write([]byte(`{ + "azEnvironment": "AzurePublicCloud", + "customData": "", + "evictionPolicy": "", + "isHostCompatibilityLayerVm": "false", + "licenseType": "", + "location": "eastus", + "name": "vpn-test", + "offer": "vpn-server", + "osProfile": { + "adminUsername": "azureuser", + "computerName": "vpn-test", + "disablePasswordAuthentication": "true" + }, + "osType": "Linux", + "placementGroupId": "", + "plan": { + "name": "vpn-server-plan", + "product": "vpn-server", + "publisher": "in4it" + } +}`)) + return } - w.WriteHeader(http.StatusInternalServerError) + w.WriteHeader(http.StatusNotFound) })) defer ts.Close() - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.ReplaceAll(ts.URL, "http://", "") - infra := guessInfrastructure() + infra := guessInfrastructure(*ts.Client()) if infra != "azure" { t.Fatalf("wrong infra returned: %s", infra) } } +func TestGuessInfrastructureAzureBYOL(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.RequestURI { + case "/metadata/versions": + w.WriteHeader(http.StatusOK) + return + case "/metadata/instance/compute?api-version=2021-02-01": + w.Write([]byte(`{ + "azEnvironment": "AzurePublicCloud", + "customData": "", + "evictionPolicy": "", + "isHostCompatibilityLayerVm": "false", + "licenseType": "", + "location": "eastus", + "name": "vpn-test", + "offer": "vpn-server", + "osProfile": { + "adminUsername": "azureuser", + "computerName": "vpn-test", + "disablePasswordAuthentication": "true" + }, + "osType": "Linux", + "placementGroupId": "", + "plan": { + "name": "vpn-server-byol-plan", + "product": "vpn-server", + "publisher": "in4it" + } +}`)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + + MetadataIP = strings.ReplaceAll(ts.URL, "http://", "") + + infra := guessInfrastructure(*ts.Client()) + + if infra != "azure-byol" { + t.Fatalf("wrong infra returned: %s", infra) + } +} + func TestGuessInfrastructureAWSMarketplace(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.RequestURI == "/metadata/versions" { @@ -172,7 +241,7 @@ func TestGuessInfrastructureAWSMarketplace(t *testing.T) { MetadataIP = strings.Replace(ts.URL, "http://", "", -1) - infra := guessInfrastructure() + infra := guessInfrastructure(*ts.Client()) if infra != "aws-marketplace" { t.Fatalf("wrong infra returned: %s", infra) @@ -206,7 +275,7 @@ func TestGuessInfrastructureAWS(t *testing.T) { MetadataIP = strings.Replace(ts.URL, "http://", "", -1) - infra := guessInfrastructure() + infra := guessInfrastructure(*ts.Client()) if infra != "aws" { t.Fatalf("wrong infra returned: %s", infra) @@ -223,9 +292,9 @@ func TestGuessInfrastructureOther(t *testing.T) { })) defer ts.Close() - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.ReplaceAll(ts.URL, "http://", "") - infra := guessInfrastructure() + infra := guessInfrastructure(*ts.Client()) if infra != "" { t.Fatalf("wrong infra returned: %s", infra) @@ -294,9 +363,9 @@ func TestGuessInfrastructureDigitalOcean(t *testing.T) { })) defer ts.Close() - MetadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.ReplaceAll(ts.URL, "http://", "") - infra := guessInfrastructure() + infra := guessInfrastructure(*ts.Client()) if infra != "digitalocean" { t.Fatalf("wrong infra returned: %s", infra) From 8937db06669bbb15ab5f343d4c1b5db961d95ad9 Mon Sep 17 00:00:00 2001 From: Edward Viaene Date: Tue, 10 Feb 2026 14:20:50 -0600 Subject: [PATCH 2/2] azure byol licensing --- licensing/azure.go | 62 +++++++++++++++-- licensing/license.go | 5 ++ licensing/license_test.go | 137 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 198 insertions(+), 6 deletions(-) diff --git a/licensing/azure.go b/licensing/azure.go index f88d7db..a0da95a 100644 --- a/licensing/azure.go +++ b/licensing/azure.go @@ -2,10 +2,14 @@ package licensing import ( "encoding/json" + "fmt" "io" "net/http" "regexp" "strconv" + + "github.com/in4it/go-devops-platform/logging" + "github.com/in4it/go-devops-platform/storage" ) func isOnAzure(client http.Client) bool { @@ -81,31 +85,77 @@ func getAzureInstanceType(client http.Client) string { } func getAzureInstancePlan(client http.Client) Plan { + instanceComputeMetadata := getAzureComputeMetadata(client) + return instanceComputeMetadata.Plan +} + +func getAzureComputeMetadata(client http.Client) Compute { metadataEndpoint := "http://" + MetadataIP + "/metadata/instance/compute?api-version=2021-02-01" req, err := http.NewRequest("GET", metadataEndpoint, nil) if err != nil { - return Plan{} + return Compute{} } req.Header.Add("Metadata", "true") resp, err := client.Do(req) if err != nil { - return Plan{} + return Compute{} } defer resp.Body.Close() if resp.StatusCode != 200 { - return Plan{} + return Compute{} } bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - return Plan{} + return Compute{} } var instanceComputeMetadata Compute err = json.Unmarshal(bodyBytes, &instanceComputeMetadata) if err != nil { - return Plan{} + return Compute{} } - return instanceComputeMetadata.Plan + return instanceComputeMetadata +} + +func GetMaxUsersAzureBYOL(client http.Client, storage storage.ReadWriter) int { + userLicense := 3 + + licenseKey, err := getAzureLicenseKey(storage, client) + if err != nil { + logging.DebugLog(fmt.Errorf("get azure license error: %s", err)) + return userLicense + } + + license, err := getLicense(client, licenseKey) + if err != nil { + logging.DebugLog(fmt.Errorf("getLicense error: %s", err)) + return userLicense + } + + return license.Users +} + +func getAzureLicenseKey(storage storage.ReadWriter, client http.Client) (string, error) { + identifier, err := getAzureIdentifier(client) + if err != nil { + logging.DebugLog(fmt.Errorf("License generation error (identifier error): %s", err)) + return "", err + } + + licenseKey, err := getLicenseKeyFromFile(storage) + if err != nil { + return "", err + } + + return generateLicenseKey(licenseKey, identifier), nil +} + +func getAzureIdentifier(client http.Client) (string, error) { + computeMetadata := getAzureComputeMetadata(client) + if computeMetadata.VMID != "" { + return computeMetadata.VMID, nil + } + return "", fmt.Errorf("could not get identifier from azure metadata") } diff --git a/licensing/license.go b/licensing/license.go index db0ff11..aabd4f9 100644 --- a/licensing/license.go +++ b/licensing/license.go @@ -74,6 +74,11 @@ func getMaxUsers(storage storage.ReadWriter, cloudType, instanceType string) int switch cloudType { case "azure": return GetMaxUsersAzure(instanceType) + case "azure-byol": + client := http.Client{ + Timeout: 5 * time.Second, + } + return GetMaxUsersAzureBYOL(client, storage) case "aws-marketplace": return GetMaxUsersAWS(instanceType) case "aws": diff --git a/licensing/license_test.go b/licensing/license_test.go index f31522e..18fd322 100644 --- a/licensing/license_test.go +++ b/licensing/license_test.go @@ -205,6 +205,143 @@ func TestGuessInfrastructureAzureBYOL(t *testing.T) { } } +func TestGetMaxUsersAzureBYOL(t *testing.T) { + azureVMID := "282cfd4d-e384-4ed3-8b33-af7d3b84dc3a" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.RequestURI { + case "/metadata/versions": + w.WriteHeader(http.StatusOK) + return + case "/metadata/instance/compute?api-version=2021-02-01": + w.Write([]byte(`{ + "azEnvironment": "AzurePublicCloud", + "customData": "", + "evictionPolicy": "", + "isHostCompatibilityLayerVm": "false", + "licenseType": "", + "location": "eastus", + "name": "vpn-test", + "offer": "vpn-server", + "osProfile": { + "adminUsername": "azureuser", + "computerName": "vpn-test", + "disablePasswordAuthentication": "true" + }, + "osType": "Linux", + "placementGroupId": "", + "plan": { + "name": "vpn-server-byol-plan", + "product": "vpn-server", + "publisher": "in4it" + }, + "tags": "", + "tagsList": [], + "userData": "", + "version": "1.1.12", + "vmId": "` + azureVMID + `", + "vmScaleSetName": "", + "vmSize": "Standard_D2_v4", + "zone": "1" +}`)) + return + } + + // return license + h := sha256.New() + h.Write([]byte(azureVMID)) + + if r.RequestURI == fmt.Sprintf("/license-1234556-license-%x", h.Sum(nil)) { + w.Write([]byte(`{"users": 50}`)) + return + } + + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + + MetadataIP = strings.ReplaceAll(ts.URL, "http://", "") + + testCases := map[string]int{ + "vm1": 50, + "vm2": 50, + } + licenseURL = ts.URL + mockStorage := &memorystorage.MockMemoryStorage{} + err := mockStorage.WriteFile("config/license.key", []byte("license-1234556-license")) + if err != nil { + t.Fatalf("writefile error: %s", err) + } + for _, v := range testCases { + if v2, _ := GetMaxUsers(mockStorage); v2 != v { + t.Fatalf("Wrong output: %d vs %d", v2, v) + } + } +} + +func TestGetMaxUsersAzureBYOLNoLicense(t *testing.T) { + azureVMID := "282cfd4d-e384-4ed3-8b33-af7d3b84dc3a" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.RequestURI { + case "/metadata/versions": + w.WriteHeader(http.StatusOK) + return + case "/metadata/instance/compute?api-version=2021-02-01": + w.Write([]byte(`{ + "azEnvironment": "AzurePublicCloud", + "customData": "", + "evictionPolicy": "", + "isHostCompatibilityLayerVm": "false", + "licenseType": "", + "location": "eastus", + "name": "vpn-test", + "offer": "vpn-server", + "osProfile": { + "adminUsername": "azureuser", + "computerName": "vpn-test", + "disablePasswordAuthentication": "true" + }, + "osType": "Linux", + "placementGroupId": "", + "plan": { + "name": "vpn-server-byol-plan", + "product": "vpn-server", + "publisher": "in4it" + }, + "tags": "", + "tagsList": [], + "userData": "", + "version": "1.1.12", + "vmId": "` + azureVMID + `", + "vmScaleSetName": "", + "vmSize": "Standard_D2_v4", + "zone": "1" +}`)) + return + } + + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + + MetadataIP = strings.ReplaceAll(ts.URL, "http://", "") + + testCases := map[string]int{ + "vm1": 3, + "vm2": 3, + } + licenseURL = ts.URL + mockStorage := &memorystorage.MockMemoryStorage{} + err := mockStorage.WriteFile("config/license.key", []byte("license-12345567-license")) + if err != nil { + t.Fatalf("writefile error: %s", err) + } + for _, v := range testCases { + if v2, _ := GetMaxUsers(mockStorage); v2 != v { + t.Fatalf("Wrong output: %d vs %d", v2, v) + } + } +} + func TestGuessInfrastructureAWSMarketplace(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.RequestURI == "/metadata/versions" {