Refactor the comparison predicates, check for identity first.

This commit is contained in:
Dag-Erling Smørgrav 2017-03-29 19:56:53 +02:00
parent 47a0bf838f
commit a11c52e896
5 changed files with 48 additions and 16 deletions

View File

@ -43,15 +43,10 @@ int
mpi_cmp(const cryb_mpi *X, const cryb_mpi *Y)
{
if (X->neg) {
if (Y->neg)
return (-mpi_cmp_abs(X, Y));
else
return (-1);
} else {
if (Y->neg)
return (1);
else
return (mpi_cmp_abs(X, Y));
}
if (X == Y)
return (0);
else if (X->neg)
return (Y->neg ? -mpi_cmp_abs(X, Y) : -1);
else
return (Y->neg ? 1 : mpi_cmp_abs(X, Y));
}

View File

@ -44,7 +44,9 @@ mpi_cmp_abs(const cryb_mpi *X, const cryb_mpi *Y)
{
int i;
/* check width first */
/* check trivial cases first */
if (X == Y)
return (0);
if (X->msb > Y->msb)
return (1);
if (X->msb < Y->msb)

View File

@ -43,6 +43,9 @@ int
mpi_eq(const cryb_mpi *A, const cryb_mpi *B)
{
return (A->neg == B->neg && A->msb == B->msb &&
memcmp(A->words, B->words, (A->msb + 31) / 32) == 0);
if (A == B)
return (1);
if (A->neg != B->neg || A->msb != B->msb)
return (0);
return (memcmp(A->words, B->words, (A->msb + 31) / 32) == 0);
}

View File

@ -43,6 +43,9 @@ int
mpi_eq_abs(const cryb_mpi *A, const cryb_mpi *B)
{
return (A->msb == B->msb &&
memcmp(A->words, B->words, (A->msb + 31) / 32) == 0);
if (A == B)
return (1);
if (A->msb != B->msb)
return (0);
return (memcmp(A->words, B->words, (A->msb + 31) / 32) == 0);
}

View File

@ -142,6 +142,33 @@ t_mpi_eq(char **desc CRYB_UNUSED, void *arg)
return (ret);
}
/*
* Compare an MPI with itself
*/
static int
t_mpi_cmp_ident(char **desc CRYB_UNUSED, void *arg CRYB_UNUSED)
{
cryb_mpi a = CRYB_MPI_ZERO;
int ret = 1;
mpi_set(&a, CRYB_TO);
ret &= t_compare_i(0, mpi_cmp(&a, &a));
mpi_destroy(&a);
return (ret);
}
static int
t_mpi_eq_ident(char **desc CRYB_UNUSED, void *arg CRYB_UNUSED)
{
cryb_mpi a = CRYB_MPI_ZERO;
int ret = 1;
mpi_set(&a, CRYB_TO);
ret &= t_compare_i(1, mpi_eq(&a, &a));
mpi_destroy(&a);
return (ret);
}
/*
* Compare an MPI with an integer
*/
@ -187,9 +214,11 @@ t_prepare(int argc, char *argv[])
t_mpi_prepare();
/* comparison */
t_add_test(t_mpi_cmp_ident, NULL, "mpi cmp mpi (identity)");
for (i = 0; i < sizeof t_cmp_cases / sizeof t_cmp_cases[0]; ++i)
t_add_test(t_mpi_cmp, &t_cmp_cases[i],
"mpi cmp mpi (%s)", t_cmp_cases[i].desc);
t_add_test(t_mpi_eq_ident, NULL, "mpi eq mpi (identity)");
for (i = 0; i < sizeof t_cmp_cases / sizeof t_cmp_cases[0]; ++i)
t_add_test(t_mpi_eq, &t_cmp_cases[i],
"mpi eq mpi (%s)", t_cmp_cases[i].desc);